diff options
| author | Matthias Springer <me@m-sp.org> | 2024-11-14 10:27:58 +0900 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2024-11-14 10:27:58 +0900 |
| commit | aed4356252df2a4ab2e430d77a29bdb3dfd874fc (patch) | |
| tree | 1a276e760bd3f9eba8b77e2c6515169b526facb8 /mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp | |
| parent | e5092c301959b599ffd51b7942a8bed5c4be54de (diff) | |
[mlir][Transforms] Dialect Conversion: Add `replaceOpWithMultiple` (#115816)
This commit adds a new function
`ConversionPatternRewriter::replaceOpWithMultiple`. This function is
similar to `replaceOp`, but it accepts multiple `ValueRange`
replacements, one per op result.
Note: This function is not an overload of `replaceOp` because of
ambiguous overload resolution that would make the API difficult to use.
This commit aligns "block signature conversions" with "op replacements":
both support 1:N replacements now. Due to incomplete 1:N support in the
dialect conversion driver, an argument materialization is inserted when
an SSA value is replaced with multiple values; same as block signature
conversions already work around the problem. These argument
materializations are going to be removed in a subsequent commit that
adds full 1:N support. The purpose of this PR is to add missing features
gradually in small increments.
This commit also updates two MLIR transformations that have their custom
workarounds around missing 1:N support. These can already start using
`replaceOpWithMultiple`.
Co-authored-by: Markus Böck <markus.boeck02@gmail.com>
Diffstat (limited to 'mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp')
| -rw-r--r-- | mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp | 43 |
1 files changed, 20 insertions, 23 deletions
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp index 062a0ea6cc47..bf7b3f9bec55 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp @@ -600,8 +600,8 @@ public: flattenOperands(adaptor.getOperands(), flattened); auto newCall = rewriter.create<func::CallOp>(loc, op.getCallee(), finalRetTy, flattened); - // (2) Create cast operation for sparse tensor returns. - SmallVector<Value> castedRet; + // (2) Gather sparse tensor returns. + SmallVector<SmallVector<Value>> packedResultVals; // Tracks the offset of current return value (of the original call) // relative to the new call (after sparse tensor flattening); unsigned retOffset = 0; @@ -618,21 +618,22 @@ public: assert(!sparseFlat.empty()); if (sparseFlat.size() > 1) { auto flatSize = sparseFlat.size(); - ValueRange fields(iterator_range<ResultRange::iterator>( - newCall.result_begin() + retOffset, - newCall.result_begin() + retOffset + flatSize)); - castedRet.push_back(genTuple(rewriter, loc, retType, fields)); + packedResultVals.emplace_back(); + llvm::append_range(packedResultVals.back(), + newCall.getResults().slice(retOffset, flatSize)); retOffset += flatSize; } else { // If this is an 1:1 conversion, no need for casting. - castedRet.push_back(newCall.getResult(retOffset)); + packedResultVals.emplace_back(); + packedResultVals.back().push_back(newCall.getResult(retOffset)); retOffset++; } sparseFlat.clear(); } - assert(castedRet.size() == op.getNumResults()); - rewriter.replaceOp(op, castedRet); + assert(packedResultVals.size() == op.getNumResults()); + rewriter.replaceOpWithMultiple( + op, llvm::to_vector_of<ValueRange>(packedResultVals)); return success(); } }; @@ -776,7 +777,7 @@ public: // Reuses specifier. fields.push_back(desc.getSpecifier()); assert(fields.size() == desc.getNumFields()); - rewriter.replaceOp(op, genTuple(rewriter, loc, resType, fields)); + rewriter.replaceOpWithMultiple(op, {fields}); return success(); } @@ -796,7 +797,7 @@ public: sizeHint, lvlSizesValues, fields); // Replace operation with resulting memrefs. - rewriter.replaceOp(op, genTuple(rewriter, loc, resType, fields)); + rewriter.replaceOpWithMultiple(op, {fields}); return success(); } @@ -837,7 +838,7 @@ public: sizeHint, lvlSizesValues, fields); // Replace operation with resulting memrefs. - rewriter.replaceOp(op, genTuple(rewriter, loc, resType, fields)); + rewriter.replaceOpWithMultiple(op, {fields}); return success(); } @@ -893,7 +894,7 @@ public: if (op.getHasInserts()) genEndInsert(rewriter, op.getLoc(), desc); // Replace operation with resulting memrefs. - rewriter.replaceOp(op, genTuple(rewriter, op.getLoc(), desc)); + rewriter.replaceOpWithMultiple(op, {desc.getFields()}); return success(); } }; @@ -1006,7 +1007,6 @@ public: rewriter.create<scf::YieldOp>(loc, insertRet); rewriter.setInsertionPointAfter(loop); - Value result = genTuple(rewriter, loc, dstType, loop->getResults()); // Deallocate the buffers on exit of the full loop nest. Operation *parent = getTop(op); rewriter.setInsertionPointAfter(parent); @@ -1014,7 +1014,7 @@ public: rewriter.create<memref::DeallocOp>(loc, filled); rewriter.create<memref::DeallocOp>(loc, added); // Replace operation with resulting memrefs. - rewriter.replaceOp(op, result); + rewriter.replaceOpWithMultiple(op, {loop->getResults()}); return success(); } }; @@ -1041,8 +1041,7 @@ public: params, /*genCall=*/true); SmallVector<Value> ret = insertGen.genCallOrInline(rewriter, loc); // Replace operation with resulting memrefs. - rewriter.replaceOp(op, - genTuple(rewriter, loc, op.getDest().getType(), ret)); + rewriter.replaceOpWithMultiple(op, {ret}); return success(); } }; @@ -1215,8 +1214,7 @@ public: return true; }); - rewriter.replaceOp( - op, genTuple(rewriter, loc, op.getResult().getType(), fields)); + rewriter.replaceOpWithMultiple(op, {fields}); return success(); } }; @@ -1271,8 +1269,7 @@ public: // NOTE: we can not generate tuples directly from descriptor here, as the // descriptor is holding the original type, yet we want the slice type // here (they shared every memref but with an updated specifier). - rewriter.replaceOp(op, genTuple(rewriter, loc, op.getResult().getType(), - desc.getFields())); + rewriter.replaceOpWithMultiple(op, {desc.getFields()}); return success(); } }; @@ -1403,7 +1400,7 @@ struct SparseAssembleOpConverter : public OpConversionPattern<AssembleOp> { } desc.setValMemSize(rewriter, loc, memSize); - rewriter.replaceOp(op, genTuple(rewriter, loc, desc)); + rewriter.replaceOpWithMultiple(op, {desc.getFields()}); return success(); } }; @@ -1577,7 +1574,7 @@ struct SparseNewConverter : public OpConversionPattern<NewOp> { EmitCInterface::Off); // Replace operation with resulting memrefs. - rewriter.replaceOp(op, genTuple(rewriter, loc, dstTp, fields)); + rewriter.replaceOpWithMultiple(op, {fields}); return success(); } }; |
