summaryrefslogtreecommitdiff
path: root/mlir/lib/Bindings/Python/IRCore.cpp
diff options
context:
space:
mode:
authorvfdev <vfdev.5@gmail.com>2025-01-13 12:00:31 +0100
committerGitHub <noreply@github.com>2025-01-13 03:00:31 -0800
commitf136c800b60dbfacdbb645e7e92acba52e2f279f (patch)
treec3932fa35c97958bad2c9f91b23bb9f847f07bb6 /mlir/lib/Bindings/Python/IRCore.cpp
parent7e2eb0f83e1cf6861c8fd1f038a88a8ddd851c34 (diff)
Enabled freethreading support in MLIR python bindings (#122684)
Reland reverted https://github.com/llvm/llvm-project/pull/107103 with the fixes for Python 3.8 cc @jpienaar Co-authored-by: Peter Hawkins <phawkins@google.com>
Diffstat (limited to 'mlir/lib/Bindings/Python/IRCore.cpp')
-rw-r--r--mlir/lib/Bindings/Python/IRCore.cpp31
1 files changed, 27 insertions, 4 deletions
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index 453d4f7c7e8b..463ebdebb3f3 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -243,9 +243,15 @@ static MlirBlock createBlock(const nb::sequence &pyArgTypes,
/// Wrapper for the global LLVM debugging flag.
struct PyGlobalDebugFlag {
- static void set(nb::object &o, bool enable) { mlirEnableGlobalDebug(enable); }
+ static void set(nb::object &o, bool enable) {
+ nb::ft_lock_guard lock(mutex);
+ mlirEnableGlobalDebug(enable);
+ }
- static bool get(const nb::object &) { return mlirIsGlobalDebugEnabled(); }
+ static bool get(const nb::object &) {
+ nb::ft_lock_guard lock(mutex);
+ return mlirIsGlobalDebugEnabled();
+ }
static void bind(nb::module_ &m) {
// Debug flags.
@@ -255,6 +261,7 @@ struct PyGlobalDebugFlag {
.def_static(
"set_types",
[](const std::string &type) {
+ nb::ft_lock_guard lock(mutex);
mlirSetGlobalDebugType(type.c_str());
},
"types"_a, "Sets specific debug types to be produced by LLVM")
@@ -263,11 +270,17 @@ struct PyGlobalDebugFlag {
pointers.reserve(types.size());
for (const std::string &str : types)
pointers.push_back(str.c_str());
+ nb::ft_lock_guard lock(mutex);
mlirSetGlobalDebugTypes(pointers.data(), pointers.size());
});
}
+
+private:
+ static nb::ft_mutex mutex;
};
+nb::ft_mutex PyGlobalDebugFlag::mutex;
+
struct PyAttrBuilderMap {
static bool dunderContains(const std::string &attributeKind) {
return PyGlobals::get().lookupAttributeBuilder(attributeKind).has_value();
@@ -606,6 +619,7 @@ private:
PyMlirContext::PyMlirContext(MlirContext context) : context(context) {
nb::gil_scoped_acquire acquire;
+ nb::ft_lock_guard lock(live_contexts_mutex);
auto &liveContexts = getLiveContexts();
liveContexts[context.ptr] = this;
}
@@ -615,7 +629,10 @@ PyMlirContext::~PyMlirContext() {
// forContext method, which always puts the associated handle into
// liveContexts.
nb::gil_scoped_acquire acquire;
- getLiveContexts().erase(context.ptr);
+ {
+ nb::ft_lock_guard lock(live_contexts_mutex);
+ getLiveContexts().erase(context.ptr);
+ }
mlirContextDestroy(context);
}
@@ -632,6 +649,7 @@ nb::object PyMlirContext::createFromCapsule(nb::object capsule) {
PyMlirContextRef PyMlirContext::forContext(MlirContext context) {
nb::gil_scoped_acquire acquire;
+ nb::ft_lock_guard lock(live_contexts_mutex);
auto &liveContexts = getLiveContexts();
auto it = liveContexts.find(context.ptr);
if (it == liveContexts.end()) {
@@ -647,12 +665,17 @@ PyMlirContextRef PyMlirContext::forContext(MlirContext context) {
return PyMlirContextRef(it->second, std::move(pyRef));
}
+nb::ft_mutex PyMlirContext::live_contexts_mutex;
+
PyMlirContext::LiveContextMap &PyMlirContext::getLiveContexts() {
static LiveContextMap liveContexts;
return liveContexts;
}
-size_t PyMlirContext::getLiveCount() { return getLiveContexts().size(); }
+size_t PyMlirContext::getLiveCount() {
+ nb::ft_lock_guard lock(live_contexts_mutex);
+ return getLiveContexts().size();
+}
size_t PyMlirContext::getLiveOperationCount() { return liveOperations.size(); }