diff options
Diffstat (limited to 'mlir/lib/Transforms')
| -rw-r--r-- | mlir/lib/Transforms/Utils/DialectConversion.cpp | 45 |
1 files changed, 34 insertions, 11 deletions
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp index 97dd3ab1f482..0d13eb5dbb06 100644 --- a/mlir/lib/Transforms/Utils/DialectConversion.cpp +++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp @@ -769,7 +769,7 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener { LogicalResult remapValues(StringRef valueDiagTag, std::optional<Location> inputLoc, PatternRewriter &rewriter, ValueRange values, - SmallVectorImpl<Value> &remapped); + SmallVector<SmallVector<Value, 1>> &remapped); /// Return "true" if the given operation is ignored, and does not need to be /// converted. @@ -1089,7 +1089,7 @@ void ConversionPatternRewriterImpl::undoRewrites(unsigned numRewritesToKeep) { LogicalResult ConversionPatternRewriterImpl::remapValues( StringRef valueDiagTag, std::optional<Location> inputLoc, PatternRewriter &rewriter, ValueRange values, - SmallVectorImpl<Value> &remapped) { + SmallVector<SmallVector<Value, 1>> &remapped) { remapped.reserve(llvm::size(values)); for (const auto &it : llvm::enumerate(values)) { @@ -1101,7 +1101,7 @@ LogicalResult ConversionPatternRewriterImpl::remapValues( // The current pattern does not have a type converter. I.e., it does not // distinguish between legal and illegal types. For each operand, simply // pass through the most recently mapped value. - remapped.push_back(mapping.lookupOrDefault(operand)); + remapped.push_back({mapping.lookupOrDefault(operand)}); continue; } @@ -1123,7 +1123,7 @@ LogicalResult ConversionPatternRewriterImpl::remapValues( // improvements to the `ConversionValueMapping` (to be able to store 1:N // mappings) and to the `ConversionPattern` adaptor handling (to be able // to pass multiple remapped values for a single operand to the adaptor). - remapped.push_back(mapping.lookupOrDefault(operand)); + remapped.push_back({mapping.lookupOrDefault(operand)}); continue; } @@ -1143,7 +1143,7 @@ LogicalResult ConversionPatternRewriterImpl::remapValues( mapping.map(newOperand, castValue); newOperand = castValue; } - remapped.push_back(newOperand); + remapped.push_back({newOperand}); } return success(); } @@ -1523,11 +1523,12 @@ void ConversionPatternRewriter::replaceUsesOfBlockArgument(BlockArgument from, } Value ConversionPatternRewriter::getRemappedValue(Value key) { - SmallVector<Value> remappedValues; + SmallVector<SmallVector<Value, 1>> remappedValues; if (failed(impl->remapValues("value", /*inputLoc=*/std::nullopt, *this, key, remappedValues))) return nullptr; - return remappedValues.front(); + assert(remappedValues.front().size() == 1 && "1:N conversion not supported"); + return remappedValues.front().front(); } LogicalResult @@ -1535,8 +1536,15 @@ ConversionPatternRewriter::getRemappedValues(ValueRange keys, SmallVectorImpl<Value> &results) { if (keys.empty()) return success(); - return impl->remapValues("value", /*inputLoc=*/std::nullopt, *this, keys, - results); + SmallVector<SmallVector<Value, 1>> remapped; + if (failed(impl->remapValues("value", /*inputLoc=*/std::nullopt, *this, keys, + remapped))) + return failure(); + for (const auto &values : remapped) { + assert(values.size() == 1 && "1:N conversion not supported"); + results.push_back(values.front()); + } + return success(); } void ConversionPatternRewriter::inlineBlockBefore(Block *source, Block *dest, @@ -1630,6 +1638,16 @@ detail::ConversionPatternRewriterImpl &ConversionPatternRewriter::getImpl() { // ConversionPattern //===----------------------------------------------------------------------===// +SmallVector<Value> ConversionPattern::getOneToOneAdaptorOperands( + ArrayRef<ArrayRef<Value>> operands) { + SmallVector<Value> oneToOneOperands; + oneToOneOperands.reserve(operands.size()); + for (ArrayRef<Value> operand : operands) { + assert(operand.size() == 1 && "pattern does not support 1:N conversion"); + oneToOneOperands.push_back(operand.front()); + } +} + LogicalResult ConversionPattern::matchAndRewrite(Operation *op, PatternRewriter &rewriter) const { @@ -1641,11 +1659,16 @@ ConversionPattern::matchAndRewrite(Operation *op, getTypeConverter()); // Remap the operands of the operation. - SmallVector<Value, 4> operands; + SmallVector<SmallVector<Value, 1>> remapped; if (failed(rewriterImpl.remapValues("operand", op->getLoc(), rewriter, - op->getOperands(), operands))) { + op->getOperands(), remapped))) { return failure(); } + SmallVector<Value, 4> operands; + for (const auto &values : remapped) { + assert(values.size() == 1 && "1:N conversion not supported"); + operands.push_back(values.front()); + } return matchAndRewrite(op, operands, dialectRewriter); } |
