//===- TestEmulateNarrowType.cpp - Test Narrow Type Emulation ------*- c++ //-*-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Arith/Transforms/NarrowTypeEmulationConverter.h" #include "mlir/Dialect/Arith/Transforms/Passes.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/MemRef/Transforms/Transforms.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" using namespace mlir; namespace { struct TestEmulateNarrowTypePass : public PassWrapper> { MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestEmulateNarrowTypePass) TestEmulateNarrowTypePass() = default; TestEmulateNarrowTypePass(const TestEmulateNarrowTypePass &pass) : PassWrapper(pass) {} void getDependentDialects(DialectRegistry ®istry) const override { registry .insert(); } StringRef getArgument() const final { return "test-emulate-narrow-int"; } StringRef getDescription() const final { return "Function pass to test Narrow Integer Emulation"; } void runOnOperation() override { if (!llvm::isPowerOf2_32(loadStoreEmulateBitwidth) || loadStoreEmulateBitwidth < 8) { signalPassFailure(); return; } Operation *op = getOperation(); MLIRContext *ctx = op->getContext(); arith::NarrowTypeEmulationConverter typeConverter(loadStoreEmulateBitwidth); // Convert scalar type. typeConverter.addConversion([this](IntegerType ty) -> std::optional { unsigned width = ty.getWidth(); if (width >= arithComputeBitwidth) return ty; return IntegerType::get(ty.getContext(), arithComputeBitwidth); }); // Convert vector type. typeConverter.addConversion([this](VectorType ty) -> std::optional { auto intTy = dyn_cast(ty.getElementType()); if (!intTy) return ty; unsigned width = intTy.getWidth(); if (width >= arithComputeBitwidth) return ty; return VectorType::get( to_vector(ty.getShape()), IntegerType::get(ty.getContext(), arithComputeBitwidth)); }); // With the type converter enabled, we are effectively unable to write // negative tests. This is a workaround specifically for negative tests. if (!disableMemrefTypeConversion) memref::populateMemRefNarrowTypeEmulationConversions(typeConverter); ConversionTarget target(*ctx); target.addDynamicallyLegalOp([&typeConverter](Operation *op) { return typeConverter.isLegal(cast(op).getFunctionType()); }); auto opLegalCallback = [&typeConverter](Operation *op) { return typeConverter.isLegal(op); }; target.addDynamicallyLegalOp(opLegalCallback); target.addDynamicallyLegalDialect< arith::ArithDialect, vector::VectorDialect, memref::MemRefDialect, affine::AffineDialect>(opLegalCallback); RewritePatternSet patterns(ctx); arith::populateArithNarrowTypeEmulationPatterns(typeConverter, patterns); memref::populateMemRefNarrowTypeEmulationPatterns(typeConverter, patterns); vector::populateVectorNarrowTypeEmulationPatterns(typeConverter, patterns, disableAtomicRMW); if (failed(applyPartialConversion(op, target, std::move(patterns)))) signalPassFailure(); } Option loadStoreEmulateBitwidth{ *this, "memref-load-bitwidth", llvm::cl::desc("memref load/store emulation bit width"), llvm::cl::init(8)}; Option arithComputeBitwidth{ *this, "arith-compute-bitwidth", llvm::cl::desc("arith computation bit width"), llvm::cl::init(4)}; Option disableMemrefTypeConversion{ *this, "skip-memref-type-conversion", llvm::cl::desc("disable memref type conversion (to test failures)"), llvm::cl::init(false)}; Option disableAtomicRMW{ *this, "disable-atomic-rmw", llvm::cl::desc("disable atomic read-modify-write and prefer generating " "normal sequence"), llvm::cl::init(false)}; }; struct TestMemRefFlattenAndVectorNarrowTypeEmulationPass : public PassWrapper> { MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( TestMemRefFlattenAndVectorNarrowTypeEmulationPass) TestMemRefFlattenAndVectorNarrowTypeEmulationPass() = default; TestMemRefFlattenAndVectorNarrowTypeEmulationPass( const TestMemRefFlattenAndVectorNarrowTypeEmulationPass &pass) : PassWrapper(pass) {} void getDependentDialects(DialectRegistry ®istry) const override { registry .insert(); } StringRef getArgument() const final { return "test-memref-flatten-and-vector-narrow-type-emulation"; } StringRef getDescription() const final { return "Test MemRef flattening and vector narrow type emulation patterns"; } void runOnOperation() override { Operation *op = getOperation(); MLIRContext *ctx = &getContext(); // Create a type converter for narrow type emulation (8-bit) arith::NarrowTypeEmulationConverter typeConverter(8); // Add conversions for memref types with i4 elements memref::populateMemRefNarrowTypeEmulationConversions(typeConverter); ConversionTarget target(*ctx); target.addDynamicallyLegalOp([&typeConverter](Operation *op) { return typeConverter.isLegal(cast(op).getFunctionType()); }); auto opLegalCallback = [&typeConverter](Operation *op) { return typeConverter.isLegal(op); }; target.addDynamicallyLegalOp(opLegalCallback); target.addDynamicallyLegalDialect< arith::ArithDialect, vector::VectorDialect, memref::MemRefDialect, affine::AffineDialect>(opLegalCallback); RewritePatternSet patterns(ctx); // This is necessary for the purpose of emulating `memref.alloc` and // function boundaries. memref::populateMemRefNarrowTypeEmulationPatterns(typeConverter, patterns); vector::populateMemRefFlattenAndVectorNarrowTypeEmulationPatterns( typeConverter, patterns); // Apply partial conversion if (failed(applyPartialConversion(op, target, std::move(patterns)))) signalPassFailure(); } }; } // namespace namespace mlir::test { void registerTestEmulateNarrowTypePass() { PassRegistration(); PassRegistration(); } } // namespace mlir::test