summaryrefslogtreecommitdiff
path: root/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp')
-rw-r--r--mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp39
1 files changed, 30 insertions, 9 deletions
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
index f7b0b87085f3..e0cf353da207 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
@@ -338,11 +338,21 @@ bool OpFilter::isOpAllowed(Operation *op) const {
namespace {
/// Default function arg type converter: Use a fully dynamic layout map.
-BaseMemRefType
-defaultFunctionArgTypeConverter(TensorType type, Attribute memorySpace,
+BufferLikeType
+defaultFunctionArgTypeConverter(TensorLikeType type, Attribute memorySpace,
func::FuncOp funcOp,
const BufferizationOptions &options) {
- return getMemRefTypeWithFullyDynamicLayout(type, memorySpace);
+ if (auto tensorType = mlir::dyn_cast<TensorType>(type)) {
+ return cast<BufferLikeType>(
+ getMemRefTypeWithFullyDynamicLayout(tensorType, memorySpace));
+ }
+
+ // If not builtin, fallback to TensorLikeType::getBufferType()
+ auto bufferType =
+ type.getBufferType(options, [&]() { return funcOp->emitError(); });
+ assert(succeeded(bufferType) &&
+ "a valid buffer is always expected at function boundary");
+ return *bufferType;
}
/// Default unknown type converter: Use a fully dynamic layout map.
BaseMemRefType
@@ -385,14 +395,25 @@ BufferizationOptions::dynCastBufferizableOp(Value value) const {
void BufferizationOptions::setFunctionBoundaryTypeConversion(
LayoutMapOption layoutMapOption) {
- functionArgTypeConverterFn = [=](TensorType tensorType, Attribute memorySpace,
+ functionArgTypeConverterFn = [=](TensorLikeType type, Attribute memorySpace,
func::FuncOp funcOp,
const BufferizationOptions &options) {
- if (layoutMapOption == LayoutMapOption::IdentityLayoutMap)
- return bufferization::getMemRefTypeWithStaticIdentityLayout(tensorType,
- memorySpace);
- return bufferization::getMemRefTypeWithFullyDynamicLayout(tensorType,
- memorySpace);
+ if (auto tensorType = mlir::dyn_cast<TensorType>(type)) {
+ if (layoutMapOption == LayoutMapOption::IdentityLayoutMap)
+ return cast<BufferLikeType>(
+ bufferization::getMemRefTypeWithStaticIdentityLayout(tensorType,
+ memorySpace));
+ return cast<BufferLikeType>(
+ bufferization::getMemRefTypeWithFullyDynamicLayout(tensorType,
+ memorySpace));
+ }
+
+ // If not builtin, fallback to TensorLikeType::getBufferType()
+ auto bufferType =
+ type.getBufferType(options, [&]() { return funcOp->emitError(); });
+ assert(succeeded(bufferType) &&
+ "a valid buffer is always expected at function boundary");
+ return *bufferType;
};
inferFunctionResultLayout =
layoutMapOption == LayoutMapOption::InferLayoutMap;