diff options
| author | Andrei Golubev <andrey.golubev@intel.com> | 2025-09-24 12:09:27 +0100 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2025-09-24 13:09:27 +0200 |
| commit | ff4c4997ee72f4fda2d9939faefe8ef262d294a8 (patch) | |
| tree | f5509b24403d66c2c59582719cdd0b5d7f213bae | |
| parent | 59e74a0749203998b3e5fd9bc6525de148db28ab (diff) | |
[mlir][bufferization] Support custom types at function boundaries (#159766)
Support custom types (3/N): allow custom tensor and buffer types in
function signatures and at call-sites. This is one of the major building
blocks to move in the direction of module-level one-shot-bufferization
support.
To achieve this, `BufferizationOptions::FunctionArgTypeConverterFn`
callback is converted to work with tensor-like and buffer-like types,
instead of the builtin counterparts. The default behavior for builtins
remains unchanged, while custom types by default go through
`TensorLikeType::getBufferType()` which is a general conversion
interface.
5 files changed, 155 insertions, 59 deletions
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h index f3b34f9fded7..dd693a25fd54 100644 --- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h +++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h @@ -260,12 +260,12 @@ struct BufferizationOptions { std::function<LogicalResult(OpBuilder &, Location, Value, Value)>; /// Initializer function for analysis state. using AnalysisStateInitFn = std::function<void(AnalysisState &)>; - /// Tensor -> MemRef type converter. - /// Parameters: tensor type, memory space, func op, bufferization options + /// Tensor-like -> Buffer-like type conversion. + /// Parameters: tensor-like type, memory space, func op, bufferization options using FunctionArgTypeConverterFn = - std::function<BaseMemRefType(TensorType, Attribute memorySpace, + std::function<BufferLikeType(TensorLikeType, Attribute memorySpace, func::FuncOp, const BufferizationOptions &)>; - /// Tensor -> MemRef type converter. + /// Tensor -> MemRef type conversion. /// Parameters: tensor type, memory space, bufferization options using UnknownTypeConverterFn = std::function<BaseMemRefType( TensorType, Attribute memorySpace, const BufferizationOptions &)>; @@ -335,10 +335,12 @@ struct BufferizationOptions { /// predictable. void setFunctionBoundaryTypeConversion(LayoutMapOption layoutMapOption); - /// Type converter from tensors to memrefs. This type converter is used to - /// determine bufferized function argument and result types. By default, a - /// type converter that returns a memref type with a fully dynamic layout map - /// is used. + /// Type conversion from tensors to buffers. This type conversion is used to + /// determine bufferized function argument and result types. + /// + /// By default, if tensor is a (builtin) tensor type, it is converted to a + /// memref type with a fully dynamic layout map; if tensor is a (generic) + /// tensor-like type, it is converted using TensorLikeType::getBufferType(). /// /// If `bufferizeFunctionBoundaries` is not set, this function isn't used. FunctionArgTypeConverterFn functionArgTypeConverterFn = nullptr; @@ -350,10 +352,9 @@ struct BufferizationOptions { /// If `bufferizeFunctionBoundaries` is not set, this flag has no effect. bool inferFunctionResultLayout = true; - /// Type converter from tensors to memrefs. This type converter is used if no - /// memref type could be inferred during bufferization. By default, a type - /// converter that returns a memref type with a fully dynamic layout map is - /// used. + /// Type conversion from tensors to memrefs. This type conversion is used if + /// no memref type could be inferred during bufferization. By default, returns + /// a memref type with a fully dynamic layout map. UnknownTypeConverterFn unknownTypeConverterFn = nullptr; // Use during type conversion to determine the memory space for memref based diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp index f7b0b87085f3..e0cf353da207 100644 --- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp +++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp @@ -338,11 +338,21 @@ bool OpFilter::isOpAllowed(Operation *op) const { namespace { /// Default function arg type converter: Use a fully dynamic layout map. -BaseMemRefType -defaultFunctionArgTypeConverter(TensorType type, Attribute memorySpace, +BufferLikeType +defaultFunctionArgTypeConverter(TensorLikeType type, Attribute memorySpace, func::FuncOp funcOp, const BufferizationOptions &options) { - return getMemRefTypeWithFullyDynamicLayout(type, memorySpace); + if (auto tensorType = mlir::dyn_cast<TensorType>(type)) { + return cast<BufferLikeType>( + getMemRefTypeWithFullyDynamicLayout(tensorType, memorySpace)); + } + + // If not builtin, fallback to TensorLikeType::getBufferType() + auto bufferType = + type.getBufferType(options, [&]() { return funcOp->emitError(); }); + assert(succeeded(bufferType) && + "a valid buffer is always expected at function boundary"); + return *bufferType; } /// Default unknown type converter: Use a fully dynamic layout map. BaseMemRefType @@ -385,14 +395,25 @@ BufferizationOptions::dynCastBufferizableOp(Value value) const { void BufferizationOptions::setFunctionBoundaryTypeConversion( LayoutMapOption layoutMapOption) { - functionArgTypeConverterFn = [=](TensorType tensorType, Attribute memorySpace, + functionArgTypeConverterFn = [=](TensorLikeType type, Attribute memorySpace, func::FuncOp funcOp, const BufferizationOptions &options) { - if (layoutMapOption == LayoutMapOption::IdentityLayoutMap) - return bufferization::getMemRefTypeWithStaticIdentityLayout(tensorType, - memorySpace); - return bufferization::getMemRefTypeWithFullyDynamicLayout(tensorType, - memorySpace); + if (auto tensorType = mlir::dyn_cast<TensorType>(type)) { + if (layoutMapOption == LayoutMapOption::IdentityLayoutMap) + return cast<BufferLikeType>( + bufferization::getMemRefTypeWithStaticIdentityLayout(tensorType, + memorySpace)); + return cast<BufferLikeType>( + bufferization::getMemRefTypeWithFullyDynamicLayout(tensorType, + memorySpace)); + } + + // If not builtin, fallback to TensorLikeType::getBufferType() + auto bufferType = + type.getBufferType(options, [&]() { return funcOp->emitError(); }); + assert(succeeded(bufferType) && + "a valid buffer is always expected at function boundary"); + return *bufferType; }; inferFunctionResultLayout = layoutMapOption == LayoutMapOption::InferLayoutMap; diff --git a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp index 68ef51992efe..701ab52a491a 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp @@ -401,7 +401,7 @@ bufferization::bufferizeBlockSignature(Block *block, RewriterBase &rewriter, // Compute the new signature. SmallVector<Type> newTypes; for (BlockArgument &bbArg : block->getArguments()) { - auto tensorType = dyn_cast<TensorType>(bbArg.getType()); + auto tensorType = dyn_cast<TensorLikeType>(bbArg.getType()); if (!tensorType) { newTypes.push_back(bbArg.getType()); continue; diff --git a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp index f69efd1b3fa8..d9d69342e42a 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp @@ -49,29 +49,47 @@ void FuncAnalysisState::startFunctionAnalysis(FuncOp funcOp) { #endif // NDEBUG } +// Note: this is a local adaptor to unify TensorType and TensorLikeType code +// paths that both work with BufferizationOptions. +static mlir::Attribute +getDefaultMemorySpace(const BufferizationOptions &options, + TensorLikeType type) { + if (auto tensorType = dyn_cast<TensorType>(type)) { + return *options.defaultMemorySpaceFn(tensorType); + } + return nullptr; +} + /// Return the index-th bufferized function argument type. This assumes that the /// specified argument is a tensor. If the tensor is ranked, a layout map may be /// specified by the user (as per `options.functionArgTypeConverterFn`). -static BaseMemRefType +static BufferLikeType getBufferizedFunctionArgType(FuncOp funcOp, int64_t index, const BufferizationOptions &options) { - auto tensorType = - dyn_cast<TensorType>(funcOp.getFunctionType().getInput(index)); - assert(tensorType && "expected TensorType"); - - BaseMemRefType memrefType = options.functionArgTypeConverterFn( - tensorType, *options.defaultMemorySpaceFn(tensorType), funcOp, options); - - auto layoutAttr = funcOp.getArgAttrOfType<MemRefLayoutAttrInterface>( - index, BufferizationDialect::kBufferLayoutAttrName); - if (!layoutAttr) - return memrefType; - - auto rankedMemrefType = dyn_cast<MemRefType>(memrefType); - assert(rankedMemrefType && "buffer layout not supported on unranked tensors"); - return MemRefType::get(rankedMemrefType.getShape(), - rankedMemrefType.getElementType(), layoutAttr, - rankedMemrefType.getMemorySpace()); + auto type = + dyn_cast<TensorLikeType>(funcOp.getFunctionType().getInput(index)); + assert(type && "expected TensorLikeType"); + + // Note: For builtin tensors there is additional logic related to layout. + if (auto tensorType = dyn_cast<TensorType>(type)) { + BufferLikeType memrefType = options.functionArgTypeConverterFn( + type, *options.defaultMemorySpaceFn(tensorType), funcOp, options); + + auto layoutAttr = funcOp.getArgAttrOfType<MemRefLayoutAttrInterface>( + index, BufferizationDialect::kBufferLayoutAttrName); + if (!layoutAttr) + return memrefType; + + auto rankedMemrefType = dyn_cast<MemRefType>(memrefType); + assert(rankedMemrefType && + "buffer layout not supported on unranked tensors"); + return cast<BufferLikeType>(MemRefType::get( + rankedMemrefType.getShape(), rankedMemrefType.getElementType(), + layoutAttr, rankedMemrefType.getMemorySpace())); + } + + return options.functionArgTypeConverterFn(type, /*memSpace=*/nullptr, funcOp, + options); } /// Return the FuncOp called by `callOp`. @@ -227,13 +245,13 @@ struct CallOpInterface FunctionType funcType = funcOp.getFunctionType(); Type resultType = funcType.getResult(cast<OpResult>(value).getResultNumber()); - if (auto bufferizedType = dyn_cast<BaseMemRefType>(resultType)) - return cast<BufferLikeType>(bufferizedType); + if (auto bufferizedType = dyn_cast<BufferLikeType>(resultType)) + return bufferizedType; // Otherwise, call the type converter to compute the bufferized type. - auto tensorType = cast<TensorType>(resultType); + auto tensorType = cast<TensorLikeType>(resultType); return cast<BufferLikeType>(options.functionArgTypeConverterFn( - tensorType, *options.defaultMemorySpaceFn(tensorType), funcOp, + tensorType, getDefaultMemorySpace(options, tensorType), funcOp, options)); } @@ -248,7 +266,7 @@ struct CallOpInterface SmallVector<Type> resultTypes; for (Value result : callOp.getResults()) { Type returnType = result.getType(); - if (!isa<TensorType>(returnType)) { + if (!isa<TensorLikeType>(returnType)) { // Non-tensor values are returned. resultTypes.push_back(returnType); continue; @@ -272,7 +290,7 @@ struct CallOpInterface for (OpOperand &opOperand : callOp->getOpOperands()) { // Non-tensor operands are just copied. - if (!isa<TensorType>(opOperand.get().getType())) { + if (!isa<TensorLikeType>(opOperand.get().getType())) { newOperands.push_back(opOperand.get()); continue; } @@ -285,8 +303,8 @@ struct CallOpInterface Value buffer = *maybeBuffer; // Caller / callee type mismatch is handled with castOrReallocMemRefValue. - auto memRefType = funcType.getInput(opOperand.getOperandNumber()); - if (!isa<BaseMemRefType>(memRefType)) { + auto bufferType = funcType.getInput(opOperand.getOperandNumber()); + if (!isa<BufferLikeType>(bufferType)) { // The called function was not bufferized yet. This can happen when // there cycles in the function call graph. Compute the bufferized // result type. @@ -296,7 +314,7 @@ struct CallOpInterface state); if (failed(maybeBufferType)) return failure(); - memRefType = *maybeBufferType; + bufferType = *maybeBufferType; } // Since we don't yet have a clear layout story, to_buffer may @@ -305,8 +323,8 @@ struct CallOpInterface // that will either canonicalize away or fail compilation until we can do // something better. Insert a reallocation + copy if it cannot be // statically guaranteed that a direct cast would be valid. - if (buffer.getType() != memRefType) { - auto memrefDstType = dyn_cast<MemRefType>(memRefType); + if (buffer.getType() != bufferType) { + auto memrefDstType = dyn_cast<MemRefType>(bufferType); assert(memrefDstType && "buffer layout not supported on unranked tensors"); FailureOr<Value> replacement = bufferization::castOrReallocMemRefValue( @@ -370,7 +388,7 @@ struct FuncOpInterface static bool supportsUnstructuredControlFlow() { return true; } bool hasTensorSemantics(Operation *op) const { - auto isaTensor = llvm::IsaPred<TensorType>; + auto isaTensor = llvm::IsaPred<TensorLikeType>; // A function has tensor semantics if it has tensor arguments/results. auto funcOp = cast<FuncOp>(op); @@ -406,8 +424,8 @@ struct FuncOpInterface // Function arguments are special. if (bbArg.getOwner() == &funcOp.getBody().front()) - return cast<BufferLikeType>( - getBufferizedFunctionArgType(funcOp, bbArg.getArgNumber(), options)); + return getBufferizedFunctionArgType(funcOp, bbArg.getArgNumber(), + options); return OpWithUnstructuredControlFlowBufferizableOpInterfaceExternalModel:: getBufferType(op, value, options, state, invocationStack); @@ -430,7 +448,7 @@ struct FuncOpInterface SmallVector<Type> argTypes; for (const auto &it : llvm::enumerate(funcType.getInputs())) { Type argType = it.value(); - if (isa<TensorType>(argType)) { + if (isa<TensorLikeType>(argType)) { argTypes.push_back( getBufferizedFunctionArgType(funcOp, it.index(), options)); continue; @@ -441,9 +459,9 @@ struct FuncOpInterface // Compute the result types. SmallVector<Type> retTypes; for (Type resultType : funcType.getResults()) { - if (auto tensorType = dyn_cast<TensorType>(resultType)) { - BaseMemRefType resultType = options.functionArgTypeConverterFn( - tensorType, *options.defaultMemorySpaceFn(tensorType), funcOp, + if (auto tensorType = dyn_cast<TensorLikeType>(resultType)) { + BufferLikeType resultType = options.functionArgTypeConverterFn( + tensorType, getDefaultMemorySpace(options, tensorType), funcOp, options); retTypes.push_back(resultType); continue; @@ -473,7 +491,7 @@ struct FuncOpInterface SmallVector<Value> returnValues; for (auto [returnVal, bufferizedType] : llvm::zip_equal(returnOp->getOperands(), retTypes)) { - auto tensorType = dyn_cast<TensorType>(returnVal.getType()); + auto tensorType = dyn_cast<TensorLikeType>(returnVal.getType()); rewriter.setInsertionPoint(returnOp); // If not a tensor type just forward it. diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir index 2efb5893c851..eb0093106dc1 100644 --- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir +++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir @@ -810,3 +810,59 @@ module @inner_module { return %t : tensor<5xf32> } } + +// ----- + +// CHECK: func.func @custom_types( +// CHECK-SAME: %[[arg:.*]]: !test.test_memref<[4, 4], f64> +// CHECK-SAME: ) -> (!test.test_memref<[4, 8], f64>, +// CHECK-SAME: !test.test_memref<[4, 8], f64>) +func.func @custom_types(%arg: !test.test_tensor<[4, 4], f64>) + -> (!test.test_tensor<[4, 8], f64>, !test.test_tensor<[4, 8], f64>) { + // CHECK: %[[out1:.*]] = "test.dummy_memref_op"(%[[arg]]) : + // CHECK-SAME: (!test.test_memref<[4, 4], f64>) -> !test.test_memref<[4, 8], f64> + %out1 = "test.dummy_tensor_op"(%arg) : (!test.test_tensor<[4, 4], f64>) + -> !test.test_tensor<[4, 8], f64> + + // CHECK: %[[alloc:.*]] = "test.create_memref_op" + // CHECK: %[[out2:.*]] = "test.dummy_memref_op"(%[[alloc]]) + // CHECK-SAME: (!test.test_memref<[4, 4], f64>) -> !test.test_memref<[4, 8], f64> + %alloc = "test.create_tensor_op"() : () -> !test.test_tensor<[4, 4], f64> + %out2 = "test.dummy_tensor_op"(%alloc) : (!test.test_tensor<[4, 4], f64>) + -> !test.test_tensor<[4, 8], f64> + + // CHECK: return %[[out1]], %[[out2]] + return %out1, %out2 : + !test.test_tensor<[4, 8], f64>, !test.test_tensor<[4, 8], f64> +} + +// ----- + +// CHECK: func.func @custom_types_foo( +// CHECK-SAME: %[[arg:.*]]: !test.test_memref<[4, 4], f64> +// CHECK-SAME: ) -> !test.test_memref<[4, 4], f64> +func.func @custom_types_foo(%arg: !test.test_tensor<[4, 4], f64>) + -> !test.test_tensor<[4, 4], f64> { + // CHECK: %[[out:.*]] = "test.dummy_memref_op"(%[[arg]]) + %out = "test.dummy_tensor_op"(%arg) : (!test.test_tensor<[4, 4], f64>) + -> !test.test_tensor<[4, 4], f64> + // CHECK: return %[[out]] + return %out : !test.test_tensor<[4, 4], f64> +} + +// CHECK: func.func @custom_types_bar( +// CHECK-SAME: %[[arg:.*]]: !test.test_memref<[4, 4], f64> +// CHECK-SAME: ) -> !test.test_memref<[4, 8], f64> +func.func @custom_types_bar(%arg: !test.test_tensor<[4, 4], f64>) + -> !test.test_tensor<[4, 8], f64> { + // CHECK: %[[call:.*]] = call @custom_types_foo(%[[arg]]) + %call = func.call @custom_types_foo(%arg) : (!test.test_tensor<[4, 4], f64>) + -> !test.test_tensor<[4, 4], f64> + + // CHECK: %[[out:.*]] = "test.dummy_memref_op"(%[[call]]) + %out = "test.dummy_tensor_op"(%call) : (!test.test_tensor<[4, 4], f64>) + -> !test.test_tensor<[4, 8], f64> + + // CHECK: return %[[out]] + return %out : !test.test_tensor<[4, 8], f64> +} |
