diff options
Diffstat (limited to 'mlir/lib/Transforms/Utils/DialectConversion.cpp')
| -rw-r--r-- | mlir/lib/Transforms/Utils/DialectConversion.cpp | 149 |
1 files changed, 28 insertions, 121 deletions
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp index d407d60334c7..e6c0ee2ab294 100644 --- a/mlir/lib/Transforms/Utils/DialectConversion.cpp +++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp @@ -734,11 +734,6 @@ public: return converterAndKind.getInt(); } - /// Set the kind of this materialization. - void setMaterializationKind(MaterializationKind kind) { - converterAndKind.setInt(kind); - } - /// Return the original illegal output type of the input values. Type getOrigOutputType() const { return origOutputType; } @@ -839,27 +834,6 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener { // Type Conversion //===--------------------------------------------------------------------===// - /// Attempt to convert the signature of the given block, if successful a new - /// block is returned containing the new arguments. Returns `block` if it did - /// not require conversion. - FailureOr<Block *> convertBlockSignature( - ConversionPatternRewriter &rewriter, Block *block, - const TypeConverter *converter, - TypeConverter::SignatureConversion *conversion = nullptr); - - /// Convert the types of non-entry block arguments within the given region. - LogicalResult convertNonEntryRegionTypes( - ConversionPatternRewriter &rewriter, Region *region, - const TypeConverter &converter, - ArrayRef<TypeConverter::SignatureConversion> blockConversions = {}); - - /// Apply a signature conversion on the given region, using `converter` for - /// materializations if not null. - Block * - applySignatureConversion(ConversionPatternRewriter &rewriter, Region *region, - TypeConverter::SignatureConversion &conversion, - const TypeConverter *converter); - /// Convert the types of block arguments within the given region. FailureOr<Block *> convertRegionTypes(ConversionPatternRewriter &rewriter, Region *region, @@ -1294,34 +1268,6 @@ bool ConversionPatternRewriterImpl::wasOpReplaced(Operation *op) const { //===----------------------------------------------------------------------===// // Type Conversion -FailureOr<Block *> ConversionPatternRewriterImpl::convertBlockSignature( - ConversionPatternRewriter &rewriter, Block *block, - const TypeConverter *converter, - TypeConverter::SignatureConversion *conversion) { - if (conversion) - return applySignatureConversion(rewriter, block, converter, *conversion); - - // If a converter wasn't provided, and the block wasn't already converted, - // there is nothing we can do. - if (!converter) - return failure(); - - // Try to convert the signature for the block with the provided converter. - if (auto conversion = converter->convertBlockSignature(block)) - return applySignatureConversion(rewriter, block, converter, *conversion); - return failure(); -} - -Block *ConversionPatternRewriterImpl::applySignatureConversion( - ConversionPatternRewriter &rewriter, Region *region, - TypeConverter::SignatureConversion &conversion, - const TypeConverter *converter) { - if (!region->empty()) - return *convertBlockSignature(rewriter, ®ion->front(), converter, - &conversion); - return nullptr; -} - FailureOr<Block *> ConversionPatternRewriterImpl::convertRegionTypes( ConversionPatternRewriter &rewriter, Region *region, const TypeConverter &converter, @@ -1330,42 +1276,29 @@ FailureOr<Block *> ConversionPatternRewriterImpl::convertRegionTypes( if (region->empty()) return nullptr; - if (failed(convertNonEntryRegionTypes(rewriter, region, converter))) - return failure(); - - FailureOr<Block *> newEntry = convertBlockSignature( - rewriter, ®ion->front(), &converter, entryConversion); - return newEntry; -} - -LogicalResult ConversionPatternRewriterImpl::convertNonEntryRegionTypes( - ConversionPatternRewriter &rewriter, Region *region, - const TypeConverter &converter, - ArrayRef<TypeConverter::SignatureConversion> blockConversions) { - regionToConverter[region] = &converter; - if (region->empty()) - return success(); - - // Convert the arguments of each block within the region. - int blockIdx = 0; - assert((blockConversions.empty() || - blockConversions.size() == region->getBlocks().size() - 1) && - "expected either to provide no SignatureConversions at all or to " - "provide a SignatureConversion for each non-entry block"); - + // Convert the arguments of each non-entry block within the region. for (Block &block : llvm::make_early_inc_range(llvm::drop_begin(*region, 1))) { - TypeConverter::SignatureConversion *blockConversion = - blockConversions.empty() - ? nullptr - : const_cast<TypeConverter::SignatureConversion *>( - &blockConversions[blockIdx++]); - - if (failed(convertBlockSignature(rewriter, &block, &converter, - blockConversion))) + // Compute the signature for the block with the provided converter. + std::optional<TypeConverter::SignatureConversion> conversion = + converter.convertBlockSignature(&block); + if (!conversion) return failure(); - } - return success(); + // Convert the block with the computed signature. + applySignatureConversion(rewriter, &block, &converter, *conversion); + } + + // Convert the entry block. If an entry signature conversion was provided, + // use that one. Otherwise, compute the signature with the type converter. + if (entryConversion) + return applySignatureConversion(rewriter, ®ion->front(), &converter, + *entryConversion); + std::optional<TypeConverter::SignatureConversion> conversion = + converter.convertBlockSignature(®ion->front()); + if (!conversion) + return failure(); + return applySignatureConversion(rewriter, ®ion->front(), &converter, + *conversion); } Block *ConversionPatternRewriterImpl::applySignatureConversion( @@ -1676,12 +1609,12 @@ void ConversionPatternRewriter::eraseBlock(Block *block) { } Block *ConversionPatternRewriter::applySignatureConversion( - Region *region, TypeConverter::SignatureConversion &conversion, + Block *block, TypeConverter::SignatureConversion &conversion, const TypeConverter *converter) { - assert(!impl->wasOpReplaced(region->getParentOp()) && + assert(!impl->wasOpReplaced(block->getParentOp()) && "attempting to apply a signature conversion to a block within a " "replaced/erased op"); - return impl->applySignatureConversion(*this, region, conversion, converter); + return impl->applySignatureConversion(*this, block, converter, conversion); } FailureOr<Block *> ConversionPatternRewriter::convertRegionTypes( @@ -1693,16 +1626,6 @@ FailureOr<Block *> ConversionPatternRewriter::convertRegionTypes( return impl->convertRegionTypes(*this, region, converter, entryConversion); } -LogicalResult ConversionPatternRewriter::convertNonEntryRegionTypes( - Region *region, const TypeConverter &converter, - ArrayRef<TypeConverter::SignatureConversion> blockConversions) { - assert(!impl->wasOpReplaced(region->getParentOp()) && - "attempting to apply a signature conversion to a block within a " - "replaced/erased op"); - return impl->convertNonEntryRegionTypes(*this, region, converter, - blockConversions); -} - void ConversionPatternRewriter::replaceUsesOfBlockArgument(BlockArgument from, Value to) { LLVM_DEBUG({ @@ -2231,11 +2154,14 @@ LogicalResult OperationLegalizer::legalizePatternBlockRewrites( // If the region of the block has a type converter, try to convert the block // directly. if (auto *converter = impl.regionToConverter.lookup(block->getParent())) { - if (failed(impl.convertBlockSignature(rewriter, block, converter))) { + std::optional<TypeConverter::SignatureConversion> conversion = + converter->convertBlockSignature(block); + if (!conversion) { LLVM_DEBUG(logFailure(impl.logger, "failed to convert types of moved " "block")); return failure(); } + impl.applySignatureConversion(rewriter, block, converter, *conversion); continue; } @@ -2828,22 +2754,9 @@ static void computeNecessaryMaterializations( // TODO: Avoid materializing other types of conversions here. } - // Check to see if this is an argument materialization. - if (llvm::any_of(op->getOperands(), llvm::IsaPred<BlockArgument>) || - llvm::any_of(inverseMapping[op->getResult(0)], - llvm::IsaPred<BlockArgument>)) { - mat->setMaterializationKind(MaterializationKind::Argument); - } - // If the materialization does not have any live users, we don't need to // generate a user materialization for it. - // FIXME: For argument materializations, we currently need to check if any - // of the inverse mapped values are used because some patterns expect blind - // value replacement even if the types differ in some cases. When those - // patterns are fixed, we can drop the argument special case here. bool isMaterializationLive = isLive(opResult); - if (mat->getMaterializationKind() == MaterializationKind::Argument) - isMaterializationLive |= llvm::any_of(inverseMapping[opResult], isLive); if (!isMaterializationLive) continue; if (!necessaryMaterializations.insert(mat)) @@ -2926,13 +2839,7 @@ static LogicalResult legalizeUnresolvedMaterialization( // Try to materialize the conversion. if (const TypeConverter *converter = mat.getConverter()) { - // FIXME: Determine a suitable insertion location when there are multiple - // inputs. - if (inputOperands.size() == 1) - rewriter.setInsertionPointAfterValue(inputOperands.front()); - else - rewriter.setInsertionPoint(op); - + rewriter.setInsertionPoint(op); Value newMaterialization; switch (mat.getMaterializationKind()) { case MaterializationKind::Argument: |
