diff options
| author | Tuomas Kärnä <tuomas.karna@intel.com> | 2025-11-11 13:57:54 +0200 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2025-11-11 11:57:54 +0000 |
| commit | 300750d4bea3fc2a17de13aa26f71aa10f2f5d2f (patch) | |
| tree | 83b52b16b5d8b51f1bc0a46276df20d03e1118da /mlir/include | |
| parent | b440fb758477e758a640842e0de9baac8616d822 (diff) | |
[MLIR][XeGPU][TransformOps] Add set_gpu_launch_threads op (#166865)
Adds `transform.xegpu.set_gpu_launch_threads` that overrides `gpu.launch` operation threads.
Diffstat (limited to 'mlir/include')
| -rw-r--r-- | mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td | 39 |
1 files changed, 39 insertions, 0 deletions
diff --git a/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td b/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td index 34f333e556de..f5e4afad535e 100644 --- a/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td +++ b/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td @@ -161,4 +161,43 @@ def SetOpLayoutAttrOp : Op<Transform_Dialect, "xegpu.set_op_layout_attr", [ }]; } +def SetGPULaunchThreadsOp + : Op<Transform_Dialect, "xegpu.set_gpu_launch_threads", [ + DeclareOpInterfaceMethods<MemoryEffectsOpInterface>, + TransformOpInterface + ]> { + + let summary = "Set number of threads for a given gpu.launch operation"; + let description = [{ + Overrides the x,y,z threads operands of a given `gpu.launch` operation in-place. + }]; + + let arguments = (ins TransformHandleTypeInterface:$target, + Variadic<TransformAnyParamTypeOrAnyHandle>:$threads, + DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_threads + ); + let results = (outs); + let builders = [ + OpBuilder<(ins "Value":$target, "ArrayRef<OpFoldResult>":$mixedThreads)>, + ]; + + let assemblyFormat = [{ + $target + `threads` `=` custom<DynamicIndexList>($threads, $static_threads) + attr-dict `:` qualified(type(operands)) + }]; + + let extraClassDeclaration = [{ + ::mlir::DiagnosedSilenceableFailure apply( + ::mlir::transform::TransformRewriter &rewriter, + ::mlir::transform::TransformResults &transformResults, + ::mlir::transform::TransformState &state); + + ::llvm::SmallVector<::mlir::OpFoldResult> getMixedThreads() { + Builder b(getContext()); + return getMixedValues(getStaticThreads(), getThreads(), b); + } + }]; +} + #endif // XEGPU_TRANSFORM_OPS |
