diff options
Diffstat (limited to 'mlir/test/python/dialects/transform.py')
| -rw-r--r-- | mlir/test/python/dialects/transform.py | 215 |
1 files changed, 142 insertions, 73 deletions
diff --git a/mlir/test/python/dialects/transform.py b/mlir/test/python/dialects/transform.py index 6c5e4e5505b1..f58442d04fc6 100644 --- a/mlir/test/python/dialects/transform.py +++ b/mlir/test/python/dialects/transform.py @@ -51,6 +51,26 @@ def testSequenceOp(module: Module): transform.AnyOpType.get(), ) with InsertionPoint(sequence.body): + res = transform.CastOp(transform.AnyOpType.get(), sequence.bodyTarget) + res2 = transform.cast(transform.any_op_t(), res.result) + transform.YieldOp([res2]) + # CHECK-LABEL: TEST: testSequenceOp + # CHECK: transform.sequence + # CHECK: ^{{.*}}(%[[ARG0:.+]]: !transform.any_op): + # CHECK: %[[RES:.+]] = cast %[[ARG0]] : !transform.any_op to !transform.any_op + # CHECK: %[[RES2:.+]] = cast %[[RES]] : !transform.any_op to !transform.any_op + # CHECK: yield %[[RES2]] : !transform.any_op + # CHECK: } + + +@run +def testSequenceOp(module: Module): + sequence = transform.SequenceOp( + transform.FailurePropagationMode.Propagate, + [transform.AnyOpType.get()], + transform.AnyOpType.get(), + ) + with InsertionPoint(sequence.body): transform.YieldOp([sequence.bodyTarget]) # CHECK-LABEL: TEST: testSequenceOp # CHECK: = transform.sequence -> !transform.any_op failures(propagate) { @@ -58,6 +78,7 @@ def testSequenceOp(module: Module): # CHECK: yield %[[ARG0]] : !transform.any_op # CHECK: } + @run def testNestedSequenceOp(module: Module): sequence = transform.SequenceOp( @@ -103,55 +124,65 @@ def testSequenceOpWithExtras(module: Module): # CHECK-LABEL: TEST: testSequenceOpWithExtras # CHECK: transform.sequence failures(propagate) # CHECK: ^{{.*}}(%{{.*}}: !transform.any_op, %{{.*}}: !transform.any_op, %{{.*}}: !transform.op<"foo.bar">): + sequence = transform.sequence( + transform.FailurePropagationMode.Propagate, + [], + transform.AnyOpType.get(), + [transform.AnyOpType.get(), transform.OperationType.get("foo.bar")], + ) + with InsertionPoint(sequence.body): + transform.yield_() + # CHECK: transform.sequence failures(propagate) + # CHECK: ^{{.*}}(%{{.*}}: !transform.any_op, %{{.*}}: !transform.any_op, %{{.*}}: !transform.op<"foo.bar">): @run def testNestedSequenceOpWithExtras(module: Module): - sequence = transform.SequenceOp( + sequence = transform.SequenceOp( transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get(), [transform.AnyOpType.get(), transform.OperationType.get("foo.bar")], ) - with InsertionPoint(sequence.body): - nested = transform.SequenceOp( + with InsertionPoint(sequence.body): + nested = transform.SequenceOp( transform.FailurePropagationMode.Propagate, [], sequence.bodyTarget, sequence.bodyExtraArgs, ) - with InsertionPoint(nested.body): - transform.YieldOp() - transform.YieldOp() - # CHECK-LABEL: TEST: testNestedSequenceOpWithExtras - # CHECK: transform.sequence failures(propagate) - # CHECK: ^{{.*}}(%[[ARG0:.*]]: !transform.any_op, %[[ARG1:.*]]: !transform.any_op, %[[ARG2:.*]]: !transform.op<"foo.bar">): - # CHECK: sequence %[[ARG0]], %[[ARG1]], %[[ARG2]] : (!transform.any_op, !transform.any_op, !transform.op<"foo.bar">) + with InsertionPoint(nested.body): + transform.YieldOp() + transform.YieldOp() + # CHECK-LABEL: TEST: testNestedSequenceOpWithExtras + # CHECK: transform.sequence failures(propagate) + # CHECK: ^{{.*}}(%[[ARG0:.*]]: !transform.any_op, %[[ARG1:.*]]: !transform.any_op, %[[ARG2:.*]]: !transform.op<"foo.bar">): + # CHECK: sequence %[[ARG0]], %[[ARG1]], %[[ARG2]] : (!transform.any_op, !transform.any_op, !transform.op<"foo.bar">) @run def testTransformPDLOps(module: Module): - withPdl = transform_pdl.WithPDLPatternsOp(transform.AnyOpType.get()) - with InsertionPoint(withPdl.body): - sequence = transform.SequenceOp( - transform.FailurePropagationMode.Propagate, - [transform.AnyOpType.get()], - withPdl.bodyTarget, - ) - with InsertionPoint(sequence.body): - match = transform_pdl.PDLMatchOp( - transform.AnyOpType.get(), sequence.bodyTarget, "pdl_matcher" - ) - transform.YieldOp(match) - # CHECK-LABEL: TEST: testTransformPDLOps - # CHECK: transform.with_pdl_patterns { - # CHECK: ^{{.*}}(%[[ARG0:.+]]: !transform.any_op): - # CHECK: = sequence %[[ARG0]] : !transform.any_op -> !transform.any_op failures(propagate) { - # CHECK: ^{{.*}}(%[[ARG1:.+]]: !transform.any_op): - # CHECK: %[[RES:.+]] = pdl_match @pdl_matcher in %[[ARG1]] - # CHECK: yield %[[RES]] : !transform.any_op - # CHECK: } - # CHECK: } + withPdl = transform_pdl.WithPDLPatternsOp(transform.AnyOpType.get()) + with InsertionPoint(withPdl.body): + sequence = transform.SequenceOp( + transform.FailurePropagationMode.Propagate, + [transform.AnyOpType.get()], + withPdl.bodyTarget, + ) + with InsertionPoint(sequence.body): + match = transform_pdl.PDLMatchOp( + transform.AnyOpType.get(), sequence.bodyTarget, "pdl_matcher" + ) + transform.YieldOp(match) + # CHECK-LABEL: TEST: testTransformPDLOps + # CHECK: transform.with_pdl_patterns { + # CHECK: ^{{.*}}(%[[ARG0:.+]]: !transform.any_op): + # CHECK: = sequence %[[ARG0]] : !transform.any_op -> !transform.any_op failures(propagate) { + # CHECK: ^{{.*}}(%[[ARG1:.+]]: !transform.any_op): + # CHECK: %[[RES:.+]] = pdl_match @pdl_matcher in %[[ARG1]] + # CHECK: yield %[[RES]] : !transform.any_op + # CHECK: } + # CHECK: } @run @@ -161,32 +192,53 @@ def testNamedSequenceOp(module: Module): "__transform_main", [transform.AnyOpType.get()], [transform.AnyOpType.get()], - arg_attrs = [{"transform.consumed": UnitAttr.get()}]) + arg_attrs=[{"transform.consumed": UnitAttr.get()}], + ) with InsertionPoint(named_sequence.body): transform.YieldOp([named_sequence.bodyTarget]) # CHECK-LABEL: TEST: testNamedSequenceOp # CHECK: module attributes {transform.with_named_sequence} { - # CHECK: transform.named_sequence @__transform_main(%[[ARG0:.+]]: !transform.any_op {transform.consumed}) -> !transform.any_op { - # CHECK: yield %[[ARG0]] : !transform.any_op + # CHECK: transform.named_sequence @__transform_main(%[[ARG0:.+]]: !transform.any_op {transform.consumed}) -> !transform.any_op { + # CHECK: yield %[[ARG0]] : !transform.any_op + named_sequence = transform.named_sequence( + "other_seq", + [transform.AnyOpType.get()], + [transform.AnyOpType.get()], + arg_attrs=[{"transform.consumed": UnitAttr.get()}], + ) + with InsertionPoint(named_sequence.body): + transform.yield_([named_sequence.bodyTarget]) + # CHECK: transform.named_sequence @other_seq(%[[ARG1:.+]]: !transform.any_op {transform.consumed}) -> !transform.any_op { + # CHECK: yield %[[ARG1]] : !transform.any_op @run def testGetParentOp(module: Module): - sequence = transform.SequenceOp( - transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get() - ) - with InsertionPoint(sequence.body): - transform.GetParentOp( - transform.AnyOpType.get(), - sequence.bodyTarget, - isolated_from_above=True, - nth_parent=2, + sequence = transform.SequenceOp( + transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get() ) - transform.YieldOp() - # CHECK-LABEL: TEST: testGetParentOp - # CHECK: transform.sequence - # CHECK: ^{{.*}}(%[[ARG1:.+]]: !transform.any_op): - # CHECK: = get_parent_op %[[ARG1]] {isolated_from_above, nth_parent = 2 : i64} + with InsertionPoint(sequence.body): + transform.GetParentOp( + transform.AnyOpType.get(), + sequence.bodyTarget, + isolated_from_above=True, + nth_parent=2, + ) + transform.get_parent_op( + transform.AnyOpType.get(), + sequence.bodyTarget, + isolated_from_above=True, + nth_parent=2, + allow_empty_results=True, + op_name="func.func", + deduplicate=True, + ) + transform.YieldOp() + # CHECK-LABEL: TEST: testGetParentOp + # CHECK: transform.sequence + # CHECK: ^{{.*}}(%[[ARG1:.+]]: !transform.any_op): + # CHECK: = get_parent_op %[[ARG1]] {isolated_from_above, nth_parent = 2 : i64} + # CHECK: = get_parent_op %[[ARG1]] {allow_empty_results, deduplicate, isolated_from_above, nth_parent = 2 : i64, op_name = "func.func"} @run @@ -195,43 +247,58 @@ def testMergeHandlesOp(module: Module): transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get() ) with InsertionPoint(sequence.body): - transform.MergeHandlesOp([sequence.bodyTarget]) + res = transform.MergeHandlesOp([sequence.bodyTarget]) + transform.merge_handles([res.result], deduplicate=True) transform.YieldOp() # CHECK-LABEL: TEST: testMergeHandlesOp # CHECK: transform.sequence # CHECK: ^{{.*}}(%[[ARG1:.+]]: !transform.any_op): - # CHECK: = merge_handles %[[ARG1]] + # CHECK: %[[RES1:.+]] = merge_handles %[[ARG1]] : !transform.any_op + # CHECK: = merge_handles deduplicate %[[RES1]] : !transform.any_op @run def testApplyPatternsOpCompact(module: Module): - sequence = transform.SequenceOp( - transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get() - ) - with InsertionPoint(sequence.body): - with InsertionPoint(transform.ApplyPatternsOp(sequence.bodyTarget).patterns): - transform.ApplyCanonicalizationPatternsOp() - transform.YieldOp() - # CHECK-LABEL: TEST: testApplyPatternsOpCompact - # CHECK: apply_patterns to - # CHECK: transform.apply_patterns.canonicalization - # CHECK: !transform.any_op + sequence = transform.SequenceOp( + transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get() + ) + with InsertionPoint(sequence.body): + with InsertionPoint(transform.ApplyPatternsOp(sequence.bodyTarget).patterns): + transform.ApplyCanonicalizationPatternsOp() + with InsertionPoint( + transform.apply_patterns( + sequence.bodyTarget, + apply_cse=True, + max_iterations=3, + max_num_rewrites=5, + ).patterns + ): + transform.ApplyCanonicalizationPatternsOp() + transform.YieldOp() + # CHECK-LABEL: TEST: testApplyPatternsOpCompact + # CHECK: apply_patterns to + # CHECK: transform.apply_patterns.canonicalization + # CHECK: } : !transform.any_op + # CHECK: apply_patterns to + # CHECK: transform.apply_patterns.canonicalization + # CHECK: } {apply_cse, max_iterations = 3 : i64, max_num_rewrites = 5 : i64} : !transform.any_op @run def testApplyPatternsOpWithType(module: Module): - sequence = transform.SequenceOp( - transform.FailurePropagationMode.Propagate, [], - transform.OperationType.get('test.dummy') - ) - with InsertionPoint(sequence.body): - with InsertionPoint(transform.ApplyPatternsOp(sequence.bodyTarget).patterns): - transform.ApplyCanonicalizationPatternsOp() - transform.YieldOp() - # CHECK-LABEL: TEST: testApplyPatternsOp - # CHECK: apply_patterns to - # CHECK: transform.apply_patterns.canonicalization - # CHECK: !transform.op<"test.dummy"> + sequence = transform.SequenceOp( + transform.FailurePropagationMode.Propagate, + [], + transform.OperationType.get("test.dummy"), + ) + with InsertionPoint(sequence.body): + with InsertionPoint(transform.ApplyPatternsOp(sequence.bodyTarget).patterns): + transform.ApplyCanonicalizationPatternsOp() + transform.YieldOp() + # CHECK-LABEL: TEST: testApplyPatternsOp + # CHECK: apply_patterns to + # CHECK: transform.apply_patterns.canonicalization + # CHECK: !transform.op<"test.dummy"> @run @@ -249,11 +316,13 @@ def testReplicateOp(module: Module): transform.AnyOpType.get(), sequence.bodyTarget, "second" ) transform.ReplicateOp(m1, [m2]) + transform.replicate(m1, [m2]) transform.YieldOp() # CHECK-LABEL: TEST: testReplicateOp # CHECK: %[[FIRST:.+]] = pdl_match # CHECK: %[[SECOND:.+]] = pdl_match # CHECK: %{{.*}} = replicate num(%[[FIRST]]) %[[SECOND]] + # CHECK: %{{.*}} = replicate num(%[[FIRST]]) %[[SECOND]] # CHECK-LABEL: TEST: testApplyRegisteredPassOp |
