//===- ExpressionSimplification.cpp - Simplify HLFIR expressions ----------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "flang/Optimizer/Builder/FIRBuilder.h" #include "flang/Optimizer/HLFIR/HLFIROps.h" #include "flang/Optimizer/HLFIR/Passes.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" namespace hlfir { #define GEN_PASS_DEF_EXPRESSIONSIMPLIFICATION #include "flang/Optimizer/HLFIR/Passes.h.inc" } // namespace hlfir // Get the first user of `op`. // Note that we consider the first user to be the one on the lowest line of // the emitted HLFIR. The user iterator considers the opposite. template static UserOp getFirstUser(mlir::Operation *op) { auto it = op->user_begin(), end = op->user_end(), prev = it; for (; it != end; prev = it++) ; if (prev != end) if (auto userOp = mlir::dyn_cast(*prev)) return userOp; return {}; } // Get the last user of `op`. // Note that we consider the last user to be the one on the highest line of // the emitted HLFIR. The user iterator considers the opposite. template static UserOp getLastUser(mlir::Operation *op) { if (!op->getUsers().empty()) if (auto userOp = mlir::dyn_cast(*op->user_begin())) return userOp; return {}; } namespace { // Trim operations can be erased in certain expressions, such as character // comparisons. // Since a character comparison appends spaces to the shorter character, // calls to trim() that are used only in the comparison can be eliminated. // // Example: // `trim(x) == trim(y)` // can be simplified to // `x == y` class EraseTrim : public mlir::OpRewritePattern { public: using mlir::OpRewritePattern::OpRewritePattern; llvm::LogicalResult matchAndRewrite(hlfir::CharTrimOp trimOp, mlir::PatternRewriter &rewriter) const override { int trimUses = std::distance(trimOp->use_begin(), trimOp->use_end()); auto cmpCharOp = getFirstUser(trimOp); auto destroyOp = getLastUser(trimOp); if (!cmpCharOp || !destroyOp || trimUses != 2) return rewriter.notifyMatchFailure( trimOp, "hlfir.char_trim is not used (only) by hlfir.cmpchar"); rewriter.eraseOp(destroyOp); rewriter.replaceOp(trimOp, trimOp.getChr()); return mlir::success(); } }; class ExpressionSimplificationPass : public hlfir::impl::ExpressionSimplificationBase< ExpressionSimplificationPass> { public: void runOnOperation() override { mlir::MLIRContext *context = &getContext(); mlir::GreedyRewriteConfig config; // Prevent the pattern driver from merging blocks. config.setRegionSimplificationLevel( mlir::GreedySimplifyRegionLevel::Disabled); mlir::RewritePatternSet patterns(context); patterns.insert(context); if (mlir::failed(mlir::applyPatternsGreedily( getOperation(), std::move(patterns), config))) { mlir::emitError(getOperation()->getLoc(), "failure in HLFIR expression simplification"); signalPassFailure(); } } }; } // namespace