diff options
| author | Manish Gupta <manigupta@google.com> | 2023-06-01 02:00:56 +0000 |
|---|---|---|
| committer | Manish Gupta <manigupta@google.com> | 2023-06-05 23:22:20 +0000 |
| commit | 9a795f0c59b1707d1f4bdb352e8805133d72d9e2 (patch) | |
| tree | 3c2aa45f8d6707de04025b2e701ed239e0d9f2a0 /mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp | |
| parent | f04cf6b73a27c84db6d1e0ecde8fa49c7bca89a4 (diff) | |
[mlir][Vector] Adds a pattern to fold `arith.extf` into `vector.contract`
Consider mixed precision data type, i.e., F16 input lhs, F16 input rhs, F32 accumulation, and F32 output. This is typically written as F32 <= F16*F16 + F32.
During vectorization from linalg to vector for mixed precision data type (F32 <= F16*F16 + F32), linalg.matmul introduces arith.extf on input lhs and rhs operands.
"linalg.matmul"(%lhs, %rhs, %acc) ({
^bb0(%arg1: f16, %arg2: f16, %arg3: f32):
%lhs_f32 = "arith.extf"(%arg1) : (f16) -> f32
%rhs_f32 = "arith.extf"(%arg2) : (f16) -> f32
%mul = "arith.mulf"(%lhs_f32, %rhs_f32) : (f32, f32) -> f32
%acc = "arith.addf"(%arg3, %mul) : (f32, f32) -> f32
"linalg.yield"(%acc) : (f32) -> ()
})
There are backend that natively supports mixed-precision data type and does not need the arith.extf. For example, NVIDIA A100 GPU has mma.sync.aligned.*.f32.f16.f16.f32 that can support mixed-precision data type. However, the presence of arith.extf in the IR, introduces the unnecessary casting targeting F32 Tensor Cores instead of F16 Tensor Cores for NVIDIA backend. This patch adds a folding pattern to fold arith.extf into vector.contract
Differential Revision: https://reviews.llvm.org/D151918
Diffstat (limited to 'mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp')
| -rw-r--r-- | mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp | 29 |
1 files changed, 29 insertions, 0 deletions
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp index 3b0cf2f83f19..4fbddcee574a 100644 --- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp +++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp @@ -19,6 +19,7 @@ #include "mlir/Dialect/Linalg/Passes.h" #include "mlir/Dialect/Linalg/Transforms/Transforms.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" @@ -709,6 +710,32 @@ struct TestVectorTransferTensorSlicePatterns } }; +struct TestFoldArithExtensionIntoVectorContractPatterns + : public PassWrapper<TestFoldArithExtensionIntoVectorContractPatterns, + OperationPass<func::FuncOp>> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( + TestFoldArithExtensionIntoVectorContractPatterns) + + StringRef getArgument() const final { + return "test-fold-arith-extf-into-vector-contract-patterns"; + } + StringRef getDescription() const final { + return "Test patterns that fold arithmetic extension ops into vector " + "contract ops"; + } + + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert<arith::ArithDialect, func::FuncDialect, nvgpu::NVGPUDialect, + memref::MemRefDialect, scf::SCFDialect, + tensor::TensorDialect, vector::VectorDialect>(); + } + + void runOnOperation() override { + RewritePatternSet patterns(&getContext()); + populateFoldArithExtensionPatterns(patterns); + (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); + } +}; } // namespace namespace mlir { @@ -745,6 +772,8 @@ void registerTestVectorLowerings() { PassRegistration<TestVectorGatherLowering>(); PassRegistration<TestVectorTransferTensorSlicePatterns>(); + + PassRegistration<TestFoldArithExtensionIntoVectorContractPatterns>(); } } // namespace test } // namespace mlir |
