summaryrefslogtreecommitdiff
path: root/mlir/lib/Conversion/SCFToGPU/SCFToGPU.cpp
diff options
context:
space:
mode:
authorChristian Sigg <csigg@google.com>2022-08-22 10:39:49 +0200
committerChristian Sigg <csigg@google.com>2022-08-22 11:14:04 +0200
commit459fd3fb342d565bbaff48673838c5ea138128f8 (patch)
tree79ca73f6fe73ca883544107333369f2b24af38ab /mlir/lib/Conversion/SCFToGPU/SCFToGPU.cpp
parenta6e155fd9a392285ab37e12777a8c140f913bbc5 (diff)
[MLIR][GPU] Detect bounds with `arith.minsi ` in loops-to-gpu
Previously, `arith.constant`, `arith.muli` and `affine.min` were supported when deriving upper loop bounds when converting parallel loops to GPU. Reviewed By: akuegel Differential Revision: https://reviews.llvm.org/D132354
Diffstat (limited to 'mlir/lib/Conversion/SCFToGPU/SCFToGPU.cpp')
-rw-r--r--mlir/lib/Conversion/SCFToGPU/SCFToGPU.cpp11
1 files changed, 9 insertions, 2 deletions
diff --git a/mlir/lib/Conversion/SCFToGPU/SCFToGPU.cpp b/mlir/lib/Conversion/SCFToGPU/SCFToGPU.cpp
index 5e1c1d1bd857..f8ad965d4ae7 100644
--- a/mlir/lib/Conversion/SCFToGPU/SCFToGPU.cpp
+++ b/mlir/lib/Conversion/SCFToGPU/SCFToGPU.cpp
@@ -328,6 +328,13 @@ static Value deriveStaticUpperBound(Value upperBound,
}
}
+ if (auto minOp = upperBound.getDefiningOp<arith::MinSIOp>()) {
+ for (Value operand : {minOp.getLhs(), minOp.getRhs()}) {
+ if (auto staticBound = deriveStaticUpperBound(operand, rewriter))
+ return staticBound;
+ }
+ }
+
if (auto multiplyOp = upperBound.getDefiningOp<arith::MulIOp>()) {
if (auto lhs = dyn_cast_or_null<arith::ConstantIndexOp>(
deriveStaticUpperBound(multiplyOp.getOperand(0), rewriter)
@@ -336,8 +343,8 @@ static Value deriveStaticUpperBound(Value upperBound,
deriveStaticUpperBound(multiplyOp.getOperand(1), rewriter)
.getDefiningOp())) {
// Assumptions about the upper bound of minimum computations no longer
- // work if multiplied by a negative value, so abort in this case.
- if (lhs.value() < 0 || rhs.value() < 0)
+ // work if multiplied by mixed signs, so abort in this case.
+ if (lhs.value() < 0 != rhs.value() < 0)
return {};
return rewriter.create<arith::ConstantIndexOp>(