diff options
Diffstat (limited to 'mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp')
| -rw-r--r-- | mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp | 18 |
1 files changed, 14 insertions, 4 deletions
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp index dc5c51b0abad..8b23fd1341bc 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp @@ -49,8 +49,13 @@ public: if (llvm::any_of(stride, [](int64_t v) { return v != 1; })) return failure(); - if (!inputTy.hasStaticShape() || !weightTy.hasStaticShape() || - !biasTy.hasStaticShape() || !resultTy.hasStaticShape()) + // Any dimensions other than batchSize cannot be dynamic for input/output + for (unsigned int i = 1; i < 4; ++i) { + if (inputTy.isDynamicDim(i) || resultTy.isDynamicDim(i)) + return failure(); + } + + if (!weightTy.hasStaticShape() || !biasTy.hasStaticShape()) return failure(); int64_t kernelHeight = weightTy.getDimSize(1); @@ -113,8 +118,13 @@ public: if (llvm::all_of(stride, [](int64_t v) { return v == 1; })) return rewriter.notifyMatchFailure(op, "non-one stride found."); - if (!inputTy.hasStaticShape() || !weightTy.hasStaticShape() || - !biasTy.hasStaticShape() || !resultTy.hasStaticShape()) + // Any dimensions other than batchSize cannot be dynamic for input/output + for (unsigned int i = 1; i < 4; ++i) { + if (inputTy.isDynamicDim(i) || resultTy.isDynamicDim(i)) + return failure(); + } + + if (!weightTy.hasStaticShape() || !biasTy.hasStaticShape()) return failure(); int64_t batch = inputTy.getDimSize(0); |
