diff options
Diffstat (limited to 'mlir/test/lib/Dialect/Test/TestPatterns.cpp')
| -rw-r--r-- | mlir/test/lib/Dialect/Test/TestPatterns.cpp | 46 |
1 files changed, 45 insertions, 1 deletions
diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp index e931b394c862..f8df89ff83cc 100644 --- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp +++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp @@ -902,6 +902,45 @@ struct TestUndoBlockArgReplace : public ConversionPattern { } }; +struct TestConvertBlockAndReplaceArg : public ConversionPattern { + TestConvertBlockAndReplaceArg(MLIRContext *ctx, + const TypeConverter &converter) + : ConversionPattern(converter, "test.convert_block_and_replace_arg", + /*benefit=*/1, ctx) {} + + LogicalResult + matchAndRewrite(Operation *op, ArrayRef<Value> operands, + ConversionPatternRewriter &rewriter) const final { + // Expect single region with single block with single block argument. + if (op->getNumRegions() != 1) + return failure(); + if (op->getRegion(0).getBlocks().size() != 1) + return failure(); + Block *block = &op->getRegion(0).front(); + if (block->getArguments().size() != 1) + return failure(); + + // Convert the block argument into to F64 block arguments. + TypeConverter::SignatureConversion result(1); + result.addInputs(0, {rewriter.getF64Type(), rewriter.getF64Type()}); + Block *newBlock = + rewriter.applySignatureConversion(block, result, getTypeConverter()); + + // Create a replacement value. + rewriter.setInsertionPointToStart(newBlock); + Value repl = rewriter.create<TestTypeProducerOp>(op->getLoc(), + rewriter.getF64Type()); + BlockArgument arg0 = newBlock->getArgument(0); + // Replace the block argument. + rewriter.replaceUsesOfBlockArgument(arg0, repl); + + // Mark the op as legal. + rewriter.modifyOpInPlace( + op, [&]() { op->setAttr("legal", rewriter.getUnitAttr()); }); + return success(); + } +}; + /// This pattern hoists ops out of a "test.hoist_me" and then fails conversion. /// This is to test the rollback logic. struct TestUndoMoveOpBefore : public ConversionPattern { @@ -1265,7 +1304,8 @@ struct TestLegalizePatternDriver TestCreateUnregisteredOp, TestUndoMoveOpBefore, TestUndoPropertiesModification, TestEraseOp>(&getContext()); patterns.add<TestDropOpSignatureConversion, TestDropAndReplaceInvalidOp, - TestPassthroughInvalidOp>(&getContext(), converter); + TestPassthroughInvalidOp, TestConvertBlockAndReplaceArg>( + &getContext(), converter); mlir::populateAnyFunctionOpInterfaceTypeConversionPattern(patterns, converter); mlir::populateCallOpTypeConversionPattern(patterns, converter); @@ -1317,6 +1357,10 @@ struct TestLegalizePatternDriver target.addDynamicallyLegalOp<TestOpInPlaceSelfFold>( [](TestOpInPlaceSelfFold op) { return op.getFolded(); }); + target.addDynamicallyLegalOp( + OperationName("test.convert_block_and_replace_arg", &getContext()), + [](Operation *op) { return op->hasAttr("legal"); }); + // Handle a partial conversion. if (mode == ConversionMode::Partial) { DenseSet<Operation *> unlegalizedOps; |
