summaryrefslogtreecommitdiff
path: root/mlir/test/lib/Dialect/Test/TestPatterns.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/test/lib/Dialect/Test/TestPatterns.cpp')
-rw-r--r--mlir/test/lib/Dialect/Test/TestPatterns.cpp46
1 files changed, 45 insertions, 1 deletions
diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index e931b394c862..f8df89ff83cc 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -902,6 +902,45 @@ struct TestUndoBlockArgReplace : public ConversionPattern {
}
};
+struct TestConvertBlockAndReplaceArg : public ConversionPattern {
+ TestConvertBlockAndReplaceArg(MLIRContext *ctx,
+ const TypeConverter &converter)
+ : ConversionPattern(converter, "test.convert_block_and_replace_arg",
+ /*benefit=*/1, ctx) {}
+
+ LogicalResult
+ matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const final {
+ // Expect single region with single block with single block argument.
+ if (op->getNumRegions() != 1)
+ return failure();
+ if (op->getRegion(0).getBlocks().size() != 1)
+ return failure();
+ Block *block = &op->getRegion(0).front();
+ if (block->getArguments().size() != 1)
+ return failure();
+
+ // Convert the block argument into to F64 block arguments.
+ TypeConverter::SignatureConversion result(1);
+ result.addInputs(0, {rewriter.getF64Type(), rewriter.getF64Type()});
+ Block *newBlock =
+ rewriter.applySignatureConversion(block, result, getTypeConverter());
+
+ // Create a replacement value.
+ rewriter.setInsertionPointToStart(newBlock);
+ Value repl = rewriter.create<TestTypeProducerOp>(op->getLoc(),
+ rewriter.getF64Type());
+ BlockArgument arg0 = newBlock->getArgument(0);
+ // Replace the block argument.
+ rewriter.replaceUsesOfBlockArgument(arg0, repl);
+
+ // Mark the op as legal.
+ rewriter.modifyOpInPlace(
+ op, [&]() { op->setAttr("legal", rewriter.getUnitAttr()); });
+ return success();
+ }
+};
+
/// This pattern hoists ops out of a "test.hoist_me" and then fails conversion.
/// This is to test the rollback logic.
struct TestUndoMoveOpBefore : public ConversionPattern {
@@ -1265,7 +1304,8 @@ struct TestLegalizePatternDriver
TestCreateUnregisteredOp, TestUndoMoveOpBefore,
TestUndoPropertiesModification, TestEraseOp>(&getContext());
patterns.add<TestDropOpSignatureConversion, TestDropAndReplaceInvalidOp,
- TestPassthroughInvalidOp>(&getContext(), converter);
+ TestPassthroughInvalidOp, TestConvertBlockAndReplaceArg>(
+ &getContext(), converter);
mlir::populateAnyFunctionOpInterfaceTypeConversionPattern(patterns,
converter);
mlir::populateCallOpTypeConversionPattern(patterns, converter);
@@ -1317,6 +1357,10 @@ struct TestLegalizePatternDriver
target.addDynamicallyLegalOp<TestOpInPlaceSelfFold>(
[](TestOpInPlaceSelfFold op) { return op.getFolded(); });
+ target.addDynamicallyLegalOp(
+ OperationName("test.convert_block_and_replace_arg", &getContext()),
+ [](Operation *op) { return op->hasAttr("legal"); });
+
// Handle a partial conversion.
if (mode == ConversionMode::Partial) {
DenseSet<Operation *> unlegalizedOps;