diff options
Diffstat (limited to 'mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp')
| -rw-r--r-- | mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp | 39 |
1 files changed, 30 insertions, 9 deletions
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp index f7b0b87085f3..e0cf353da207 100644 --- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp +++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp @@ -338,11 +338,21 @@ bool OpFilter::isOpAllowed(Operation *op) const { namespace { /// Default function arg type converter: Use a fully dynamic layout map. -BaseMemRefType -defaultFunctionArgTypeConverter(TensorType type, Attribute memorySpace, +BufferLikeType +defaultFunctionArgTypeConverter(TensorLikeType type, Attribute memorySpace, func::FuncOp funcOp, const BufferizationOptions &options) { - return getMemRefTypeWithFullyDynamicLayout(type, memorySpace); + if (auto tensorType = mlir::dyn_cast<TensorType>(type)) { + return cast<BufferLikeType>( + getMemRefTypeWithFullyDynamicLayout(tensorType, memorySpace)); + } + + // If not builtin, fallback to TensorLikeType::getBufferType() + auto bufferType = + type.getBufferType(options, [&]() { return funcOp->emitError(); }); + assert(succeeded(bufferType) && + "a valid buffer is always expected at function boundary"); + return *bufferType; } /// Default unknown type converter: Use a fully dynamic layout map. BaseMemRefType @@ -385,14 +395,25 @@ BufferizationOptions::dynCastBufferizableOp(Value value) const { void BufferizationOptions::setFunctionBoundaryTypeConversion( LayoutMapOption layoutMapOption) { - functionArgTypeConverterFn = [=](TensorType tensorType, Attribute memorySpace, + functionArgTypeConverterFn = [=](TensorLikeType type, Attribute memorySpace, func::FuncOp funcOp, const BufferizationOptions &options) { - if (layoutMapOption == LayoutMapOption::IdentityLayoutMap) - return bufferization::getMemRefTypeWithStaticIdentityLayout(tensorType, - memorySpace); - return bufferization::getMemRefTypeWithFullyDynamicLayout(tensorType, - memorySpace); + if (auto tensorType = mlir::dyn_cast<TensorType>(type)) { + if (layoutMapOption == LayoutMapOption::IdentityLayoutMap) + return cast<BufferLikeType>( + bufferization::getMemRefTypeWithStaticIdentityLayout(tensorType, + memorySpace)); + return cast<BufferLikeType>( + bufferization::getMemRefTypeWithFullyDynamicLayout(tensorType, + memorySpace)); + } + + // If not builtin, fallback to TensorLikeType::getBufferType() + auto bufferType = + type.getBufferType(options, [&]() { return funcOp->emitError(); }); + assert(succeeded(bufferType) && + "a valid buffer is always expected at function boundary"); + return *bufferType; }; inferFunctionResultLayout = layoutMapOption == LayoutMapOption::InferLayoutMap; |
