diff options
Diffstat (limited to 'mlir/lib/Conversion/BufferizationToMemRef/BufferizationToMemRef.cpp')
| -rw-r--r-- | mlir/lib/Conversion/BufferizationToMemRef/BufferizationToMemRef.cpp | 97 |
1 files changed, 67 insertions, 30 deletions
diff --git a/mlir/lib/Conversion/BufferizationToMemRef/BufferizationToMemRef.cpp b/mlir/lib/Conversion/BufferizationToMemRef/BufferizationToMemRef.cpp index 3069f6e07324..810f82f6442e 100644 --- a/mlir/lib/Conversion/BufferizationToMemRef/BufferizationToMemRef.cpp +++ b/mlir/lib/Conversion/BufferizationToMemRef/BufferizationToMemRef.cpp @@ -42,39 +42,76 @@ struct CloneOpConversion : public OpConversionPattern<bufferization::CloneOp> { LogicalResult matchAndRewrite(bufferization::CloneOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - // Check for unranked memref types which are currently not supported. - Type type = op.getType(); - if (isa<UnrankedMemRefType>(type)) { - return rewriter.notifyMatchFailure( - op, "UnrankedMemRefType is not supported."); - } - MemRefType memrefType = cast<MemRefType>(type); - MemRefLayoutAttrInterface layout; - auto allocType = - MemRefType::get(memrefType.getShape(), memrefType.getElementType(), - layout, memrefType.getMemorySpace()); - // Since this implementation always allocates, certain result types of the - // clone op cannot be lowered. - if (!memref::CastOp::areCastCompatible({allocType}, {memrefType})) - return failure(); - - // Transform a clone operation into alloc + copy operation and pay - // attention to the shape dimensions. Location loc = op->getLoc(); - SmallVector<Value, 4> dynamicOperands; - for (int i = 0; i < memrefType.getRank(); ++i) { - if (!memrefType.isDynamicDim(i)) - continue; - Value dim = rewriter.createOrFold<memref::DimOp>(loc, op.getInput(), i); - dynamicOperands.push_back(dim); + + Type type = op.getType(); + Value alloc; + + if (auto unrankedType = dyn_cast<UnrankedMemRefType>(type)) { + // Constants + Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0); + Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1); + + // Dynamically evaluate the size and shape of the unranked memref + Value rank = rewriter.create<memref::RankOp>(loc, op.getInput()); + MemRefType allocType = + MemRefType::get({ShapedType::kDynamic}, rewriter.getIndexType()); + Value shape = rewriter.create<memref::AllocaOp>(loc, allocType, rank); + + // Create a loop to query dimension sizes, store them as a shape, and + // compute the total size of the memref + auto loopBody = [&](OpBuilder &builder, Location loc, Value i, + ValueRange args) { + auto acc = args.front(); + auto dim = rewriter.create<memref::DimOp>(loc, op.getInput(), i); + + rewriter.create<memref::StoreOp>(loc, dim, shape, i); + acc = rewriter.create<arith::MulIOp>(loc, acc, dim); + + rewriter.create<scf::YieldOp>(loc, acc); + }; + auto size = rewriter + .create<scf::ForOp>(loc, zero, rank, one, ValueRange(one), + loopBody) + .getResult(0); + + MemRefType memrefType = MemRefType::get({ShapedType::kDynamic}, + unrankedType.getElementType()); + + // Allocate new memref with 1D dynamic shape, then reshape into the + // shape of the original unranked memref + alloc = rewriter.create<memref::AllocOp>(loc, memrefType, size); + alloc = + rewriter.create<memref::ReshapeOp>(loc, unrankedType, alloc, shape); + } else { + MemRefType memrefType = cast<MemRefType>(type); + MemRefLayoutAttrInterface layout; + auto allocType = + MemRefType::get(memrefType.getShape(), memrefType.getElementType(), + layout, memrefType.getMemorySpace()); + // Since this implementation always allocates, certain result types of + // the clone op cannot be lowered. + if (!memref::CastOp::areCastCompatible({allocType}, {memrefType})) + return failure(); + + // Transform a clone operation into alloc + copy operation and pay + // attention to the shape dimensions. + SmallVector<Value, 4> dynamicOperands; + for (int i = 0; i < memrefType.getRank(); ++i) { + if (!memrefType.isDynamicDim(i)) + continue; + Value dim = rewriter.createOrFold<memref::DimOp>(loc, op.getInput(), i); + dynamicOperands.push_back(dim); + } + + // Allocate a memref with identity layout. + alloc = rewriter.create<memref::AllocOp>(loc, allocType, dynamicOperands); + // Cast the allocation to the specified type if needed. + if (memrefType != allocType) + alloc = + rewriter.create<memref::CastOp>(op->getLoc(), memrefType, alloc); } - // Allocate a memref with identity layout. - Value alloc = rewriter.create<memref::AllocOp>(op->getLoc(), allocType, - dynamicOperands); - // Cast the allocation to the specified type if needed. - if (memrefType != allocType) - alloc = rewriter.create<memref::CastOp>(op->getLoc(), memrefType, alloc); rewriter.replaceOp(op, alloc); rewriter.create<memref::CopyOp>(loc, op.getInput(), alloc); return success(); |
