summaryrefslogtreecommitdiff
path: root/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
diff options
context:
space:
mode:
authorManish Gupta <manigupta@google.com>2023-06-01 02:00:56 +0000
committerManish Gupta <manigupta@google.com>2023-06-05 23:22:20 +0000
commit9a795f0c59b1707d1f4bdb352e8805133d72d9e2 (patch)
tree3c2aa45f8d6707de04025b2e701ed239e0d9f2a0 /mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
parentf04cf6b73a27c84db6d1e0ecde8fa49c7bca89a4 (diff)
[mlir][Vector] Adds a pattern to fold `arith.extf` into `vector.contract`
Consider mixed precision data type, i.e., F16 input lhs, F16 input rhs, F32 accumulation, and F32 output. This is typically written as F32 <= F16*F16 + F32. During vectorization from linalg to vector for mixed precision data type (F32 <= F16*F16 + F32), linalg.matmul introduces arith.extf on input lhs and rhs operands. "linalg.matmul"(%lhs, %rhs, %acc) ({ ^bb0(%arg1: f16, %arg2: f16, %arg3: f32): %lhs_f32 = "arith.extf"(%arg1) : (f16) -> f32 %rhs_f32 = "arith.extf"(%arg2) : (f16) -> f32 %mul = "arith.mulf"(%lhs_f32, %rhs_f32) : (f32, f32) -> f32 %acc = "arith.addf"(%arg3, %mul) : (f32, f32) -> f32 "linalg.yield"(%acc) : (f32) -> () }) There are backend that natively supports mixed-precision data type and does not need the arith.extf. For example, NVIDIA A100 GPU has mma.sync.aligned.*.f32.f16.f16.f32 that can support mixed-precision data type. However, the presence of arith.extf in the IR, introduces the unnecessary casting targeting F32 Tensor Cores instead of F16 Tensor Cores for NVIDIA backend. This patch adds a folding pattern to fold arith.extf into vector.contract Differential Revision: https://reviews.llvm.org/D151918
Diffstat (limited to 'mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp')
-rw-r--r--mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp29
1 files changed, 29 insertions, 0 deletions
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index 3b0cf2f83f19..4fbddcee574a 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -19,6 +19,7 @@
#include "mlir/Dialect/Linalg/Passes.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
@@ -709,6 +710,32 @@ struct TestVectorTransferTensorSlicePatterns
}
};
+struct TestFoldArithExtensionIntoVectorContractPatterns
+ : public PassWrapper<TestFoldArithExtensionIntoVectorContractPatterns,
+ OperationPass<func::FuncOp>> {
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
+ TestFoldArithExtensionIntoVectorContractPatterns)
+
+ StringRef getArgument() const final {
+ return "test-fold-arith-extf-into-vector-contract-patterns";
+ }
+ StringRef getDescription() const final {
+ return "Test patterns that fold arithmetic extension ops into vector "
+ "contract ops";
+ }
+
+ void getDependentDialects(DialectRegistry &registry) const override {
+ registry.insert<arith::ArithDialect, func::FuncDialect, nvgpu::NVGPUDialect,
+ memref::MemRefDialect, scf::SCFDialect,
+ tensor::TensorDialect, vector::VectorDialect>();
+ }
+
+ void runOnOperation() override {
+ RewritePatternSet patterns(&getContext());
+ populateFoldArithExtensionPatterns(patterns);
+ (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
+ }
+};
} // namespace
namespace mlir {
@@ -745,6 +772,8 @@ void registerTestVectorLowerings() {
PassRegistration<TestVectorGatherLowering>();
PassRegistration<TestVectorTransferTensorSlicePatterns>();
+
+ PassRegistration<TestFoldArithExtensionIntoVectorContractPatterns>();
}
} // namespace test
} // namespace mlir