diff options
| author | Peter Hawkins <phawkins@google.com> | 2024-12-18 14:16:11 -0500 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2024-12-18 11:16:11 -0800 |
| commit | 41bd35b58bb482fd466aa4b13aa44a810ad6470f (patch) | |
| tree | 4031a35c28a7191885555d86b086b7b42f81c535 /mlir/lib/Bindings/Python/MainModule.cpp | |
| parent | bfd05102d817fce38938ce864f89ad90ef0b6cda (diff) | |
[mlir python] Port Python core code to nanobind. (#118583)
Why? https://nanobind.readthedocs.io/en/latest/why.html says it better
than I can, but my primary motivation for this change is to improve MLIR
IR construction time from JAX.
For a complicated Google-internal LLM model in JAX, this change improves
the MLIR
lowering time by around 5s (out of around 30s), which is a significant
speedup for simply switching binding frameworks.
To a large extent, this is a mechanical change, for instance changing
`pybind11::`
to `nanobind::`.
Notes:
* this PR needs Nanobind 2.4.0, because it needs a bug fix
(https://github.com/wjakob/nanobind/pull/806) that landed in that
release.
* this PR does not port the in-tree dialect extension modules. They can
be ported in a future PR.
* I removed the py::sibling() annotations from def_static and def_class
in `PybindAdapters.h`. These ask pybind11 to try to form an overload
with an existing method, but it's not possible to form mixed
pybind11/nanobind overloads this ways and the parent class is now
defined in nanobind. Better solutions may be possible here.
* nanobind does not contain an exact equivalent of pybind11's buffer
protocol support. It was not hard to add a nanobind implementation of a
similar API.
* nanobind is pickier about casting to std::vector<bool>, expecting that
the input is a sequence of bool types, not truthy values. In a couple of
places I added code to support truthy values during casting.
* nanobind distinguishes bytes (`nb::bytes`) from strings (e.g.,
`std::string`). This required nb::bytes overloads in a few places.
Diffstat (limited to 'mlir/lib/Bindings/Python/MainModule.cpp')
| -rw-r--r-- | mlir/lib/Bindings/Python/MainModule.cpp | 56 |
1 files changed, 29 insertions, 27 deletions
diff --git a/mlir/lib/Bindings/Python/MainModule.cpp b/mlir/lib/Bindings/Python/MainModule.cpp index 7c27021902de..e5e64a921a79 100644 --- a/mlir/lib/Bindings/Python/MainModule.cpp +++ b/mlir/lib/Bindings/Python/MainModule.cpp @@ -6,29 +6,31 @@ // //===----------------------------------------------------------------------===// -#include "PybindUtils.h" +#include <nanobind/nanobind.h> +#include <nanobind/stl/string.h> #include "Globals.h" #include "IRModule.h" +#include "NanobindUtils.h" #include "Pass.h" #include "Rewrite.h" -namespace py = pybind11; +namespace nb = nanobind; using namespace mlir; -using namespace py::literals; +using namespace nb::literals; using namespace mlir::python; // ----------------------------------------------------------------------------- // Module initialization. // ----------------------------------------------------------------------------- -PYBIND11_MODULE(_mlir, m) { +NB_MODULE(_mlir, m) { m.doc() = "MLIR Python Native Extension"; - py::class_<PyGlobals>(m, "_Globals", py::module_local()) - .def_property("dialect_search_modules", - &PyGlobals::getDialectSearchPrefixes, - &PyGlobals::setDialectSearchPrefixes) + nb::class_<PyGlobals>(m, "_Globals") + .def_prop_rw("dialect_search_modules", + &PyGlobals::getDialectSearchPrefixes, + &PyGlobals::setDialectSearchPrefixes) .def( "append_dialect_search_prefix", [](PyGlobals &self, std::string moduleName) { @@ -45,22 +47,21 @@ PYBIND11_MODULE(_mlir, m) { "dialect_namespace"_a, "dialect_class"_a, "Testing hook for directly registering a dialect") .def("_register_operation_impl", &PyGlobals::registerOperationImpl, - "operation_name"_a, "operation_class"_a, py::kw_only(), + "operation_name"_a, "operation_class"_a, nb::kw_only(), "replace"_a = false, "Testing hook for directly registering an operation"); // Aside from making the globals accessible to python, having python manage // it is necessary to make sure it is destroyed (and releases its python // resources) properly. - m.attr("globals") = - py::cast(new PyGlobals, py::return_value_policy::take_ownership); + m.attr("globals") = nb::cast(new PyGlobals, nb::rv_policy::take_ownership); // Registration decorators. m.def( "register_dialect", - [](py::type pyClass) { + [](nb::type_object pyClass) { std::string dialectNamespace = - pyClass.attr("DIALECT_NAMESPACE").cast<std::string>(); + nanobind::cast<std::string>(pyClass.attr("DIALECT_NAMESPACE")); PyGlobals::get().registerDialectImpl(dialectNamespace, pyClass); return pyClass; }, @@ -68,45 +69,46 @@ PYBIND11_MODULE(_mlir, m) { "Class decorator for registering a custom Dialect wrapper"); m.def( "register_operation", - [](const py::type &dialectClass, bool replace) -> py::cpp_function { - return py::cpp_function( - [dialectClass, replace](py::type opClass) -> py::type { + [](const nb::type_object &dialectClass, bool replace) -> nb::object { + return nb::cpp_function( + [dialectClass, + replace](nb::type_object opClass) -> nb::type_object { std::string operationName = - opClass.attr("OPERATION_NAME").cast<std::string>(); + nanobind::cast<std::string>(opClass.attr("OPERATION_NAME")); PyGlobals::get().registerOperationImpl(operationName, opClass, replace); // Dict-stuff the new opClass by name onto the dialect class. - py::object opClassName = opClass.attr("__name__"); + nb::object opClassName = opClass.attr("__name__"); dialectClass.attr(opClassName) = opClass; return opClass; }); }, - "dialect_class"_a, py::kw_only(), "replace"_a = false, + "dialect_class"_a, nb::kw_only(), "replace"_a = false, "Produce a class decorator for registering an Operation class as part of " "a dialect"); m.def( MLIR_PYTHON_CAPI_TYPE_CASTER_REGISTER_ATTR, - [](MlirTypeID mlirTypeID, bool replace) -> py::cpp_function { - return py::cpp_function([mlirTypeID, - replace](py::object typeCaster) -> py::object { + [](MlirTypeID mlirTypeID, bool replace) -> nb::object { + return nb::cpp_function([mlirTypeID, replace]( + nb::callable typeCaster) -> nb::object { PyGlobals::get().registerTypeCaster(mlirTypeID, typeCaster, replace); return typeCaster; }); }, - "typeid"_a, py::kw_only(), "replace"_a = false, + "typeid"_a, nb::kw_only(), "replace"_a = false, "Register a type caster for casting MLIR types to custom user types."); m.def( MLIR_PYTHON_CAPI_VALUE_CASTER_REGISTER_ATTR, - [](MlirTypeID mlirTypeID, bool replace) -> py::cpp_function { - return py::cpp_function( - [mlirTypeID, replace](py::object valueCaster) -> py::object { + [](MlirTypeID mlirTypeID, bool replace) -> nb::object { + return nb::cpp_function( + [mlirTypeID, replace](nb::callable valueCaster) -> nb::object { PyGlobals::get().registerValueCaster(mlirTypeID, valueCaster, replace); return valueCaster; }); }, - "typeid"_a, py::kw_only(), "replace"_a = false, + "typeid"_a, nb::kw_only(), "replace"_a = false, "Register a value caster for casting MLIR values to custom user values."); // Define and populate IR submodule. |
