summaryrefslogtreecommitdiff
path: root/mlir/lib/Transforms
diff options
context:
space:
mode:
authorMatthias Springer <mspringer@nvidia.com>2024-10-05 17:29:12 +0200
committerMatthias Springer <mspringer@nvidia.com>2024-10-12 10:45:42 +0200
commit16388fdda61e751c85a2dcb8beff8e2fa337b698 (patch)
tree83d4ff431204cc44e312006aba02322203b34de5 /mlir/lib/Transforms
parent9f24c145494ee238e65e25205a4dcb4451f009ae (diff)
[WIP] 1:N conversion patternusers/matthias-springer/one_to_n_pattern
Diffstat (limited to 'mlir/lib/Transforms')
-rw-r--r--mlir/lib/Transforms/Utils/DialectConversion.cpp45
1 files changed, 34 insertions, 11 deletions
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 97dd3ab1f482..0d13eb5dbb06 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -769,7 +769,7 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
LogicalResult remapValues(StringRef valueDiagTag,
std::optional<Location> inputLoc,
PatternRewriter &rewriter, ValueRange values,
- SmallVectorImpl<Value> &remapped);
+ SmallVector<SmallVector<Value, 1>> &remapped);
/// Return "true" if the given operation is ignored, and does not need to be
/// converted.
@@ -1089,7 +1089,7 @@ void ConversionPatternRewriterImpl::undoRewrites(unsigned numRewritesToKeep) {
LogicalResult ConversionPatternRewriterImpl::remapValues(
StringRef valueDiagTag, std::optional<Location> inputLoc,
PatternRewriter &rewriter, ValueRange values,
- SmallVectorImpl<Value> &remapped) {
+ SmallVector<SmallVector<Value, 1>> &remapped) {
remapped.reserve(llvm::size(values));
for (const auto &it : llvm::enumerate(values)) {
@@ -1101,7 +1101,7 @@ LogicalResult ConversionPatternRewriterImpl::remapValues(
// The current pattern does not have a type converter. I.e., it does not
// distinguish between legal and illegal types. For each operand, simply
// pass through the most recently mapped value.
- remapped.push_back(mapping.lookupOrDefault(operand));
+ remapped.push_back({mapping.lookupOrDefault(operand)});
continue;
}
@@ -1123,7 +1123,7 @@ LogicalResult ConversionPatternRewriterImpl::remapValues(
// improvements to the `ConversionValueMapping` (to be able to store 1:N
// mappings) and to the `ConversionPattern` adaptor handling (to be able
// to pass multiple remapped values for a single operand to the adaptor).
- remapped.push_back(mapping.lookupOrDefault(operand));
+ remapped.push_back({mapping.lookupOrDefault(operand)});
continue;
}
@@ -1143,7 +1143,7 @@ LogicalResult ConversionPatternRewriterImpl::remapValues(
mapping.map(newOperand, castValue);
newOperand = castValue;
}
- remapped.push_back(newOperand);
+ remapped.push_back({newOperand});
}
return success();
}
@@ -1523,11 +1523,12 @@ void ConversionPatternRewriter::replaceUsesOfBlockArgument(BlockArgument from,
}
Value ConversionPatternRewriter::getRemappedValue(Value key) {
- SmallVector<Value> remappedValues;
+ SmallVector<SmallVector<Value, 1>> remappedValues;
if (failed(impl->remapValues("value", /*inputLoc=*/std::nullopt, *this, key,
remappedValues)))
return nullptr;
- return remappedValues.front();
+ assert(remappedValues.front().size() == 1 && "1:N conversion not supported");
+ return remappedValues.front().front();
}
LogicalResult
@@ -1535,8 +1536,15 @@ ConversionPatternRewriter::getRemappedValues(ValueRange keys,
SmallVectorImpl<Value> &results) {
if (keys.empty())
return success();
- return impl->remapValues("value", /*inputLoc=*/std::nullopt, *this, keys,
- results);
+ SmallVector<SmallVector<Value, 1>> remapped;
+ if (failed(impl->remapValues("value", /*inputLoc=*/std::nullopt, *this, keys,
+ remapped)))
+ return failure();
+ for (const auto &values : remapped) {
+ assert(values.size() == 1 && "1:N conversion not supported");
+ results.push_back(values.front());
+ }
+ return success();
}
void ConversionPatternRewriter::inlineBlockBefore(Block *source, Block *dest,
@@ -1630,6 +1638,16 @@ detail::ConversionPatternRewriterImpl &ConversionPatternRewriter::getImpl() {
// ConversionPattern
//===----------------------------------------------------------------------===//
+SmallVector<Value> ConversionPattern::getOneToOneAdaptorOperands(
+ ArrayRef<ArrayRef<Value>> operands) {
+ SmallVector<Value> oneToOneOperands;
+ oneToOneOperands.reserve(operands.size());
+ for (ArrayRef<Value> operand : operands) {
+ assert(operand.size() == 1 && "pattern does not support 1:N conversion");
+ oneToOneOperands.push_back(operand.front());
+ }
+}
+
LogicalResult
ConversionPattern::matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const {
@@ -1641,11 +1659,16 @@ ConversionPattern::matchAndRewrite(Operation *op,
getTypeConverter());
// Remap the operands of the operation.
- SmallVector<Value, 4> operands;
+ SmallVector<SmallVector<Value, 1>> remapped;
if (failed(rewriterImpl.remapValues("operand", op->getLoc(), rewriter,
- op->getOperands(), operands))) {
+ op->getOperands(), remapped))) {
return failure();
}
+ SmallVector<Value, 4> operands;
+ for (const auto &values : remapped) {
+ assert(values.size() == 1 && "1:N conversion not supported");
+ operands.push_back(values.front());
+ }
return matchAndRewrite(op, operands, dialectRewriter);
}