summaryrefslogtreecommitdiff
path: root/mlir/python
diff options
context:
space:
mode:
authorRolf Morel <rolf.morel@intel.com>2025-10-01 14:47:35 +0100
committerGitHub <noreply@github.com>2025-10-01 13:47:35 +0000
commitf4d18c0ef8e3207b8ee2363fea60f21d4fa325bc (patch)
tree4fea1d9ceb566cfa12d5993009296a0dcfde0391 /mlir/python
parenta33544b83c80dcaa851fabd2979def6f68dd6e7a (diff)
[MLIR][Transform][Tune] Introduce `transform.tune.alternatives` op (#160724)
This op enables expressing uncertainty regarding what should be happening at particular places in transform-dialect schedules. In particular, it enables representing a choice among alternative regions. This choice is resolved through providing a `selected_region` argument. When this argument is provided, the semantics are such that it is valid to rewrite the op through substituting in the selected region -- with the op's interpreted semantics corresponding to exactly this. This op represents another piece of the puzzle w.r.t. a toolkit for expressing autotuning problems with the transform dialect. Note that this goes beyond tuning knobs _on_ transforms, going further by making it tunable which (sequences of) transforms are to be applied.
Diffstat (limited to 'mlir/python')
-rw-r--r--mlir/python/mlir/dialects/transform/tune.py66
1 files changed, 63 insertions, 3 deletions
diff --git a/mlir/python/mlir/dialects/transform/tune.py b/mlir/python/mlir/dialects/transform/tune.py
index f63f88a38242..b3bfa8015c4d 100644
--- a/mlir/python/mlir/dialects/transform/tune.py
+++ b/mlir/python/mlir/dialects/transform/tune.py
@@ -6,6 +6,9 @@ from typing import Optional, Sequence
from ...ir import (
Type,
+ Value,
+ Operation,
+ OpView,
Attribute,
ArrayAttr,
StringAttr,
@@ -19,7 +22,10 @@ from .._transform_tune_extension_ops_gen import *
from .._transform_tune_extension_ops_gen import _Dialect
try:
- from .._ods_common import _cext as _ods_cext
+ from .._ods_common import (
+ get_op_result_or_value as _get_op_result_or_value,
+ _cext as _ods_cext,
+ )
except ImportError as e:
raise RuntimeError("Error loading imports from extension module") from e
@@ -36,7 +42,7 @@ class KnobOp(KnobOp):
ArrayAttr, Sequence[Union[Attribute, bool, int, float, str]], Attribute
],
*,
- selected: Optional[Attribute] = None,
+ selected: Optional[Union[Attribute, bool, int, float, str]] = None,
loc=None,
ip=None,
):
@@ -75,8 +81,62 @@ def knob(
ArrayAttr, Sequence[Union[Attribute, bool, int, float, str]], Attribute
],
*,
- selected: Optional[Attribute] = None,
+ selected: Optional[Union[Attribute, bool, int, float, str]] = None,
loc=None,
ip=None,
):
return KnobOp(result, name, options, selected=selected, loc=loc, ip=ip)
+
+
+@_ods_cext.register_operation(_Dialect, replace=True)
+class AlternativesOp(AlternativesOp):
+ def __init__(
+ self,
+ results: Sequence[Type],
+ name: Union[StringAttr, str],
+ num_alternatives: int,
+ *,
+ selected_region: Optional[
+ Union[int, IntegerAttr, Value, Operation, OpView]
+ ] = None,
+ loc=None,
+ ip=None,
+ ):
+ if isinstance(name, str):
+ name = StringAttr.get(name)
+
+ selected_region_attr = selected_region_param = None
+ if isinstance(selected_region, IntegerAttr):
+ selected_region_attr = selected_region
+ elif isinstance(selected_region, int):
+ selected_region_attr = IntegerAttr.get(
+ IntegerType.get_signless(32), selected_region
+ )
+ elif isinstance(selected_region, (Value, Operation, OpView)):
+ selected_region_param = _get_op_result_or_value(selected_region)
+
+ super().__init__(
+ results,
+ name,
+ num_alternatives,
+ selected_region_attr=selected_region_attr,
+ selected_region_param=selected_region_param,
+ loc=loc,
+ ip=ip,
+ )
+ for region in self.regions:
+ region.blocks.append()
+
+
+def alternatives(
+ results: Sequence[Type],
+ name: Union[StringAttr, str],
+ num_alternatives: int,
+ *,
+ selected_region: Optional[Union[int, IntegerAttr, Value, Operation, OpView]] = None,
+ loc=None,
+ ip=None,
+):
+ return AlternativesOp(
+ results, name, num_alternatives, selected_region=selected_region, loc=loc, ip=ip
+ )