diff options
| author | Matthias Springer <me@m-sp.org> | 2025-11-06 01:29:04 +0000 |
|---|---|---|
| committer | Matthias Springer <me@m-sp.org> | 2025-11-09 01:17:40 +0000 |
| commit | b55e698aa6f595e4cc493c46ecf458948b510d37 (patch) | |
| tree | cb6bda77fb59b6815fe2c99044b4ead31bffb45e /mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp | |
| parent | 3bb903e3c0afcc2fb17ebf2b7da127d1595dba56 (diff) | |
[mlir][arith] Fix `arith.cmpf` lowering with unsupported FP typesusers/matthias-springer/fix_cmpf
Diffstat (limited to 'mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp')
| -rw-r--r-- | mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp | 21 |
1 files changed, 21 insertions, 0 deletions
diff --git a/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp b/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp index e7dd0b506e12..b37b35d79901 100644 --- a/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp +++ b/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp @@ -131,3 +131,24 @@ LogicalResult LLVM::detail::vectorOneToOneRewrite( return handleMultidimensionalVectors(op, operands, typeConverter, callback, rewriter); } + +/// Return the given type if it's a floating point type. If the given type is +/// a vector type, return its element type if it's a floating point type. +static FloatType getFloatingPointType(Type type) { + if (auto floatType = dyn_cast<FloatType>(type)) + return floatType; + if (auto vecType = dyn_cast<VectorType>(type)) + return dyn_cast<FloatType>(vecType.getElementType()); + return nullptr; +} + +bool LLVM::detail::isUnsupportedFloatingPointType( + const TypeConverter &typeConverter, Type type) { + FloatType floatType = getFloatingPointType(type); + if (!floatType) + return false; + Type convertedType = typeConverter.convertType(floatType); + if (!convertedType) + return true; + return !isa<FloatType>(convertedType); +} |
