summaryrefslogtreecommitdiff
path: root/mlir/lib/Bindings/Python/Rewrite.cpp
diff options
context:
space:
mode:
authorKoakuma <koachan@protonmail.com>2024-07-08 19:19:54 +0700
committerKoakuma <koachan@protonmail.com>2024-07-08 19:19:54 +0700
commit5c4fdc2fd5898ebd9e89999a4f4b8aa289ca637f (patch)
treef3b92a07f3dfc6e70f36d1000605f36a3c15af46 /mlir/lib/Bindings/Python/Rewrite.cpp
parentdbda8e2f2cd8764e0badd983915d62a2c3377f4d (diff)
parente9b8cd0c806db00f0981fb36717077c941426302 (diff)
Created using spr 1.3.5 [skip ci]
Diffstat (limited to 'mlir/lib/Bindings/Python/Rewrite.cpp')
-rw-r--r--mlir/lib/Bindings/Python/Rewrite.cpp110
1 files changed, 110 insertions, 0 deletions
diff --git a/mlir/lib/Bindings/Python/Rewrite.cpp b/mlir/lib/Bindings/Python/Rewrite.cpp
new file mode 100644
index 000000000000..1d8128be9f08
--- /dev/null
+++ b/mlir/lib/Bindings/Python/Rewrite.cpp
@@ -0,0 +1,110 @@
+//===- Rewrite.cpp - Rewrite ----------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "Rewrite.h"
+
+#include "IRModule.h"
+#include "mlir-c/Bindings/Python/Interop.h"
+#include "mlir-c/Rewrite.h"
+#include "mlir/Config/mlir-config.h"
+
+namespace py = pybind11;
+using namespace mlir;
+using namespace py::literals;
+using namespace mlir::python;
+
+namespace {
+
+#if MLIR_ENABLE_PDL_IN_PATTERNMATCH
+/// Owning Wrapper around a PDLPatternModule.
+class PyPDLPatternModule {
+public:
+ PyPDLPatternModule(MlirPDLPatternModule module) : module(module) {}
+ PyPDLPatternModule(PyPDLPatternModule &&other) noexcept
+ : module(other.module) {
+ other.module.ptr = nullptr;
+ }
+ ~PyPDLPatternModule() {
+ if (module.ptr != nullptr)
+ mlirPDLPatternModuleDestroy(module);
+ }
+ MlirPDLPatternModule get() { return module; }
+
+private:
+ MlirPDLPatternModule module;
+};
+#endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH
+
+/// Owning Wrapper around a FrozenRewritePatternSet.
+class PyFrozenRewritePatternSet {
+public:
+ PyFrozenRewritePatternSet(MlirFrozenRewritePatternSet set) : set(set) {}
+ PyFrozenRewritePatternSet(PyFrozenRewritePatternSet &&other) noexcept
+ : set(other.set) {
+ other.set.ptr = nullptr;
+ }
+ ~PyFrozenRewritePatternSet() {
+ if (set.ptr != nullptr)
+ mlirFrozenRewritePatternSetDestroy(set);
+ }
+ MlirFrozenRewritePatternSet get() { return set; }
+
+ pybind11::object getCapsule() {
+ return py::reinterpret_steal<py::object>(
+ mlirPythonFrozenRewritePatternSetToCapsule(get()));
+ }
+
+ static pybind11::object createFromCapsule(pybind11::object capsule) {
+ MlirFrozenRewritePatternSet rawPm =
+ mlirPythonCapsuleToFrozenRewritePatternSet(capsule.ptr());
+ if (rawPm.ptr == nullptr)
+ throw py::error_already_set();
+ return py::cast(PyFrozenRewritePatternSet(rawPm),
+ py::return_value_policy::move);
+ }
+
+private:
+ MlirFrozenRewritePatternSet set;
+};
+
+} // namespace
+
+/// Create the `mlir.rewrite` here.
+void mlir::python::populateRewriteSubmodule(py::module &m) {
+ //----------------------------------------------------------------------------
+ // Mapping of the top-level PassManager
+ //----------------------------------------------------------------------------
+#if MLIR_ENABLE_PDL_IN_PATTERNMATCH
+ py::class_<PyPDLPatternModule>(m, "PDLModule", py::module_local())
+ .def(py::init<>([](MlirModule module) {
+ return mlirPDLPatternModuleFromModule(module);
+ }),
+ "module"_a, "Create a PDL module from the given module.")
+ .def("freeze", [](PyPDLPatternModule &self) {
+ return new PyFrozenRewritePatternSet(mlirFreezeRewritePattern(
+ mlirRewritePatternSetFromPDLPatternModule(self.get())));
+ });
+#endif // MLIR_ENABLE_PDL_IN_PATTERNMATCg
+ py::class_<PyFrozenRewritePatternSet>(m, "FrozenRewritePatternSet",
+ py::module_local())
+ .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
+ &PyFrozenRewritePatternSet::getCapsule)
+ .def(MLIR_PYTHON_CAPI_FACTORY_ATTR,
+ &PyFrozenRewritePatternSet::createFromCapsule);
+ m.def(
+ "apply_patterns_and_fold_greedily",
+ [](MlirModule module, MlirFrozenRewritePatternSet set) {
+ auto status = mlirApplyPatternsAndFoldGreedily(module, set, {});
+ if (mlirLogicalResultIsFailure(status))
+ // FIXME: Not sure this is the right error to throw here.
+ throw py::value_error("pattern application failed to converge");
+ },
+ "module"_a, "set"_a,
+ "Applys the given patterns to the given module greedily while folding "
+ "results.");
+}