summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--mlir/python/mlir/dialects/linalg/opdsl/lang/yaml_helper.py62
-rw-r--r--mlir/python/requirements.txt6
2 files changed, 43 insertions, 25 deletions
diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/yaml_helper.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/yaml_helper.py
index 1672656b3a1f..2235bb2865c0 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/lang/yaml_helper.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/yaml_helper.py
@@ -5,22 +5,25 @@
import sys
+
+def multiline_str_representer(dumper, data):
+ if len(data.splitlines()) > 1:
+ return dumper.represent_scalar("tag:yaml.org,2002:str", data, style="|")
+ else:
+ return dumper.represent_scalar("tag:yaml.org,2002:str", data)
+
+
try:
- import yaml
+ from yaml import YAMLObject as _YAMLObject, add_representer
+
+ add_representer(str, multiline_str_representer)
except ModuleNotFoundError as e:
- raise ModuleNotFoundError(
- f"This tool requires PyYAML but it was not installed. "
- f"Recommend: {sys.executable} -m pip install PyYAML"
- ) from e
-__all__ = [
- "yaml_dump",
- "yaml_dump_all",
- "YAMLObject",
-]
+ class _YAMLObject:
+ pass
-class YAMLObject(yaml.YAMLObject):
+class YAMLObject(_YAMLObject):
@classmethod
def to_yaml(cls, dumper, self):
"""Default to a custom dictionary mapping."""
@@ -33,21 +36,34 @@ class YAMLObject(yaml.YAMLObject):
return yaml_dump(self)
-def multiline_str_representer(dumper, data):
- if len(data.splitlines()) > 1:
- return dumper.represent_scalar("tag:yaml.org,2002:str", data, style="|")
- else:
- return dumper.represent_scalar("tag:yaml.org,2002:str", data)
+def yaml_dump(data, sort_keys=False, **kwargs):
+ try:
+ import yaml
+ return yaml.dump(data, sort_keys=sort_keys, **kwargs)
+ except ModuleNotFoundError as e:
+ raise ModuleNotFoundError(
+ f"This tool requires PyYAML but it was not installed. "
+ f"Recommend: {sys.executable} -m pip install PyYAML"
+ ) from e
-yaml.add_representer(str, multiline_str_representer)
+def yaml_dump_all(data, sort_keys=False, explicit_start=True, **kwargs):
+ try:
+ import yaml
-def yaml_dump(data, sort_keys=False, **kwargs):
- return yaml.dump(data, sort_keys=sort_keys, **kwargs)
+ return yaml.dump_all(
+ data, sort_keys=sort_keys, explicit_start=explicit_start, **kwargs
+ )
+ except ModuleNotFoundError as e:
+ raise ModuleNotFoundError(
+ f"This tool requires PyYAML but it was not installed. "
+ f"Recommend: {sys.executable} -m pip install PyYAML"
+ ) from e
-def yaml_dump_all(data, sort_keys=False, explicit_start=True, **kwargs):
- return yaml.dump_all(
- data, sort_keys=sort_keys, explicit_start=explicit_start, **kwargs
- )
+__all__ = [
+ "yaml_dump",
+ "yaml_dump_all",
+ "YAMLObject",
+]
diff --git a/mlir/python/requirements.txt b/mlir/python/requirements.txt
index abe09259bb1e..a1ff6e815d2f 100644
--- a/mlir/python/requirements.txt
+++ b/mlir/python/requirements.txt
@@ -1,7 +1,9 @@
+# BUILD dependencies
nanobind>=2.9, <3.0
-numpy>=1.19.5, <=2.1.2
pybind11>=2.10.0, <=2.13.6
PyYAML>=5.4.0, <=6.0.1
+typing_extensions>=4.12.2
+# RUN dependencies
+numpy>=1.19.5, <=2.1.2
ml_dtypes>=0.1.0, <=0.6.0; python_version<"3.13" # provides several NumPy dtype extensions, including the bf16
ml_dtypes>=0.5.0, <=0.6.0; python_version>="3.13"
-typing_extensions>=4.12.2