diff options
| author | Matthias Springer <me@m-sp.org> | 2025-11-06 01:29:04 +0000 |
|---|---|---|
| committer | Matthias Springer <me@m-sp.org> | 2025-11-09 01:17:40 +0000 |
| commit | b55e698aa6f595e4cc493c46ecf458948b510d37 (patch) | |
| tree | cb6bda77fb59b6815fe2c99044b4ead31bffb45e | |
| parent | 3bb903e3c0afcc2fb17ebf2b7da127d1595dba56 (diff) | |
[mlir][arith] Fix `arith.cmpf` lowering with unsupported FP typesusers/matthias-springer/fix_cmpf
4 files changed, 48 insertions, 31 deletions
diff --git a/mlir/include/mlir/Conversion/LLVMCommon/VectorPattern.h b/mlir/include/mlir/Conversion/LLVMCommon/VectorPattern.h index e7ab63abfeaa..06934461a6bb 100644 --- a/mlir/include/mlir/Conversion/LLVMCommon/VectorPattern.h +++ b/mlir/include/mlir/Conversion/LLVMCommon/VectorPattern.h @@ -59,6 +59,12 @@ LogicalResult vectorOneToOneRewrite( ArrayRef<NamedAttribute> targetAttrs, const LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter, IntegerOverflowFlags overflowFlags = IntegerOverflowFlags::none); + +/// Return "true" if the given type is an unsupported floating point type. In +/// case of a vector type, return "true" if the element type is an unsupported +/// floating point type. +bool isUnsupportedFloatingPointType(const TypeConverter &typeConverter, + Type type); } // namespace detail } // namespace LLVM @@ -93,16 +99,6 @@ public: using ConvertOpToLLVMPattern<SourceOp>::ConvertOpToLLVMPattern; using Super = VectorConvertToLLVMPattern<SourceOp, TargetOp>; - /// Return the given type if it's a floating point type. If the given type is - /// a vector type, return its element type if it's a floating point type. - static FloatType getFloatingPointType(Type type) { - if (auto floatType = dyn_cast<FloatType>(type)) - return floatType; - if (auto vecType = dyn_cast<VectorType>(type)) - return dyn_cast<FloatType>(vecType.getElementType()); - return nullptr; - } - LogicalResult matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { @@ -110,26 +106,18 @@ public: std::is_base_of<OpTrait::OneResult<SourceOp>, SourceOp>::value, "expected single result op"); - // The pattern should not apply if a floating-point operand is converted to - // a non-floating-point type. This indicates that the floating point type - // is not supported by the LLVM lowering. (Such types are converted to - // integers.) - auto checkType = [&](Value v) -> LogicalResult { - FloatType floatType = getFloatingPointType(v.getType()); - if (!floatType) - return success(); - Type convertedType = this->getTypeConverter()->convertType(floatType); - if (!isa_and_nonnull<FloatType>(convertedType)) - return rewriter.notifyMatchFailure(op, - "unsupported floating point type"); - return success(); - }; + // Bail on unsupported floating point types. (These are type-converted to + // integer types.) if (FailOnUnsupportedFP) { for (Value operand : op->getOperands()) - if (failed(checkType(operand))) - return failure(); - if (failed(checkType(op->getResult(0)))) - return failure(); + if (LLVM::detail::isUnsupportedFloatingPointType( + *this->getTypeConverter(), operand.getType())) + return rewriter.notifyMatchFailure(op, + "unsupported floating point type"); + if (LLVM::detail::isUnsupportedFloatingPointType( + *this->getTypeConverter(), op->getResult(0).getType())) + return rewriter.notifyMatchFailure(op, + "unsupported floating point type"); } // Determine attributes for the target op diff --git a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp index b6099902cc33..2a2d0cb5fb80 100644 --- a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp +++ b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp @@ -481,6 +481,10 @@ CmpIOpLowering::matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor, LogicalResult CmpFOpLowering::matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { + if (LLVM::detail::isUnsupportedFloatingPointType(*this->getTypeConverter(), + op.getLhs().getType())) + return rewriter.notifyMatchFailure(op, "unsupported floating point type"); + Type operandType = adaptor.getLhs().getType(); Type resultType = op.getResult().getType(); LLVM::FastmathFlags fmf = diff --git a/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp b/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp index e7dd0b506e12..b37b35d79901 100644 --- a/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp +++ b/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp @@ -131,3 +131,24 @@ LogicalResult LLVM::detail::vectorOneToOneRewrite( return handleMultidimensionalVectors(op, operands, typeConverter, callback, rewriter); } + +/// Return the given type if it's a floating point type. If the given type is +/// a vector type, return its element type if it's a floating point type. +static FloatType getFloatingPointType(Type type) { + if (auto floatType = dyn_cast<FloatType>(type)) + return floatType; + if (auto vecType = dyn_cast<VectorType>(type)) + return dyn_cast<FloatType>(vecType.getElementType()); + return nullptr; +} + +bool LLVM::detail::isUnsupportedFloatingPointType( + const TypeConverter &typeConverter, Type type) { + FloatType floatType = getFloatingPointType(type); + if (!floatType) + return false; + Type convertedType = typeConverter.convertType(floatType); + if (!convertedType) + return true; + return !isa<FloatType>(convertedType); +} diff --git a/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir b/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir index 5f1ec66234df..951db78fd7dd 100644 --- a/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir +++ b/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir @@ -754,12 +754,14 @@ func.func @memref_bitcast(%1: memref<?xi16>) -> memref<?xbf16> { // CHECK: arith.addf {{.*}} : f4E2M1FN // CHECK: arith.addf {{.*}} : vector<4xf4E2M1FN> // CHECK: arith.addf {{.*}} : vector<8x4xf4E2M1FN> +// CHECK: arith.cmpf {{.*}} : f4E2M1FN // CHECK: llvm.select {{.*}} : i1, i4 func.func @unsupported_fp_type(%arg0: f4E2M1FN, %arg1: vector<4xf4E2M1FN>, %arg2: vector<8x4xf4E2M1FN>, %arg3: f4E2M1FN, %arg4: i1) { %0 = arith.addf %arg0, %arg0 : f4E2M1FN %1 = arith.addf %arg1, %arg1 : vector<4xf4E2M1FN> %2 = arith.addf %arg2, %arg2 : vector<8x4xf4E2M1FN> - %3 = arith.select %arg4, %arg0, %arg3 : f4E2M1FN + %3 = arith.cmpf oeq, %arg0, %arg3 : f4E2M1FN + %4 = arith.select %arg4, %arg0, %arg3 : f4E2M1FN return } @@ -769,9 +771,11 @@ func.func @unsupported_fp_type(%arg0: f4E2M1FN, %arg1: vector<4xf4E2M1FN>, %arg2 // CHECK: llvm.fadd {{.*}} : f32 // CHECK: llvm.fadd {{.*}} : vector<4xf32> // CHECK-COUNT-4: llvm.fadd {{.*}} : vector<8xf32> -func.func @supported_fp_type(%arg0: f32, %arg1: vector<4xf32>, %arg2: vector<4x8xf32>) -> (f32, vector<4xf32>, vector<4x8xf32>) { +// CHECK: llvm.fcmp {{.*}} : f32 +func.func @supported_fp_type(%arg0: f32, %arg1: vector<4xf32>, %arg2: vector<4x8xf32>, %arg3: f32) { %0 = arith.addf %arg0, %arg0 : f32 %1 = arith.addf %arg1, %arg1 : vector<4xf32> %2 = arith.addf %arg2, %arg2 : vector<4x8xf32> - return %0, %1, %2 : f32, vector<4xf32>, vector<4x8xf32> + %3 = arith.cmpf oeq, %arg0, %arg3 : f32 + return } |
