summaryrefslogtreecommitdiff
path: root/mlir/include
diff options
context:
space:
mode:
authorTuomas Kärnä <tuomas.karna@intel.com>2025-11-11 13:57:54 +0200
committerGitHub <noreply@github.com>2025-11-11 11:57:54 +0000
commit300750d4bea3fc2a17de13aa26f71aa10f2f5d2f (patch)
tree83b52b16b5d8b51f1bc0a46276df20d03e1118da /mlir/include
parentb440fb758477e758a640842e0de9baac8616d822 (diff)
[MLIR][XeGPU][TransformOps] Add set_gpu_launch_threads op (#166865)
Adds `transform.xegpu.set_gpu_launch_threads` that overrides `gpu.launch` operation threads.
Diffstat (limited to 'mlir/include')
-rw-r--r--mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td39
1 files changed, 39 insertions, 0 deletions
diff --git a/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td b/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td
index 34f333e556de..f5e4afad535e 100644
--- a/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td
+++ b/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td
@@ -161,4 +161,43 @@ def SetOpLayoutAttrOp : Op<Transform_Dialect, "xegpu.set_op_layout_attr", [
}];
}
+def SetGPULaunchThreadsOp
+ : Op<Transform_Dialect, "xegpu.set_gpu_launch_threads", [
+ DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
+ TransformOpInterface
+ ]> {
+
+ let summary = "Set number of threads for a given gpu.launch operation";
+ let description = [{
+ Overrides the x,y,z threads operands of a given `gpu.launch` operation in-place.
+ }];
+
+ let arguments = (ins TransformHandleTypeInterface:$target,
+ Variadic<TransformAnyParamTypeOrAnyHandle>:$threads,
+ DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_threads
+ );
+ let results = (outs);
+ let builders = [
+ OpBuilder<(ins "Value":$target, "ArrayRef<OpFoldResult>":$mixedThreads)>,
+ ];
+
+ let assemblyFormat = [{
+ $target
+ `threads` `=` custom<DynamicIndexList>($threads, $static_threads)
+ attr-dict `:` qualified(type(operands))
+ }];
+
+ let extraClassDeclaration = [{
+ ::mlir::DiagnosedSilenceableFailure apply(
+ ::mlir::transform::TransformRewriter &rewriter,
+ ::mlir::transform::TransformResults &transformResults,
+ ::mlir::transform::TransformState &state);
+
+ ::llvm::SmallVector<::mlir::OpFoldResult> getMixedThreads() {
+ Builder b(getContext());
+ return getMixedValues(getStaticThreads(), getThreads(), b);
+ }
+ }];
+}
+
#endif // XEGPU_TRANSFORM_OPS