diff options
| author | Muzammiluddin Syed <muzasyed@amd.com> | 2025-11-21 10:04:29 -0500 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2025-11-21 10:04:29 -0500 |
| commit | 77c329f54ca5cc884001e216afa47990aad27de4 (patch) | |
| tree | 0e63b0a649cea932129afbcab9aa33378392d68c /mlir | |
| parent | 8c3f59f1b297ec65b5870ebfd727a40f632de966 (diff) | |
[mlir][ROCDL] Adds wmma scaled intrinsics for gfx1250 (#165915)
Signed-off-by: Muzammiluddin Syed <muzasyed@amd.com>
Diffstat (limited to 'mlir')
| -rw-r--r-- | mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td | 112 | ||||
| -rw-r--r-- | mlir/test/Dialect/LLVMIR/rocdl.mlir | 20 | ||||
| -rw-r--r-- | mlir/test/Target/LLVMIR/rocdl.mlir | 153 |
3 files changed, 256 insertions, 29 deletions
diff --git a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td index 19741f10ce8c..a384273ba30e 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td @@ -599,25 +599,25 @@ def ROCDL_smfmac_f32_32x32x32_fp8_fp8 : ROCDL_Mfma_IntrOp<"smfmac.f32.32x32x32.f class ROCDL_WMMA_IntrOp<string mnemonic, Type AB, Type CD> : ROCDL_IntrOp<mnemonic, [0], [0], [], 1, 0, 0, 0, [], []>, Arguments<(ins - LLVM_ScalarOrVectorOf<AB>:$A, - LLVM_ScalarOrVectorOf<AB>:$B, - LLVM_ScalarOrVectorOf<CD>:$C)> { + LLVM_ScalarOrVectorOf<AB>:$a, + LLVM_ScalarOrVectorOf<AB>:$b, + LLVM_ScalarOrVectorOf<CD>:$c)> { let results = (outs LLVM_ScalarOrVectorOf<CD>:$res); let assemblyFormat = [{ - $A `,` $B `,` $C attr-dict `:` functional-type(operands, $res) + $a `,` $b `,` $c attr-dict `:` functional-type(operands, $res) }]; } class ROCDL_WMMA_Opsel_IntrOp<string mnemonic, Type AB, Type CD> : ROCDL_IntrOp<mnemonic, [0], [1], [], 1, 0, 0, 0, [3], ["opsel"]>, Arguments<(ins - LLVM_ScalarOrVectorOf<AB>:$A, - LLVM_ScalarOrVectorOf<AB>:$B, - LLVM_ScalarOrVectorOf<CD>:$C, + LLVM_ScalarOrVectorOf<AB>:$a, + LLVM_ScalarOrVectorOf<AB>:$b, + LLVM_ScalarOrVectorOf<CD>:$c, DefaultValuedAttr<I1Attr, "0">:$opsel)> { let results = (outs LLVM_ScalarOrVectorOf<CD>:$res); let assemblyFormat = [{ - $A `,` $B `,` $C attr-dict `:` functional-type(operands, $res) + $a `,` $b `,` $c attr-dict `:` functional-type(operands, $res) }]; } @@ -625,14 +625,14 @@ class ROCDL_WMMA_IU_IntrOp<string mnemonic, Type AB, Type CD> : ROCDL_IntrOp<mne [0], [1], [], 1, 0, 0, 0, [0, 2, 5], ["signA", "signB", "clamp"]>, Arguments<(ins DefaultValuedAttr<I1Attr, "0">:$signA, - LLVM_ScalarOrVectorOf<AB>:$A, + LLVM_ScalarOrVectorOf<AB>:$a, DefaultValuedAttr<I1Attr, "0">:$signB, - LLVM_ScalarOrVectorOf<AB>:$B, - LLVM_ScalarOrVectorOf<CD>:$C, + LLVM_ScalarOrVectorOf<AB>:$b, + LLVM_ScalarOrVectorOf<CD>:$c, DefaultValuedAttr<I1Attr, "0">:$clamp)> { let results = (outs LLVM_ScalarOrVectorOf<CD>:$res); let assemblyFormat = [{ - $A `,` $B `,` $C attr-dict `:` functional-type(operands, $res) + $a `,` $b `,` $c attr-dict `:` functional-type(operands, $res) }]; } @@ -640,31 +640,31 @@ class ROCDL_WMMA_ModsAll_Reuse_IntrOp<string mnemonic, Type AB, Type CD> : ROCDL [0], [1], [], 1, 0, 0, 0, [0, 2, 4, 6, 7], ["signA", "signB","modC","reuseA","reuseB"]>, Arguments<(ins DefaultValuedAttr<I1Attr, "0">:$signA, - LLVM_ScalarOrVectorOf<AB>:$A, + LLVM_ScalarOrVectorOf<AB>:$a, DefaultValuedAttr<I1Attr, "0">:$signB, - LLVM_ScalarOrVectorOf<AB>:$B, + LLVM_ScalarOrVectorOf<AB>:$b, DefaultValuedAttr<I16Attr, "0">:$modC, - LLVM_ScalarOrVectorOf<CD>:$C, + LLVM_ScalarOrVectorOf<CD>:$c, DefaultValuedAttr<I1Attr, "0">:$reuseA, DefaultValuedAttr<I1Attr, "0">:$reuseB)> { let results = (outs LLVM_ScalarOrVectorOf<CD>:$res); let assemblyFormat = [{ - $A `,` $B `,` $C attr-dict `:` functional-type(operands, $res) + $a `,` $b `,` $c attr-dict `:` functional-type(operands, $res) }]; } class ROCDL_WMMA_ModsC_IntrOp<string mnemonic, Type AB, Type CD> : ROCDL_IntrOp<mnemonic, [0], [0], [], 1, 0, 0, 0, [2, 4, 5], ["modC","reuseA","reuseB"]>, Arguments<(ins - LLVM_ScalarOrVectorOf<AB>:$A, - LLVM_ScalarOrVectorOf<AB>:$B, + LLVM_ScalarOrVectorOf<AB>:$a, + LLVM_ScalarOrVectorOf<AB>:$b, DefaultValuedAttr<I16Attr, "0">:$modC, - LLVM_ScalarOrVectorOf<CD>:$C, + LLVM_ScalarOrVectorOf<CD>:$c, DefaultValuedAttr<I1Attr, "0">:$reuseA, DefaultValuedAttr<I1Attr, "0">:$reuseB)> { let results = (outs LLVM_ScalarOrVectorOf<CD>:$res); let assemblyFormat = [{ - $A `,` $B `,` $C attr-dict `:` functional-type(operands, $res) + $a `,` $b `,` $c attr-dict `:` functional-type(operands, $res) }]; } @@ -672,16 +672,16 @@ class ROCDL_WMMA_ModsAll_Diff_IntrOp<string mnemonic, Type AB, Type C, Type D> : [0], [1, 5], [], 1, 0, 0, 0, [0, 2, 4, 6, 7], ["signA", "signB","modC","reuseA","reuseB"]>, Arguments<(ins DefaultValuedAttr<I1Attr, "0">:$signA, - LLVM_ScalarOrVectorOf<AB>:$A, + LLVM_ScalarOrVectorOf<AB>:$a, DefaultValuedAttr<I1Attr, "0">:$signB, - LLVM_ScalarOrVectorOf<AB>:$B, + LLVM_ScalarOrVectorOf<AB>:$b, DefaultValuedAttr<I16Attr, "0">:$modC, - LLVM_ScalarOrVectorOf<C>:$C, + LLVM_ScalarOrVectorOf<C>:$c, DefaultValuedAttr<I1Attr, "0">:$reuseA, DefaultValuedAttr<I1Attr, "0">:$reuseB)> { let results = (outs LLVM_ScalarOrVectorOf<D>:$res); let assemblyFormat = [{ - $A `,` $B `,` $C attr-dict `:` functional-type(operands, $res) + $a `,` $b `,` $c attr-dict `:` functional-type(operands, $res) }]; } @@ -689,15 +689,65 @@ class ROCDL_WMMA_ModsAB_IntrOp<string mnemonic, Type AB, Type CD> : ROCDL_IntrOp [0], [1], [], 1, 0, 0, 0, [0, 2, 5, 6], ["signA", "signB", "reuseA","reuseB"]>, Arguments<(ins DefaultValuedAttr<I1Attr, "0">:$signA, - LLVM_ScalarOrVectorOf<AB>:$A, + LLVM_ScalarOrVectorOf<AB>:$a, DefaultValuedAttr<I1Attr, "0">:$signB, - LLVM_ScalarOrVectorOf<AB>:$B, - LLVM_ScalarOrVectorOf<CD>:$C, + LLVM_ScalarOrVectorOf<AB>:$b, + LLVM_ScalarOrVectorOf<CD>:$c, DefaultValuedAttr<I1Attr, "0">:$reuseA, DefaultValuedAttr<I1Attr, "0">:$reuseB)> { let results = (outs LLVM_ScalarOrVectorOf<CD>:$res); let assemblyFormat = [{ - $A `,` $B `,` $C attr-dict `:` functional-type(operands, $res) + $a `,` $b `,` $c attr-dict `:` functional-type(operands, $res) + }]; +} + +// Overloaded operands: [1, 3] refers to LLVM intrinsic parameter positions where +// A is at position 1 and B is at position 3 (after format parameters). +class ROCDL_WMMA_Scale_IntrOp<string mnemonic, Type AB, Type CD, Type ScaleExpTy> : ROCDL_IntrOp<mnemonic, + [0], [1, 3], [], 1, 0, 0, 0, [0, 2, 4, 6, 7, 9, 10, 12, 13], + ["fmtA", "fmtB", "modC", "scaleAType", "fmtScaleA", + "scaleBType", "fmtScaleB", "reuseA", "reuseB"]>, + Arguments<(ins + DefaultValuedAttr<I32Attr, "0">:$fmtA, + LLVM_ScalarOrVectorOf<AB>:$a, + DefaultValuedAttr<I32Attr, "0">:$fmtB, + LLVM_ScalarOrVectorOf<AB>:$b, + DefaultValuedAttr<I16Attr, "0">:$modC, + LLVM_ScalarOrVectorOf<CD>:$c, + DefaultValuedAttr<I32Attr, "0">:$scaleAType, + DefaultValuedAttr<I32Attr, "0">:$fmtScaleA, + ScaleExpTy:$scaleA, + DefaultValuedAttr<I32Attr, "0">:$scaleBType, + DefaultValuedAttr<I32Attr, "0">:$fmtScaleB, + ScaleExpTy:$scaleB, + DefaultValuedAttr<I1Attr, "0">:$reuseA, + DefaultValuedAttr<I1Attr, "0">:$reuseB)> { + let results = (outs LLVM_ScalarOrVectorOf<CD>:$res); + let assemblyFormat = [{ + $a `,` $b `,` $c `,` $scaleA `,` $scaleB attr-dict `:` functional-type(operands, $res) + }]; +} + +class ROCDL_WMMA_Scale_F4_IntrOp<string mnemonic, Type AB, Type CD, Type ScaleExpTy> : ROCDL_IntrOp<mnemonic, + [0], [0, 1], [], 1, 0, 0, 0, [2, 4, 5, 7, 8, 10, 11], + ["modC", "scaleAType", "fmtScaleA", + "scaleBType", "fmtScaleB", "reuseA", "reuseB"]>, + Arguments<(ins + LLVM_ScalarOrVectorOf<AB>:$a, + LLVM_ScalarOrVectorOf<AB>:$b, + DefaultValuedAttr<I16Attr, "0">:$modC, + LLVM_ScalarOrVectorOf<CD>:$c, + DefaultValuedAttr<I32Attr, "0">:$scaleAType, + DefaultValuedAttr<I32Attr, "0">:$fmtScaleA, + ScaleExpTy:$scaleA, + DefaultValuedAttr<I32Attr, "0">:$scaleBType, + DefaultValuedAttr<I32Attr, "0">:$fmtScaleB, + ScaleExpTy:$scaleB, + DefaultValuedAttr<I1Attr, "0">:$reuseA, + DefaultValuedAttr<I1Attr, "0">:$reuseB)> { + let results = (outs LLVM_ScalarOrVectorOf<CD>:$res); + let assemblyFormat = [{ + $a `,` $b `,` $c `,` $scaleA `,` $scaleB attr-dict `:` functional-type(operands, $res) }]; } @@ -739,6 +789,12 @@ def ROCDL_wmma_f16_16x16x128_bf8_fp8 : ROCDL_WMMA_ModsC_IntrOp<"wmma.f16.16x16x1 def ROCDL_wmma_f16_16x16x128_bf8_bf8 : ROCDL_WMMA_ModsC_IntrOp<"wmma.f16.16x16x128.bf8_bf8", AnyInteger, F16>; def ROCDL_wmma_i32_16x16x64_iu8 : ROCDL_WMMA_ModsAB_IntrOp<"wmma.i32.16x16x64.iu8", AnyInteger, AnyInteger>; +// Scaled wmma intrinsics (available from gfx1250) +def ROCDL_wmma_scale_f32_16x16x128_f8f6f4 : ROCDL_WMMA_Scale_IntrOp<"wmma.scale.f32.16x16x128.f8f6f4", AnyInteger, F32, I32>; +def ROCDL_wmma_scale16_f32_16x16x128_f8f6f4 : ROCDL_WMMA_Scale_IntrOp<"wmma.scale16.f32.16x16x128.f8f6f4", AnyInteger, F32, I64>; +def ROCDL_wmma_scale_f32_32x16x128_f4 : ROCDL_WMMA_Scale_F4_IntrOp<"wmma.scale.f32.32x16x128.f4", AnyInteger, F32, I32>; +def ROCDL_wmma_scale16_f32_32x16x128_f4 : ROCDL_WMMA_Scale_F4_IntrOp<"wmma.scale16.f32.32x16x128.f4", AnyInteger, F32, I64>; + //===---------------------------------------------------------------------===// // LDS transpose intrinsics (available in GFX950) diff --git a/mlir/test/Dialect/LLVMIR/rocdl.mlir b/mlir/test/Dialect/LLVMIR/rocdl.mlir index 675975ae597a..27bf4163b9b7 100644 --- a/mlir/test/Dialect/LLVMIR/rocdl.mlir +++ b/mlir/test/Dialect/LLVMIR/rocdl.mlir @@ -1351,6 +1351,26 @@ llvm.func @rocdl.cvt.scalef32.sr.pk16(%v16xf32: vector<16xf32>, // ----- +// CHECK-LABEL: @rocdl_wmma_scale_ops +llvm.func @rocdl_wmma_scale_ops(%a_f8: vector<8xi32>, %a_f4: vector<4xi32>, %c_f32: vector<4xf32>, %c16_f32: vector<16xf32>, + %scale_i32: i32, %scale_i64: i64) { + // CHECK: rocdl.wmma.scale.f32.16x16x128.f8f6f4 %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : (vector<8xi32>, vector<8xi32>, vector<4xf32>, i32, i32) -> vector<4xf32> + %r0 = rocdl.wmma.scale.f32.16x16x128.f8f6f4 %a_f8, %a_f8, %c_f32, %scale_i32, %scale_i32 : (vector<8xi32>, vector<8xi32>, vector<4xf32>, i32, i32) -> vector<4xf32> + + // CHECK: rocdl.wmma.scale16.f32.16x16x128.f8f6f4 %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : (vector<8xi32>, vector<8xi32>, vector<4xf32>, i64, i64) -> vector<4xf32> + %r1 = rocdl.wmma.scale16.f32.16x16x128.f8f6f4 %a_f8, %a_f8, %c_f32, %scale_i64, %scale_i64 : (vector<8xi32>, vector<8xi32>, vector<4xf32>, i64, i64) -> vector<4xf32> + + // CHECK: rocdl.wmma.scale.f32.32x16x128.f4 %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : (vector<4xi32>, vector<4xi32>, vector<16xf32>, i32, i32) -> vector<16xf32> + %r2 = rocdl.wmma.scale.f32.32x16x128.f4 %a_f4, %a_f4, %c16_f32, %scale_i32, %scale_i32 : (vector<4xi32>, vector<4xi32>, vector<16xf32>, i32, i32) -> vector<16xf32> + + // CHECK: rocdl.wmma.scale16.f32.32x16x128.f4 %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : (vector<4xi32>, vector<4xi32>, vector<16xf32>, i64, i64) -> vector<16xf32> + %r3 = rocdl.wmma.scale16.f32.32x16x128.f4 %a_f4, %a_f4, %c16_f32, %scale_i64, %scale_i64 : (vector<4xi32>, vector<4xi32>, vector<16xf32>, i64, i64) -> vector<16xf32> + + llvm.return +} + +// ----- + // expected-error@below {{attribute attached to unexpected op}} func.func private @expected_llvm_func() attributes { rocdl.kernel } diff --git a/mlir/test/Target/LLVMIR/rocdl.mlir b/mlir/test/Target/LLVMIR/rocdl.mlir index dcf80ad4395d..86b69812787b 100644 --- a/mlir/test/Target/LLVMIR/rocdl.mlir +++ b/mlir/test/Target/LLVMIR/rocdl.mlir @@ -918,6 +918,30 @@ llvm.func @rocdl.wmma(%arg0 : vector<8xf32>, %arg1 : vector<16 x f16>, %arg2 : v // CHECK: call <8 x i32> @llvm.amdgcn.wmma.i32.16x16x32.iu4.v8i32.v2i32(i1 false, <2 x i32> %{{.*}} i1 false, <2 x i32> %{{.*}} <8 x i32> %{{.*}} i1 false) %r6.gfx12 = rocdl.wmma.i32.16x16x32.iu4 %arg4, %arg4, %arg3 {signA = false, signB = false, clamp = false} : (vector<2xi32>, vector<2xi32>, vector<8xi32>) -> vector<8xi32> + // Test signA=true, signB=false for iu8 + // CHECK: call <8 x i32> @llvm.amdgcn.wmma.i32.16x16x16.iu8.v8i32.v4i32(i1 true, <4 x i32> %{{.*}} i1 false, <4 x i32> %{{.*}} <8 x i32> %{{.*}} i1 false) + %r5a = rocdl.wmma.i32.16x16x16.iu8 %arg5, %arg5, %arg3 {signA = true, signB = false, clamp = false} : (vector<4xi32>, vector<4xi32>, vector<8xi32>) -> vector<8xi32> + + // Test signA=false, signB=true for iu8 + // CHECK: call <8 x i32> @llvm.amdgcn.wmma.i32.16x16x16.iu8.v8i32.v4i32(i1 false, <4 x i32> %{{.*}} i1 true, <4 x i32> %{{.*}} <8 x i32> %{{.*}} i1 false) + %r5b = rocdl.wmma.i32.16x16x16.iu8 %arg5, %arg5, %arg3 {signA = false, signB = true, clamp = false} : (vector<4xi32>, vector<4xi32>, vector<8xi32>) -> vector<8xi32> + + // Test signA=true, signB=true, clamp=true for iu8 + // CHECK: call <8 x i32> @llvm.amdgcn.wmma.i32.16x16x16.iu8.v8i32.v4i32(i1 true, <4 x i32> %{{.*}} i1 true, <4 x i32> %{{.*}} <8 x i32> %{{.*}} i1 true) + %r5c = rocdl.wmma.i32.16x16x16.iu8 %arg5, %arg5, %arg3 {signA = true, signB = true, clamp = true} : (vector<4xi32>, vector<4xi32>, vector<8xi32>) -> vector<8xi32> + + // Test signA=true, signB=false for iu4 + // CHECK: call <8 x i32> @llvm.amdgcn.wmma.i32.16x16x16.iu4.v8i32.v2i32(i1 true, <2 x i32> %{{.*}} i1 false, <2 x i32> %{{.*}} <8 x i32> %{{.*}} i1 false) + %r6a = rocdl.wmma.i32.16x16x16.iu4 %arg4, %arg4, %arg3 {signA = true, signB = false, clamp = false} : (vector<2xi32>, vector<2xi32>, vector<8xi32>) -> vector<8xi32> + + // Test signA=false, signB=true, clamp=true for iu4 + // CHECK: call <8 x i32> @llvm.amdgcn.wmma.i32.16x16x16.iu4.v8i32.v2i32(i1 false, <2 x i32> %{{.*}} i1 true, <2 x i32> %{{.*}} <8 x i32> %{{.*}} i1 true) + %r6b = rocdl.wmma.i32.16x16x16.iu4 %arg4, %arg4, %arg3 {signA = false, signB = true, clamp = true} : (vector<2xi32>, vector<2xi32>, vector<8xi32>) -> vector<8xi32> + + // Test signA=true, signB=true for iu4 gfx12 + // CHECK: call <8 x i32> @llvm.amdgcn.wmma.i32.16x16x32.iu4.v8i32.v2i32(i1 true, <2 x i32> %{{.*}} i1 true, <2 x i32> %{{.*}} <8 x i32> %{{.*}} i1 false) + %r6c = rocdl.wmma.i32.16x16x32.iu4 %arg4, %arg4, %arg3 {signA = true, signB = true, clamp = false} : (vector<2xi32>, vector<2xi32>, vector<8xi32>) -> vector<8xi32> + // f32 -> f32 // CHECK: call <4 x float> @llvm.amdgcn.wmma.f32.16x16x4.f32.v4f32.v16f32(i1 false, <16 x float> %{{.*}} i1 false, <16 x float> %{{.*}} i16 0, <4 x float> %{{.*}} i1 false, i1 false) %r1.gfx1250 = rocdl.wmma.f32.16x16x4.f32 %arg10, %arg10, %arg11 {signA = false, signB = false, modC = 0 : i16} : (vector<16xf32>, vector<16xf32>, vector<4xf32>) -> vector<4xf32> @@ -981,7 +1005,7 @@ llvm.func @rocdl.wmma(%arg0 : vector<8xf32>, %arg1 : vector<16 x f16>, %arg2 : v // CHECK: call <64 x half> @llvm.amdgcn.wmma.f16.16x16x128.fp8.fp8.v64f16.v4i32(<4 x i32> %{{.*}} <4 x i32> %{{.*}} i16 0, <64 x half> %{{.*}} i1 false, i1 false) %r19.gfx1250 = rocdl.wmma.f16.16x16x128.fp8_fp8 %arg5, %arg5, %arg15 {signA = false, signB = false, modC = 0 : i16} : (vector<4xi32>, vector<4xi32>, vector<64xf16>) -> vector<64xf16> - + // CHECK: call <64 x half> @llvm.amdgcn.wmma.f16.16x16x128.fp8.bf8.v64f16.v4i32(<4 x i32> %{{.*}} <4 x i32> %{{.*}} i16 0, <64 x half> %{{.*}} i1 false, i1 false) %r20.gfx1250 = rocdl.wmma.f16.16x16x128.fp8_bf8 %arg5, %arg5, %arg15 {signA = false, signB = false, modC = 0 : i16} : (vector<4xi32>, vector<4xi32>, vector<64xf16>) -> vector<64xf16> @@ -995,6 +1019,26 @@ llvm.func @rocdl.wmma(%arg0 : vector<8xf32>, %arg1 : vector<16 x f16>, %arg2 : v // CHECK: call <64 x i32> @llvm.amdgcn.wmma.i32.16x16x64.iu8.v64i32.v4i32(i1 false, <4 x i32> %{{.*}} i1 false, <4 x i32> %{{.*}} <64 x i32> %{{.*}} i1 false, i1 false) %r23.gfx1250 = rocdl.wmma.i32.16x16x64.iu8 %arg5, %arg5, %arg14 {signA = false, signB = false} : (vector<4xi32>, vector<4xi32>, vector<64xi32>) -> vector<64xi32> + // Test signA=true, signB=true for iu8 gfx1250 + // CHECK: call <64 x i32> @llvm.amdgcn.wmma.i32.16x16x64.iu8.v64i32.v4i32(i1 true, <4 x i32> %{{.*}} i1 true, <4 x i32> %{{.*}} <64 x i32> %{{.*}} i1 false, i1 false) + %r23a.gfx1250 = rocdl.wmma.i32.16x16x64.iu8 %arg5, %arg5, %arg14 {signA = true, signB = true} : (vector<4xi32>, vector<4xi32>, vector<64xi32>) -> vector<64xi32> + + // Test signA=true, signB=false, reuseA=true, reuseB=true for iu8 gfx1250 + // CHECK: call <64 x i32> @llvm.amdgcn.wmma.i32.16x16x64.iu8.v64i32.v4i32(i1 true, <4 x i32> %{{.*}} i1 false, <4 x i32> %{{.*}} <64 x i32> %{{.*}} i1 true, i1 true) + %r23b.gfx1250 = rocdl.wmma.i32.16x16x64.iu8 %arg5, %arg5, %arg14 {signA = true, signB = false, reuseA = true, reuseB = true} : (vector<4xi32>, vector<4xi32>, vector<64xi32>) -> vector<64xi32> + + // Test signA=true, signB=true with modC=1 for f32 gfx1250 + // CHECK: call <4 x float> @llvm.amdgcn.wmma.f32.16x16x4.f32.v4f32.v16f32(i1 true, <16 x float> %{{.*}} i1 true, <16 x float> %{{.*}} i16 1, <4 x float> %{{.*}} i1 false, i1 false) + %r1a.gfx1250 = rocdl.wmma.f32.16x16x4.f32 %arg10, %arg10, %arg11 {signA = true, signB = true, modC = 1 : i16, reuseA = false, reuseB = false} : (vector<16xf32>, vector<16xf32>, vector<4xf32>) -> vector<4xf32> + + // Test with modC=2 and signA=false, signB=true, reuseA=true for f16 gfx1250 + // CHECK: call <32 x float> @llvm.amdgcn.wmma.f32.16x16x32.f16.v32f32.v16f16(i1 false, <16 x half> %{{.*}} i1 true, <16 x half> %{{.*}} i16 2, <32 x float> %{{.*}} i1 true, i1 false) + %r2a.gfx1250 = rocdl.wmma.f32.16x16x32.f16 %arg1, %arg1, %arg12 {signA = false, signB = true, modC = 2 : i16, reuseA = true, reuseB = false} : (vector<16xf16>, vector<16xf16>, vector<32xf32>) -> vector<32xf32> + + // Test with modC=3 and signA=true, signB=true, reuseB=true for bf16 gfx1250 + // CHECK: call <32 x float> @llvm.amdgcn.wmma.f32.16x16x32.bf16.v32f32.v16bf16(i1 true, <16 x bfloat> %{{.*}} i1 true, <16 x bfloat> %{{.*}} i16 3, <32 x float> %{{.*}} i1 false, i1 true) + %r3a.gfx1250 = rocdl.wmma.f32.16x16x32.bf16 %arg16, %arg16, %arg12 {signA = true, signB = true, modC = 3 : i16, reuseA = false, reuseB = true} : (vector<16xbf16>, vector<16xbf16>, vector<32xf32>) -> vector<32xf32> + // ---- Wave64 ----- // f16 -> f32 @@ -1231,6 +1275,113 @@ llvm.func @rocdl.raw.ptr.buffer.load.lds(%rsrc : !llvm.ptr<8>, %dstLds : !llvm.p llvm.return } +llvm.func @rocdl.wmma.scale(%arg0: i32, %arg1: vector<4xf32>, %arg2: vector<8xi32>, + %arg3: vector<12xi32>, %arg5: vector<16xi32>, + %arg8: i64, %arg9: vector<8xf32>) -> vector<4xf32> { + // CHECK-LABEL: rocdl.wmma.scale + + // Test with default attributes (all zeros/false) + // CHECK: call <4 x float> @llvm.amdgcn.wmma.scale.f32.16x16x128.f8f6f4.v4f32.v16i32.v16i32(i32 0, <16 x i32> %{{.*}}, i32 0, <16 x i32> %{{.*}}, i16 0, <4 x float> %{{.*}}, i32 0, i32 0, i32 %{{.*}}, i32 0, i32 0, i32 %{{.*}}, i1 false, i1 false) + %r00 = rocdl.wmma.scale.f32.16x16x128.f8f6f4 %arg5, %arg5, %arg1, %arg0, %arg0 + {fmtA = 0 : i32, fmtB = 0 : i32, modC = 0 : i16, + scaleAType = 0 : i32, fmtScaleA = 0 : i32, + scaleBType = 0 : i32, fmtScaleB = 0 : i32, + reuseA = false, reuseB = false} : + (vector<16xi32>, vector<16xi32>, vector<4xf32>, i32, i32) -> vector<4xf32> + + // Test with different matrix formats (FP8 x BF8) + // CHECK: call <4 x float> @llvm.amdgcn.wmma.scale.f32.16x16x128.f8f6f4.v4f32.v16i32.v16i32(i32 0, <16 x i32> %{{.*}}, i32 1, <16 x i32> %{{.*}}, i16 0, <4 x float> %{{.*}}, i32 1, i32 1, i32 %{{.*}}, i32 1, i32 1, i32 %{{.*}}, i1 false, i1 false) + %r01 = rocdl.wmma.scale.f32.16x16x128.f8f6f4 %arg5, %arg5, %arg1, %arg0, %arg0 + {fmtA = 0 : i32, fmtB = 1 : i32, modC = 0 : i16, + scaleAType = 1 : i32, fmtScaleA = 1 : i32, + scaleBType = 1 : i32, fmtScaleB = 1 : i32, + reuseA = false, reuseB = false} : + (vector<16xi32>, vector<16xi32>, vector<4xf32>, i32, i32) -> vector<4xf32> + + // Test with FP8 x FP6 (different vector sizes) and modC = 1 (negate) + // CHECK: call <4 x float> @llvm.amdgcn.wmma.scale.f32.16x16x128.f8f6f4.v4f32.v16i32.v12i32(i32 0, <16 x i32> %{{.*}}, i32 2, <12 x i32> %{{.*}}, i16 1, <4 x float> %{{.*}}, i32 2, i32 2, i32 %{{.*}}, i32 2, i32 2, i32 %{{.*}}, i1 false, i1 false) + %r02 = rocdl.wmma.scale.f32.16x16x128.f8f6f4 %arg5, %arg3, %arg1, %arg0, %arg0 + {fmtA = 0 : i32, fmtB = 2 : i32, modC = 1 : i16, + scaleAType = 2 : i32, fmtScaleA = 2 : i32, + scaleBType = 2 : i32, fmtScaleB = 2 : i32, + reuseA = false, reuseB = false} : + (vector<16xi32>, vector<12xi32>, vector<4xf32>, i32, i32) -> vector<4xf32> + + // Test with BF8 x BF6 and modC = 2 (abs) + // CHECK: call <4 x float> @llvm.amdgcn.wmma.scale.f32.16x16x128.f8f6f4.v4f32.v16i32.v12i32(i32 1, <16 x i32> %{{.*}}, i32 3, <12 x i32> %{{.*}}, i16 2, <4 x float> %{{.*}}, i32 0, i32 0, i32 %{{.*}}, i32 0, i32 0, i32 %{{.*}}, i1 false, i1 false) + %r03 = rocdl.wmma.scale.f32.16x16x128.f8f6f4 %arg5, %arg3, %arg1, %arg0, %arg0 + {fmtA = 1 : i32, fmtB = 3 : i32, modC = 2 : i16, + scaleAType = 0 : i32, fmtScaleA = 0 : i32, + scaleBType = 0 : i32, fmtScaleB = 0 : i32, + reuseA = false, reuseB = false} : + (vector<16xi32>, vector<12xi32>, vector<4xf32>, i32, i32) -> vector<4xf32> + + // Test with FP8 x FP4 and modC = 3 (negate(abs)) + // CHECK: call <4 x float> @llvm.amdgcn.wmma.scale.f32.16x16x128.f8f6f4.v4f32.v16i32.v8i32(i32 0, <16 x i32> %{{.*}}, i32 4, <8 x i32> %{{.*}}, i16 3, <4 x float> %{{.*}}, i32 3, i32 3, i32 %{{.*}}, i32 3, i32 3, i32 %{{.*}}, i1 false, i1 false) + %r04 = rocdl.wmma.scale.f32.16x16x128.f8f6f4 %arg5, %arg2, %arg1, %arg0, %arg0 + {fmtA = 0 : i32, fmtB = 4 : i32, modC = 3 : i16, + scaleAType = 3 : i32, fmtScaleA = 3 : i32, + scaleBType = 3 : i32, fmtScaleB = 3 : i32, + reuseA = false, reuseB = false} : + (vector<16xi32>, vector<8xi32>, vector<4xf32>, i32, i32) -> vector<4xf32> + + // Test with reuseA = true + // CHECK: call <4 x float> @llvm.amdgcn.wmma.scale.f32.16x16x128.f8f6f4.v4f32.v16i32.v16i32(i32 2, <16 x i32> %{{.*}}, i32 2, <16 x i32> %{{.*}}, i16 0, <4 x float> %{{.*}}, i32 0, i32 0, i32 %{{.*}}, i32 0, i32 0, i32 %{{.*}}, i1 true, i1 false) + %r10 = rocdl.wmma.scale.f32.16x16x128.f8f6f4 %arg5, %arg5, %arg1, %arg0, %arg0 + {fmtA = 2 : i32, fmtB = 2 : i32, modC = 0 : i16, + scaleAType = 0 : i32, fmtScaleA = 0 : i32, + scaleBType = 0 : i32, fmtScaleB = 0 : i32, + reuseA = true, reuseB = false} : + (vector<16xi32>, vector<16xi32>, vector<4xf32>, i32, i32) -> vector<4xf32> + + // Test with reuseB = true + // CHECK: call <4 x float> @llvm.amdgcn.wmma.scale.f32.16x16x128.f8f6f4.v4f32.v16i32.v16i32(i32 3, <16 x i32> %{{.*}}, i32 3, <16 x i32> %{{.*}}, i16 0, <4 x float> %{{.*}}, i32 0, i32 0, i32 %{{.*}}, i32 0, i32 0, i32 %{{.*}}, i1 false, i1 true) + %r11 = rocdl.wmma.scale.f32.16x16x128.f8f6f4 %arg5, %arg5, %arg1, %arg0, %arg0 + {fmtA = 3 : i32, fmtB = 3 : i32, modC = 0 : i16, + scaleAType = 0 : i32, fmtScaleA = 0 : i32, + scaleBType = 0 : i32, fmtScaleB = 0 : i32, + reuseA = false, reuseB = true} : + (vector<16xi32>, vector<16xi32>, vector<4xf32>, i32, i32) -> vector<4xf32> + + // Test with both reuseA and reuseB = true + // CHECK: call <4 x float> @llvm.amdgcn.wmma.scale.f32.16x16x128.f8f6f4.v4f32.v16i32.v16i32(i32 4, <16 x i32> %{{.*}}, i32 4, <16 x i32> %{{.*}}, i16 1, <4 x float> %{{.*}}, i32 1, i32 1, i32 %{{.*}}, i32 1, i32 1, i32 %{{.*}}, i1 true, i1 true) + %r12 = rocdl.wmma.scale.f32.16x16x128.f8f6f4 %arg5, %arg5, %arg1, %arg0, %arg0 + {fmtA = 4 : i32, fmtB = 4 : i32, modC = 1 : i16, + scaleAType = 1 : i32, fmtScaleA = 1 : i32, + scaleBType = 1 : i32, fmtScaleB = 1 : i32, + reuseA = true, reuseB = true} : + (vector<16xi32>, vector<16xi32>, vector<4xf32>, i32, i32) -> vector<4xf32> + + // Test scale16 variant with i64 scale exponents + // CHECK: call <4 x float> @llvm.amdgcn.wmma.scale16.f32.16x16x128.f8f6f4.v4f32.v16i32.v16i32(i32 0, <16 x i32> %{{.*}}, i32 1, <16 x i32> %{{.*}}, i16 2, <4 x float> %{{.*}}, i32 2, i32 2, i64 %{{.*}}, i32 2, i32 2, i64 %{{.*}}, i1 false, i1 false) + %r_scale16 = rocdl.wmma.scale16.f32.16x16x128.f8f6f4 %arg5, %arg5, %arg1, %arg8, %arg8 + {fmtA = 0 : i32, fmtB = 1 : i32, modC = 2 : i16, + scaleAType = 2 : i32, fmtScaleA = 2 : i32, + scaleBType = 2 : i32, fmtScaleB = 2 : i32, + reuseA = false, reuseB = false} : + (vector<16xi32>, vector<16xi32>, vector<4xf32>, i64, i64) -> vector<4xf32> + + // Test f4 variant (no matrix format parameters) + // CHECK: call <8 x float> @llvm.amdgcn.wmma.scale.f32.32x16x128.f4.v8f32.v16i32.v8i32(<16 x i32> %{{.*}}, <8 x i32> %{{.*}}, i16 0, <8 x float> %{{.*}}, i32 1, i32 1, i32 %{{.*}}, i32 1, i32 1, i32 %{{.*}}, i1 false, i1 false) + %r_f4 = rocdl.wmma.scale.f32.32x16x128.f4 %arg5, %arg2, %arg9, %arg0, %arg0 + {modC = 0 : i16, + scaleAType = 1 : i32, fmtScaleA = 1 : i32, + scaleBType = 1 : i32, fmtScaleB = 1 : i32, + reuseA = false, reuseB = false} : + (vector<16xi32>, vector<8xi32>, vector<8xf32>, i32, i32) -> vector<8xf32> + + // Test f4 scale16 variant with varied attributes + // CHECK: call <8 x float> @llvm.amdgcn.wmma.scale16.f32.32x16x128.f4.v8f32.v16i32.v8i32(<16 x i32> %{{.*}}, <8 x i32> %{{.*}}, i16 3, <8 x float> %{{.*}}, i32 2, i32 3, i64 %{{.*}}, i32 3, i32 2, i64 %{{.*}}, i1 true, i1 true) + %r_f4_scale16 = rocdl.wmma.scale16.f32.32x16x128.f4 %arg5, %arg2, %arg9, %arg8, %arg8 + {modC = 3 : i16, + scaleAType = 2 : i32, fmtScaleA = 3 : i32, + scaleBType = 3 : i32, fmtScaleB = 2 : i32, + reuseA = true, reuseB = true} : + (vector<16xi32>, vector<8xi32>, vector<8xf32>, i64, i64) -> vector<8xf32> + + llvm.return %r00 : vector<4xf32> +} + llvm.func @rocdl.raw.ptr.buffer.atomic.f32(%rsrc : !llvm.ptr<8>, %offset : i32, %soffset : i32, %vdata1 : f32) { |
