diff options
Diffstat (limited to 'mlir/test/lib/Dialect/Linalg/TestLinalgRankReduceContractionOps.cpp')
| -rw-r--r-- | mlir/test/lib/Dialect/Linalg/TestLinalgRankReduceContractionOps.cpp | 67 |
1 files changed, 67 insertions, 0 deletions
diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgRankReduceContractionOps.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgRankReduceContractionOps.cpp new file mode 100644 index 000000000000..8b455d7d68c3 --- /dev/null +++ b/mlir/test/lib/Dialect/Linalg/TestLinalgRankReduceContractionOps.cpp @@ -0,0 +1,67 @@ +//===- TestLinalgRankReduceContractionOps.cpp -----------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements a pass for testing rank reduing patterns for named +// contraction ops with unit dims. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Linalg/Transforms/Transforms.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +using namespace mlir; + +namespace { + +struct TestLinalgRankReduceContractionOps + : public PassWrapper<TestLinalgRankReduceContractionOps, + OperationPass<func::FuncOp>> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( + TestLinalgRankReduceContractionOps) + + TestLinalgRankReduceContractionOps() = default; + TestLinalgRankReduceContractionOps( + const TestLinalgRankReduceContractionOps &pass) + : PassWrapper(pass) {} + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert<affine::AffineDialect, linalg::LinalgDialect, + memref::MemRefDialect, tensor::TensorDialect>(); + } + StringRef getArgument() const final { + return "test-linalg-rank-reduce-contraction-ops"; + } + StringRef getDescription() const final { + return "Test Linalg rank reduce contraction ops with unit dims"; + } + + void runOnOperation() override { + MLIRContext *context = &this->getContext(); + func::FuncOp funcOp = this->getOperation(); + + RewritePatternSet patterns(context); + linalg::populateContractionOpRankReducingPatterns(patterns); + if (failed(applyPatternsAndFoldGreedily(funcOp.getBody(), + std::move(patterns)))) + return signalPassFailure(); + return; + } +}; + +} // namespace + +namespace mlir { +namespace test { +void registerTestLinalgRankReduceContractionOps() { + PassRegistration<TestLinalgRankReduceContractionOps>(); +} +} // namespace test +} // namespace mlir |
