diff options
| author | vfdev <vfdev.5@gmail.com> | 2025-01-13 12:00:31 +0100 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2025-01-13 03:00:31 -0800 |
| commit | f136c800b60dbfacdbb645e7e92acba52e2f279f (patch) | |
| tree | c3932fa35c97958bad2c9f91b23bb9f847f07bb6 /mlir/lib/Bindings/Python/IRCore.cpp | |
| parent | 7e2eb0f83e1cf6861c8fd1f038a88a8ddd851c34 (diff) | |
Enabled freethreading support in MLIR python bindings (#122684)
Reland reverted https://github.com/llvm/llvm-project/pull/107103 with
the fixes for Python 3.8
cc @jpienaar
Co-authored-by: Peter Hawkins <phawkins@google.com>
Diffstat (limited to 'mlir/lib/Bindings/Python/IRCore.cpp')
| -rw-r--r-- | mlir/lib/Bindings/Python/IRCore.cpp | 31 |
1 files changed, 27 insertions, 4 deletions
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp index 453d4f7c7e8b..463ebdebb3f3 100644 --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -243,9 +243,15 @@ static MlirBlock createBlock(const nb::sequence &pyArgTypes, /// Wrapper for the global LLVM debugging flag. struct PyGlobalDebugFlag { - static void set(nb::object &o, bool enable) { mlirEnableGlobalDebug(enable); } + static void set(nb::object &o, bool enable) { + nb::ft_lock_guard lock(mutex); + mlirEnableGlobalDebug(enable); + } - static bool get(const nb::object &) { return mlirIsGlobalDebugEnabled(); } + static bool get(const nb::object &) { + nb::ft_lock_guard lock(mutex); + return mlirIsGlobalDebugEnabled(); + } static void bind(nb::module_ &m) { // Debug flags. @@ -255,6 +261,7 @@ struct PyGlobalDebugFlag { .def_static( "set_types", [](const std::string &type) { + nb::ft_lock_guard lock(mutex); mlirSetGlobalDebugType(type.c_str()); }, "types"_a, "Sets specific debug types to be produced by LLVM") @@ -263,11 +270,17 @@ struct PyGlobalDebugFlag { pointers.reserve(types.size()); for (const std::string &str : types) pointers.push_back(str.c_str()); + nb::ft_lock_guard lock(mutex); mlirSetGlobalDebugTypes(pointers.data(), pointers.size()); }); } + +private: + static nb::ft_mutex mutex; }; +nb::ft_mutex PyGlobalDebugFlag::mutex; + struct PyAttrBuilderMap { static bool dunderContains(const std::string &attributeKind) { return PyGlobals::get().lookupAttributeBuilder(attributeKind).has_value(); @@ -606,6 +619,7 @@ private: PyMlirContext::PyMlirContext(MlirContext context) : context(context) { nb::gil_scoped_acquire acquire; + nb::ft_lock_guard lock(live_contexts_mutex); auto &liveContexts = getLiveContexts(); liveContexts[context.ptr] = this; } @@ -615,7 +629,10 @@ PyMlirContext::~PyMlirContext() { // forContext method, which always puts the associated handle into // liveContexts. nb::gil_scoped_acquire acquire; - getLiveContexts().erase(context.ptr); + { + nb::ft_lock_guard lock(live_contexts_mutex); + getLiveContexts().erase(context.ptr); + } mlirContextDestroy(context); } @@ -632,6 +649,7 @@ nb::object PyMlirContext::createFromCapsule(nb::object capsule) { PyMlirContextRef PyMlirContext::forContext(MlirContext context) { nb::gil_scoped_acquire acquire; + nb::ft_lock_guard lock(live_contexts_mutex); auto &liveContexts = getLiveContexts(); auto it = liveContexts.find(context.ptr); if (it == liveContexts.end()) { @@ -647,12 +665,17 @@ PyMlirContextRef PyMlirContext::forContext(MlirContext context) { return PyMlirContextRef(it->second, std::move(pyRef)); } +nb::ft_mutex PyMlirContext::live_contexts_mutex; + PyMlirContext::LiveContextMap &PyMlirContext::getLiveContexts() { static LiveContextMap liveContexts; return liveContexts; } -size_t PyMlirContext::getLiveCount() { return getLiveContexts().size(); } +size_t PyMlirContext::getLiveCount() { + nb::ft_lock_guard lock(live_contexts_mutex); + return getLiveContexts().size(); +} size_t PyMlirContext::getLiveOperationCount() { return liveOperations.size(); } |
