summaryrefslogtreecommitdiff
path: root/mlir/lib/Dialect/Math/Transforms/LegalizeToF32.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Dialect/Math/Transforms/LegalizeToF32.cpp')
-rw-r--r--mlir/lib/Dialect/Math/Transforms/LegalizeToF32.cpp11
1 files changed, 8 insertions, 3 deletions
diff --git a/mlir/lib/Dialect/Math/Transforms/LegalizeToF32.cpp b/mlir/lib/Dialect/Math/Transforms/LegalizeToF32.cpp
index 5998133b7eab..3d99f3033cf5 100644
--- a/mlir/lib/Dialect/Math/Transforms/LegalizeToF32.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/LegalizeToF32.cpp
@@ -57,7 +57,9 @@ void mlir::math::populateLegalizeToF32TypeConverter(
});
typeConverter.addTargetMaterialization(
[](OpBuilder &b, Type target, ValueRange input, Location loc) {
- return b.create<arith::ExtFOp>(loc, target, input);
+ auto extFOp = b.create<arith::ExtFOp>(loc, target, input);
+ extFOp.setFastmath(arith::FastMathFlags::contract);
+ return extFOp;
});
}
@@ -84,8 +86,11 @@ LogicalResult LegalizeToF32RewritePattern::matchAndRewrite(
SmallVector<Value> results = (*legalized)->getResults();
for (auto [result, newType, origType] : llvm::zip_equal(
results, (*legalized)->getResultTypes(), op->getResultTypes())) {
- if (newType != origType)
- result = rewriter.create<arith::TruncFOp>(loc, origType, result);
+ if (newType != origType) {
+ auto truncFOp = rewriter.create<arith::TruncFOp>(loc, origType, result);
+ truncFOp.setFastmath(arith::FastMathFlags::contract);
+ result = truncFOp.getResult();
+ }
}
rewriter.replaceOp(op, results);
return success();