summaryrefslogtreecommitdiff
path: root/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
diff options
context:
space:
mode:
authorAart Bik <ajcbik@google.com>2024-05-07 19:01:36 -0700
committerGitHub <noreply@github.com>2024-05-07 19:01:36 -0700
commitc4e5a8a4d3ef0948384d9411ea1e44fc113e5b5c (patch)
treef59d0612de11e27813ead7bf1244f9e0656274d5 /mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
parent584253c4e2f788f870488fc32193b52d67ddaccc (diff)
[mlir][sparse] support 'batch' dimensions in sparse_tensor.print (#91411)
Diffstat (limited to 'mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp')
-rw-r--r--mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp12
1 files changed, 9 insertions, 3 deletions
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
index d9b203a88648..164e722c45db 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
@@ -417,11 +417,17 @@ static void genEndInsert(OpBuilder &builder, Location loc,
/// Generates a subview into the sizes.
static Value genSliceToSize(OpBuilder &builder, Location loc, Value mem,
Value sz) {
- auto elemTp = llvm::cast<MemRefType>(mem.getType()).getElementType();
+ auto memTp = llvm::cast<MemRefType>(mem.getType());
+ // For higher-dimensional memrefs, we assume that the innermost
+ // dimension is always of the right size.
+ // TODO: generate complex truncating view here too?
+ if (memTp.getRank() > 1)
+ return mem;
+ // Truncate linear memrefs to given size.
return builder
.create<memref::SubViewOp>(
- loc, MemRefType::get({ShapedType::kDynamic}, elemTp), mem,
- ValueRange{}, ValueRange{sz}, ValueRange{},
+ loc, MemRefType::get({ShapedType::kDynamic}, memTp.getElementType()),
+ mem, ValueRange{}, ValueRange{sz}, ValueRange{},
ArrayRef<int64_t>{0}, // static offset
ArrayRef<int64_t>{ShapedType::kDynamic}, // dynamic size
ArrayRef<int64_t>{1}) // static stride