diff options
Diffstat (limited to 'flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp')
| -rw-r--r-- | flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp | 208 |
1 files changed, 207 insertions, 1 deletions
diff --git a/flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp b/flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp index fe12f49c655b..d8e36ea294cd 100644 --- a/flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp +++ b/flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp @@ -2078,6 +2078,212 @@ private: } }; +class CmpCharOpConversion : public mlir::OpRewritePattern<hlfir::CmpCharOp> { +public: + using mlir::OpRewritePattern<hlfir::CmpCharOp>::OpRewritePattern; + + llvm::LogicalResult + matchAndRewrite(hlfir::CmpCharOp cmp, + mlir::PatternRewriter &rewriter) const override { + + fir::FirOpBuilder builder{rewriter, cmp.getOperation()}; + const mlir::Location &loc = cmp->getLoc(); + + auto toVariable = + [&builder, + &loc](mlir::Value val) -> std::pair<mlir::Value, hlfir::AssociateOp> { + mlir::Value opnd; + hlfir::AssociateOp associate; + if (mlir::isa<hlfir::ExprType>(val.getType())) { + hlfir::Entity entity{val}; + mlir::NamedAttribute byRefAttr = fir::getAdaptToByRefAttr(builder); + associate = hlfir::genAssociateExpr(loc, builder, entity, + entity.getType(), "", byRefAttr); + opnd = associate.getBase(); + } else { + opnd = val; + } + return {opnd, associate}; + }; + + auto [lhsOpnd, lhsAssociate] = toVariable(cmp.getLchr()); + auto [rhsOpnd, rhsAssociate] = toVariable(cmp.getRchr()); + + hlfir::Entity lhs{lhsOpnd}; + hlfir::Entity rhs{rhsOpnd}; + + auto charTy = mlir::cast<fir::CharacterType>(lhs.getFortranElementType()); + unsigned kind = charTy.getFKind(); + + auto bits = builder.getKindMap().getCharacterBitsize(kind); + auto intTy = builder.getIntegerType(bits); + + auto idxTy = builder.getIndexType(); + auto charLen1Ty = + fir::CharacterType::getSingleton(builder.getContext(), kind); + mlir::Type designatorType = + fir::ReferenceType::get(charLen1Ty, fir::isa_volatile_type(charTy)); + auto idxAttr = builder.getIntegerAttr(idxTy, 0); + + auto genExtractAndConvertToInt = + [&idxAttr, &intTy, &designatorType]( + mlir::Location loc, fir::FirOpBuilder &builder, + hlfir::Entity &charStr, mlir::Value index, mlir::Value length) { + auto singleChr = hlfir::DesignateOp::create( + builder, loc, designatorType, charStr, /*component=*/{}, + /*compShape=*/mlir::Value{}, hlfir::DesignateOp::Subscripts{}, + /*substring=*/mlir::ValueRange{index, index}, + /*complexPart=*/std::nullopt, + /*shape=*/mlir::Value{}, /*typeParams=*/mlir::ValueRange{length}, + fir::FortranVariableFlagsAttr{}); + auto chrVal = fir::LoadOp::create(builder, loc, singleChr); + mlir::Value intVal = fir::ExtractValueOp::create( + builder, loc, intTy, chrVal, builder.getArrayAttr(idxAttr)); + return intVal; + }; + + mlir::arith::CmpIPredicate predicate = cmp.getPredicate(); + mlir::Value oneIdx = builder.createIntegerConstant(loc, idxTy, 1); + + mlir::Value lhsLen = builder.createConvert( + loc, idxTy, hlfir::genCharLength(loc, builder, lhs)); + mlir::Value rhsLen = builder.createConvert( + loc, idxTy, hlfir::genCharLength(loc, builder, rhs)); + + enum class GenCmp { LeftToRight, LeftToBlank, BlankToRight }; + + mlir::Value zeroInt = builder.createIntegerConstant(loc, intTy, 0); + mlir::Value oneInt = builder.createIntegerConstant(loc, intTy, 1); + mlir::Value negOneInt = builder.createIntegerConstant(loc, intTy, -1); + mlir::Value blankInt = builder.createIntegerConstant(loc, intTy, ' '); + + auto step = GenCmp::LeftToRight; + auto genCmp = [&](mlir::Location loc, fir::FirOpBuilder &builder, + mlir::ValueRange index, mlir::ValueRange reductionArgs) + -> llvm::SmallVector<mlir::Value, 1> { + assert(index.size() == 1 && "expected single loop"); + assert(reductionArgs.size() == 1 && "expected single reduction value"); + mlir::Value inRes = reductionArgs[0]; + auto accEQzero = mlir::arith::CmpIOp::create( + builder, loc, mlir::arith::CmpIPredicate::eq, inRes, zeroInt); + + mlir::Value res = + builder + .genIfOp(loc, {intTy}, accEQzero, + /*withElseRegion=*/true) + .genThen([&]() { + mlir::Value offset = + builder.createConvert(loc, idxTy, index[0]); + mlir::Value lhsInt; + mlir::Value rhsInt; + if (step == GenCmp::LeftToRight) { + lhsInt = genExtractAndConvertToInt(loc, builder, lhs, offset, + oneIdx); + rhsInt = genExtractAndConvertToInt(loc, builder, rhs, offset, + oneIdx); + } else if (step == GenCmp::LeftToBlank) { + // lhsLen > rhsLen + offset = + mlir::arith::AddIOp::create(builder, loc, rhsLen, offset); + + lhsInt = genExtractAndConvertToInt(loc, builder, lhs, offset, + oneIdx); + rhsInt = blankInt; + } else if (step == GenCmp::BlankToRight) { + // rhsLen > lhsLen + offset = + mlir::arith::AddIOp::create(builder, loc, lhsLen, offset); + + lhsInt = blankInt; + rhsInt = genExtractAndConvertToInt(loc, builder, rhs, offset, + oneIdx); + } else { + llvm_unreachable( + "unknown compare step for CmpCharOp lowering"); + } + + mlir::Value newVal = mlir::arith::SelectOp::create( + builder, loc, + mlir::arith::CmpIOp::create(builder, loc, + mlir::arith::CmpIPredicate::ult, + lhsInt, rhsInt), + negOneInt, inRes); + newVal = mlir::arith::SelectOp::create( + builder, loc, + mlir::arith::CmpIOp::create(builder, loc, + mlir::arith::CmpIPredicate::ugt, + lhsInt, rhsInt), + oneInt, newVal); + fir::ResultOp::create(builder, loc, newVal); + }) + .genElse([&]() { fir::ResultOp::create(builder, loc, inRes); }) + .getResults()[0]; + + return {res}; + }; + + // First generate comparison of two strings for the legth of the shorter + // one. + mlir::Value minLen = mlir::arith::SelectOp::create( + builder, loc, + mlir::arith::CmpIOp::create( + builder, loc, mlir::arith::CmpIPredicate::slt, lhsLen, rhsLen), + lhsLen, rhsLen); + + llvm::SmallVector<mlir::Value, 1> loopOut = + hlfir::genLoopNestWithReductions(loc, builder, {minLen}, + /*reductionInits=*/{zeroInt}, genCmp, + /*isUnordered=*/false); + mlir::Value partRes = loopOut[0]; + + auto lhsLonger = mlir::arith::CmpIOp::create( + builder, loc, mlir::arith::CmpIPredicate::sgt, lhsLen, rhsLen); + mlir::Value tempRes = + builder + .genIfOp(loc, {intTy}, lhsLonger, + /*withElseRegion=*/true) + .genThen([&]() { + // If left is the longer string generate compare left to blank. + step = GenCmp::LeftToBlank; + auto lenDiff = + mlir::arith::SubIOp::create(builder, loc, lhsLen, rhsLen); + + llvm::SmallVector<mlir::Value, 1> output = + hlfir::genLoopNestWithReductions(loc, builder, {lenDiff}, + /*reductionInits=*/{partRes}, + genCmp, + /*isUnordered=*/false); + mlir::Value res = output[0]; + fir::ResultOp::create(builder, loc, res); + }) + .genElse([&]() { + // If right is the longer string generate compare blank to + // right. + step = GenCmp::BlankToRight; + auto lenDiff = + mlir::arith::SubIOp::create(builder, loc, rhsLen, lhsLen); + llvm::SmallVector<mlir::Value, 1> output = + hlfir::genLoopNestWithReductions(loc, builder, {lenDiff}, + /*reductionInits=*/{partRes}, + genCmp, + /*isUnordered=*/false); + + mlir::Value res = output[0]; + fir::ResultOp::create(builder, loc, res); + }) + .getResults()[0]; + if (lhsAssociate) + hlfir::EndAssociateOp::create(builder, loc, lhsAssociate); + if (rhsAssociate) + hlfir::EndAssociateOp::create(builder, loc, rhsAssociate); + + auto finalCmpResult = + mlir::arith::CmpIOp::create(builder, loc, predicate, tempRes, zeroInt); + rewriter.replaceOp(cmp, finalCmpResult); + return mlir::success(); + } +}; + template <typename Op> class MatmulConversion : public mlir::OpRewritePattern<Op> { public: @@ -2748,8 +2954,8 @@ public: patterns.insert<ReductionConversion<hlfir::SumOp>>(context); patterns.insert<ArrayShiftConversion<hlfir::CShiftOp>>(context); patterns.insert<ArrayShiftConversion<hlfir::EOShiftOp>>(context); + patterns.insert<CmpCharOpConversion>(context); patterns.insert<MatmulConversion<hlfir::MatmulTransposeOp>>(context); - patterns.insert<ReductionConversion<hlfir::CountOp>>(context); patterns.insert<ReductionConversion<hlfir::AnyOp>>(context); patterns.insert<ReductionConversion<hlfir::AllOp>>(context); |
