diff options
| -rw-r--r-- | mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h | 7 | ||||
| -rw-r--r-- | mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp | 45 | ||||
| -rw-r--r-- | mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp | 14 | ||||
| -rw-r--r-- | mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp | 23 | ||||
| -rw-r--r-- | mlir/lib/ExecutionEngine/APFloatWrappers.cpp | 40 | ||||
| -rw-r--r-- | mlir/lib/ExecutionEngine/CMakeLists.txt | 12 | ||||
| -rw-r--r-- | mlir/test/Integration/Dialect/Arith/CPU/test-apfloat-emulation.mlir | 19 |
7 files changed, 159 insertions, 1 deletions
diff --git a/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h b/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h index 8ad9ed18aceb..8564d0f4205c 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h +++ b/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h @@ -52,6 +52,13 @@ lookupOrCreatePrintF32Fn(OpBuilder &b, Operation *moduleOp, FailureOr<LLVM::LLVMFuncOp> lookupOrCreatePrintF64Fn(OpBuilder &b, Operation *moduleOp, SymbolTableCollection *symbolTables = nullptr); +FailureOr<LLVM::LLVMFuncOp> +lookupOrCreateApFloatPrintFn(OpBuilder &b, Operation *moduleOp, + SymbolTableCollection *symbolTables = nullptr); +FailureOr<LLVM::LLVMFuncOp> +lookupOrCreateApFloatAddFFn(OpBuilder &b, Operation *moduleOp, + SymbolTableCollection *symbolTables = nullptr); + /// Declares a function to print a C-string. /// If a custom runtime function is defined via `runtimeFunctionName`, it must /// have the signature void(char const*). The default function is `printString`. diff --git a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp index 03ed4d51cc74..632e1a7f0260 100644 --- a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp +++ b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp @@ -14,6 +14,7 @@ #include "mlir/Conversion/LLVMCommon/VectorPattern.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Arith/Transforms/Passes.h" +#include "mlir/Dialect/LLVMIR/FunctionCallUtils.h" #include "mlir/Dialect/LLVMIR/LLVMAttrs.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/IR/TypeUtilities.h" @@ -572,6 +573,47 @@ void mlir::arith::registerConvertArithToLLVMInterface( }); } +struct FancyAddFLowering : public ConvertOpToLLVMPattern<arith::AddFOp> { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(arith::AddFOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // Get APFloat adder function from runtime library. + auto parent = op->getParentOfType<ModuleOp>(); + if (!parent) + return failure(); + FailureOr<Operation *> adder = + LLVM::lookupOrCreateApFloatAddFFn(rewriter, parent); + auto floatTy = cast<FloatType>(op.getType()); + + // Cast operands to 64-bit integers. + Location loc = op.getLoc(); + Value lhsBits = rewriter.create<LLVM::ZExtOp>(loc, rewriter.getI64Type(), + adaptor.getLhs()); + Value rhsBits = rewriter.create<LLVM::ZExtOp>(loc, rewriter.getI64Type(), + adaptor.getRhs()); + + // Call software implementation of floating point addition. + int32_t sem = + llvm::APFloatBase::SemanticsToEnum(floatTy.getFloatSemantics()); + Value semValue = rewriter.create<LLVM::ConstantOp>( + loc, rewriter.getI32Type(), + rewriter.getIntegerAttr(rewriter.getI32Type(), sem)); + SmallVector<Value> params = {semValue, lhsBits, rhsBits}; + auto resultOp = + LLVM::CallOp::create(rewriter, loc, TypeRange(rewriter.getI64Type()), + SymbolRefAttr::get(*adder), params); + + // Truncate result to the original width. + Value truncatedBits = rewriter.create<LLVM::TruncOp>( + loc, rewriter.getIntegerType(floatTy.getWidth()), + resultOp->getResult(0)); + rewriter.replaceOp(op, truncatedBits); + return success(); + } +}; + //===----------------------------------------------------------------------===// // Pattern Population //===----------------------------------------------------------------------===// @@ -586,7 +628,8 @@ void mlir::arith::populateArithToLLVMConversionPatterns( // clang-format off patterns.add< - AddFOpLowering, + //AddFOpLowering, + FancyAddFLowering, AddIOpLowering, AndIOpLowering, AddUIExtendedOpLowering, diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp index 69a317ecd101..260c028ffd9c 100644 --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -1654,6 +1654,20 @@ private: return failure(); } } + } else if (auto floatTy = dyn_cast<FloatType>(printType)) { + // Print other floating-point types using the APFloat runtime library. + int32_t sem = + llvm::APFloatBase::SemanticsToEnum(floatTy.getFloatSemantics()); + Value semValue = rewriter.create<LLVM::ConstantOp>( + loc, rewriter.getI32Type(), + rewriter.getIntegerAttr(rewriter.getI32Type(), sem)); + Value floatBits = + rewriter.create<LLVM::ZExtOp>(loc, rewriter.getI64Type(), value); + printer = + LLVM::lookupOrCreateApFloatPrintFn(rewriter, parent, symbolTables); + emitCall(rewriter, loc, printer.value(), + ValueRange({semValue, floatBits})); + return success(); } else { return failure(); } diff --git a/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp b/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp index feaffa34897b..8ee039be6056 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp @@ -30,6 +30,8 @@ static constexpr llvm::StringRef kPrintF16 = "printF16"; static constexpr llvm::StringRef kPrintBF16 = "printBF16"; static constexpr llvm::StringRef kPrintF32 = "printF32"; static constexpr llvm::StringRef kPrintF64 = "printF64"; +static constexpr llvm::StringRef kPrintApFloat = "printApFloat"; +static constexpr llvm::StringRef kApFloatAddF = "APFloat_add"; static constexpr llvm::StringRef kPrintString = "printString"; static constexpr llvm::StringRef kPrintOpen = "printOpen"; static constexpr llvm::StringRef kPrintClose = "printClose"; @@ -160,6 +162,27 @@ mlir::LLVM::lookupOrCreatePrintF64Fn(OpBuilder &b, Operation *moduleOp, LLVM::LLVMVoidType::get(moduleOp->getContext()), symbolTables); } +FailureOr<LLVM::LLVMFuncOp> +mlir::LLVM::lookupOrCreateApFloatPrintFn(OpBuilder &b, Operation *moduleOp, + SymbolTableCollection *symbolTables) { + return lookupOrCreateReservedFn( + b, moduleOp, kPrintApFloat, + {IntegerType::get(moduleOp->getContext(), 32), + IntegerType::get(moduleOp->getContext(), 64)}, + LLVM::LLVMVoidType::get(moduleOp->getContext()), symbolTables); +} + +FailureOr<LLVM::LLVMFuncOp> +mlir::LLVM::lookupOrCreateApFloatAddFFn(OpBuilder &b, Operation *moduleOp, + SymbolTableCollection *symbolTables) { + return lookupOrCreateReservedFn( + b, moduleOp, kApFloatAddF, + {IntegerType::get(moduleOp->getContext(), 32), + IntegerType::get(moduleOp->getContext(), 64), + IntegerType::get(moduleOp->getContext(), 64)}, + IntegerType::get(moduleOp->getContext(), 64), symbolTables); +} + static LLVM::LLVMPointerType getCharPtr(MLIRContext *context) { return LLVM::LLVMPointerType::get(context); } diff --git a/mlir/lib/ExecutionEngine/APFloatWrappers.cpp b/mlir/lib/ExecutionEngine/APFloatWrappers.cpp new file mode 100644 index 000000000000..7879c7580335 --- /dev/null +++ b/mlir/lib/ExecutionEngine/APFloatWrappers.cpp @@ -0,0 +1,40 @@ +//===- ArmRunnerUtils.cpp - Utilities for configuring architecture properties // +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "llvm/ADT/APFloat.h" +#include <iostream> + +#if (defined(_WIN32) || defined(__CYGWIN__)) +#define MLIR_APFLOAT_WRAPPERS_EXPORTED __declspec(dllexport) +#else +#define MLIR_APFLOAT_WRAPPERS_EXPORTED __attribute__((visibility("default"))) +#endif + +extern "C" { + +int64_t MLIR_APFLOAT_WRAPPERS_EXPORTED APFloat_add(int32_t semantics, + uint64_t a, uint64_t b) { + const llvm::fltSemantics &sem = llvm::APFloatBase::EnumToSemantics( + static_cast<llvm::APFloatBase::Semantics>(semantics)); + unsigned bitWidth = llvm::APFloatBase::semanticsSizeInBits(sem); + llvm::APFloat lhs(sem, llvm::APInt(bitWidth, a)); + llvm::APFloat rhs(sem, llvm::APInt(bitWidth, b)); + auto status = lhs.add(rhs, llvm::RoundingMode::NearestTiesToEven); + return lhs.bitcastToAPInt().getZExtValue(); +} + +void MLIR_APFLOAT_WRAPPERS_EXPORTED printApFloat(int32_t semantics, + uint64_t a) { + const llvm::fltSemantics &sem = llvm::APFloatBase::EnumToSemantics( + static_cast<llvm::APFloatBase::Semantics>(semantics)); + unsigned bitWidth = llvm::APFloatBase::semanticsSizeInBits(sem); + llvm::APFloat x(sem, llvm::APInt(bitWidth, a)); + double d = x.convertToDouble(); + std::cout << d << std::endl; +} +} diff --git a/mlir/lib/ExecutionEngine/CMakeLists.txt b/mlir/lib/ExecutionEngine/CMakeLists.txt index fdeb4dacf927..8c09e50e4de7 100644 --- a/mlir/lib/ExecutionEngine/CMakeLists.txt +++ b/mlir/lib/ExecutionEngine/CMakeLists.txt @@ -2,6 +2,7 @@ # is a big dependency which most don't need. set(LLVM_OPTIONAL_SOURCES + APFloatWrappers.cpp ArmRunnerUtils.cpp ArmSMEStubs.cpp AsyncRuntime.cpp @@ -167,6 +168,15 @@ if(LLVM_ENABLE_PIC) set_property(TARGET mlir_float16_utils PROPERTY CXX_STANDARD 17) target_compile_definitions(mlir_float16_utils PRIVATE mlir_float16_utils_EXPORTS) + add_mlir_library(mlir_apfloat_wrappers + SHARED + APFloatWrappers.cpp + + EXCLUDE_FROM_LIBMLIR + ) + set_property(TARGET mlir_apfloat_wrappers PROPERTY CXX_STANDARD 17) + target_compile_definitions(mlir_apfloat_wrappers PRIVATE mlir_apfloat_wrappers_EXPORTS) + add_subdirectory(SparseTensor) add_mlir_library(mlir_c_runner_utils @@ -177,6 +187,7 @@ if(LLVM_ENABLE_PIC) EXCLUDE_FROM_LIBMLIR LINK_LIBS PUBLIC + mlir_apfloat_wrappers mlir_float16_utils MLIRSparseTensorEnums MLIRSparseTensorRuntime @@ -191,6 +202,7 @@ if(LLVM_ENABLE_PIC) EXCLUDE_FROM_LIBMLIR LINK_LIBS PUBLIC + mlir_apfloat_wrappers mlir_float16_utils ) target_compile_definitions(mlir_runner_utils PRIVATE mlir_runner_utils_EXPORTS) diff --git a/mlir/test/Integration/Dialect/Arith/CPU/test-apfloat-emulation.mlir b/mlir/test/Integration/Dialect/Arith/CPU/test-apfloat-emulation.mlir new file mode 100644 index 000000000000..5cd83688d171 --- /dev/null +++ b/mlir/test/Integration/Dialect/Arith/CPU/test-apfloat-emulation.mlir @@ -0,0 +1,19 @@ +// Check that the ceildivsi lowering is correct. +// We do not check any poison or UB values, as it is not possible to catch them. + +// RUN: mlir-opt %s --convert-to-llvm + +// Put rhs into separate function so that it won't be constant-folded. +func.func @foo() -> f4E2M1FN { + %cst = arith.constant 5.0 : f4E2M1FN + return %cst : f4E2M1FN +} + +func.func @entry() { + %a = arith.constant 5.0 : f4E2M1FN + %b = func.call @foo() : () -> (f4E2M1FN) + %c = arith.addf %a, %b : f4E2M1FN + vector.print %c : f4E2M1FN + return +} + |
