summaryrefslogtreecommitdiff
path: root/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
diff options
context:
space:
mode:
authorMatthias Springer <me@m-sp.org>2024-11-14 10:27:58 +0900
committerGitHub <noreply@github.com>2024-11-14 10:27:58 +0900
commitaed4356252df2a4ab2e430d77a29bdb3dfd874fc (patch)
tree1a276e760bd3f9eba8b77e2c6515169b526facb8 /mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
parente5092c301959b599ffd51b7942a8bed5c4be54de (diff)
[mlir][Transforms] Dialect Conversion: Add `replaceOpWithMultiple` (#115816)
This commit adds a new function `ConversionPatternRewriter::replaceOpWithMultiple`. This function is similar to `replaceOp`, but it accepts multiple `ValueRange` replacements, one per op result. Note: This function is not an overload of `replaceOp` because of ambiguous overload resolution that would make the API difficult to use. This commit aligns "block signature conversions" with "op replacements": both support 1:N replacements now. Due to incomplete 1:N support in the dialect conversion driver, an argument materialization is inserted when an SSA value is replaced with multiple values; same as block signature conversions already work around the problem. These argument materializations are going to be removed in a subsequent commit that adds full 1:N support. The purpose of this PR is to add missing features gradually in small increments. This commit also updates two MLIR transformations that have their custom workarounds around missing 1:N support. These can already start using `replaceOpWithMultiple`. Co-authored-by: Markus Böck <markus.boeck02@gmail.com>
Diffstat (limited to 'mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp')
-rw-r--r--mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp43
1 files changed, 20 insertions, 23 deletions
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
index 062a0ea6cc47..bf7b3f9bec55 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
@@ -600,8 +600,8 @@ public:
flattenOperands(adaptor.getOperands(), flattened);
auto newCall = rewriter.create<func::CallOp>(loc, op.getCallee(),
finalRetTy, flattened);
- // (2) Create cast operation for sparse tensor returns.
- SmallVector<Value> castedRet;
+ // (2) Gather sparse tensor returns.
+ SmallVector<SmallVector<Value>> packedResultVals;
// Tracks the offset of current return value (of the original call)
// relative to the new call (after sparse tensor flattening);
unsigned retOffset = 0;
@@ -618,21 +618,22 @@ public:
assert(!sparseFlat.empty());
if (sparseFlat.size() > 1) {
auto flatSize = sparseFlat.size();
- ValueRange fields(iterator_range<ResultRange::iterator>(
- newCall.result_begin() + retOffset,
- newCall.result_begin() + retOffset + flatSize));
- castedRet.push_back(genTuple(rewriter, loc, retType, fields));
+ packedResultVals.emplace_back();
+ llvm::append_range(packedResultVals.back(),
+ newCall.getResults().slice(retOffset, flatSize));
retOffset += flatSize;
} else {
// If this is an 1:1 conversion, no need for casting.
- castedRet.push_back(newCall.getResult(retOffset));
+ packedResultVals.emplace_back();
+ packedResultVals.back().push_back(newCall.getResult(retOffset));
retOffset++;
}
sparseFlat.clear();
}
- assert(castedRet.size() == op.getNumResults());
- rewriter.replaceOp(op, castedRet);
+ assert(packedResultVals.size() == op.getNumResults());
+ rewriter.replaceOpWithMultiple(
+ op, llvm::to_vector_of<ValueRange>(packedResultVals));
return success();
}
};
@@ -776,7 +777,7 @@ public:
// Reuses specifier.
fields.push_back(desc.getSpecifier());
assert(fields.size() == desc.getNumFields());
- rewriter.replaceOp(op, genTuple(rewriter, loc, resType, fields));
+ rewriter.replaceOpWithMultiple(op, {fields});
return success();
}
@@ -796,7 +797,7 @@ public:
sizeHint, lvlSizesValues, fields);
// Replace operation with resulting memrefs.
- rewriter.replaceOp(op, genTuple(rewriter, loc, resType, fields));
+ rewriter.replaceOpWithMultiple(op, {fields});
return success();
}
@@ -837,7 +838,7 @@ public:
sizeHint, lvlSizesValues, fields);
// Replace operation with resulting memrefs.
- rewriter.replaceOp(op, genTuple(rewriter, loc, resType, fields));
+ rewriter.replaceOpWithMultiple(op, {fields});
return success();
}
@@ -893,7 +894,7 @@ public:
if (op.getHasInserts())
genEndInsert(rewriter, op.getLoc(), desc);
// Replace operation with resulting memrefs.
- rewriter.replaceOp(op, genTuple(rewriter, op.getLoc(), desc));
+ rewriter.replaceOpWithMultiple(op, {desc.getFields()});
return success();
}
};
@@ -1006,7 +1007,6 @@ public:
rewriter.create<scf::YieldOp>(loc, insertRet);
rewriter.setInsertionPointAfter(loop);
- Value result = genTuple(rewriter, loc, dstType, loop->getResults());
// Deallocate the buffers on exit of the full loop nest.
Operation *parent = getTop(op);
rewriter.setInsertionPointAfter(parent);
@@ -1014,7 +1014,7 @@ public:
rewriter.create<memref::DeallocOp>(loc, filled);
rewriter.create<memref::DeallocOp>(loc, added);
// Replace operation with resulting memrefs.
- rewriter.replaceOp(op, result);
+ rewriter.replaceOpWithMultiple(op, {loop->getResults()});
return success();
}
};
@@ -1041,8 +1041,7 @@ public:
params, /*genCall=*/true);
SmallVector<Value> ret = insertGen.genCallOrInline(rewriter, loc);
// Replace operation with resulting memrefs.
- rewriter.replaceOp(op,
- genTuple(rewriter, loc, op.getDest().getType(), ret));
+ rewriter.replaceOpWithMultiple(op, {ret});
return success();
}
};
@@ -1215,8 +1214,7 @@ public:
return true;
});
- rewriter.replaceOp(
- op, genTuple(rewriter, loc, op.getResult().getType(), fields));
+ rewriter.replaceOpWithMultiple(op, {fields});
return success();
}
};
@@ -1271,8 +1269,7 @@ public:
// NOTE: we can not generate tuples directly from descriptor here, as the
// descriptor is holding the original type, yet we want the slice type
// here (they shared every memref but with an updated specifier).
- rewriter.replaceOp(op, genTuple(rewriter, loc, op.getResult().getType(),
- desc.getFields()));
+ rewriter.replaceOpWithMultiple(op, {desc.getFields()});
return success();
}
};
@@ -1403,7 +1400,7 @@ struct SparseAssembleOpConverter : public OpConversionPattern<AssembleOp> {
}
desc.setValMemSize(rewriter, loc, memSize);
- rewriter.replaceOp(op, genTuple(rewriter, loc, desc));
+ rewriter.replaceOpWithMultiple(op, {desc.getFields()});
return success();
}
};
@@ -1577,7 +1574,7 @@ struct SparseNewConverter : public OpConversionPattern<NewOp> {
EmitCInterface::Off);
// Replace operation with resulting memrefs.
- rewriter.replaceOp(op, genTuple(rewriter, loc, dstTp, fields));
+ rewriter.replaceOpWithMultiple(op, {fields});
return success();
}
};