summaryrefslogtreecommitdiff
path: root/mlir/lib/Transforms/Utils/DialectConversion.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Transforms/Utils/DialectConversion.cpp')
-rw-r--r--mlir/lib/Transforms/Utils/DialectConversion.cpp149
1 files changed, 28 insertions, 121 deletions
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index d407d60334c7..e6c0ee2ab294 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -734,11 +734,6 @@ public:
return converterAndKind.getInt();
}
- /// Set the kind of this materialization.
- void setMaterializationKind(MaterializationKind kind) {
- converterAndKind.setInt(kind);
- }
-
/// Return the original illegal output type of the input values.
Type getOrigOutputType() const { return origOutputType; }
@@ -839,27 +834,6 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
// Type Conversion
//===--------------------------------------------------------------------===//
- /// Attempt to convert the signature of the given block, if successful a new
- /// block is returned containing the new arguments. Returns `block` if it did
- /// not require conversion.
- FailureOr<Block *> convertBlockSignature(
- ConversionPatternRewriter &rewriter, Block *block,
- const TypeConverter *converter,
- TypeConverter::SignatureConversion *conversion = nullptr);
-
- /// Convert the types of non-entry block arguments within the given region.
- LogicalResult convertNonEntryRegionTypes(
- ConversionPatternRewriter &rewriter, Region *region,
- const TypeConverter &converter,
- ArrayRef<TypeConverter::SignatureConversion> blockConversions = {});
-
- /// Apply a signature conversion on the given region, using `converter` for
- /// materializations if not null.
- Block *
- applySignatureConversion(ConversionPatternRewriter &rewriter, Region *region,
- TypeConverter::SignatureConversion &conversion,
- const TypeConverter *converter);
-
/// Convert the types of block arguments within the given region.
FailureOr<Block *>
convertRegionTypes(ConversionPatternRewriter &rewriter, Region *region,
@@ -1294,34 +1268,6 @@ bool ConversionPatternRewriterImpl::wasOpReplaced(Operation *op) const {
//===----------------------------------------------------------------------===//
// Type Conversion
-FailureOr<Block *> ConversionPatternRewriterImpl::convertBlockSignature(
- ConversionPatternRewriter &rewriter, Block *block,
- const TypeConverter *converter,
- TypeConverter::SignatureConversion *conversion) {
- if (conversion)
- return applySignatureConversion(rewriter, block, converter, *conversion);
-
- // If a converter wasn't provided, and the block wasn't already converted,
- // there is nothing we can do.
- if (!converter)
- return failure();
-
- // Try to convert the signature for the block with the provided converter.
- if (auto conversion = converter->convertBlockSignature(block))
- return applySignatureConversion(rewriter, block, converter, *conversion);
- return failure();
-}
-
-Block *ConversionPatternRewriterImpl::applySignatureConversion(
- ConversionPatternRewriter &rewriter, Region *region,
- TypeConverter::SignatureConversion &conversion,
- const TypeConverter *converter) {
- if (!region->empty())
- return *convertBlockSignature(rewriter, &region->front(), converter,
- &conversion);
- return nullptr;
-}
-
FailureOr<Block *> ConversionPatternRewriterImpl::convertRegionTypes(
ConversionPatternRewriter &rewriter, Region *region,
const TypeConverter &converter,
@@ -1330,42 +1276,29 @@ FailureOr<Block *> ConversionPatternRewriterImpl::convertRegionTypes(
if (region->empty())
return nullptr;
- if (failed(convertNonEntryRegionTypes(rewriter, region, converter)))
- return failure();
-
- FailureOr<Block *> newEntry = convertBlockSignature(
- rewriter, &region->front(), &converter, entryConversion);
- return newEntry;
-}
-
-LogicalResult ConversionPatternRewriterImpl::convertNonEntryRegionTypes(
- ConversionPatternRewriter &rewriter, Region *region,
- const TypeConverter &converter,
- ArrayRef<TypeConverter::SignatureConversion> blockConversions) {
- regionToConverter[region] = &converter;
- if (region->empty())
- return success();
-
- // Convert the arguments of each block within the region.
- int blockIdx = 0;
- assert((blockConversions.empty() ||
- blockConversions.size() == region->getBlocks().size() - 1) &&
- "expected either to provide no SignatureConversions at all or to "
- "provide a SignatureConversion for each non-entry block");
-
+ // Convert the arguments of each non-entry block within the region.
for (Block &block :
llvm::make_early_inc_range(llvm::drop_begin(*region, 1))) {
- TypeConverter::SignatureConversion *blockConversion =
- blockConversions.empty()
- ? nullptr
- : const_cast<TypeConverter::SignatureConversion *>(
- &blockConversions[blockIdx++]);
-
- if (failed(convertBlockSignature(rewriter, &block, &converter,
- blockConversion)))
+ // Compute the signature for the block with the provided converter.
+ std::optional<TypeConverter::SignatureConversion> conversion =
+ converter.convertBlockSignature(&block);
+ if (!conversion)
return failure();
- }
- return success();
+ // Convert the block with the computed signature.
+ applySignatureConversion(rewriter, &block, &converter, *conversion);
+ }
+
+ // Convert the entry block. If an entry signature conversion was provided,
+ // use that one. Otherwise, compute the signature with the type converter.
+ if (entryConversion)
+ return applySignatureConversion(rewriter, &region->front(), &converter,
+ *entryConversion);
+ std::optional<TypeConverter::SignatureConversion> conversion =
+ converter.convertBlockSignature(&region->front());
+ if (!conversion)
+ return failure();
+ return applySignatureConversion(rewriter, &region->front(), &converter,
+ *conversion);
}
Block *ConversionPatternRewriterImpl::applySignatureConversion(
@@ -1676,12 +1609,12 @@ void ConversionPatternRewriter::eraseBlock(Block *block) {
}
Block *ConversionPatternRewriter::applySignatureConversion(
- Region *region, TypeConverter::SignatureConversion &conversion,
+ Block *block, TypeConverter::SignatureConversion &conversion,
const TypeConverter *converter) {
- assert(!impl->wasOpReplaced(region->getParentOp()) &&
+ assert(!impl->wasOpReplaced(block->getParentOp()) &&
"attempting to apply a signature conversion to a block within a "
"replaced/erased op");
- return impl->applySignatureConversion(*this, region, conversion, converter);
+ return impl->applySignatureConversion(*this, block, converter, conversion);
}
FailureOr<Block *> ConversionPatternRewriter::convertRegionTypes(
@@ -1693,16 +1626,6 @@ FailureOr<Block *> ConversionPatternRewriter::convertRegionTypes(
return impl->convertRegionTypes(*this, region, converter, entryConversion);
}
-LogicalResult ConversionPatternRewriter::convertNonEntryRegionTypes(
- Region *region, const TypeConverter &converter,
- ArrayRef<TypeConverter::SignatureConversion> blockConversions) {
- assert(!impl->wasOpReplaced(region->getParentOp()) &&
- "attempting to apply a signature conversion to a block within a "
- "replaced/erased op");
- return impl->convertNonEntryRegionTypes(*this, region, converter,
- blockConversions);
-}
-
void ConversionPatternRewriter::replaceUsesOfBlockArgument(BlockArgument from,
Value to) {
LLVM_DEBUG({
@@ -2231,11 +2154,14 @@ LogicalResult OperationLegalizer::legalizePatternBlockRewrites(
// If the region of the block has a type converter, try to convert the block
// directly.
if (auto *converter = impl.regionToConverter.lookup(block->getParent())) {
- if (failed(impl.convertBlockSignature(rewriter, block, converter))) {
+ std::optional<TypeConverter::SignatureConversion> conversion =
+ converter->convertBlockSignature(block);
+ if (!conversion) {
LLVM_DEBUG(logFailure(impl.logger, "failed to convert types of moved "
"block"));
return failure();
}
+ impl.applySignatureConversion(rewriter, block, converter, *conversion);
continue;
}
@@ -2828,22 +2754,9 @@ static void computeNecessaryMaterializations(
// TODO: Avoid materializing other types of conversions here.
}
- // Check to see if this is an argument materialization.
- if (llvm::any_of(op->getOperands(), llvm::IsaPred<BlockArgument>) ||
- llvm::any_of(inverseMapping[op->getResult(0)],
- llvm::IsaPred<BlockArgument>)) {
- mat->setMaterializationKind(MaterializationKind::Argument);
- }
-
// If the materialization does not have any live users, we don't need to
// generate a user materialization for it.
- // FIXME: For argument materializations, we currently need to check if any
- // of the inverse mapped values are used because some patterns expect blind
- // value replacement even if the types differ in some cases. When those
- // patterns are fixed, we can drop the argument special case here.
bool isMaterializationLive = isLive(opResult);
- if (mat->getMaterializationKind() == MaterializationKind::Argument)
- isMaterializationLive |= llvm::any_of(inverseMapping[opResult], isLive);
if (!isMaterializationLive)
continue;
if (!necessaryMaterializations.insert(mat))
@@ -2926,13 +2839,7 @@ static LogicalResult legalizeUnresolvedMaterialization(
// Try to materialize the conversion.
if (const TypeConverter *converter = mat.getConverter()) {
- // FIXME: Determine a suitable insertion location when there are multiple
- // inputs.
- if (inputOperands.size() == 1)
- rewriter.setInsertionPointAfterValue(inputOperands.front());
- else
- rewriter.setInsertionPoint(op);
-
+ rewriter.setInsertionPoint(op);
Value newMaterialization;
switch (mat.getMaterializationKind()) {
case MaterializationKind::Argument: