summaryrefslogtreecommitdiff
path: root/mlir
diff options
context:
space:
mode:
authorJacques Pienaar <jpienaar@google.com>2025-11-21 15:31:46 +0200
committerGitHub <noreply@github.com>2025-11-21 13:31:46 +0000
commit5ab49edde282814f41b90431194afaff694deba7 (patch)
tree6ff67625c849b8296023831123269db48aa26e6d /mlir
parent4fca7b05e397e381466d6943a56f8407349c7594 (diff)
[mlir][py][c] Enable setting block arg locations. (#169033)
This enables changing the location of a block argument. Follows the approach for updating type of block arg.
Diffstat (limited to 'mlir')
-rw-r--r--mlir/include/mlir-c/IR.h4
-rw-r--r--mlir/lib/Bindings/Python/IRCore.cpp6
-rw-r--r--mlir/lib/CAPI/IR/IR.cpp5
-rw-r--r--mlir/test/python/ir/blocks.py15
4 files changed, 30 insertions, 0 deletions
diff --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h
index c464e4da66f1..d2f476286ca6 100644
--- a/mlir/include/mlir-c/IR.h
+++ b/mlir/include/mlir-c/IR.h
@@ -1051,6 +1051,10 @@ MLIR_CAPI_EXPORTED intptr_t mlirBlockArgumentGetArgNumber(MlirValue value);
MLIR_CAPI_EXPORTED void mlirBlockArgumentSetType(MlirValue value,
MlirType type);
+/// Sets the location of the block argument to the given location.
+MLIR_CAPI_EXPORTED void mlirBlockArgumentSetLocation(MlirValue value,
+ MlirLocation loc);
+
/// Returns an operation that produced this value as its result. Asserts if the
/// value is not an op result.
MLIR_CAPI_EXPORTED MlirOperation mlirOpResultGetOwner(MlirValue value);
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index 9d5bb9f54e93..03b540de97d4 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -2347,6 +2347,12 @@ public:
return mlirBlockArgumentSetType(self.get(), type);
},
nb::arg("type"), "Sets the type of this block argument.");
+ c.def(
+ "set_location",
+ [](PyBlockArgument &self, PyLocation loc) {
+ return mlirBlockArgumentSetLocation(self.get(), loc);
+ },
+ nb::arg("loc"), "Sets the location of this block argument.");
}
};
diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp
index 188186598c5c..ffcbed8b340c 100644
--- a/mlir/lib/CAPI/IR/IR.cpp
+++ b/mlir/lib/CAPI/IR/IR.cpp
@@ -1129,6 +1129,11 @@ void mlirBlockArgumentSetType(MlirValue value, MlirType type) {
blockArg.setType(unwrap(type));
}
+void mlirBlockArgumentSetLocation(MlirValue value, MlirLocation loc) {
+ if (auto blockArg = llvm::dyn_cast<BlockArgument>(unwrap(value)))
+ blockArg.setLoc(unwrap(loc));
+}
+
MlirOperation mlirOpResultGetOwner(MlirValue value) {
return wrap(llvm::dyn_cast<OpResult>(unwrap(value)).getOwner());
}
diff --git a/mlir/test/python/ir/blocks.py b/mlir/test/python/ir/blocks.py
index ced5fce43472..e876c00e0c52 100644
--- a/mlir/test/python/ir/blocks.py
+++ b/mlir/test/python/ir/blocks.py
@@ -191,3 +191,18 @@ def testBlockEraseArgs():
blocks[0].erase_argument(0)
# CHECK: ^bb0:
op.print(enable_debug_info=True)
+
+
+# CHECK-LABEL: TEST: testBlockArgSetLocation
+# CHECK: ^bb0(%{{.+}}: f32 loc("new_loc")):
+@run
+def testBlockArgSetLocation():
+ with Context() as ctx, Location.unknown(ctx) as loc:
+ ctx.allow_unregistered_dialects = True
+ f32 = F32Type.get()
+ op = Operation.create("test", regions=1, loc=Location.unknown())
+ blocks = op.regions[0].blocks
+ blocks.append(f32)
+ arg = blocks[0].arguments[0]
+ arg.set_location(Location.name("new_loc"))
+ op.print(enable_debug_info=True)