summaryrefslogtreecommitdiff
path: root/mlir/include
diff options
context:
space:
mode:
authorMaheshRavishankar <1663364+MaheshRavishankar@users.noreply.github.com>2025-11-10 09:01:01 -0800
committerGitHub <noreply@github.com>2025-11-10 09:01:01 -0800
commitfc093f13615269951c41a5acce445d299e31fa76 (patch)
tree5f726947942ebc05aea2b43648e05d245e84b6ad /mlir/include
parentfa98fcd02e3dec47e3a8ea340a247dfa8ef1f2f5 (diff)
[mlir][Interfaces] Add interface methods to allow reifying single result/single dim of result. (#162924)
Current implementation of `reifyResultShapes` forces all implementations to return all dimensions of all results. This can be wasteful when you only require dimensions of one result, or a single dimension of a result. Further this also creates issues with using patterns to resolve the `tensor.dim` and `memref.dim` operations since the extra operations created result in the pattern rewriter entering an infinite loop (eventually breaking out of the loop due to the iteration limit on the pattern rewriter). This is demonstrated by some of the test cases added here that hit this limit when using `--resolve-shaped-type-result-dims` and `--resolve-ranked-shaped-type-result-dims`. To resolve this issue the interface should allow for creating just the operations needed. This change is the first step in resolving this. The original implementation was done with the restriction in mind that it might not always be possible to compute dimension of a single result or one dimension of a single result in all cases. To account for such cases, two additional interface methods are added - `reifyShapeOfResult` (which allows reifying dimensions of just one result), has a default implementation that calls `reifyResultShapes` and returns the dimensions of a single result. - `reifyDimOfResult` (which allows reifying a single dimension of a single result) has a default implementation that calls `reifyDimOfResult` and returns the value for the dimension of the result (which in turn for the default case would call `reifyDimOfResult`). While this change sets up the interface, ideally most operations will implement the `refiyDimOfResult` when possible. For almost all operations in tree this is true. Subsequent commits will change those incrementally. Some of the tests added here that check that the default implementations for the above method work as expected, also end up hitting the pattern rewriter limit when using `--resolve-ranked-shaped-type-result-dims`/ `--resolve-ranked-shaped-type-result-dims`. For testing purposes, a flag is added to these passes that ignore the error returned by the pattern application (this flag is left on by default to maintain current state). Changes required downstream to integrate this change 1. In operation definitions in .td files, for those operations that implement the `ReifyRankedShapedTypeOpInterface`. ``` def <op-name> : Op<..., [..., DeclareOpInterfaceMethods[ReifyRankedShapedTypeOpInterface]]> ``` should be changed to ``` def <op-name> : Op<..., [..., DeclareOpInterfaceMethods[ReifyRankedShapedTypeOpInterface, [ "reifyResultShapes"]]]> ``` --------- Signed-off-by: MaheshRavishankar <mahesh.ravishankar@gmail.com>
Diffstat (limited to 'mlir/include')
-rw-r--r--mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td6
-rw-r--r--mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td3
-rw-r--r--mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td3
-rw-r--r--mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td3
-rw-r--r--mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td10
-rw-r--r--mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td22
-rw-r--r--mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td3
-rw-r--r--mlir/include/mlir/Interfaces/InferTypeOpInterface.h4
-rw-r--r--mlir/include/mlir/Interfaces/InferTypeOpInterface.td64
9 files changed, 101 insertions, 17 deletions
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
index 6724d4c48310..a9b2b9f39519 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
@@ -28,7 +28,8 @@ class Bufferization_Op<string mnemonic, list<Trait> traits = []>
def Bufferization_AllocTensorOp : Bufferization_Op<"alloc_tensor",
[AttrSizedOperandSegments, BufferizableOpInterface,
- DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>]> {
+ DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface, [
+ "reifyResultShapes"]>]> {
let summary = "allocate buffer for a tensor";
let description = [{
@@ -219,7 +220,8 @@ def Bufferization_MaterializeInDestinationOp
: Bufferization_Op<"materialize_in_destination",
[AllElementTypesMatch<["source", "dest"]>,
BufferizableOpInterface, DestinationStyleOpInterface,
- DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>,
+ DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface, [
+ "reifyResultShapes"]>,
DeclareOpInterfaceMethods<SubsetOpInterface,
["operatesOnEquivalentSubset", "operatesOnDisjointSubset"]>,
DeclareOpInterfaceMethods<SubsetInsertionOpInterface,
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
index 7ff44c2e1d2e..2754ee3b4f58 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
@@ -94,7 +94,8 @@ def Linalg_IndexOp : Linalg_Op<"index", [Pure]>,
def Linalg_SoftmaxOp : Linalg_Op<"softmax",
[DestinationStyleOpInterface,
PredOpTrait<"input and output have same element type", TCopVTEtIsSameAs<0, 1>>,
- DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>,
+ DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface,
+ ["reifyResultShapes"]>,
DeclareOpInterfaceMethods<AggregatedOpInterface, ["decomposeOperation"]>,
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
DeclareOpInterfaceMethods<TilingInterface,
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td
index 6504ca8664d4..238fa42cae42 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td
@@ -35,7 +35,8 @@ class Linalg_RelayoutOp<string mnemonic, list<Trait> traits = []> :
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
DestinationStyleOpInterface, LinalgRelayoutOpInterface,
ConditionallySpeculatable, NoMemoryEffect,
- DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>,
+ DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface, [
+ "reifyResultShapes"]>,
TypesMatchWith<"result type matches type of dest",
"dest", "result",
"$_self">])> {
diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
index 8965302a58c5..0bf22928f690 100644
--- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
@@ -1783,7 +1783,8 @@ class MemRef_ReassociativeReshapeOp<string mnemonic, list<Trait> traits = []> :
def MemRef_ExpandShapeOp : MemRef_ReassociativeReshapeOp<"expand_shape", [
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>,
- DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>]> {
+ DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface,
+ ["reifyResultShapes"]>]> {
let summary = "operation to produce a memref with a higher rank.";
let description = [{
The `memref.expand_shape` op produces a new view with a higher rank whose
diff --git a/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td b/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td
index f3e40aaa2907..c403386bd214 100644
--- a/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td
@@ -164,6 +164,11 @@ def ResolveRankedShapeTypeResultDimsPass
implement the `ReifyRankedShapedTypeOpInterface` in terms of
shapes of its operands.
}];
+ let options = [
+ Option<"errorOnPatternIterationLimit", "error-on-pattern-iteration-limit", "bool",
+ /*default=*/"true",
+ "Throw an error when pattern rewriter hits iteration limit">,
+ ];
let dependentDialects = [
"memref::MemRefDialect", "tensor::TensorDialect"
];
@@ -177,6 +182,11 @@ def ResolveShapedTypeResultDimsPass : Pass<"resolve-shaped-type-result-dims"> {
`ReifyRankedShapedTypeOpInterface` in terms of shapes of its
operands.
}];
+ let options = [
+ Option<"errorOnPatternIterationLimit", "error-on-pattern-iteration-limit", "bool",
+ /*default=*/"true",
+ "Throw an error when pattern rewriter hits iteration limit">,
+ ];
let dependentDialects = [
"affine::AffineDialect", "memref::MemRefDialect", "tensor::TensorDialect"
];
diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
index 2453cf5b5b5a..3e93e58575e6 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
+++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
@@ -131,7 +131,9 @@ def Tensor_CastOp : Tensor_Op<"cast", [
def Tensor_ConcatOp : Tensor_Op<"concat",
[Pure,
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
- DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>]> {
+ DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface, [
+ "reifyResultShapes"]>,
+ ]> {
let summary = "tensor concatenation operation";
let description = [{
The "concat" operation constructs a tensor out of a variadic list of input
@@ -261,7 +263,8 @@ def Tensor_DimOp : Tensor_Op<"dim", [
def Tensor_EmptyOp : Tensor_Op<"empty",
[Pure,
- DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>]> {
+ DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface, [
+ "reifyResultShapes"]>]> {
let summary = "empty tensor operation";
let description = [{
@@ -358,7 +361,8 @@ def Tensor_ExtractOp : Tensor_Op<"extract", [
def Tensor_ExtractSliceOp : Tensor_OpWithOffsetSizesAndStrides<"extract_slice", [
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
- DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>,
+ DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface, [
+ "reifyResultShapes"]>,
AttrSizedOperandSegments,
Pure,
OffsetSizeAndStrideOpInterface
@@ -740,7 +744,8 @@ def Tensor_GatherOp : Tensor_Op<"gather", [
def Tensor_GenerateOp : Tensor_Op<"generate", [
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
RecursiveMemoryEffects,
- DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>,
+ DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface, [
+ "reifyResultShapes"]>,
SingleBlockImplicitTerminator<"mlir::tensor::YieldOp">]> {
let summary = "Creates a dynamically sized tensor from elements";
let description = [{
@@ -835,7 +840,8 @@ def Tensor_InsertOp : Tensor_Op<"insert", [
def Tensor_InsertSliceOp : Tensor_OpWithOffsetSizesAndStrides<"insert_slice", [
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
- DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>,
+ DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface, [
+ "reifyResultShapes"]>,
AttrSizedOperandSegments,
DestinationStyleOpInterface,
Pure,
@@ -1256,7 +1262,8 @@ def Tensor_CollapseShapeOp : Tensor_ReassociativeReshapeOp<"collapse_shape"> {
def Tensor_PadOp : Tensor_Op<"pad", [
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
- DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>,
+ DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface, [
+ "reifyResultShapes"]>,
AttrSizedOperandSegments,
Pure,
SingleBlockImplicitTerminator<"mlir::tensor::YieldOp">]> {
@@ -1764,7 +1771,8 @@ def Tensor_ScatterOp : Tensor_Op<"scatter", [
def Tensor_SplatOp : Tensor_Op<"splat", [
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
- DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>,
+ DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface,
+ ["reifyResultShapes"]>,
Pure,
TypesMatchWith<"operand type matches element type of result",
"aggregate", "input",
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index 467dba3232f2..31d1e80f2772 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -2219,7 +2219,8 @@ def Tosa_TileOp : Tosa_InferShapedTypeOp<"tile"> {
// Operator: transpose
//===----------------------------------------------------------------------===//
def Tosa_TransposeOp : Tosa_InferShapedTypeOp<"transpose",
- [DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>,
+ [DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface ,
+ ["reifyResultShapes"]>,
AllElementTypesMatch<["input1", "output"]>]> {
let summary = "Transpose operator.";
diff --git a/mlir/include/mlir/Interfaces/InferTypeOpInterface.h b/mlir/include/mlir/Interfaces/InferTypeOpInterface.h
index 4fcbeff9df56..1bfb66e681d8 100644
--- a/mlir/include/mlir/Interfaces/InferTypeOpInterface.h
+++ b/mlir/include/mlir/Interfaces/InferTypeOpInterface.h
@@ -33,6 +33,10 @@ using ReifiedRankedShapedTypeDims = SmallVector<SmallVector<OpFoldResult>>;
LogicalResult
reifyResultShapes(OpBuilder &b, Operation *op,
ReifiedRankedShapedTypeDims &reifiedReturnShapes);
+FailureOr<SmallVector<OpFoldResult>>
+reifyShapeOfResult(OpBuilder &b, Operation *op, int resultIndex);
+FailureOr<OpFoldResult> reifyDimOfResult(OpBuilder &b, Operation *op,
+ int resultIndex, int dim);
/// Adaptor class to abstract the differences between whether value is from
/// a ShapedType or ShapedTypeComponents or DenseIntElementsAttribute.
diff --git a/mlir/include/mlir/Interfaces/InferTypeOpInterface.td b/mlir/include/mlir/Interfaces/InferTypeOpInterface.td
index 1a2c05fc16ed..67568f731f59 100644
--- a/mlir/include/mlir/Interfaces/InferTypeOpInterface.td
+++ b/mlir/include/mlir/Interfaces/InferTypeOpInterface.td
@@ -361,20 +361,76 @@ def ReifyRankedShapedTypeOpInterface :
let methods = [
InterfaceMethod<
/*desc=*/[{
- Reify the shape of the result of an operation (typically in terms of the
- shape of its operands).
+ Reify the shapes of all the result of an operation (typically in terms
+ of the shape of its operands).
`reifiedReturnShapes` is populated with one vector per op result. Each
of those vectors contains an OpFoldResult for each dimension of the
shaped type. The given builder may be used to insert ops that compute
result shapes.
- If the shape of a particular result cannot be computed it must be empty.
+ If the shape of a particular result cannot be computed it in terms of
+ its operands it must be left empty. If any dimension of the result cannot
+ be computed it must be set to OpFoldResult().
}],
/*retTy=*/"::llvm::LogicalResult",
/*methodName=*/"reifyResultShapes",
/*args=*/(ins "::mlir::OpBuilder &":$builder,
- "::mlir::ReifiedRankedShapedTypeDims &":$reifiedReturnShapes)
+ "::mlir::ReifiedRankedShapedTypeDims &":$reifiedReturnShapes),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{ return ::mlir::failure(); }]
+ >,
+ InterfaceMethod<
+ /*desc=*/[{
+ Reify the shape of a single result of an operation (typically in terms
+ of the shape of its operands).
+
+ Returns the shape of a single result of the operation as a
+ `SmallVector<OpFoldResult>`, one per dimension of the shaped type. The
+ given builder may be used to insert ops that compute result shapes.
+
+ If any dimension of the result cannot be computed it must be set to
+ OpFoldResult().
+ }],
+ /*retTy=*/"::llvm::FailureOr<::llvm::SmallVector<::mlir::OpFoldResult>>",
+ /*methodName=*/"reifyShapeOfResult",
+ /*args=*/(ins "::mlir::OpBuilder &":$builder,
+ "int":$resultIndex),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
+ ReifiedRankedShapedTypeDims reifiedShapes;
+ if (failed(cast<ReifyRankedShapedTypeOpInterface>($_op.getOperation()).reifyResultShapes(builder, reifiedShapes)))
+ return failure();
+ if (resultIndex < 0 || resultIndex >= static_cast<int>(reifiedShapes.size()))
+ return $_op.emitOpError("invalid result index");
+ return reifiedShapes[resultIndex];
+ }]
+ >,
+ InterfaceMethod<
+ /*desc=*/[{
+ Reify the shape of a dimension of a given result of an operation
+ (typically in terms of the shape of its operands).
+
+ Returns the shape of a specific dimension of a result of the operation as
+ an OpFoldResult. The given builder may be used to insert ops that compute
+ the shapes.
+
+ If the dimension of the result cannot be computed the method must return
+ `failure()`.
+ }],
+ /*retTy=*/"::llvm::FailureOr<::mlir::OpFoldResult>",
+ /*methodName=*/"reifyDimOfResult",
+ /*args=*/(ins "::mlir::OpBuilder &":$builder,
+ "int":$resultIndex, "int":$dim),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
+ auto shapes = cast<ReifyRankedShapedTypeOpInterface>($_op.getOperation()).reifyShapeOfResult(builder, resultIndex);
+ if (failed(shapes))
+ return failure();
+ if (dim < 0 || dim >= static_cast<int>((*shapes).size()))
+ return $_op.emitOpError("invalid dimension");
+ return (*shapes)[dim];
+ }]
>
];
}