summaryrefslogtreecommitdiff
path: root/mlir/lib/Dialect/Linalg/Transforms/RuntimeOpVerification.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Dialect/Linalg/Transforms/RuntimeOpVerification.cpp')
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/RuntimeOpVerification.cpp29
1 files changed, 29 insertions, 0 deletions
diff --git a/mlir/lib/Dialect/Linalg/Transforms/RuntimeOpVerification.cpp b/mlir/lib/Dialect/Linalg/Transforms/RuntimeOpVerification.cpp
index 15eb51a6dcab..5e10ba378060 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/RuntimeOpVerification.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/RuntimeOpVerification.cpp
@@ -17,6 +17,7 @@
#include "mlir/Dialect/Index/IR/IndexOps.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Interfaces/RuntimeVerifiableOpInterface.h"
@@ -43,6 +44,33 @@ struct StructuredOpInterface
auto zero = arith::ConstantIndexOp::create(builder, loc, 0);
auto one = arith::ConstantIndexOp::create(builder, loc, 1);
+ Value iterationDomainIsNonDegenerate;
+ for (auto [start, end] : llvm::zip(starts, ends)) {
+ auto startValue = getValueOrCreateConstantIndexOp(builder, loc, start);
+ auto endValue = getValueOrCreateConstantIndexOp(builder, loc, end);
+
+ // Loop Trip count > 0 iff start < end
+ Value dimensionHasNonZeroTripCount = index::CmpOp::create(
+ builder, loc, index::IndexCmpPredicate::SLT, startValue, endValue);
+
+ if (!iterationDomainIsNonDegenerate) {
+ iterationDomainIsNonDegenerate = dimensionHasNonZeroTripCount;
+ } else {
+ // Iteration domain is non-degenerate iff all dimensions have loop trip
+ // count > 0
+ iterationDomainIsNonDegenerate =
+ arith::AndIOp::create(builder, loc, iterationDomainIsNonDegenerate,
+ dimensionHasNonZeroTripCount);
+ }
+ }
+
+ if (!iterationDomainIsNonDegenerate)
+ return;
+
+ auto ifOp = scf::IfOp::create(builder, loc, iterationDomainIsNonDegenerate,
+ /*withElseRegion=*/false);
+ builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
+
// Subtract one from the loop ends before composing with the indexing map
transform(ends, ends.begin(), [&](OpFoldResult end) {
auto endValue = getValueOrCreateConstantIndexOp(builder, loc, end);
@@ -110,6 +138,7 @@ struct StructuredOpInterface
builder.createOrFold<cf::AssertOp>(loc, cmpOp, msg);
}
}
+ builder.setInsertionPointAfter(ifOp);
}
};