diff options
| author | Aart Bik <ajcbik@google.com> | 2024-05-07 19:01:36 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2024-05-07 19:01:36 -0700 |
| commit | c4e5a8a4d3ef0948384d9411ea1e44fc113e5b5c (patch) | |
| tree | f59d0612de11e27813ead7bf1244f9e0656274d5 /mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp | |
| parent | 584253c4e2f788f870488fc32193b52d67ddaccc (diff) | |
[mlir][sparse] support 'batch' dimensions in sparse_tensor.print (#91411)
Diffstat (limited to 'mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp')
| -rw-r--r-- | mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp | 12 |
1 files changed, 9 insertions, 3 deletions
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp index d9b203a88648..164e722c45db 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp @@ -417,11 +417,17 @@ static void genEndInsert(OpBuilder &builder, Location loc, /// Generates a subview into the sizes. static Value genSliceToSize(OpBuilder &builder, Location loc, Value mem, Value sz) { - auto elemTp = llvm::cast<MemRefType>(mem.getType()).getElementType(); + auto memTp = llvm::cast<MemRefType>(mem.getType()); + // For higher-dimensional memrefs, we assume that the innermost + // dimension is always of the right size. + // TODO: generate complex truncating view here too? + if (memTp.getRank() > 1) + return mem; + // Truncate linear memrefs to given size. return builder .create<memref::SubViewOp>( - loc, MemRefType::get({ShapedType::kDynamic}, elemTp), mem, - ValueRange{}, ValueRange{sz}, ValueRange{}, + loc, MemRefType::get({ShapedType::kDynamic}, memTp.getElementType()), + mem, ValueRange{}, ValueRange{sz}, ValueRange{}, ArrayRef<int64_t>{0}, // static offset ArrayRef<int64_t>{ShapedType::kDynamic}, // dynamic size ArrayRef<int64_t>{1}) // static stride |
