summaryrefslogtreecommitdiff
path: root/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp')
-rw-r--r--mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp18
1 files changed, 14 insertions, 4 deletions
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp
index dc5c51b0abad..8b23fd1341bc 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp
@@ -49,8 +49,13 @@ public:
if (llvm::any_of(stride, [](int64_t v) { return v != 1; }))
return failure();
- if (!inputTy.hasStaticShape() || !weightTy.hasStaticShape() ||
- !biasTy.hasStaticShape() || !resultTy.hasStaticShape())
+ // Any dimensions other than batchSize cannot be dynamic for input/output
+ for (unsigned int i = 1; i < 4; ++i) {
+ if (inputTy.isDynamicDim(i) || resultTy.isDynamicDim(i))
+ return failure();
+ }
+
+ if (!weightTy.hasStaticShape() || !biasTy.hasStaticShape())
return failure();
int64_t kernelHeight = weightTy.getDimSize(1);
@@ -113,8 +118,13 @@ public:
if (llvm::all_of(stride, [](int64_t v) { return v == 1; }))
return rewriter.notifyMatchFailure(op, "non-one stride found.");
- if (!inputTy.hasStaticShape() || !weightTy.hasStaticShape() ||
- !biasTy.hasStaticShape() || !resultTy.hasStaticShape())
+ // Any dimensions other than batchSize cannot be dynamic for input/output
+ for (unsigned int i = 1; i < 4; ++i) {
+ if (inputTy.isDynamicDim(i) || resultTy.isDynamicDim(i))
+ return failure();
+ }
+
+ if (!weightTy.hasStaticShape() || !biasTy.hasStaticShape())
return failure();
int64_t batch = inputTy.getDimSize(0);