diff options
| author | jerryyin <zhuoryin@amd.com> | 2025-05-01 20:58:45 +0000 |
|---|---|---|
| committer | jerryyin <zhuoryin@amd.com> | 2025-05-01 20:58:45 +0000 |
| commit | 446a0a19295d49afb0879644475f5115408b54a3 (patch) | |
| tree | 53138439189975fd29db7b918d0d5072ef871b03 | |
| parent | 96f28e044f1faa36f116d7b29161a82dd9f729db (diff) | |
| -rw-r--r-- | mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp | 27 |
1 files changed, 27 insertions, 0 deletions
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp index 0ddf03bf317e..0a9c3b7136e9 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp @@ -6,6 +6,7 @@ // //===----------------------------------------------------------------------===// +#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h" #include "mlir/Dialect/Linalg/Passes.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" @@ -15,6 +16,7 @@ #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tensor/Utils/Utils.h" #include "mlir/Dialect/Utils/IndexingUtils.h" +#include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Dominance.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "llvm/ADT/SetOperations.h" @@ -304,14 +306,39 @@ static GenericOp packGenericOp(RewriterBase &rewriter, GenericOp genericOp, const PackInfo &packInfo) { Location loc = genericOp.getLoc(); SmallVector<Value> inputOperands; + SmallVector<Value> inputOperandsFromUnpackedSource; SmallVector<AffineMap> indexingMaps; + + bool isGenericUnary = isaElemwiseSingleUnaryOpInterface(genericOp); + // TODO: fix the condition + bool canUnpackPackCancelout = isGenericUnary; for (OpOperand *inputOperand : genericOp.getDpsInputOperands()) { auto [packedOperand, packedIndexingMap] = getOrCreatePackedViewOfOperand( rewriter, loc, packInfo, genericOp, inputOperand); + + auto unpackOp = inputOperand->get().getDefiningOp<linalg::UnPackOp>(); + // If the GenericOp fits the elementwise unary op interface then we can skip + // using a redundant pack op as the operand and instead just use the source + // of the unpack op. + if (unpackOp) { + inputOperandsFromUnpackedSource.push_back(unpackOp.getSource()); + } else { + inputOperandsFromUnpackedSource.push_back(packedOperand); + } + inputOperands.push_back(packedOperand); indexingMaps.push_back(packedIndexingMap); } + // if The pack and unpack op has cancelled each other out, we don't care about the + // init tensor of the generic op and can instead just forward the new tensor.empty + // as a destination. + if (canUnpackPackCancelout){ + inputOperands = inputOperandsFromUnpackedSource; + if(auto destPack = dest.getDefiningOp<linalg::PackOp>()) + dest = destPack.getDest(); + } + int64_t numInnerLoops = packInfo.getNumTiledLoops(); SmallVector<utils::IteratorType> iterTypes = genericOp.getIteratorTypesArray(); |
