diff options
| author | NAKAMURA Takumi <geek4civic@gmail.com> | 2025-01-09 18:39:43 +0900 |
|---|---|---|
| committer | NAKAMURA Takumi <geek4civic@gmail.com> | 2025-01-09 18:39:43 +0900 |
| commit | c36c84047e92587931e74aea1b3d91342617400b (patch) | |
| tree | 3d25b78796205b1f3f1ee5f9c55da298f6449ce8 /mlir/lib/Dialect/Arith/IR/ArithOps.cpp | |
| parent | 122393694892e7a718e8c612b5650388075e2833 (diff) | |
| parent | bdcf47e4bcb92889665825654bb80a8bbe30379e (diff) | |
Merge branch 'users/chapuni/cov/single/base' into users/chapuni/cov/single/condopusers/chapuni/cov/single/condop
Conflicts:
clang/lib/CodeGen/CoverageMappingGen.cpp
Diffstat (limited to 'mlir/lib/Dialect/Arith/IR/ArithOps.cpp')
| -rw-r--r-- | mlir/lib/Dialect/Arith/IR/ArithOps.cpp | 24 |
1 files changed, 24 insertions, 0 deletions
diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp index d8b314a3fa43..e016a6e16e59 100644 --- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp +++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp @@ -580,11 +580,31 @@ void arith::MulUIExtendedOp::getCanonicalizationPatterns( // DivUIOp //===----------------------------------------------------------------------===// +/// Fold `(a * b) / b -> a` +static Value foldDivMul(Value lhs, Value rhs, + arith::IntegerOverflowFlags ovfFlags) { + auto mul = lhs.getDefiningOp<mlir::arith::MulIOp>(); + if (!mul || !bitEnumContainsAll(mul.getOverflowFlags(), ovfFlags)) + return {}; + + if (mul.getLhs() == rhs) + return mul.getRhs(); + + if (mul.getRhs() == rhs) + return mul.getLhs(); + + return {}; +} + OpFoldResult arith::DivUIOp::fold(FoldAdaptor adaptor) { // divui (x, 1) -> x. if (matchPattern(adaptor.getRhs(), m_One())) return getLhs(); + // (a * b) / b -> a + if (Value val = foldDivMul(getLhs(), getRhs(), IntegerOverflowFlags::nuw)) + return val; + // Don't fold if it would require a division by zero. bool div0 = false; auto result = constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(), @@ -621,6 +641,10 @@ OpFoldResult arith::DivSIOp::fold(FoldAdaptor adaptor) { if (matchPattern(adaptor.getRhs(), m_One())) return getLhs(); + // (a * b) / b -> a + if (Value val = foldDivMul(getLhs(), getRhs(), IntegerOverflowFlags::nsw)) + return val; + // Don't fold if it would overflow or if it requires a division by zero. bool overflowOrDiv0 = false; auto result = constFoldBinaryOp<IntegerAttr>( |
