summaryrefslogtreecommitdiff
path: root/flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp')
-rw-r--r--flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp208
1 files changed, 207 insertions, 1 deletions
diff --git a/flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp b/flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp
index fe12f49c655b..d8e36ea294cd 100644
--- a/flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp
+++ b/flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp
@@ -2078,6 +2078,212 @@ private:
}
};
+class CmpCharOpConversion : public mlir::OpRewritePattern<hlfir::CmpCharOp> {
+public:
+ using mlir::OpRewritePattern<hlfir::CmpCharOp>::OpRewritePattern;
+
+ llvm::LogicalResult
+ matchAndRewrite(hlfir::CmpCharOp cmp,
+ mlir::PatternRewriter &rewriter) const override {
+
+ fir::FirOpBuilder builder{rewriter, cmp.getOperation()};
+ const mlir::Location &loc = cmp->getLoc();
+
+ auto toVariable =
+ [&builder,
+ &loc](mlir::Value val) -> std::pair<mlir::Value, hlfir::AssociateOp> {
+ mlir::Value opnd;
+ hlfir::AssociateOp associate;
+ if (mlir::isa<hlfir::ExprType>(val.getType())) {
+ hlfir::Entity entity{val};
+ mlir::NamedAttribute byRefAttr = fir::getAdaptToByRefAttr(builder);
+ associate = hlfir::genAssociateExpr(loc, builder, entity,
+ entity.getType(), "", byRefAttr);
+ opnd = associate.getBase();
+ } else {
+ opnd = val;
+ }
+ return {opnd, associate};
+ };
+
+ auto [lhsOpnd, lhsAssociate] = toVariable(cmp.getLchr());
+ auto [rhsOpnd, rhsAssociate] = toVariable(cmp.getRchr());
+
+ hlfir::Entity lhs{lhsOpnd};
+ hlfir::Entity rhs{rhsOpnd};
+
+ auto charTy = mlir::cast<fir::CharacterType>(lhs.getFortranElementType());
+ unsigned kind = charTy.getFKind();
+
+ auto bits = builder.getKindMap().getCharacterBitsize(kind);
+ auto intTy = builder.getIntegerType(bits);
+
+ auto idxTy = builder.getIndexType();
+ auto charLen1Ty =
+ fir::CharacterType::getSingleton(builder.getContext(), kind);
+ mlir::Type designatorType =
+ fir::ReferenceType::get(charLen1Ty, fir::isa_volatile_type(charTy));
+ auto idxAttr = builder.getIntegerAttr(idxTy, 0);
+
+ auto genExtractAndConvertToInt =
+ [&idxAttr, &intTy, &designatorType](
+ mlir::Location loc, fir::FirOpBuilder &builder,
+ hlfir::Entity &charStr, mlir::Value index, mlir::Value length) {
+ auto singleChr = hlfir::DesignateOp::create(
+ builder, loc, designatorType, charStr, /*component=*/{},
+ /*compShape=*/mlir::Value{}, hlfir::DesignateOp::Subscripts{},
+ /*substring=*/mlir::ValueRange{index, index},
+ /*complexPart=*/std::nullopt,
+ /*shape=*/mlir::Value{}, /*typeParams=*/mlir::ValueRange{length},
+ fir::FortranVariableFlagsAttr{});
+ auto chrVal = fir::LoadOp::create(builder, loc, singleChr);
+ mlir::Value intVal = fir::ExtractValueOp::create(
+ builder, loc, intTy, chrVal, builder.getArrayAttr(idxAttr));
+ return intVal;
+ };
+
+ mlir::arith::CmpIPredicate predicate = cmp.getPredicate();
+ mlir::Value oneIdx = builder.createIntegerConstant(loc, idxTy, 1);
+
+ mlir::Value lhsLen = builder.createConvert(
+ loc, idxTy, hlfir::genCharLength(loc, builder, lhs));
+ mlir::Value rhsLen = builder.createConvert(
+ loc, idxTy, hlfir::genCharLength(loc, builder, rhs));
+
+ enum class GenCmp { LeftToRight, LeftToBlank, BlankToRight };
+
+ mlir::Value zeroInt = builder.createIntegerConstant(loc, intTy, 0);
+ mlir::Value oneInt = builder.createIntegerConstant(loc, intTy, 1);
+ mlir::Value negOneInt = builder.createIntegerConstant(loc, intTy, -1);
+ mlir::Value blankInt = builder.createIntegerConstant(loc, intTy, ' ');
+
+ auto step = GenCmp::LeftToRight;
+ auto genCmp = [&](mlir::Location loc, fir::FirOpBuilder &builder,
+ mlir::ValueRange index, mlir::ValueRange reductionArgs)
+ -> llvm::SmallVector<mlir::Value, 1> {
+ assert(index.size() == 1 && "expected single loop");
+ assert(reductionArgs.size() == 1 && "expected single reduction value");
+ mlir::Value inRes = reductionArgs[0];
+ auto accEQzero = mlir::arith::CmpIOp::create(
+ builder, loc, mlir::arith::CmpIPredicate::eq, inRes, zeroInt);
+
+ mlir::Value res =
+ builder
+ .genIfOp(loc, {intTy}, accEQzero,
+ /*withElseRegion=*/true)
+ .genThen([&]() {
+ mlir::Value offset =
+ builder.createConvert(loc, idxTy, index[0]);
+ mlir::Value lhsInt;
+ mlir::Value rhsInt;
+ if (step == GenCmp::LeftToRight) {
+ lhsInt = genExtractAndConvertToInt(loc, builder, lhs, offset,
+ oneIdx);
+ rhsInt = genExtractAndConvertToInt(loc, builder, rhs, offset,
+ oneIdx);
+ } else if (step == GenCmp::LeftToBlank) {
+ // lhsLen > rhsLen
+ offset =
+ mlir::arith::AddIOp::create(builder, loc, rhsLen, offset);
+
+ lhsInt = genExtractAndConvertToInt(loc, builder, lhs, offset,
+ oneIdx);
+ rhsInt = blankInt;
+ } else if (step == GenCmp::BlankToRight) {
+ // rhsLen > lhsLen
+ offset =
+ mlir::arith::AddIOp::create(builder, loc, lhsLen, offset);
+
+ lhsInt = blankInt;
+ rhsInt = genExtractAndConvertToInt(loc, builder, rhs, offset,
+ oneIdx);
+ } else {
+ llvm_unreachable(
+ "unknown compare step for CmpCharOp lowering");
+ }
+
+ mlir::Value newVal = mlir::arith::SelectOp::create(
+ builder, loc,
+ mlir::arith::CmpIOp::create(builder, loc,
+ mlir::arith::CmpIPredicate::ult,
+ lhsInt, rhsInt),
+ negOneInt, inRes);
+ newVal = mlir::arith::SelectOp::create(
+ builder, loc,
+ mlir::arith::CmpIOp::create(builder, loc,
+ mlir::arith::CmpIPredicate::ugt,
+ lhsInt, rhsInt),
+ oneInt, newVal);
+ fir::ResultOp::create(builder, loc, newVal);
+ })
+ .genElse([&]() { fir::ResultOp::create(builder, loc, inRes); })
+ .getResults()[0];
+
+ return {res};
+ };
+
+ // First generate comparison of two strings for the legth of the shorter
+ // one.
+ mlir::Value minLen = mlir::arith::SelectOp::create(
+ builder, loc,
+ mlir::arith::CmpIOp::create(
+ builder, loc, mlir::arith::CmpIPredicate::slt, lhsLen, rhsLen),
+ lhsLen, rhsLen);
+
+ llvm::SmallVector<mlir::Value, 1> loopOut =
+ hlfir::genLoopNestWithReductions(loc, builder, {minLen},
+ /*reductionInits=*/{zeroInt}, genCmp,
+ /*isUnordered=*/false);
+ mlir::Value partRes = loopOut[0];
+
+ auto lhsLonger = mlir::arith::CmpIOp::create(
+ builder, loc, mlir::arith::CmpIPredicate::sgt, lhsLen, rhsLen);
+ mlir::Value tempRes =
+ builder
+ .genIfOp(loc, {intTy}, lhsLonger,
+ /*withElseRegion=*/true)
+ .genThen([&]() {
+ // If left is the longer string generate compare left to blank.
+ step = GenCmp::LeftToBlank;
+ auto lenDiff =
+ mlir::arith::SubIOp::create(builder, loc, lhsLen, rhsLen);
+
+ llvm::SmallVector<mlir::Value, 1> output =
+ hlfir::genLoopNestWithReductions(loc, builder, {lenDiff},
+ /*reductionInits=*/{partRes},
+ genCmp,
+ /*isUnordered=*/false);
+ mlir::Value res = output[0];
+ fir::ResultOp::create(builder, loc, res);
+ })
+ .genElse([&]() {
+ // If right is the longer string generate compare blank to
+ // right.
+ step = GenCmp::BlankToRight;
+ auto lenDiff =
+ mlir::arith::SubIOp::create(builder, loc, rhsLen, lhsLen);
+ llvm::SmallVector<mlir::Value, 1> output =
+ hlfir::genLoopNestWithReductions(loc, builder, {lenDiff},
+ /*reductionInits=*/{partRes},
+ genCmp,
+ /*isUnordered=*/false);
+
+ mlir::Value res = output[0];
+ fir::ResultOp::create(builder, loc, res);
+ })
+ .getResults()[0];
+ if (lhsAssociate)
+ hlfir::EndAssociateOp::create(builder, loc, lhsAssociate);
+ if (rhsAssociate)
+ hlfir::EndAssociateOp::create(builder, loc, rhsAssociate);
+
+ auto finalCmpResult =
+ mlir::arith::CmpIOp::create(builder, loc, predicate, tempRes, zeroInt);
+ rewriter.replaceOp(cmp, finalCmpResult);
+ return mlir::success();
+ }
+};
+
template <typename Op>
class MatmulConversion : public mlir::OpRewritePattern<Op> {
public:
@@ -2748,8 +2954,8 @@ public:
patterns.insert<ReductionConversion<hlfir::SumOp>>(context);
patterns.insert<ArrayShiftConversion<hlfir::CShiftOp>>(context);
patterns.insert<ArrayShiftConversion<hlfir::EOShiftOp>>(context);
+ patterns.insert<CmpCharOpConversion>(context);
patterns.insert<MatmulConversion<hlfir::MatmulTransposeOp>>(context);
-
patterns.insert<ReductionConversion<hlfir::CountOp>>(context);
patterns.insert<ReductionConversion<hlfir::AnyOp>>(context);
patterns.insert<ReductionConversion<hlfir::AllOp>>(context);