summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorjerryyin <zhuoryin@amd.com>2025-05-01 20:58:45 +0000
committerjerryyin <zhuoryin@amd.com>2025-05-01 20:58:45 +0000
commit446a0a19295d49afb0879644475f5115408b54a3 (patch)
tree53138439189975fd29db7b918d0d5072ef871b03
parent96f28e044f1faa36f116d7b29161a82dd9f729db (diff)
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp27
1 files changed, 27 insertions, 0 deletions
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
index 0ddf03bf317e..0a9c3b7136e9 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
@@ -6,6 +6,7 @@
//
//===----------------------------------------------------------------------===//
+#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
#include "mlir/Dialect/Linalg/Passes.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
@@ -15,6 +16,7 @@
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tensor/Utils/Utils.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
+#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Dominance.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/ADT/SetOperations.h"
@@ -304,14 +306,39 @@ static GenericOp packGenericOp(RewriterBase &rewriter, GenericOp genericOp,
const PackInfo &packInfo) {
Location loc = genericOp.getLoc();
SmallVector<Value> inputOperands;
+ SmallVector<Value> inputOperandsFromUnpackedSource;
SmallVector<AffineMap> indexingMaps;
+
+ bool isGenericUnary = isaElemwiseSingleUnaryOpInterface(genericOp);
+ // TODO: fix the condition
+ bool canUnpackPackCancelout = isGenericUnary;
for (OpOperand *inputOperand : genericOp.getDpsInputOperands()) {
auto [packedOperand, packedIndexingMap] = getOrCreatePackedViewOfOperand(
rewriter, loc, packInfo, genericOp, inputOperand);
+
+ auto unpackOp = inputOperand->get().getDefiningOp<linalg::UnPackOp>();
+ // If the GenericOp fits the elementwise unary op interface then we can skip
+ // using a redundant pack op as the operand and instead just use the source
+ // of the unpack op.
+ if (unpackOp) {
+ inputOperandsFromUnpackedSource.push_back(unpackOp.getSource());
+ } else {
+ inputOperandsFromUnpackedSource.push_back(packedOperand);
+ }
+
inputOperands.push_back(packedOperand);
indexingMaps.push_back(packedIndexingMap);
}
+ // if The pack and unpack op has cancelled each other out, we don't care about the
+ // init tensor of the generic op and can instead just forward the new tensor.empty
+ // as a destination.
+ if (canUnpackPackCancelout){
+ inputOperands = inputOperandsFromUnpackedSource;
+ if(auto destPack = dest.getDefiningOp<linalg::PackOp>())
+ dest = destPack.getDest();
+ }
+
int64_t numInnerLoops = packInfo.getNumTiledLoops();
SmallVector<utils::IteratorType> iterTypes =
genericOp.getIteratorTypesArray();