diff options
Diffstat (limited to 'mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp')
| -rw-r--r-- | mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp | 17 |
1 files changed, 10 insertions, 7 deletions
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp index 0779cdb9667a..db1e219b601b 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp @@ -75,13 +75,15 @@ public: loc, resultTy, input, reverse2, bias, rewriter.getDenseI64ArrayAttr(convPad), rewriter.getDenseI64ArrayAttr(stride), - rewriter.getDenseI64ArrayAttr({1, 1}), *op.getQuantizationInfo()); + rewriter.getDenseI64ArrayAttr({1, 1}), + /* acc_type = */ op.getAccType(), *op.getQuantizationInfo()); } else { conv2d = rewriter.create<tosa::Conv2DOp>( loc, resultTy, input, reverse2, bias, rewriter.getDenseI64ArrayAttr(convPad), rewriter.getDenseI64ArrayAttr(stride), - rewriter.getDenseI64ArrayAttr({1, 1})); + rewriter.getDenseI64ArrayAttr({1, 1}), + /* acc_type = */ op.getAccTypeAttr()); } rewriter.replaceOp(op, conv2d); @@ -139,7 +141,7 @@ public: weightPadding[5] = (weightWidth % stride[1]) ? (stride[1] - weightWidth % stride[1]) : 0; DenseElementsAttr weightPaddingAttr = DenseIntElementsAttr::get( - RankedTensorType::get({4, 2}, rewriter.getI32Type()), weightPadding); + RankedTensorType::get({8}, rewriter.getI32Type()), weightPadding); Value weightPaddingVal = CreateOpAndInferShape<tosa::ConstOp>( rewriter, loc, weightPaddingAttr.getType(), weightPaddingAttr); @@ -202,7 +204,7 @@ public: inputPadding[5] += restridedWeightTy.getDimSize(2) - 1; DenseElementsAttr inputPaddingAttr = DenseIntElementsAttr::get( - RankedTensorType::get({4, 2}, rewriter.getI32Type()), inputPadding); + RankedTensorType::get({8}, rewriter.getI32Type()), inputPadding); Value inputPaddingVal = CreateOpAndInferShape<tosa::ConstOp>( rewriter, loc, inputPaddingAttr.getType(), inputPaddingAttr); @@ -238,7 +240,7 @@ public: /*pad=*/rewriter.getDenseI64ArrayAttr({0, 0, 0, 0}), /*stride=*/rewriter.getDenseI64ArrayAttr({1, 1}), /*dilation=*/rewriter.getDenseI64ArrayAttr({1, 1}), - *op.getQuantizationInfo()) + /* acc_type = */ op.getAccType(), *op.getQuantizationInfo()) .getResult(); } else { conv2d = CreateOpAndInferShape<tosa::Conv2DOp>( @@ -246,7 +248,8 @@ public: weight, zeroBias, /*pad=*/rewriter.getDenseI64ArrayAttr({0, 0, 0, 0}), /*stride=*/rewriter.getDenseI64ArrayAttr({1, 1}), - /*dilation=*/rewriter.getDenseI64ArrayAttr({1, 1})) + /*dilation=*/rewriter.getDenseI64ArrayAttr({1, 1}), + /* acc_type = */ op.getAccTypeAttr()) .getResult(); } @@ -314,7 +317,7 @@ public: resultPadding[5] = resultTy.getDimSize(2) - resultPadLeft - sliceSize[2]; DenseElementsAttr resultPaddingAttr = DenseIntElementsAttr::get( - RankedTensorType::get({4, 2}, rewriter.getI32Type()), resultPadding); + RankedTensorType::get({8}, rewriter.getI32Type()), resultPadding); Value resultPaddingVal = CreateOpAndInferShape<tosa::ConstOp>( rewriter, loc, resultPaddingAttr.getType(), resultPaddingAttr); |
