diff options
| author | Christian Sigg <csigg@google.com> | 2022-08-22 10:39:49 +0200 |
|---|---|---|
| committer | Christian Sigg <csigg@google.com> | 2022-08-22 11:14:04 +0200 |
| commit | 459fd3fb342d565bbaff48673838c5ea138128f8 (patch) | |
| tree | 79ca73f6fe73ca883544107333369f2b24af38ab /mlir/lib/Conversion/SCFToGPU/SCFToGPU.cpp | |
| parent | a6e155fd9a392285ab37e12777a8c140f913bbc5 (diff) | |
[MLIR][GPU] Detect bounds with `arith.minsi ` in loops-to-gpu
Previously, `arith.constant`, `arith.muli` and `affine.min` were supported when deriving upper loop bounds when converting parallel loops to GPU.
Reviewed By: akuegel
Differential Revision: https://reviews.llvm.org/D132354
Diffstat (limited to 'mlir/lib/Conversion/SCFToGPU/SCFToGPU.cpp')
| -rw-r--r-- | mlir/lib/Conversion/SCFToGPU/SCFToGPU.cpp | 11 |
1 files changed, 9 insertions, 2 deletions
diff --git a/mlir/lib/Conversion/SCFToGPU/SCFToGPU.cpp b/mlir/lib/Conversion/SCFToGPU/SCFToGPU.cpp index 5e1c1d1bd857..f8ad965d4ae7 100644 --- a/mlir/lib/Conversion/SCFToGPU/SCFToGPU.cpp +++ b/mlir/lib/Conversion/SCFToGPU/SCFToGPU.cpp @@ -328,6 +328,13 @@ static Value deriveStaticUpperBound(Value upperBound, } } + if (auto minOp = upperBound.getDefiningOp<arith::MinSIOp>()) { + for (Value operand : {minOp.getLhs(), minOp.getRhs()}) { + if (auto staticBound = deriveStaticUpperBound(operand, rewriter)) + return staticBound; + } + } + if (auto multiplyOp = upperBound.getDefiningOp<arith::MulIOp>()) { if (auto lhs = dyn_cast_or_null<arith::ConstantIndexOp>( deriveStaticUpperBound(multiplyOp.getOperand(0), rewriter) @@ -336,8 +343,8 @@ static Value deriveStaticUpperBound(Value upperBound, deriveStaticUpperBound(multiplyOp.getOperand(1), rewriter) .getDefiningOp())) { // Assumptions about the upper bound of minimum computations no longer - // work if multiplied by a negative value, so abort in this case. - if (lhs.value() < 0 || rhs.value() < 0) + // work if multiplied by mixed signs, so abort in this case. + if (lhs.value() < 0 != rhs.value() < 0) return {}; return rewriter.create<arith::ConstantIndexOp>( |
