diff options
Diffstat (limited to 'mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp')
| -rw-r--r-- | mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp | 103 |
1 files changed, 52 insertions, 51 deletions
diff --git a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp index 434d7df853a5..d43e6816641c 100644 --- a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp +++ b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp @@ -117,12 +117,12 @@ static Value getScalarOrVectorConstInt(Type type, uint64_t value, if (auto vectorType = dyn_cast<VectorType>(type)) { Attribute element = IntegerAttr::get(vectorType.getElementType(), value); auto attr = SplatElementsAttr::get(vectorType, element); - return builder.create<spirv::ConstantOp>(loc, vectorType, attr); + return spirv::ConstantOp::create(builder, loc, vectorType, attr); } if (auto intType = dyn_cast<IntegerType>(type)) - return builder.create<spirv::ConstantOp>( - loc, type, builder.getIntegerAttr(type, value)); + return spirv::ConstantOp::create(builder, loc, type, + builder.getIntegerAttr(type, value)); return nullptr; } @@ -418,18 +418,19 @@ static Value emulateSignedRemainder(Location loc, Value lhs, Value rhs, Type type = lhs.getType(); // Calculate the remainder with spirv.UMod. - Value lhsAbs = builder.create<SignedAbsOp>(loc, type, lhs); - Value rhsAbs = builder.create<SignedAbsOp>(loc, type, rhs); - Value abs = builder.create<spirv::UModOp>(loc, lhsAbs, rhsAbs); + Value lhsAbs = SignedAbsOp::create(builder, loc, type, lhs); + Value rhsAbs = SignedAbsOp::create(builder, loc, type, rhs); + Value abs = spirv::UModOp::create(builder, loc, lhsAbs, rhsAbs); // Fix the sign. Value isPositive; if (lhs == signOperand) - isPositive = builder.create<spirv::IEqualOp>(loc, lhs, lhsAbs); + isPositive = spirv::IEqualOp::create(builder, loc, lhs, lhsAbs); else - isPositive = builder.create<spirv::IEqualOp>(loc, rhs, rhsAbs); - Value absNegate = builder.create<spirv::SNegateOp>(loc, type, abs); - return builder.create<spirv::SelectOp>(loc, type, isPositive, abs, absNegate); + isPositive = spirv::IEqualOp::create(builder, loc, rhs, rhsAbs); + Value absNegate = spirv::SNegateOp::create(builder, loc, type, abs); + return spirv::SelectOp::create(builder, loc, type, isPositive, abs, + absNegate); } /// Converts arith.remsi to GLSL SPIR-V ops. @@ -601,13 +602,13 @@ struct ExtSII1Pattern final : public OpConversionPattern<arith::ExtSIOp> { Value allOnes; if (auto intTy = dyn_cast<IntegerType>(dstType)) { unsigned componentBitwidth = intTy.getWidth(); - allOnes = rewriter.create<spirv::ConstantOp>( - loc, intTy, + allOnes = spirv::ConstantOp::create( + rewriter, loc, intTy, rewriter.getIntegerAttr(intTy, APInt::getAllOnes(componentBitwidth))); } else if (auto vectorTy = dyn_cast<VectorType>(dstType)) { unsigned componentBitwidth = vectorTy.getElementTypeBitWidth(); - allOnes = rewriter.create<spirv::ConstantOp>( - loc, vectorTy, + allOnes = spirv::ConstantOp::create( + rewriter, loc, vectorTy, SplatElementsAttr::get(vectorTy, APInt::getAllOnes(componentBitwidth))); } else { @@ -653,8 +654,8 @@ struct ExtSIPattern final : public OpConversionPattern<arith::ExtSIOp> { // First shift left to sequeeze out all leading bits beyond the original // bitwidth. Here we need to use the original source and result type's // bitwidth. - auto shiftLOp = rewriter.create<spirv::ShiftLeftLogicalOp>( - op.getLoc(), dstType, adaptor.getIn(), shiftSize); + auto shiftLOp = spirv::ShiftLeftLogicalOp::create( + rewriter, op.getLoc(), dstType, adaptor.getIn(), shiftSize); // Then we perform arithmetic right shift to make sure we have the right // sign bits for negative values. @@ -757,9 +758,9 @@ struct TruncII1Pattern final : public OpConversionPattern<arith::TruncIOp> { auto srcType = adaptor.getOperands().front().getType(); // Check if (x & 1) == 1. Value mask = spirv::ConstantOp::getOne(srcType, loc, rewriter); - Value maskedSrc = rewriter.create<spirv::BitwiseAndOp>( - loc, srcType, adaptor.getOperands()[0], mask); - Value isOne = rewriter.create<spirv::IEqualOp>(loc, maskedSrc, mask); + Value maskedSrc = spirv::BitwiseAndOp::create( + rewriter, loc, srcType, adaptor.getOperands()[0], mask); + Value isOne = spirv::IEqualOp::create(rewriter, loc, maskedSrc, mask); Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter); Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter); @@ -914,9 +915,9 @@ public: if (auto vectorType = dyn_cast<VectorType>(dstType)) type = VectorType::get(vectorType.getShape(), type); Value extLhs = - rewriter.create<arith::ExtUIOp>(op.getLoc(), type, adaptor.getLhs()); + arith::ExtUIOp::create(rewriter, op.getLoc(), type, adaptor.getLhs()); Value extRhs = - rewriter.create<arith::ExtUIOp>(op.getLoc(), type, adaptor.getRhs()); + arith::ExtUIOp::create(rewriter, op.getLoc(), type, adaptor.getRhs()); rewriter.replaceOpWithNewOp<arith::CmpIOp>(op, op.getPredicate(), extLhs, extRhs); @@ -1067,12 +1068,12 @@ public: replace = spirv::ConstantOp::getZero(op.getType(), loc, rewriter); } } else { - Value lhsIsNan = rewriter.create<spirv::IsNanOp>(loc, adaptor.getLhs()); - Value rhsIsNan = rewriter.create<spirv::IsNanOp>(loc, adaptor.getRhs()); + Value lhsIsNan = spirv::IsNanOp::create(rewriter, loc, adaptor.getLhs()); + Value rhsIsNan = spirv::IsNanOp::create(rewriter, loc, adaptor.getRhs()); - replace = rewriter.create<spirv::LogicalOrOp>(loc, lhsIsNan, rhsIsNan); + replace = spirv::LogicalOrOp::create(rewriter, loc, lhsIsNan, rhsIsNan); if (op.getPredicate() == arith::CmpFPredicate::ORD) - replace = rewriter.create<spirv::LogicalNotOp>(loc, replace); + replace = spirv::LogicalNotOp::create(rewriter, loc, replace); } rewriter.replaceOp(op, replace); @@ -1094,17 +1095,17 @@ public: ConversionPatternRewriter &rewriter) const override { Type dstElemTy = adaptor.getLhs().getType(); Location loc = op->getLoc(); - Value result = rewriter.create<spirv::IAddCarryOp>(loc, adaptor.getLhs(), - adaptor.getRhs()); + Value result = spirv::IAddCarryOp::create(rewriter, loc, adaptor.getLhs(), + adaptor.getRhs()); - Value sumResult = rewriter.create<spirv::CompositeExtractOp>( - loc, result, llvm::ArrayRef(0)); - Value carryValue = rewriter.create<spirv::CompositeExtractOp>( - loc, result, llvm::ArrayRef(1)); + Value sumResult = spirv::CompositeExtractOp::create(rewriter, loc, result, + llvm::ArrayRef(0)); + Value carryValue = spirv::CompositeExtractOp::create(rewriter, loc, result, + llvm::ArrayRef(1)); // Convert the carry value to boolean. Value one = spirv::ConstantOp::getOne(dstElemTy, loc, rewriter); - Value carryResult = rewriter.create<spirv::IEqualOp>(loc, carryValue, one); + Value carryResult = spirv::IEqualOp::create(rewriter, loc, carryValue, one); rewriter.replaceOp(op, {sumResult, carryResult}); return success(); @@ -1125,12 +1126,12 @@ public: ConversionPatternRewriter &rewriter) const override { Location loc = op->getLoc(); Value result = - rewriter.create<SPIRVMulOp>(loc, adaptor.getLhs(), adaptor.getRhs()); + SPIRVMulOp::create(rewriter, loc, adaptor.getLhs(), adaptor.getRhs()); - Value low = rewriter.create<spirv::CompositeExtractOp>(loc, result, - llvm::ArrayRef(0)); - Value high = rewriter.create<spirv::CompositeExtractOp>(loc, result, - llvm::ArrayRef(1)); + Value low = spirv::CompositeExtractOp::create(rewriter, loc, result, + llvm::ArrayRef(0)); + Value high = spirv::CompositeExtractOp::create(rewriter, loc, result, + llvm::ArrayRef(1)); rewriter.replaceOp(op, {low, high}); return success(); @@ -1183,20 +1184,20 @@ public: Location loc = op.getLoc(); Value spirvOp = - rewriter.create<SPIRVOp>(loc, dstType, adaptor.getOperands()); + SPIRVOp::create(rewriter, loc, dstType, adaptor.getOperands()); if (bitEnumContainsAll(op.getFastmath(), arith::FastMathFlags::nnan)) { rewriter.replaceOp(op, spirvOp); return success(); } - Value lhsIsNan = rewriter.create<spirv::IsNanOp>(loc, adaptor.getLhs()); - Value rhsIsNan = rewriter.create<spirv::IsNanOp>(loc, adaptor.getRhs()); + Value lhsIsNan = spirv::IsNanOp::create(rewriter, loc, adaptor.getLhs()); + Value rhsIsNan = spirv::IsNanOp::create(rewriter, loc, adaptor.getRhs()); - Value select1 = rewriter.create<spirv::SelectOp>(loc, dstType, lhsIsNan, - adaptor.getLhs(), spirvOp); - Value select2 = rewriter.create<spirv::SelectOp>(loc, dstType, rhsIsNan, - adaptor.getRhs(), select1); + Value select1 = spirv::SelectOp::create(rewriter, loc, dstType, lhsIsNan, + adaptor.getLhs(), spirvOp); + Value select2 = spirv::SelectOp::create(rewriter, loc, dstType, rhsIsNan, + adaptor.getRhs(), select1); rewriter.replaceOp(op, select2); return success(); @@ -1237,7 +1238,7 @@ public: Location loc = op.getLoc(); Value spirvOp = - rewriter.create<SPIRVOp>(loc, dstType, adaptor.getOperands()); + SPIRVOp::create(rewriter, loc, dstType, adaptor.getOperands()); if (!shouldInsertNanGuards<SPIRVOp>() || bitEnumContainsAll(op.getFastmath(), arith::FastMathFlags::nnan)) { @@ -1245,13 +1246,13 @@ public: return success(); } - Value lhsIsNan = rewriter.create<spirv::IsNanOp>(loc, adaptor.getLhs()); - Value rhsIsNan = rewriter.create<spirv::IsNanOp>(loc, adaptor.getRhs()); + Value lhsIsNan = spirv::IsNanOp::create(rewriter, loc, adaptor.getLhs()); + Value rhsIsNan = spirv::IsNanOp::create(rewriter, loc, adaptor.getRhs()); - Value select1 = rewriter.create<spirv::SelectOp>(loc, dstType, lhsIsNan, - adaptor.getRhs(), spirvOp); - Value select2 = rewriter.create<spirv::SelectOp>(loc, dstType, rhsIsNan, - adaptor.getLhs(), select1); + Value select1 = spirv::SelectOp::create(rewriter, loc, dstType, lhsIsNan, + adaptor.getRhs(), spirvOp); + Value select2 = spirv::SelectOp::create(rewriter, loc, dstType, rhsIsNan, + adaptor.getLhs(), select1); rewriter.replaceOp(op, select2); return success(); |
