diff options
| author | Rolf Morel <rolf.morel@intel.com> | 2025-10-01 14:47:35 +0100 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2025-10-01 13:47:35 +0000 |
| commit | f4d18c0ef8e3207b8ee2363fea60f21d4fa325bc (patch) | |
| tree | 4fea1d9ceb566cfa12d5993009296a0dcfde0391 /mlir/python | |
| parent | a33544b83c80dcaa851fabd2979def6f68dd6e7a (diff) | |
[MLIR][Transform][Tune] Introduce `transform.tune.alternatives` op (#160724)
This op enables expressing uncertainty regarding what should be
happening at particular places in transform-dialect schedules. In
particular, it enables representing a choice among alternative regions.
This choice is resolved through providing a `selected_region` argument.
When this argument is provided, the semantics are such that it is valid
to rewrite the op through substituting in the selected region -- with
the op's interpreted semantics corresponding to exactly this.
This op represents another piece of the puzzle w.r.t. a toolkit for
expressing autotuning problems with the transform dialect. Note that
this goes beyond tuning knobs _on_ transforms, going further by making
it tunable which (sequences of) transforms are to be applied.
Diffstat (limited to 'mlir/python')
| -rw-r--r-- | mlir/python/mlir/dialects/transform/tune.py | 66 |
1 files changed, 63 insertions, 3 deletions
diff --git a/mlir/python/mlir/dialects/transform/tune.py b/mlir/python/mlir/dialects/transform/tune.py index f63f88a38242..b3bfa8015c4d 100644 --- a/mlir/python/mlir/dialects/transform/tune.py +++ b/mlir/python/mlir/dialects/transform/tune.py @@ -6,6 +6,9 @@ from typing import Optional, Sequence from ...ir import ( Type, + Value, + Operation, + OpView, Attribute, ArrayAttr, StringAttr, @@ -19,7 +22,10 @@ from .._transform_tune_extension_ops_gen import * from .._transform_tune_extension_ops_gen import _Dialect try: - from .._ods_common import _cext as _ods_cext + from .._ods_common import ( + get_op_result_or_value as _get_op_result_or_value, + _cext as _ods_cext, + ) except ImportError as e: raise RuntimeError("Error loading imports from extension module") from e @@ -36,7 +42,7 @@ class KnobOp(KnobOp): ArrayAttr, Sequence[Union[Attribute, bool, int, float, str]], Attribute ], *, - selected: Optional[Attribute] = None, + selected: Optional[Union[Attribute, bool, int, float, str]] = None, loc=None, ip=None, ): @@ -75,8 +81,62 @@ def knob( ArrayAttr, Sequence[Union[Attribute, bool, int, float, str]], Attribute ], *, - selected: Optional[Attribute] = None, + selected: Optional[Union[Attribute, bool, int, float, str]] = None, loc=None, ip=None, ): return KnobOp(result, name, options, selected=selected, loc=loc, ip=ip) + + +@_ods_cext.register_operation(_Dialect, replace=True) +class AlternativesOp(AlternativesOp): + def __init__( + self, + results: Sequence[Type], + name: Union[StringAttr, str], + num_alternatives: int, + *, + selected_region: Optional[ + Union[int, IntegerAttr, Value, Operation, OpView] + ] = None, + loc=None, + ip=None, + ): + if isinstance(name, str): + name = StringAttr.get(name) + + selected_region_attr = selected_region_param = None + if isinstance(selected_region, IntegerAttr): + selected_region_attr = selected_region + elif isinstance(selected_region, int): + selected_region_attr = IntegerAttr.get( + IntegerType.get_signless(32), selected_region + ) + elif isinstance(selected_region, (Value, Operation, OpView)): + selected_region_param = _get_op_result_or_value(selected_region) + + super().__init__( + results, + name, + num_alternatives, + selected_region_attr=selected_region_attr, + selected_region_param=selected_region_param, + loc=loc, + ip=ip, + ) + for region in self.regions: + region.blocks.append() + + +def alternatives( + results: Sequence[Type], + name: Union[StringAttr, str], + num_alternatives: int, + *, + selected_region: Optional[Union[int, IntegerAttr, Value, Operation, OpView]] = None, + loc=None, + ip=None, +): + return AlternativesOp( + results, name, num_alternatives, selected_region=selected_region, loc=loc, ip=ip + ) |
