summaryrefslogtreecommitdiff
path: root/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp')
-rw-r--r--mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp17
1 files changed, 10 insertions, 7 deletions
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp
index 0779cdb9667a..db1e219b601b 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp
@@ -75,13 +75,15 @@ public:
loc, resultTy, input, reverse2, bias,
rewriter.getDenseI64ArrayAttr(convPad),
rewriter.getDenseI64ArrayAttr(stride),
- rewriter.getDenseI64ArrayAttr({1, 1}), *op.getQuantizationInfo());
+ rewriter.getDenseI64ArrayAttr({1, 1}),
+ /* acc_type = */ op.getAccType(), *op.getQuantizationInfo());
} else {
conv2d = rewriter.create<tosa::Conv2DOp>(
loc, resultTy, input, reverse2, bias,
rewriter.getDenseI64ArrayAttr(convPad),
rewriter.getDenseI64ArrayAttr(stride),
- rewriter.getDenseI64ArrayAttr({1, 1}));
+ rewriter.getDenseI64ArrayAttr({1, 1}),
+ /* acc_type = */ op.getAccTypeAttr());
}
rewriter.replaceOp(op, conv2d);
@@ -139,7 +141,7 @@ public:
weightPadding[5] =
(weightWidth % stride[1]) ? (stride[1] - weightWidth % stride[1]) : 0;
DenseElementsAttr weightPaddingAttr = DenseIntElementsAttr::get(
- RankedTensorType::get({4, 2}, rewriter.getI32Type()), weightPadding);
+ RankedTensorType::get({8}, rewriter.getI32Type()), weightPadding);
Value weightPaddingVal = CreateOpAndInferShape<tosa::ConstOp>(
rewriter, loc, weightPaddingAttr.getType(), weightPaddingAttr);
@@ -202,7 +204,7 @@ public:
inputPadding[5] += restridedWeightTy.getDimSize(2) - 1;
DenseElementsAttr inputPaddingAttr = DenseIntElementsAttr::get(
- RankedTensorType::get({4, 2}, rewriter.getI32Type()), inputPadding);
+ RankedTensorType::get({8}, rewriter.getI32Type()), inputPadding);
Value inputPaddingVal = CreateOpAndInferShape<tosa::ConstOp>(
rewriter, loc, inputPaddingAttr.getType(), inputPaddingAttr);
@@ -238,7 +240,7 @@ public:
/*pad=*/rewriter.getDenseI64ArrayAttr({0, 0, 0, 0}),
/*stride=*/rewriter.getDenseI64ArrayAttr({1, 1}),
/*dilation=*/rewriter.getDenseI64ArrayAttr({1, 1}),
- *op.getQuantizationInfo())
+ /* acc_type = */ op.getAccType(), *op.getQuantizationInfo())
.getResult();
} else {
conv2d = CreateOpAndInferShape<tosa::Conv2DOp>(
@@ -246,7 +248,8 @@ public:
weight, zeroBias,
/*pad=*/rewriter.getDenseI64ArrayAttr({0, 0, 0, 0}),
/*stride=*/rewriter.getDenseI64ArrayAttr({1, 1}),
- /*dilation=*/rewriter.getDenseI64ArrayAttr({1, 1}))
+ /*dilation=*/rewriter.getDenseI64ArrayAttr({1, 1}),
+ /* acc_type = */ op.getAccTypeAttr())
.getResult();
}
@@ -314,7 +317,7 @@ public:
resultPadding[5] = resultTy.getDimSize(2) - resultPadLeft - sliceSize[2];
DenseElementsAttr resultPaddingAttr = DenseIntElementsAttr::get(
- RankedTensorType::get({4, 2}, rewriter.getI32Type()), resultPadding);
+ RankedTensorType::get({8}, rewriter.getI32Type()), resultPadding);
Value resultPaddingVal = CreateOpAndInferShape<tosa::ConstOp>(
rewriter, loc, resultPaddingAttr.getType(), resultPaddingAttr);