diff options
| author | NAKAMURA Takumi <geek4civic@gmail.com> | 2025-01-09 18:49:54 +0900 |
|---|---|---|
| committer | NAKAMURA Takumi <geek4civic@gmail.com> | 2025-01-09 18:49:54 +0900 |
| commit | e2810c9a248f4c7fbfae84bb32b6f7e01027458b (patch) | |
| tree | ae0b02a8491b969a1cee94ea16ffe42c559143c5 /mlir | |
| parent | fa04eb4af95c1ca7377279728cb004bcd2324d01 (diff) | |
| parent | bdcf47e4bcb92889665825654bb80a8bbe30379e (diff) | |
Merge branch 'users/chapuni/cov/single/base' into users/chapuni/cov/single/switchusers/chapuni/cov/single/switch
Diffstat (limited to 'mlir')
153 files changed, 3408 insertions, 1295 deletions
diff --git a/mlir/CMakeLists.txt b/mlir/CMakeLists.txt index 5ea49c0dbfa7..a888ac243b04 100644 --- a/mlir/CMakeLists.txt +++ b/mlir/CMakeLists.txt @@ -170,7 +170,7 @@ configure_file( # The pybind11 library can be found (set with -DPYBIND_DIR=...) # The python executable is correct (set with -DPython3_EXECUTABLE=...) # By default, find_package and probing for installed pybind11 is performed. -# Super projects can set MLIR_DISABLE_CONFIGURE_PYTHON_DEV_PACKAGES=ON to +# Super projects can set MLIR_DISABLE_CONFIGURE_PYTHON_DEV_PACKAGES=ON to # disable all package setup and control it themselves. #------------------------------------------------------------------------------- @@ -196,8 +196,10 @@ endif() set(CMAKE_INCLUDE_CURRENT_DIR ON) -include_directories( "include") -include_directories( ${MLIR_INCLUDE_DIR}) +include_directories(BEFORE + "include" + ${MLIR_INCLUDE_DIR} + ) # Adding tools/mlir-tblgen here as calling add_tablegen sets some variables like # MLIR_TABLEGEN_EXE in PARENT_SCOPE which gets lost if that folder is included diff --git a/mlir/cmake/modules/AddMLIRPython.cmake b/mlir/cmake/modules/AddMLIRPython.cmake index 9d4e06c7909c..717a503468a8 100644 --- a/mlir/cmake/modules/AddMLIRPython.cmake +++ b/mlir/cmake/modules/AddMLIRPython.cmake @@ -683,6 +683,13 @@ function(add_mlir_python_extension libname extname) ${eh_rtti_enable} ) endif() + + if(APPLE) + # NanobindAdaptors.h uses PyClassMethod_New to build `pure_subclass`es but nanobind + # doesn't declare this API as undefined in its linker flags. So we need to declare it as such + # for downstream users that do not do something like `-undefined dynamic_lookup`. + set(CMAKE_MODULE_LINKER_FLAGS "${CMAKE_MODULE_LINKER_FLAGS} -Wl,-U -Wl,_PyClassMethod_New") + endif() endif() target_compile_options(${libname} PRIVATE ${eh_rtti_enable}) diff --git a/mlir/docs/Bindings/Python.md b/mlir/docs/Bindings/Python.md index a0bd1cac118b..32df3310d811 100644 --- a/mlir/docs/Bindings/Python.md +++ b/mlir/docs/Bindings/Python.md @@ -1035,7 +1035,7 @@ class ConstantOp(_ods_ir.OpView): ... ``` -expects `value` to be a `TypedAttr` (e.g., `IntegerAttr` or `FloatAttr`). +expects `value` to be a `TypedAttr` (e.g., `IntegerAttr` or `FloatAttr`). Thus, a natural extension is a builder that accepts a MLIR type and a Python value and instantiates the appropriate `TypedAttr`: ```python @@ -1181,9 +1181,9 @@ make the passes available along with the dialect. Dialect functionality other than IR objects or passes, such as helper functions, can be exposed to Python similarly to attributes and types. C API is expected to exist for this functionality, which can then be wrapped using pybind11 and -`[include/mlir/Bindings/Python/PybindAdaptors.h](https://github.com/llvm/llvm-project/blob/main/mlir/include/mlir/Bindings/Python/PybindAdaptors.h)`, +[`include/mlir/Bindings/Python/PybindAdaptors.h`](https://github.com/llvm/llvm-project/blob/main/mlir/include/mlir/Bindings/Python/PybindAdaptors.h), or nanobind and -`[include/mlir/Bindings/Python/NanobindAdaptors.h](https://github.com/llvm/llvm-project/blob/main/mlir/include/mlir/Bindings/Python/NanobindAdaptors.h)` +[`include/mlir/Bindings/Python/NanobindAdaptors.h`](https://github.com/llvm/llvm-project/blob/main/mlir/include/mlir/Bindings/Python/NanobindAdaptors.h) utilities to connect to the rest of Python API. The bindings can be located in a separate module or in the same module as attributes and types, and loaded along with the dialect. diff --git a/mlir/docs/DialectConversion.md b/mlir/docs/DialectConversion.md index 3168f5e13c75..abacd5a82c61 100644 --- a/mlir/docs/DialectConversion.md +++ b/mlir/docs/DialectConversion.md @@ -242,19 +242,6 @@ cannot. These materializations are used by the conversion framework to ensure type safety during the conversion process. There are several types of materializations depending on the situation. -* Argument Materialization - - - An argument materialization is used when converting the type of a block - argument during a [signature conversion](#region-signature-conversion). - The new block argument types are specified in a `SignatureConversion` - object. An original block argument can be converted into multiple - block arguments, which is not supported everywhere in the dialect - conversion. (E.g., adaptors support only a single replacement value for - each original value.) Therefore, an argument materialization is used to - convert potentially multiple new block arguments back into a single SSA - value. An argument materialization is also used when replacing an op - result with multiple values. - * Source Materialization - A source materialization is used when a value was replaced with a value @@ -344,17 +331,6 @@ class TypeConverter { /// persist after the conversion has finished. /// This method registers a materialization that will be called when - /// converting (potentially multiple) block arguments that were the result of - /// a signature conversion of a single block argument, to a single SSA value - /// with the old argument type. - template <typename FnT, - typename T = typename llvm::function_traits<FnT>::template arg_t<1>> - void addArgumentMaterialization(FnT &&callback) { - argumentMaterializations.emplace_back( - wrapMaterialization<T>(std::forward<FnT>(callback))); - } - - /// This method registers a materialization that will be called when /// converting a replacement value back to its original source type. /// This is used when some uses of the original value persist beyond the main /// conversion. @@ -406,12 +382,11 @@ done explicitly via a conversion pattern. To convert the types of block arguments within a Region, a custom hook on the `ConversionPatternRewriter` must be invoked; `convertRegionTypes`. This hook uses a provided type converter to apply type conversions to all blocks of a -given region. As noted above, the conversions performed by this method use the -argument materialization hook on the `TypeConverter`. This hook also takes an -optional `TypeConverter::SignatureConversion` parameter that applies a custom -conversion to the entry block of the region. The types of the entry block -arguments are often tied semantically to the operation, e.g., -`func::FuncOp`, `AffineForOp`, etc. +given region. This hook also takes an optional +`TypeConverter::SignatureConversion` parameter that applies a custom conversion +to the entry block of the region. The types of the entry block arguments are +often tied semantically to the operation, e.g., `func::FuncOp`, `AffineForOp`, +etc. To convert the signature of just one given block, the `applySignatureConversion` hook can be used. diff --git a/mlir/docs/Tutorials/Toy/Ch-2.md b/mlir/docs/Tutorials/Toy/Ch-2.md index b807ee3a2049..039417c9c9a1 100644 --- a/mlir/docs/Tutorials/Toy/Ch-2.md +++ b/mlir/docs/Tutorials/Toy/Ch-2.md @@ -262,7 +262,7 @@ class ConstantOp : public mlir::Op< mlir::OpTrait::OneResult, /// We also provide a utility `getType` accessor that /// returns the TensorType of the single result. - mlir::OpTraits::OneTypedResult<TensorType>::Impl> { + mlir::OpTrait::OneTypedResult<TensorType>::Impl> { public: /// Inherit the constructors from the base Op class. diff --git a/mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp b/mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp index 3ad70e727969..123d114ae163 100644 --- a/mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp +++ b/mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp @@ -220,6 +220,7 @@ void ToyToLLVMLoweringPass::runOnOperation() { mlir::arith::populateArithToLLVMConversionPatterns(typeConverter, patterns); populateFinalizeMemRefToLLVMConversionPatterns(typeConverter, patterns); cf::populateControlFlowToLLVMConversionPatterns(typeConverter, patterns); + cf::populateAssertToLLVMConversionPattern(typeConverter, patterns); populateFuncToLLVMConversionPatterns(typeConverter, patterns); // The only remaining operation to lower from the `toy` dialect, is the diff --git a/mlir/include/mlir-c/Dialect/LLVM.h b/mlir/include/mlir-c/Dialect/LLVM.h index 0992285f997e..26c4140757c3 100644 --- a/mlir/include/mlir-c/Dialect/LLVM.h +++ b/mlir/include/mlir-c/Dialect/LLVM.h @@ -45,6 +45,13 @@ MLIR_CAPI_EXPORTED MlirType mlirLLVMFunctionTypeGet(MlirType resultType, intptr_t nArgumentTypes, MlirType const *argumentTypes, bool isVarArg); +/// Returns the number of input types. +MLIR_CAPI_EXPORTED intptr_t mlirLLVMFunctionTypeGetNumInputs(MlirType type); + +/// Returns the pos-th input type. +MLIR_CAPI_EXPORTED MlirType mlirLLVMFunctionTypeGetInput(MlirType type, + intptr_t pos); + /// Returns `true` if the type is an LLVM dialect struct type. MLIR_CAPI_EXPORTED bool mlirTypeIsALLVMStructType(MlirType type); diff --git a/mlir/include/mlir/Analysis/DataFlowFramework.h b/mlir/include/mlir/Analysis/DataFlowFramework.h index dfd358e7017a..b6d10ba0bea2 100644 --- a/mlir/include/mlir/Analysis/DataFlowFramework.h +++ b/mlir/include/mlir/Analysis/DataFlowFramework.h @@ -332,9 +332,11 @@ public: /// does not exist. template <typename StateT, typename AnchorT> const StateT *lookupState(AnchorT anchor) const { - auto it = - analysisStates.find({LatticeAnchor(anchor), TypeID::get<StateT>()}); - if (it == analysisStates.end()) + const auto &mapIt = analysisStates.find(LatticeAnchor(anchor)); + if (mapIt == analysisStates.end()) + return nullptr; + auto it = mapIt->second.find(TypeID::get<StateT>()); + if (it == mapIt->second.end()) return nullptr; return static_cast<const StateT *>(it->second.get()); } @@ -343,11 +345,7 @@ public: template <typename AnchorT> void eraseState(AnchorT anchor) { LatticeAnchor la(anchor); - - for (auto it = analysisStates.begin(); it != analysisStates.end(); ++it) { - if (it->first.first == la) - analysisStates.erase(it); - } + analysisStates.erase(LatticeAnchor(anchor)); } // Erase all analysis states @@ -426,7 +424,8 @@ private: /// A type-erased map of lattice anchors to associated analysis states for /// first-class lattice anchors. - DenseMap<std::pair<LatticeAnchor, TypeID>, std::unique_ptr<AnalysisState>> + DenseMap<LatticeAnchor, DenseMap<TypeID, std::unique_ptr<AnalysisState>>, + DenseMapInfo<LatticeAnchor::ParentTy>> analysisStates; /// Allow the base child analysis class to access the internals of the solver. @@ -643,7 +642,7 @@ AnalysisT *DataFlowSolver::load(Args &&...args) { template <typename StateT, typename AnchorT> StateT *DataFlowSolver::getOrCreateState(AnchorT anchor) { std::unique_ptr<AnalysisState> &state = - analysisStates[{LatticeAnchor(anchor), TypeID::get<StateT>()}]; + analysisStates[LatticeAnchor(anchor)][TypeID::get<StateT>()]; if (!state) { state = std::unique_ptr<StateT>(new StateT(anchor)); #if LLVM_ENABLE_ABI_BREAKING_CHECKS @@ -689,10 +688,6 @@ struct DenseMapInfo<mlir::ProgramPoint> { } }; -template <> -struct DenseMapInfo<mlir::LatticeAnchor> - : public DenseMapInfo<mlir::LatticeAnchor::ParentTy> {}; - // Allow llvm::cast style functions. template <typename To> struct CastInfo<To, mlir::LatticeAnchor> diff --git a/mlir/include/mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h b/mlir/include/mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h index b88c1e8b20f3..88f18022da9b 100644 --- a/mlir/include/mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h +++ b/mlir/include/mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h @@ -29,6 +29,10 @@ namespace cf { /// Collect the patterns to convert from the ControlFlow dialect to LLVM. The /// conversion patterns capture the LLVMTypeConverter by reference meaning the /// references have to remain alive during the entire pattern lifetime. +/// +/// Note: This function does not populate the default cf.assert lowering. That +/// is because some platforms have a custom cf.assert lowering. The default +/// lowering can be populated with `populateAssertToLLVMConversionPattern`. void populateControlFlowToLLVMConversionPatterns( const LLVMTypeConverter &converter, RewritePatternSet &patterns); diff --git a/mlir/include/mlir/Conversion/SCFToEmitC/SCFToEmitC.h b/mlir/include/mlir/Conversion/SCFToEmitC/SCFToEmitC.h index 22df7f1c5dcf..acc39e6acf72 100644 --- a/mlir/include/mlir/Conversion/SCFToEmitC/SCFToEmitC.h +++ b/mlir/include/mlir/Conversion/SCFToEmitC/SCFToEmitC.h @@ -9,6 +9,7 @@ #ifndef MLIR_CONVERSION_SCFTOEMITC_SCFTOEMITC_H #define MLIR_CONVERSION_SCFTOEMITC_SCFTOEMITC_H +#include "mlir/Transforms/DialectConversion.h" #include <memory> namespace mlir { @@ -19,7 +20,8 @@ class RewritePatternSet; #include "mlir/Conversion/Passes.h.inc" /// Collect a set of patterns to convert SCF operations to the EmitC dialect. -void populateSCFToEmitCConversionPatterns(RewritePatternSet &patterns); +void populateSCFToEmitCConversionPatterns(RewritePatternSet &patterns, + TypeConverter &typeConverter); } // namespace mlir #endif // MLIR_CONVERSION_SCFTOEMITC_SCFTOEMITC_H diff --git a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td index f5ca24389065..e2eab1fb2178 100644 --- a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td +++ b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td @@ -1083,6 +1083,9 @@ def AffineDelinearizeIndexOp : Affine_Op<"delinearize_index", [Pure]> { %indices_2 = affine.apply #map2()[%linear_index] ``` + In other words, `%0:3 = affine.delinearize_index %x into (B, C)` produces + `%0 = {%x / (B * C), (%x mod (B * C)) / C, %x mod C}`. + The basis may either contain `N` or `N-1` elements, where `N` is the number of results. If there are N basis elements, the first one will not be used during computations, but may be used during analysis and canonicalization to eliminate terms from @@ -1098,7 +1101,12 @@ def AffineDelinearizeIndexOp : Affine_Op<"delinearize_index", [Pure]> { %0:3 = affine.delinearize_index %linear_index into (244, 244) : index, index ``` - Note that, due to the constraints of affine maps, all the basis elements must + Note that, for symmetry with `getPaddedBasis()`, if `hasOuterBound` is `true` + when one of the `OpFoldResult` builders is called but the first element of the + basis is `nullptr`, that first element is ignored and the builder proceeds as if + there was no outer bound. + + Due to the constraints of affine maps, all the basis elements must be strictly positive. A dynamic basis element being 0 or negative causes undefined behavior. }]; @@ -1136,6 +1144,11 @@ def AffineDelinearizeIndexOp : Affine_Op<"delinearize_index", [Pure]> { /// Return a vector that contains the basis of the operation, removing /// the outer bound if one is present. SmallVector<OpFoldResult> getEffectiveBasis(); + + /// Return the vector with one basis element per result of the operation. If + /// there is no outer bound specified, the leading entry of this result will be + /// nullptr. + SmallVector<OpFoldResult> getPaddedBasis(); }]; let hasVerifier = 1; @@ -1160,6 +1173,9 @@ def AffineLinearizeIndexOp : Affine_Op<"linearize_index", sum(i = 0 to N-1) %idx_i * product(j = i + 1 to N-1) B_j ``` + In other words, `%0 = affine.linearize_index [%z, %y, %x] by (Z, Y, X)` + gives `%0 = %x + %y * X + %z * X * Y`, or `%0 = %x + X * (%y + Y * (%z))`. + The basis may either have `N` or `N-1` elements, where `N` is the number of inputs to linearize_index. If `N` inputs are provided, the first one is not used in computation, but may be used during analysis or canonicalization as a bound @@ -1168,6 +1184,10 @@ def AffineLinearizeIndexOp : Affine_Op<"linearize_index", If all `N` basis elements are provided, the linearize_index operation is said to "have an outer bound". + As a convenience, and for symmetry with `getPaddedBasis()`, ifg the first + element of a set of `OpFoldResult`s passed to the builders of this operation is + `nullptr`, that element is ignored. + If the `disjoint` property is present, this is an optimization hint that, for all `i`, `0 <= %idx_i < B_i` - that is, no index affects any other index, except that `%idx_0` may be negative to make the index as a whole negative. @@ -1224,6 +1244,11 @@ def AffineLinearizeIndexOp : Affine_Op<"linearize_index", /// Return a vector that contains the basis of the operation, removing /// the outer bound if one is present. SmallVector<OpFoldResult> getEffectiveBasis(); + + /// Return the vector with one basis element per index operand of the operation. + /// If there is no outer bound specified, the leading entry of this basis will be + /// nullptr. + SmallVector<OpFoldResult> getPaddedBasis(); }]; let hasVerifier = 1; diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h index 983f7a29cb22..d1a102e2a6e4 100644 --- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h +++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h @@ -456,7 +456,7 @@ public: /// read by themselves (e.g., ExtractSliceOp). bool isValueRead(Value value) const; - /// Starting from `value`, follow the use-def chain in reverse, always + /// Starting from `opOperand`, follow the use-def chain in reverse, always /// selecting the aliasing OpOperands. Find and return Values for which /// `condition` evaluates to true. OpOperands of such matching Values are not /// traversed any further, the visited aliasing opOperands will be preserved @@ -484,7 +484,7 @@ public: /// Additional stopping conditions for the traversal can be specified in /// `config`. SetVector<Value> findValueInReverseUseDefChain( - Value value, llvm::function_ref<bool(Value)> condition, + OpOperand *opOperand, llvm::function_ref<bool(Value)> condition, TraversalConfig config = TraversalConfig(), llvm::DenseSet<OpOperand *> *visitedOpOperands = nullptr) const; @@ -520,7 +520,7 @@ public: /// /// Note: OpResults of unknown ops are handled conservatively and assumed to /// be definitions. - SetVector<Value> findDefinitions(Value value) const; + SetVector<Value> findDefinitions(OpOperand *opOperand) const; /// Return `true` if the given OpResult has been decided to bufferize inplace. virtual bool isInPlace(OpOperand &opOperand) const; diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h index d50a3042aeea..bd23a19f7472 100644 --- a/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h +++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h @@ -127,9 +127,9 @@ public: /// Return true if the buffer of the given tensor value is writable. bool isWritable(Value value) const; - /// Find the definitions of the given tensor value or retrieve them from the - /// cache. - const SetVector<Value> &findDefinitionsCached(Value value); + /// Find the definitions of the given operand's value or + /// retrieve them from the cache. + const SetVector<Value> &findDefinitionsCached(OpOperand *opOperand); /// Reset cached data structures. void resetCache() override; diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/Transforms.h index 892675954493..a4ee893ca534 100644 --- a/mlir/include/mlir/Dialect/Bufferization/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/Transforms.h @@ -10,7 +10,9 @@ #define MLIR_DIALECT_BUFFERIZATION_TRANSFORMS_TRANSFORMS_H #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/Operation.h" +#include "mlir/Interfaces/SubsetOpInterface.h" namespace mlir { namespace bufferization { @@ -34,13 +36,35 @@ struct OneShotBufferizationOptions; /// "tensor.empty" op. LogicalResult eliminateEmptyTensors(RewriterBase &rewriter, Operation *op); +/// A function type that defines a callback to control the construction +/// of the subset extraction of the `SubsetInsertionOpInterface`. +/// The subset extraction value can be used as a replacement for the +/// `emptyTensorOp` value which is being consumed by `user`, failing +/// of building such a value should be indicated with an empty value. +/// This function should guarantee the legality of the replacement, +/// i.e. the replacement should dominate the user of the `emptyTensorOp` +/// being eliminated. +using ControlBuildSubsetExtractionFn = + std::function<Value(RewriterBase &, SubsetInsertionOpInterface, + tensor::EmptyOp emptyTensorOp, Operation *user)>; + +/// This method builds and returns a subset extraction value for the +/// destination tensor that the given `op` inserts into. +/// It returns a value which should replace the `emptyTensorOp` use +/// that is being consumed by `user`. +/// If no such a value found it will return an empty Value. +Value buildSubsetExtraction(RewriterBase &rewriter, + SubsetInsertionOpInterface op, + tensor::EmptyOp emptyTensorOp, Operation *user); + /// Try to eliminate "tensor.empty" ops inside `op`. /// /// This function overload accepts an existing `OneShotAnalysisState`, which /// contains in-place bufferization decisions. This overload is useful if an /// existing analysis should be reused for empty tensor elimination. -LogicalResult eliminateEmptyTensors(RewriterBase &rewriter, Operation *op, - OneShotAnalysisState &state); +LogicalResult eliminateEmptyTensors( + RewriterBase &rewriter, Operation *op, OneShotAnalysisState &state, + ControlBuildSubsetExtractionFn subsetsExtractionFn = buildSubsetExtraction); /// Within the given operation, hoist buffers from loops where possible. See /// "BufferLoopHoistingPass" for more information. diff --git a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td index fc5a33541533..744a0dc4770e 100644 --- a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td +++ b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td @@ -727,7 +727,7 @@ def EmitC_ReturnOp : EmitC_Op<"return", [Pure, HasParent<"FuncOp">, Example: ```mlir - emitc.func @foo() : (i32) { + emitc.func @foo() -> (i32) { ... emitc.return %0 : i32 } @@ -1305,8 +1305,6 @@ def EmitC_IfOp : EmitC_Op<"if", Block* body = getBody(1); return OpBuilder::atBlockEnd(body, listener); } - Block* thenBlock(); - Block* elseBlock(); }]; let hasCustomAssemblyFormat = 1; } diff --git a/mlir/include/mlir/Dialect/Func/IR/FuncOps.td b/mlir/include/mlir/Dialect/Func/IR/FuncOps.td index 237a825c1910..211201802b08 100644 --- a/mlir/include/mlir/Dialect/Func/IR/FuncOps.td +++ b/mlir/include/mlir/Dialect/Func/IR/FuncOps.td @@ -352,7 +352,7 @@ def ReturnOp : Func_Op<"return", [Pure, HasParent<"FuncOp">, Example: ```mlir - func.func @foo() : (i32, f8) { + func.func @foo() -> (i32, f8) { ... return %0, %1 : i32, f8 } diff --git a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td index 42a017db300a..3adfd5f4f2c4 100644 --- a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td +++ b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td @@ -1055,7 +1055,7 @@ def GPU_PrintfOp : GPU_Op<"printf", [MemoryEffects<[MemWrite]>]>, imposed by one's target platform. }]; let assemblyFormat = [{ - $format attr-dict ($args^ `:` type($args))? + $format attr-dict (`,` $args^ `:` type($args))? }]; } diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td index e8eeafd09a9c..267389774bd5 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td @@ -825,7 +825,7 @@ def LLVM_MemoryEffectsAttr : LLVM_Attr<"MemoryEffects", "memory_effects"> { def LLVM_AliasScopeDomainAttr : LLVM_Attr<"AliasScopeDomain", "alias_scope_domain"> { let parameters = (ins - "DistinctAttr":$id, + "Attribute":$id, OptionalParameter<"StringAttr">:$description ); @@ -853,7 +853,7 @@ def LLVM_AliasScopeDomainAttr : LLVM_Attr<"AliasScopeDomain", def LLVM_AliasScopeAttr : LLVM_Attr<"AliasScope", "alias_scope"> { let parameters = (ins - "DistinctAttr":$id, + "Attribute":$id, "AliasScopeDomainAttr":$domain, OptionalParameter<"StringAttr">:$description ); @@ -891,6 +891,8 @@ def LLVM_AliasScopeAttr : LLVM_Attr<"AliasScope", "alias_scope"> { } ``` + The first attribute can either be a DistinctAttr or a StringAttr. + See the following link for more details: https://llvm.org/docs/LangRef.html#noalias-and-alias-scope-metadata }]; @@ -898,6 +900,8 @@ def LLVM_AliasScopeAttr : LLVM_Attr<"AliasScope", "alias_scope"> { let summary = "LLVM dialect alias scope"; let assemblyFormat = "`<` struct(params) `>`"; + + let genVerifyDecl = 1; } def LLVM_AliasScopeArrayAttr diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td index 37eec6e07963..fff4048ee125 100644 --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td @@ -472,9 +472,6 @@ def TransposeOp : LinalgStructuredBase_Op<"transpose", [ getRegionBuilder() { return regionBuilder; } - - static void createRegion(::mlir::OpBuilder &opBuilder, - ::mlir::OperationState & odsState); }]; let hasFolder = 1; diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCF.h b/mlir/include/mlir/Dialect/SCF/IR/SCF.h index b62c94179794..ba648181daec 100644 --- a/mlir/include/mlir/Dialect/SCF/IR/SCF.h +++ b/mlir/include/mlir/Dialect/SCF/IR/SCF.h @@ -40,12 +40,6 @@ void buildTerminatedBody(OpBuilder &builder, Location loc); namespace mlir { namespace scf { -// Insert `loop.yield` at the end of the only region's only block if it -// does not have a terminator already. If a new `loop.yield` is inserted, -// the location is specified by `loc`. If the region is empty, insert a new -// block first. -void ensureLoopTerminator(Region ®ion, Builder &builder, Location loc); - /// Returns the loop parent of an induction variable. If the provided value is /// not an induction variable, then return nullptr. ForOp getForInductionVarOwner(Value val); diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td index 23c597a1ca51..6f408b3c924d 100644 --- a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td +++ b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td @@ -302,7 +302,7 @@ def ForallOp : SCF_Op<"forall", [ AttrSizedOperandSegments, AutomaticAllocationScope, DeclareOpInterfaceMethods<LoopLikeOpInterface, - ["getInitsMutable", "getRegionIterArgs", "getLoopInductionVars", + ["getInitsMutable", "getRegionIterArgs", "getLoopInductionVars", "getLoopLowerBounds", "getLoopUpperBounds", "getLoopSteps", "promoteIfSingleIteration", "yieldTiledValuesAndReplace"]>, RecursiveMemoryEffects, @@ -671,7 +671,7 @@ def IfOp : SCF_Op<"if", [DeclareOpInterfaceMethods<RegionBranchOpInterface, [ "getNumRegionInvocations", "getRegionInvocationBounds", "getEntrySuccessorRegions"]>, InferTypeOpAdaptor, SingleBlockImplicitTerminator<"scf::YieldOp">, - RecursiveMemoryEffects, NoRegionArguments]> { + RecursiveMemoryEffects, RecursivelySpeculatable, NoRegionArguments]> { let summary = "if-then-else operation"; let description = [{ The `scf.if` operation represents an if-then-else construct for diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/Patterns.h b/mlir/include/mlir/Dialect/SCF/Transforms/Patterns.h index b87407d302a8..18c9dfd205de 100644 --- a/mlir/include/mlir/Dialect/SCF/Transforms/Patterns.h +++ b/mlir/include/mlir/Dialect/SCF/Transforms/Patterns.h @@ -66,6 +66,9 @@ void populateSCFStructuralTypeConversionTarget( /// Populates the provided pattern set with patterns that do 1:N type /// conversions on (some) SCF ops. This is intended to be used with /// applyPartialOneToNConversion. +/// FIXME: The 1:N dialect conversion is deprecated and will be removed soon. +/// 1:N support has been added to the regular dialect conversion driver. +/// Use populateSCFStructuralTypeConversions() instead. void populateSCFStructuralOneToNTypeConversions( const TypeConverter &typeConverter, RewritePatternSet &patterns); diff --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td index 08a0398e74b0..8bccba426ab1 100644 --- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td +++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td @@ -321,11 +321,6 @@ def Shape_DimOp : Shape_Op<"dim", let assemblyFormat = "$value `,` $index attr-dict `:` type($value) `,`" "type($index) `->` type($extent)"; - let builders = [ - // Builder that allows passing a constant dimension as a simple integer. - OpBuilder<(ins "Value":$value, "int64_t":$index)> - ]; - let extraClassDeclaration = [{ /// Get the `index` value as integer if it is constant. std::optional<int64_t> getConstantIndex(); diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td index f5536927dc25..d3f12c34421b 100644 --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td @@ -126,11 +126,12 @@ def Tosa_ConvOpQuantInfoBuilder : OpBuilder< (ins "::mlir::Type":$outputType, "::mlir::Value":$input, "::mlir::Value":$weight, "::mlir::Value":$bias, "::mlir::DenseI64ArrayAttr":$pad, "::mlir::DenseI64ArrayAttr":$stride, - "::mlir::DenseI64ArrayAttr":$dilation), + "::mlir::DenseI64ArrayAttr":$dilation, + "::mlir::TypeAttr":$acc_type), [{ buildConvOpWithQuantInfo($_builder, $_state, outputType, input, weight, bias, - pad, stride, dilation); + pad, stride, dilation, acc_type); }]>; // Handles tosa.transpose_conv2d which has an outpad and output shape attribute. @@ -139,12 +140,13 @@ def Tosa_TransConvOpQuantInfoBuilder : OpBuilder< "::mlir::Value":$weight, "mlir::Value":$bias, "::mlir::DenseI64ArrayAttr":$outpad, "::mlir::DenseI64ArrayAttr":$stride, - "::mlir::DenseI64ArrayAttr":$outputShape), + "::mlir::DenseI64ArrayAttr":$outputShape, + "::mlir::TypeAttr":$acc_type), [{ buildTransConvOpWithQuantInfo($_builder, $_state, outputType, input, weight, bias, outpad, stride, - outputShape); + outputShape, acc_type); }]>; // The tosa.fully_connected op has its own builder as it does not have diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td index e3c725801d16..6b43c9a259b1 100644 --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td @@ -57,7 +57,7 @@ def Tosa_ArgMaxOp : Tosa_InferShapedTypeOp<"argmax"> { // Accumulator types. //===----------------------------------------------------------------------===// -def Tosa_AccType : AnyTypeOf<[I<32>, SI<32>, F16, F32]>; +def Tosa_AccType : AnyTypeOf<[I<32>, I<48>, F16, F32]>; //===----------------------------------------------------------------------===// // Operator: avg_pool2d @@ -106,6 +106,7 @@ def Tosa_Conv2DOp : Tosa_InferShapedTypeOp<"conv2d"> { Tosa_IntArrayAttr4:$pad, Tosa_IntArrayAttr2:$stride, Tosa_IntArrayAttr2:$dilation, + TypeAttrOf<Tosa_AccType>:$acc_type, OptionalAttr<Tosa_ConvOpQuantizationAttr>:$quantization_info, DefaultValuedOptionalAttr<BoolAttr, "false">:$local_bound ); @@ -135,6 +136,7 @@ def Tosa_Conv3DOp : Tosa_InferShapedTypeOp<"conv3d"> { Tosa_IntArrayAttr6:$pad, Tosa_IntArrayAttr3:$stride, Tosa_IntArrayAttr3:$dilation, + TypeAttrOf<Tosa_AccType>:$acc_type, OptionalAttr<Tosa_ConvOpQuantizationAttr>:$quantization_info, DefaultValuedOptionalAttr<BoolAttr, "false">:$local_bound ); @@ -165,6 +167,7 @@ def Tosa_DepthwiseConv2DOp : Tosa_InferShapedTypeOp<"depthwise_conv2d"> { Tosa_IntArrayAttr4:$pad, Tosa_IntArrayAttr2:$stride, Tosa_IntArrayAttr2:$dilation, + TypeAttrOf<Tosa_AccType>:$acc_type, OptionalAttr<Tosa_ConvOpQuantizationAttr>:$quantization_info, DefaultValuedOptionalAttr<BoolAttr, "false">:$local_bound ); @@ -348,6 +351,7 @@ def Tosa_TransposeConv2DOp : Tosa_InferShapedTypeOp<"transpose_conv2d"> { Tosa_IntArrayAttr4:$out_pad, Tosa_IntArrayAttr2:$stride, Tosa_IntArrayAttr4:$out_shape, + TypeAttrOf<Tosa_AccType>:$acc_type, OptionalAttr<Tosa_ConvOpQuantizationAttr>:$quantization_info, DefaultValuedOptionalAttr<BoolAttr, "false">:$local_bound ); @@ -357,6 +361,7 @@ def Tosa_TransposeConv2DOp : Tosa_InferShapedTypeOp<"transpose_conv2d"> { ); let builders = [Tosa_TransConvOpQuantInfoBuilder]; + let hasVerifier = 1; } //===----------------------------------------------------------------------===// @@ -1552,21 +1557,21 @@ def Tosa_PadOp : Tosa_InferShapedTypeOp<"pad"> { Example: ```mlir - %0 = arith.constant dense<[[1, 2], [3, 4]]> : tensor<2x2xi32> - tosa.pad %arg0, %0 : (tensor<1x2xf32>, tensor<2x2xi32>) -> (tensor<4x9xf32>) + %0 = arith.constant dense<[1, 2, 3, 4]> : tensor<4xi32> + tosa.pad %arg0, %0 : (tensor<1x2xf32>, tensor<4xi32>) -> (tensor<4x9xf32>) ``` Example 2: ```mlir - %0 = arith.constant dense<[[-1, 2], [3, 4]]> : tensor<2x2xi32> - tosa.pad %arg0, %0 : (tensor<1x2xf32>, tensor<2x2xi32>) -> (tensor<?x9xf32>) + %0 = arith.constant dense<[-1, 2, 3, 4]> : tensor<4xi32> + tosa.pad %arg0, %0 : (tensor<1x2xf32>, tensor<4xi32>) -> (tensor<?x9xf32>) ``` }]; let arguments = (ins Tosa_RankedTensor:$input1, - Tosa_Int32Or64Tensor:$padding, + TosaTensorRankOf<[Tosa_Int32Or64], [1]>:$padding, Optional<Tosa_ScalarTensor>:$pad_const, OptionalAttr<Tosa_PadOpQuantizationAttr>:$quantization_info ); @@ -1698,7 +1703,8 @@ def Tosa_TileOp : Tosa_InferShapedTypeOp<"tile"> { // Operator: transpose //===----------------------------------------------------------------------===// def Tosa_TransposeOp : Tosa_InferShapedTypeOp<"transpose", - [DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>]> { + [DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>, + AllElementTypesMatch<["input1", "output"]>]> { let summary = "Transpose operator"; let description = [{ diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td index a6d3163d4446..d3cc6e92bac2 100644 --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td @@ -65,17 +65,17 @@ def Tosa_Int32Or64 : AnyTypeOf<[Tosa_Int32, // int8 : symmetric per tensor/per channel, signed // int16 : symmetric per tensor, signed //===----------------------------------------------------------------------===// -def Tosa_QuantizedInt : AnyTypeOf<[ Tosa_QuantizedType<"uint8", [8], 0>, - Tosa_QuantizedType<"int4", [4, 0], 1>, - Tosa_QuantizedType<"int8", [8, 0], 1>, - Tosa_QuantizedType<"int16", [16, 0], 1>, - Tosa_QuantizedType<"int32", [32, 0], 1>]>; +def Tosa_QuantizedInt : AnyTypeOf<[Tosa_QuantizedType<"uint8", [8], 0>, + Tosa_QuantizedType<"int4", [4, 0], 1>, + Tosa_QuantizedType<"int8", [8, 0], 1>, + Tosa_QuantizedType<"int16", [16, 0], 1>, + Tosa_QuantizedType<"int32", [32, 0], 1>]>; //===----------------------------------------------------------------------===// // Multi-category types. //===----------------------------------------------------------------------===// def Tosa_AnyNumber : AnyTypeOf<[Tosa_Int, Tosa_QuantizedInt, AnyFloat], - "number">; + "number">; // For weight tensors from tosa::Conv2DOp, tosa::Conv3DOp, // tosa::DepthwiseConv2DOp, tosa::TransposeConv2DOp, tosa::FullyConnectedOp @@ -112,7 +112,7 @@ class TosaTensorRankOf<list<Type> allowedTypes, list<int> ranks> def Tosa_I1Tensor : TosaTensorOf<[I1]>; def Tosa_Int32Tensor : TosaTensorOf<[Tosa_Int32]>; -def Tosa_Int32Or64Tensor :TosaTensorOf<[Tosa_Int32Or64]>; +def Tosa_Int32Or64Tensor : TosaTensorOf<[Tosa_Int32Or64]>; def Tosa_FloatTensor : TosaTensorOf<[AnyFloat]>; diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td index 2aaa7fd4221a..4841f94de75f 100644 --- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td +++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td @@ -164,11 +164,9 @@ def XeGPU_SGMapAttr : XeGPUAttr<"SGMap", "sg_map"> { }]; let parameters = (ins ArrayRefParameter<"uint32_t">:$wi_layout, - ArrayRefParameter<"uint32_t">:$wi_data); + ArrayRefParameter<"uint32_t">:$wi_data + ); - let builders = [ - AttrBuilder<(ins)> - ]; let hasCustomAssemblyFormat = 1; let genVerifyDecl = 1; diff --git a/mlir/include/mlir/IR/Dialect.h b/mlir/include/mlir/IR/Dialect.h index f3e5f6d88c53..fb24a6895dab 100644 --- a/mlir/include/mlir/IR/Dialect.h +++ b/mlir/include/mlir/IR/Dialect.h @@ -368,7 +368,6 @@ private: DenseSet<std::pair<TypeID, TypeID>> unresolvedPromisedInterfaces; friend class DialectRegistry; - friend void registerDialect(); friend class MLIRContext; }; diff --git a/mlir/include/mlir/IR/Dominance.h b/mlir/include/mlir/IR/Dominance.h index 16d17b9c0f3d..9e1254c1dfe1 100644 --- a/mlir/include/mlir/IR/Dominance.h +++ b/mlir/include/mlir/IR/Dominance.h @@ -113,12 +113,12 @@ protected: llvm::PointerIntPair<DomTree *, 1, bool> getDominanceInfo(Region *region, bool needsDomTree) const; - /// Return "true" if the specified block A properly (post)dominates block B. - bool properlyDominatesImpl(Block *a, Block *b) const; - - /// Return "true" if the specified op A properly (post)dominates op B. - bool properlyDominatesImpl(Operation *a, Operation *b, - bool enclosingOpOk = true) const; + /// Return "true" if block iterator A properly (post)dominates block iterator + /// B. If `enclosingOk` is set, A is considered to (post)dominate B if A + /// encloses B. + bool properlyDominatesImpl(Block *aBlock, Block::iterator aIt, Block *bBlock, + Block::iterator bIt, + bool enclosingOk = true) const; /// A mapping of regions to their base dominator tree and a cached /// "hasSSADominance" bit. This map does not contain dominator trees for @@ -151,9 +151,7 @@ public: /// The `enclosingOpOk` flag says whether we should return true if the B op /// is enclosed by a region on A. bool properlyDominates(Operation *a, Operation *b, - bool enclosingOpOk = true) const { - return super::properlyDominatesImpl(a, b, enclosingOpOk); - } + bool enclosingOpOk = true) const; /// Return true if operation A dominates operation B, i.e. if A and B are the /// same operation or A properly dominates B. @@ -188,8 +186,17 @@ public: /// Graph regions have only a single block. To be consistent with "proper /// dominance" of ops, the single block is considered to properly dominate /// itself in a graph region. - bool properlyDominates(Block *a, Block *b) const { - return super::properlyDominatesImpl(a, b); + bool properlyDominates(Block *a, Block *b) const; + + bool properlyDominates(Block *aBlock, Block::iterator aIt, Block *bBlock, + Block::iterator bIt, bool enclosingOk = true) const { + return super::properlyDominatesImpl(aBlock, aIt, bBlock, bIt, enclosingOk); + } + + bool dominates(Block *aBlock, Block::iterator aIt, Block *bBlock, + Block::iterator bIt, bool enclosingOk = true) const { + return (aBlock == bBlock && aIt == bIt) || + super::properlyDominatesImpl(aBlock, aIt, bBlock, bIt, enclosingOk); } }; @@ -200,9 +207,7 @@ public: /// Return true if operation A properly postdominates operation B. bool properlyPostDominates(Operation *a, Operation *b, - bool enclosingOpOk = true) const { - return super::properlyDominatesImpl(a, b, enclosingOpOk); - } + bool enclosingOpOk = true) const; /// Return true if operation A postdominates operation B. bool postDominates(Operation *a, Operation *b) const { @@ -210,14 +215,24 @@ public: } /// Return true if the specified block A properly postdominates block B. - bool properlyPostDominates(Block *a, Block *b) const { - return super::properlyDominatesImpl(a, b); - } + bool properlyPostDominates(Block *a, Block *b) const; /// Return true if the specified block A postdominates block B. bool postDominates(Block *a, Block *b) const { return a == b || properlyPostDominates(a, b); } + + bool properlyPostDominates(Block *aBlock, Block::iterator aIt, Block *bBlock, + Block::iterator bIt, + bool enclosingOk = true) const { + return super::properlyDominatesImpl(aBlock, aIt, bBlock, bIt, enclosingOk); + } + + bool postDominates(Block *aBlock, Block::iterator aIt, Block *bBlock, + Block::iterator bIt, bool enclosingOk = true) const { + return (aBlock == bBlock && aIt == bIt) || + super::properlyDominatesImpl(aBlock, aIt, bBlock, bIt, enclosingOk); + } }; } // namespace mlir diff --git a/mlir/include/mlir/IR/OperationSupport.h b/mlir/include/mlir/IR/OperationSupport.h index ef5b8b178fbc..5eb2d69134ea 100644 --- a/mlir/include/mlir/IR/OperationSupport.h +++ b/mlir/include/mlir/IR/OperationSupport.h @@ -693,9 +693,6 @@ public: /// Return the dialect this operation is registered to. Dialect &getDialect() const { return *getImpl()->getDialect(); } - /// Use the specified object to parse this ops custom assembly format. - ParseResult parseAssembly(OpAsmParser &parser, OperationState &result) const; - /// Represent the operation name as an opaque pointer. (Used to support /// PointerLikeTypeTraits). static RegisteredOperationName getFromOpaquePointer(const void *pointer) { @@ -1169,16 +1166,20 @@ public: OpPrintingFlags &skipRegions(bool skip = true); /// Do not verify the operation when using custom operation printers. - OpPrintingFlags &assumeVerified(); + OpPrintingFlags &assumeVerified(bool enable = true); /// Use local scope when printing the operation. This allows for using the /// printer in a more localized and thread-safe setting, but may not /// necessarily be identical to what the IR will look like when dumping /// the full module. - OpPrintingFlags &useLocalScope(); + OpPrintingFlags &useLocalScope(bool enable = true); /// Print users of values as comments. - OpPrintingFlags &printValueUsers(); + OpPrintingFlags &printValueUsers(bool enable = true); + + /// Print unique SSA ID numbers for values, block arguments and naming + /// conflicts across all regions + OpPrintingFlags &printUniqueSSAIDs(bool enable = true); /// Return if the given ElementsAttr should be elided. bool shouldElideElementsAttr(ElementsAttr attr) const; diff --git a/mlir/include/mlir/Target/LLVMIR/ModuleImport.h b/mlir/include/mlir/Target/LLVMIR/ModuleImport.h index eea0647895b0..33c9af7c6335 100644 --- a/mlir/include/mlir/Target/LLVMIR/ModuleImport.h +++ b/mlir/include/mlir/Target/LLVMIR/ModuleImport.h @@ -319,9 +319,13 @@ private: /// Appends the converted result type and operands of `callInst` to the /// `types` and `operands` arrays. For indirect calls, the method additionally /// inserts the called function at the beginning of the `operands` array. + /// If `allowInlineAsm` is set to false (the default), it will return failure + /// if the called operand is an inline asm which isn't convertible to MLIR as + /// a value. LogicalResult convertCallTypeAndOperands(llvm::CallBase *callInst, SmallVectorImpl<Type> &types, - SmallVectorImpl<Value> &operands); + SmallVectorImpl<Value> &operands, + bool allowInlineAsm = false); /// Converts the parameter attributes attached to `func` and adds them to the /// `funcOp`. void convertParameterAttributes(llvm::Function *func, LLVMFuncOp funcOp, diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h index 28150e886913..9a6975dcf8df 100644 --- a/mlir/include/mlir/Transforms/DialectConversion.h +++ b/mlir/include/mlir/Transforms/DialectConversion.h @@ -181,6 +181,10 @@ public: /// converting (potentially multiple) block arguments that were the result of /// a signature conversion of a single block argument, to a single SSA value /// with the old block argument type. + /// + /// Note: Argument materializations are used only with the 1:N dialect + /// conversion driver. The 1:N dialect conversion driver will be removed soon + /// and so will be argument materializations. template <typename FnT, typename T = typename llvm::function_traits< std::decay_t<FnT>>::template arg_t<1>> void addArgumentMaterialization(FnT &&callback) { @@ -880,15 +884,7 @@ public: void replaceOp(Operation *op, Operation *newOp) override; /// Replace the given operation with the new value ranges. The number of op - /// results and value ranges must match. If an original SSA value is replaced - /// by multiple SSA values (i.e., a value range has more than 1 element), the - /// conversion driver will insert an argument materialization to convert the - /// N SSA values back into 1 SSA value of the original type. The given - /// operation is erased. - /// - /// Note: The argument materialization is a workaround until we have full 1:N - /// support in the dialect conversion. (It is going to disappear from both - /// `replaceOpWithMultiple` and `applySignatureConversion`.) + /// results and value ranges must match. The given operation is erased. void replaceOpWithMultiple(Operation *op, ArrayRef<ValueRange> newValues); /// PatternRewriter hook for erasing a dead operation. The uses of this @@ -1285,8 +1281,8 @@ struct ConversionConfig { // represented at the moment. RewriterBase::Listener *listener = nullptr; - /// If set to "true", the dialect conversion attempts to build source/target/ - /// argument materializations through the type converter API in lieu of + /// If set to "true", the dialect conversion attempts to build source/target + /// materializations through the type converter API in lieu of /// "builtin.unrealized_conversion_cast ops". The conversion process fails if /// at least one materialization could not be built. /// diff --git a/mlir/include/mlir/Transforms/LocationSnapshot.h b/mlir/include/mlir/Transforms/LocationSnapshot.h index ccfdbac007ac..cefe005d2c4c 100644 --- a/mlir/include/mlir/Transforms/LocationSnapshot.h +++ b/mlir/include/mlir/Transforms/LocationSnapshot.h @@ -51,18 +51,6 @@ void generateLocationsFromIR(raw_ostream &os, StringRef fileName, StringRef tag, LogicalResult generateLocationsFromIR(StringRef fileName, StringRef tag, Operation *op, OpPrintingFlags flags); -/// Create a pass to generate new locations by snapshotting the IR to the given -/// file, and using the printed locations within that file. If `filename` is -/// empty, a temporary file is generated instead. If a 'tag' is non-empty, the -/// generated locations are represented as a NameLoc with the given tag as the -/// name, and then fused with the existing locations. Otherwise, the existing -/// locations are replaced. -std::unique_ptr<Pass> createLocationSnapshotPass(OpPrintingFlags flags, - StringRef fileName = "", - StringRef tag = ""); -/// Overload utilizing pass options for initialization. -std::unique_ptr<Pass> createLocationSnapshotPass(); - } // namespace mlir #endif // MLIR_TRANSFORMS_LOCATIONSNAPSHOT_H diff --git a/mlir/include/mlir/Transforms/OneToNTypeConversion.h b/mlir/include/mlir/Transforms/OneToNTypeConversion.h index 7b4dd65cbff7..37a326818d64 100644 --- a/mlir/include/mlir/Transforms/OneToNTypeConversion.h +++ b/mlir/include/mlir/Transforms/OneToNTypeConversion.h @@ -6,6 +6,9 @@ // //===----------------------------------------------------------------------===// // +// Note: The 1:N dialect conversion is deprecated and will be removed soon. +// 1:N support has been added to the regular dialect conversion driver. +// // This file provides utils for implementing (poor-man's) dialect conversion // passes with 1:N type conversions. // @@ -119,6 +122,8 @@ public: /// types must be the same as the result types of the op) and the new values /// (i.e., the converted types must be the same as the types of the new /// values). + /// FIXME: The 1:N dialect conversion is deprecated and will be removed soon. + /// Use replaceOpWithMultiple() instead. void replaceOp(Operation *op, ValueRange newValues, const OneToNTypeMapping &resultMapping); using PatternRewriter::replaceOp; @@ -251,6 +256,9 @@ public: /// or illegal types; the function simply applies the given patterns and does /// not fail if some ops or types remain unconverted (i.e., the conversion is /// only "partial"). +/// FIXME: The 1:N dialect conversion is deprecated and will be removed soon. +/// 1:N support has been added to the regular dialect conversion driver. +/// Use applyPartialConversion() instead. LogicalResult applyPartialOneToNConversion(Operation *op, TypeConverter &typeConverter, const FrozenRewritePatternSet &patterns); @@ -259,6 +267,9 @@ applyPartialOneToNConversion(Operation *op, TypeConverter &typeConverter, /// FunctionOpInterface op with the given type converter. This only supports /// ops which use FunctionType to represent their type. This is intended to be /// used with the 1:N dialect conversion. +/// FIXME: The 1:N dialect conversion is deprecated and will be removed soon. +/// 1:N support has been added to the regular dialect conversion driver. +/// Use populateFunctionOpInterfaceTypeConversionPattern() instead. void populateOneToNFunctionOpInterfaceTypeConversionPattern( StringRef functionLikeOpName, const TypeConverter &converter, RewritePatternSet &patterns); diff --git a/mlir/include/mlir/Transforms/Passes.td b/mlir/include/mlir/Transforms/Passes.td index 000d9f697618..c4a8e7a81fa4 100644 --- a/mlir/include/mlir/Transforms/Passes.td +++ b/mlir/include/mlir/Transforms/Passes.td @@ -331,13 +331,21 @@ def LocationSnapshot : Pass<"snapshot-op-locations"> { ... loc(fused["original_source.cpp":1:1, "snapshot"("snapshot_source.mlir":10:10)]) ``` }]; - let constructor = "mlir::createLocationSnapshotPass()"; let options = [ Option<"fileName", "filename", "std::string", /*default=*/"", "The filename to print the generated IR">, Option<"tag", "tag", "std::string", /*default=*/"", "A tag to use when fusing the new locations with the " "original. If unset, the locations are replaced.">, + Option<"enableDebugInfo", "print-debuginfo", "bool", /*default=*/"false", + "Print debug info in MLIR output">, + Option<"printGenericOpForm", "print-op-generic", "bool", /*default=*/"false", + "Print the generic op form">, + Option<"useLocalScope", "print-local-scope", "bool", /*default=*/"false", + "Print with local scope and inline information (eliding " + "aliases for attributes, types, and locations">, + Option<"printPrettyDebugInfo", "pretty-debuginfo", "bool", /*default=*/"false", + "Print pretty debug info in MLIR output">, ]; } diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp index 86afa956398a..453d4f7c7e8b 100644 --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -272,13 +272,13 @@ struct PyAttrBuilderMap { static bool dunderContains(const std::string &attributeKind) { return PyGlobals::get().lookupAttributeBuilder(attributeKind).has_value(); } - static nb::callable dundeGetItemNamed(const std::string &attributeKind) { + static nb::callable dunderGetItemNamed(const std::string &attributeKind) { auto builder = PyGlobals::get().lookupAttributeBuilder(attributeKind); if (!builder) throw nb::key_error(attributeKind.c_str()); return *builder; } - static void dundeSetItemNamed(const std::string &attributeKind, + static void dunderSetItemNamed(const std::string &attributeKind, nb::callable func, bool replace) { PyGlobals::get().registerAttributeBuilder(attributeKind, std::move(func), replace); @@ -287,8 +287,8 @@ struct PyAttrBuilderMap { static void bind(nb::module_ &m) { nb::class_<PyAttrBuilderMap>(m, "AttrBuilder") .def_static("contains", &PyAttrBuilderMap::dunderContains) - .def_static("get", &PyAttrBuilderMap::dundeGetItemNamed) - .def_static("insert", &PyAttrBuilderMap::dundeSetItemNamed, + .def_static("get", &PyAttrBuilderMap::dunderGetItemNamed) + .def_static("insert", &PyAttrBuilderMap::dunderSetItemNamed, "attribute_kind"_a, "attr_builder"_a, "replace"_a = false, "Register an attribute builder for building MLIR " "attributes from python values."); @@ -2587,6 +2587,8 @@ private: //------------------------------------------------------------------------------ void mlir::python::populateIRCore(nb::module_ &m) { + // disable leak warnings which tend to be false positives. + nb::set_leak_warnings(false); //---------------------------------------------------------------------------- // Enums. //---------------------------------------------------------------------------- diff --git a/mlir/lib/CAPI/Dialect/LLVM.cpp b/mlir/lib/CAPI/Dialect/LLVM.cpp index 6ed82ba1a025..da450dd3fd8a 100644 --- a/mlir/lib/CAPI/Dialect/LLVM.cpp +++ b/mlir/lib/CAPI/Dialect/LLVM.cpp @@ -55,6 +55,16 @@ MlirType mlirLLVMFunctionTypeGet(MlirType resultType, intptr_t nArgumentTypes, unwrapList(nArgumentTypes, argumentTypes, argumentStorage), isVarArg)); } +intptr_t mlirLLVMFunctionTypeGetNumInputs(MlirType type) { + return llvm::cast<LLVM::LLVMFunctionType>(unwrap(type)).getNumParams(); +} + +MlirType mlirLLVMFunctionTypeGetInput(MlirType type, intptr_t pos) { + assert(pos >= 0 && "pos in array must be positive"); + return wrap(llvm::cast<LLVM::LLVMFunctionType>(unwrap(type)) + .getParamType(static_cast<unsigned>(pos))); +} + bool mlirTypeIsALLVMStructType(MlirType type) { return isa<LLVM::LLVMStructType>(unwrap(type)); } diff --git a/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp b/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp index 8672e7b849d9..d0ffb94f3f96 100644 --- a/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp +++ b/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp @@ -215,7 +215,6 @@ void mlir::cf::populateControlFlowToLLVMConversionPatterns( const LLVMTypeConverter &converter, RewritePatternSet &patterns) { // clang-format off patterns.add< - AssertOpLowering, BranchOpLowering, CondBranchOpLowering, SwitchOpLowering>(converter); @@ -258,6 +257,7 @@ struct ConvertControlFlowToLLVM LLVMTypeConverter converter(ctx, options); RewritePatternSet patterns(ctx); mlir::cf::populateControlFlowToLLVMConversionPatterns(converter, patterns); + mlir::cf::populateAssertToLLVMConversionPattern(converter, patterns); if (failed(applyPartialConversion(getOperation(), target, std::move(patterns)))) @@ -286,6 +286,7 @@ struct ControlFlowToLLVMDialectInterface RewritePatternSet &patterns) const final { mlir::cf::populateControlFlowToLLVMConversionPatterns(typeConverter, patterns); + mlir::cf::populateAssertToLLVMConversionPattern(typeConverter, patterns); } }; } // namespace diff --git a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp index b3c3fd4956d0..544fc57949e2 100644 --- a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp +++ b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp @@ -19,6 +19,59 @@ using namespace mlir; +LLVM::LLVMFuncOp mlir::getOrDefineFunction(gpu::GPUModuleOp moduleOp, + Location loc, OpBuilder &b, + StringRef name, + LLVM::LLVMFunctionType type) { + LLVM::LLVMFuncOp ret; + if (!(ret = moduleOp.template lookupSymbol<LLVM::LLVMFuncOp>(name))) { + OpBuilder::InsertionGuard guard(b); + b.setInsertionPointToStart(moduleOp.getBody()); + ret = b.create<LLVM::LLVMFuncOp>(loc, name, type, LLVM::Linkage::External); + } + return ret; +} + +static SmallString<16> getUniqueSymbolName(gpu::GPUModuleOp moduleOp, + StringRef prefix) { + // Get a unique global name. + unsigned stringNumber = 0; + SmallString<16> stringConstName; + do { + stringConstName.clear(); + (prefix + Twine(stringNumber++)).toStringRef(stringConstName); + } while (moduleOp.lookupSymbol(stringConstName)); + return stringConstName; +} + +LLVM::GlobalOp +mlir::getOrCreateStringConstant(OpBuilder &b, Location loc, + gpu::GPUModuleOp moduleOp, Type llvmI8, + StringRef namePrefix, StringRef str, + uint64_t alignment, unsigned addrSpace) { + llvm::SmallString<20> nullTermStr(str); + nullTermStr.push_back('\0'); // Null terminate for C + auto globalType = + LLVM::LLVMArrayType::get(llvmI8, nullTermStr.size_in_bytes()); + StringAttr attr = b.getStringAttr(nullTermStr); + + // Try to find existing global. + for (auto globalOp : moduleOp.getOps<LLVM::GlobalOp>()) + if (globalOp.getGlobalType() == globalType && globalOp.getConstant() && + globalOp.getValueAttr() == attr && + globalOp.getAlignment().value_or(0) == alignment && + globalOp.getAddrSpace() == addrSpace) + return globalOp; + + // Not found: create new global. + OpBuilder::InsertionGuard guard(b); + b.setInsertionPointToStart(moduleOp.getBody()); + SmallString<16> name = getUniqueSymbolName(moduleOp, namePrefix); + return b.create<LLVM::GlobalOp>(loc, globalType, + /*isConstant=*/true, LLVM::Linkage::Internal, + name, attr, alignment, addrSpace); +} + LogicalResult GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { @@ -328,61 +381,6 @@ GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor, return success(); } -static SmallString<16> getUniqueFormatGlobalName(gpu::GPUModuleOp moduleOp) { - const char formatStringPrefix[] = "printfFormat_"; - // Get a unique global name. - unsigned stringNumber = 0; - SmallString<16> stringConstName; - do { - stringConstName.clear(); - (formatStringPrefix + Twine(stringNumber++)).toStringRef(stringConstName); - } while (moduleOp.lookupSymbol(stringConstName)); - return stringConstName; -} - -/// Create an global that contains the given format string. If a global with -/// the same format string exists already in the module, return that global. -static LLVM::GlobalOp getOrCreateFormatStringConstant( - OpBuilder &b, Location loc, gpu::GPUModuleOp moduleOp, Type llvmI8, - StringRef str, uint64_t alignment = 0, unsigned addrSpace = 0) { - llvm::SmallString<20> formatString(str); - formatString.push_back('\0'); // Null terminate for C - auto globalType = - LLVM::LLVMArrayType::get(llvmI8, formatString.size_in_bytes()); - StringAttr attr = b.getStringAttr(formatString); - - // Try to find existing global. - for (auto globalOp : moduleOp.getOps<LLVM::GlobalOp>()) - if (globalOp.getGlobalType() == globalType && globalOp.getConstant() && - globalOp.getValueAttr() == attr && - globalOp.getAlignment().value_or(0) == alignment && - globalOp.getAddrSpace() == addrSpace) - return globalOp; - - // Not found: create new global. - OpBuilder::InsertionGuard guard(b); - b.setInsertionPointToStart(moduleOp.getBody()); - SmallString<16> name = getUniqueFormatGlobalName(moduleOp); - return b.create<LLVM::GlobalOp>(loc, globalType, - /*isConstant=*/true, LLVM::Linkage::Internal, - name, attr, alignment, addrSpace); -} - -template <typename T> -static LLVM::LLVMFuncOp getOrDefineFunction(T &moduleOp, const Location loc, - ConversionPatternRewriter &rewriter, - StringRef name, - LLVM::LLVMFunctionType type) { - LLVM::LLVMFuncOp ret; - if (!(ret = moduleOp.template lookupSymbol<LLVM::LLVMFuncOp>(name))) { - ConversionPatternRewriter::InsertionGuard guard(rewriter); - rewriter.setInsertionPointToStart(moduleOp.getBody()); - ret = rewriter.create<LLVM::LLVMFuncOp>(loc, name, type, - LLVM::Linkage::External); - } - return ret; -} - LogicalResult GPUPrintfOpToHIPLowering::matchAndRewrite( gpu::PrintfOp gpuPrintfOp, gpu::PrintfOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { @@ -420,8 +418,8 @@ LogicalResult GPUPrintfOpToHIPLowering::matchAndRewrite( Value printfDesc = printfBeginCall.getResult(); // Create the global op or find an existing one. - LLVM::GlobalOp global = getOrCreateFormatStringConstant( - rewriter, loc, moduleOp, llvmI8, adaptor.getFormat()); + LLVM::GlobalOp global = getOrCreateStringConstant( + rewriter, loc, moduleOp, llvmI8, "printfFormat_", adaptor.getFormat()); // Get a pointer to the format string's first element and pass it to printf() Value globalPtr = rewriter.create<LLVM::AddressOfOp>( @@ -502,9 +500,9 @@ LogicalResult GPUPrintfOpToLLVMCallLowering::matchAndRewrite( getOrDefineFunction(moduleOp, loc, rewriter, "printf", printfType); // Create the global op or find an existing one. - LLVM::GlobalOp global = getOrCreateFormatStringConstant( - rewriter, loc, moduleOp, llvmI8, adaptor.getFormat(), /*alignment=*/0, - addressSpace); + LLVM::GlobalOp global = getOrCreateStringConstant( + rewriter, loc, moduleOp, llvmI8, "printfFormat_", adaptor.getFormat(), + /*alignment=*/0, addressSpace); // Get a pointer to the format string's first element Value globalPtr = rewriter.create<LLVM::AddressOfOp>( @@ -546,8 +544,8 @@ LogicalResult GPUPrintfOpToVPrintfLowering::matchAndRewrite( getOrDefineFunction(moduleOp, loc, rewriter, "vprintf", vprintfType); // Create the global op or find an existing one. - LLVM::GlobalOp global = getOrCreateFormatStringConstant( - rewriter, loc, moduleOp, llvmI8, adaptor.getFormat()); + LLVM::GlobalOp global = getOrCreateStringConstant( + rewriter, loc, moduleOp, llvmI8, "printfFormat_", adaptor.getFormat()); // Get a pointer to the format string's first element Value globalPtr = rewriter.create<LLVM::AddressOfOp>(loc, global); diff --git a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h index 444a07a93ca3..e73a74845d2b 100644 --- a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h +++ b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h @@ -14,6 +14,27 @@ namespace mlir { +//===----------------------------------------------------------------------===// +// Helper Functions +//===----------------------------------------------------------------------===// + +/// Find or create an external function declaration in the given module. +LLVM::LLVMFuncOp getOrDefineFunction(gpu::GPUModuleOp moduleOp, Location loc, + OpBuilder &b, StringRef name, + LLVM::LLVMFunctionType type); + +/// Create a global that contains the given string. If a global with the same +/// string already exists in the module, return that global. +LLVM::GlobalOp getOrCreateStringConstant(OpBuilder &b, Location loc, + gpu::GPUModuleOp moduleOp, Type llvmI8, + StringRef namePrefix, StringRef str, + uint64_t alignment = 0, + unsigned addrSpace = 0); + +//===----------------------------------------------------------------------===// +// Lowering Patterns +//===----------------------------------------------------------------------===// + /// Lowering for gpu.dynamic.shared.memory to LLVM dialect. The pattern first /// create a 0-sized global array symbol similar as LLVM expects. It constructs /// a memref descriptor with these values and return it. diff --git a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp index e022d3ce6f63..2768929f460e 100644 --- a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp +++ b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp @@ -25,6 +25,7 @@ #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h" #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlow.h" +#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/GPU/IR/GPUDialect.h" #include "mlir/Dialect/GPU/Transforms/Passes.h" @@ -236,6 +237,103 @@ struct GPULaneIdOpToNVVM : ConvertOpToLLVMPattern<gpu::LaneIdOp> { } }; +/// Lowering of cf.assert into a conditional __assertfail. +struct AssertOpToAssertfailLowering + : public ConvertOpToLLVMPattern<cf::AssertOp> { + using ConvertOpToLLVMPattern<cf::AssertOp>::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(cf::AssertOp assertOp, cf::AssertOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + MLIRContext *ctx = rewriter.getContext(); + Location loc = assertOp.getLoc(); + Type i8Type = typeConverter->convertType(rewriter.getIntegerType(8)); + Type i32Type = typeConverter->convertType(rewriter.getIntegerType(32)); + Type i64Type = typeConverter->convertType(rewriter.getIntegerType(64)); + Type ptrType = LLVM::LLVMPointerType::get(ctx); + Type voidType = LLVM::LLVMVoidType::get(ctx); + + // Find or create __assertfail function declaration. + auto moduleOp = assertOp->getParentOfType<gpu::GPUModuleOp>(); + auto assertfailType = LLVM::LLVMFunctionType::get( + voidType, {ptrType, ptrType, i32Type, ptrType, i64Type}); + LLVM::LLVMFuncOp assertfailDecl = getOrDefineFunction( + moduleOp, loc, rewriter, "__assertfail", assertfailType); + assertfailDecl.setPassthroughAttr( + ArrayAttr::get(ctx, StringAttr::get(ctx, "noreturn"))); + + // Split blocks and insert conditional branch. + // ^before: + // ... + // cf.cond_br %condition, ^after, ^assert + // ^assert: + // cf.assert + // cf.br ^after + // ^after: + // ... + Block *beforeBlock = assertOp->getBlock(); + Block *assertBlock = + rewriter.splitBlock(beforeBlock, assertOp->getIterator()); + Block *afterBlock = + rewriter.splitBlock(assertBlock, ++assertOp->getIterator()); + rewriter.setInsertionPointToEnd(beforeBlock); + rewriter.create<cf::CondBranchOp>(loc, adaptor.getArg(), afterBlock, + assertBlock); + rewriter.setInsertionPointToEnd(assertBlock); + rewriter.create<cf::BranchOp>(loc, afterBlock); + + // Continue cf.assert lowering. + rewriter.setInsertionPoint(assertOp); + + // Populate file name, file number and function name from the location of + // the AssertOp. + StringRef fileName = "(unknown)"; + StringRef funcName = "(unknown)"; + int32_t fileLine = 0; + while (auto callSiteLoc = dyn_cast<CallSiteLoc>(loc)) + loc = callSiteLoc.getCallee(); + if (auto fileLineColLoc = dyn_cast<FileLineColRange>(loc)) { + fileName = fileLineColLoc.getFilename().strref(); + fileLine = fileLineColLoc.getStartLine(); + } else if (auto nameLoc = dyn_cast<NameLoc>(loc)) { + funcName = nameLoc.getName().strref(); + if (auto fileLineColLoc = + dyn_cast<FileLineColRange>(nameLoc.getChildLoc())) { + fileName = fileLineColLoc.getFilename().strref(); + fileLine = fileLineColLoc.getStartLine(); + } + } + + // Create constants. + auto getGlobal = [&](LLVM::GlobalOp global) { + // Get a pointer to the format string's first element. + Value globalPtr = rewriter.create<LLVM::AddressOfOp>( + loc, LLVM::LLVMPointerType::get(ctx, global.getAddrSpace()), + global.getSymNameAttr()); + Value start = + rewriter.create<LLVM::GEPOp>(loc, ptrType, global.getGlobalType(), + globalPtr, ArrayRef<LLVM::GEPArg>{0, 0}); + return start; + }; + Value assertMessage = getGlobal(getOrCreateStringConstant( + rewriter, loc, moduleOp, i8Type, "assert_message_", assertOp.getMsg())); + Value assertFile = getGlobal(getOrCreateStringConstant( + rewriter, loc, moduleOp, i8Type, "assert_file_", fileName)); + Value assertFunc = getGlobal(getOrCreateStringConstant( + rewriter, loc, moduleOp, i8Type, "assert_func_", funcName)); + Value assertLine = + rewriter.create<LLVM::ConstantOp>(loc, i32Type, fileLine); + Value c1 = rewriter.create<LLVM::ConstantOp>(loc, i64Type, 1); + + // Insert function call to __assertfail. + SmallVector<Value> arguments{assertMessage, assertFile, assertLine, + assertFunc, c1}; + rewriter.replaceOpWithNewOp<LLVM::CallOp>(assertOp, assertfailDecl, + arguments); + return success(); + } +}; + /// Import the GPU Ops to NVVM Patterns. #include "GPUToNVVM.cpp.inc" @@ -358,7 +456,8 @@ void mlir::populateGpuToNVVMConversionPatterns( using gpu::index_lowering::IndexKind; using gpu::index_lowering::IntrType; populateWithGenerated(patterns); - patterns.add<GPUPrintfOpToVPrintfLowering>(converter); + patterns.add<GPUPrintfOpToVPrintfLowering, AssertOpToAssertfailLowering>( + converter); patterns.add< gpu::index_lowering::OpLowering<gpu::ThreadIdOp, NVVM::ThreadIdXOp, NVVM::ThreadIdYOp, NVVM::ThreadIdZOp>>( diff --git a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp index d52a86987b1c..afebded1c3ea 100644 --- a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp +++ b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp @@ -47,7 +47,6 @@ #include "../GPUCommon/GPUOpsLowering.h" #include "../GPUCommon/IndexIntrinsicsOpLowering.h" -#include "../GPUCommon/OpToFuncCallLowering.h" namespace mlir { #define GEN_PASS_DEF_CONVERTGPUOPSTOROCDLOPS @@ -297,6 +296,7 @@ struct LowerGpuOpsToROCDLOpsPass populateVectorToLLVMConversionPatterns(converter, llvmPatterns); populateMathToLLVMConversionPatterns(converter, llvmPatterns); cf::populateControlFlowToLLVMConversionPatterns(converter, llvmPatterns); + cf::populateAssertToLLVMConversionPattern(converter, llvmPatterns); populateFuncToLLVMConversionPatterns(converter, llvmPatterns); populateFinalizeMemRefToLLVMConversionPatterns(converter, llvmPatterns); populateGpuToROCDLConversionPatterns(converter, llvmPatterns, runtime); @@ -346,16 +346,6 @@ void mlir::configureGpuToROCDLConversionLegality(ConversionTarget &target) { target.addLegalOp<gpu::YieldOp, gpu::GPUModuleOp>(); } -template <typename OpTy> -static void populateOpPatterns(const LLVMTypeConverter &converter, - RewritePatternSet &patterns, StringRef f32Func, - StringRef f64Func, StringRef f32ApproxFunc, - StringRef f16Func) { - patterns.add<ScalarizeVectorOpLowering<OpTy>>(converter); - patterns.add<OpToFuncCallLowering<OpTy>>(converter, f32Func, f32ApproxFunc, - f16Func); -} - void mlir::populateGpuToROCDLConversionPatterns( const LLVMTypeConverter &converter, RewritePatternSet &patterns, mlir::gpu::amd::Runtime runtime) { diff --git a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp index 49e2d9432866..72799e42cf3f 100644 --- a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp +++ b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp @@ -85,7 +85,7 @@ static Value unrankedMemRefMaterialization(OpBuilder &builder, UnrankedMemRefType resultType, ValueRange inputs, Location loc, const LLVMTypeConverter &converter) { - // An argument materialization must return a value of type + // A source materialization must return a value of type // `resultType`, so insert a cast from the memref descriptor type // (!llvm.struct) to the original memref type. Value packed = @@ -101,7 +101,7 @@ static Value rankedMemRefMaterialization(OpBuilder &builder, MemRefType resultType, ValueRange inputs, Location loc, const LLVMTypeConverter &converter) { - // An argument materialization must return a value of type `resultType`, + // A source materialization must return a value of type `resultType`, // so insert a cast from the memref descriptor type (!llvm.struct) to the // original memref type. Value packed = @@ -234,19 +234,9 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx, .getResult(0); }); - // Argument materializations convert from the new block argument types + // Source materializations convert from the new block argument types // (multiple SSA values that make up a memref descriptor) back to the // original block argument type. - addArgumentMaterialization([&](OpBuilder &builder, - UnrankedMemRefType resultType, - ValueRange inputs, Location loc) { - return unrankedMemRefMaterialization(builder, resultType, inputs, loc, - *this); - }); - addArgumentMaterialization([&](OpBuilder &builder, MemRefType resultType, - ValueRange inputs, Location loc) { - return rankedMemRefMaterialization(builder, resultType, inputs, loc, *this); - }); addSourceMaterialization([&](OpBuilder &builder, UnrankedMemRefType resultType, ValueRange inputs, Location loc) { diff --git a/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp b/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp index 58fd3d565fce..5d0003911bca 100644 --- a/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp +++ b/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp @@ -304,6 +304,7 @@ void ConvertOpenMPToLLVMPass::runOnOperation() { LLVMTypeConverter converter(&getContext()); arith::populateArithToLLVMConversionPatterns(converter, patterns); cf::populateControlFlowToLLVMConversionPatterns(converter, patterns); + cf::populateAssertToLLVMConversionPattern(converter, patterns); populateFinalizeMemRefToLLVMConversionPatterns(converter, patterns); populateFuncToLLVMConversionPatterns(converter, patterns); populateOpenMPToLLVMConversionPatterns(converter, patterns); diff --git a/mlir/lib/Conversion/SCFToEmitC/CMakeLists.txt b/mlir/lib/Conversion/SCFToEmitC/CMakeLists.txt index 79119d374f7a..af5493be8a4b 100644 --- a/mlir/lib/Conversion/SCFToEmitC/CMakeLists.txt +++ b/mlir/lib/Conversion/SCFToEmitC/CMakeLists.txt @@ -13,6 +13,7 @@ add_mlir_conversion_library(MLIRSCFToEmitC LINK_LIBS PUBLIC MLIRArithDialect MLIREmitCDialect + MLIREmitCTransforms MLIRSCFDialect MLIRTransforms ) diff --git a/mlir/lib/Conversion/SCFToEmitC/SCFToEmitC.cpp b/mlir/lib/Conversion/SCFToEmitC/SCFToEmitC.cpp index 67a43c43d608..92523ca4f12b 100644 --- a/mlir/lib/Conversion/SCFToEmitC/SCFToEmitC.cpp +++ b/mlir/lib/Conversion/SCFToEmitC/SCFToEmitC.cpp @@ -14,6 +14,7 @@ #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/EmitC/IR/EmitC.h" +#include "mlir/Dialect/EmitC/Transforms/TypeConversions.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" @@ -39,21 +40,22 @@ struct SCFToEmitCPass : public impl::SCFToEmitCBase<SCFToEmitCPass> { // Lower scf::for to emitc::for, implementing result values using // emitc::variable's updated within the loop body. -struct ForLowering : public OpRewritePattern<ForOp> { - using OpRewritePattern<ForOp>::OpRewritePattern; +struct ForLowering : public OpConversionPattern<ForOp> { + using OpConversionPattern<ForOp>::OpConversionPattern; - LogicalResult matchAndRewrite(ForOp forOp, - PatternRewriter &rewriter) const override; + LogicalResult + matchAndRewrite(ForOp forOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override; }; // Create an uninitialized emitc::variable op for each result of the given op. template <typename T> -static SmallVector<Value> createVariablesForResults(T op, - PatternRewriter &rewriter) { - SmallVector<Value> resultVariables; - +static LogicalResult +createVariablesForResults(T op, const TypeConverter *typeConverter, + ConversionPatternRewriter &rewriter, + SmallVector<Value> &resultVariables) { if (!op.getNumResults()) - return resultVariables; + return success(); Location loc = op->getLoc(); MLIRContext *context = op.getContext(); @@ -62,7 +64,9 @@ static SmallVector<Value> createVariablesForResults(T op, rewriter.setInsertionPoint(op); for (OpResult result : op.getResults()) { - Type resultType = result.getType(); + Type resultType = typeConverter->convertType(result.getType()); + if (!resultType) + return rewriter.notifyMatchFailure(op, "result type conversion failed"); Type varType = emitc::LValueType::get(resultType); emitc::OpaqueAttr noInit = emitc::OpaqueAttr::get(context, ""); emitc::VariableOp var = @@ -70,13 +74,13 @@ static SmallVector<Value> createVariablesForResults(T op, resultVariables.push_back(var); } - return resultVariables; + return success(); } // Create a series of assign ops assigning given values to given variables at // the current insertion point of given rewriter. -static void assignValues(ValueRange values, SmallVector<Value> &variables, - PatternRewriter &rewriter, Location loc) { +static void assignValues(ValueRange values, ValueRange variables, + ConversionPatternRewriter &rewriter, Location loc) { for (auto [value, var] : llvm::zip(values, variables)) rewriter.create<emitc::AssignOp>(loc, var, value); } @@ -89,18 +93,25 @@ SmallVector<Value> loadValues(const SmallVector<Value> &variables, }); } -static void lowerYield(SmallVector<Value> &resultVariables, - PatternRewriter &rewriter, scf::YieldOp yield) { +static LogicalResult lowerYield(Operation *op, ValueRange resultVariables, + ConversionPatternRewriter &rewriter, + scf::YieldOp yield) { Location loc = yield.getLoc(); - ValueRange operands = yield.getOperands(); OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPoint(yield); - assignValues(operands, resultVariables, rewriter, loc); + SmallVector<Value> yieldOperands; + if (failed(rewriter.getRemappedValues(yield.getOperands(), yieldOperands))) { + return rewriter.notifyMatchFailure(op, "failed to lower yield operands"); + } + + assignValues(yieldOperands, resultVariables, rewriter, loc); rewriter.create<emitc::YieldOp>(loc); rewriter.eraseOp(yield); + + return success(); } // Lower the contents of an scf::if/scf::index_switch regions to an @@ -108,27 +119,32 @@ static void lowerYield(SmallVector<Value> &resultVariables, // moved into the respective lowered region, but the scf::yield is replaced not // only with an emitc::yield, but also with a sequence of emitc::assign ops that // set the yielded values into the result variables. -static void lowerRegion(SmallVector<Value> &resultVariables, - PatternRewriter &rewriter, Region ®ion, - Region &loweredRegion) { +static LogicalResult lowerRegion(Operation *op, ValueRange resultVariables, + ConversionPatternRewriter &rewriter, + Region ®ion, Region &loweredRegion) { rewriter.inlineRegionBefore(region, loweredRegion, loweredRegion.end()); Operation *terminator = loweredRegion.back().getTerminator(); - lowerYield(resultVariables, rewriter, cast<scf::YieldOp>(terminator)); + return lowerYield(op, resultVariables, rewriter, + cast<scf::YieldOp>(terminator)); } -LogicalResult ForLowering::matchAndRewrite(ForOp forOp, - PatternRewriter &rewriter) const { +LogicalResult +ForLowering::matchAndRewrite(ForOp forOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { Location loc = forOp.getLoc(); // Create an emitc::variable op for each result. These variables will be // assigned to by emitc::assign ops within the loop body. - SmallVector<Value> resultVariables = - createVariablesForResults(forOp, rewriter); + SmallVector<Value> resultVariables; + if (failed(createVariablesForResults(forOp, getTypeConverter(), rewriter, + resultVariables))) + return rewriter.notifyMatchFailure(forOp, + "create variables for results failed"); - assignValues(forOp.getInits(), resultVariables, rewriter, loc); + assignValues(adaptor.getInitArgs(), resultVariables, rewriter, loc); emitc::ForOp loweredFor = rewriter.create<emitc::ForOp>( - loc, forOp.getLowerBound(), forOp.getUpperBound(), forOp.getStep()); + loc, adaptor.getLowerBound(), adaptor.getUpperBound(), adaptor.getStep()); Block *loweredBody = loweredFor.getBody(); @@ -143,13 +159,27 @@ LogicalResult ForLowering::matchAndRewrite(ForOp forOp, rewriter.restoreInsertionPoint(ip); + // Convert the original region types into the new types by adding unrealized + // casts in the beginning of the loop. This performs the conversion in place. + if (failed(rewriter.convertRegionTypes(&forOp.getRegion(), + *getTypeConverter(), nullptr))) { + return rewriter.notifyMatchFailure(forOp, "region types conversion failed"); + } + + // Register the replacements for the block arguments and inline the body of + // the scf.for loop into the body of the emitc::for loop. + Block *scfBody = &(forOp.getRegion().front()); SmallVector<Value> replacingValues; replacingValues.push_back(loweredFor.getInductionVar()); replacingValues.append(iterArgsValues.begin(), iterArgsValues.end()); + rewriter.mergeBlocks(scfBody, loweredBody, replacingValues); - rewriter.mergeBlocks(forOp.getBody(), loweredBody, replacingValues); - lowerYield(resultVariables, rewriter, - cast<scf::YieldOp>(loweredBody->getTerminator())); + auto result = lowerYield(forOp, resultVariables, rewriter, + cast<scf::YieldOp>(loweredBody->getTerminator())); + + if (failed(result)) { + return result; + } // Load variables into SSA values after the for loop. SmallVector<Value> resultValues = loadValues(resultVariables, rewriter, loc); @@ -160,38 +190,66 @@ LogicalResult ForLowering::matchAndRewrite(ForOp forOp, // Lower scf::if to emitc::if, implementing result values as emitc::variable's // updated within the then and else regions. -struct IfLowering : public OpRewritePattern<IfOp> { - using OpRewritePattern<IfOp>::OpRewritePattern; +struct IfLowering : public OpConversionPattern<IfOp> { + using OpConversionPattern<IfOp>::OpConversionPattern; - LogicalResult matchAndRewrite(IfOp ifOp, - PatternRewriter &rewriter) const override; + LogicalResult + matchAndRewrite(IfOp ifOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override; }; } // namespace -LogicalResult IfLowering::matchAndRewrite(IfOp ifOp, - PatternRewriter &rewriter) const { +LogicalResult +IfLowering::matchAndRewrite(IfOp ifOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { Location loc = ifOp.getLoc(); // Create an emitc::variable op for each result. These variables will be // assigned to by emitc::assign ops within the then & else regions. - SmallVector<Value> resultVariables = - createVariablesForResults(ifOp, rewriter); - - Region &thenRegion = ifOp.getThenRegion(); - Region &elseRegion = ifOp.getElseRegion(); + SmallVector<Value> resultVariables; + if (failed(createVariablesForResults(ifOp, getTypeConverter(), rewriter, + resultVariables))) + return rewriter.notifyMatchFailure(ifOp, + "create variables for results failed"); + + // Utility function to lower the contents of an scf::if region to an emitc::if + // region. The contents of the scf::if regions is moved into the respective + // emitc::if regions, but the scf::yield is replaced not only with an + // emitc::yield, but also with a sequence of emitc::assign ops that set the + // yielded values into the result variables. + auto lowerRegion = [&resultVariables, &rewriter, + &ifOp](Region ®ion, Region &loweredRegion) { + rewriter.inlineRegionBefore(region, loweredRegion, loweredRegion.end()); + Operation *terminator = loweredRegion.back().getTerminator(); + auto result = lowerYield(ifOp, resultVariables, rewriter, + cast<scf::YieldOp>(terminator)); + if (failed(result)) { + return result; + } + return success(); + }; + + Region &thenRegion = adaptor.getThenRegion(); + Region &elseRegion = adaptor.getElseRegion(); bool hasElseBlock = !elseRegion.empty(); auto loweredIf = - rewriter.create<emitc::IfOp>(loc, ifOp.getCondition(), false, false); + rewriter.create<emitc::IfOp>(loc, adaptor.getCondition(), false, false); Region &loweredThenRegion = loweredIf.getThenRegion(); - lowerRegion(resultVariables, rewriter, thenRegion, loweredThenRegion); + auto result = lowerRegion(thenRegion, loweredThenRegion); + if (failed(result)) { + return result; + } if (hasElseBlock) { Region &loweredElseRegion = loweredIf.getElseRegion(); - lowerRegion(resultVariables, rewriter, elseRegion, loweredElseRegion); + auto result = lowerRegion(elseRegion, loweredElseRegion); + if (failed(result)) { + return result; + } } rewriter.setInsertionPointAfter(ifOp); @@ -203,37 +261,46 @@ LogicalResult IfLowering::matchAndRewrite(IfOp ifOp, // Lower scf::index_switch to emitc::switch, implementing result values as // emitc::variable's updated within the case and default regions. -struct IndexSwitchOpLowering : public OpRewritePattern<IndexSwitchOp> { - using OpRewritePattern<IndexSwitchOp>::OpRewritePattern; +struct IndexSwitchOpLowering : public OpConversionPattern<IndexSwitchOp> { + using OpConversionPattern::OpConversionPattern; - LogicalResult matchAndRewrite(IndexSwitchOp indexSwitchOp, - PatternRewriter &rewriter) const override; + LogicalResult + matchAndRewrite(IndexSwitchOp indexSwitchOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override; }; -LogicalResult -IndexSwitchOpLowering::matchAndRewrite(IndexSwitchOp indexSwitchOp, - PatternRewriter &rewriter) const { +LogicalResult IndexSwitchOpLowering::matchAndRewrite( + IndexSwitchOp indexSwitchOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { Location loc = indexSwitchOp.getLoc(); // Create an emitc::variable op for each result. These variables will be // assigned to by emitc::assign ops within the case and default regions. - SmallVector<Value> resultVariables = - createVariablesForResults(indexSwitchOp, rewriter); + SmallVector<Value> resultVariables; + if (failed(createVariablesForResults(indexSwitchOp, getTypeConverter(), + rewriter, resultVariables))) { + return rewriter.notifyMatchFailure(indexSwitchOp, + "create variables for results failed"); + } auto loweredSwitch = rewriter.create<emitc::SwitchOp>( - loc, indexSwitchOp.getArg(), indexSwitchOp.getCases(), - indexSwitchOp.getNumCases()); + loc, adaptor.getArg(), adaptor.getCases(), indexSwitchOp.getNumCases()); // Lowering all case regions. - for (auto pair : llvm::zip(indexSwitchOp.getCaseRegions(), - loweredSwitch.getCaseRegions())) { - lowerRegion(resultVariables, rewriter, std::get<0>(pair), - std::get<1>(pair)); + for (auto pair : + llvm::zip(adaptor.getCaseRegions(), loweredSwitch.getCaseRegions())) { + if (failed(lowerRegion(indexSwitchOp, resultVariables, rewriter, + *std::get<0>(pair), std::get<1>(pair)))) { + return failure(); + } } // Lowering default region. - lowerRegion(resultVariables, rewriter, indexSwitchOp.getDefaultRegion(), - loweredSwitch.getDefaultRegion()); + if (failed(lowerRegion(indexSwitchOp, resultVariables, rewriter, + adaptor.getDefaultRegion(), + loweredSwitch.getDefaultRegion()))) { + return failure(); + } rewriter.setInsertionPointAfter(indexSwitchOp); SmallVector<Value> results = loadValues(resultVariables, rewriter, loc); @@ -242,15 +309,22 @@ IndexSwitchOpLowering::matchAndRewrite(IndexSwitchOp indexSwitchOp, return success(); } -void mlir::populateSCFToEmitCConversionPatterns(RewritePatternSet &patterns) { - patterns.add<ForLowering>(patterns.getContext()); - patterns.add<IfLowering>(patterns.getContext()); - patterns.add<IndexSwitchOpLowering>(patterns.getContext()); +void mlir::populateSCFToEmitCConversionPatterns(RewritePatternSet &patterns, + TypeConverter &typeConverter) { + patterns.add<ForLowering>(typeConverter, patterns.getContext()); + patterns.add<IfLowering>(typeConverter, patterns.getContext()); + patterns.add<IndexSwitchOpLowering>(typeConverter, patterns.getContext()); } void SCFToEmitCPass::runOnOperation() { RewritePatternSet patterns(&getContext()); - populateSCFToEmitCConversionPatterns(patterns); + TypeConverter typeConverter; + // Fallback converter + // See note https://mlir.llvm.org/docs/DialectConversion/#type-converter + // Type converters are called most to least recently inserted + typeConverter.addConversion([](Type t) { return t; }); + populateEmitCSizeTTypeConversions(typeConverter); + populateSCFToEmitCConversionPatterns(patterns, typeConverter); // Configure conversion to lower out SCF operations. ConversionTarget target(getContext()); diff --git a/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp b/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp index 6f085cb6ed06..b5a0da15e780 100644 --- a/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp +++ b/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp @@ -338,11 +338,6 @@ public: padOp, "tosa.pad was unable to determine the pad constant value."); } - Value lowIndex = - rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(0)); - Value highIndex = - rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(1)); - SmallVector<OpFoldResult, 3> lowValues; SmallVector<OpFoldResult, 3> highValues; @@ -350,11 +345,12 @@ public: highValues.reserve(rank); for (int i = 0; i < rank; i++) { - Value inputIndex = rewriter.create<arith::ConstantIndexOp>(loc, i); + Value lowIndex = rewriter.create<arith::ConstantIndexOp>(loc, 2 * i); + Value highIndex = rewriter.create<arith::ConstantIndexOp>(loc, 2 * i + 1); Value lowVal = rewriter.createOrFold<tensor::ExtractOp>( - loc, padding, ValueRange({inputIndex, lowIndex})); + loc, padding, ValueRange({lowIndex})); Value highVal = rewriter.createOrFold<tensor::ExtractOp>( - loc, padding, ValueRange({inputIndex, highIndex})); + loc, padding, ValueRange({highIndex})); lowVal = rewriter.createOrFold<arith::IndexCastOp>( loc, rewriter.getIndexType(), lowVal); diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp index dceebbfec586..b45829bcf6d2 100644 --- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp +++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp @@ -4520,6 +4520,10 @@ void AffineDelinearizeIndexOp::build(OpBuilder &odsBuilder, OperationState &odsState, Value linearIndex, ValueRange basis, bool hasOuterBound) { + if (hasOuterBound && !basis.empty() && basis.front() == nullptr) { + hasOuterBound = false; + basis = basis.drop_front(); + } SmallVector<Value> dynamicBasis; SmallVector<int64_t> staticBasis; dispatchIndexOpFoldResults(getAsOpFoldResult(basis), dynamicBasis, @@ -4533,6 +4537,10 @@ void AffineDelinearizeIndexOp::build(OpBuilder &odsBuilder, Value linearIndex, ArrayRef<OpFoldResult> basis, bool hasOuterBound) { + if (hasOuterBound && !basis.empty() && basis.front() == OpFoldResult()) { + hasOuterBound = false; + basis = basis.drop_front(); + } SmallVector<Value> dynamicBasis; SmallVector<int64_t> staticBasis; dispatchIndexOpFoldResults(basis, dynamicBasis, staticBasis); @@ -4654,6 +4662,13 @@ SmallVector<OpFoldResult> AffineDelinearizeIndexOp::getEffectiveBasis() { return getMixedValues(getStaticBasis(), getDynamicBasis(), builder); } +SmallVector<OpFoldResult> AffineDelinearizeIndexOp::getPaddedBasis() { + SmallVector<OpFoldResult> ret = getMixedBasis(); + if (!hasOuterBound()) + ret.insert(ret.begin(), OpFoldResult()); + return ret; +} + namespace { // Drops delinearization indices that correspond to unit-extent basis @@ -4672,25 +4687,27 @@ struct DropUnitExtentBasis return zero.value(); }; - bool hasOuterBound = delinearizeOp.hasOuterBound(); // Replace all indices corresponding to unit-extent basis with 0. // Remaining basis can be used to get a new `affine.delinearize_index` op. SmallVector<OpFoldResult> newBasis; - for (auto [index, basis] : llvm::enumerate(delinearizeOp.getMixedBasis())) { - std::optional<int64_t> basisVal = getConstantIntValue(basis); + for (auto [index, basis] : + llvm::enumerate(delinearizeOp.getPaddedBasis())) { + std::optional<int64_t> basisVal = + basis ? getConstantIntValue(basis) : std::nullopt; if (basisVal && *basisVal == 1) - replacements[index + (hasOuterBound ? 0 : 1)] = getZero(); + replacements[index] = getZero(); else newBasis.push_back(basis); } - if (newBasis.size() == delinearizeOp.getStaticBasis().size()) + if (newBasis.size() == delinearizeOp.getNumResults()) return rewriter.notifyMatchFailure(delinearizeOp, "no unit basis elements"); - if (!newBasis.empty() || !hasOuterBound) { + if (!newBasis.empty()) { + // Will drop the leading nullptr from `basis` if there was no outer bound. auto newDelinearizeOp = rewriter.create<affine::AffineDelinearizeIndexOp>( - loc, delinearizeOp.getLinearIndex(), newBasis, hasOuterBound); + loc, delinearizeOp.getLinearIndex(), newBasis); int newIndex = 0; // Map back the new delinearized indices to the values they replace. for (auto &replacement : replacements) { @@ -4871,6 +4888,8 @@ void AffineLinearizeIndexOp::build(OpBuilder &odsBuilder, OperationState &odsState, ValueRange multiIndex, ValueRange basis, bool disjoint) { + if (!basis.empty() && basis.front() == Value()) + basis = basis.drop_front(); SmallVector<Value> dynamicBasis; SmallVector<int64_t> staticBasis; dispatchIndexOpFoldResults(getAsOpFoldResult(basis), dynamicBasis, @@ -4883,6 +4902,8 @@ void AffineLinearizeIndexOp::build(OpBuilder &odsBuilder, ValueRange multiIndex, ArrayRef<OpFoldResult> basis, bool disjoint) { + if (!basis.empty() && basis.front() == OpFoldResult()) + basis = basis.drop_front(); SmallVector<Value> dynamicBasis; SmallVector<int64_t> staticBasis; dispatchIndexOpFoldResults(basis, dynamicBasis, staticBasis); @@ -4965,7 +4986,14 @@ SmallVector<OpFoldResult> AffineLinearizeIndexOp::getEffectiveBasis() { builder); } - return ::mlir::getMixedValues(getStaticBasis(), getDynamicBasis(), builder); + return getMixedValues(getStaticBasis(), getDynamicBasis(), builder); +} + +SmallVector<OpFoldResult> AffineLinearizeIndexOp::getPaddedBasis() { + SmallVector<OpFoldResult> ret = getMixedBasis(); + if (!hasOuterBound()) + ret.insert(ret.begin(), OpFoldResult()); + return ret; } namespace { @@ -5027,38 +5055,228 @@ struct DropLinearizeUnitComponentsIfDisjointOrZero final } }; -/// Cancel out linearize_index(delinearize_index(x, B), B). +/// Return the product of `terms`, creating an `affine.apply` if any of them are +/// non-constant values. If any of `terms` is `nullptr`, return `nullptr`. +static OpFoldResult computeProduct(Location loc, OpBuilder &builder, + ArrayRef<OpFoldResult> terms) { + int64_t nDynamic = 0; + SmallVector<Value> dynamicPart; + AffineExpr result = builder.getAffineConstantExpr(1); + for (OpFoldResult term : terms) { + if (!term) + return term; + std::optional<int64_t> maybeConst = getConstantIntValue(term); + if (maybeConst) { + result = result * builder.getAffineConstantExpr(*maybeConst); + } else { + dynamicPart.push_back(term.get<Value>()); + result = result * builder.getAffineSymbolExpr(nDynamic++); + } + } + if (auto constant = dyn_cast<AffineConstantExpr>(result)) + return getAsIndexOpFoldResult(builder.getContext(), constant.getValue()); + return builder.create<AffineApplyOp>(loc, result, dynamicPart).getResult(); +} + +/// If conseceutive outputs of a delinearize_index are linearized with the same +/// bounds, canonicalize away the redundant arithmetic. +/// +/// That is, if we have +/// ``` +/// %s:N = affine.delinearize_index %x into (...a, B1, B2, ... BK, ...b) +/// %t = affine.linearize_index [...c, %s#I, %s#(I + 1), ... %s#(I+K-1), ...d] +/// by (...e, B1, B2, ..., BK, ...f) +/// ``` /// -/// That is, rewrite +/// We can rewrite this to /// ``` -/// %0:N = affine.delinearize_index %x by (%b1, %b2, ... %bN) -/// %y = affine.linearize_index [%0#0, %0#1, ... %0#(N-1)] by (%b1, %b2, ... -/// %bN) +/// B = B1 * B2 ... BK +/// %sMerged:(N-K+1) affine.delinearize_index %x into (...a, B, ...b) +/// %t = affine.linearize_index [...c, %s#I, ...d] by (...e, B, ...f) /// ``` -/// to replacing `%y` with `%x`. -struct CancelLinearizeOfDelinearizeExact final +/// where we replace all results of %s unaffected by the change with results +/// from %sMerged. +/// +/// As a special case, if all results of the delinearize are merged in this way +/// we can replace those usages with %x, thus cancelling the delinearization +/// entirely, as in +/// ``` +/// %s:3 = affine.delinearize_index %x into (2, 4, 8) +/// %t = affine.linearize_index [%s#0, %s#1, %s#2, %c0] by (2, 4, 8, 16) +/// ``` +/// becoming `%t = affine.linearize_index [%x, %c0] by (64, 16)` +struct CancelLinearizeOfDelinearizePortion final : OpRewritePattern<affine::AffineLinearizeIndexOp> { using OpRewritePattern::OpRewritePattern; +private: + // Struct representing a case where the cancellation pattern + // applies. A `Match` means that `length` inputs to the linearize operation + // starting at `linStart` can be cancelled with `length` outputs of + // `delinearize`, starting from `delinStart`. + struct Match { + AffineDelinearizeIndexOp delinearize; + unsigned linStart = 0; + unsigned delinStart = 0; + unsigned length = 0; + }; + +public: LogicalResult matchAndRewrite(affine::AffineLinearizeIndexOp linearizeOp, PatternRewriter &rewriter) const override { - auto delinearizeOp = linearizeOp.getMultiIndex() - .front() - .getDefiningOp<affine::AffineDelinearizeIndexOp>(); - if (!delinearizeOp) - return rewriter.notifyMatchFailure( - linearizeOp, "last entry doesn't come from a delinearize"); + SmallVector<Match> matches; + + const SmallVector<OpFoldResult> linBasis = linearizeOp.getPaddedBasis(); + ArrayRef<OpFoldResult> linBasisRef = linBasis; + + ValueRange multiIndex = linearizeOp.getMultiIndex(); + unsigned numLinArgs = multiIndex.size(); + unsigned linArgIdx = 0; + // We only want to replace one run from the same delinearize op per + // pattern invocation lest we run into invalidation issues. + llvm::SmallPtrSet<Operation *, 2> alreadyMatchedDelinearize; + while (linArgIdx < numLinArgs) { + auto asResult = dyn_cast<OpResult>(multiIndex[linArgIdx]); + if (!asResult) { + linArgIdx++; + continue; + } - if (linearizeOp.getEffectiveBasis() != delinearizeOp.getEffectiveBasis()) - return rewriter.notifyMatchFailure( - linearizeOp, "basis of linearize and delinearize don't match exactly " - "(excluding outer bounds)"); + auto delinearizeOp = + dyn_cast<AffineDelinearizeIndexOp>(asResult.getOwner()); + if (!delinearizeOp) { + linArgIdx++; + continue; + } + + /// Result 0 of the delinearize and argument 0 of the linearize can + /// leave their maximum value unspecified. However, even if this happens + /// we can still sometimes start the match process. Specifically, if + /// - The argument we're matching is result 0 and argument 0 (so the + /// bounds don't matter). For example, + /// + /// %0:2 = affine.delinearize_index %x into (8) : index, index + /// %1 = affine.linearize_index [%s#0, %s#1, ...] (8, ...) + /// allows cancellation + /// - The delinearization doesn't specify a bound, but the linearization + /// is `disjoint`, which asserts that the bound on the linearization is + /// correct. + unsigned delinArgIdx = asResult.getResultNumber(); + SmallVector<OpFoldResult> delinBasis = delinearizeOp.getPaddedBasis(); + OpFoldResult firstDelinBound = delinBasis[delinArgIdx]; + OpFoldResult firstLinBound = linBasis[linArgIdx]; + bool boundsMatch = firstDelinBound == firstLinBound; + bool bothAtFront = linArgIdx == 0 && delinArgIdx == 0; + bool knownByDisjoint = + linearizeOp.getDisjoint() && delinArgIdx == 0 && !firstDelinBound; + if (!boundsMatch && !bothAtFront && !knownByDisjoint) { + linArgIdx++; + continue; + } + + unsigned j = 1; + unsigned numDelinOuts = delinearizeOp.getNumResults(); + for (; j + linArgIdx < numLinArgs && j + delinArgIdx < numDelinOuts; + ++j) { + if (multiIndex[linArgIdx + j] != + delinearizeOp.getResult(delinArgIdx + j)) + break; + if (linBasis[linArgIdx + j] != delinBasis[delinArgIdx + j]) + break; + } + // If there're multiple matches against the same delinearize_index, + // only rewrite the first one we find to prevent invalidations. The next + // ones will be taken care of by subsequent pattern invocations. + if (j <= 1 || !alreadyMatchedDelinearize.insert(delinearizeOp).second) { + linArgIdx++; + continue; + } + matches.push_back(Match{delinearizeOp, linArgIdx, delinArgIdx, j}); + linArgIdx += j; + } - if (delinearizeOp.getResults() != linearizeOp.getMultiIndex()) + if (matches.empty()) return rewriter.notifyMatchFailure( - linearizeOp, "not all indices come from delinearize"); + linearizeOp, "no run of delinearize outputs to deal with"); + + // Record all the delinearize replacements so we can do them after creating + // the new linearization operation, since the new operation might use + // outputs of something we're replacing. + SmallVector<SmallVector<Value>> delinearizeReplacements; + + SmallVector<Value> newIndex; + newIndex.reserve(numLinArgs); + SmallVector<OpFoldResult> newBasis; + newBasis.reserve(numLinArgs); + unsigned prevMatchEnd = 0; + for (Match m : matches) { + unsigned gap = m.linStart - prevMatchEnd; + llvm::append_range(newIndex, multiIndex.slice(prevMatchEnd, gap)); + llvm::append_range(newBasis, linBasisRef.slice(prevMatchEnd, gap)); + // Update here so we don't forget this during early continues + prevMatchEnd = m.linStart + m.length; + + PatternRewriter::InsertionGuard g(rewriter); + rewriter.setInsertionPoint(m.delinearize); + + ArrayRef<OpFoldResult> basisToMerge = + linBasisRef.slice(m.linStart, m.length); + // We use the slice from the linearize's basis above because of the + // "bounds inferred from `disjoint`" case above. + OpFoldResult newSize = + computeProduct(linearizeOp.getLoc(), rewriter, basisToMerge); + + // Trivial case where we can just skip past the delinearize all together + if (m.length == m.delinearize.getNumResults()) { + newIndex.push_back(m.delinearize.getLinearIndex()); + newBasis.push_back(newSize); + // Pad out set of replacements so we don't do anything with this one. + delinearizeReplacements.push_back(SmallVector<Value>()); + continue; + } + + SmallVector<Value> newDelinResults; + SmallVector<OpFoldResult> newDelinBasis = m.delinearize.getPaddedBasis(); + newDelinBasis.erase(newDelinBasis.begin() + m.delinStart, + newDelinBasis.begin() + m.delinStart + m.length); + newDelinBasis.insert(newDelinBasis.begin() + m.delinStart, newSize); + auto newDelinearize = rewriter.create<AffineDelinearizeIndexOp>( + m.delinearize.getLoc(), m.delinearize.getLinearIndex(), + newDelinBasis); + + // Since there may be other uses of the indices we just merged together, + // create a residual affine.delinearize_index that delinearizes the + // merged output into its component parts. + Value combinedElem = newDelinearize.getResult(m.delinStart); + auto residualDelinearize = rewriter.create<AffineDelinearizeIndexOp>( + m.delinearize.getLoc(), combinedElem, basisToMerge); + + // Swap all the uses of the unaffected delinearize outputs to the new + // delinearization so that the old code can be removed if this + // linearize_index is the only user of the merged results. + llvm::append_range(newDelinResults, + newDelinearize.getResults().take_front(m.delinStart)); + llvm::append_range(newDelinResults, residualDelinearize.getResults()); + llvm::append_range( + newDelinResults, + newDelinearize.getResults().drop_front(m.delinStart + 1)); + + delinearizeReplacements.push_back(newDelinResults); + newIndex.push_back(combinedElem); + newBasis.push_back(newSize); + } + llvm::append_range(newIndex, multiIndex.drop_front(prevMatchEnd)); + llvm::append_range(newBasis, linBasisRef.drop_front(prevMatchEnd)); + rewriter.replaceOpWithNewOp<AffineLinearizeIndexOp>( + linearizeOp, newIndex, newBasis, linearizeOp.getDisjoint()); + + for (auto [m, newResults] : + llvm::zip_equal(matches, delinearizeReplacements)) { + if (newResults.empty()) + continue; + rewriter.replaceOp(m.delinearize, newResults); + } - rewriter.replaceOp(linearizeOp, delinearizeOp.getLinearIndex()); return success(); } }; @@ -5096,7 +5314,7 @@ struct DropLinearizeLeadingZero final void affine::AffineLinearizeIndexOp::getCanonicalizationPatterns( RewritePatternSet &patterns, MLIRContext *context) { - patterns.add<CancelLinearizeOfDelinearizeExact, DropLinearizeLeadingZero, + patterns.add<CancelLinearizeOfDelinearizePortion, DropLinearizeLeadingZero, DropLinearizeUnitComponentsIfDisjointOrZero>(context); } diff --git a/mlir/lib/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.cpp b/mlir/lib/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.cpp index 82a9fb0d4908..e93b99b4f498 100644 --- a/mlir/lib/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.cpp @@ -91,6 +91,64 @@ struct AffineMaxOpInterface }; }; +struct AffineDelinearizeIndexOpInterface + : public ValueBoundsOpInterface::ExternalModel< + AffineDelinearizeIndexOpInterface, AffineDelinearizeIndexOp> { + void populateBoundsForIndexValue(Operation *rawOp, Value value, + ValueBoundsConstraintSet &cstr) const { + auto op = cast<AffineDelinearizeIndexOp>(rawOp); + auto result = cast<OpResult>(value); + assert(result.getOwner() == rawOp && + "bounded value isn't a result of this delinearize_index"); + unsigned resIdx = result.getResultNumber(); + + AffineExpr linearIdx = cstr.getExpr(op.getLinearIndex()); + + SmallVector<OpFoldResult> basis = op.getPaddedBasis(); + AffineExpr divisor = cstr.getExpr(1); + for (OpFoldResult basisElem : llvm::drop_begin(basis, resIdx + 1)) + divisor = divisor * cstr.getExpr(basisElem); + + if (resIdx == 0) { + cstr.bound(value) == linearIdx.floorDiv(divisor); + if (!basis.front().isNull()) + cstr.bound(value) < cstr.getExpr(basis.front()); + return; + } + AffineExpr thisBasis = cstr.getExpr(basis[resIdx]); + cstr.bound(value) == (linearIdx % (thisBasis * divisor)).floorDiv(divisor); + } +}; + +struct AffineLinearizeIndexOpInterface + : public ValueBoundsOpInterface::ExternalModel< + AffineLinearizeIndexOpInterface, AffineLinearizeIndexOp> { + void populateBoundsForIndexValue(Operation *rawOp, Value value, + ValueBoundsConstraintSet &cstr) const { + auto op = cast<AffineLinearizeIndexOp>(rawOp); + assert(value == op.getResult() && + "value isn't the result of this linearize"); + + AffineExpr bound = cstr.getExpr(0); + AffineExpr stride = cstr.getExpr(1); + SmallVector<OpFoldResult> basis = op.getPaddedBasis(); + OperandRange multiIndex = op.getMultiIndex(); + unsigned numArgs = multiIndex.size(); + for (auto [revArgNum, length] : llvm::enumerate(llvm::reverse(basis))) { + unsigned argNum = numArgs - (revArgNum + 1); + if (argNum == 0) + break; + OpFoldResult indexAsFoldRes = getAsOpFoldResult(multiIndex[argNum]); + bound = bound + cstr.getExpr(indexAsFoldRes) * stride; + stride = stride * cstr.getExpr(length); + } + bound = bound + cstr.getExpr(op.getMultiIndex().front()) * stride; + cstr.bound(value) == bound; + if (op.getDisjoint() && !basis.front().isNull()) { + cstr.bound(value) < stride *cstr.getExpr(basis.front()); + } + } +}; } // namespace } // namespace mlir @@ -100,6 +158,10 @@ void mlir::affine::registerValueBoundsOpInterfaceExternalModels( AffineApplyOp::attachInterface<AffineApplyOpInterface>(*ctx); AffineMaxOp::attachInterface<AffineMaxOpInterface>(*ctx); AffineMinOp::attachInterface<AffineMinOpInterface>(*ctx); + AffineDelinearizeIndexOp::attachInterface< + AffineDelinearizeIndexOpInterface>(*ctx); + AffineLinearizeIndexOp::attachInterface<AffineLinearizeIndexOpInterface>( + *ctx); }); } diff --git a/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp b/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp index 0f2c889d4f39..4e02559a0894 100644 --- a/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp +++ b/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp @@ -919,8 +919,9 @@ static void generateUnrolledLoop( // 'forOp'. auto builder = OpBuilder::atBlockTerminator(loopBodyBlock); + constexpr auto defaultAnnotateFn = [](unsigned, Operation *, OpBuilder) {}; if (!annotateFn) - annotateFn = [](unsigned, Operation *, OpBuilder) {}; + annotateFn = defaultAnnotateFn; // Keep a pointer to the last non-terminator operation in the original block // so that we know what to clone (since we are doing this in-place). diff --git a/mlir/lib/Dialect/Affine/Utils/Utils.cpp b/mlir/lib/Dialect/Affine/Utils/Utils.cpp index 4d3ead20fb5c..9e3257a62b12 100644 --- a/mlir/lib/Dialect/Affine/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Affine/Utils/Utils.cpp @@ -51,12 +51,14 @@ public: loc(loc) {} template <typename OpTy> - Value buildBinaryExpr(AffineBinaryOpExpr expr) { + Value buildBinaryExpr(AffineBinaryOpExpr expr, + arith::IntegerOverflowFlags overflowFlags = + arith::IntegerOverflowFlags::none) { auto lhs = visit(expr.getLHS()); auto rhs = visit(expr.getRHS()); if (!lhs || !rhs) return nullptr; - auto op = builder.create<OpTy>(loc, lhs, rhs); + auto op = builder.create<OpTy>(loc, lhs, rhs, overflowFlags); return op.getResult(); } @@ -65,7 +67,8 @@ public: } Value visitMulExpr(AffineBinaryOpExpr expr) { - return buildBinaryExpr<arith::MulIOp>(expr); + return buildBinaryExpr<arith::MulIOp>(expr, + arith::IntegerOverflowFlags::nsw); } /// Euclidean modulo operation: negative RHS is not allowed. diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp index d8b314a3fa43..e016a6e16e59 100644 --- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp +++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp @@ -580,11 +580,31 @@ void arith::MulUIExtendedOp::getCanonicalizationPatterns( // DivUIOp //===----------------------------------------------------------------------===// +/// Fold `(a * b) / b -> a` +static Value foldDivMul(Value lhs, Value rhs, + arith::IntegerOverflowFlags ovfFlags) { + auto mul = lhs.getDefiningOp<mlir::arith::MulIOp>(); + if (!mul || !bitEnumContainsAll(mul.getOverflowFlags(), ovfFlags)) + return {}; + + if (mul.getLhs() == rhs) + return mul.getRhs(); + + if (mul.getRhs() == rhs) + return mul.getLhs(); + + return {}; +} + OpFoldResult arith::DivUIOp::fold(FoldAdaptor adaptor) { // divui (x, 1) -> x. if (matchPattern(adaptor.getRhs(), m_One())) return getLhs(); + // (a * b) / b -> a + if (Value val = foldDivMul(getLhs(), getRhs(), IntegerOverflowFlags::nuw)) + return val; + // Don't fold if it would require a division by zero. bool div0 = false; auto result = constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(), @@ -621,6 +641,10 @@ OpFoldResult arith::DivSIOp::fold(FoldAdaptor adaptor) { if (matchPattern(adaptor.getRhs(), m_One())) return getLhs(); + // (a * b) / b -> a + if (Value val = foldDivMul(getLhs(), getRhs(), IntegerOverflowFlags::nsw)) + return val; + // Don't fold if it would overflow or if it requires a division by zero. bool overflowOrDiv0 = false; auto result = constFoldBinaryOp<IntegerAttr>( diff --git a/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp b/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp index 61767f3b21c9..12c65a72babc 100644 --- a/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp +++ b/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp @@ -17,7 +17,7 @@ #include "mlir/Dialect/ArmSME/Transforms/Passes.h" #include "mlir/Dialect/ArmSME/Utils/Utils.h" #include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/Dialect/Func/Transforms/OneToNFuncConversions.h" +#include "mlir/Dialect/Func/Transforms/FuncConversions.h" #include "mlir/Dialect/Index/IR/IndexDialect.h" #include "mlir/Dialect/Index/IR/IndexOps.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" @@ -25,7 +25,8 @@ #include "mlir/Dialect/SCF/Transforms/Patterns.h" #include "mlir/Dialect/Utils/IndexingUtils.h" #include "mlir/Dialect/Vector/Utils/VectorUtils.h" -#include "mlir/Transforms/OneToNTypeConversion.h" +#include "mlir/Transforms/DialectConversion.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" #define DEBUG_TYPE "arm-sme-vector-legalization" @@ -172,12 +173,12 @@ int getNumberOfSMETilesForVectorType(VectorType type) { /// Legalize `arith.constant dense<value>` splat operations to fit within SME /// tiles by decomposing them into tile-sized operations. struct LegalizeArithConstantOpsByDecomposition - : public OneToNOpConversionPattern<arith::ConstantOp> { - using OneToNOpConversionPattern::OneToNOpConversionPattern; + : public OpConversionPattern<arith::ConstantOp> { + using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(arith::ConstantOp constantOp, OpAdaptor adaptor, - OneToNPatternRewriter &rewriter) const override { + ConversionPatternRewriter &rewriter) const override { auto vectorType = dyn_cast<VectorType>(constantOp.getType()); auto denseAttr = dyn_cast<DenseElementsAttr>(constantOp.getValueAttr()); if (!vectorType || !denseAttr || !denseAttr.isSplat()) @@ -191,8 +192,8 @@ struct LegalizeArithConstantOpsByDecomposition auto tileCount = getNumberOfSMETilesForVectorType(vectorType); auto tileSplat = rewriter.create<arith::ConstantOp>( constantOp.getLoc(), denseAttr.resizeSplat(smeTileType)); - rewriter.replaceOp(constantOp, SmallVector<Value>(tileCount, tileSplat), - adaptor.getResultMapping()); + SmallVector<Value> repl(tileCount, tileSplat); + rewriter.replaceOpWithMultiple(constantOp, {repl}); return success(); } @@ -201,12 +202,13 @@ struct LegalizeArithConstantOpsByDecomposition /// Legalize `vector.outerproduct` operations to fit within SME tiles by /// decomposing them into tile-sized operations. struct LegalizeVectorOuterProductOpsByDecomposition - : public OneToNOpConversionPattern<vector::OuterProductOp> { - using OneToNOpConversionPattern::OneToNOpConversionPattern; + : public OpConversionPattern<vector::OuterProductOp> { + using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(vector::OuterProductOp outerProductOp, OpAdaptor adaptor, - OneToNPatternRewriter &rewriter) const override { + matchAndRewrite(vector::OuterProductOp outerProductOp, + OneToNOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { auto vectorType = outerProductOp.getResultVectorType(); if (!isMultipleOfSMETileVectorType(vectorType)) return rewriter.notifyMatchFailure(outerProductOp, @@ -219,6 +221,7 @@ struct LegalizeVectorOuterProductOpsByDecomposition auto maskOp = outerProductOp.getMaskingOp(); mask = maskOp.getMask(); rootOp = maskOp; + rewriter.setInsertionPoint(rootOp); } if (!isSupportedMaskOp(mask)) @@ -248,7 +251,7 @@ struct LegalizeVectorOuterProductOpsByDecomposition resultSMETiles.push_back(maskedOuterProduct->getResult(0)); } - rewriter.replaceOp(rootOp, resultSMETiles, adaptor.getResultMapping()); + rewriter.replaceOpWithMultiple(rootOp, {resultSMETiles}); return success(); } }; @@ -259,12 +262,12 @@ struct LegalizeVectorOuterProductOpsByDecomposition // (invalid). This pattern matches on `vector.mask` then calls into the // `vector.outerproduct` pattern to work around this issue. struct LegalizeMaskedVectorOuterProductOpsByDecomposition - : public OneToNOpConversionPattern<vector::MaskOp> { - using OneToNOpConversionPattern::OneToNOpConversionPattern; + : public OpConversionPattern<vector::MaskOp> { + using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(vector::MaskOp maskOp, OpAdaptor adaptor, - OneToNPatternRewriter &rewriter) const override { + matchAndRewrite(vector::MaskOp maskOp, OneToNOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { if (auto outerProductOp = llvm::dyn_cast_or_null<vector::OuterProductOp>( maskOp.getMaskableOp())) { LegalizeVectorOuterProductOpsByDecomposition pattern(*getTypeConverter(), @@ -279,12 +282,12 @@ struct LegalizeMaskedVectorOuterProductOpsByDecomposition /// Legalize `vector.transfer_read` operations to fit within SME tiles by /// decomposing them into tile-sized operations. struct LegalizeTransferReadOpsByDecomposition - : public OneToNOpConversionPattern<vector::TransferReadOp> { - using OneToNOpConversionPattern::OneToNOpConversionPattern; + : public OpConversionPattern<vector::TransferReadOp> { + using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(vector::TransferReadOp readOp, OpAdaptor adaptor, - OneToNPatternRewriter &rewriter) const override { + matchAndRewrite(vector::TransferReadOp readOp, OneToNOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { auto vectorType = readOp.getVectorType(); if (!isMultipleOfSMETileVectorType(vectorType)) return rewriter.notifyMatchFailure(readOp, @@ -319,7 +322,7 @@ struct LegalizeTransferReadOpsByDecomposition resultSMETiles.push_back(smeRead); } - rewriter.replaceOp(readOp, resultSMETiles, adaptor.getResultMapping()); + rewriter.replaceOpWithMultiple(readOp, {resultSMETiles}); return success(); } }; @@ -327,12 +330,12 @@ struct LegalizeTransferReadOpsByDecomposition /// Legalize `vector.transfer_write` operations to fit within SME tiles by /// decomposing them into tile-sized operations. struct LegalizeTransferWriteOpsByDecomposition - : public OneToNOpConversionPattern<vector::TransferWriteOp> { - using OneToNOpConversionPattern::OneToNOpConversionPattern; + : public OpConversionPattern<vector::TransferWriteOp> { + using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(vector::TransferWriteOp writeOp, OpAdaptor adaptor, - OneToNPatternRewriter &rewriter) const override { + matchAndRewrite(vector::TransferWriteOp writeOp, OneToNOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { auto vectorType = writeOp.getVectorType(); if (!isMultipleOfSMETileVectorType(vectorType)) return rewriter.notifyMatchFailure(writeOp, @@ -409,12 +412,12 @@ struct LegalizeTransferWriteOpsByDecomposition /// } /// ``` struct LegalizeMultiTileTransferWriteAsStoreLoop - : public OneToNOpConversionPattern<vector::TransferWriteOp> { - using OneToNOpConversionPattern::OneToNOpConversionPattern; + : public OpConversionPattern<vector::TransferWriteOp> { + using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(vector::TransferWriteOp writeOp, OpAdaptor adaptor, - OneToNPatternRewriter &rewriter) const override { + matchAndRewrite(vector::TransferWriteOp writeOp, OneToNOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { if (writeOp.hasPureTensorSemantics()) return rewriter.notifyMatchFailure( writeOp, "TODO: tensor semantics are unsupported"); @@ -936,10 +939,16 @@ struct VectorLegalizationPass return success(); }); - patterns.add<FoldExtractFromVectorOfSMELikeCreateMasks, - LiftIllegalVectorTransposeToMemory, - ConvertIllegalShapeCastOpsToTransposes, - LowerIllegalTransposeStoreViaZA>(context); + // Apply preprocessing patterns. + RewritePatternSet rewritePatterns(context); + rewritePatterns.add<FoldExtractFromVectorOfSMELikeCreateMasks, + LiftIllegalVectorTransposeToMemory, + ConvertIllegalShapeCastOpsToTransposes, + LowerIllegalTransposeStoreViaZA>(context); + if (failed( + applyPatternsGreedily(getOperation(), std::move(rewritePatterns)))) + return signalPassFailure(); + // Note: These two patterns are added with a high benefit to ensure: // - Masked outer products are handled before unmasked ones // - Multi-tile writes are lowered as a store loop (if possible) @@ -950,11 +959,20 @@ struct VectorLegalizationPass LegalizeVectorOuterProductOpsByDecomposition, LegalizeTransferReadOpsByDecomposition, LegalizeTransferWriteOpsByDecomposition>(converter, context); - populateFuncTypeConversionPatterns(converter, patterns); - scf::populateSCFStructuralOneToNTypeConversions(converter, patterns); - - if (failed(applyPartialOneToNConversion(getOperation(), converter, - std::move(patterns)))) + populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(patterns, + converter); + populateCallOpTypeConversionPattern(patterns, converter); + populateReturnOpTypeConversionPattern(patterns, converter); + scf::populateSCFStructuralTypeConversions(converter, patterns); + + ConversionTarget target(getContext()); + target.markUnknownOpDynamicallyLegal( + [&](Operation *op) { return converter.isLegal(op); }); + target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) { + return converter.isSignatureLegal(op.getFunctionType()); + }); + if (failed(applyPartialConversion(getOperation(), target, + std::move(patterns)))) return signalPassFailure(); } }; diff --git a/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp index 845a32c4d97b..2bdb640699d0 100644 --- a/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp +++ b/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp @@ -20,22 +20,6 @@ using namespace mlir; using namespace mlir::arm_sve; -template <typename OpTy> -class ForwardOperands : public OpConversionPattern<OpTy> { - using OpConversionPattern<OpTy>::OpConversionPattern; - - LogicalResult - matchAndRewrite(OpTy op, typename OpTy::Adaptor adaptor, - ConversionPatternRewriter &rewriter) const final { - if (adaptor.getOperands().getTypes() == op->getOperands().getTypes()) - return rewriter.notifyMatchFailure(op, "operand types already match"); - - rewriter.modifyOpInPlace(op, - [&]() { op->setOperands(adaptor.getOperands()); }); - return success(); - } -}; - using SdotOpLowering = OneToOneConvertToLLVMPattern<SdotOp, SdotIntrOp>; using SmmlaOpLowering = OneToOneConvertToLLVMPattern<SmmlaOp, SmmlaIntrOp>; using UdotOpLowering = OneToOneConvertToLLVMPattern<UdotOp, UdotIntrOp>; @@ -204,10 +188,6 @@ void mlir::populateArmSVELegalizeForLLVMExportPatterns( // Populate conversion patterns // clang-format off - patterns.add<ForwardOperands<func::CallOp>, - ForwardOperands<func::CallIndirectOp>, - ForwardOperands<func::ReturnOp>>(converter, - &converter.getContext()); patterns.add<SdotOpLowering, SmmlaOpLowering, UdotOpLowering, diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp index 349841f06959..1eb27e44810b 100644 --- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp +++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp @@ -480,18 +480,21 @@ bool AnalysisState::isValueRead(Value value) const { return false; } -// Starting from `value`, follow the use-def chain in reverse, always selecting -// the aliasing OpOperands. Find and return Values for which `condition` -// evaluates to true. OpOperands of such matching Values are not traversed any -// further, the visited aliasing opOperands will be preserved through -// `visitedOpOperands`. +// Starting from `opOperand`, follow the use-def chain in reverse, always +// selecting the aliasing OpOperands. Find and return Values for which +// `condition` evaluates to true. Uses of such matching Values are not +// traversed any further, the visited aliasing opOperands will be preserved +// through `visitedOpOperands`. llvm::SetVector<Value> AnalysisState::findValueInReverseUseDefChain( - Value value, llvm::function_ref<bool(Value)> condition, + OpOperand *opOperand, llvm::function_ref<bool(Value)> condition, TraversalConfig config, llvm::DenseSet<OpOperand *> *visitedOpOperands) const { llvm::DenseSet<Value> visited; llvm::SetVector<Value> result, workingSet; - workingSet.insert(value); + workingSet.insert(opOperand->get()); + + if (visitedOpOperands) + visitedOpOperands->insert(opOperand); while (!workingSet.empty()) { Value value = workingSet.pop_back_val(); @@ -563,12 +566,14 @@ llvm::SetVector<Value> AnalysisState::findValueInReverseUseDefChain( return result; } -// Find the values that define the contents of the given value. -llvm::SetVector<Value> AnalysisState::findDefinitions(Value value) const { +// Find the values that define the contents of the given operand's value. +llvm::SetVector<Value> +AnalysisState::findDefinitions(OpOperand *opOperand) const { TraversalConfig config; config.alwaysIncludeLeaves = false; return findValueInReverseUseDefChain( - value, [&](Value v) { return this->bufferizesToMemoryWrite(v); }, config); + opOperand, [&](Value v) { return this->bufferizesToMemoryWrite(v); }, + config); } AnalysisState::AnalysisState(const BufferizationOptions &options) @@ -892,7 +897,7 @@ bool bufferization::detail::defaultResultBufferizesToMemoryWrite( config.alwaysIncludeLeaves = false; for (AliasingOpOperand alias : opOperands) { if (!state - .findValueInReverseUseDefChain(alias.opOperand->get(), + .findValueInReverseUseDefChain(alias.opOperand, isMemoryWriteInsideOp, config) .empty()) return true; diff --git a/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp b/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp index abc0635a2cdf..2c4e362101f8 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp @@ -93,8 +93,31 @@ findValidInsertionPoint(Operation *emptyTensorOp, Operation *user, return nullptr; } +Value mlir::bufferization::buildSubsetExtraction(RewriterBase &rewriter, + SubsetInsertionOpInterface op, + tensor::EmptyOp emptyTensorOp, + Operation *user) { + + mlir::OpBuilder::InsertionGuard guard(rewriter); + // All values that are needed to create the replacement op. + SmallVector<Value> neededValues = op.getValuesNeededToBuildSubsetExtraction(); + // Find a suitable insertion point. If no suitable insertion point + // for the replacement can be found, return an empty value to skip + // this replacement. + Operation *insertionPoint = + findValidInsertionPoint(emptyTensorOp, user, neededValues); + if (!insertionPoint) + return {}; + + rewriter.setInsertionPoint(insertionPoint); + Value replacement = + op.buildSubsetExtraction(rewriter, emptyTensorOp->getLoc()); + return replacement; +} + LogicalResult mlir::bufferization::eliminateEmptyTensors( - RewriterBase &rewriter, Operation *op, OneShotAnalysisState &state) { + RewriterBase &rewriter, Operation *op, OneShotAnalysisState &state, + ControlBuildSubsetExtractionFn subsetsExtractionFn) { OpBuilder::InsertionGuard g(rewriter); llvm::DenseSet<OpOperand *> visitedOpOperands; op->walk([&](SubsetInsertionOpInterface op) { @@ -105,10 +128,6 @@ LogicalResult mlir::bufferization::eliminateEmptyTensors( if (!state.isInPlace(source)) return WalkResult::skip(); - // All values that are needed to create the replacement op. - SmallVector<Value> neededValues = - op.getValuesNeededToBuildSubsetExtraction(); - // Find tensor.empty ops on the reverse SSA use-def chain. Only follow // equivalent tensors. I.e., stop when there are ops such as extract_slice // on the path. @@ -124,35 +143,23 @@ LogicalResult mlir::bufferization::eliminateEmptyTensors( // %3 = tensor.insert_slice %2 into ... config.followSameTypeOrCastsOnly = true; SetVector<Value> emptyTensors = state.findValueInReverseUseDefChain( - source.get(), /*condition=*/ + &source, /*condition=*/ [&](Value val) { return val.getDefiningOp<tensor::EmptyOp>(); }, config, &visitedOpOperands); for (Value v : emptyTensors) { - Operation *emptyTensorOp = v.getDefiningOp(); - + auto emptyTensorOp = v.getDefiningOp<tensor::EmptyOp>(); + assert(emptyTensorOp && "expected tensor.empty op"); // Find the use to be replaced from the use-def chain. auto iter = llvm::find_if( visitedOpOperands, [&emptyTensorOp](OpOperand *opOperand) { return llvm::count(emptyTensorOp->getUses(), *opOperand); }); - // This could be achieved when a use of `emptyTensorOp` is being - // consumed by `SubsetInsertionOpInterface`'s source directly. - if (iter == visitedOpOperands.end()) - continue; + + assert(iter != visitedOpOperands.end() && "could not find use"); OpOperand *useToBeReplaced = *iter; Operation *user = useToBeReplaced->getOwner(); - - // Find a suitable insertion point. If no suitable insertion point for - // the replacement can be found, skip this replacement. - Operation *insertionPoint = - findValidInsertionPoint(emptyTensorOp, user, neededValues); - if (!insertionPoint) - continue; - - rewriter.setInsertionPoint(insertionPoint); - Value replacement = - op.buildSubsetExtraction(rewriter, emptyTensorOp->getLoc()); + auto replacement = subsetsExtractionFn(rewriter, op, emptyTensorOp, user); if (!replacement) continue; if (emptyTensorOp == replacement.getDefiningOp()) diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp index d1e6acef324f..fc1b221b4f03 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp @@ -196,7 +196,12 @@ void OneShotAnalysisState::gatherUndefinedTensorUses(Operation *op) { // If there is no preceding definition, the tensor contents are // undefined. - if (findDefinitionsCached(opResult).empty()) + if (opResult.getUses().empty()) + continue; + // It does not really matter which use to take to search about + // the value's definitions. + OpOperand *opOperand = &(*opResult.getUses().begin()); + if (findDefinitionsCached(opOperand).empty()) for (OpOperand &use : opResult.getUses()) undefinedTensorUses.insert(&use); } @@ -464,7 +469,8 @@ static void annotateConflict(OpOperand *uRead, OpOperand *uConflictingWrite, /// indexing. I.e., the tensor types do not change along the use-def chain, /// apart from static <-> dynamic dim casts. static bool hasEquivalentValueInReverseUseDefChain(AnalysisState &state, - Value start, Value other) { + OpOperand *start, + Value other) { TraversalConfig config; config.followEquivalentOnly = true; config.alwaysIncludeLeaves = false; @@ -475,9 +481,10 @@ static bool hasEquivalentValueInReverseUseDefChain(AnalysisState &state, .empty(); } -/// Return "true" if `value` is originating from a subset that is equivalent to -/// the subset that `subsetOp` inserts into. -static bool matchesInsertDestination(const AnalysisState &state, Value value, +/// Return "true" if the given operand's value is originating from a subset +/// that is equivalent to the subset that `subsetOp` inserts into. +static bool matchesInsertDestination(const AnalysisState &state, + OpOperand *opOperand, SubsetInsertionOpInterface subsetOp) { auto matchingSubset = [&](Value val) { if (auto opResult = dyn_cast<OpResult>(val)) @@ -490,7 +497,7 @@ static bool matchesInsertDestination(const AnalysisState &state, Value value, // There may be multiple leaves at which the reverse SSA use-def chain lookup // terminates. All of them must be equivalent subsets. SetVector<Value> backwardSlice = - state.findValueInReverseUseDefChain(value, matchingSubset); + state.findValueInReverseUseDefChain(opOperand, matchingSubset); return static_cast<bool>(llvm::all_of(backwardSlice, matchingSubset)); } @@ -516,7 +523,7 @@ static bool areNonConflictingSubsets(OpOperand *uRead, // {inplace= [true] } if (uRead == &subsetOp.getDestinationOperand() && - matchesInsertDestination(state, uConflictingWrite->get(), subsetOp)) + matchesInsertDestination(state, uConflictingWrite, subsetOp)) // Case 1: The main insight is that InsertSliceOp reads only part of // the destination tensor. The overwritten area is not read. If // uConflictingWrite writes into exactly the memory location that is @@ -533,7 +540,7 @@ static bool areNonConflictingSubsets(OpOperand *uRead, if (uRead == &subsetOp.getSourceOperand() && uConflictingWrite == &subsetOp.getDestinationOperand() && - matchesInsertDestination(state, uRead->get(), subsetOp)) + matchesInsertDestination(state, uRead, subsetOp)) // Case 2: The read of the source tensor and the write to the dest // tensor via an InsertSliceOp is not a conflict if the read is // reading exactly that part of an equivalent tensor that the @@ -567,8 +574,7 @@ static bool areNonConflictingSubsets(OpOperand *uRead, if (uConflictingWrite == &subsetOp.getDestinationOperand() && state.areEquivalentBufferizedValues( uRead->get(), subsetOp.getSourceOperand().get()) && - matchesInsertDestination(state, subsetOp.getSourceOperand().get(), - subsetOp)) + matchesInsertDestination(state, &subsetOp.getSourceOperand(), subsetOp)) return true; return false; @@ -600,9 +606,9 @@ hasReadAfterWriteInterference(const DenseSet<OpOperand *> &usesRead, // even though that op just bufferizes to an allocation but does define // the contents of the buffer. SetVector<Value> definitionsOrLeaves = - state.findValueInReverseUseDefChain( - uConflictingWrite->get(), - [&](Value v) { return state.bufferizesToMemoryWrite(v); }); + state.findValueInReverseUseDefChain(uConflictingWrite, [&](Value v) { + return state.bufferizesToMemoryWrite(v); + }); assert(!definitionsOrLeaves.empty() && "expected at least one definition or leaf"); @@ -641,8 +647,7 @@ hasReadAfterWriteInterference(const DenseSet<OpOperand *> &usesRead, // In the above example, if uRead is the OpOperand of reading_op, the // definition is %0. Note that operations that create an alias but do not // bufferize to a memory write (such as ExtractSliceOp) are skipped. - const SetVector<Value> &definitions = - state.findDefinitionsCached(uRead->get()); + const SetVector<Value> &definitions = state.findDefinitionsCached(uRead); if (definitions.empty()) { // Fast path: No conflict if there are no definitions. LLVM_DEBUG(llvm::dbgs() @@ -714,9 +719,9 @@ hasReadAfterWriteInterference(const DenseSet<OpOperand *> &usesRead, if (bufferizableOp.bufferizesToElementwiseAccess( state, {uRead, uConflictingWrite})) { if (hasEquivalentValueInReverseUseDefChain( - state, uRead->get(), uConflictingWrite->get()) || + state, uRead, uConflictingWrite->get()) || hasEquivalentValueInReverseUseDefChain( - state, uConflictingWrite->get(), uRead->get())) { + state, uConflictingWrite, uRead->get())) { LLVM_DEBUG( llvm::dbgs() << " no conflict: op bufferizes to element-wise access\n"); @@ -965,11 +970,12 @@ wouldCreateWriteToNonWritableBuffer(OpOperand &operand, // Bufferization analyses. //===----------------------------------------------------------------------===// -// Find the values that define the contents of the given value. +// Find the values that define the contents of the given operand's value. const llvm::SetVector<Value> & -OneShotAnalysisState::findDefinitionsCached(Value value) { +OneShotAnalysisState::findDefinitionsCached(OpOperand *opOperand) { + Value value = opOperand->get(); if (!cachedDefinitions.count(value)) - cachedDefinitions[value] = findDefinitions(value); + cachedDefinitions[value] = findDefinitions(opOperand); return cachedDefinitions[value]; } diff --git a/mlir/lib/Dialect/EmitC/Transforms/TypeConversions.cpp b/mlir/lib/Dialect/EmitC/Transforms/TypeConversions.cpp index 0b3a494794f3..72c8fd0f3248 100644 --- a/mlir/lib/Dialect/EmitC/Transforms/TypeConversions.cpp +++ b/mlir/lib/Dialect/EmitC/Transforms/TypeConversions.cpp @@ -33,7 +33,6 @@ void mlir::populateEmitCSizeTTypeConversions(TypeConverter &converter) { converter.addSourceMaterialization(materializeAsUnrealizedCast); converter.addTargetMaterialization(materializeAsUnrealizedCast); - converter.addArgumentMaterialization(materializeAsUnrealizedCast); } /// Get an unsigned integer or size data type corresponding to \p ty. diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMAttrs.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMAttrs.cpp index c7ddc1b36f4d..ff1636bc121b 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMAttrs.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMAttrs.cpp @@ -48,10 +48,28 @@ void LLVMDialect::registerAttributes() { addAttributes< #define GET_ATTRDEF_LIST #include "mlir/Dialect/LLVMIR/LLVMOpsAttrDefs.cpp.inc" + >(); } //===----------------------------------------------------------------------===// +// AliasScopeAttr +//===----------------------------------------------------------------------===// + +LogicalResult +AliasScopeAttr::verify(function_ref<InFlightDiagnostic()> emitError, + Attribute id, AliasScopeDomainAttr domain, + StringAttr description) { + (void)domain; + (void)description; + if (!llvm::isa<StringAttr, DistinctAttr>(id)) + return emitError() + << "id of an alias scope must be a StringAttr or a DistrinctAttr"; + + return success(); +} + +//===----------------------------------------------------------------------===// // DINodeAttr //===----------------------------------------------------------------------===// @@ -232,7 +250,7 @@ DIRecursiveTypeAttrInterface DISubprogramAttr::withRecId(DistinctAttr recId) { DIRecursiveTypeAttrInterface DISubprogramAttr::getRecSelf(DistinctAttr recId) { return DISubprogramAttr::get(recId.getContext(), recId, /*isRecSelf=*/true, - {}, {}, {}, {}, {}, 0, 0, {}, {}, {}, {}, {}); + {}, {}, {}, {}, {}, {}, 0, 0, {}, {}, {}, {}); } //===----------------------------------------------------------------------===// @@ -288,6 +306,16 @@ TargetFeaturesAttr TargetFeaturesAttr::get(MLIRContext *context, })); } +TargetFeaturesAttr +TargetFeaturesAttr::getChecked(function_ref<InFlightDiagnostic()> emitError, + MLIRContext *context, + llvm::ArrayRef<StringRef> features) { + return Base::getChecked(emitError, context, + llvm::map_to_vector(features, [&](StringRef feature) { + return StringAttr::get(context, feature); + })); +} + TargetFeaturesAttr TargetFeaturesAttr::get(MLIRContext *context, StringRef targetFeatures) { SmallVector<StringRef> features; @@ -296,6 +324,16 @@ TargetFeaturesAttr TargetFeaturesAttr::get(MLIRContext *context, return get(context, features); } +TargetFeaturesAttr +TargetFeaturesAttr::getChecked(function_ref<InFlightDiagnostic()> emitError, + MLIRContext *context, StringRef targetFeatures) { + SmallVector<StringRef> features; + targetFeatures.split(features, ',', /*MaxSplit=*/-1, + /*KeepEmpty=*/false); + ArrayRef featuresRef(features); + return getChecked(emitError, context, featuresRef); +} + LogicalResult TargetFeaturesAttr::verify(function_ref<InFlightDiagnostic()> emitError, llvm::ArrayRef<StringAttr> features) { diff --git a/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp b/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp index 6801b68a8538..6c1087730ebb 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp @@ -553,7 +553,7 @@ Value linalg::bufferizeToAllocation( Value alloc = createAllocationForTensor( rewriter, op->getLoc(), operand->get(), options, memorySpace); allocs.push_back(alloc); - if (!state.findDefinitions(operand->get()).empty()) { + if (!state.findDefinitions(operand).empty()) { // Initialize buffer with a copy of the operand data. Not needed if the // tensor is uninitialized. createMemcpy(rewriter, op->getLoc(), operand->get(), alloc, options); diff --git a/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp index 0e651f4cee4c..fc6671ef8117 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp @@ -154,7 +154,6 @@ public: }); addSourceMaterialization(sourceMaterializationCallback); - addArgumentMaterialization(sourceMaterializationCallback); } }; diff --git a/mlir/lib/Dialect/Linalg/Transforms/EliminateEmptyTensors.cpp b/mlir/lib/Dialect/Linalg/Transforms/EliminateEmptyTensors.cpp index 4776883ed95c..b710bde87f9f 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/EliminateEmptyTensors.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/EliminateEmptyTensors.cpp @@ -59,7 +59,7 @@ LogicalResult linalg::linalgOpAnchoredEmptyTensorEliminationStep( config.followEquivalentOnly = true; config.alwaysIncludeLeaves = false; SetVector<Value> emptyTensors = state.findValueInReverseUseDefChain( - in->get(), /*condition=*/ + in, /*condition=*/ [&](Value val) { return val.getDefiningOp<tensor::EmptyOp>() && val.getType() == in->get().getType(); diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp index 60cf897b00de..50593b08ad74 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp @@ -1656,8 +1656,8 @@ void linalg::populateDecomposeConvolutionPatterns(RewritePatternSet &patterns, } void linalg::populateDecomposePackUnpackPatterns(RewritePatternSet &patterns) { - // TODO: Add and test patterns for tensor.unpack patterns.add<DecomposeOuterUnitDimsPackOpPattern>(patterns.getContext()); + patterns.add<DecomposeOuterUnitDimsUnPackOpPattern>(patterns.getContext()); } void linalg::populateDecomposePadPatterns(RewritePatternSet &patterns) { diff --git a/mlir/lib/Dialect/Quant/Transforms/StripFuncQuantTypes.cpp b/mlir/lib/Dialect/Quant/Transforms/StripFuncQuantTypes.cpp index 619127226628..71b88d1be1b0 100644 --- a/mlir/lib/Dialect/Quant/Transforms/StripFuncQuantTypes.cpp +++ b/mlir/lib/Dialect/Quant/Transforms/StripFuncQuantTypes.cpp @@ -56,7 +56,6 @@ public: addConversion(convertQuantizedType); addConversion(convertTensorType); - addArgumentMaterialization(materializeConversion); addSourceMaterialization(materializeConversion); addTargetMaterialization(materializeConversion); } diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp index eded1c394f12..83ae79ce4826 100644 --- a/mlir/lib/Dialect/SCF/IR/SCF.cpp +++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp @@ -839,8 +839,7 @@ mlir::scf::replaceAndCastForOpIterArg(RewriterBase &rewriter, scf::ForOp forOp, namespace { // Fold away ForOp iter arguments when: // 1) The op yields the iter arguments. -// 2) The iter arguments have no use and the corresponding outer region -// iterators (inputs) are yielded. +// 2) The argument's corresponding outer region iterators (inputs) are yielded. // 3) The iter arguments have no use and the corresponding (operation) results // have no use. // @@ -872,30 +871,28 @@ struct ForOpIterArgsFolder : public OpRewritePattern<scf::ForOp> { newIterArgs.reserve(forOp.getInitArgs().size()); newYieldValues.reserve(numResults); newResultValues.reserve(numResults); - for (auto it : llvm::zip(forOp.getInitArgs(), // iter from outside - forOp.getRegionIterArgs(), // iter inside region - forOp.getResults(), // op results - forOp.getYieldedValues() // iter yield - )) { + for (auto [init, arg, result, yielded] : + llvm::zip(forOp.getInitArgs(), // iter from outside + forOp.getRegionIterArgs(), // iter inside region + forOp.getResults(), // op results + forOp.getYieldedValues() // iter yield + )) { // Forwarded is `true` when: // 1) The region `iter` argument is yielded. - // 2) The region `iter` argument has no use, and the corresponding iter - // operand (input) is yielded. + // 2) The region `iter` argument the corresponding input is yielded. // 3) The region `iter` argument has no use, and the corresponding op // result has no use. - bool forwarded = ((std::get<1>(it) == std::get<3>(it)) || - (std::get<1>(it).use_empty() && - (std::get<0>(it) == std::get<3>(it) || - std::get<2>(it).use_empty()))); + bool forwarded = (arg == yielded) || (init == yielded) || + (arg.use_empty() && result.use_empty()); keepMask.push_back(!forwarded); canonicalize |= forwarded; if (forwarded) { - newBlockTransferArgs.push_back(std::get<0>(it)); - newResultValues.push_back(std::get<0>(it)); + newBlockTransferArgs.push_back(init); + newResultValues.push_back(init); continue; } - newIterArgs.push_back(std::get<0>(it)); - newYieldValues.push_back(std::get<3>(it)); + newIterArgs.push_back(init); + newYieldValues.push_back(yielded); newBlockTransferArgs.push_back(Value()); // placeholder with null value newResultValues.push_back(Value()); // placeholder with null value } diff --git a/mlir/lib/Dialect/SCF/Utils/Utils.cpp b/mlir/lib/Dialect/SCF/Utils/Utils.cpp index 41410a0a56aa..6cda7100fe07 100644 --- a/mlir/lib/Dialect/SCF/Utils/Utils.cpp +++ b/mlir/lib/Dialect/SCF/Utils/Utils.cpp @@ -329,8 +329,9 @@ static void generateUnrolledLoop( // 'forOp'. auto builder = OpBuilder::atBlockTerminator(loopBodyBlock); + constexpr auto defaultAnnotateFn = [](unsigned, Operation *, OpBuilder) {}; if (!annotateFn) - annotateFn = [](unsigned, Operation *, OpBuilder) {}; + annotateFn = defaultAnnotateFn; // Keep a pointer to the last non-terminator operation in the original block // so that we know what to clone (since we are doing this in-place). diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorDescriptor.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorDescriptor.cpp index 834e3634cc13..8bbb2cac5efd 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorDescriptor.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorDescriptor.cpp @@ -69,9 +69,6 @@ SparseTensorTypeToBufferConverter::SparseTensorTypeToBufferConverter() { // Required by scf.for 1:N type conversion. addSourceMaterialization(materializeTuple); - - // Required as a workaround until we have full 1:N support. - addArgumentMaterialization(materializeTuple); } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp index f79c774ceb3e..24a1d5531531 100644 --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -4795,6 +4795,44 @@ static SmallVector<Value> getNewOperands(DestinationStyleOpInterface op, return newOperands; } +// Given the (potentially) updated packed type, `newPackedTy`, generates an +// updated mixed-tile-sizes attribute. A tile size is updated only +// when: +// * a dim from newPackedTy is static, and +// * the corresponding size from mixedTiles is still dynamic. +// Otherwise, the original tile size is preserved. +// Note - packed-type-dim and mixed-tile-size should always match! +static SmallVector<OpFoldResult> +getNewMixedTileSizes(PatternRewriter &rewriter, Type newPackedTy, + SmallVector<OpFoldResult> mixedTiles) { + SmallVector<OpFoldResult> newMixedTileSizes; + for (auto it : llvm::zip(cast<ShapedType>(newPackedTy) + .getShape() + .take_back(mixedTiles.size()), + mixedTiles)) { + int64_t shape = std::get<0>(it); + if (shape == ShapedType::kDynamic) { + newMixedTileSizes.push_back(std::get<1>(it)); + continue; + } + + // If the current result dim is static, update the dynamic mixed-size + // (provided the original value is dynamic). + OpFoldResult tile = std::get<1>(it); + if (Attribute attr = llvm::dyn_cast_if_present<Attribute>(tile)) { + // Already a constant + newMixedTileSizes.push_back(tile); + } else { + assert(getConstantIntValue(tile).value() == shape && + "tile size and dim size don't match!"); + newMixedTileSizes.push_back( + (rewriter.getIntegerAttr(rewriter.getIndexType(), shape))); + } + } + + return newMixedTileSizes; +} + /// Folds a tensor.cast op into a consuming tensor::PackOp op if the /// `tensor.cast` has source that is more static than the consuming op. /// @@ -4821,31 +4859,13 @@ struct FoldTensorCastPackOp : public OpRewritePattern<PackOp> { SmallVector<Value> newOperands = getNewOperands(op, newResultTypes); // Get the updated mixed-tile-sizes attribute. - SmallVector<OpFoldResult> newMixedTileSizes; - for (auto it : llvm::zip(cast<ShapedType>(newResultTypes[0]) - .getShape() - .take_back(op.getMixedTiles().size()), - op.getMixedTiles())) { - int64_t shape = std::get<0>(it); - if (shape == ShapedType::kDynamic) { - newMixedTileSizes.push_back(std::get<1>(it)); - continue; - } - - if (Attribute attr = - llvm::dyn_cast_if_present<Attribute>(std::get<1>(it))) { - // Already a constant - newMixedTileSizes.push_back(std::get<1>(it)); - } else { - int64_t tileSize = getConstantIntValue(std::get<1>(it)).value(); - assert(tileSize == shape && "tile size and dim size don't match!"); - (void)tileSize; - newMixedTileSizes.push_back( - (rewriter.getIntegerAttr(rewriter.getIndexType(), shape))); - } - } + SmallVector<OpFoldResult> newMixedTileSizes = + getNewMixedTileSizes(rewriter, newResultTypes[0], op.getMixedTiles()); // Clone op. + // TODO: Strictly speaking, discardable attributes should be _discarded_ at + // this point. However, in practice, we use them for things that we'd like + // to preserve. Implement a better abstraction. PackOp newOp = rewriter.create<PackOp>( op.getLoc(), newOperands[0], newOperands[1], op.getInnerDimsPos(), newMixedTileSizes, op.getPaddingValue(), op.getOuterDimsPerm()); @@ -4865,6 +4885,59 @@ struct FoldTensorCastPackOp : public OpRewritePattern<PackOp> { } }; +/// Folds a tensor.cast op into a consuming tensor::UnPackOp op if the +/// `tensor.cast` has source that is more static than the consuming op. +/// +/// Example: +/// ```mlir +/// %1 = tensor.cast %0 : tensor<1x1x8x1xi32> to tensor<1x1x?x1xi32> +/// %2 = tensor.unpack %1 ... : tensor<1x1x?x1xi32> -> tensor<7x?xi32> +/// ``` +/// +/// folds into: +/// +/// ```mlir +/// %2 = tensor.unpack %0 ... tensor<1x1x8x1xi32> -> tensor<7x?xi32> +/// ``` +struct FoldTensorCastUnPackOp : public OpRewritePattern<UnPackOp> { + using OpRewritePattern<UnPackOp>::OpRewritePattern; + + LogicalResult matchAndRewrite(UnPackOp op, + PatternRewriter &rewriter) const override { + if (!foldTensorCastPrecondition(op)) + return failure(); + + SmallVector<Type> newResultTypes(op->getResultTypes()); + SmallVector<Value> newOperands = getNewOperands(op, newResultTypes); + Value sourceTensor = newOperands[0]; + + // Get the updated mixed-tile-sizes attribute. + SmallVector<OpFoldResult> newMixedTileSizes = getNewMixedTileSizes( + rewriter, sourceTensor.getType(), op.getMixedTiles()); + + // Clone op. + // TODO: Strictly speaking, discardable attributes should be _discarded_ at + // this point. However, in practice, we use them for things that we'd like + // to preserve. Implement a better abstraction. + UnPackOp newOp = rewriter.create<UnPackOp>( + op.getLoc(), sourceTensor, newOperands[1], op.getInnerDimsPos(), + newMixedTileSizes, op.getOuterDimsPerm()); + newOp->setDiscardableAttrs(op->getDiscardableAttrDictionary()); + + // Replace op. + Value oldResult = op.getResult(); + Value newResult = newOp.getResult(); + Value replacement = (newResult.getType() != oldResult.getType()) + ? rewriter.create<tensor::CastOp>( + op->getLoc(), oldResult.getType(), newResult) + : newResult; + + rewriter.replaceOp(op, {replacement}); + + return success(); + } +}; + /// Folds a tensor.cast op into a consuming DestinationStyleOpInterface op if /// the `tensor.cast` has source that is more static than the consuming op. /// @@ -4890,7 +4963,8 @@ struct FoldTensorCastProducerOp PatternRewriter &rewriter) const override { // Reject tensor::PackOp - there's dedicated pattern for that instead. - if (!foldTensorCastPrecondition(op) || dyn_cast<tensor::PackOp>(*op)) + if (!foldTensorCastPrecondition(op) || + isa<tensor::PackOp, tensor::UnPackOp>(*op)) return failure(); SmallVector<Type> newResultTypes(op->getResultTypes()); @@ -4923,6 +4997,7 @@ struct FoldTensorCastProducerOp void TensorDialect::getCanonicalizationPatterns( RewritePatternSet &results) const { results.add<FoldTensorCastPackOp>(getContext()); + results.add<FoldTensorCastUnPackOp>(getContext()); results.add<FoldTensorCastProducerOp>(getContext()); } diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp index 39d0ee122b16..f51c3dbce6ee 100644 --- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp @@ -1002,10 +1002,6 @@ OpFoldResult TransposeOp::fold(FoldAdaptor adaptor) { return input.reshape(resultTy); } - // Transpose does not change the input type. - if (getInput1().getType() != getType()) - return {}; - // Transpose is not the identity transpose. SmallVector<int32_t> perms; if (getConstantPerms(perms).failed()) diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp index 631d3c48f2df..764a5db48e07 100644 --- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp @@ -210,7 +210,12 @@ template <typename T> static LogicalResult verifyConvOp(T op) { // All TOSA conv ops have an input() and weight(). auto inputType = llvm::dyn_cast<RankedTensorType>(op.getInput().getType()); - auto weightType = llvm::dyn_cast<RankedTensorType>(op.getWeight().getType()); + + RankedTensorType weightType; + if constexpr (std::is_same_v<T, tosa::TransposeConv2DOp>) + weightType = llvm::dyn_cast<RankedTensorType>(op.getFilter().getType()); + else + weightType = llvm::dyn_cast<RankedTensorType>(op.getWeight().getType()); // Must be ranked tensor types if (!inputType) { @@ -218,7 +223,13 @@ static LogicalResult verifyConvOp(T op) { return failure(); } if (!weightType) { - op.emitOpError("expect a ranked tensor for weight, got ") << op.getWeight(); + if constexpr (std::is_same_v<T, tosa::TransposeConv2DOp>) { + op.emitOpError("expect a ranked tensor for filter, got ") + << op.getFilter(); + } else { + op.emitOpError("expect a ranked tensor for weight, got ") + << op.getWeight(); + } return failure(); } @@ -271,6 +282,38 @@ LogicalResult tosa::ConstOp::verify() { return success(); } +template <typename T> +static LogicalResult verifyConvOpModes(T op) { + auto inputEType = + llvm::cast<ShapedType>(op.getInput().getType()).getElementType(); + + if (auto quantType = + llvm::dyn_cast<mlir::quant::UniformQuantizedType>(inputEType)) + inputEType = quantType.getStorageType(); + + auto accType = op.getAccType(); + if (inputEType.isInteger(8) && !accType.isInteger(32)) + return op.emitOpError("accumulator type for i8 tensor is not i32"); + + if (inputEType.isInteger(16) && !accType.isInteger(48)) + return op.emitOpError("accumulator type for i16 tensor is not i48"); + + if ((inputEType.isFloat8E5M2() || inputEType.isFloat8E4M3()) && + !accType.isF16()) + return op.emitOpError("accumulator type for f8 tensor is not f16"); + + if (inputEType.isF16() && !(accType.isF16() || accType.isF32())) + return op.emitOpError("accumulator type for f16 tensor is not f16/f32"); + + if (inputEType.isBF16() && !accType.isF32()) + return op.emitOpError("accumulator type for bf16 tensor is not f32"); + + if (inputEType.isF32() && !accType.isF32()) + return op.emitOpError("accumulator type for f32 tensor is not f32"); + + return success(); +} + LogicalResult tosa::ArgMaxOp::verify() { // Ensure output is of 32-bit integer const auto resultETy = llvm::cast<ShapedType>(getType()).getElementType(); @@ -368,12 +411,14 @@ static void buildConvOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input, Value weight, Value bias, DenseI64ArrayAttr pad, DenseI64ArrayAttr stride, - DenseI64ArrayAttr dilation) { + DenseI64ArrayAttr dilation, + TypeAttr accType) { result.addOperands({input, weight, bias}); result.addAttribute("pad", pad); result.addAttribute("stride", stride); result.addAttribute("dilation", dilation); + result.addAttribute("acc_type", accType); auto quantAttr = buildConvOpQuantizationAttr(builder, input, weight); if (quantAttr) { @@ -390,11 +435,12 @@ static void buildConvOpWithQuantInfo(OpBuilder &builder, OperationState &result, static void buildTransConvOpWithQuantInfo( OpBuilder &builder, OperationState &result, Type outputType, Value input, Value weight, Value bias, DenseI64ArrayAttr outpad, - DenseI64ArrayAttr stride, DenseI64ArrayAttr outputShape) { + DenseI64ArrayAttr stride, DenseI64ArrayAttr outputShape, TypeAttr accType) { result.addOperands({input, weight, bias}); result.addAttribute("out_pad", outpad); result.addAttribute("stride", stride); result.addAttribute("out_shape", outputShape); + result.addAttribute("acc_type", accType); auto quantAttr = ::buildConvOpQuantizationAttr(builder, input, weight); if (quantAttr) { @@ -787,7 +833,7 @@ LogicalResult tosa::PadOp::inferReturnTypeComponents( return success(); } - outputShape.resize(paddingShape.getDimSize(0), ShapedType::kDynamic); + outputShape.resize(paddingShape.getDimSize(0) / 2, ShapedType::kDynamic); inferredReturnShapes.push_back(ShapedTypeComponents(outputShape)); return success(); } @@ -823,13 +869,17 @@ LogicalResult tosa::PadOp::inferReturnTypeComponents( LogicalResult tosa::PadOp::verify() { RankedTensorType inputType = getInput1().getType(); RankedTensorType outputType = getOutput().getType(); - TensorType paddingType = getPadding().getType(); + RankedTensorType paddingType = getPadding().getType(); if (inputType.getRank() != outputType.getRank()) return emitOpError() << "expect same input and output tensor rank."; - if (paddingType.hasRank() && paddingType.getRank() != 2) - return emitOpError() << "expect 'padding' tensor rank equal to 2."; + if (!paddingType.isDynamicDim(0) && + paddingType.getDimSize(0) != inputType.getRank() * 2) + return emitOpError() << "expected padding tensor dim 0 to have size " + << inputType.getRank() * 2 + << " (2*rank(shape1)) but got size " + << paddingType.getDimSize(0); return success(); } @@ -1595,7 +1645,11 @@ LogicalResult Conv2DOp::inferReturnTypeComponents( return success(); } -LogicalResult Conv2DOp::verify() { return verifyConvOp(*this); } +LogicalResult Conv2DOp::verify() { + if (verifyConvOp(*this).failed() || verifyConvOpModes(*this).failed()) + return failure(); + return success(); +} LogicalResult Conv3DOp::inferReturnTypeComponents( MLIRContext *context, ::std::optional<Location> location, @@ -1667,7 +1721,11 @@ LogicalResult Conv3DOp::inferReturnTypeComponents( return success(); } -LogicalResult Conv3DOp::verify() { return verifyConvOp(*this); } +LogicalResult Conv3DOp::verify() { + if (verifyConvOp(*this).failed() || verifyConvOpModes(*this).failed()) + return failure(); + return success(); +} LogicalResult AvgPool2dOp::inferReturnTypeComponents( MLIRContext *context, ::std::optional<Location> location, @@ -1762,7 +1820,11 @@ LogicalResult DepthwiseConv2DOp::inferReturnTypeComponents( return success(); } -LogicalResult DepthwiseConv2DOp::verify() { return verifyConvOp(*this); } +LogicalResult DepthwiseConv2DOp::verify() { + if (verifyConvOp(*this).failed() || verifyConvOpModes(*this).failed()) + return failure(); + return success(); +} LogicalResult TransposeConv2DOp::inferReturnTypeComponents( MLIRContext *context, ::std::optional<Location> location, @@ -1828,6 +1890,12 @@ LogicalResult TransposeConv2DOp::inferReturnTypeComponents( return success(); } +LogicalResult TransposeConv2DOp::verify() { + if (verifyConvOp(*this).failed() || verifyConvOpModes(*this).failed()) + return failure(); + return success(); +} + LogicalResult IfOp::inferReturnTypeComponents( MLIRContext *context, ::std::optional<Location> location, IfOp::Adaptor adaptor, diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeConv2D.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeConv2D.cpp index 44f64f76e9b0..04a709c59677 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeConv2D.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeConv2D.cpp @@ -81,7 +81,7 @@ struct Conv2DIsFullyConnected : public OpRewritePattern<tosa::Conv2DOp> { } } - auto padSizeTy = RankedTensorType::get({4, 2}, rewriter.getI64Type()); + auto padSizeTy = RankedTensorType::get({8}, rewriter.getI64Type()); auto padSize = DenseIntElementsAttr::get(padSizeTy, ArrayRef<int64_t>(pad)); Value padSizeVal = diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp index e6fba211dc37..14f392ab8c45 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp @@ -108,7 +108,7 @@ struct DepthwiseConv2DIsMul : public OpRewritePattern<tosa::DepthwiseConv2DOp> { } } - auto padSizeTy = RankedTensorType::get({5, 2}, rewriter.getI64Type()); + auto padSizeTy = RankedTensorType::get({10}, rewriter.getI64Type()); auto padSize = DenseIntElementsAttr::get(padSizeTy, ArrayRef<int64_t>(pad)); Value padSizeVal = diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp index 0779cdb9667a..db1e219b601b 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp @@ -75,13 +75,15 @@ public: loc, resultTy, input, reverse2, bias, rewriter.getDenseI64ArrayAttr(convPad), rewriter.getDenseI64ArrayAttr(stride), - rewriter.getDenseI64ArrayAttr({1, 1}), *op.getQuantizationInfo()); + rewriter.getDenseI64ArrayAttr({1, 1}), + /* acc_type = */ op.getAccType(), *op.getQuantizationInfo()); } else { conv2d = rewriter.create<tosa::Conv2DOp>( loc, resultTy, input, reverse2, bias, rewriter.getDenseI64ArrayAttr(convPad), rewriter.getDenseI64ArrayAttr(stride), - rewriter.getDenseI64ArrayAttr({1, 1})); + rewriter.getDenseI64ArrayAttr({1, 1}), + /* acc_type = */ op.getAccTypeAttr()); } rewriter.replaceOp(op, conv2d); @@ -139,7 +141,7 @@ public: weightPadding[5] = (weightWidth % stride[1]) ? (stride[1] - weightWidth % stride[1]) : 0; DenseElementsAttr weightPaddingAttr = DenseIntElementsAttr::get( - RankedTensorType::get({4, 2}, rewriter.getI32Type()), weightPadding); + RankedTensorType::get({8}, rewriter.getI32Type()), weightPadding); Value weightPaddingVal = CreateOpAndInferShape<tosa::ConstOp>( rewriter, loc, weightPaddingAttr.getType(), weightPaddingAttr); @@ -202,7 +204,7 @@ public: inputPadding[5] += restridedWeightTy.getDimSize(2) - 1; DenseElementsAttr inputPaddingAttr = DenseIntElementsAttr::get( - RankedTensorType::get({4, 2}, rewriter.getI32Type()), inputPadding); + RankedTensorType::get({8}, rewriter.getI32Type()), inputPadding); Value inputPaddingVal = CreateOpAndInferShape<tosa::ConstOp>( rewriter, loc, inputPaddingAttr.getType(), inputPaddingAttr); @@ -238,7 +240,7 @@ public: /*pad=*/rewriter.getDenseI64ArrayAttr({0, 0, 0, 0}), /*stride=*/rewriter.getDenseI64ArrayAttr({1, 1}), /*dilation=*/rewriter.getDenseI64ArrayAttr({1, 1}), - *op.getQuantizationInfo()) + /* acc_type = */ op.getAccType(), *op.getQuantizationInfo()) .getResult(); } else { conv2d = CreateOpAndInferShape<tosa::Conv2DOp>( @@ -246,7 +248,8 @@ public: weight, zeroBias, /*pad=*/rewriter.getDenseI64ArrayAttr({0, 0, 0, 0}), /*stride=*/rewriter.getDenseI64ArrayAttr({1, 1}), - /*dilation=*/rewriter.getDenseI64ArrayAttr({1, 1})) + /*dilation=*/rewriter.getDenseI64ArrayAttr({1, 1}), + /* acc_type = */ op.getAccTypeAttr()) .getResult(); } @@ -314,7 +317,7 @@ public: resultPadding[5] = resultTy.getDimSize(2) - resultPadLeft - sliceSize[2]; DenseElementsAttr resultPaddingAttr = DenseIntElementsAttr::get( - RankedTensorType::get({4, 2}, rewriter.getI32Type()), resultPadding); + RankedTensorType::get({8}, rewriter.getI32Type()), resultPadding); Value resultPaddingVal = CreateOpAndInferShape<tosa::ConstOp>( rewriter, loc, resultPaddingAttr.getType(), resultPaddingAttr); diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp index 6fd671051362..8588c878bfe4 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp @@ -542,9 +542,13 @@ bool TosaValidation::isValidElementType(Type type) { void TosaValidation::runOnOperation() { configLevelAndProfile(); + + TosaDialect *tosaDialect = getContext().getLoadedDialect<TosaDialect>(); + if (!tosaDialect) + return; + getOperation().walk([&](Operation *op) { - if (!op->getDialect() || - op->getDialect()->getNamespace() != TosaDialect::getDialectNamespace()) + if (op->getDialect() != tosaDialect) return; for (Value operand : op->getOperands()) { diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp index 106a79473509..798853a75441 100644 --- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp @@ -2840,6 +2840,7 @@ transform::PrintOp::apply(transform::TransformRewriter &rewriter, llvm::outs() << "top-level ]]]\n"; state.getTopLevel()->print(llvm::outs(), printFlags); llvm::outs() << "\n"; + llvm::outs().flush(); return DiagnosedSilenceableFailure::success(); } @@ -2849,6 +2850,7 @@ transform::PrintOp::apply(transform::TransformRewriter &rewriter, llvm::outs() << "\n"; } + llvm::outs().flush(); return DiagnosedSilenceableFailure::success(); } diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp index 757631944f22..68535ae5a7a5 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp @@ -481,7 +481,6 @@ void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality( return builder.create<vector::ShapeCastOp>(loc, type, inputs.front()); }; - typeConverter.addArgumentMaterialization(materializeCast); typeConverter.addSourceMaterialization(materializeCast); typeConverter.addTargetMaterialization(materializeCast); target.markUnknownOpDynamicallyLegal( diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp index 6fe96504ae10..c603db450cbd 100644 --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -284,22 +284,29 @@ OpPrintingFlags &OpPrintingFlags::skipRegions(bool skip) { } /// Do not verify the operation when using custom operation printers. -OpPrintingFlags &OpPrintingFlags::assumeVerified() { - assumeVerifiedFlag = true; +OpPrintingFlags &OpPrintingFlags::assumeVerified(bool enable) { + assumeVerifiedFlag = enable; return *this; } /// Use local scope when printing the operation. This allows for using the /// printer in a more localized and thread-safe setting, but may not necessarily /// be identical of what the IR will look like when dumping the full module. -OpPrintingFlags &OpPrintingFlags::useLocalScope() { - printLocalScope = true; +OpPrintingFlags &OpPrintingFlags::useLocalScope(bool enable) { + printLocalScope = enable; return *this; } /// Print users of values as comments. -OpPrintingFlags &OpPrintingFlags::printValueUsers() { - printValueUsersFlag = true; +OpPrintingFlags &OpPrintingFlags::printValueUsers(bool enable) { + printValueUsersFlag = enable; + return *this; +} + +/// Print unique SSA ID numbers for values, block arguments and naming conflicts +/// across all regions +OpPrintingFlags &OpPrintingFlags::printUniqueSSAIDs(bool enable) { + printUniqueSSAIDsFlag = enable; return *this; } diff --git a/mlir/lib/IR/Dominance.cpp b/mlir/lib/IR/Dominance.cpp index 406e0f2d62d6..1c54e09d29b9 100644 --- a/mlir/lib/IR/Dominance.cpp +++ b/mlir/lib/IR/Dominance.cpp @@ -213,61 +213,73 @@ DominanceInfoBase<IsPostDom>::findNearestCommonDominator(Block *a, return getDomTree(a->getParent()).findNearestCommonDominator(a, b); } -/// Return true if the specified block A properly dominates block B. -template <bool IsPostDom> -bool DominanceInfoBase<IsPostDom>::properlyDominatesImpl(Block *a, - Block *b) const { - assert(a && b && "null blocks not allowed"); +/// Returns the given block iterator if it lies within the region region. +/// Otherwise, otherwise finds the ancestor of the given block iterator that +/// lies within the given region. Returns and "empty" iterator if the latter +/// fails. +/// +/// Note: This is a variant of Region::findAncestorOpInRegion that operates on +/// block iterators instead of ops. +static std::pair<Block *, Block::iterator> +findAncestorIteratorInRegion(Region *r, Block *b, Block::iterator it) { + // Case 1: The iterator lies within the region region. + if (b->getParent() == r) + return std::make_pair(b, it); + + // Otherwise: Find ancestor iterator. Bail if we run out of parent ops. + Operation *parentOp = b->getParentOp(); + if (!parentOp) + return std::make_pair(static_cast<Block *>(nullptr), Block::iterator()); + Operation *op = r->findAncestorOpInRegion(*parentOp); + if (!op) + return std::make_pair(static_cast<Block *>(nullptr), Block::iterator()); + return std::make_pair(op->getBlock(), op->getIterator()); +} - // A block dominates, but does not properly dominate, itself unless this - // is a graph region. +/// Given two iterators into the same block, return "true" if `a` is before `b. +/// Note: This is a variant of Operation::isBeforeInBlock that operates on +/// block iterators instead of ops. +static bool isBeforeInBlock(Block *block, Block::iterator a, + Block::iterator b) { if (a == b) - return !hasSSADominance(a); - - // If both blocks are not in the same region, `a` properly dominates `b` if - // `b` is defined in an operation region that (recursively) ends up being - // dominated by `a`. Walk up the list of containers enclosing B. - Region *regionA = a->getParent(); - if (regionA != b->getParent()) { - b = regionA ? regionA->findAncestorBlockInRegion(*b) : nullptr; - // If we could not find a valid block b then it is a not a dominator. - if (!b) - return false; - - // Check to see if the ancestor of `b` is the same block as `a`. A properly - // dominates B if it contains an op that contains the B block. - if (a == b) - return true; - } - - // Otherwise, they are two different blocks in the same region, use DomTree. - return getDomTree(regionA).properlyDominates(a, b); + return false; + if (a == block->end()) + return false; + if (b == block->end()) + return true; + return a->isBeforeInBlock(&*b); } template <bool IsPostDom> bool DominanceInfoBase<IsPostDom>::properlyDominatesImpl( - Operation *a, Operation *b, bool enclosingOpOk) const { - Block *aBlock = a->getBlock(), *bBlock = b->getBlock(); - assert(aBlock && bBlock && "operations must be in a block"); + Block *aBlock, Block::iterator aIt, Block *bBlock, Block::iterator bIt, + bool enclosingOk) const { + assert(aBlock && bBlock && "expected non-null blocks"); - // An operation (pos)dominates, but does not properly (pos)dominate, itself - // unless this is a graph region. - if (a == b) + // A block iterator (post)dominates, but does not properly (post)dominate, + // itself unless this is a graph region. + if (aBlock == bBlock && aIt == bIt) return !hasSSADominance(aBlock); - // If these ops are in different regions, then normalize one into the other. + // If the iterators are in different regions, then normalize one into the + // other. Region *aRegion = aBlock->getParent(); if (aRegion != bBlock->getParent()) { - // Scoot up b's region tree until we find an operation in A's region that + // Scoot up b's region tree until we find a location in A's region that // encloses it. If this fails, then we know there is no (post)dom relation. - b = aRegion ? aRegion->findAncestorOpInRegion(*b) : nullptr; - if (!b) + if (!aRegion) { + bBlock = nullptr; + bIt = Block::iterator(); + } else { + std::tie(bBlock, bIt) = + findAncestorIteratorInRegion(aRegion, bBlock, bIt); + } + if (!bBlock) return false; - bBlock = b->getBlock(); - assert(bBlock->getParent() == aRegion); + assert(bBlock->getParent() == aRegion && "expected block in regionA"); // If 'a' encloses 'b', then we consider it to (post)dominate. - if (a == b && enclosingOpOk) + if (aBlock == bBlock && aIt == bIt && enclosingOk) return true; } @@ -279,9 +291,9 @@ bool DominanceInfoBase<IsPostDom>::properlyDominatesImpl( if (!hasSSADominance(aBlock)) return true; if constexpr (IsPostDom) { - return b->isBeforeInBlock(a); + return isBeforeInBlock(aBlock, bIt, aIt); } else { - return a->isBeforeInBlock(b); + return isBeforeInBlock(aBlock, aIt, bIt); } } @@ -309,6 +321,18 @@ template class detail::DominanceInfoBase</*IsPostDom=*/false>; // DominanceInfo //===----------------------------------------------------------------------===// +bool DominanceInfo::properlyDominates(Operation *a, Operation *b, + bool enclosingOpOk) const { + return super::properlyDominatesImpl(a->getBlock(), a->getIterator(), + b->getBlock(), b->getIterator(), + enclosingOpOk); +} + +bool DominanceInfo::properlyDominates(Block *a, Block *b) const { + return super::properlyDominatesImpl(a, a->begin(), b, b->begin(), + /*enclosingOk=*/true); +} + /// Return true if the `a` value properly dominates operation `b`, i.e if the /// operation that defines `a` properlyDominates `b` and the operation that /// defines `a` does not contain `b`. @@ -322,3 +346,19 @@ bool DominanceInfo::properlyDominates(Value a, Operation *b) const { // `b`, but `a` does not itself enclose `b` in one of its regions. return properlyDominates(a.getDefiningOp(), b, /*enclosingOpOk=*/false); } + +//===----------------------------------------------------------------------===// +// PostDominanceInfo +//===----------------------------------------------------------------------===// + +bool PostDominanceInfo::properlyPostDominates(Operation *a, Operation *b, + bool enclosingOpOk) const { + return super::properlyDominatesImpl(a->getBlock(), a->getIterator(), + b->getBlock(), b->getIterator(), + enclosingOpOk); +} + +bool PostDominanceInfo::properlyPostDominates(Block *a, Block *b) const { + return super::properlyDominatesImpl(a, a->end(), b, b->end(), + /*enclosingOk=*/true); +} diff --git a/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp index cf58bc5d8f47..659ab1227f11 100644 --- a/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp @@ -237,15 +237,7 @@ public: generateMetadata(value.getInt(), "maxnreg"); } else if (attribute.getName() == NVVM::NVVMDialect::getKernelFuncAttrName()) { - llvm::Metadata *llvmMetadataKernel[] = { - llvm::ValueAsMetadata::get(llvmFunc), - llvm::MDString::get(llvmContext, "kernel"), - llvm::ValueAsMetadata::get( - llvm::ConstantInt::get(llvm::Type::getInt32Ty(llvmContext), 1))}; - llvm::MDNode *llvmMetadataNode = - llvm::MDNode::get(llvmContext, llvmMetadataKernel); - moduleTranslation.getOrInsertNamedModuleMetadata("nvvm.annotations") - ->addOperand(llvmMetadataNode); + llvmFunc->setCallingConv(llvm::CallingConv::PTX_Kernel); } return success(); } diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp index 9a30266103b1..87cb7f03fec6 100644 --- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp @@ -150,10 +150,6 @@ static LogicalResult checkImplementationStatus(Operation &op) { << " operation"; }; - auto checkAligned = [&todo](auto op, LogicalResult &result) { - if (!op.getAlignedVars().empty() || op.getAlignments()) - result = todo("aligned"); - }; auto checkAllocate = [&todo](auto op, LogicalResult &result) { if (!op.getAllocateVars().empty() || !op.getAllocatorVars().empty()) result = todo("allocate"); @@ -275,7 +271,6 @@ static LogicalResult checkImplementationStatus(Operation &op) { }) .Case([&](omp::ParallelOp op) { checkAllocate(op, result); }) .Case([&](omp::SimdOp op) { - checkAligned(op, result); checkLinear(op, result); checkNontemporal(op, result); checkPrivate(op, result); @@ -2302,6 +2297,24 @@ convertOmpSimd(Operation &opInst, llvm::IRBuilderBase &builder, llvm::MapVector<llvm::Value *, llvm::Value *> alignedVars; llvm::omp::OrderKind order = convertOrderKind(simdOp.getOrder()); + llvm::BasicBlock *sourceBlock = builder.GetInsertBlock(); + std::optional<ArrayAttr> alignmentValues = simdOp.getAlignments(); + mlir::OperandRange operands = simdOp.getAlignedVars(); + for (size_t i = 0; i < operands.size(); ++i) { + llvm::Value *alignment = nullptr; + llvm::Value *llvmVal = moduleTranslation.lookupValue(operands[i]); + llvm::Type *ty = llvmVal->getType(); + if (auto intAttr = llvm::dyn_cast<IntegerAttr>((*alignmentValues)[i])) { + alignment = builder.getInt64(intAttr.getInt()); + assert(ty->isPointerTy() && "Invalid type for aligned variable"); + assert(alignment && "Invalid alignment value"); + auto curInsert = builder.saveIP(); + builder.SetInsertPoint(sourceBlock->getTerminator()); + llvmVal = builder.CreateLoad(ty, llvmVal); + builder.restoreIP(curInsert); + alignedVars[llvmVal] = alignment; + } + } ompBuilder->applySimd(loopInfo, alignedVars, simdOp.getIfExpr() ? moduleTranslation.lookupValue(simdOp.getIfExpr()) @@ -2575,6 +2588,7 @@ static LogicalResult convertOmpThreadprivate(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation) { llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder); + llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder(); auto threadprivateOp = cast<omp::ThreadprivateOp>(opInst); if (failed(checkImplementationStatus(opInst))) @@ -2582,6 +2596,10 @@ convertOmpThreadprivate(Operation &opInst, llvm::IRBuilderBase &builder, Value symAddr = threadprivateOp.getSymAddr(); auto *symOp = symAddr.getDefiningOp(); + + if (auto asCast = dyn_cast<LLVM::AddrSpaceCastOp>(symOp)) + symOp = asCast.getOperand().getDefiningOp(); + if (!isa<LLVM::AddressOfOp>(symOp)) return opInst.emitError("Addressing symbol not found"); LLVM::AddressOfOp addressOfOp = dyn_cast<LLVM::AddressOfOp>(symOp); @@ -2589,17 +2607,20 @@ convertOmpThreadprivate(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::GlobalOp global = addressOfOp.getGlobal(moduleTranslation.symbolTable()); llvm::GlobalValue *globalValue = moduleTranslation.lookupGlobal(global); - llvm::Type *type = globalValue->getValueType(); - llvm::TypeSize typeSize = - builder.GetInsertBlock()->getModule()->getDataLayout().getTypeStoreSize( - type); - llvm::ConstantInt *size = builder.getInt64(typeSize.getFixedValue()); - llvm::StringRef suffix = llvm::StringRef(".cache", 6); - std::string cacheName = (Twine(global.getSymName()).concat(suffix)).str(); - llvm::Value *callInst = - moduleTranslation.getOpenMPBuilder()->createCachedThreadPrivate( - ompLoc, globalValue, size, cacheName); - moduleTranslation.mapValue(opInst.getResult(0), callInst); + + if (!ompBuilder->Config.isTargetDevice()) { + llvm::Type *type = globalValue->getValueType(); + llvm::TypeSize typeSize = + builder.GetInsertBlock()->getModule()->getDataLayout().getTypeStoreSize( + type); + llvm::ConstantInt *size = builder.getInt64(typeSize.getFixedValue()); + llvm::Value *callInst = ompBuilder->createCachedThreadPrivate( + ompLoc, globalValue, size, global.getSymName() + ".cache"); + moduleTranslation.mapValue(opInst.getResult(0), callInst); + } else { + moduleTranslation.mapValue(opInst.getResult(0), globalValue); + } + return success(); } @@ -4199,6 +4220,14 @@ static bool isTargetDeviceOp(Operation *op) { if (op->getParentOfType<omp::TargetOp>()) return true; + // Certain operations return results, and whether utilised in host or + // target there is a chance an LLVM Dialect operation depends on it + // by taking it in as an operand, so we must always lower these in + // some manner or result in an ICE (whether they end up in a no-op + // or otherwise). + if (mlir::isa<omp::ThreadprivateOp>(op)) + return true; + if (auto parentFn = op->getParentOfType<LLVM::LLVMFuncOp>()) if (auto declareTargetIface = llvm::dyn_cast<mlir::omp::DeclareTargetInterface>( diff --git a/mlir/lib/Target/LLVMIR/ModuleImport.cpp b/mlir/lib/Target/LLVMIR/ModuleImport.cpp index b0d5e635248d..2d8d7745eca9 100644 --- a/mlir/lib/Target/LLVMIR/ModuleImport.cpp +++ b/mlir/lib/Target/LLVMIR/ModuleImport.cpp @@ -427,19 +427,33 @@ ModuleImport::processAliasScopeMetadata(const llvm::MDNode *node) { return node->getNumOperands() != 0 && node == dyn_cast<llvm::MDNode>(node->getOperand(0)); }; + auto verifySelfRefOrString = [](const llvm::MDNode *node) { + return node->getNumOperands() != 0 && + (node == dyn_cast<llvm::MDNode>(node->getOperand(0)) || + isa<llvm::MDString>(node->getOperand(0))); + }; // Helper that verifies the given operand is a string or does not exist. auto verifyDescription = [](const llvm::MDNode *node, unsigned idx) { return idx >= node->getNumOperands() || isa<llvm::MDString>(node->getOperand(idx)); }; + + auto getIdAttr = [&](const llvm::MDNode *node) -> Attribute { + if (verifySelfRef(node)) + return DistinctAttr::create(builder.getUnitAttr()); + + auto name = cast<llvm::MDString>(node->getOperand(0)); + return builder.getStringAttr(name->getString()); + }; + // Helper that creates an alias scope domain attribute. auto createAliasScopeDomainOp = [&](const llvm::MDNode *aliasDomain) { StringAttr description = nullptr; if (aliasDomain->getNumOperands() >= 2) if (auto *operand = dyn_cast<llvm::MDString>(aliasDomain->getOperand(1))) description = builder.getStringAttr(operand->getString()); - return builder.getAttr<AliasScopeDomainAttr>( - DistinctAttr::create(builder.getUnitAttr()), description); + Attribute idAttr = getIdAttr(aliasDomain); + return builder.getAttr<AliasScopeDomainAttr>(idAttr, description); }; // Collect the alias scopes and domains to translate them. @@ -452,10 +466,11 @@ ModuleImport::processAliasScopeMetadata(const llvm::MDNode *node) { // verifying its domain. Perform the verification before looking it up in // the alias scope mapping since it could have been inserted as a domain // node before. - if (!verifySelfRef(scope) || !domain || !verifyDescription(scope, 2)) + if (!verifySelfRefOrString(scope) || !domain || + !verifyDescription(scope, 2)) return emitError(loc) << "unsupported alias scope node: " << diagMD(scope, llvmModule.get()); - if (!verifySelfRef(domain) || !verifyDescription(domain, 1)) + if (!verifySelfRefOrString(domain) || !verifyDescription(domain, 1)) return emitError(loc) << "unsupported alias domain node: " << diagMD(domain, llvmModule.get()); @@ -473,9 +488,10 @@ ModuleImport::processAliasScopeMetadata(const llvm::MDNode *node) { StringAttr description = nullptr; if (!aliasScope.getName().empty()) description = builder.getStringAttr(aliasScope.getName()); + Attribute idAttr = getIdAttr(scope); auto aliasScopeOp = builder.getAttr<AliasScopeAttr>( - DistinctAttr::create(builder.getUnitAttr()), - cast<AliasScopeDomainAttr>(it->second), description); + idAttr, cast<AliasScopeDomainAttr>(it->second), description); + aliasScopeMapping.try_emplace(aliasScope.getNode(), aliasScopeOp); } } @@ -1473,18 +1489,20 @@ ModuleImport::convertBranchArgs(llvm::Instruction *branch, return success(); } -LogicalResult -ModuleImport::convertCallTypeAndOperands(llvm::CallBase *callInst, - SmallVectorImpl<Type> &types, - SmallVectorImpl<Value> &operands) { +LogicalResult ModuleImport::convertCallTypeAndOperands( + llvm::CallBase *callInst, SmallVectorImpl<Type> &types, + SmallVectorImpl<Value> &operands, bool allowInlineAsm) { if (!callInst->getType()->isVoidTy()) types.push_back(convertType(callInst->getType())); if (!callInst->getCalledFunction()) { - FailureOr<Value> called = convertValue(callInst->getCalledOperand()); - if (failed(called)) - return failure(); - operands.push_back(*called); + if (!allowInlineAsm || + !isa<llvm::InlineAsm>(callInst->getCalledOperand())) { + FailureOr<Value> called = convertValue(callInst->getCalledOperand()); + if (failed(called)) + return failure(); + operands.push_back(*called); + } } SmallVector<llvm::Value *> args(callInst->args()); FailureOr<SmallVector<Value>> arguments = convertValues(args); @@ -1579,7 +1597,8 @@ LogicalResult ModuleImport::convertInstruction(llvm::Instruction *inst) { SmallVector<Type> types; SmallVector<Value> operands; - if (failed(convertCallTypeAndOperands(callInst, types, operands))) + if (failed(convertCallTypeAndOperands(callInst, types, operands, + /*allowInlineAsm=*/true))) return failure(); auto funcTy = @@ -1587,45 +1606,59 @@ LogicalResult ModuleImport::convertInstruction(llvm::Instruction *inst) { if (!funcTy) return failure(); - CallOp callOp; - - if (llvm::Function *callee = callInst->getCalledFunction()) { - callOp = builder.create<CallOp>( - loc, funcTy, SymbolRefAttr::get(context, callee->getName()), - operands); + if (auto asmI = dyn_cast<llvm::InlineAsm>(callInst->getCalledOperand())) { + auto callOp = builder.create<InlineAsmOp>( + loc, funcTy.getReturnType(), operands, + builder.getStringAttr(asmI->getAsmString()), + builder.getStringAttr(asmI->getConstraintString()), + /*has_side_effects=*/true, + /*is_align_stack=*/false, /*asm_dialect=*/nullptr, + /*operand_attrs=*/nullptr); + if (!callInst->getType()->isVoidTy()) + mapValue(inst, callOp.getResult(0)); + else + mapNoResultOp(inst, callOp); } else { - callOp = builder.create<CallOp>(loc, funcTy, operands); + CallOp callOp; + + if (llvm::Function *callee = callInst->getCalledFunction()) { + callOp = builder.create<CallOp>( + loc, funcTy, SymbolRefAttr::get(context, callee->getName()), + operands); + } else { + callOp = builder.create<CallOp>(loc, funcTy, operands); + } + callOp.setCConv(convertCConvFromLLVM(callInst->getCallingConv())); + callOp.setTailCallKind( + convertTailCallKindFromLLVM(callInst->getTailCallKind())); + setFastmathFlagsAttr(inst, callOp); + + // Handle function attributes. + if (callInst->hasFnAttr(llvm::Attribute::Convergent)) + callOp.setConvergent(true); + if (callInst->hasFnAttr(llvm::Attribute::NoUnwind)) + callOp.setNoUnwind(true); + if (callInst->hasFnAttr(llvm::Attribute::WillReturn)) + callOp.setWillReturn(true); + + llvm::MemoryEffects memEffects = callInst->getMemoryEffects(); + ModRefInfo othermem = convertModRefInfoFromLLVM( + memEffects.getModRef(llvm::MemoryEffects::Location::Other)); + ModRefInfo argMem = convertModRefInfoFromLLVM( + memEffects.getModRef(llvm::MemoryEffects::Location::ArgMem)); + ModRefInfo inaccessibleMem = convertModRefInfoFromLLVM( + memEffects.getModRef(llvm::MemoryEffects::Location::InaccessibleMem)); + auto memAttr = MemoryEffectsAttr::get(callOp.getContext(), othermem, + argMem, inaccessibleMem); + // Only set the attribute when it does not match the default value. + if (!memAttr.isReadWrite()) + callOp.setMemoryEffectsAttr(memAttr); + + if (!callInst->getType()->isVoidTy()) + mapValue(inst, callOp.getResult()); + else + mapNoResultOp(inst, callOp); } - callOp.setCConv(convertCConvFromLLVM(callInst->getCallingConv())); - callOp.setTailCallKind( - convertTailCallKindFromLLVM(callInst->getTailCallKind())); - setFastmathFlagsAttr(inst, callOp); - - // Handle function attributes. - if (callInst->hasFnAttr(llvm::Attribute::Convergent)) - callOp.setConvergent(true); - if (callInst->hasFnAttr(llvm::Attribute::NoUnwind)) - callOp.setNoUnwind(true); - if (callInst->hasFnAttr(llvm::Attribute::WillReturn)) - callOp.setWillReturn(true); - - llvm::MemoryEffects memEffects = callInst->getMemoryEffects(); - ModRefInfo othermem = convertModRefInfoFromLLVM( - memEffects.getModRef(llvm::MemoryEffects::Location::Other)); - ModRefInfo argMem = convertModRefInfoFromLLVM( - memEffects.getModRef(llvm::MemoryEffects::Location::ArgMem)); - ModRefInfo inaccessibleMem = convertModRefInfoFromLLVM( - memEffects.getModRef(llvm::MemoryEffects::Location::InaccessibleMem)); - auto memAttr = MemoryEffectsAttr::get(callOp.getContext(), othermem, argMem, - inaccessibleMem); - // Only set the attribute when it does not match the default value. - if (!memAttr.isReadWrite()) - callOp.setMemoryEffectsAttr(memAttr); - - if (!callInst->getType()->isVoidTy()) - mapValue(inst, callOp.getResult()); - else - mapNoResultOp(inst, callOp); return success(); } if (inst->getOpcode() == llvm::Instruction::LandingPad) { diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp index ad62ae0cef57..4367100e3aca 100644 --- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp @@ -1724,25 +1724,36 @@ ModuleTranslation::getOrCreateAliasScope(AliasScopeAttr aliasScopeAttr) { aliasScopeAttr.getDomain(), nullptr); if (insertedDomain) { llvm::SmallVector<llvm::Metadata *, 2> operands; - // Placeholder for self-reference. + // Placeholder for potential self-reference. operands.push_back(dummy.get()); if (StringAttr description = aliasScopeAttr.getDomain().getDescription()) operands.push_back(llvm::MDString::get(ctx, description)); domainIt->second = llvm::MDNode::get(ctx, operands); // Self-reference for uniqueness. - domainIt->second->replaceOperandWith(0, domainIt->second); + llvm::Metadata *replacement; + if (auto stringAttr = + dyn_cast<StringAttr>(aliasScopeAttr.getDomain().getId())) + replacement = llvm::MDString::get(ctx, stringAttr.getValue()); + else + replacement = domainIt->second; + domainIt->second->replaceOperandWith(0, replacement); } // Convert the scope metadata node. assert(domainIt->second && "Scope's domain should already be valid"); llvm::SmallVector<llvm::Metadata *, 3> operands; - // Placeholder for self-reference. + // Placeholder for potential self-reference. operands.push_back(dummy.get()); operands.push_back(domainIt->second); if (StringAttr description = aliasScopeAttr.getDescription()) operands.push_back(llvm::MDString::get(ctx, description)); scopeIt->second = llvm::MDNode::get(ctx, operands); // Self-reference for uniqueness. - scopeIt->second->replaceOperandWith(0, scopeIt->second); + llvm::Metadata *replacement; + if (auto stringAttr = dyn_cast<StringAttr>(aliasScopeAttr.getId())) + replacement = llvm::MDString::get(ctx, stringAttr.getValue()); + else + replacement = scopeIt->second; + scopeIt->second->replaceOperandWith(0, replacement); return scopeIt->second; } diff --git a/mlir/lib/Transforms/LocationSnapshot.cpp b/mlir/lib/Transforms/LocationSnapshot.cpp index b85850acda91..f701c8b4f0a9 100644 --- a/mlir/lib/Transforms/LocationSnapshot.cpp +++ b/mlir/lib/Transforms/LocationSnapshot.cpp @@ -10,6 +10,7 @@ #include "mlir/IR/AsmState.h" #include "mlir/IR/Builders.h" +#include "mlir/IR/OperationSupport.h" #include "mlir/Pass/Pass.h" #include "mlir/Support/FileUtilities.h" #include "llvm/Support/FileSystem.h" @@ -131,29 +132,23 @@ LogicalResult mlir::generateLocationsFromIR(StringRef fileName, StringRef tag, namespace { struct LocationSnapshotPass : public impl::LocationSnapshotBase<LocationSnapshotPass> { - LocationSnapshotPass() = default; - LocationSnapshotPass(OpPrintingFlags flags, StringRef fileName, StringRef tag) - : flags(flags) { - this->fileName = fileName.str(); - this->tag = tag.str(); - } + using impl::LocationSnapshotBase<LocationSnapshotPass>::LocationSnapshotBase; void runOnOperation() override { Operation *op = getOperation(); - if (failed(generateLocationsFromIR(fileName, op, OpPrintingFlags(), tag))) + if (failed(generateLocationsFromIR(fileName, op, getFlags(), tag))) return signalPassFailure(); } - /// The printing flags to use when creating the snapshot. - OpPrintingFlags flags; +private: + /// build the flags from the command line arguments to the pass + OpPrintingFlags getFlags() { + OpPrintingFlags flags; + flags.enableDebugInfo(enableDebugInfo, printPrettyDebugInfo); + flags.printGenericOpForm(printGenericOpForm); + if (useLocalScope) + flags.useLocalScope(); + return flags; + } }; } // namespace - -std::unique_ptr<Pass> mlir::createLocationSnapshotPass(OpPrintingFlags flags, - StringRef fileName, - StringRef tag) { - return std::make_unique<LocationSnapshotPass>(flags, fileName, tag); -} -std::unique_ptr<Pass> mlir::createLocationSnapshotPass() { - return std::make_unique<LocationSnapshotPass>(); -} diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp index 255b0ba2559e..403321d40d53 100644 --- a/mlir/lib/Transforms/Utils/DialectConversion.cpp +++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp @@ -11,6 +11,7 @@ #include "mlir/IR/Block.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/Dominance.h" #include "mlir/IR/IRMapping.h" #include "mlir/IR/Iterators.h" #include "mlir/Interfaces/FunctionInterfaces.h" @@ -63,11 +64,55 @@ static OpBuilder::InsertPoint computeInsertPoint(Value value) { return OpBuilder::InsertPoint(insertBlock, insertPt); } +/// Helper function that computes an insertion point where the given values are +/// defined and can be used without a dominance violation. +static OpBuilder::InsertPoint computeInsertPoint(ArrayRef<Value> vals) { + assert(!vals.empty() && "expected at least one value"); + DominanceInfo domInfo; + OpBuilder::InsertPoint pt = computeInsertPoint(vals.front()); + for (Value v : vals.drop_front()) { + // Choose the "later" insertion point. + OpBuilder::InsertPoint nextPt = computeInsertPoint(v); + if (domInfo.dominates(pt.getBlock(), pt.getPoint(), nextPt.getBlock(), + nextPt.getPoint())) { + // pt is before nextPt => choose nextPt. + pt = nextPt; + } else { +#ifndef NDEBUG + // nextPt should be before pt => choose pt. + // If pt, nextPt are no dominance relationship, then there is no valid + // insertion point at which all given values are defined. + bool dom = domInfo.dominates(nextPt.getBlock(), nextPt.getPoint(), + pt.getBlock(), pt.getPoint()); + assert(dom && "unable to find valid insertion point"); +#endif // NDEBUG + } + } + return pt; +} + //===----------------------------------------------------------------------===// // ConversionValueMapping //===----------------------------------------------------------------------===// +/// A vector of SSA values, optimized for the most common case of a single +/// value. +using ValueVector = SmallVector<Value, 1>; + namespace { + +/// Helper class to make it possible to use `ValueVector` as a key in DenseMap. +struct ValueVectorMapInfo { + static ValueVector getEmptyKey() { return ValueVector{Value()}; } + static ValueVector getTombstoneKey() { return ValueVector{Value(), Value()}; } + static ::llvm::hash_code getHashValue(const ValueVector &val) { + return ::llvm::hash_combine_range(val.begin(), val.end()); + } + static bool isEqual(const ValueVector &LHS, const ValueVector &RHS) { + return LHS == RHS; + } +}; + /// This class wraps a IRMapping to provide recursive lookup /// functionality, i.e. we will traverse if the mapped value also has a mapping. struct ConversionValueMapping { @@ -75,68 +120,128 @@ struct ConversionValueMapping { /// false positives. bool isMappedTo(Value value) const { return mappedTo.contains(value); } - /// Lookup the most recently mapped value with the desired type in the + /// Lookup the most recently mapped values with the desired types in the /// mapping. /// /// Special cases: - /// - If the desired type is "null", simply return the most recently mapped - /// value. - /// - If there is no mapping to the desired type, also return the most - /// recently mapped value. - /// - If there is no mapping for the given value at all, return the given + /// - If the desired type range is empty, simply return the most recently + /// mapped values. + /// - If there is no mapping to the desired types, also return the most + /// recently mapped values. + /// - If there is no mapping for the given values at all, return the given /// value. - Value lookupOrDefault(Value from, Type desiredType = nullptr) const; + ValueVector lookupOrDefault(Value from, TypeRange desiredTypes = {}) const; - /// Lookup a mapped value within the map, or return null if a mapping does not - /// exist. If a mapping exists, this follows the same behavior of - /// `lookupOrDefault`. - Value lookupOrNull(Value from, Type desiredType = nullptr) const; + /// Lookup the given value within the map, or return an empty vector if the + /// value is not mapped. If it is mapped, this follows the same behavior + /// as `lookupOrDefault`. + ValueVector lookupOrNull(Value from, TypeRange desiredTypes = {}) const; - /// Map a value to the one provided. - void map(Value oldVal, Value newVal) { + template <typename T> + struct IsValueVector : std::is_same<std::decay_t<T>, ValueVector> {}; + + /// Map a value vector to the one provided. + template <typename OldVal, typename NewVal> + std::enable_if_t<IsValueVector<OldVal>::value && IsValueVector<NewVal>::value> + map(OldVal &&oldVal, NewVal &&newVal) { LLVM_DEBUG({ - for (Value it = newVal; it; it = mapping.lookupOrNull(it)) - assert(it != oldVal && "inserting cyclic mapping"); + ValueVector next(newVal); + while (true) { + assert(next != oldVal && "inserting cyclic mapping"); + auto it = mapping.find(next); + if (it == mapping.end()) + break; + next = it->second; + } }); - mapping.map(oldVal, newVal); - mappedTo.insert(newVal); + for (Value v : newVal) + mappedTo.insert(v); + + mapping[std::forward<OldVal>(oldVal)] = std::forward<NewVal>(newVal); } - /// Drop the last mapping for the given value. - void erase(Value value) { mapping.erase(value); } + /// Map a value vector or single value to the one provided. + template <typename OldVal, typename NewVal> + std::enable_if_t<!IsValueVector<OldVal>::value || + !IsValueVector<NewVal>::value> + map(OldVal &&oldVal, NewVal &&newVal) { + if constexpr (IsValueVector<OldVal>{}) { + map(std::forward<OldVal>(oldVal), ValueVector{newVal}); + } else if constexpr (IsValueVector<NewVal>{}) { + map(ValueVector{oldVal}, std::forward<NewVal>(newVal)); + } else { + map(ValueVector{oldVal}, ValueVector{newVal}); + } + } + + /// Drop the last mapping for the given values. + void erase(const ValueVector &value) { mapping.erase(value); } private: /// Current value mappings. - IRMapping mapping; + DenseMap<ValueVector, ValueVector, ValueVectorMapInfo> mapping; /// All SSA values that are mapped to. May contain false positives. DenseSet<Value> mappedTo; }; } // namespace -Value ConversionValueMapping::lookupOrDefault(Value from, - Type desiredType) const { - // Try to find the deepest value that has the desired type. If there is no - // such value, simply return the deepest value. - Value desiredValue; +ValueVector +ConversionValueMapping::lookupOrDefault(Value from, + TypeRange desiredTypes) const { + // Try to find the deepest values that have the desired types. If there is no + // such mapping, simply return the deepest values. + ValueVector desiredValue; + ValueVector current{from}; do { - if (!desiredType || from.getType() == desiredType) - desiredValue = from; + // Store the current value if the types match. + if (TypeRange(ValueRange(current)) == desiredTypes) + desiredValue = current; + + // If possible, Replace each value with (one or multiple) mapped values. + ValueVector next; + for (Value v : current) { + auto it = mapping.find({v}); + if (it != mapping.end()) { + llvm::append_range(next, it->second); + } else { + next.push_back(v); + } + } + if (next != current) { + // If at least one value was replaced, continue the lookup from there. + current = std::move(next); + continue; + } - Value mappedValue = mapping.lookupOrNull(from); - if (!mappedValue) + // Otherwise: Check if there is a mapping for the entire vector. Such + // mappings are materializations. (N:M mapping are not supported for value + // replacements.) + // + // Note: From a correctness point of view, materializations do not have to + // be stored (and looked up) in the mapping. But for performance reasons, + // we choose to reuse existing IR (when possible) instead of creating it + // multiple times. + auto it = mapping.find(current); + if (it == mapping.end()) { + // No mapping found: The lookup stops here. break; - from = mappedValue; + } + current = it->second; } while (true); - // If the desired value was found use it, otherwise default to the leaf value. - return desiredValue ? desiredValue : from; + // If the desired values were found use them, otherwise default to the leaf + // values. + // Note: If `desiredTypes` is empty, this function always returns `current`. + return !desiredValue.empty() ? std::move(desiredValue) : std::move(current); } -Value ConversionValueMapping::lookupOrNull(Value from, Type desiredType) const { - Value result = lookupOrDefault(from, desiredType); - if (result == from || (desiredType && result.getType() != desiredType)) - return nullptr; +ValueVector ConversionValueMapping::lookupOrNull(Value from, + TypeRange desiredTypes) const { + ValueVector result = lookupOrDefault(from, desiredTypes); + if (result == ValueVector{from} || + (!desiredTypes.empty() && TypeRange(ValueRange(result)) != desiredTypes)) + return {}; return result; } @@ -651,10 +756,6 @@ public: /// The type of materialization. enum MaterializationKind { - /// This materialization materializes a conversion for an illegal block - /// argument type, to the original one. - Argument, - /// This materialization materializes a conversion from an illegal type to a /// legal one. Target, @@ -673,7 +774,7 @@ public: UnrealizedConversionCastOp op, const TypeConverter *converter, MaterializationKind kind, Type originalType, - Value mappedValue); + ValueVector mappedValues); static bool classof(const IRRewrite *rewrite) { return rewrite->getKind() == Kind::UnresolvedMaterialization; @@ -708,9 +809,9 @@ private: /// materializations. Type originalType; - /// The value in the conversion value mapping that is being replaced by the + /// The values in the conversion value mapping that are being replaced by the /// results of this unresolved materialization. - Value mappedValue; + ValueVector mappedValues; }; } // namespace @@ -779,7 +880,7 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener { LogicalResult remapValues(StringRef valueDiagTag, std::optional<Location> inputLoc, PatternRewriter &rewriter, ValueRange values, - SmallVector<SmallVector<Value>> &remapped); + SmallVector<ValueVector> &remapped); /// Return "true" if the given operation is ignored, and does not need to be /// converted. @@ -820,39 +921,14 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener { /// If a cast op was built, it can optionally be returned with the `castOp` /// output argument. /// - /// If `valueToMap` is set to a non-null Value, then that value is mapped to + /// If `valuesToMap` is set to a non-null Value, then that value is mapped to /// the results of the unresolved materialization in the conversion value /// mapping. ValueRange buildUnresolvedMaterialization( MaterializationKind kind, OpBuilder::InsertPoint ip, Location loc, - Value valueToMap, ValueRange inputs, TypeRange outputTypes, + ValueVector valuesToMap, ValueRange inputs, TypeRange outputTypes, Type originalType, const TypeConverter *converter, UnrealizedConversionCastOp *castOp = nullptr); - Value buildUnresolvedMaterialization( - MaterializationKind kind, OpBuilder::InsertPoint ip, Location loc, - Value valueToMap, ValueRange inputs, Type outputType, Type originalType, - const TypeConverter *converter, - UnrealizedConversionCastOp *castOp = nullptr) { - return buildUnresolvedMaterialization(kind, ip, loc, valueToMap, inputs, - TypeRange(outputType), originalType, - converter, castOp) - .front(); - } - - /// Build an N:1 materialization for the given original value that was - /// replaced with the given replacement values. - /// - /// This is a workaround around incomplete 1:N support in the dialect - /// conversion driver. The conversion mapping can store only 1:1 replacements - /// and the conversion patterns only support single Value replacements in the - /// adaptor, so N values must be converted back to a single value. This - /// function will be deleted when full 1:N support has been added. - /// - /// This function inserts an argument materialization back to the original - /// type. - void insertNTo1Materialization(OpBuilder::InsertPoint ip, Location loc, - ValueRange replacements, Value originalValue, - const TypeConverter *converter); /// Find a replacement value for the given SSA value in the conversion value /// mapping. The replacement value must have the same type as the given SSA @@ -862,16 +938,6 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener { Value findOrBuildReplacementValue(Value value, const TypeConverter *converter); - /// Unpack an N:1 materialization and return the inputs of the - /// materialization. This function unpacks only those materializations that - /// were built with `insertNTo1Materialization`. - /// - /// This is a workaround around incomplete 1:N support in the dialect - /// conversion driver. It allows us to write 1:N conversion patterns while - /// 1:N support is still missing in the conversion value mapping. This - /// function will be deleted when full 1:N support has been added. - SmallVector<Value> unpackNTo1Materialization(Value value); - //===--------------------------------------------------------------------===// // Rewriter Notification Hooks //===--------------------------------------------------------------------===// @@ -974,10 +1040,6 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener { DenseMap<UnrealizedConversionCastOp, UnresolvedMaterializationRewrite *> unresolvedMaterializations; - /// A set of all N:1 materializations that were added to work around - /// incomplete 1:N support in the dialect conversion driver. - DenseSet<UnrealizedConversionCastOp> nTo1TempMaterializations; - /// The current type converter, or nullptr if no type converter is currently /// active. const TypeConverter *currentTypeConverter = nullptr; @@ -1041,7 +1103,7 @@ void ReplaceBlockArgRewrite::commit(RewriterBase &rewriter) { }); } -void ReplaceBlockArgRewrite::rollback() { rewriterImpl.mapping.erase(arg); } +void ReplaceBlockArgRewrite::rollback() { rewriterImpl.mapping.erase({arg}); } void ReplaceOperationRewrite::commit(RewriterBase &rewriter) { auto *listener = @@ -1082,7 +1144,7 @@ void ReplaceOperationRewrite::commit(RewriterBase &rewriter) { void ReplaceOperationRewrite::rollback() { for (auto result : op->getResults()) - rewriterImpl.mapping.erase(result); + rewriterImpl.mapping.erase({result}); } void ReplaceOperationRewrite::cleanup(RewriterBase &rewriter) { @@ -1101,20 +1163,19 @@ void CreateOperationRewrite::rollback() { UnresolvedMaterializationRewrite::UnresolvedMaterializationRewrite( ConversionPatternRewriterImpl &rewriterImpl, UnrealizedConversionCastOp op, const TypeConverter *converter, MaterializationKind kind, Type originalType, - Value mappedValue) + ValueVector mappedValues) : OperationRewrite(Kind::UnresolvedMaterialization, rewriterImpl, op), converterAndKind(converter, kind), originalType(originalType), - mappedValue(mappedValue) { + mappedValues(std::move(mappedValues)) { assert((!originalType || kind == MaterializationKind::Target) && "original type is valid only for target materializations"); rewriterImpl.unresolvedMaterializations[op] = this; } void UnresolvedMaterializationRewrite::rollback() { - if (mappedValue) - rewriterImpl.mapping.erase(mappedValue); + if (!mappedValues.empty()) + rewriterImpl.mapping.erase(mappedValues); rewriterImpl.unresolvedMaterializations.erase(getOperation()); - rewriterImpl.nTo1TempMaterializations.erase(getOperation()); op->erase(); } @@ -1160,7 +1221,7 @@ void ConversionPatternRewriterImpl::undoRewrites(unsigned numRewritesToKeep) { LogicalResult ConversionPatternRewriterImpl::remapValues( StringRef valueDiagTag, std::optional<Location> inputLoc, PatternRewriter &rewriter, ValueRange values, - SmallVector<SmallVector<Value>> &remapped) { + SmallVector<ValueVector> &remapped) { remapped.reserve(llvm::size(values)); for (const auto &it : llvm::enumerate(values)) { @@ -1168,18 +1229,11 @@ LogicalResult ConversionPatternRewriterImpl::remapValues( Type origType = operand.getType(); Location operandLoc = inputLoc ? *inputLoc : operand.getLoc(); - // Find the most recently mapped value. Unpack all temporary N:1 - // materializations. Such conversions are a workaround around missing - // 1:N support in the ConversionValueMapping. (The conversion patterns - // already support 1:N replacements.) - Value repl = mapping.lookupOrDefault(operand); - SmallVector<Value> unpacked = unpackNTo1Materialization(repl); - if (!currentTypeConverter) { // The current pattern does not have a type converter. I.e., it does not // distinguish between legal and illegal types. For each operand, simply - // pass through the most recently mapped value. - remapped.push_back(std::move(unpacked)); + // pass through the most recently mapped values. + remapped.push_back(mapping.lookupOrDefault(operand)); continue; } @@ -1192,51 +1246,28 @@ LogicalResult ConversionPatternRewriterImpl::remapValues( }); return failure(); } - // If a type is converted to 0 types, there is nothing to do. if (legalTypes.empty()) { remapped.push_back({}); continue; } - if (legalTypes.size() != 1) { - // TODO: This is a 1:N conversion. The conversion value mapping does not - // store such materializations yet. If the types of the most recently - // mapped values do not match, build a target materialization. - ValueRange unpackedRange(unpacked); - if (TypeRange(unpackedRange) == legalTypes) { - remapped.push_back(std::move(unpacked)); - continue; - } - - // Insert a target materialization if the current pattern expects - // different legalized types. - ValueRange targetMat = buildUnresolvedMaterialization( - MaterializationKind::Target, computeInsertPoint(repl), operandLoc, - /*valueToMap=*/Value(), /*inputs=*/unpacked, - /*outputType=*/legalTypes, /*originalType=*/origType, - currentTypeConverter); - remapped.push_back(targetMat); + ValueVector repl = mapping.lookupOrDefault(operand, legalTypes); + if (!repl.empty() && TypeRange(ValueRange(repl)) == legalTypes) { + // Mapped values have the correct type or there is an existing + // materialization. Or the operand is not mapped at all and has the + // correct type. + remapped.push_back(std::move(repl)); continue; } - // Handle 1->1 type conversions. - Type desiredType = legalTypes.front(); - // Try to find a mapped value with the desired type. (Or the operand itself - // if the value is not mapped at all.) - Value newOperand = mapping.lookupOrDefault(operand, desiredType); - if (newOperand.getType() != desiredType) { - // If the looked up value's type does not have the desired type, it means - // that the value was replaced with a value of different type and no - // target materialization was created yet. - Value castValue = buildUnresolvedMaterialization( - MaterializationKind::Target, computeInsertPoint(newOperand), - operandLoc, /*valueToMap=*/newOperand, /*inputs=*/unpacked, - /*outputType=*/desiredType, /*originalType=*/origType, - currentTypeConverter); - newOperand = castValue; - } - remapped.push_back({newOperand}); + // Create a materialization for the most recently mapped values. + repl = mapping.lookupOrDefault(operand); + ValueRange castValues = buildUnresolvedMaterialization( + MaterializationKind::Target, computeInsertPoint(repl), operandLoc, + /*valuesToMap=*/repl, /*inputs=*/repl, /*outputTypes=*/legalTypes, + /*originalType=*/origType, currentTypeConverter); + remapped.push_back(castValues); } return success(); } @@ -1353,7 +1384,7 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion( buildUnresolvedMaterialization( MaterializationKind::Source, OpBuilder::InsertPoint(newBlock, newBlock->begin()), origArg.getLoc(), - /*valueToMap=*/origArg, /*inputs=*/ValueRange(), + /*valuesToMap=*/{origArg}, /*inputs=*/ValueRange(), /*outputType=*/origArgType, /*originalType=*/Type(), converter); appendRewrite<ReplaceBlockArgRewrite>(block, origArg, converter); continue; @@ -1369,19 +1400,11 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion( continue; } - // This is a 1->1+ mapping. 1->N mappings are not fully supported in the - // dialect conversion. Therefore, we need an argument materialization to - // turn the replacement block arguments into a single SSA value that can be - // used as a replacement. + // This is a 1->1+ mapping. auto replArgs = newBlock->getArguments().slice(inputMap->inputNo, inputMap->size); - if (replArgs.size() == 1) { - mapping.map(origArg, replArgs.front()); - } else { - insertNTo1Materialization( - OpBuilder::InsertPoint(newBlock, newBlock->begin()), origArg.getLoc(), - /*replacements=*/replArgs, /*outputValue=*/origArg, converter); - } + ValueVector replArgVals = llvm::to_vector_of<Value, 1>(replArgs); + mapping.map(origArg, std::move(replArgVals)); appendRewrite<ReplaceBlockArgRewrite>(block, origArg, converter); } @@ -1402,20 +1425,13 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion( /// of input operands. ValueRange ConversionPatternRewriterImpl::buildUnresolvedMaterialization( MaterializationKind kind, OpBuilder::InsertPoint ip, Location loc, - Value valueToMap, ValueRange inputs, TypeRange outputTypes, + ValueVector valuesToMap, ValueRange inputs, TypeRange outputTypes, Type originalType, const TypeConverter *converter, UnrealizedConversionCastOp *castOp) { assert((!originalType || kind == MaterializationKind::Target) && "original type is valid only for target materializations"); - - // Avoid materializing an unnecessary cast. - if (TypeRange(inputs) == outputTypes) { - if (valueToMap) { - assert(inputs.size() == 1 && "1:N mapping is not supported"); - mapping.map(valueToMap, inputs.front()); - } - return inputs; - } + assert(TypeRange(inputs) != outputTypes && + "materialization is not necessary"); // Create an unresolved materialization. We use a new OpBuilder to avoid // tracking the materialization like we do for other operations. @@ -1423,37 +1439,23 @@ ValueRange ConversionPatternRewriterImpl::buildUnresolvedMaterialization( builder.setInsertionPoint(ip.getBlock(), ip.getPoint()); auto convertOp = builder.create<UnrealizedConversionCastOp>(loc, outputTypes, inputs); - if (valueToMap) { - assert(outputTypes.size() == 1 && "1:N mapping is not supported"); - mapping.map(valueToMap, convertOp.getResult(0)); - } + if (!valuesToMap.empty()) + mapping.map(valuesToMap, convertOp.getResults()); if (castOp) *castOp = convertOp; - appendRewrite<UnresolvedMaterializationRewrite>(convertOp, converter, kind, - originalType, valueToMap); + appendRewrite<UnresolvedMaterializationRewrite>( + convertOp, converter, kind, originalType, std::move(valuesToMap)); return convertOp.getResults(); } -void ConversionPatternRewriterImpl::insertNTo1Materialization( - OpBuilder::InsertPoint ip, Location loc, ValueRange replacements, - Value originalValue, const TypeConverter *converter) { - // Insert argument materialization back to the original type. - Type originalType = originalValue.getType(); - UnrealizedConversionCastOp argCastOp; - buildUnresolvedMaterialization( - MaterializationKind::Argument, ip, loc, /*valueToMap=*/originalValue, - /*inputs=*/replacements, originalType, - /*originalType=*/Type(), converter, &argCastOp); - if (argCastOp) - nTo1TempMaterializations.insert(argCastOp); -} - Value ConversionPatternRewriterImpl::findOrBuildReplacementValue( Value value, const TypeConverter *converter) { - // Find a replacement value with the same type. - Value repl = mapping.lookupOrNull(value, value.getType()); - if (repl) - return repl; + // Try to find a replacement value with the same type in the conversion value + // mapping. This includes cached materializations. We try to reuse those + // instead of generating duplicate IR. + ValueVector repl = mapping.lookupOrNull(value, value.getType()); + if (!repl.empty()) + return repl.front(); // Check if the value is dead. No replacement value is needed in that case. // This is an approximate check that may have false negatives but does not @@ -1468,7 +1470,7 @@ Value ConversionPatternRewriterImpl::findOrBuildReplacementValue( // (regardless of the type) and build a source materialization to the // original type. repl = mapping.lookupOrNull(value); - if (!repl) { + if (repl.empty()) { // No replacement value is registered in the mapping. This means that the // value is dropped and no longer needed. (If the value were still needed, // a source materialization producing a replacement value "out of thin air" @@ -1476,34 +1478,22 @@ Value ConversionPatternRewriterImpl::findOrBuildReplacementValue( // `applySignatureConversion`.) return Value(); } - Value castValue = buildUnresolvedMaterialization( - MaterializationKind::Source, computeInsertPoint(repl), value.getLoc(), - /*valueToMap=*/value, /*inputs=*/repl, /*outputType=*/value.getType(), - /*originalType=*/Type(), converter); - mapping.map(value, castValue); - return castValue; -} -SmallVector<Value> -ConversionPatternRewriterImpl::unpackNTo1Materialization(Value value) { - // Unpack unrealized_conversion_cast ops that were inserted as a N:1 - // workaround. - auto castOp = value.getDefiningOp<UnrealizedConversionCastOp>(); - if (!castOp) - return {value}; - if (!nTo1TempMaterializations.contains(castOp)) - return {value}; - assert(castOp->getNumResults() == 1 && "expected single result"); - - SmallVector<Value> result; - for (Value v : castOp.getOperands()) { - // Keep unpacking if possible. This is needed because during block - // signature conversions and 1:N op replacements, the driver may have - // inserted two materializations back-to-back: first an argument - // materialization, then a target materialization. - llvm::append_range(result, unpackNTo1Materialization(v)); - } - return result; + // Note: `computeInsertPoint` computes the "earliest" insertion point at + // which all values in `repl` are defined. It is important to emit the + // materialization at that location because the same materialization may be + // reused in a different context. (That's because materializations are cached + // in the conversion value mapping.) The insertion point of the + // materialization must be valid for all future users that may be created + // later in the conversion process. + Value castValue = + buildUnresolvedMaterialization(MaterializationKind::Source, + computeInsertPoint(repl), value.getLoc(), + /*valuesToMap=*/repl, /*inputs=*/repl, + /*outputType=*/value.getType(), + /*originalType=*/Type(), converter) + .front(); + return castValue; } //===----------------------------------------------------------------------===// @@ -1554,7 +1544,7 @@ void ConversionPatternRewriterImpl::notifyOpReplaced( // Materialize a replacement value "out of thin air". buildUnresolvedMaterialization( MaterializationKind::Source, computeInsertPoint(result), - result.getLoc(), /*valueToMap=*/result, /*inputs=*/ValueRange(), + result.getLoc(), /*valuesToMap=*/{result}, /*inputs=*/ValueRange(), /*outputType=*/result.getType(), /*originalType=*/Type(), currentTypeConverter); continue; @@ -1572,16 +1562,7 @@ void ConversionPatternRewriterImpl::notifyOpReplaced( // Remap result to replacement value. if (repl.empty()) continue; - - if (repl.size() == 1) { - // Single replacement value: replace directly. - mapping.map(result, repl.front()); - } else { - // Multiple replacement values: insert N:1 materialization. - insertNTo1Materialization(computeInsertPoint(result), result.getLoc(), - /*replacements=*/repl, /*outputValue=*/result, - currentTypeConverter); - } + mapping.map(result, repl); } appendRewrite<ReplaceOperationRewrite>(op, currentTypeConverter); @@ -1660,8 +1641,13 @@ void ConversionPatternRewriter::replaceOp(Operation *op, ValueRange newValues) { << "** Replace : '" << op->getName() << "'(" << op << ")\n"; }); SmallVector<ValueRange> newVals; - for (size_t i = 0; i < newValues.size(); ++i) - newVals.push_back(newValues.slice(i, 1)); + for (size_t i = 0; i < newValues.size(); ++i) { + if (newValues[i]) { + newVals.push_back(newValues.slice(i, 1)); + } else { + newVals.push_back(ValueRange()); + } + } impl->notifyOpReplaced(op, newVals); } @@ -1733,7 +1719,7 @@ void ConversionPatternRewriter::replaceUsesOfBlockArgument(BlockArgument from, } Value ConversionPatternRewriter::getRemappedValue(Value key) { - SmallVector<SmallVector<Value>> remappedValues; + SmallVector<ValueVector> remappedValues; if (failed(impl->remapValues("value", /*inputLoc=*/std::nullopt, *this, key, remappedValues))) return nullptr; @@ -1746,7 +1732,7 @@ ConversionPatternRewriter::getRemappedValues(ValueRange keys, SmallVectorImpl<Value> &results) { if (keys.empty()) return success(); - SmallVector<SmallVector<Value>> remapped; + SmallVector<ValueVector> remapped; if (failed(impl->remapValues("value", /*inputLoc=*/std::nullopt, *this, keys, remapped))) return failure(); @@ -1872,7 +1858,7 @@ ConversionPattern::matchAndRewrite(Operation *op, getTypeConverter()); // Remap the operands of the operation. - SmallVector<SmallVector<Value>> remapped; + SmallVector<ValueVector> remapped; if (failed(rewriterImpl.remapValues("operand", op->getLoc(), rewriter, op->getOperands(), remapped))) { return failure(); @@ -2625,19 +2611,6 @@ legalizeUnresolvedMaterialization(RewriterBase &rewriter, rewriter.setInsertionPoint(op); SmallVector<Value> newMaterialization; switch (rewrite->getMaterializationKind()) { - case MaterializationKind::Argument: { - // Try to materialize an argument conversion. - assert(op->getNumResults() == 1 && "expected single result"); - Value argMat = converter->materializeArgumentConversion( - rewriter, op->getLoc(), op.getResultTypes().front(), inputOperands); - if (argMat) { - newMaterialization.push_back(argMat); - break; - } - } - // If an argument materialization failed, fallback to trying a target - // materialization. - [[fallthrough]]; case MaterializationKind::Target: newMaterialization = converter->materializeTargetConversion( rewriter, op->getLoc(), op.getResultTypes(), inputOperands, diff --git a/mlir/python/mlir/_mlir_libs/__init__.py b/mlir/python/mlir/_mlir_libs/__init__.py index c5cb22c6dccb..d021dde05dd8 100644 --- a/mlir/python/mlir/_mlir_libs/__init__.py +++ b/mlir/python/mlir/_mlir_libs/__init__.py @@ -58,6 +58,7 @@ def get_include_dirs() -> Sequence[str]: # needs. _dialect_registry = None +_load_on_create_dialects = None def get_dialect_registry(): @@ -71,6 +72,21 @@ def get_dialect_registry(): return _dialect_registry +def append_load_on_create_dialect(dialect: str): + global _load_on_create_dialects + if _load_on_create_dialects is None: + _load_on_create_dialects = [dialect] + else: + _load_on_create_dialects.append(dialect) + + +def get_load_on_create_dialects(): + global _load_on_create_dialects + if _load_on_create_dialects is None: + _load_on_create_dialects = [] + return _load_on_create_dialects + + def _site_initialize(): import importlib import itertools @@ -132,15 +148,35 @@ def _site_initialize(): break class Context(ir._BaseContext): - def __init__(self, *args, **kwargs): + def __init__(self, load_on_create_dialects=None, *args, **kwargs): super().__init__(*args, **kwargs) self.append_dialect_registry(get_dialect_registry()) for hook in post_init_hooks: hook(self) if not disable_multithreading: self.enable_multithreading(True) - if not disable_load_all_available_dialects: - self.load_all_available_dialects() + if load_on_create_dialects is not None: + logger.debug( + "Loading all dialects from load_on_create_dialects arg %r", + load_on_create_dialects, + ) + for dialect in load_on_create_dialects: + # This triggers loading the dialect into the context. + _ = self.dialects[dialect] + else: + if disable_load_all_available_dialects: + dialects = get_load_on_create_dialects() + if dialects: + logger.debug( + "Loading all dialects from global load_on_create_dialects %r", + dialects, + ) + for dialect in dialects: + # This triggers loading the dialect into the context. + _ = self.dialects[dialect] + else: + logger.debug("Loading all available dialects") + self.load_all_available_dialects() if init_module: logger.debug( "Registering translations from initializer %r", init_module diff --git a/mlir/python/mlir/dialects/transform/structured.py b/mlir/python/mlir/dialects/transform/structured.py index 9121aa8e4023..bf40cc532065 100644 --- a/mlir/python/mlir/dialects/transform/structured.py +++ b/mlir/python/mlir/dialects/transform/structured.py @@ -141,6 +141,77 @@ class FuseIntoContainingOp(FuseIntoContainingOp): @_ods_cext.register_operation(_Dialect, replace=True) +class FuseOp(FuseOp): + """Specialization for FuseOp class.""" + + @overload + def __init__( + self, + loop_types: Union[Type, Sequence[Type]], + target: Union[Operation, Value, OpView], + *, + tile_sizes: Optional[Union[DynamicIndexList, ArrayAttr]] = None, + tile_interchange: OptionalIntList = None, + apply_cleanup: Optional[bool] = False, + loc=None, + ip=None, + ): + ... + + @overload + def __init__( + self, + target: Union[Operation, Value, OpView], + *, + tile_sizes: Optional[Union[DynamicIndexList, ArrayAttr]] = None, + tile_interchange: OptionalIntList = None, + apply_cleanup: Optional[bool] = False, + loc=None, + ip=None, + ): + ... + + def __init__( + self, + loop_types_or_target: Union[Type, Sequence[Type], Operation, OpView, Value], + target_or_none: Optional[Union[Operation, Value, OpView]] = None, + *, + tile_sizes: Optional[Union[DynamicIndexList, ArrayAttr]] = None, + tile_interchange: OptionalIntList = None, + apply_cleanup: Optional[bool] = False, + loc=None, + ip=None, + ): + tile_sizes = tile_sizes if tile_sizes else [] + tile_interchange = tile_interchange if tile_interchange else [] + _, tile_sizes, _ = _dispatch_dynamic_index_list(tile_sizes) + _, tile_interchange, _ = _dispatch_dynamic_index_list(tile_interchange) + num_loops = sum(0 if v == 0 else 1 for v in tile_sizes) + + if isinstance(loop_types_or_target, (Operation, Value, OpView)): + loop_types = [transform.AnyOpType.get()] * num_loops + target = loop_types_or_target + assert target_or_none is None, "Cannot construct FuseOp with two targets." + else: + loop_types = ( + ([loop_types_or_target] * num_loops) + if isinstance(loop_types_or_target, Type) + else loop_types_or_target + ) + target = target_or_none + super().__init__( + target.type, + loop_types, + target, + tile_sizes=tile_sizes, + tile_interchange=tile_interchange, + apply_cleanup=apply_cleanup, + loc=loc, + ip=ip, + ) + + +@_ods_cext.register_operation(_Dialect, replace=True) class GeneralizeOp(GeneralizeOp): """Specialization for GeneralizeOp class.""" diff --git a/mlir/python/mlir/ir.py b/mlir/python/mlir/ir.py index 9a6ce462047a..6f37266d5bf3 100644 --- a/mlir/python/mlir/ir.py +++ b/mlir/python/mlir/ir.py @@ -5,7 +5,11 @@ from ._mlir_libs._mlir.ir import * from ._mlir_libs._mlir.ir import _GlobalDebug from ._mlir_libs._mlir import register_type_caster, register_value_caster -from ._mlir_libs import get_dialect_registry +from ._mlir_libs import ( + get_dialect_registry, + append_load_on_create_dialect, + get_load_on_create_dialects, +) # Convenience decorator for registering user-friendly Attribute builders. diff --git a/mlir/test/Conversion/AffineToStandard/lower-affine-to-vector.mlir b/mlir/test/Conversion/AffineToStandard/lower-affine-to-vector.mlir index 58580a194df0..f2e0306073f2 100644 --- a/mlir/test/Conversion/AffineToStandard/lower-affine-to-vector.mlir +++ b/mlir/test/Conversion/AffineToStandard/lower-affine-to-vector.mlir @@ -26,7 +26,7 @@ func.func @affine_vector_store(%arg0 : index) { // CHECK: %[[buf:.*]] = memref.alloc // CHECK: %[[val:.*]] = arith.constant dense // CHECK: %[[c_1:.*]] = arith.constant -1 : index -// CHECK-NEXT: %[[a:.*]] = arith.muli %arg0, %[[c_1]] : index +// CHECK-NEXT: %[[a:.*]] = arith.muli %arg0, %[[c_1]] overflow<nsw> : index // CHECK-NEXT: %[[b:.*]] = arith.addi %{{.*}}, %[[a]] : index // CHECK-NEXT: %[[c7:.*]] = arith.constant 7 : index // CHECK-NEXT: %[[c:.*]] = arith.addi %[[b]], %[[c7]] : index diff --git a/mlir/test/Conversion/AffineToStandard/lower-affine.mlir b/mlir/test/Conversion/AffineToStandard/lower-affine.mlir index 00d7b6b8d65f..550ea71882e1 100644 --- a/mlir/test/Conversion/AffineToStandard/lower-affine.mlir +++ b/mlir/test/Conversion/AffineToStandard/lower-affine.mlir @@ -156,7 +156,7 @@ func.func private @get_idx() -> (index) // CHECK-NEXT: %[[v0:.*]] = call @get_idx() : () -> index // CHECK-NEXT: %[[c0:.*]] = arith.constant 0 : index // CHECK-NEXT: %[[cm1:.*]] = arith.constant -1 : index -// CHECK-NEXT: %[[v1:.*]] = arith.muli %[[v0]], %[[cm1]] : index +// CHECK-NEXT: %[[v1:.*]] = arith.muli %[[v0]], %[[cm1]] overflow<nsw> : index // CHECK-NEXT: %[[c20:.*]] = arith.constant 20 : index // CHECK-NEXT: %[[v2:.*]] = arith.addi %[[v1]], %[[c20]] : index // CHECK-NEXT: %[[v3:.*]] = arith.cmpi sge, %[[v2]], %[[c0]] : index @@ -177,7 +177,7 @@ func.func @if_only() { // CHECK-NEXT: %[[v0:.*]] = call @get_idx() : () -> index // CHECK-NEXT: %[[c0:.*]] = arith.constant 0 : index // CHECK-NEXT: %[[cm1:.*]] = arith.constant -1 : index -// CHECK-NEXT: %[[v1:.*]] = arith.muli %[[v0]], %[[cm1]] : index +// CHECK-NEXT: %[[v1:.*]] = arith.muli %[[v0]], %[[cm1]] overflow<nsw> : index // CHECK-NEXT: %[[c20:.*]] = arith.constant 20 : index // CHECK-NEXT: %[[v2:.*]] = arith.addi %[[v1]], %[[c20]] : index // CHECK-NEXT: %[[v3:.*]] = arith.cmpi sge, %[[v2]], %[[c0]] : index @@ -202,7 +202,7 @@ func.func @if_else() { // CHECK-NEXT: %[[v0:.*]] = call @get_idx() : () -> index // CHECK-NEXT: %[[c0:.*]] = arith.constant 0 : index // CHECK-NEXT: %[[cm1:.*]] = arith.constant -1 : index -// CHECK-NEXT: %[[v1:.*]] = arith.muli %[[v0]], %[[cm1]] : index +// CHECK-NEXT: %[[v1:.*]] = arith.muli %[[v0]], %[[cm1]] overflow<nsw> : index // CHECK-NEXT: %[[c20:.*]] = arith.constant 20 : index // CHECK-NEXT: %[[v2:.*]] = arith.addi %[[v1]], %[[c20]] : index // CHECK-NEXT: %[[v3:.*]] = arith.cmpi sge, %[[v2]], %[[c0]] : index @@ -272,7 +272,7 @@ func.func @if_with_yield() -> (i64) { // CHECK-NEXT: %[[v0:.*]] = call @get_idx() : () -> index // CHECK-NEXT: %[[c0:.*]] = arith.constant 0 : index // CHECK-NEXT: %[[cm1:.*]] = arith.constant -1 : index -// CHECK-NEXT: %[[v1:.*]] = arith.muli %[[v0]], %[[cm1]] : index +// CHECK-NEXT: %[[v1:.*]] = arith.muli %[[v0]], %[[cm1]] overflow<nsw> : index // CHECK-NEXT: %[[v2:.*]] = arith.addi %[[v1]], %{{.*}} : index // CHECK-NEXT: %[[c1:.*]] = arith.constant 1 : index // CHECK-NEXT: %[[v3:.*]] = arith.addi %[[v2]], %[[c1]] : index @@ -316,7 +316,7 @@ func.func @if_for() { %i = call @get_idx() : () -> (index) // CHECK-NEXT: %[[c0:.*]] = arith.constant 0 : index // CHECK-NEXT: %[[cm1:.*]] = arith.constant -1 : index -// CHECK-NEXT: %[[v1:.*]] = arith.muli %[[v0]], %[[cm1]] : index +// CHECK-NEXT: %[[v1:.*]] = arith.muli %[[v0]], %[[cm1]] overflow<nsw> : index // CHECK-NEXT: %[[c20:.*]] = arith.constant 20 : index // CHECK-NEXT: %[[v2:.*]] = arith.addi %[[v1]], %[[c20]] : index // CHECK-NEXT: %[[v3:.*]] = arith.cmpi sge, %[[v2]], %[[c0]] : index @@ -371,7 +371,7 @@ func.func @if_for() { // CHECK-NEXT: %[[c1:.*]] = arith.constant 1 : index // CHECK-NEXT: for %{{.*}} = %[[c0]] to %[[c42]] step %[[c1]] { // CHECK-NEXT: %[[cm1:.*]] = arith.constant -1 : index -// CHECK-NEXT: %[[mul0:.*]] = arith.muli %{{.*}}, %[[cm1]] : index +// CHECK-NEXT: %[[mul0:.*]] = arith.muli %{{.*}}, %[[cm1]] overflow<nsw> : index // CHECK-NEXT: %[[add0:.*]] = arith.addi %[[mul0]], %{{.*}} : index // CHECK-NEXT: %[[max:.*]] = arith.maxsi %{{.*}}, %[[add0]] : index // CHECK-NEXT: %[[c10:.*]] = arith.constant 10 : index @@ -448,22 +448,22 @@ func.func @affine_applies(%arg0 : index) { %one = affine.apply #map3(%symbZero)[%zero] // CHECK-NEXT: %[[c2:.*]] = arith.constant 2 : index -// CHECK-NEXT: %[[v2:.*]] = arith.muli %arg0, %[[c2]] : index +// CHECK-NEXT: %[[v2:.*]] = arith.muli %arg0, %[[c2]] overflow<nsw> : index // CHECK-NEXT: %[[v3:.*]] = arith.addi %arg0, %[[v2]] : index // CHECK-NEXT: %[[c3:.*]] = arith.constant 3 : index -// CHECK-NEXT: %[[v4:.*]] = arith.muli %arg0, %[[c3]] : index +// CHECK-NEXT: %[[v4:.*]] = arith.muli %arg0, %[[c3]] overflow<nsw> : index // CHECK-NEXT: %[[v5:.*]] = arith.addi %[[v3]], %[[v4]] : index // CHECK-NEXT: %[[c4:.*]] = arith.constant 4 : index -// CHECK-NEXT: %[[v6:.*]] = arith.muli %arg0, %[[c4]] : index +// CHECK-NEXT: %[[v6:.*]] = arith.muli %arg0, %[[c4]] overflow<nsw> : index // CHECK-NEXT: %[[v7:.*]] = arith.addi %[[v5]], %[[v6]] : index // CHECK-NEXT: %[[c5:.*]] = arith.constant 5 : index -// CHECK-NEXT: %[[v8:.*]] = arith.muli %arg0, %[[c5]] : index +// CHECK-NEXT: %[[v8:.*]] = arith.muli %arg0, %[[c5]] overflow<nsw> : index // CHECK-NEXT: %[[v9:.*]] = arith.addi %[[v7]], %[[v8]] : index // CHECK-NEXT: %[[c6:.*]] = arith.constant 6 : index -// CHECK-NEXT: %[[v10:.*]] = arith.muli %arg0, %[[c6]] : index +// CHECK-NEXT: %[[v10:.*]] = arith.muli %arg0, %[[c6]] overflow<nsw> : index // CHECK-NEXT: %[[v11:.*]] = arith.addi %[[v9]], %[[v10]] : index // CHECK-NEXT: %[[c7:.*]] = arith.constant 7 : index -// CHECK-NEXT: %[[v12:.*]] = arith.muli %arg0, %[[c7]] : index +// CHECK-NEXT: %[[v12:.*]] = arith.muli %arg0, %[[c7]] overflow<nsw> : index // CHECK-NEXT: %[[v13:.*]] = arith.addi %[[v11]], %[[v12]] : index %four = affine.apply #map4(%arg0, %arg0, %arg0, %arg0)[%arg0, %arg0, %arg0] return @@ -610,7 +610,7 @@ func.func @affine_store(%arg0 : index) { affine.store %1, %0[%i0 - symbol(%arg0) + 7] : memref<10xf32> } // CHECK: %[[cm1:.*]] = arith.constant -1 : index -// CHECK-NEXT: %[[a:.*]] = arith.muli %{{.*}}, %[[cm1]] : index +// CHECK-NEXT: %[[a:.*]] = arith.muli %{{.*}}, %[[cm1]] overflow<nsw> : index // CHECK-NEXT: %[[b:.*]] = arith.addi %{{.*}}, %[[a]] : index // CHECK-NEXT: %[[c7:.*]] = arith.constant 7 : index // CHECK-NEXT: %[[c:.*]] = arith.addi %[[b]], %[[c7]] : index diff --git a/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir b/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir index 748dfe8c68fc..f52dd6c0d0ce 100644 --- a/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir +++ b/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir @@ -633,7 +633,7 @@ gpu.module @test_module_29 { // CHECK-NEXT: %[[EL1:.*]] = llvm.getelementptr %[[ALLOC]][0, 1] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<(i32, f64)> // CHECK-NEXT: llvm.store %[[EXT]], %[[EL1]] : f64, !llvm.ptr // CHECK-NEXT: llvm.call @vprintf(%[[FORMATSTART]], %[[ALLOC]]) : (!llvm.ptr, !llvm.ptr) -> i32 - gpu.printf "Hello: %d\n" %arg0, %arg1 : i32, f32 + gpu.printf "Hello: %d\n", %arg0, %arg1 : i32, f32 gpu.return } } @@ -969,6 +969,35 @@ gpu.module @test_module_50 { } } +// CHECK-LABEL: gpu.module @test_module_51 +// CHECK: llvm.mlir.global internal constant @[[func_name:.*]]("(unknown)\00") {addr_space = 0 : i32} +// CHECK: llvm.mlir.global internal constant @[[file_name:.*]]("{{.*}}gpu-to-nvvm.mlir{{.*}}") {addr_space = 0 : i32} +// CHECK: llvm.mlir.global internal constant @[[message:.*]]("assert message\00") {addr_space = 0 : i32} +// CHECK: llvm.func @__assertfail(!llvm.ptr, !llvm.ptr, i32, !llvm.ptr, i64) attributes {passthrough = ["noreturn"]} +// CHECK: llvm.func @test_assert(%[[cond:.*]]: i1) attributes {gpu.kernel, nvvm.kernel} { +// CHECK: llvm.cond_br %[[cond]], ^[[after_block:.*]], ^[[assert_block:.*]] +// CHECK: ^[[assert_block]]: +// CHECK: %[[message_ptr:.*]] = llvm.mlir.addressof @[[message]] : !llvm.ptr +// CHECK: %[[message_start:.*]] = llvm.getelementptr %[[message_ptr]][0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.array<15 x i8> +// CHECK: %[[file_ptr:.*]] = llvm.mlir.addressof @[[file_name]] : !llvm.ptr +// CHECK: %[[file_start:.*]] = llvm.getelementptr %[[file_ptr]][0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.array<{{.*}} x i8> +// CHECK: %[[func_ptr:.*]] = llvm.mlir.addressof @[[func_name]] : !llvm.ptr +// CHECK: %[[func_start:.*]] = llvm.getelementptr %[[func_ptr]][0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.array<{{.*}} x i8> +// CHECK: %[[line_num:.*]] = llvm.mlir.constant({{.*}} : i32) : i32 +// CHECK: %[[ptr:.*]] = llvm.mlir.constant(1 : i64) : i64 +// CHECK: llvm.call @__assertfail(%[[message_start]], %[[file_start]], %[[line_num]], %[[func_start]], %[[ptr]]) : (!llvm.ptr, !llvm.ptr, i32, !llvm.ptr, i64) -> () +// CHECK: llvm.br ^[[after_block]] +// CHECK: ^[[after_block]]: +// CHECK: llvm.return +// CHECK: } + +gpu.module @test_module_51 { + gpu.func @test_assert(%arg0: i1) kernel { + cf.assert %arg0, "assert message" + gpu.return + } +} + module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%toplevel_module: !transform.any_op {transform.readonly}) { %gpu_module = transform.structured.match ops{["gpu.module"]} in %toplevel_module diff --git a/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl-hip.mlir b/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl-hip.mlir index 1b904fa142ba..2dc6a5ab2a86 100644 --- a/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl-hip.mlir +++ b/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl-hip.mlir @@ -36,7 +36,7 @@ gpu.module @test_module { // CHECK-NEXT: %[[NARGS1:.*]] = llvm.mlir.constant(1 : i32) : i32 // CHECK-NEXT: %[[ARG0_64:.*]] = llvm.zext %[[ARG0]] : i32 to i64 // CHECK-NEXT: %{{.*}} = llvm.call @__ockl_printf_append_args(%[[DESC1]], %[[NARGS1]], %[[ARG0_64]], %[[CST0]], %[[CST0]], %[[CST0]], %[[CST0]], %[[CST0]], %[[CST0]], %[[ISLAST]]) : (i64, i32, i64, i64, i64, i64, i64, i64, i64, i32) -> i64 - gpu.printf "Hello: %d\n" %arg0 : i32 + gpu.printf "Hello: %d\n", %arg0 : i32 gpu.return } } diff --git a/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl-opencl.mlir b/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl-opencl.mlir index 870f5c5016ec..00d1d7d85268 100644 --- a/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl-opencl.mlir +++ b/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl-opencl.mlir @@ -9,7 +9,7 @@ gpu.module @test_module { // CHECK: %[[IMM0:.*]] = llvm.mlir.addressof @[[$PRINT_GLOBAL]] : !llvm.ptr<4> // CHECK-NEXT: %[[IMM2:.*]] = llvm.getelementptr %[[IMM0]][0, 0] : (!llvm.ptr<4>) -> !llvm.ptr<4>, !llvm.array<11 x i8> // CHECK-NEXT: %{{.*}} = llvm.call @printf(%[[IMM2]], %[[ARG0]]) vararg(!llvm.func<i32 (ptr<4>, ...)>) : (!llvm.ptr<4>, i32) -> i32 - gpu.printf "Hello: %d\n" %arg0 : i32 + gpu.printf "Hello: %d\n", %arg0 : i32 gpu.return } } diff --git a/mlir/test/Conversion/GPUToSPIRV/printf.mlir b/mlir/test/Conversion/GPUToSPIRV/printf.mlir index bc091124ea4c..7fe9752b088d 100644 --- a/mlir/test/Conversion/GPUToSPIRV/printf.mlir +++ b/mlir/test/Conversion/GPUToSPIRV/printf.mlir @@ -62,7 +62,7 @@ module attributes { // CHECK: [[FMTSTR_ADDR:%.*]] = spirv.mlir.addressof [[PRINTMSG]] : !spirv.ptr<!spirv.array<[[ARRAYSIZE]] x i8>, UniformConstant> // CHECK-NEXT: [[FMTSTR_PTR1:%.*]] = spirv.Bitcast [[FMTSTR_ADDR]] : !spirv.ptr<!spirv.array<[[ARRAYSIZE]] x i8>, UniformConstant> to !spirv.ptr<i8, UniformConstant> // CHECK-NEXT: {{%.*}} = spirv.CL.printf [[FMTSTR_PTR1]] {{%.*}}, {{%.*}}, {{%.*}} : !spirv.ptr<i8, UniformConstant>, i32, f32, i32 -> i32 - gpu.printf "\nHello, world : %d %f \n Thread id: %d\n" %arg0, %arg1, %2: i32, f32, index + gpu.printf "\nHello, world : %d %f \n Thread id: %d\n", %arg0, %arg1, %2: i32, f32, index // CHECK: spirv.Return gpu.return diff --git a/mlir/test/Conversion/MemRefToLLVM/expand-then-convert-to-llvm.mlir b/mlir/test/Conversion/MemRefToLLVM/expand-then-convert-to-llvm.mlir index a78db9733b7e..1fe4217cde98 100644 --- a/mlir/test/Conversion/MemRefToLLVM/expand-then-convert-to-llvm.mlir +++ b/mlir/test/Conversion/MemRefToLLVM/expand-then-convert-to-llvm.mlir @@ -59,7 +59,7 @@ func.func @subview(%0 : memref<64x4xf32, strided<[4, 1], offset: 0>>, %arg0 : in // CHECK: %[[BASE:.*]] = llvm.extractvalue %[[MEMREF]][0] : !llvm.struct<(ptr, ptr, i64 // CHECK: %[[BASE_ALIGNED:.*]] = llvm.extractvalue %[[MEMREF]][1] : !llvm.struct<(ptr, ptr, i64 // CHECK: %[[STRIDE0:.*]] = llvm.mlir.constant(4 : index) : i64 - // CHECK: %[[DESCSTRIDE0:.*]] = llvm.mul %[[ARG0]], %[[STRIDE0]] : i64 + // CHECK: %[[DESCSTRIDE0:.*]] = llvm.mul %[[ARG0]], %[[STRIDE0]] overflow<nsw> : i64 // CHECK: %[[TMP:.*]] = builtin.unrealized_conversion_cast %[[DESCSTRIDE0]] : i64 to index // CHECK: %[[DESCSTRIDE0_V2:.*]] = builtin.unrealized_conversion_cast %[[TMP]] : index to i64 // CHECK: %[[OFF2:.*]] = llvm.add %[[DESCSTRIDE0]], %[[ARG1]] : i64 @@ -95,10 +95,10 @@ func.func @subview_non_zero_addrspace(%0 : memref<64x4xf32, strided<[4, 1], offs // CHECK: %[[BASE:.*]] = llvm.extractvalue %[[MEMREF]][0] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)> // CHECK: %[[BASE_ALIGNED:.*]] = llvm.extractvalue %[[MEMREF]][1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)> // CHECK: %[[STRIDE0:.*]] = llvm.mlir.constant(4 : index) : i64 - // CHECK: %[[DESCSTRIDE0:.*]] = llvm.mul %[[ARG0]], %[[STRIDE0]] : i64 + // CHECK: %[[DESCSTRIDE0:.*]] = llvm.mul %[[ARG0]], %[[STRIDE0]] overflow<nsw> : i64 // CHECK: %[[TMP:.*]] = builtin.unrealized_conversion_cast %[[DESCSTRIDE0]] : i64 to index // CHECK: %[[DESCSTRIDE0_V2:.*]] = builtin.unrealized_conversion_cast %[[TMP]] : index to i64 - // CHECK: %[[OFF2:.*]] = llvm.add %[[DESCSTRIDE0]], %[[ARG1]] : i64 + // CHECK: %[[OFF2:.*]] = llvm.add %[[DESCSTRIDE0]], %[[ARG1]] : i64 // CHECK: %[[TMP:.*]] = builtin.unrealized_conversion_cast %[[OFF2]] : i64 to index // CHECK: %[[OFF2:.*]] = builtin.unrealized_conversion_cast %[[TMP]] : index to i64 // CHECK: %[[DESC:.*]] = llvm.mlir.undef : !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)> @@ -131,10 +131,10 @@ func.func @subview_const_size(%0 : memref<64x4xf32, strided<[4, 1], offset: 0>>, // CHECK: %[[BASE:.*]] = llvm.extractvalue %[[MEMREF]][0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // CHECK: %[[BASE_ALIGNED:.*]] = llvm.extractvalue %[[MEMREF]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // CHECK: %[[C4:.*]] = llvm.mlir.constant(4 : index) : i64 - // CHECK: %[[DESCSTRIDE0:.*]] = llvm.mul %[[ARG0]], %[[C4]] : i64 + // CHECK: %[[DESCSTRIDE0:.*]] = llvm.mul %[[ARG0]], %[[C4]] overflow<nsw> : i64 // CHECK: %[[TMP:.*]] = builtin.unrealized_conversion_cast %[[DESCSTRIDE0]] : i64 to index // CHECK: %[[DESCSTRIDE0_V2:.*]] = builtin.unrealized_conversion_cast %[[TMP]] : index to i64 - // CHECK: %[[OFF2:.*]] = llvm.add %[[DESCSTRIDE0]], %[[ARG1]] : i64 + // CHECK: %[[OFF2:.*]] = llvm.add %[[DESCSTRIDE0]], %[[ARG1]] : i64 // CHECK: %[[TMP:.*]] = builtin.unrealized_conversion_cast %[[OFF2]] : i64 to index // CHECK: %[[OFF2:.*]] = builtin.unrealized_conversion_cast %[[TMP]] : index to i64 // CHECK: %[[DESC:.*]] = llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> @@ -168,8 +168,8 @@ func.func @subview_const_stride(%0 : memref<64x4xf32, strided<[4, 1], offset: 0> // CHECK: %[[BASE:.*]] = llvm.extractvalue %[[MEMREF]][0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // CHECK: %[[BASE_ALIGNED:.*]] = llvm.extractvalue %[[MEMREF]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // CHECK: %[[C4:.*]] = llvm.mlir.constant(4 : index) : i64 - // CHECK: %[[OFF0:.*]] = llvm.mul %[[ARG0]], %[[C4]] : i64 - // CHECK: %[[OFF2:.*]] = llvm.add %[[OFF0]], %[[ARG1]] : i64 + // CHECK: %[[OFF0:.*]] = llvm.mul %[[ARG0]], %[[C4]] overflow<nsw> : i64 + // CHECK: %[[OFF2:.*]] = llvm.add %[[OFF0]], %[[ARG1]] : i64 // CHECK: %[[TMP:.*]] = builtin.unrealized_conversion_cast %[[OFF2]] : i64 to index // CHECK: %[[OFF2:.*]] = builtin.unrealized_conversion_cast %[[TMP]] : index to i64 // CHECK: %[[DESC:.*]] = llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> @@ -234,12 +234,12 @@ func.func @subview_mixed_static_dynamic(%0 : memref<64x4xf32, strided<[4, 1], of // CHECK: %[[BASE:.*]] = llvm.extractvalue %[[MEMREF]][0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // CHECK: %[[BASE_ALIGNED:.*]] = llvm.extractvalue %[[MEMREF]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // CHECK: %[[STRIDE0:.*]] = llvm.mlir.constant(4 : index) : i64 - // CHECK: %[[DESCSTRIDE0:.*]] = llvm.mul %[[ARG0]], %[[STRIDE0]] : i64 + // CHECK: %[[DESCSTRIDE0:.*]] = llvm.mul %[[ARG0]], %[[STRIDE0]] overflow<nsw> : i64 // CHECK: %[[TMP:.*]] = builtin.unrealized_conversion_cast %[[DESCSTRIDE0]] : i64 to index // CHECK: %[[DESCSTRIDE0_V2:.*]] = builtin.unrealized_conversion_cast %[[TMP]] : index to i64 - // CHECK: %[[OFF0:.*]] = llvm.mul %[[ARG1]], %[[STRIDE0]] : i64 + // CHECK: %[[OFF0:.*]] = llvm.mul %[[ARG1]], %[[STRIDE0]] overflow<nsw> : i64 // CHECK: %[[BASE_OFF:.*]] = llvm.mlir.constant(8 : index) : i64 - // CHECK: %[[OFF2:.*]] = llvm.add %[[OFF0]], %[[BASE_OFF]] : i64 + // CHECK: %[[OFF2:.*]] = llvm.add %[[OFF0]], %[[BASE_OFF]] : i64 // CHECK: %[[TMP:.*]] = builtin.unrealized_conversion_cast %[[OFF2]] : i64 to index // CHECK: %[[OFF2:.*]] = builtin.unrealized_conversion_cast %[[TMP]] : index to i64 // CHECK: %[[DESC:.*]] = llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> @@ -301,7 +301,7 @@ func.func @subview_leading_operands_dynamic(%0 : memref<5x?xf32>) -> memref<3x?x // CHECK: %[[STRIDE0:.*]] = llvm.extractvalue %[[MEMREF]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // Compute and insert offset from 2 + dynamic value. // CHECK: %[[CST_OFF0:.*]] = llvm.mlir.constant(2 : index) : i64 - // CHECK: %[[OFF0:.*]] = llvm.mul %[[STRIDE0]], %[[CST_OFF0]] : i64 + // CHECK: %[[OFF0:.*]] = llvm.mul %[[STRIDE0]], %[[CST_OFF0]] overflow<nsw> : i64 // CHECK: %[[TMP:.*]] = builtin.unrealized_conversion_cast %[[OFF0]] : i64 to index // CHECK: %[[OFF0:.*]] = builtin.unrealized_conversion_cast %[[TMP]] : index to i64 // CHECK: %[[DESC:.*]] = llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> @@ -425,7 +425,7 @@ func.func @collapse_shape_dynamic_with_non_identity_layout( // CHECK: %[[SIZE1:.*]] = llvm.extractvalue %[[MEM]][3, 1] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> // CHECK: %[[SIZE2:.*]] = llvm.extractvalue %[[MEM]][3, 2] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> // CHECK: %[[STRIDE0:.*]] = llvm.extractvalue %[[MEM]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> -// CHECK: %[[FINAL_SIZE1:.*]] = llvm.mul %[[SIZE1]], %[[SIZE2]] : i64 +// CHECK: %[[FINAL_SIZE1:.*]] = llvm.mul %[[SIZE1]], %[[SIZE2]] overflow<nsw> : i64 // CHECK: %[[SIZE1_TO_IDX:.*]] = builtin.unrealized_conversion_cast %[[FINAL_SIZE1]] : i64 to index // CHECK: %[[FINAL_SIZE1:.*]] = builtin.unrealized_conversion_cast %[[SIZE1_TO_IDX]] : index to i64 // CHECK: %[[DESC:.*]] = llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> @@ -547,7 +547,7 @@ func.func @collapse_shape_dynamic(%arg0 : memref<1x2x?xf32>) -> memref<1x?xf32> // CHECK: %[[SIZE2:.*]] = llvm.extractvalue %[[MEM]][3, 2] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> // CHECK: %[[STRIDE0:.*]] = llvm.extractvalue %[[MEM]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> // CHECK: %[[C2:.*]] = llvm.mlir.constant(2 : index) : i64 -// CHECK: %[[FINAL_SIZE1:.*]] = llvm.mul %[[SIZE2]], %[[C2]] : i64 +// CHECK: %[[FINAL_SIZE1:.*]] = llvm.mul %[[SIZE2]], %[[C2]] overflow<nsw> : i64 // CHECK: %[[SIZE1_TO_IDX:.*]] = builtin.unrealized_conversion_cast %[[FINAL_SIZE1]] : i64 to index // CHECK: %[[FINAL_SIZE1:.*]] = builtin.unrealized_conversion_cast %[[SIZE1_TO_IDX]] : index to i64 // CHECK: %[[DESC:.*]] = llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> diff --git a/mlir/test/Conversion/SCFToEmitC/for.mlir b/mlir/test/Conversion/SCFToEmitC/for.mlir index 83592187a9b6..7f41e636936b 100644 --- a/mlir/test/Conversion/SCFToEmitC/for.mlir +++ b/mlir/test/Conversion/SCFToEmitC/for.mlir @@ -7,8 +7,11 @@ func.func @simple_std_for_loop(%arg0 : index, %arg1 : index, %arg2 : index) { return } // CHECK-LABEL: func.func @simple_std_for_loop( -// CHECK-SAME: %[[VAL_0:.*]]: index, %[[VAL_1:.*]]: index, %[[VAL_2:.*]]: index) { -// CHECK-NEXT: emitc.for %[[VAL_3:.*]] = %[[VAL_0]] to %[[VAL_1]] step %[[VAL_2]] { +// CHECK-SAME: %[[ARG_0:.*]]: index, %[[ARG_1:.*]]: index, %[[ARG_2:.*]]: index) { +// CHECK-NEXT: %[[VAL_2:.*]] = builtin.unrealized_conversion_cast %[[ARG_2]] : index to !emitc.size_t +// CHECK-NEXT: %[[VAL_1:.*]] = builtin.unrealized_conversion_cast %[[ARG_1]] : index to !emitc.size_t +// CHECK-NEXT: %[[VAL_0:.*]] = builtin.unrealized_conversion_cast %[[ARG_0]] : index to !emitc.size_t +// CHECK-NEXT: emitc.for %[[VAL_3:.*]] = %[[VAL_0]] to %[[VAL_1]] step %[[VAL_2]] : !emitc.size_t { // CHECK-NEXT: %[[VAL_4:.*]] = arith.constant 1 : index // CHECK-NEXT: } // CHECK-NEXT: return @@ -24,10 +27,13 @@ func.func @simple_std_2_for_loops(%arg0 : index, %arg1 : index, %arg2 : index) { return } // CHECK-LABEL: func.func @simple_std_2_for_loops( -// CHECK-SAME: %[[VAL_0:.*]]: index, %[[VAL_1:.*]]: index, %[[VAL_2:.*]]: index) { -// CHECK-NEXT: emitc.for %[[VAL_3:.*]] = %[[VAL_0]] to %[[VAL_1]] step %[[VAL_2]] { +// CHECK-SAME: %[[ARG_0:.*]]: index, %[[ARG_1:.*]]: index, %[[ARG_2:.*]]: index) { +// CHECK-NEXT: %[[VAL_2:.*]] = builtin.unrealized_conversion_cast %[[ARG_2]] : index to !emitc.size_t +// CHECK-NEXT: %[[VAL_1:.*]] = builtin.unrealized_conversion_cast %[[ARG_1]] : index to !emitc.size_t +// CHECK-NEXT: %[[VAL_0:.*]] = builtin.unrealized_conversion_cast %[[ARG_0]] : index to !emitc.size_t +// CHECK-NEXT: emitc.for %[[VAL_3:.*]] = %[[VAL_0]] to %[[VAL_1]] step %[[VAL_2]] : !emitc.size_t { // CHECK-NEXT: %[[VAL_4:.*]] = arith.constant 1 : index -// CHECK-NEXT: emitc.for %[[VAL_5:.*]] = %[[VAL_0]] to %[[VAL_1]] step %[[VAL_2]] { +// CHECK-NEXT: emitc.for %[[VAL_5:.*]] = %[[VAL_0]] to %[[VAL_1]] step %[[VAL_2]] : !emitc.size_t { // CHECK-NEXT: %[[VAL_6:.*]] = arith.constant 1 : index // CHECK-NEXT: } // CHECK-NEXT: } @@ -44,14 +50,17 @@ func.func @for_yield(%arg0 : index, %arg1 : index, %arg2 : index) -> (f32, f32) return %result#0, %result#1 : f32, f32 } // CHECK-LABEL: func.func @for_yield( -// CHECK-SAME: %[[VAL_0:.*]]: index, %[[VAL_1:.*]]: index, %[[VAL_2:.*]]: index) -> (f32, f32) { +// CHECK-SAME: %[[ARG_0:.*]]: index, %[[ARG_1:.*]]: index, %[[ARG_2:.*]]: index) -> (f32, f32) { +// CHECK-NEXT: %[[VAL_2:.*]] = builtin.unrealized_conversion_cast %[[ARG_2]] : index to !emitc.size_t +// CHECK-NEXT: %[[VAL_1:.*]] = builtin.unrealized_conversion_cast %[[ARG_1]] : index to !emitc.size_t +// CHECK-NEXT: %[[VAL_0:.*]] = builtin.unrealized_conversion_cast %[[ARG_0]] : index to !emitc.size_t // CHECK-NEXT: %[[VAL_3:.*]] = arith.constant 0.000000e+00 : f32 // CHECK-NEXT: %[[VAL_4:.*]] = arith.constant 1.000000e+00 : f32 // CHECK-NEXT: %[[VAL_5:.*]] = "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> !emitc.lvalue<f32> // CHECK-NEXT: %[[VAL_6:.*]] = "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> !emitc.lvalue<f32> // CHECK-NEXT: emitc.assign %[[VAL_3]] : f32 to %[[VAL_5]] : <f32> // CHECK-NEXT: emitc.assign %[[VAL_4]] : f32 to %[[VAL_6]] : <f32> -// CHECK-NEXT: emitc.for %[[VAL_7:.*]] = %[[VAL_0]] to %[[VAL_1]] step %[[VAL_2]] { +// CHECK-NEXT: emitc.for %[[VAL_7:.*]] = %[[VAL_0]] to %[[VAL_1]] step %[[VAL_2]] : !emitc.size_t { // CHECK-NEXT: %[[VAL_8:.*]] = emitc.load %[[VAL_5]] : <f32> // CHECK-NEXT: %[[VAL_9:.*]] = emitc.load %[[VAL_6]] : <f32> // CHECK-NEXT: %[[VAL_10:.*]] = arith.addf %[[VAL_8]], %[[VAL_9]] : f32 @@ -75,15 +84,18 @@ func.func @nested_for_yield(%arg0 : index, %arg1 : index, %arg2 : index) -> f32 return %r : f32 } // CHECK-LABEL: func.func @nested_for_yield( -// CHECK-SAME: %[[VAL_0:.*]]: index, %[[VAL_1:.*]]: index, %[[VAL_2:.*]]: index) -> f32 { +// CHECK-SAME: %[[ARG_0:.*]]: index, %[[ARG_1:.*]]: index, %[[ARG_2:.*]]: index) -> f32 { +// CHECK-NEXT: %[[VAL_2:.*]] = builtin.unrealized_conversion_cast %[[ARG_2]] : index to !emitc.size_t +// CHECK-NEXT: %[[VAL_1:.*]] = builtin.unrealized_conversion_cast %[[ARG_1]] : index to !emitc.size_t +// CHECK-NEXT: %[[VAL_0:.*]] = builtin.unrealized_conversion_cast %[[ARG_0]] : index to !emitc.size_t // CHECK-NEXT: %[[VAL_3:.*]] = arith.constant 1.000000e+00 : f32 // CHECK-NEXT: %[[VAL_4:.*]] = "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> !emitc.lvalue<f32> // CHECK-NEXT: emitc.assign %[[VAL_3]] : f32 to %[[VAL_4]] : <f32> -// CHECK-NEXT: emitc.for %[[VAL_5:.*]] = %[[VAL_0]] to %[[VAL_1]] step %[[VAL_2]] { +// CHECK-NEXT: emitc.for %[[VAL_5:.*]] = %[[VAL_0]] to %[[VAL_1]] step %[[VAL_2]] : !emitc.size_t { // CHECK-NEXT: %[[VAL_6:.*]] = emitc.load %[[VAL_4]] : <f32> // CHECK-NEXT: %[[VAL_7:.*]] = "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> !emitc.lvalue<f32> // CHECK-NEXT: emitc.assign %[[VAL_6]] : f32 to %[[VAL_7]] : <f32> -// CHECK-NEXT: emitc.for %[[VAL_8:.*]] = %[[VAL_0]] to %[[VAL_1]] step %[[VAL_2]] { +// CHECK-NEXT: emitc.for %[[VAL_8:.*]] = %[[VAL_0]] to %[[VAL_1]] step %[[VAL_2]] : !emitc.size_t { // CHECK-NEXT: %[[VAL_9:.*]] = emitc.load %[[VAL_7]] : <f32> // CHECK-NEXT: %[[VAL_10:.*]] = arith.addf %[[VAL_9]], %[[VAL_9]] : f32 // CHECK-NEXT: emitc.assign %[[VAL_10]] : f32 to %[[VAL_7]] : <f32> @@ -94,3 +106,60 @@ func.func @nested_for_yield(%arg0 : index, %arg1 : index, %arg2 : index) -> f32 // CHECK-NEXT: %[[VAL_12:.*]] = emitc.load %[[VAL_4]] : <f32> // CHECK-NEXT: return %[[VAL_12]] : f32 // CHECK-NEXT: } + +func.func @for_yield_index(%arg0 : index, %arg1 : index, %arg2 : index) -> index { + %zero = arith.constant 0 : index + %r = scf.for %i0 = %arg0 to %arg1 step %arg2 iter_args(%acc = %zero) -> index { + scf.yield %acc : index + } + return %r : index +} + +// CHECK-LABEL: func.func @for_yield_index( +// CHECK-SAME: %[[ARG_0:.*]]: index, %[[ARG_1:.*]]: index, %[[ARG_2:.*]]: index) -> index { +// CHECK: %[[VAL_0:.*]] = builtin.unrealized_conversion_cast %[[ARG_2]] : index to !emitc.size_t +// CHECK: %[[VAL_1:.*]] = builtin.unrealized_conversion_cast %[[ARG_1]] : index to !emitc.size_t +// CHECK: %[[VAL_2:.*]] = builtin.unrealized_conversion_cast %[[ARG_0]] : index to !emitc.size_t +// CHECK: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %[[VAL_3:.*]] = builtin.unrealized_conversion_cast %[[C0]] : index to !emitc.size_t +// CHECK: %[[VAL_4:.*]] = "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> !emitc.lvalue<!emitc.size_t> +// CHECK: emitc.assign %[[VAL_3]] : !emitc.size_t to %[[VAL_4]] : <!emitc.size_t> +// CHECK: emitc.for %[[VAL_5:.*]] = %[[VAL_2]] to %[[VAL_1]] step %[[VAL_0]] : !emitc.size_t { +// CHECK: %[[V:.*]] = emitc.load %[[VAL_4]] : <!emitc.size_t> +// CHECK: emitc.assign %[[V]] : !emitc.size_t to %[[VAL_4]] : <!emitc.size_t> +// CHECK: } +// CHECK: %[[V2:.*]] = emitc.load %[[VAL_4]] : <!emitc.size_t> +// CHECK: %[[VAL_8:.*]] = builtin.unrealized_conversion_cast %[[V2]] : !emitc.size_t to index +// CHECK: return %[[VAL_8]] : index +// CHECK: } + + +func.func @for_yield_update_loop_carried_var(%arg0 : index, %arg1 : index, %arg2 : index) -> index { + %zero = arith.constant 0 : index + %r = scf.for %i0 = %arg0 to %arg1 step %arg2 iter_args(%acc = %zero) -> index { + %sn = arith.addi %acc, %acc : index + scf.yield %sn: index + } + return %r : index + } + +// CHECK-LABEL: func.func @for_yield_update_loop_carried_var( +// CHECK-SAME: %[[ARG_0:.*]]: index, %[[ARG_1:.*]]: index, %[[ARG_2:.*]]: index) -> index { +// CHECK: %[[VAL_0:.*]] = builtin.unrealized_conversion_cast %[[ARG_2]] : index to !emitc.size_t +// CHECK: %[[VAL_1:.*]] = builtin.unrealized_conversion_cast %[[ARG_1]] : index to !emitc.size_t +// CHECK: %[[VAL_2:.*]] = builtin.unrealized_conversion_cast %[[ARG_0]] : index to !emitc.size_t +// CHECK: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %[[VAL_3:.*]] = builtin.unrealized_conversion_cast %[[C0]] : index to !emitc.size_t +// CHECK: %[[VAL_4:.*]] = "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> !emitc.lvalue<!emitc.size_t> +// CHECK: emitc.assign %[[VAL_3]] : !emitc.size_t to %[[VAL_4]] : <!emitc.size_t> +// CHECK: emitc.for %[[ARG_3:.*]] = %[[VAL_2]] to %[[VAL_1]] step %[[VAL_0]] : !emitc.size_t { +// CHECK: %[[V:.*]] = emitc.load %[[VAL_4]] : <!emitc.size_t> +// CHECK: %[[VAL_5:.*]] = builtin.unrealized_conversion_cast %[[V]] : !emitc.size_t to index +// CHECK: %[[VAL_6:.*]] = arith.addi %[[VAL_5]], %[[VAL_5]] : index +// CHECK: %[[VAL_8:.*]] = builtin.unrealized_conversion_cast %[[VAL_6]] : index to !emitc.size_t +// CHECK: emitc.assign %[[VAL_8]] : !emitc.size_t to %[[VAL_4]] : <!emitc.size_t> +// CHECK: } +// CHECK: %[[V2:.*]] = emitc.load %[[VAL_4]] : <!emitc.size_t> +// CHECK: %[[VAL_9:.*]] = builtin.unrealized_conversion_cast %[[V2]] : !emitc.size_t to index +// CHECK: return %[[VAL_9]] : index +// CHECK: } diff --git a/mlir/test/Conversion/SCFToEmitC/switch.mlir b/mlir/test/Conversion/SCFToEmitC/switch.mlir index 86d96ed21f1b..61015b0ae483 100644 --- a/mlir/test/Conversion/SCFToEmitC/switch.mlir +++ b/mlir/test/Conversion/SCFToEmitC/switch.mlir @@ -1,7 +1,8 @@ // RUN: mlir-opt -allow-unregistered-dialect -convert-scf-to-emitc %s | FileCheck %s // CHECK-LABEL: func.func @switch_no_result( -// CHECK-SAME: %[[VAL_0:.*]]: index) { +// CHECK-SAME: %[[ARG_0:.*]]: index) { +// CHECK: %[[VAL_0:.*]] = builtin.unrealized_conversion_cast %[[ARG_0]] : index to !emitc.size_t // CHECK: emitc.switch %[[VAL_0]] // CHECK: case 2 { // CHECK: %[[VAL_1:.*]] = arith.constant 10 : i32 @@ -33,7 +34,8 @@ func.func @switch_no_result(%arg0 : index) { } // CHECK-LABEL: func.func @switch_one_result( -// CHECK-SAME: %[[VAL_0:.*]]: index) { +// CHECK-SAME: %[[ARG_0:.*]]: index) { +// CHECK: %[[VAL_0:.*]] = builtin.unrealized_conversion_cast %[[ARG_0]] : index to !emitc.size_t // CHECK: %[[VAL_1:.*]] = "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> !emitc.lvalue<i32> // CHECK: emitc.switch %[[VAL_0]] // CHECK: case 2 { @@ -70,7 +72,8 @@ func.func @switch_one_result(%arg0 : index) { } // CHECK-LABEL: func.func @switch_two_results( -// CHECK-SAME: %[[VAL_0:.*]]: index) -> (i32, f32) { +// CHECK-SAME: %[[ARG_0:.*]]: index) -> (i32, f32) { +// CHECK: %[[VAL_0:.*]] = builtin.unrealized_conversion_cast %[[ARG_0]] : index to !emitc.size_t // CHECK: %[[VAL_1:.*]] = "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> !emitc.lvalue<i32> // CHECK: %[[VAL_2:.*]] = "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> !emitc.lvalue<f32> // CHECK: emitc.switch %[[VAL_0]] diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir index bfdc72ee07e9..453a8610e716 100644 --- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir +++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir @@ -510,7 +510,7 @@ func.func @avg_pool_dyn(%arg0: tensor<?x6x34x62xf32>) -> (tensor<?x5x33x62xf32>) func.func @conv2d_scalar_bias_f32(%input: tensor<1x49x42x27xf32>, %weights: tensor<28x3x3x27xf32>, %bias: tensor<1xf32>) -> () { // CHECK: %[[INIT:.+]] = tensor.empty() : tensor<1x45x40x28xf32> // CHECK: %[[BROADCAST:.+]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP2]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2 : tensor<1xf32>) outs(%[[INIT]] : tensor<1x45x40x28xf32>) { - %0 = tosa.conv2d %input, %weights, %bias {pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, dilation = array<i64: 2, 1>} : (tensor<1x49x42x27xf32>, tensor<28x3x3x27xf32>, tensor<1xf32>) -> tensor<1x45x40x28xf32> + %0 = tosa.conv2d %input, %weights, %bias {acc_type = f32, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, dilation = array<i64: 2, 1>} : (tensor<1x49x42x27xf32>, tensor<28x3x3x27xf32>, tensor<1xf32>) -> tensor<1x45x40x28xf32> return } @@ -531,7 +531,7 @@ func.func @conv2d_i8(%input: tensor<1x49x42x27xi8>, %weights: tensor<28x1x1x27xi // CHECK: linalg.conv_2d_nhwc_fhwc_q {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1, %c0_i32, %c0_i32_0 : tensor<1x49x42x27xi8>, tensor<28x1x1x27xi8>, i32, i32) outs(%[[BROADCAST]] : tensor<1x45x40x28xi32>) -> tensor<1x45x40x28xi32> // HWCF: linalg.conv_2d_nhwc_hwcf_q {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %[[TRANSPOSE]], %c0_i32, %c0_i32_0 : tensor<1x49x42x27xi8>, tensor<1x1x27x28xi8>, i32, i32) outs(%{{[a-zA-Z0-9_]*}} : tensor<1x45x40x28xi32>) -> tensor<1x45x40x28xi32> - %0 = tosa.conv2d %input, %weights, %bias {dilation = array<i64: 2, 1>, pad = array<i64: 0, 0, 0, 0>, quantization_info = #tosa.conv_quant<input_zp = 0, weight_zp = 0>, stride = array<i64: 1, 1>} : (tensor<1x49x42x27xi8>, tensor<28x1x1x27xi8>, tensor<28xi8>) -> tensor<1x45x40x28xi32> + %0 = tosa.conv2d %input, %weights, %bias {acc_type = i32, dilation = array<i64: 2, 1>, pad = array<i64: 0, 0, 0, 0>, quantization_info = #tosa.conv_quant<input_zp = 0, weight_zp = 0>, stride = array<i64: 1, 1>} : (tensor<1x49x42x27xi8>, tensor<28x1x1x27xi8>, tensor<28xi8>) -> tensor<1x45x40x28xi32> return } @@ -552,7 +552,7 @@ func.func @conv2d_f32(%input: tensor<1x49x42x27xf32>, %weights: tensor<28x3x3x27 // CHECK: linalg.conv_2d_nhwc_fhwc {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<1x49x42x27xf32>, tensor<28x3x3x27xf32>) outs(%1 : tensor<1x45x40x28xf32>) -> tensor<1x45x40x28xf32> // HWCF: linalg.conv_2d_nhwc_hwcf {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %[[TRANSPOSE]] : tensor<1x49x42x27xf32>, tensor<3x3x27x28xf32>) outs(%{{[a-zA-Z0-9_]*}} : tensor<1x45x40x28xf32> - %0 = tosa.conv2d %input, %weights, %bias {pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, dilation = array<i64: 2, 1>} : (tensor<1x49x42x27xf32>, tensor<28x3x3x27xf32>, tensor<28xf32>) -> tensor<1x45x40x28xf32> + %0 = tosa.conv2d %input, %weights, %bias {acc_type = f32, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, dilation = array<i64: 2, 1>} : (tensor<1x49x42x27xf32>, tensor<28x3x3x27xf32>, tensor<28xf32>) -> tensor<1x45x40x28xf32> return } @@ -571,7 +571,7 @@ func.func @conv2d_dyn(%input: tensor<?x49x42x27xf32>, %weights: tensor<28x3x3x27 // CHECK: linalg.yield %[[IN]] : f32 // CHECK: } -> tensor<?x45x40x28xf32> // CHECK: %2 = linalg.conv_2d_nhwc_fhwc {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<?x49x42x27xf32>, tensor<28x3x3x27xf32>) outs(%[[BROADCAST]] : tensor<?x45x40x28xf32>) -> tensor<?x45x40x28xf32> - %0 = tosa.conv2d %input, %weights, %bias {pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, dilation = array<i64: 2, 1>} : (tensor<?x49x42x27xf32>, tensor<28x3x3x27xf32>, tensor<28xf32>) -> tensor<?x45x40x28xf32> + %0 = tosa.conv2d %input, %weights, %bias {acc_type = f32, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, dilation = array<i64: 2, 1>} : (tensor<?x49x42x27xf32>, tensor<28x3x3x27xf32>, tensor<28xf32>) -> tensor<?x45x40x28xf32> return } @@ -627,7 +627,7 @@ func.func @conv2d_dyn_w_h(%input: tensor<1x?x?x27xf32>, %weights: tensor<28x3x3x // CHECK: } -> tensor<1x?x?x28xf32> // CHECK: linalg.conv_2d_nhwc_fhwc {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<1x?x?x27xf32>, tensor<28x3x3x27xf32>) outs(%17 : tensor<1x?x?x28xf32>) -> tensor<1x?x?x28xf32> - %0 = tosa.conv2d %input, %weights, %bias {pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, dilation = array<i64: 2, 1>} : (tensor<1x?x?x27xf32>, tensor<28x3x3x27xf32>, tensor<28xf32>) -> tensor<1x?x?x28xf32> + %0 = tosa.conv2d %input, %weights, %bias {acc_type = f32, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, dilation = array<i64: 2, 1>} : (tensor<1x?x?x27xf32>, tensor<28x3x3x27xf32>, tensor<28xf32>) -> tensor<1x?x?x28xf32> return } @@ -650,7 +650,7 @@ func.func @conv2d_dyn_output(%input: tensor<2x6x5x4xf32>, %weights: tensor<4x3x3 // linalg.yield %[[ADD]] : f32 // } -> tensor<?x4x3x4xf32> - %0 = tosa.conv2d %input, %weights, %bias {dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<2x6x5x4xf32 >, tensor<4x3x3x4xf32>, tensor<4xf32>) -> tensor<?x4x3x4xf32> + %0 = tosa.conv2d %input, %weights, %bias {acc_type = f32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<2x6x5x4xf32 >, tensor<4x3x3x4xf32>, tensor<4xf32>) -> tensor<?x4x3x4xf32> return } @@ -662,7 +662,7 @@ func.func @conv2d_padded_f32(%input: tensor<1x47x40x28xf32>, %weights: tensor<28 // CHECK: tensor.pad %arg0 low[0, 1, 1, 0] high[0, 1, 1, 0] // CHECK: tensor.yield %[[C0]] // CHECK: linalg.conv_2d_nhwc_fhwc - %0 = tosa.conv2d %input, %weights, %bias {pad = array<i64: 1, 1, 1, 1>, stride = array<i64: 1, 1>, dilation = array<i64: 2, 1>} : (tensor<1x47x40x28xf32>, tensor<28x3x3x28xf32>, tensor<28xf32>) -> tensor<1x45x40x28xf32> + %0 = tosa.conv2d %input, %weights, %bias {acc_type = f32, pad = array<i64: 1, 1, 1, 1>, stride = array<i64: 1, 1>, dilation = array<i64: 2, 1>} : (tensor<1x47x40x28xf32>, tensor<28x3x3x28xf32>, tensor<28xf32>) -> tensor<1x45x40x28xf32> return } @@ -674,7 +674,7 @@ func.func @conv2d_quant(%arg0 : tensor<1x12x12x1xi8>, %arg1 : tensor<1024x3x3x1x // CHECK: tensor.pad %arg0 low[0, 1, 1, 0] high[0, 1, 1, 0] // CHECK: tensor.yield %[[C22]] // CHECK: linalg.conv_2d_nhwc_fhwc_q - %0 = tosa.conv2d %arg0, %arg1, %arg2 {dilation = array<i64: 1, 1>, pad = array<i64: 1, 1, 1, 1>, quantization_info = #tosa.conv_quant<input_zp = -22, weight_zp = 42>, stride = array<i64: 1, 1>} : (tensor<1x12x12x1xi8>, tensor<1024x3x3x1xi8>, tensor<1024xi32>) -> tensor<1x12x12x1024xi32> + %0 = tosa.conv2d %arg0, %arg1, %arg2 {acc_type = i32, dilation = array<i64: 1, 1>, pad = array<i64: 1, 1, 1, 1>, quantization_info = #tosa.conv_quant<input_zp = -22, weight_zp = 42>, stride = array<i64: 1, 1>} : (tensor<1x12x12x1xi8>, tensor<1024x3x3x1xi8>, tensor<1024xi32>) -> tensor<1x12x12x1024xi32> return } @@ -696,7 +696,7 @@ func.func @depthwise_conv(%arg0 : tensor<1x7x5x3xf32>, %arg1 : tensor<3x1x3x11xf // CHECK: [[ADD:%.+]] = arith.addf %[[ARG3]], %[[ARG4]] : f32 // CHECK: linalg.yield [[ADD]] : f32 // CHECK: } -> tensor<1x5x5x33xf32> - %2 = tosa.depthwise_conv2d %arg0, %arg1, %arg2 { pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, dilation = array<i64: 1, 1> } : (tensor<1x7x5x3xf32>, tensor<3x1x3x11xf32>, tensor<33xf32>) -> tensor<1x5x5x33xf32> + %2 = tosa.depthwise_conv2d %arg0, %arg1, %arg2 { acc_type = f32, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, dilation = array<i64: 1, 1> } : (tensor<1x7x5x3xf32>, tensor<3x1x3x11xf32>, tensor<33xf32>) -> tensor<1x5x5x33xf32> return } @@ -712,7 +712,7 @@ func.func @depthwise_conv_scalar_bias(%arg0 : tensor<1x7x5x3xf32>, %arg1 : tenso // CHECK: [[ADD:%.+]] = arith.addf %[[ARG3]], %[[ARG4]] : f32 // CHECK: linalg.yield [[ADD]] : f32 // CHECK: } -> tensor<1x5x5x33xf32> - %2 = tosa.depthwise_conv2d %arg0, %arg1, %arg2 { pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, dilation = array<i64: 1, 1> } : (tensor<1x7x5x3xf32>, tensor<3x1x3x11xf32>, tensor<1xf32>) -> tensor<1x5x5x33xf32> + %2 = tosa.depthwise_conv2d %arg0, %arg1, %arg2 {acc_type = f32, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, dilation = array<i64: 1, 1> } : (tensor<1x7x5x3xf32>, tensor<3x1x3x11xf32>, tensor<1xf32>) -> tensor<1x5x5x33xf32> return } @@ -736,7 +736,7 @@ func.func @depthwise_conv_dyn(%arg0 : tensor<?x7x5x3xf32>, %arg1 : tensor<3x1x3x // CHECK: %[[ADD:.+]] = arith.addf %[[ARG3]], %[[ARG4]] : f32 // CHECK: linalg.yield %[[ADD]] : f32 // CHECK: } -> tensor<?x5x5x33xf32> - %2 = tosa.depthwise_conv2d %arg0, %arg1, %arg2 { pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, dilation = array<i64: 1, 1> } : (tensor<?x7x5x3xf32>, tensor<3x1x3x11xf32>, tensor<33xf32>) -> tensor<?x5x5x33xf32> + %2 = tosa.depthwise_conv2d %arg0, %arg1, %arg2 { acc_type = f32, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, dilation = array<i64: 1, 1> } : (tensor<?x7x5x3xf32>, tensor<3x1x3x11xf32>, tensor<33xf32>) -> tensor<?x5x5x33xf32> return } @@ -758,7 +758,7 @@ func.func @depthwise_conv_strides(%arg0 : tensor<1x11x9x3xf32>, %arg1 : tensor<3 // CHECK: [[ADD:%.+]] = arith.addf %[[ARG3]], %[[ARG4]] : f32 // CHECK: linalg.yield [[ADD]] : f32 // CHECK: } -> tensor<1x5x5x33xf32> - %2 = tosa.depthwise_conv2d %arg0, %arg1, %arg2 { pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 2, 2>, dilation = array<i64: 1, 1> } : (tensor<1x11x9x3xf32>, tensor<3x1x3x11xf32>, tensor<33xf32>) -> tensor<1x5x5x33xf32> + %2 = tosa.depthwise_conv2d %arg0, %arg1, %arg2 { acc_type = f32, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 2, 2>, dilation = array<i64: 1, 1> } : (tensor<1x11x9x3xf32>, tensor<3x1x3x11xf32>, tensor<33xf32>) -> tensor<1x5x5x33xf32> return } @@ -786,7 +786,7 @@ func.func @depthwise_conv_quant(%arg0 : tensor<1x12x12x4xi8>, %arg1 : tensor<3x3 // CHECK: [[ADD:%.+]] = arith.addi %[[ARG3]], %[[ARG4]] : i32 // CHECK: linalg.yield [[ADD]] : i32 // CHECK: } -> tensor<1x12x12x512xi32> - %0 = tosa.depthwise_conv2d %arg0, %arg1, %arg2 {pad = array<i64: 1, 1, 1, 1>, quantization_info = #tosa.conv_quant<input_zp = -128, weight_zp = 42>, stride = array<i64: 1, 1>, dilation = array<i64: 1, 1> } : (tensor<1x12x12x4xi8>, tensor<3x3x4x128xi8>, tensor<512xi32>) -> tensor<1x12x12x512xi32> + %0 = tosa.depthwise_conv2d %arg0, %arg1, %arg2 {acc_type = i32, pad = array<i64: 1, 1, 1, 1>, quantization_info = #tosa.conv_quant<input_zp = -128, weight_zp = 42>, stride = array<i64: 1, 1>, dilation = array<i64: 1, 1> } : (tensor<1x12x12x4xi8>, tensor<3x3x4x128xi8>, tensor<512xi32>) -> tensor<1x12x12x512xi32> return } @@ -810,7 +810,7 @@ func.func @depthwise_conv_quant_dilations(%arg0 : tensor<1x14x14x4xi8>, %arg1 : // CHECK: [[ADD:%.+]] = arith.addi %[[ARG3]], %[[ARG4]] : i32 // CHECK: linalg.yield [[ADD]] : i32 // CHECK: } -> tensor<1x10x10x512xi32> - %0 = tosa.depthwise_conv2d %arg0, %arg1, %arg2 {pad = array<i64: 0, 0, 0, 0>, quantization_info = #tosa.conv_quant<input_zp = -128, weight_zp = 42>, stride = array<i64: 1, 1>, dilation = array<i64: 2, 2> } : (tensor<1x14x14x4xi8>, tensor<3x3x4x128xi8>, tensor<512xi32>) -> tensor<1x10x10x512xi32> + %0 = tosa.depthwise_conv2d %arg0, %arg1, %arg2 {acc_type = i32, pad = array<i64: 0, 0, 0, 0>, quantization_info = #tosa.conv_quant<input_zp = -128, weight_zp = 42>, stride = array<i64: 1, 1>, dilation = array<i64: 2, 2> } : (tensor<1x14x14x4xi8>, tensor<3x3x4x128xi8>, tensor<512xi32>) -> tensor<1x10x10x512xi32> return } @@ -826,7 +826,7 @@ func.func @depthwise_conv2d_dyn_w_h(%arg0: tensor<2x?x?x3xf32>, %arg1: tensor<3x // CHECK: } : tensor<2x?x?x3xf32> to tensor<2x?x?x3xf32> // CHECK: %[[CONV:.+]] = linalg.depthwise_conv_2d_nhwc_hwcm {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<[1, 2]> : tensor<2xi64>} ins(%[[PADDED]], %arg1 : tensor<2x?x?x3xf32>, tensor<3x6x3x5xf32>) outs(%{{.*}} : tensor<2x?x?x3x5xf32>) -> tensor<2x?x?x3x5xf32> // CHECK: %[[COLLAPSED:.+]] = tensor.collapse_shape %[[CONV]] {{\[}}[0], [1], [2], [3, 4]] - %0 = tosa.depthwise_conv2d %arg0, %arg1, %arg2 {pad = array<i64: 1, 2, 3, 4>, dilation = array<i64: 2, 1>, stride = array<i64: 1, 2>} : (tensor<2x?x?x3xf32>, tensor<3x6x3x5xf32>, tensor<15xf32>) -> tensor<2x?x?x15xf32> + %0 = tosa.depthwise_conv2d %arg0, %arg1, %arg2 {acc_type = f32, pad = array<i64: 1, 2, 3, 4>, dilation = array<i64: 2, 1>, stride = array<i64: 1, 2>} : (tensor<2x?x?x3xf32>, tensor<3x6x3x5xf32>, tensor<15xf32>) -> tensor<2x?x?x15xf32> return } @@ -850,7 +850,7 @@ func.func @conv3d_f32(%input: tensor<1x49x48x47x27xf32>, %weights: tensor<28x3x4 // CHECK-SAME: {dilations = dense<1> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>} // CHECK-SAME: ins(%arg0, %[[TRANSPOSE]] : tensor<1x49x48x47x27xf32>, tensor<3x4x5x27x28xf32>) // CHECK-SAME: outs(%[[BROADCAST]] : tensor<1x47x45x43x28xf32>) -> tensor<1x47x45x43x28xf32> - %0 = tosa.conv3d %input, %weights, %bias {pad = array<i64: 0, 0, 0, 0, 0, 0>, stride = array<i64: 1, 1, 1>, dilation = array<i64: 1, 1, 1>} : (tensor<1x49x48x47x27xf32>, tensor<28x3x4x5x27xf32>, tensor<28xf32>) -> tensor<1x47x45x43x28xf32> + %0 = tosa.conv3d %input, %weights, %bias {acc_type = f32, pad = array<i64: 0, 0, 0, 0, 0, 0>, stride = array<i64: 1, 1, 1>, dilation = array<i64: 1, 1, 1>} : (tensor<1x49x48x47x27xf32>, tensor<28x3x4x5x27xf32>, tensor<28xf32>) -> tensor<1x47x45x43x28xf32> return } @@ -864,7 +864,7 @@ func.func @conv3d_scalar_bias_f32(%input: tensor<1x49x48x47x27xf32>, %weights: t // CHECK: %[[INIT:.+]] = tensor.empty() : tensor<1x47x45x43x28xf32> // CHECK: %[[BROADCAST:.+]] = linalg.generic // CHECK-SAME: {indexing_maps = [#[[$MAP1]], #[[$MAP2]]], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]} - %0 = tosa.conv3d %input, %weights, %bias {pad = array<i64: 0, 0, 0, 0, 0, 0>, stride = array<i64: 1, 1, 1>, dilation = array<i64: 1, 1, 1>} : (tensor<1x49x48x47x27xf32>, tensor<28x3x4x5x27xf32>, tensor<1xf32>) -> tensor<1x47x45x43x28xf32> + %0 = tosa.conv3d %input, %weights, %bias {acc_type = f32, pad = array<i64: 0, 0, 0, 0, 0, 0>, stride = array<i64: 1, 1, 1>, dilation = array<i64: 1, 1, 1>} : (tensor<1x49x48x47x27xf32>, tensor<28x3x4x5x27xf32>, tensor<1xf32>) -> tensor<1x47x45x43x28xf32> return } @@ -892,7 +892,7 @@ func.func @conv3d_i8(%input: tensor<1x49x48x47x27xi8>, %weights: tensor<28x3x4x5 // CHECK-SAME: ins(%arg0, %[[TRANSPOSE]], %[[IZP]], %[[FZP]] : tensor<1x49x48x47x27xi8>, tensor<3x4x5x27x28xi8>, i32, i32) // CHECK-SAME: outs(%[[BROADCAST]] : tensor<1x47x45x43x28xi32>) -> tensor<1x47x45x43x28xi32> - %0 = tosa.conv3d %input, %weights, %bias {pad = array<i64: 0, 0, 0, 0, 0, 0>, quantization_info = #tosa.conv_quant<input_zp = -128, weight_zp = 42>, stride = array<i64: 1, 1, 1>, dilation = array<i64: 1, 1, 1>} : (tensor<1x49x48x47x27xi8>, tensor<28x3x4x5x27xi8>, tensor<28xi32>) -> tensor<1x47x45x43x28xi32> + %0 = tosa.conv3d %input, %weights, %bias {acc_type = i32, pad = array<i64: 0, 0, 0, 0, 0, 0>, quantization_info = #tosa.conv_quant<input_zp = -128, weight_zp = 42>, stride = array<i64: 1, 1, 1>, dilation = array<i64: 1, 1, 1>} : (tensor<1x49x48x47x27xi8>, tensor<28x3x4x5x27xi8>, tensor<28xi32>) -> tensor<1x47x45x43x28xi32> return } diff --git a/mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir b/mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir index 1e62e25176a0..0b9a64494bc0 100644 --- a/mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir +++ b/mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir @@ -459,85 +459,65 @@ func.func @slice_dyn(%arg0: tensor<?xf32>) -> (tensor<?xf32>) { // CHECK-LABEL: @pad_float // CHECK-SAME: (%[[ARG0:[0-9a-zA-Z_]*]]: func.func @pad_float(%arg0 : tensor<1x2xf32>) -> (tensor<4x9xf32>) { - %0 = arith.constant dense<[[1, 2], [3, 4]]> : tensor<2x2xi32> - // TODO: Output contains multiple "arith.constant 1 : index". - // CHECK-DAG: [[INDEX1:%.+]] = arith.constant 1 : index - // CHECK-DAG: [[INDEX2:%.+]] = arith.constant 2 : index - // CHECK-DAG: [[INDEX3:%.+]] = arith.constant 3 : index - // CHECK-DAG: [[INDEX4:%.+]] = arith.constant 4 : index + %0 = arith.constant dense<[1, 2, 3, 4]> : tensor<4xi32> // CHECK-DAG: [[CST:%.+]] = arith.constant 0.000000e+00 : f32 - // CHECK: tensor.pad %[[ARG0]] low{{\[}}%{{.*}}, [[INDEX3]]] high{{\[}}[[INDEX2]], [[INDEX4]]] { + // CHECK: tensor.pad %[[ARG0]] low{{\[}}%{{.*}}, %{{.*}}] high{{\[}}%{{.*}}, %{{.*}}] { // CHECK: tensor.yield [[CST]] // CHECK: } : tensor<1x2xf32> to tensor<4x9xf32> - %1 = "tosa.pad"(%arg0, %0) : (tensor<1x2xf32>, tensor<2x2xi32>) -> (tensor<4x9xf32>) + %1 = "tosa.pad"(%arg0, %0) : (tensor<1x2xf32>, tensor<4xi32>) -> (tensor<4x9xf32>) return %1 : tensor<4x9xf32> } func.func @pad_int(%arg0 : tensor<1x2xi32>) -> (tensor<4x9xi32>) { - %0 = arith.constant dense<[[1, 2], [3, 4]]> : tensor<2x2xi32> + %0 = arith.constant dense<[1, 2, 3, 4]> : tensor<4xi32> // CHECK: [[CST:%.+]] = arith.constant 0 : i32 // CHECK: tensor.pad // CHECK: tensor.yield [[CST]] - %1 = "tosa.pad"(%arg0, %0) : (tensor<1x2xi32>, tensor<2x2xi32>) -> (tensor<4x9xi32>) + %1 = "tosa.pad"(%arg0, %0) : (tensor<1x2xi32>, tensor<4xi32>) -> (tensor<4x9xi32>) return %1 : tensor<4x9xi32> } func.func @pad_quant(%arg0 : tensor<1x2xi32>) -> (tensor<4x9xi32>) { - %0 = arith.constant dense<[[1, 2], [3, 4]]> : tensor<2x2xi32> + %0 = arith.constant dense<[1, 2, 3, 4]> : tensor<4xi32> // CHECK: [[CST:%.+]] = arith.constant 42 : i32 // CHECK: tensor.pad // CHECK: tensor.yield [[CST]] - %1 = "tosa.pad"(%arg0, %0) {quantization_info = #tosa.pad_quant<input_zp = 42>} : (tensor<1x2xi32>, tensor<2x2xi32>) -> (tensor<4x9xi32>) + %1 = "tosa.pad"(%arg0, %0) {quantization_info = #tosa.pad_quant<input_zp = 42>} : (tensor<1x2xi32>, tensor<4xi32>) -> (tensor<4x9xi32>) return %1 : tensor<4x9xi32> } // ----- func.func @pad_float_explicit(%arg0 : tensor<1x2xf32>) -> (tensor<4x9xf32>) { - %0 = arith.constant dense<[[1, 2], [3, 4]]> : tensor<2x2xi32> - // TODO: Output contains multiple "arith.constant 1 : index". - // CHECK-DAG: [[INDEX1:%.+]] = arith.constant 1 : index - // CHECK-DAG: [[INDEX2:%.+]] = arith.constant 2 : index - // CHECK-DAG: [[INDEX3:%.+]] = arith.constant 3 : index - // CHECK-DAG: [[INDEX4:%.+]] = arith.constant 4 : index + %0 = arith.constant dense<[1, 2, 3, 4]> : tensor<4xi32> // CHECK-DAG: [[CST:%.+]] = arith.constant 4.200000e+01 : f32 - // CHECK: tensor.pad %[[ARG0]] low{{\[}}%{{.*}}, [[INDEX3]]] high{{\[}}[[INDEX2]], [[INDEX4]]] { + // CHECK: tensor.pad %[[ARG0]] low{{\[}}%{{.*}}, %{{.*}}] high{{\[}}%{{.*}}, %{{.*}}] { // CHECK: tensor.yield [[CST]] // CHECK: } : tensor<1x2xf32> to tensor<4x9xf32> %1 = arith.constant dense<42.0> : tensor<f32> - %2 = "tosa.pad"(%arg0, %0, %1) : (tensor<1x2xf32>, tensor<2x2xi32>, tensor<f32>) -> (tensor<4x9xf32>) + %2 = "tosa.pad"(%arg0, %0, %1) : (tensor<1x2xf32>, tensor<4xi32>, tensor<f32>) -> (tensor<4x9xf32>) return %2 : tensor<4x9xf32> } // ----- func.func @pad_dyn_input(%arg0 : tensor<?x2xf32>) -> (tensor<?x9xf32>) { - %0 = arith.constant dense<[[1, 2], [3, 4]]> : tensor<2x2xi32> - // TODO: Output contains multiple "arith.constant 1 : index". - // CHECK-DAG: [[INDEX1:%.+]] = arith.constant 1 : index - // CHECK-DAG: [[INDEX2:%.+]] = arith.constant 2 : index - // CHECK-DAG: [[INDEX3:%.+]] = arith.constant 3 : index - // CHECK-DAG: [[INDEX4:%.+]] = arith.constant 4 : index + %0 = arith.constant dense<[1, 2, 3, 4]> : tensor<4xi32> // CHECK-DAG: [[CST:%.+]] = arith.constant 0.000000e+00 : f32 - // CHECK: tensor.pad %[[ARG0]] low{{\[}}%{{.*}}, [[INDEX3]]] high{{\[}}[[INDEX2]], [[INDEX4]]] { + // CHECK: tensor.pad %[[ARG0]] low{{\[}}%{{.*}}, %{{.*}}] high{{\[}}%{{.*}}, %{{.*}}] { // CHECK: tensor.yield [[CST]] // CHECK: } : tensor<?x2xf32> to tensor<?x9xf32> - %1 = "tosa.pad"(%arg0, %0) : (tensor<?x2xf32>, tensor<2x2xi32>) -> (tensor<?x9xf32>) + %1 = "tosa.pad"(%arg0, %0) : (tensor<?x2xf32>, tensor<4xi32>) -> (tensor<?x9xf32>) return %1 : tensor<?x9xf32> } func.func @pad_dyn_padding(%arg0 : tensor<1x2xf32>) -> (tensor<?x9xf32>) { - %0 = arith.constant dense<[[-1, 2], [3, 4]]> : tensor<2x2xi32> - // TODO: Output contains multiple "arith.constant 1 : index". - // CHECK-DAG: [[INDEX1:%.+]] = arith.constant 1 : index - // CHECK-DAG: [[INDEX2:%.+]] = arith.constant 2 : index - // CHECK-DAG: [[INDEX3:%.+]] = arith.constant 3 : index - // CHECK-DAG: [[INDEX4:%.+]] = arith.constant 4 : index + %0 = arith.constant dense<[-1, 2, 3, 4]> : tensor<4xi32> // CHECK-DAG: [[CST:%.+]] = arith.constant 0.000000e+00 : f32 - // CHECK: tensor.pad %[[ARG0]] low{{\[}}%{{.*}}, [[INDEX3]]] high{{\[}}[[INDEX2]], [[INDEX4]]] { + // CHECK: tensor.pad %[[ARG0]] low{{\[}}%{{.*}}, %{{.*}}] high{{\[}}%{{.*}}, %{{.*}}] { // CHECK: tensor.yield [[CST]] // CHECK: } : tensor<1x2xf32> to tensor<?x9xf32> - %1 = "tosa.pad"(%arg0, %0) : (tensor<1x2xf32>, tensor<2x2xi32>) -> (tensor<?x9xf32>) + %1 = "tosa.pad"(%arg0, %0) : (tensor<1x2xf32>, tensor<4xi32>) -> (tensor<?x9xf32>) return %1 : tensor<?x9xf32> } diff --git a/mlir/test/Dialect/Affine/canonicalize.mlir b/mlir/test/Dialect/Affine/canonicalize.mlir index 717004eb50c0..a9ac13ad7162 100644 --- a/mlir/test/Dialect/Affine/canonicalize.mlir +++ b/mlir/test/Dialect/Affine/canonicalize.mlir @@ -1917,12 +1917,12 @@ func.func @linearize_one_element_basis(%arg0: index, %arg1: index) -> index { // ----- -// CHECK-LABEL: func @cancel_linearize_denearize_exact( +// CHECK-LABEL: func @cancel_linearize_delinearize_exact( // CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: index, // CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: index, // CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: index) // CHECK: return %[[ARG0]] -func.func @cancel_linearize_denearize_exact(%arg0: index, %arg1: index, %arg2: index) -> index { +func.func @cancel_linearize_delinearize_exact(%arg0: index, %arg1: index, %arg2: index) -> index { %0:3 = affine.delinearize_index %arg0 into (%arg1, 4, %arg2) : index, index, index %1 = affine.linearize_index [%0#0, %0#1, %0#2] by (%arg1, 4, %arg2) : index return %1 : index @@ -1930,12 +1930,12 @@ func.func @cancel_linearize_denearize_exact(%arg0: index, %arg1: index, %arg2: i // ----- -// CHECK-LABEL: func @cancel_linearize_denearize_linearize_extra_bound( +// CHECK-LABEL: func @cancel_linearize_delinearize_linearize_extra_bound( // CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: index, // CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: index, // CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: index) // CHECK: return %[[ARG0]] -func.func @cancel_linearize_denearize_linearize_extra_bound(%arg0: index, %arg1: index, %arg2: index) -> index { +func.func @cancel_linearize_delinearize_linearize_extra_bound(%arg0: index, %arg1: index, %arg2: index) -> index { %0:3 = affine.delinearize_index %arg0 into (4, %arg2) : index, index, index %1 = affine.linearize_index [%0#0, %0#1, %0#2] by (%arg1, 4, %arg2) : index return %1 : index @@ -1943,12 +1943,12 @@ func.func @cancel_linearize_denearize_linearize_extra_bound(%arg0: index, %arg1: // ----- -// CHECK-LABEL: func @cancel_linearize_denearize_delinearize_extra_bound( +// CHECK-LABEL: func @cancel_linearize_delinearize_delinearize_extra_bound( // CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: index, // CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: index, // CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: index) // CHECK: return %[[ARG0]] -func.func @cancel_linearize_denearize_delinearize_extra_bound(%arg0: index, %arg1: index, %arg2: index) -> index { +func.func @cancel_linearize_delinearize_delinearize_extra_bound(%arg0: index, %arg1: index, %arg2: index) -> index { %0:3 = affine.delinearize_index %arg0 into (%arg1, 4, %arg2) : index, index, index %1 = affine.linearize_index [%0#0, %0#1, %0#2] by (4, %arg2) : index return %1 : index @@ -1956,31 +1956,252 @@ func.func @cancel_linearize_denearize_delinearize_extra_bound(%arg0: index, %arg // ----- +// CHECK-LABEL: func @cancel_linearize_delinearize_head( +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: index, +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: index) +// CHECK: %[[DELIN:.+]]:2 = affine.delinearize_index %[[ARG0]] into (12, 8) +// CHECK: %[[LIN:.+]] = affine.linearize_index [%[[DELIN]]#0, %[[ARG1]]] by (12, 16) +// CHECK: return %[[LIN]] +func.func @cancel_linearize_delinearize_head(%arg0: index, %arg1: index) -> index { + %0:3 = affine.delinearize_index %arg0 into (3, 4, 8) : index, index, index + %1 = affine.linearize_index [%0#0, %0#1, %arg1] by (3, 4, 16) : index + return %1 : index +} + +// ----- + +// CHECK-LABEL: func @cancel_linearize_delinearize_head_delinearize_unbounded( +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: index, +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: index) +// CHECK: %[[DELIN:.+]]:2 = affine.delinearize_index %[[ARG0]] into (12, 8) +// CHECK: %[[LIN:.+]] = affine.linearize_index [%[[DELIN]]#0, %[[ARG1]]] by (12, 16) +// CHECK: return %[[LIN]] +func.func @cancel_linearize_delinearize_head_delinearize_unbounded(%arg0: index, %arg1: index) -> index { + %0:3 = affine.delinearize_index %arg0 into (4, 8) : index, index, index + %1 = affine.linearize_index [%0#0, %0#1, %arg1] by (3, 4, 16) : index + return %1 : index +} + +// ----- + +// CHECK-LABEL: func @cancel_linearize_delinearize_head_linearize_unbounded( +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: index, +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: index) +// CHECK: %[[DELIN:.+]]:2 = affine.delinearize_index %[[ARG0]] into (8) +// CHECK: %[[LIN:.+]] = affine.linearize_index [%[[DELIN]]#0, %[[ARG1]]] by (16) +// CHECK: return %[[LIN]] +func.func @cancel_linearize_delinearize_head_linearize_unbounded(%arg0: index, %arg1: index) -> index { + %0:3 = affine.delinearize_index %arg0 into (3, 4, 8) : index, index, index + %1 = affine.linearize_index [%0#0, %0#1, %arg1] by (4, 16) : index + return %1 : index +} + +// ----- + +// CHECK-LABEL: func @cancel_linearize_delinearize_head_both_unbounded( +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: index, +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: index) +// CHECK: %[[DELIN:.+]]:2 = affine.delinearize_index %[[ARG0]] into (8) +// CHECK: %[[LIN:.+]] = affine.linearize_index [%[[DELIN]]#0, %[[ARG1]]] by (16) +// CHECK: return %[[LIN]] +func.func @cancel_linearize_delinearize_head_both_unbounded(%arg0: index, %arg1: index) -> index { + %0:3 = affine.delinearize_index %arg0 into (4, 8) : index, index, index + %1 = affine.linearize_index [%0#0, %0#1, %arg1] by (4, 16) : index + return %1 : index +} + +// ----- + +// CHECK-LABEL: func @cancel_linearize_delinearize_tail( +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: index, +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: index) +// CHECK: %[[DELIN:.+]]:2 = affine.delinearize_index %[[ARG0]] into (3, 32) +// CHECK: %[[LIN:.+]] = affine.linearize_index [%[[ARG1]], %[[DELIN]]#1] by (5, 32) +// CHECK: return %[[LIN]] +func.func @cancel_linearize_delinearize_tail(%arg0: index, %arg1: index) -> index { + %0:3 = affine.delinearize_index %arg0 into (3, 4, 8) : index, index, index + %1 = affine.linearize_index [%arg1, %0#1, %0#2] by (5, 4, 8) : index + return %1 : index +} + +// ----- + +// CHECK-LABEL: func @cancel_linearize_delinearize_middle_exact( +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: index, +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: index, +// CHECK-SAME: %[[ARG2:[a-zA-z0-9]+]]: index) +// CHECK: %[[LIN:.+]] = affine.linearize_index [%[[ARG1]], %[[ARG0]], %[[ARG2]]] by (9, 30, 7) +// CHECK: return %[[LIN]] +func.func @cancel_linearize_delinearize_middle_exact(%arg0: index, %arg1: index, %arg2: index) -> index { + %0:3 = affine.delinearize_index %arg0 into (2, 3, 5) : index, index, index + %1 = affine.linearize_index [%arg1, %0#0, %0#1, %0#2, %arg2] by (9, 2, 3, 5, 7) : index + return %1 : index +} + +// ----- + +// CHECK: #[[$MAP:.+]] = affine_map<()[s0, s1] -> ((s0 * s1) * 16)> + +// CHECK-LABEL: func @cancel_linearize_delinearize_middle_exact_dynamic_basis( +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: index, +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: index, +// CHECK-SAME: %[[ARG2:[a-zA-z0-9]+]]: index) +// CHECK: %[[C1:.+]] = arith.constant 1 : index +// CHECK: %[[SIZEPROD:.+]] = affine.apply #[[$MAP]]()[%[[ARG1]], %[[ARG2]]] +// CHECK: %[[LIN:.+]] = affine.linearize_index [%[[C1]], %[[ARG0]], %[[C1]]] by (3, %[[SIZEPROD]], 4) +// CHECK: return %[[LIN]] +func.func @cancel_linearize_delinearize_middle_exact_dynamic_basis(%arg0: index, %arg1: index, %arg2: index) -> index { + %c1 = arith.constant 1 : index + %0:4 = affine.delinearize_index %arg0 into (2, %arg1, %arg2, 8) : index, index, index, index + %1 = affine.linearize_index [%c1, %0#0, %0#1, %0#2, %0#3, %c1] by (3, 2, %arg1, %arg2, 8, 4) : index + return %1 : index +} + +// ----- + +// CHECK-LABEL: func @cancel_linearize_delinearize_middle_exact_delinearize_unbounded_disjoint( +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: index, +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: index, +// CHECK-SAME: %[[ARG2:[a-zA-z0-9]+]]: index) +// CHECK: %[[LIN:.+]] = affine.linearize_index disjoint [%[[ARG1]], %[[ARG0]], %[[ARG2]]] by (9, 30, 7) +// CHECK: return %[[LIN]] +func.func @cancel_linearize_delinearize_middle_exact_delinearize_unbounded_disjoint(%arg0: index, %arg1: index, %arg2: index) -> index { + %0:3 = affine.delinearize_index %arg0 into (3, 5) : index, index, index + %1 = affine.linearize_index disjoint [%arg1, %0#0, %0#1, %0#2, %arg2] by (9, 2, 3, 5, 7) : index + return %1 : index +} + +// ----- + +// Unlike in the test above, the linerize indices aren't asserted to be disjoint, so +// we can't know if the `2` from the basis is a correct bound. +// CHECK-LABEL: func @dont_cancel_linearize_delinearize_middle_exact_delinearize_unbounded( +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: index, +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: index, +// CHECK-SAME: %[[ARG2:[a-zA-z0-9]+]]: index) +// CHECK: %[[DELIN:.+]]:2 = affine.delinearize_index %[[ARG0]] into (3) +// CHECK: %[[LIN:.+]] = affine.linearize_index [%[[ARG1]], %[[DELIN]]#0, %[[DELIN]]#1, %[[ARG2]]] by (9, 2, 3, 7) +// CHECK: return %[[LIN]] + +func.func @dont_cancel_linearize_delinearize_middle_exact_delinearize_unbounded(%arg0: index, %arg1: index, %arg2: index) -> index { + %0:2 = affine.delinearize_index %arg0 into (3) : index, index + %1 = affine.linearize_index [%arg1, %0#0, %0#1, %arg2] by (9, 2, 3, 7) : index + return %1 : index +} + +// ----- + +// The presence of a `disjoint` here tells us that the "unbounded" term on the +// delinearization can't have been above 2. +// CHECK-LABEL: func @cancel_linearize_delinearize_middle_delinearize_unbounded_disjoint_implied_bound( +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: index, +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: index, +// CHECK-SAME: %[[ARG2:[a-zA-z0-9]+]]: index) +// CHECK: %[[DELIN:.+]]:2 = affine.delinearize_index %[[ARG0]] into (6, 5) +// CHECK: %[[LIN:.+]] = affine.linearize_index disjoint [%[[ARG1]], %[[DELIN]]#0, %[[ARG2]]] by (9, 6, 7) +// CHECK: return %[[LIN]] +func.func @cancel_linearize_delinearize_middle_delinearize_unbounded_disjoint_implied_bound(%arg0: index, %arg1: index, %arg2: index) -> index { + %0:3 = affine.delinearize_index %arg0 into (3, 5) : index, index, index + %1 = affine.linearize_index disjoint [%arg1, %0#0, %0#1, %arg2] by (9, 2, 3, 7) : index + return %1 : index +} + +// ----- + +// CHECK-LABEL: func @cancel_linearize_delinearize_multiple_matches( +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: index, +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: index) +// CHECK: %[[C0:.+]] = arith.constant 0 +// CHECK: %[[DELIN:.+]]:4 = affine.delinearize_index %[[ARG0]] into (4, 16, 4, 64) +// CHECK: %[[LIN:.+]] = affine.linearize_index [%[[ARG1]], %[[DELIN]]#1, %[[C0]], %[[DELIN]]#3] by (4, 16, 4, 64) +// CHECK: return %[[LIN]] +func.func @cancel_linearize_delinearize_multiple_matches(%arg0: index, %arg1: index) -> index { + %c0 = arith.constant 0 : index + %0:7 = affine.delinearize_index %arg0 into (4, 4, 4, 4, 4, 4, 4) : index, index, index, index, index, index, index + %1 = affine.linearize_index [%arg1, %0#1, %0#2, %c0, %0#4, %0#5, %0#6] by (4, 4, 4, 4, 4, 4, 4) : index + return %1 : index +} + +// ----- + +// CHECK-LABEL: func @cancel_linearize_delinearize_multiple_delinearizes( +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: index, +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: index) +// CHECK: %[[LIN:.+]] = affine.linearize_index [%[[ARG0]], %[[ARG1]]] by (32, 32) +// CHECK: return %[[LIN]] +func.func @cancel_linearize_delinearize_multiple_delinearizes(%arg0: index, %arg1: index) -> index { + %0:2 = affine.delinearize_index %arg0 into (4, 8) : index, index + %1:2 = affine.delinearize_index %arg1 into (2, 16) : index, index + %2 = affine.linearize_index [%0#0, %0#1, %1#0, %1#1] by (4, 8, 2, 16) : index + return %2 : index +} + +// ----- + // Don't cancel because the values from the delinearize aren't used in order -// CHECK-LABEL: func @no_cancel_linearize_denearize_permuted( +// CHECK-LABEL: func @no_cancel_linearize_delinearize_permuted( // CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: index, // CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: index, // CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: index) // CHECK: %[[DELIN:.+]]:3 = affine.delinearize_index %[[ARG0]] into (%[[ARG1]], 4, %[[ARG2]]) -// CHECK: %[[LIN:.+]] = affine.linearize_index [%[[DELIN]]#0, %[[DELIN]]#2, %[[DELIN]]#1] by (%[[ARG1]], 4, %[[ARG2]]) +// CHECK: %[[LIN:.+]] = affine.linearize_index [%[[DELIN]]#0, %[[DELIN]]#2, %[[DELIN]]#1] by (%[[ARG1]], %[[ARG2]], 4) // CHECK: return %[[LIN]] -func.func @no_cancel_linearize_denearize_permuted(%arg0: index, %arg1: index, %arg2: index) -> index { +func.func @no_cancel_linearize_delinearize_permuted(%arg0: index, %arg1: index, %arg2: index) -> index { %0:3 = affine.delinearize_index %arg0 into (%arg1, 4, %arg2) : index, index, index - %1 = affine.linearize_index [%0#0, %0#2, %0#1] by (%arg1, 4, %arg2) : index + %1 = affine.linearize_index [%0#0, %0#2, %0#1] by (%arg1, %arg2, 4) : index + return %1 : index +} + +// ----- + +// CHECK: #[[$MAP:.+]] = affine_map<()[s0] -> (s0 * 3)> +// But these cancel because they're a contiguous segment +// CHECK-LABEL: func @partial_cancel_linearize_delinearize_not_fully_permuted( +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: index, +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: index, +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: index) +// CHECK: %[[SIZEPROD:.+]] = affine.apply #[[$MAP]]()[%[[ARG2]]] +// CHECK: %[[DELIN:.+]]:3 = affine.delinearize_index %[[ARG0]] into (%[[ARG1]], 4, %[[SIZEPROD]]) +// CHECK: %[[LIN:.+]] = affine.linearize_index [%[[DELIN]]#0, %[[DELIN]]#2, %[[DELIN]]#1] by (%[[ARG1]], %[[SIZEPROD]], 4) +// CHECK: return %[[LIN]] +func.func @partial_cancel_linearize_delinearize_not_fully_permuted(%arg0: index, %arg1: index, %arg2: index) -> index { + %0:4 = affine.delinearize_index %arg0 into (%arg1, 4, %arg2, 3) : index, index, index, index + %1 = affine.linearize_index [%0#0, %0#2, %0#3, %0#1] by (%arg1, %arg2, 3, 4) : index return %1 : index } // ----- +// Ensure we don't get SSA errors when creating new `affine.delinearize` operations. +// CHECK-LABEL: func @cancel_linearize_delinearize_placement +// CHECK-SAME: (%[[ARG0:.+]]: index) +// CHECK: %[[C0:.+]] = arith.constant 0 : index +// CHECK: %[[NEW_DELIN:.+]]:2 = affine.delinearize_index %[[ARG0]] into (8, 32) : index, index +// CHECK-NEXT: %[[DELIN_PART:.+]]:2 = affine.delinearize_index %[[NEW_DELIN]]#1 into (8, 4) : index, index +// CHECK-NEXT: %[[L1:.+]] = affine.linearize_index disjoint [%[[DELIN_PART]]#1, %[[NEW_DELIN]]#0, %[[C0]], %[[C0]]] by (4, 8, 4, 8) +// CHECK-NEXT: %[[L2:.+]] = affine.linearize_index disjoint [%[[NEW_DELIN]]#1, %[[C0]], %[[C0]]] by (32, 8, 4) +// CHECK-NEXT: %[[L3:.+]] = affine.linearize_index disjoint [%[[DELIN_PART]]#0, %[[NEW_DELIN]]#0, %[[C0]], %[[C0]]] by (8, 8, 4, 4) +// CHECK-NEXT: return %[[L1]], %[[L2]], %[[L3]] +func.func @cancel_linearize_delinearize_placement(%arg0: index) -> (index, index, index) { + %c0 = arith.constant 0 : index + %0:3 = affine.delinearize_index %arg0 into (8, 8, 4) : index, index, index + %1 = affine.linearize_index disjoint [%0#2, %0#0, %c0, %c0] by (4, 8, 4, 8) : index + %2 = affine.linearize_index disjoint [%0#1, %0#2, %c0, %c0] by (8, 4, 8, 4) : index + %3 = affine.linearize_index disjoint [%0#1, %0#0, %c0, %c0] by (8, 8, 4, 4) : index + return %1, %2, %3 : index, index, index +} + +// ----- + // Won't cancel because the linearize and delinearize are using a different basis -// CHECK-LABEL: func @no_cancel_linearize_denearize_different_basis( +// CHECK-LABEL: func @no_cancel_linearize_delinearize_different_basis( // CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: index, // CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: index, // CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: index) // CHECK: %[[DELIN:.+]]:3 = affine.delinearize_index %[[ARG0]] into (%[[ARG1]], 4, %[[ARG2]]) // CHECK: %[[LIN:.+]] = affine.linearize_index [%[[DELIN]]#0, %[[DELIN]]#1, %[[DELIN]]#2] by (%[[ARG1]], 8, %[[ARG2]]) // CHECK: return %[[LIN]] -func.func @no_cancel_linearize_denearize_different_basis(%arg0: index, %arg1: index, %arg2: index) -> index { +func.func @no_cancel_linearize_delinearize_different_basis(%arg0: index, %arg1: index, %arg2: index) -> index { %0:3 = affine.delinearize_index %arg0 into (%arg1, 4, %arg2) : index, index, index %1 = affine.linearize_index [%0#0, %0#1, %0#2] by (%arg1, 8, %arg2) : index return %1 : index diff --git a/mlir/test/Dialect/Affine/value-bounds-op-interface-impl.mlir b/mlir/test/Dialect/Affine/value-bounds-op-interface-impl.mlir index 935c08aceff5..5354eb38d7b0 100644 --- a/mlir/test/Dialect/Affine/value-bounds-op-interface-impl.mlir +++ b/mlir/test/Dialect/Affine/value-bounds-op-interface-impl.mlir @@ -155,3 +155,84 @@ func.func @compare_maps(%a: index, %b: index) { : (index, index, index, index) -> () return } + +// ----- + +// CHECK-DAG: #[[$map1:.+]] = affine_map<()[s0] -> (s0 floordiv 15)> +// CHECK-DAG: #[[$map2:.+]] = affine_map<()[s0] -> ((s0 mod 15) floordiv 5)> +// CHECK-DAG: #[[$map3:.+]] = affine_map<()[s0] -> (s0 mod 5)> +// CHECK-LABEL: func.func @delinearize_static +// CHECK-SAME: (%[[arg0:.+]]: index) +// CHECK-DAG: %[[v1:.+]] = affine.apply #[[$map1]]()[%[[arg0]]] +// CHECK-DAG: %[[v2:.+]] = affine.apply #[[$map2]]()[%[[arg0]]] +// CHECK-DAG: %[[v3:.+]] = affine.apply #[[$map3]]()[%[[arg0]]] +// CHECK: return %[[v1]], %[[v2]], %[[v3]] +func.func @delinearize_static(%arg0: index) -> (index, index, index) { + %c2 = arith.constant 2 : index + %c3 = arith.constant 3 : index + %0:3 = affine.delinearize_index %arg0 into (2, 3, 5) : index, index, index + %1 = "test.reify_bound"(%0#0) {type = "EQ"} : (index) -> (index) + %2 = "test.reify_bound"(%0#1) {type = "EQ"} : (index) -> (index) + %3 = "test.reify_bound"(%0#2) {type = "EQ"} : (index) -> (index) + // expected-remark @below{{true}} + "test.compare"(%0#0, %c2) {cmp = "LT"} : (index, index) -> () + // expected-remark @below{{true}} + "test.compare"(%0#1, %c3) {cmp = "LT"} : (index, index) -> () + return %1, %2, %3 : index, index, index +} + +// ----- + +// CHECK-DAG: #[[$map1:.+]] = affine_map<()[s0] -> (s0 floordiv 15)> +// CHECK-DAG: #[[$map2:.+]] = affine_map<()[s0] -> ((s0 mod 15) floordiv 5)> +// CHECK-DAG: #[[$map3:.+]] = affine_map<()[s0] -> (s0 mod 5)> +// CHECK-LABEL: func.func @delinearize_static_no_outer_bound +// CHECK-SAME: (%[[arg0:.+]]: index) +// CHECK-DAG: %[[v1:.+]] = affine.apply #[[$map1]]()[%[[arg0]]] +// CHECK-DAG: %[[v2:.+]] = affine.apply #[[$map2]]()[%[[arg0]]] +// CHECK-DAG: %[[v3:.+]] = affine.apply #[[$map3]]()[%[[arg0]]] +// CHECK: return %[[v1]], %[[v2]], %[[v3]] +func.func @delinearize_static_no_outer_bound(%arg0: index) -> (index, index, index) { + %c2 = arith.constant 2 : index + %c3 = arith.constant 3 : index + %0:3 = affine.delinearize_index %arg0 into (3, 5) : index, index, index + %1 = "test.reify_bound"(%0#0) {type = "EQ"} : (index) -> (index) + %2 = "test.reify_bound"(%0#1) {type = "EQ"} : (index) -> (index) + %3 = "test.reify_bound"(%0#2) {type = "EQ"} : (index) -> (index) + "test.compaare"(%0#0, %c2) {cmp = "LT"} : (index, index) -> () + // expected-remark @below{{true}} + "test.compare"(%0#1, %c3) {cmp = "LT"} : (index, index) -> () + return %1, %2, %3 : index, index, index +} + +// ----- + +// CHECK: #[[$map:.+]] = affine_map<()[s0, s1] -> (s0 + s1 * 3)> +// CHECK-LABEL: func.func @linearize_static +// CHECK-SAME: (%[[arg0:.+]]: index, %[[arg1:.+]]: index) +// CHECK: %[[v1:.+]] = affine.apply #[[$map]]()[%[[arg1]], %[[arg0]]] +// CHECK: return %[[v1]] +func.func @linearize_static(%arg0: index, %arg1: index) -> index { + %c6 = arith.constant 6 : index + %0 = affine.linearize_index disjoint [%arg0, %arg1] by (2, 3) : index + %1 = "test.reify_bound"(%0) {type = "EQ"} : (index) -> (index) + // expected-remark @below{{true}} + "test.compare"(%0, %c6) {cmp = "LT"} : (index, index) -> () + return %1 : index +} + +// ----- + +// CHECK: #[[$map:.+]] = affine_map<()[s0, s1] -> (s0 + s1 * 3)> +// CHECK-LABEL: func.func @linearize_static_no_outer_bound +// CHECK-SAME: (%[[arg0:.+]]: index, %[[arg1:.+]]: index) +// CHECK: %[[v1:.+]] = affine.apply #[[$map]]()[%[[arg1]], %[[arg0]]] +// CHECK: return %[[v1]] +func.func @linearize_static_no_outer_bound(%arg0: index, %arg1: index) -> index { + %c6 = arith.constant 6 : index + %0 = affine.linearize_index disjoint [%arg0, %arg1] by (3) : index + %1 = "test.reify_bound"(%0) {type = "EQ"} : (index) -> (index) + // expected-error @below{{unknown}} + "test.compare"(%0, %c6) {cmp = "LT"} : (index, index) -> () + return %1 : index +} diff --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir index 6a186a0c6cec..522711b08f28 100644 --- a/mlir/test/Dialect/Arith/canonicalize.mlir +++ b/mlir/test/Dialect/Arith/canonicalize.mlir @@ -2060,6 +2060,70 @@ func.func @test_divf1(%arg0 : f32, %arg1 : f32) -> (f32) { // ----- +func.func @fold_divui_of_muli_0(%arg0 : index, %arg1 : index) -> index { + %0 = arith.muli %arg0, %arg1 overflow<nuw> : index + %1 = arith.divui %0, %arg0 : index + return %1 : index +} +// CHECK-LABEL: func @fold_divui_of_muli_0( +// CHECK-SAME: %[[ARG0:.+]]: index, +// CHECK-SAME: %[[ARG1:.+]]: index) +// CHECK: return %[[ARG1]] + +func.func @fold_divui_of_muli_1(%arg0 : index, %arg1 : index) -> index { + %0 = arith.muli %arg0, %arg1 overflow<nuw> : index + %1 = arith.divui %0, %arg1 : index + return %1 : index +} +// CHECK-LABEL: func @fold_divui_of_muli_1( +// CHECK-SAME: %[[ARG0:.+]]: index, +// CHECK-SAME: %[[ARG1:.+]]: index) +// CHECK: return %[[ARG0]] + +func.func @fold_divsi_of_muli_0(%arg0 : index, %arg1 : index) -> index { + %0 = arith.muli %arg0, %arg1 overflow<nsw> : index + %1 = arith.divsi %0, %arg0 : index + return %1 : index +} +// CHECK-LABEL: func @fold_divsi_of_muli_0( +// CHECK-SAME: %[[ARG0:.+]]: index, +// CHECK-SAME: %[[ARG1:.+]]: index) +// CHECK: return %[[ARG1]] + +func.func @fold_divsi_of_muli_1(%arg0 : index, %arg1 : index) -> index { + %0 = arith.muli %arg0, %arg1 overflow<nsw> : index + %1 = arith.divsi %0, %arg1 : index + return %1 : index +} +// CHECK-LABEL: func @fold_divsi_of_muli_1( +// CHECK-SAME: %[[ARG0:.+]]: index, +// CHECK-SAME: %[[ARG1:.+]]: index) +// CHECK: return %[[ARG0]] + +// Do not fold divui(mul(a, v), v) -> a with nuw attribute. +func.func @no_fold_divui_of_muli(%arg0 : index, %arg1 : index) -> index { + %0 = arith.muli %arg0, %arg1 : index + %1 = arith.divui %0, %arg0 : index + return %1 : index +} +// CHECK-LABEL: func @no_fold_divui_of_muli +// CHECK: %[[T0:.+]] = arith.muli +// CHECK: %[[T1:.+]] = arith.divui %[[T0]], +// CHECK: return %[[T1]] + +// Do not fold divsi(mul(a, v), v) -> a with nuw attribute. +func.func @no_fold_divsi_of_muli(%arg0 : index, %arg1 : index) -> index { + %0 = arith.muli %arg0, %arg1 : index + %1 = arith.divsi %0, %arg0 : index + return %1 : index +} +// CHECK-LABEL: func @no_fold_divsi_of_muli +// CHECK: %[[T0:.+]] = arith.muli +// CHECK: %[[T1:.+]] = arith.divsi %[[T0]], +// CHECK: return %[[T1]] + +// ----- + // CHECK-LABEL: @test_cmpf( func.func @test_cmpf(%arg0 : f32) -> (i1, i1, i1, i1) { // CHECK-DAG: %[[T:.*]] = arith.constant true diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-empty-tensor-elimination.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-empty-tensor-elimination.mlir index 26434774730e..820fb3dfa5e5 100644 --- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-empty-tensor-elimination.mlir +++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-empty-tensor-elimination.mlir @@ -465,3 +465,14 @@ func.func @mutli_use_of_the_same_tensor_empty_creates_non_existent_read(%arg1: t : tensor<5x6x64xf32> into tensor<5x6x128xf32> return %inserted_slice_1, %res_2 : tensor<5x6x128xf32>, tensor<5x6x64xf32> } + +// ----- + +// CHECK-LABEL: func.func @direct_use_of_tensor_empty +func.func @direct_use_of_tensor_empty(%arg0: tensor<5x6x128xf32>) -> tensor<5x6x128xf32> { + // CHECK-NOT: memref.alloc + %empty_1 = tensor.empty() : tensor<5x6x64xf32> + %inserted_slice_1 = tensor.insert_slice %empty_1 into %arg0[0, 0, 0][5, 6, 64][1, 1, 1] + : tensor<5x6x64xf32> into tensor<5x6x128xf32> + return %inserted_slice_1 : tensor<5x6x128xf32> +} diff --git a/mlir/test/Dialect/GPU/indirect-device-func-call.mlir b/mlir/test/Dialect/GPU/indirect-device-func-call.mlir index 91d7f1cd6c67..85805da3ac10 100644 --- a/mlir/test/Dialect/GPU/indirect-device-func-call.mlir +++ b/mlir/test/Dialect/GPU/indirect-device-func-call.mlir @@ -6,7 +6,7 @@ gpu.module @kernels { func.func @hello(%arg0 : f32) { %tid_x = gpu.thread_id x %csti8 = arith.constant 2 : i8 - gpu.printf "Hello from %lld, %d, %f\n" %tid_x, %csti8, %arg0 : index, i8, f32 + gpu.printf "Hello from %lld, %d, %f\n", %tid_x, %csti8, %arg0 : index, i8, f32 return } // CHECK-LABEL: @hello_indirect diff --git a/mlir/test/Dialect/GPU/ops.mlir b/mlir/test/Dialect/GPU/ops.mlir index c0ff2044b76c..99915c493ea4 100644 --- a/mlir/test/Dialect/GPU/ops.mlir +++ b/mlir/test/Dialect/GPU/ops.mlir @@ -229,9 +229,22 @@ module attributes {gpu.container_module} { // CHECK-LABEL: gpu.func @printf_test // CHECK: (%[[ARG0:.*]]: i32) - // CHECK: gpu.printf "Value: %d" %[[ARG0]] : i32 + // CHECK: gpu.printf "Value: %d", %[[ARG0]] : i32 gpu.func @printf_test(%arg0 : i32) { - gpu.printf "Value: %d" %arg0 : i32 + gpu.printf "Value: %d", %arg0 : i32 + gpu.return + } + + // CHECK-LABEL: gpu.func @printf_empty + // CHECK: gpu.printf "]" + // CHECK: scf.if + // CHECK: gpu.printf ", " + gpu.func @printf_empty(%arg0 : i32) { + gpu.printf "]" + %1 = arith.cmpi slt, %arg0, %arg0 : i32 + scf.if %1 { + gpu.printf ", " + } gpu.return } diff --git a/mlir/test/Dialect/GPU/test-nvvm-pipeline.mlir b/mlir/test/Dialect/GPU/test-nvvm-pipeline.mlir index 732f40c4333d..f02b26dba97d 100644 --- a/mlir/test/Dialect/GPU/test-nvvm-pipeline.mlir +++ b/mlir/test/Dialect/GPU/test-nvvm-pipeline.mlir @@ -23,7 +23,7 @@ func.func @test_math(%arg0 : f32) { threads(%6, %7, %8) in (%9 = %c2, %10 = %c1, %11 = %c1) { // CHECK-NVVM: __nv_expf %s1 = math.exp %arg0 : f32 - gpu.printf "%f" %s1 : f32 + gpu.printf "%f", %s1 : f32 gpu.terminator } return diff --git a/mlir/test/Dialect/LLVM/lower-to-llvm-e2e-with-target-tag.mlir b/mlir/test/Dialect/LLVM/lower-to-llvm-e2e-with-target-tag.mlir index f6d3387d99b3..2785b5088612 100644 --- a/mlir/test/Dialect/LLVM/lower-to-llvm-e2e-with-target-tag.mlir +++ b/mlir/test/Dialect/LLVM/lower-to-llvm-e2e-with-target-tag.mlir @@ -28,7 +28,7 @@ func.func @subview(%0 : memref<64x4xf32, strided<[4, 1], offset: 0>>, %arg0 : in // CHECK-SAME: -> !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64> // CHECK-DAG: %[[STRIDE0:.*]] = llvm.mlir.constant(4 : index) : i64 - // CHECK-DAG: %[[DESCSTRIDE0:.*]] = llvm.mul %[[ARG0]], %[[STRIDE0]] : i64 + // CHECK-DAG: %[[DESCSTRIDE0:.*]] = llvm.mul %[[ARG0]], %[[STRIDE0]] overflow<nsw> : i64 // CHECK-DAG: %[[OFF2:.*]] = llvm.add %[[DESCSTRIDE0]], %[[ARG1]] : i64 // CHECK-DAG: %[[DESC:.*]] = llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> diff --git a/mlir/test/Dialect/LLVM/lower-to-llvm-e2e-with-top-level-named-sequence.mlir b/mlir/test/Dialect/LLVM/lower-to-llvm-e2e-with-top-level-named-sequence.mlir index a74553cc2268..c1f30c7eaf64 100644 --- a/mlir/test/Dialect/LLVM/lower-to-llvm-e2e-with-top-level-named-sequence.mlir +++ b/mlir/test/Dialect/LLVM/lower-to-llvm-e2e-with-top-level-named-sequence.mlir @@ -27,7 +27,7 @@ func.func @subview(%0 : memref<64x4xf32, strided<[4, 1], offset: 0>>, %arg0 : in // CHECK-SAME: -> !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64> // CHECK-DAG: %[[STRIDE0:.*]] = llvm.mlir.constant(4 : index) : i64 - // CHECK-DAG: %[[DESCSTRIDE0:.*]] = llvm.mul %[[ARG0]], %[[STRIDE0]] : i64 + // CHECK-DAG: %[[DESCSTRIDE0:.*]] = llvm.mul %[[ARG0]], %[[STRIDE0]] overflow<nsw> : i64 // CHECK-DAG: %[[OFF2:.*]] = llvm.add %[[DESCSTRIDE0]], %[[ARG1]] : i64 // CHECK-DAG: %[[DESC:.*]] = llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> diff --git a/mlir/test/Dialect/LLVMIR/roundtrip.mlir b/mlir/test/Dialect/LLVMIR/roundtrip.mlir index aebfd7492093..88660ce598f3 100644 --- a/mlir/test/Dialect/LLVMIR/roundtrip.mlir +++ b/mlir/test/Dialect/LLVMIR/roundtrip.mlir @@ -750,6 +750,16 @@ llvm.func @experimental_noalias_scope_decl() { llvm.return } +#alias_scope_domain2 = #llvm.alias_scope_domain<id = "domainid", description = "The domain"> +#alias_scope2 = #llvm.alias_scope<id = "stringid", domain = #alias_scope_domain2, description = "The domain"> + +// CHECK-LABEL: @experimental_noalias_scope_with_string_id +llvm.func @experimental_noalias_scope_with_string_id() { + // CHECK: llvm.intr.experimental.noalias.scope.decl #{{.*}} + llvm.intr.experimental.noalias.scope.decl #alias_scope2 + llvm.return +} + // CHECK-LABEL: @experimental_constrained_fptrunc llvm.func @experimental_constrained_fptrunc(%in: f64) { // CHECK: llvm.intr.experimental.constrained.fptrunc %{{.*}} towardzero ignore : f64 to f32 diff --git a/mlir/test/Dialect/Linalg/decompose-tensor-unpack-tile.mlir b/mlir/test/Dialect/Linalg/decompose-tensor-unpack-tile.mlir index 6d9709caf709..0dbdf470bbfc 100644 --- a/mlir/test/Dialect/Linalg/decompose-tensor-unpack-tile.mlir +++ b/mlir/test/Dialect/Linalg/decompose-tensor-unpack-tile.mlir @@ -1,4 +1,7 @@ -// RUN: mlir-opt -split-input-file --transform-interpreter --canonicalize --test-linalg-transform-patterns="test-decompose-tensor-unpack" %s | FileCheck %s +// RUN: mlir-opt -split-input-file -transform-interpreter --canonicalize \ +// RUN: -transform-preload-library='transform-library-paths=%p/td/decompose-unpack.mlir' \ +// RUN: -transform-interpreter=entry-point=decompose_unpack \ +// RUN: -transform-interpreter %s | FileCheck %s func.func @KCRSsr_to_KCRS(%arg0: tensor<1x1x4x8x8x32xf32>, %arg1: tensor<1x1x128x64xf32>) -> tensor<1x1x128x64xf32> { %0 = tensor.unpack %arg0 inner_dims_pos = [3, 2] inner_tiles = [8, 32] into %arg1 : tensor<1x1x4x8x8x32xf32> -> tensor<1x1x128x64xf32> diff --git a/mlir/test/Dialect/Linalg/decompose-tensor-unpack.mlir b/mlir/test/Dialect/Linalg/decompose-tensor-unpack.mlir index bd60504f5334..ba1f21495256 100644 --- a/mlir/test/Dialect/Linalg/decompose-tensor-unpack.mlir +++ b/mlir/test/Dialect/Linalg/decompose-tensor-unpack.mlir @@ -1,4 +1,6 @@ -// RUN: mlir-opt -split-input-file --test-linalg-transform-patterns="test-decompose-tensor-unpack" %s | FileCheck %s +// RUN: mlir-opt -split-input-file \ +// RUN: -transform-preload-library='transform-library-paths=%p/td/decompose-unpack.mlir' \ +// RUN: -transform-interpreter=entry-point=decompose_unpack %s | FileCheck %s func.func @simple_KCRSsr_to_KCRS(%arg0: tensor<1x1x1x1x8x32xf32>, %arg1: tensor<1x1x32x8xf32>) -> tensor<1x1x32x8xf32> { %0 = tensor.unpack %arg0 inner_dims_pos = [3, 2] inner_tiles = [8, 32] into %arg1 : tensor<1x1x1x1x8x32xf32> -> tensor<1x1x32x8xf32> diff --git a/mlir/test/Dialect/Linalg/td/decompose-unpack.mlir b/mlir/test/Dialect/Linalg/td/decompose-unpack.mlir new file mode 100644 index 000000000000..11243634262e --- /dev/null +++ b/mlir/test/Dialect/Linalg/td/decompose-unpack.mlir @@ -0,0 +1,12 @@ +module @transforms attributes { transform.with_named_sequence } { + transform.named_sequence @decompose_unpack(%module: !transform.any_op {transform.readonly}) { + %pack = transform.structured.match ops{["tensor.unpack"]} in %module : (!transform.any_op) -> !transform.any_op + + %1 = transform.get_parent_op %pack {isolated_from_above} : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %1 { + transform.apply_patterns.linalg.decompose_pack_unpack + } : !transform.any_op + + transform.yield + } +} diff --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir index 8c4e7a41ee6b..828758df6d31 100644 --- a/mlir/test/Dialect/SCF/canonicalize.mlir +++ b/mlir/test/Dialect/SCF/canonicalize.mlir @@ -408,6 +408,20 @@ func.func @for_yields_4() -> i32 { // ----- +// CHECK-LABEL: @constant_iter_arg +func.func @constant_iter_arg(%arg0: index, %arg1: index, %arg2: index) { + %c0_i32 = arith.constant 0 : i32 + // CHECK: scf.for %arg3 = %arg0 to %arg1 step %arg2 { + %0 = scf.for %i = %arg0 to %arg1 step %arg2 iter_args(%arg3 = %c0_i32) -> i32 { + // CHECK-NEXT: "test.use"(%c0_i32) + "test.use"(%arg3) : (i32) -> () + scf.yield %c0_i32 : i32 + } + return +} + +// ----- + // CHECK-LABEL: @replace_true_if func.func @replace_true_if() { %true = arith.constant true @@ -1789,7 +1803,7 @@ module { } // CHECK-LABEL: @fold_iter_args_not_being_modified_within_scfforall // CHECK-SAME: (%{{.*}}: index, %[[ARG1:.*]]: tensor<?xf32>, %[[ARG2:.*]]: tensor<?xf32>) -> (tensor<?xf32>, tensor<?xf32>) { -// CHECK: %[[RESULT:.*]] = scf.forall +// CHECK: %[[RESULT:.*]] = scf.forall // CHECK-SAME: shared_outs(%[[ITER_ARG_5:.*]] = %[[ARG2]]) -> (tensor<?xf32>) { // CHECK: %[[OPERAND0:.*]] = tensor.extract_slice %[[ARG1]] // CHECK: %[[OPERAND1:.*]] = tensor.extract_slice %[[ITER_ARG_5]] @@ -1832,7 +1846,7 @@ module { } // CHECK-LABEL: @fold_iter_args_with_no_use_of_result_scfforall // CHECK-SAME: (%{{.*}}: index, %[[ARG1:.*]]: tensor<?xf32>, %[[ARG2:.*]]: tensor<?xf32>, %[[ARG3:.*]]: tensor<?xf32>) -> tensor<?xf32> { -// CHECK: %[[RESULT:.*]] = scf.forall +// CHECK: %[[RESULT:.*]] = scf.forall // CHECK-SAME: shared_outs(%[[ITER_ARG_6:.*]] = %[[ARG2]]) -> (tensor<?xf32>) { // CHECK: %[[OPERAND0:.*]] = tensor.extract_slice %[[ARG1]] // CHECK: %[[OPERAND1:.*]] = tensor.extract_slice %[[ARG3]] @@ -1856,7 +1870,7 @@ func.func @index_switch_fold() -> (f32, f32) { %y = arith.constant 42.0 : f32 scf.yield %y : f32 } - + %switch_cst_2 = arith.constant 2: index %1 = scf.index_switch %switch_cst_2 -> f32 case 0 { @@ -1867,7 +1881,7 @@ func.func @index_switch_fold() -> (f32, f32) { %y = arith.constant 42.0 : f32 scf.yield %y : f32 } - + return %0, %1 : f32, f32 } diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir index e8fc4ce834e1..01d14871072c 100644 --- a/mlir/test/Dialect/Tensor/canonicalize.mlir +++ b/mlir/test/Dialect/Tensor/canonicalize.mlir @@ -2786,6 +2786,7 @@ func.func @fold_cast_multiple_results(%arg0: tensor<2x2xf32>, %arg1: tensor<2x2x %0:2 = test.destination_style_op ins(%cast : tensor<?x2xf32>) outs(%cast_0 : tensor<?x2xf32>) -> tensor<?x2xf32>, index return %0#1 : index } + // ----- // CHECK-LABEL: func.func @fold_cast_pack_dynamic_tile_size @@ -2794,7 +2795,7 @@ func.func @fold_cast_multiple_results(%arg0: tensor<2x2xf32>, %arg1: tensor<2x2x // CHECK-SAME: %[[PAD:.*]]: i32) -> tensor<1x1x8x1xi32> { // CHECK: %[[PACK:.*]] = tensor.pack %[[SRC]] padding_value(%[[PAD]] : i32) // CHECK-SAME: inner_dims_pos = [0, 1] inner_tiles = [8, 1] into %[[DEST]] -// CHECK-SAME: some_attr +// CHECK-SAME: test_attr // CHECK-SAME: : tensor<7x?xi32> -> tensor<1x1x8x1xi32> // CHECK: return %[[PACK]] : tensor<1x1x8x1xi32> func.func @fold_cast_pack_dynamic_tile_size( @@ -2807,13 +2808,33 @@ func.func @fold_cast_pack_dynamic_tile_size( %pack = tensor.pack %src padding_value(%pad : i32) inner_dims_pos = [0, 1] inner_tiles = [%c8, 1] - into %cast {some_attr} : tensor<7x?xi32> -> tensor<1x1x?x1xi32> + into %cast {test_attr} : tensor<7x?xi32> -> tensor<1x1x?x1xi32> %res = tensor.cast %pack : tensor<1x1x?x1xi32> to tensor<1x1x8x1xi32> return %res : tensor<1x1x8x1xi32> } // ----- +// CHECK-LABEL: func.func @fold_cast_unpack_dynamic_tile_size( +// CHECK-SAME: %[[SRC:.*]]: tensor<1x1x8x1xi32>, +// CHECK-SAME: %[[DEST:.*]]: tensor<7x?xi32>) -> tensor<7x?xi32> { +// CHECK: %[[RES:.*]] = tensor.unpack %[[SRC]] inner_dims_pos = [0, 1] inner_tiles = [8, 1] into %[[DEST]] {test_attr} : tensor<1x1x8x1xi32> -> tensor<7x?xi32> +// CHECK: return %[[RES]] : tensor<7x?xi32> +func.func @fold_cast_unpack_dynamic_tile_size( + %src: tensor<1x1x8x1xi32>, + %res: tensor<7x?xi32>) -> tensor<7x?xi32> { + + %cast = tensor.cast %src : tensor<1x1x8x1xi32> to tensor<1x1x?x1xi32> + %c8 = arith.constant 8 : index + %unpack = tensor.unpack %cast + inner_dims_pos = [0, 1] + inner_tiles = [%c8, 1] + into %res {test_attr} : tensor<1x1x?x1xi32> -> tensor<7x?xi32> + return %unpack : tensor<7x?xi32> +} + +// ----- + // CHECK-LABEL: func.func @pack_dont_drop_attributes( // CHECK: tensor.pack {{.*}} {test_attr} func.func @pack_dont_drop_attributes(%arg0: tensor<?x?x?xf16>, %arg1: tensor<128x?x100x16x1xf16>) -> tensor<128x?x100x16x1xf16> { diff --git a/mlir/test/Dialect/Tensor/invalid.mlir b/mlir/test/Dialect/Tensor/invalid.mlir index 83cb4b9d4ab2..1de3e281bc46 100644 --- a/mlir/test/Dialect/Tensor/invalid.mlir +++ b/mlir/test/Dialect/Tensor/invalid.mlir @@ -699,7 +699,7 @@ func.func @pack_invalid_output_rank(%input: tensor<256x128xf32>, %output: tensor // ----- -func.func @pack_invalid_output_rank(%input: tensor<256x128xf32>, %output: tensor<64x32x16xf32>) -> tensor<256x128xf32> { +func.func @unpack_invalid_output_rank(%input: tensor<256x128xf32>, %output: tensor<64x32x16xf32>) -> tensor<256x128xf32> { // expected-error@+1 {{packed rank != (unpacked rank + num tiling factors), got 3 != 4}} %0 = tensor.unpack %output inner_dims_pos = [0, 1] inner_tiles = [32, 16] into %input : tensor<64x32x16xf32> -> tensor<256x128xf32> return %0 : tensor<256x128xf32> diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir index 67cd01f62f0b..60121bb0ea2f 100644 --- a/mlir/test/Dialect/Tosa/canonicalize.mlir +++ b/mlir/test/Dialect/Tosa/canonicalize.mlir @@ -162,7 +162,7 @@ func.func @conv2d_stride_2(%arg0: tensor<4x10x10x2xf32>) -> tensor<4x10x10x3xf32 // CHECK: tosa.conv2d %weight = "tosa.const"() {value = dense<[[[[1.0, 1.0]]], [[[1.0, 1.0]]], [[[1.0, 1.0]]]]> : tensor<3x1x1x2xf32>} : ()-> tensor<3x1x1x2xf32> %bias = "tosa.const"() {value = dense<0.0> : tensor<3xf32>} : ()-> tensor<3xf32> - %0 = tosa.conv2d %arg0, %weight, %bias {pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 2, 2>, dilation = array<i64: 1, 1>} : (tensor<4x10x10x2xf32>, tensor<3x1x1x2xf32>, tensor<3xf32>) -> tensor<4x10x10x3xf32> + %0 = tosa.conv2d %arg0, %weight, %bias {acc_type = f32, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 2, 2>, dilation = array<i64: 1, 1>} : (tensor<4x10x10x2xf32>, tensor<3x1x1x2xf32>, tensor<3xf32>) -> tensor<4x10x10x3xf32> return %0 : tensor<4x10x10x3xf32> } @@ -173,7 +173,7 @@ func.func @conv2d_weight_2x2(%arg0: tensor<4x10x10x1xf32>) -> tensor<4x10x10x1xf // CHECK: tosa.conv2d %weight = "tosa.const"() {value = dense<[[[[1.0], [1.0]], [[1.0], [1.0]]]]> : tensor<1x2x2x1xf32>} : ()-> tensor<1x2x2x1xf32> %bias = "tosa.const"() {value = dense<0.0> : tensor<1xf32>} : ()-> tensor<1xf32> - %0 = tosa.conv2d %arg0, %weight, %bias {pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, dilation = array<i64: 1, 1>} : (tensor<4x10x10x1xf32>, tensor<1x2x2x1xf32>, tensor<1xf32>) -> tensor<4x10x10x1xf32> + %0 = tosa.conv2d %arg0, %weight, %bias {acc_type = f32, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, dilation = array<i64: 1, 1>} : (tensor<4x10x10x1xf32>, tensor<1x2x2x1xf32>, tensor<1xf32>) -> tensor<4x10x10x1xf32> return %0 : tensor<4x10x10x1xf32> } @@ -182,7 +182,7 @@ func.func @conv2d_weight_2x2(%arg0: tensor<4x10x10x1xf32>) -> tensor<4x10x10x1xf // CHECK-LABEL: @depthwise_conv2d_stride_2 func.func @depthwise_conv2d_stride_2(%arg0: tensor<4x10x10x2xf32>, %arg1: tensor<1x1x2x3xf32>, %arg2: tensor<6xf32>) -> tensor<4x10x10x6xf32> { // CHECK: tosa.depthwise_conv2d - %0 = tosa.depthwise_conv2d %arg0, %arg1, %arg2 {pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 2, 2>, dilation = array<i64: 1, 1>} : (tensor<4x10x10x2xf32>, tensor<1x1x2x3xf32>, tensor<6xf32>) -> tensor<4x10x10x6xf32> + %0 = tosa.depthwise_conv2d %arg0, %arg1, %arg2 {acc_type = f32, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 2, 2>, dilation = array<i64: 1, 1>} : (tensor<4x10x10x2xf32>, tensor<1x1x2x3xf32>, tensor<6xf32>) -> tensor<4x10x10x6xf32> return %0 : tensor<4x10x10x6xf32> } @@ -191,7 +191,7 @@ func.func @depthwise_conv2d_stride_2(%arg0: tensor<4x10x10x2xf32>, %arg1: tensor // CHECK-LABEL: @depthwise_conv2d_weight_2x2 func.func @depthwise_conv2d_weight_2x2(%arg0: tensor<4x10x10x2xf32>, %arg1: tensor<2x2x2x3xf32>, %arg2: tensor<6xf32>) -> tensor<4x10x10x6xf32> { // CHECK: tosa.depthwise_conv2d - %0 = tosa.depthwise_conv2d %arg0, %arg1, %arg2 {pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, dilation = array<i64: 1, 1>} : (tensor<4x10x10x2xf32>, tensor<2x2x2x3xf32>, tensor<6xf32>) -> tensor<4x10x10x6xf32> + %0 = tosa.depthwise_conv2d %arg0, %arg1, %arg2 {acc_type = f32, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, dilation = array<i64: 1, 1>} : (tensor<4x10x10x2xf32>, tensor<2x2x2x3xf32>, tensor<6xf32>) -> tensor<4x10x10x6xf32> return %0 : tensor<4x10x10x6xf32> } @@ -210,8 +210,8 @@ func.func @max_pool2d_is_noop(%arg0: tensor<10x1x1x3xf32>) -> tensor<10x1x1x3xf3 // CHECK-LABEL: @pad_noop func.func @pad_noop(%arg0: tensor<?x?xf32>) -> tensor<?x?xf32> { // CHECK: return %arg0 - %0 = "tosa.const"() { value = dense<0> : tensor<2x2xi32>} : () -> tensor<2x2xi32> - %1 = tosa.pad %arg0, %0 : (tensor<?x?xf32>, tensor<2x2xi32>) -> tensor<?x?xf32> + %0 = "tosa.const"() { value = dense<0> : tensor<4xi32>} : () -> tensor<4xi32> + %1 = tosa.pad %arg0, %0 : (tensor<?x?xf32>, tensor<4xi32>) -> tensor<?x?xf32> return %1 : tensor<?x?xf32> } @@ -221,8 +221,8 @@ func.func @pad_noop(%arg0: tensor<?x?xf32>) -> tensor<?x?xf32> { func.func @pad_noop_padding_mismatch_nofold(%arg0: tensor<?x?xf32>) -> tensor<?x?xf32> { // CHECK: %[[PAD:.+]] = tosa.pad // CHECK: return %[[PAD]] - %0 = "tosa.const"() { value = dense_resource<__elided__> : tensor<2x2xi32>} : () -> tensor<2x2xi32> - %1 = tosa.pad %arg0, %0 : (tensor<?x?xf32>, tensor<2x2xi32>) -> tensor<?x?xf32> + %0 = "tosa.const"() { value = dense_resource<__elided__> : tensor<4xi32>} : () -> tensor<4xi32> + %1 = tosa.pad %arg0, %0 : (tensor<?x?xf32>, tensor<4xi32>) -> tensor<?x?xf32> return %1 : tensor<?x?xf32> } @@ -234,42 +234,39 @@ func.func @pad_noop_type_mismatch_nofold(%arg0: tensor<10xf32>) -> tensor<?xf32> // CHECK: return %[[PAD]] %c0_i32 = arith.constant 0 : i32 - %shape = tensor.from_elements %c0_i32, %c0_i32 : tensor<1x2xi32> + %shape = tensor.from_elements %c0_i32, %c0_i32 : tensor<2xi32> - %0 = tosa.pad %arg0, %shape : (tensor<10xf32>, tensor<1x2xi32>) -> tensor<?xf32> + %0 = tosa.pad %arg0, %shape : (tensor<10xf32>, tensor<2xi32>) -> tensor<?xf32> return %0 : tensor<?xf32> } // ----- // CHECK-LABEL: @pad_determine_val_i32 -func.func @pad_determine_val_i32(%arg0: tensor<?x?xi32>, %arg1 : tensor<2x2xi32>) -> tensor<?x?xi32> { +func.func @pad_determine_val_i32(%arg0: tensor<?x?xi32>, %arg1 : tensor<4xi32>) -> tensor<?x?xi32> { // CHECK: %[[ZERO:.+]] = "tosa.const"() <{value = dense<0> : tensor<i32>} // CHECK: tosa.pad %arg0, %arg1, %[[ZERO]] - %0 = "tosa.const"() { value = dense<[[1, 0], [0, 1]]> : tensor<2x2xi32>} : () -> tensor<2x2xi32> - %1 = tosa.pad %arg0, %arg1 : (tensor<?x?xi32>, tensor<2x2xi32>) -> tensor<?x?xi32> + %1 = tosa.pad %arg0, %arg1 : (tensor<?x?xi32>, tensor<4xi32>) -> tensor<?x?xi32> return %1 : tensor<?x?xi32> } // ----- // CHECK-LABEL: @pad_determine_val_f32 -func.func @pad_determine_val_f32(%arg0: tensor<?x?xf32>, %arg1 : tensor<2x2xi32>) -> tensor<?x?xf32> { +func.func @pad_determine_val_f32(%arg0: tensor<?x?xf32>, %arg1 : tensor<4xi32>) -> tensor<?x?xf32> { // CHECK: %[[ZERO:.+]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<f32>} // CHECK: tosa.pad %arg0, %arg1, %[[ZERO]] - %0 = "tosa.const"() { value = dense<[[1, 0], [0, 1]]> : tensor<2x2xi32>} : () -> tensor<2x2xi32> - %1 = tosa.pad %arg0, %arg1 : (tensor<?x?xf32>, tensor<2x2xi32>) -> tensor<?x?xf32> + %1 = tosa.pad %arg0, %arg1 : (tensor<?x?xf32>, tensor<4xi32>) -> tensor<?x?xf32> return %1 : tensor<?x?xf32> } // ----- // CHECK-LABEL: @pad_determine_val_quant -func.func @pad_determine_val_quant(%arg0: tensor<?x?xi32>, %arg1 : tensor<2x2xi32>) -> tensor<?x?xi32> { +func.func @pad_determine_val_quant(%arg0: tensor<?x?xi32>, %arg1 : tensor<4xi32>) -> tensor<?x?xi32> { // CHECK: %[[ZERO:.+]] = "tosa.const"() <{value = dense<42> : tensor<i32>} // CHECK: tosa.pad %arg0, %arg1, %[[ZERO]] - %0 = "tosa.const"() { value = dense<[[1, 0], [0, 1]]> : tensor<2x2xi32>} : () -> tensor<2x2xi32> - %1 = tosa.pad %arg0, %arg1 {quantization_info = #tosa.pad_quant<input_zp = 42>} : (tensor<?x?xi32>, tensor<2x2xi32>) -> tensor<?x?xi32> + %1 = tosa.pad %arg0, %arg1 {quantization_info = #tosa.pad_quant<input_zp = 42>} : (tensor<?x?xi32>, tensor<4xi32>) -> tensor<?x?xi32> return %1 : tensor<?x?xi32> } diff --git a/mlir/test/Dialect/Tosa/constant-op-fold.mlir b/mlir/test/Dialect/Tosa/constant-op-fold.mlir index 2902c4a62009..8198903b78ac 100644 --- a/mlir/test/Dialect/Tosa/constant-op-fold.mlir +++ b/mlir/test/Dialect/Tosa/constant-op-fold.mlir @@ -117,15 +117,6 @@ func.func @transpose_nofold_multi_users() -> (tensor<3x2xf32>, tensor<2x3xf32>) return %1, %input : tensor<3x2xf32>, tensor<2x3xf32> } -// CHECK-LABEL: @transpose_nofold_quantized_types -func.func @transpose_nofold_quantized_types() -> tensor<1x1x2x2x!quant.uniform<i8<-127:127>:f32:3, {1.000000e-01,1.000000e-01}>> { - %perms = "tosa.const"() {value = dense<[1, 2, 3, 0]> : tensor<4xi32>} : () -> tensor<4xi32> - %input = "tosa.const"() {value = dense<-127> : tensor<2x1x1x2xi8>} : () -> tensor<2x1x1x2xi8> - // CHECK: tosa.transpose - %0 = tosa.transpose %input, %perms : (tensor<2x1x1x2xi8>, tensor<4xi32>) -> tensor<1x1x2x2x!quant.uniform<i8<-127:127>:f32:3, {1.000000e-01,1.000000e-01}>> - return %0: tensor<1x1x2x2x!quant.uniform<i8<-127:127>:f32:3, {1.000000e-01,1.000000e-01}>> -} - // CHECK-LABEL: @transpose_nofold_dense_resource func.func @transpose_nofold_dense_resource() -> tensor<2x2xf32> { %0 = "tosa.const"() <{value = dense_resource<resource> : tensor<2x2xf32>}> : () -> tensor<2x2xf32> diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir index cca50b25d14d..a6d57f8a2f61 100644 --- a/mlir/test/Dialect/Tosa/invalid.mlir +++ b/mlir/test/Dialect/Tosa/invalid.mlir @@ -25,7 +25,7 @@ func.func @test_const_non_tensor_attr() { func.func @test_conv2d(%arg0: tensor<1x29x29x4xf32>, %arg1: tensor<16x3x3x4xi8>, %arg2: tensor<16xi8>) -> tensor<1x27x27x16xi8> { // expected-error@+1 {{expect both input and weight to be float or not together, got 'f32' and 'i8'}} - %0 = tosa.conv2d %arg0, %arg1, %arg2 {dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} + %0 = tosa.conv2d %arg0, %arg1, %arg2 {acc_type = i32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<1x29x29x4xf32>, tensor<16x3x3x4xi8>, tensor<16xi8>) -> tensor<1x27x27x16xi8> return %0 : tensor<1x27x27x16xi8> } @@ -34,7 +34,7 @@ func.func @test_conv2d(%arg0: tensor<1x29x29x4xf32>, %arg1: tensor<16x3x3x4xi8>, func.func @test_conv2d(%arg0: tensor<*xi8>, %arg1: tensor<16x3x3x4xi8>, %arg2: tensor<16xi8>) -> tensor<1x27x27x16xi8> { // expected-error@+1 {{expect a ranked tensor for input, got <block argument> of type 'tensor<*xi8>' at index: 0}} - %0 = tosa.conv2d %arg0, %arg1, %arg2 {dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} + %0 = tosa.conv2d %arg0, %arg1, %arg2 {acc_type = i32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<*xi8>, tensor<16x3x3x4xi8>, tensor<16xi8>) -> tensor<1x27x27x16xi8> return %0 : tensor<1x27x27x16xi8> } @@ -43,7 +43,7 @@ func.func @test_conv2d(%arg0: tensor<*xi8>, %arg1: tensor<16x3x3x4xi8>, %arg2: t func.func @test_conv2d(%arg0: tensor<1x29x29x4xi8>, %arg1: tensor<*xi8>, %arg2: tensor<16xi8>) -> tensor<1x27x27x16xi8> { // expected-error@+1 {{'tosa.conv2d' op operand #1 must be 4D tensor of 4-bit signless integer or 8-bit signless integer or Quint8 type or Qint4 type or Qint8 type or Qint16 type or Qint32 type or floating-point values, but got 'tensor<*xi8>'}} - %0 = tosa.conv2d %arg0, %arg1, %arg2 {dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} + %0 = tosa.conv2d %arg0, %arg1, %arg2 {acc_type = i32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<1x29x29x4xi8>, tensor<*xi8>, tensor<16xi8>) -> tensor<1x27x27x16xi8> return %0 : tensor<1x27x27x16xi8> } @@ -52,13 +52,101 @@ func.func @test_conv2d(%arg0: tensor<1x29x29x4xi8>, %arg1: tensor<*xi8>, %arg2: func.func @test_conv2d(%arg0: tensor<1x29x29x4xi8>, %arg1: tensor<16x3x3x4xi8>, %arg2: tensor<16xi8>) -> tensor<1x27x27x16xi8> { // expected-error@+1 {{'tosa.conv2d' op quantizationattr is required for quantized type, and not allowed for float type}} - %0 = tosa.conv2d %arg0, %arg1, %arg2 {dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} + %0 = tosa.conv2d %arg0, %arg1, %arg2 {acc_type = f16, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<1x29x29x4xi8>, tensor<16x3x3x4xi8>, tensor<16xi8>) -> tensor<1x27x27x16xi8> return %0 : tensor<1x27x27x16xi8> } // ----- +func.func @test_conv2d_acc_type(%arg0: tensor<1x29x29x4xi8>, %arg1: tensor<16x3x3x4xi8>, %arg2: tensor<16xi8>) -> tensor<1x27x27x16xi8> { + // expected-error@+1 {{'tosa.conv2d' op accumulator type for i8 tensor is not i32}} + %0 = tosa.conv2d %arg0, %arg1, %arg2 {acc_type = f16, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, quantization_info = #tosa.conv_quant<input_zp = 0, weight_zp = 0>} + : (tensor<1x29x29x4xi8>, tensor<16x3x3x4xi8>, tensor<16xi8>) -> tensor<1x27x27x16xi8> + return %0 : tensor<1x27x27x16xi8> +} + +// ----- + +func.func @test_conv2d_acc_type(%arg0: tensor<1x29x29x4xi16>, %arg1: tensor<16x3x3x4xi8>, %arg2: tensor<16xi16>) -> tensor<1x27x27x16xi16> { + // expected-error@+1 {{'tosa.conv2d' op accumulator type for i16 tensor is not i48}} + %0 = tosa.conv2d %arg0, %arg1, %arg2 {acc_type = f16, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, quantization_info = #tosa.conv_quant<input_zp = 0, weight_zp = 0>} + : (tensor<1x29x29x4xi16>, tensor<16x3x3x4xi8>, tensor<16xi16>) -> tensor<1x27x27x16xi16> + return %0 : tensor<1x27x27x16xi16> +} + +// ----- + +func.func @test_conv2d_acc_type(%arg0: tensor<1x29x29x4xf8E5M2>, %arg1: tensor<16x3x3x4xf8E5M2>, %arg2: tensor<16xf16>) -> tensor<1x27x27x16xf16> { + // expected-error@+1 {{'tosa.conv2d' op accumulator type for f8 tensor is not f16}} + %0 = tosa.conv2d %arg0, %arg1, %arg2 {acc_type = i32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} + : (tensor<1x29x29x4xf8E5M2>, tensor<16x3x3x4xf8E5M2>, tensor<16xf16>) -> tensor<1x27x27x16xf16> + return %0 : tensor<1x27x27x16xf16> +} + +// ----- + +func.func @test_conv2d_acc_type(%arg0: tensor<1x29x29x4xf8E4M3>, %arg1: tensor<16x3x3x4xf8E4M3>, %arg2: tensor<16xf16>) -> tensor<1x27x27x16xf16> { + // expected-error@+1 {{'tosa.conv2d' op accumulator type for f8 tensor is not f16}} + %0 = tosa.conv2d %arg0, %arg1, %arg2 {acc_type = i32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} + : (tensor<1x29x29x4xf8E4M3>, tensor<16x3x3x4xf8E4M3>, tensor<16xf16>) -> tensor<1x27x27x16xf16> + return %0 : tensor<1x27x27x16xf16> +} + +// ----- + +func.func @test_conv2d_acc_type(%arg0: tensor<1x29x29x4xf16>, %arg1: tensor<16x3x3x4xf16>, %arg2: tensor<16xf16>) -> tensor<1x27x27x16xf16> { + // expected-error@+1 {{'tosa.conv2d' op accumulator type for f16 tensor is not f16/f32}} + %0 = tosa.conv2d %arg0, %arg1, %arg2 {acc_type = i32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} + : (tensor<1x29x29x4xf16>, tensor<16x3x3x4xf16>, tensor<16xf16>) -> tensor<1x27x27x16xf16> + return %0 : tensor<1x27x27x16xf16> +} + +// ----- + +func.func @test_conv2d_acc_type(%arg0: tensor<1x29x29x4xbf16>, %arg1: tensor<16x3x3x4xbf16>, %arg2: tensor<16xbf16>) -> tensor<1x27x27x16xbf16> { + // expected-error@+1 {{'tosa.conv2d' op accumulator type for bf16 tensor is not f32}} + %0 = tosa.conv2d %arg0, %arg1, %arg2 {acc_type = i32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} + : (tensor<1x29x29x4xbf16>, tensor<16x3x3x4xbf16>, tensor<16xbf16>) -> tensor<1x27x27x16xbf16> + return %0 : tensor<1x27x27x16xbf16> +} + +// ----- + +func.func @test_conv2d_acc_type(%arg0: tensor<1x29x29x4xf32>, %arg1: tensor<16x3x3x4xf32>, %arg2: tensor<16xf32>) -> tensor<1x27x27x16xf32> { + // expected-error@+1 {{'tosa.conv2d' op accumulator type for f32 tensor is not f32}} + %0 = tosa.conv2d %arg0, %arg1, %arg2 {acc_type = i32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} + : (tensor<1x29x29x4xf32>, tensor<16x3x3x4xf32>, tensor<16xf32>) -> tensor<1x27x27x16xf32> + return %0 : tensor<1x27x27x16xf32> +} + +// ----- + +func.func @test_conv3d_acc_type(%arg0: tensor<1x4x8x21x17xi8>, %arg1: tensor<34x1x1x1x17xi8>, %arg2: tensor<34xi8>) -> tensor<1x4x8x21x34xi8> { + // expected-error@+1 {{'tosa.conv3d' op accumulator type for i8 tensor is not i32}} + %0 = tosa.conv3d %arg0, %arg1, %arg2 {acc_type = f16, dilation = array<i64: 1, 1, 1>, pad = array<i64: 0, 0, 0, 0, 0, 0>, stride = array<i64: 1, 1, 1>, quantization_info = #tosa.conv_quant<input_zp = 0, weight_zp = 0>} + : (tensor<1x4x8x21x17xi8>, tensor<34x1x1x1x17xi8>, tensor<34xi8>) -> tensor<1x4x8x21x34xi8> + return %0 : tensor<1x4x8x21x34xi8> +} + +// ----- + +func.func @test_depthwise_conv2d_acc_type(%arg0: tensor<1x4x4x4xi8>, %arg1: tensor<1x1x4x2xi8>, %arg2: tensor<8xi8>) -> tensor<1x4x4x8xi8> { + // expected-error@+1 {{'tosa.depthwise_conv2d' op accumulator type for i8 tensor is not i32}} + %0 = tosa.depthwise_conv2d %arg0, %arg1, %arg2 {acc_type = f16, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, quantization_info = #tosa.conv_quant<input_zp = 0, weight_zp = 0>} : (tensor<1x4x4x4xi8>, tensor<1x1x4x2xi8>, tensor<8xi8>) -> tensor<1x4x4x8xi8> + return %0 : tensor<1x4x4x8xi8> +} + +// ----- + +func.func @test_transpose_conv2d(%arg0: tensor<1x32x32x8xi8>, %arg1: tensor<16x1x1x8xi8>, %arg2: tensor<16xi8>) -> tensor<1x32x32x16xi8> { + // expected-error@+1 {{'tosa.transpose_conv2d' op accumulator type for i8 tensor is not i32}} + %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2 {acc_type = f16, out_pad = array<i64: 0, 0, 0, 0>, out_shape = array<i64: 1, 32, 32, 16>, stride = array<i64: 1, 1>, quantization_info = #tosa.conv_quant<input_zp = 0, weight_zp = 0>} : (tensor<1x32x32x8xi8>, tensor<16x1x1x8xi8>, tensor<16xi8>) -> tensor<1x32x32x16xi8> + return %0 : tensor<1x32x32x16xi8> +} + +// ----- + func.func @test_concat(%arg0 : tensor<2x1xf32>, %arg1 : tensor<2x2xf32>) -> tensor<?x?xf32> { // expected-error@+2 {{failed to infer returned types}} // expected-error@+1 {{Cannot concat tensors with different sizes on the non-axis dimension 1}} @@ -77,48 +165,56 @@ func.func @test_concat_element_type_mismatch(%arg0 : tensor<1x2xf32>, %arg1 : te // ----- -func.func @test_pad_non_const(%arg0: tensor<13x21x3xf32>, %arg1: tensor<3x2xi32>) -> tensor<13x21x3xf32> { +func.func @test_pad_non_const(%arg0: tensor<13x21x3xf32>, %arg1: tensor<6xi32>) -> tensor<13x21x3xf32> { // expected-error@+1 {{'tosa.pad' op padding of pad is not constant}} - %0 = tosa.pad %arg0, %arg1 : (tensor<13x21x3xf32>, tensor<3x2xi32>) -> tensor<13x21x3xf32> + %0 = tosa.pad %arg0, %arg1 : (tensor<13x21x3xf32>, tensor<6xi32>) -> tensor<13x21x3xf32> return %0 : tensor<13x21x3xf32> } // ----- func.func @test_pad_non_const(%arg0: tensor<13x21x3xi8>, %arg1: tensor<i8>) -> tensor<13x21x3xi8> { - %0 = "tosa.const"() {value = dense<[[0, 0], [0, 1], [0, 1]]> : tensor<3x2xi32>} : () -> tensor<3x2xi32> + %0 = "tosa.const"() {value = dense<[0, 0, 0, 1, 0, 1]> : tensor<6xi32>} : () -> tensor<6xi32> // expected-error@+1 {{'tosa.pad' op pad_const of pad is not constant}} - %1 = tosa.pad %arg0, %0, %arg1 : (tensor<13x21x3xi8>, tensor<3x2xi32>, tensor<i8>) -> tensor<13x21x3xi8> + %1 = tosa.pad %arg0, %0, %arg1 : (tensor<13x21x3xi8>, tensor<6xi32>, tensor<i8>) -> tensor<13x21x3xi8> return %1 : tensor<13x21x3xi8> } // ----- -func.func @test_pad_io_rank_mismatch(%arg0: tensor<13x21xf32>, %arg1: tensor<2x2xi32>) { +func.func @test_pad_io_rank_mismatch(%arg0: tensor<13x21xf32>, %arg1: tensor<4xi32>) { // expected-error@+1 {{'tosa.pad' op expect same input and output tensor rank.}} - %1 = tosa.pad %arg0, %arg1 : (tensor<13x21xf32>, tensor<2x2xi32>) -> tensor<13x21x3xf32> + %1 = tosa.pad %arg0, %arg1 : (tensor<13x21xf32>, tensor<4xi32>) -> tensor<13x21x3xf32> return } // ----- -func.func @test_pad_invalid_padding_rank(%arg0: tensor<13x21xf32>, %arg1: tensor<2xi32>) { - // expected-error@+1 {{'tosa.pad' op expect 'padding' tensor rank equal to 2.}} - %1 = tosa.pad %arg0, %arg1 : (tensor<13x21xf32>, tensor<2xi32>) -> tensor<13x21xf32> +func.func @test_pad_invalid_padding_rank(%arg0: tensor<13x21xf32>, %arg1: tensor<2x2xi32>) { + // expected-error@+1 {{'tosa.pad' op operand #1 must be 1D tensor of 32-bit signless integer or 64-bit signless integer values, but got 'tensor<2x2xi32>'}} + %1 = tosa.pad %arg0, %arg1 : (tensor<13x21xf32>, tensor<2x2xi32>) -> tensor<13x21xf32> return } // ----- -func.func @test_pad_invalid_padConst_rank(%arg0: tensor<13x21xf32>, %arg1: tensor<2x2xi32>) { +func.func @test_pad_invalid_padConst_rank(%arg0: tensor<13x21xf32>, %arg1: tensor<4xi32>) { %0 = "tosa.const"() {value = dense<3.14> : tensor<1xf32>} : () -> tensor<1xf32> // expected-error@+1 {{'tosa.pad' op operand #2 must be 0D tensor of number values, but got 'tensor<1xf32>'}} - %1 = tosa.pad %arg0, %arg1, %0 : (tensor<13x21xf32>, tensor<2x2xi32>, tensor<1xf32>) -> tensor<13x21xf32> + %1 = tosa.pad %arg0, %arg1, %0 : (tensor<13x21xf32>, tensor<4xi32>, tensor<1xf32>) -> tensor<13x21xf32> return } // ----- +func.func @test_pad_padding_shape_mismatch(%arg0: tensor<13x21x3xf32>, %arg1: tensor<4xi32>) -> tensor<13x21x3xf32> { + // expected-error@+1 {{'tosa.pad' op expected padding tensor dim 0 to have size 6 (2*rank(shape1)) but got size 4}} + %0 = tosa.pad %arg0, %arg1 : (tensor<13x21x3xf32>, tensor<4xi32>) -> tensor<13x21x3xf32> + return %0 : tensor<13x21x3xf32> +} + +// ----- + func.func @test_transpose_non_const(%arg0: tensor<13x21x3xf32>, %arg1: tensor<3xi32>) -> tensor<3x13x21xf32> { // expected-error@+1 {{'tosa.transpose' op perms of transpose is not constant}} %0 = tosa.transpose %arg0, %arg1 : (tensor<13x21x3xf32>, tensor<3xi32>) -> tensor<3x13x21xf32> @@ -206,6 +302,15 @@ func.func @test_transpose_invalid_permutation_types_dynamic_dim_ok(%arg0: tensor // ----- +func.func @test_transpose_element_type_mismatch(%arg0: tensor<2x3xi32>) -> tensor<3x2xf32> { + %perms = "tosa.const"() {value = dense<[1, 0]> : tensor<2xi32>} : () -> tensor<2xi32> + // expected-error@+1 {{'tosa.transpose' op failed to verify that all of {input1, output} have same element type}} + %1 = tosa.transpose %arg0, %perms : (tensor<2x3xi32>, tensor<2xi32>) -> tensor<3x2xf32> + return %1 : tensor<3x2xf32> +} + +// ----- + func.func @test_fully_connected_non_const(%arg0: tensor<13x21x3xf32>, %arg1: tensor<2x3xf32>) -> tensor<273x2xf32> { %0 = "tosa.const"() {value = dense<0.000000e+00> : tensor<2xf32>} : () -> tensor<2xf32> %1 = tosa.reshape %arg0 {new_shape = array<i64: 273, 3>} : (tensor<13x21x3xf32>) -> tensor<273x3xf32> @@ -416,7 +521,7 @@ func.func @test_const_attribute_type_mismatch() -> tensor<100x100xf32> { func.func @test_conv2d_static_zero_dim_input(%arg0: tensor<1x29x0x4xf32>, %arg1: tensor<16x3x3x4xf32>, %arg2: tensor<16xf32>) -> tensor<1x27x27x16xf32> { // expected-error@+1 {{'tosa.conv2d' op operand #0 must be 4-d tosa-conformant tensor, but got 'tensor<1x29x0x4xf32>'}} - %0 = "tosa.conv2d"(%arg0, %arg1, %arg2) {dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} + %0 = "tosa.conv2d"(%arg0, %arg1, %arg2) {acc_type = f32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<1x29x0x4xf32>, tensor<16x3x3x4xf32>, tensor<16xf32>) -> tensor<1x27x27x16xf32> return %0 : tensor<1x27x27x16xf32> } diff --git a/mlir/test/Dialect/Tosa/level_check.mlir b/mlir/test/Dialect/Tosa/level_check.mlir index 529a16ca48c7..ba8ed8a1e5f5 100644 --- a/mlir/test/Dialect/Tosa/level_check.mlir +++ b/mlir/test/Dialect/Tosa/level_check.mlir @@ -226,7 +226,7 @@ func.func @test_avgpool2d_pad_right(%arg0: tensor<1x32x32x8xf32>) -> tensor<1x32 func.func @test_conv2d_dilation_y(%arg0: tensor<1x32x32x8xf32>, %arg1: tensor<16x2x2x8xf32>, %arg2: tensor<16xf32>) -> tensor<1x32x32x16xf32> { // expected-error@+1 {{'tosa.conv2d' op failed level check: dilation_y * KH <= MAX_KERNEL}} - %0 = "tosa.conv2d"(%arg0, %arg1, %arg2) {dilation = array<i64: 4097, 1>, pad = array<i64: 0, 1, 0, 1>, stride = array<i64: 1, 1>} : + %0 = "tosa.conv2d"(%arg0, %arg1, %arg2) {acc_type = f32, dilation = array<i64: 4097, 1>, pad = array<i64: 0, 1, 0, 1>, stride = array<i64: 1, 1>} : (tensor<1x32x32x8xf32>, tensor<16x2x2x8xf32>, tensor<16xf32>) -> tensor<1x32x32x16xf32> return %0 : tensor<1x32x32x16xf32> } @@ -235,7 +235,7 @@ func.func @test_conv2d_dilation_y(%arg0: tensor<1x32x32x8xf32>, %arg1: tensor<16 func.func @test_conv2d_dilation_x(%arg0: tensor<1x32x32x8xf32>, %arg1: tensor<16x2x2x8xf32>, %arg2: tensor<16xf32>) -> tensor<1x32x32x16xf32> { // expected-error@+1 {{'tosa.conv2d' op failed level check: dilation_x * KW <= MAX_KERNEL}} - %0 = "tosa.conv2d"(%arg0, %arg1, %arg2) {dilation = array<i64: 1, 4097>, pad = array<i64: 0, 1, 0, 1>, stride = array<i64: 1, 1>} : + %0 = "tosa.conv2d"(%arg0, %arg1, %arg2) {acc_type = f32, dilation = array<i64: 1, 4097>, pad = array<i64: 0, 1, 0, 1>, stride = array<i64: 1, 1>} : (tensor<1x32x32x8xf32>, tensor<16x2x2x8xf32>, tensor<16xf32>) -> tensor<1x32x32x16xf32> return %0 : tensor<1x32x32x16xf32> } @@ -244,7 +244,7 @@ func.func @test_conv2d_dilation_x(%arg0: tensor<1x32x32x8xf32>, %arg1: tensor<16 func.func @test_conv2d_pad_top(%arg0: tensor<1x32x32x8xf32>, %arg1: tensor<16x2x2x8xf32>, %arg2: tensor<16xf32>) -> tensor<1x32x32x16xf32> { // expected-error@+1 {{'tosa.conv2d' op failed level check: pad <= MAX_KERNEL}} - %0 = "tosa.conv2d"(%arg0, %arg1, %arg2) {dilation = array<i64: 1, 1>, pad = array<i64: 8193, 1, 0, 1>, stride = array<i64: 1, 1>} : + %0 = "tosa.conv2d"(%arg0, %arg1, %arg2) {acc_type = f32, dilation = array<i64: 1, 1>, pad = array<i64: 8193, 1, 0, 1>, stride = array<i64: 1, 1>} : (tensor<1x32x32x8xf32>, tensor<16x2x2x8xf32>, tensor<16xf32>) -> tensor<1x32x32x16xf32> return %0 : tensor<1x32x32x16xf32> } @@ -253,7 +253,7 @@ func.func @test_conv2d_pad_top(%arg0: tensor<1x32x32x8xf32>, %arg1: tensor<16x2x func.func @test_conv2d_pad_bottom(%arg0: tensor<1x32x32x8xf32>, %arg1: tensor<16x2x2x8xf32>, %arg2: tensor<16xf32>) -> tensor<1x32x32x16xf32> { // expected-error@+1 {{'tosa.conv2d' op failed level check: pad <= MAX_KERNEL}} - %0 = "tosa.conv2d"(%arg0, %arg1, %arg2) {dilation = array<i64: 1, 1>, pad = array<i64: 0, 8193, 0, 1>, stride = array<i64: 1, 1>} : + %0 = "tosa.conv2d"(%arg0, %arg1, %arg2) {acc_type = f32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 8193, 0, 1>, stride = array<i64: 1, 1>} : (tensor<1x32x32x8xf32>, tensor<16x2x2x8xf32>, tensor<16xf32>) -> tensor<1x32x32x16xf32> return %0 : tensor<1x32x32x16xf32> } @@ -262,7 +262,7 @@ func.func @test_conv2d_pad_bottom(%arg0: tensor<1x32x32x8xf32>, %arg1: tensor<16 func.func @test_conv2d_pad_left(%arg0: tensor<1x32x32x8xf32>, %arg1: tensor<16x2x2x8xf32>, %arg2: tensor<16xf32>) -> tensor<1x32x32x16xf32> { // expected-error@+1 {{'tosa.conv2d' op failed level check: pad <= MAX_KERNEL}} - %0 = "tosa.conv2d"(%arg0, %arg1, %arg2) {dilation = array<i64: 1, 1>, pad = array<i64: 0, 1, 8193, 1>, stride = array<i64: 1, 1>} : + %0 = "tosa.conv2d"(%arg0, %arg1, %arg2) {acc_type = f32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 1, 8193, 1>, stride = array<i64: 1, 1>} : (tensor<1x32x32x8xf32>, tensor<16x2x2x8xf32>, tensor<16xf32>) -> tensor<1x32x32x16xf32> return %0 : tensor<1x32x32x16xf32> } @@ -271,7 +271,7 @@ func.func @test_conv2d_pad_left(%arg0: tensor<1x32x32x8xf32>, %arg1: tensor<16x2 func.func @test_conv2d_pad_right(%arg0: tensor<1x32x32x8xf32>, %arg1: tensor<16x2x2x8xf32>, %arg2: tensor<16xf32>) -> tensor<1x32x32x16xf32> { // expected-error@+1 {{'tosa.conv2d' op failed level check: pad <= MAX_KERNEL}} - %0 = "tosa.conv2d"(%arg0, %arg1, %arg2) {dilation = array<i64: 1, 1>, pad = array<i64: 0, 1, 0, 8193>, stride = array<i64: 1, 1>} : + %0 = "tosa.conv2d"(%arg0, %arg1, %arg2) {acc_type = f32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 1, 0, 8193>, stride = array<i64: 1, 1>} : (tensor<1x32x32x8xf32>, tensor<16x2x2x8xf32>, tensor<16xf32>) -> tensor<1x32x32x16xf32> return %0 : tensor<1x32x32x16xf32> } @@ -280,7 +280,7 @@ func.func @test_conv2d_pad_right(%arg0: tensor<1x32x32x8xf32>, %arg1: tensor<16x func.func @test_conv2d_stride_y(%arg0: tensor<1x32x32x8xf32>, %arg1: tensor<16x2x2x8xf32>, %arg2: tensor<16xf32>) -> tensor<1x32x32x16xf32> { // expected-error@+1 {{'tosa.conv2d' op failed level check: stride <= MAX_STRIDE}} - %0 = "tosa.conv2d"(%arg0, %arg1, %arg2) {dilation = array<i64: 1, 1>, pad = array<i64: 0, 1, 0, 1>, stride = array<i64: 8193, 1>} : + %0 = "tosa.conv2d"(%arg0, %arg1, %arg2) {acc_type = f32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 1, 0, 1>, stride = array<i64: 8193, 1>} : (tensor<1x32x32x8xf32>, tensor<16x2x2x8xf32>, tensor<16xf32>) -> tensor<1x32x32x16xf32> return %0 : tensor<1x32x32x16xf32> } @@ -289,7 +289,7 @@ func.func @test_conv2d_stride_y(%arg0: tensor<1x32x32x8xf32>, %arg1: tensor<16x2 func.func @test_conv2d_stride_x(%arg0: tensor<1x32x32x8xf32>, %arg1: tensor<16x2x2x8xf32>, %arg2: tensor<16xf32>) -> tensor<1x32x32x16xf32> { // expected-error@+1 {{'tosa.conv2d' op failed level check: stride <= MAX_STRIDE}} - %0 = "tosa.conv2d"(%arg0, %arg1, %arg2) {dilation = array<i64: 1, 1>, pad = array<i64: 0, 1, 0, 1>, stride = array<i64: 1, 8193>} : + %0 = "tosa.conv2d"(%arg0, %arg1, %arg2) {acc_type = f32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 1, 0, 1>, stride = array<i64: 1, 8193>} : (tensor<1x32x32x8xf32>, tensor<16x2x2x8xf32>, tensor<16xf32>) -> tensor<1x32x32x16xf32> return %0 : tensor<1x32x32x16xf32> } @@ -298,7 +298,7 @@ func.func @test_conv2d_stride_x(%arg0: tensor<1x32x32x8xf32>, %arg1: tensor<16x2 func.func @test_conv3d_dilation_d(%arg0: tensor<1x1x32x32x8xf32>, %arg1: tensor<16x2x2x2x8xf32>, %arg2: tensor<16xf32>) -> tensor<1x1x32x32x16xf32> { // expected-error@+1 {{'tosa.conv3d' op failed level check: dilation_d * KD <= MAX_KERNEL}} - %0 = "tosa.conv3d"(%arg0, %arg1, %arg2) {dilation = array<i64: 4097, 1, 1>, pad = array<i64: 0, 1, 0, 1, 0, 1>, stride = array<i64: 1, 1, 1>} : + %0 = "tosa.conv3d"(%arg0, %arg1, %arg2) {acc_type = f32, dilation = array<i64: 4097, 1, 1>, pad = array<i64: 0, 1, 0, 1, 0, 1>, stride = array<i64: 1, 1, 1>} : (tensor<1x1x32x32x8xf32>, tensor<16x2x2x2x8xf32>, tensor<16xf32>) -> tensor<1x1x32x32x16xf32> return %0 : tensor<1x1x32x32x16xf32> } @@ -307,7 +307,7 @@ func.func @test_conv3d_dilation_d(%arg0: tensor<1x1x32x32x8xf32>, %arg1: tensor< func.func @test_conv3d_dilation_y(%arg0: tensor<1x1x32x32x8xf32>, %arg1: tensor<16x2x2x2x8xf32>, %arg2: tensor<16xf32>) -> tensor<1x1x32x32x16xf32> { // expected-error@+1 {{'tosa.conv3d' op failed level check: dilation_y * KH <= MAX_KERNEL}} - %0 = "tosa.conv3d"(%arg0, %arg1, %arg2) {dilation = array<i64: 1, 4097, 1>, pad = array<i64: 0, 1, 0, 1, 0, 1>, stride = array<i64: 1, 1, 1>} : + %0 = "tosa.conv3d"(%arg0, %arg1, %arg2) {acc_type = f32, dilation = array<i64: 1, 4097, 1>, pad = array<i64: 0, 1, 0, 1, 0, 1>, stride = array<i64: 1, 1, 1>} : (tensor<1x1x32x32x8xf32>, tensor<16x2x2x2x8xf32>, tensor<16xf32>) -> tensor<1x1x32x32x16xf32> return %0 : tensor<1x1x32x32x16xf32> } @@ -316,7 +316,7 @@ func.func @test_conv3d_dilation_y(%arg0: tensor<1x1x32x32x8xf32>, %arg1: tensor< func.func @test_conv3d_dilation_x(%arg0: tensor<1x1x32x32x8xf32>, %arg1: tensor<16x2x2x2x8xf32>, %arg2: tensor<16xf32>) -> tensor<1x1x32x32x16xf32> { // expected-error@+1 {{'tosa.conv3d' op failed level check: dilation_x * KW <= MAX_KERNEL}} - %0 = "tosa.conv3d"(%arg0, %arg1, %arg2) {dilation = array<i64: 1, 1, 4097>, pad = array<i64: 0, 1, 0, 1, 0, 1>, stride = array<i64: 1, 1, 1>} : + %0 = "tosa.conv3d"(%arg0, %arg1, %arg2) {acc_type = f32, dilation = array<i64: 1, 1, 4097>, pad = array<i64: 0, 1, 0, 1, 0, 1>, stride = array<i64: 1, 1, 1>} : (tensor<1x1x32x32x8xf32>, tensor<16x2x2x2x8xf32>, tensor<16xf32>) -> tensor<1x1x32x32x16xf32> return %0 : tensor<1x1x32x32x16xf32> } @@ -325,7 +325,7 @@ func.func @test_conv3d_dilation_x(%arg0: tensor<1x1x32x32x8xf32>, %arg1: tensor< func.func @test_conv3d_pad_d0(%arg0: tensor<1x1x32x32x8xf32>, %arg1: tensor<16x2x2x2x8xf32>, %arg2: tensor<16xf32>) -> tensor<1x1x32x32x16xf32> { // expected-error@+1 {{'tosa.conv3d' op failed level check: pad <= MAX_KERNEL}} - %0 = "tosa.conv3d"(%arg0, %arg1, %arg2) {dilation = array<i64: 1, 1, 1>, pad = array<i64: 8193, 1, 0, 1, 0, 1>, stride = array<i64: 1, 1, 1>} : + %0 = "tosa.conv3d"(%arg0, %arg1, %arg2) {acc_type = f32, dilation = array<i64: 1, 1, 1>, pad = array<i64: 8193, 1, 0, 1, 0, 1>, stride = array<i64: 1, 1, 1>} : (tensor<1x1x32x32x8xf32>, tensor<16x2x2x2x8xf32>, tensor<16xf32>) -> tensor<1x1x32x32x16xf32> return %0 : tensor<1x1x32x32x16xf32> } @@ -334,7 +334,7 @@ func.func @test_conv3d_pad_d0(%arg0: tensor<1x1x32x32x8xf32>, %arg1: tensor<16x2 func.func @test_conv3d_pad_d1(%arg0: tensor<1x1x32x32x8xf32>, %arg1: tensor<16x2x2x2x8xf32>, %arg2: tensor<16xf32>) -> tensor<1x1x32x32x16xf32> { // expected-error@+1 {{'tosa.conv3d' op failed level check: pad <= MAX_KERNEL}} - %0 = "tosa.conv3d"(%arg0, %arg1, %arg2) {dilation = array<i64: 1, 1, 1>, pad = array<i64: 1, 8193, 0, 1, 0, 1>, stride = array<i64: 1, 1, 1>} : + %0 = "tosa.conv3d"(%arg0, %arg1, %arg2) {acc_type = f32, dilation = array<i64: 1, 1, 1>, pad = array<i64: 1, 8193, 0, 1, 0, 1>, stride = array<i64: 1, 1, 1>} : (tensor<1x1x32x32x8xf32>, tensor<16x2x2x2x8xf32>, tensor<16xf32>) -> tensor<1x1x32x32x16xf32> return %0 : tensor<1x1x32x32x16xf32> } @@ -343,7 +343,7 @@ func.func @test_conv3d_pad_d1(%arg0: tensor<1x1x32x32x8xf32>, %arg1: tensor<16x2 func.func @test_conv3d_pad_top(%arg0: tensor<1x1x32x32x8xf32>, %arg1: tensor<16x2x2x2x8xf32>, %arg2: tensor<16xf32>) -> tensor<1x1x32x32x16xf32> { // expected-error@+1 {{'tosa.conv3d' op failed level check: pad <= MAX_KERNEL}} - %0 = "tosa.conv3d"(%arg0, %arg1, %arg2) {dilation = array<i64: 1, 1, 1>, pad = array<i64: 0, 1, 8193, 1, 0, 1>, stride = array<i64: 1, 1, 1>} : + %0 = "tosa.conv3d"(%arg0, %arg1, %arg2) {acc_type = f32, dilation = array<i64: 1, 1, 1>, pad = array<i64: 0, 1, 8193, 1, 0, 1>, stride = array<i64: 1, 1, 1>} : (tensor<1x1x32x32x8xf32>, tensor<16x2x2x2x8xf32>, tensor<16xf32>) -> tensor<1x1x32x32x16xf32> return %0 : tensor<1x1x32x32x16xf32> } @@ -352,7 +352,7 @@ func.func @test_conv3d_pad_top(%arg0: tensor<1x1x32x32x8xf32>, %arg1: tensor<16x func.func @test_conv3d_pad_bottom(%arg0: tensor<1x1x32x32x8xf32>, %arg1: tensor<16x2x2x2x8xf32>, %arg2: tensor<16xf32>) -> tensor<1x1x32x32x16xf32> { // expected-error@+1 {{'tosa.conv3d' op failed level check: pad <= MAX_KERNEL}} - %0 = "tosa.conv3d"(%arg0, %arg1, %arg2) {dilation = array<i64: 1, 1, 1>, pad = array<i64: 0, 1, 0, 8193, 0, 1>, stride = array<i64: 1, 1, 1>} : + %0 = "tosa.conv3d"(%arg0, %arg1, %arg2) {acc_type = f32, dilation = array<i64: 1, 1, 1>, pad = array<i64: 0, 1, 0, 8193, 0, 1>, stride = array<i64: 1, 1, 1>} : (tensor<1x1x32x32x8xf32>, tensor<16x2x2x2x8xf32>, tensor<16xf32>) -> tensor<1x1x32x32x16xf32> return %0 : tensor<1x1x32x32x16xf32> } @@ -361,7 +361,7 @@ func.func @test_conv3d_pad_bottom(%arg0: tensor<1x1x32x32x8xf32>, %arg1: tensor< func.func @test_conv3d_pad_left(%arg0: tensor<1x1x32x32x8xf32>, %arg1: tensor<16x2x2x2x8xf32>, %arg2: tensor<16xf32>) -> tensor<1x1x32x32x16xf32> { // expected-error@+1 {{'tosa.conv3d' op failed level check: pad <= MAX_KERNEL}} - %0 = "tosa.conv3d"(%arg0, %arg1, %arg2) {dilation = array<i64: 1, 1, 1>, pad = array<i64: 0, 1, 0, 1, 8193, 1>, stride = array<i64: 1, 1, 1>} : + %0 = "tosa.conv3d"(%arg0, %arg1, %arg2) {acc_type = f32, dilation = array<i64: 1, 1, 1>, pad = array<i64: 0, 1, 0, 1, 8193, 1>, stride = array<i64: 1, 1, 1>} : (tensor<1x1x32x32x8xf32>, tensor<16x2x2x2x8xf32>, tensor<16xf32>) -> tensor<1x1x32x32x16xf32> return %0 : tensor<1x1x32x32x16xf32> } @@ -370,7 +370,7 @@ func.func @test_conv3d_pad_left(%arg0: tensor<1x1x32x32x8xf32>, %arg1: tensor<16 func.func @test_conv3d_pad_right(%arg0: tensor<1x1x32x32x8xf32>, %arg1: tensor<16x2x2x2x8xf32>, %arg2: tensor<16xf32>) -> tensor<1x1x32x32x16xf32> { // expected-error@+1 {{'tosa.conv3d' op failed level check: pad <= MAX_KERNEL}} - %0 = "tosa.conv3d"(%arg0, %arg1, %arg2) {dilation = array<i64: 1, 1, 1>, pad = array<i64: 0, 1, 0, 1, 0, 8193>, stride = array<i64: 1, 1, 1>} : + %0 = "tosa.conv3d"(%arg0, %arg1, %arg2) {acc_type = f32, dilation = array<i64: 1, 1, 1>, pad = array<i64: 0, 1, 0, 1, 0, 8193>, stride = array<i64: 1, 1, 1>} : (tensor<1x1x32x32x8xf32>, tensor<16x2x2x2x8xf32>, tensor<16xf32>) -> tensor<1x1x32x32x16xf32> return %0 : tensor<1x1x32x32x16xf32> } @@ -379,7 +379,7 @@ func.func @test_conv3d_pad_right(%arg0: tensor<1x1x32x32x8xf32>, %arg1: tensor<1 func.func @test_conv3d_stride_d(%arg0: tensor<1x1x32x32x8xf32>, %arg1: tensor<16x2x2x2x8xf32>, %arg2: tensor<16xf32>) -> tensor<1x1x32x32x16xf32> { // expected-error@+1 {{'tosa.conv3d' op failed level check: stride <= MAX_STRIDE}} - %0 = "tosa.conv3d"(%arg0, %arg1, %arg2) {dilation = array<i64: 1, 1, 1>, pad = array<i64: 0, 1, 0, 1, 0, 1>, stride = array<i64: 8193, 1, 1>} : + %0 = "tosa.conv3d"(%arg0, %arg1, %arg2) {acc_type = f32, dilation = array<i64: 1, 1, 1>, pad = array<i64: 0, 1, 0, 1, 0, 1>, stride = array<i64: 8193, 1, 1>} : (tensor<1x1x32x32x8xf32>, tensor<16x2x2x2x8xf32>, tensor<16xf32>) -> tensor<1x1x32x32x16xf32> return %0 : tensor<1x1x32x32x16xf32> } @@ -388,7 +388,7 @@ func.func @test_conv3d_stride_d(%arg0: tensor<1x1x32x32x8xf32>, %arg1: tensor<16 func.func @test_conv3d_stride_y(%arg0: tensor<1x1x32x32x8xf32>, %arg1: tensor<16x2x2x2x8xf32>, %arg2: tensor<16xf32>) -> tensor<1x1x32x32x16xf32> { // expected-error@+1 {{'tosa.conv3d' op failed level check: stride <= MAX_STRIDE}} - %0 = "tosa.conv3d"(%arg0, %arg1, %arg2) {dilation = array<i64: 1, 1, 1>, pad = array<i64: 0, 1, 0, 1, 0, 1>, stride = array<i64: 1, 8193, 1>} : + %0 = "tosa.conv3d"(%arg0, %arg1, %arg2) {acc_type = f32, dilation = array<i64: 1, 1, 1>, pad = array<i64: 0, 1, 0, 1, 0, 1>, stride = array<i64: 1, 8193, 1>} : (tensor<1x1x32x32x8xf32>, tensor<16x2x2x2x8xf32>, tensor<16xf32>) -> tensor<1x1x32x32x16xf32> return %0 : tensor<1x1x32x32x16xf32> } @@ -397,7 +397,7 @@ func.func @test_conv3d_stride_y(%arg0: tensor<1x1x32x32x8xf32>, %arg1: tensor<16 func.func @test_conv3d_stride_x(%arg0: tensor<1x1x32x32x8xf32>, %arg1: tensor<16x2x2x2x8xf32>, %arg2: tensor<16xf32>) -> tensor<1x1x32x32x16xf32> { // expected-error@+1 {{'tosa.conv3d' op failed level check: stride <= MAX_STRIDE}} - %0 = "tosa.conv3d"(%arg0, %arg1, %arg2) {dilation = array<i64: 1, 1, 1>, pad = array<i64: 0, 1, 0, 1, 0, 1>, stride = array<i64: 1, 1, 8193>} : + %0 = "tosa.conv3d"(%arg0, %arg1, %arg2) {acc_type = f32, dilation = array<i64: 1, 1, 1>, pad = array<i64: 0, 1, 0, 1, 0, 1>, stride = array<i64: 1, 1, 8193>} : (tensor<1x1x32x32x8xf32>, tensor<16x2x2x2x8xf32>, tensor<16xf32>) -> tensor<1x1x32x32x16xf32> return %0 : tensor<1x1x32x32x16xf32> } @@ -406,7 +406,7 @@ func.func @test_conv3d_stride_x(%arg0: tensor<1x1x32x32x8xf32>, %arg1: tensor<16 func.func @test_depthwise_conv2d_dilation_y(%arg0: tensor<1x32x32x8xf32>, %arg1: tensor<2x2x8x8xf32>, %arg2: tensor<64xf32>) -> tensor<1x32x32x64xf32> { // expected-error@+1 {{'tosa.depthwise_conv2d' op failed level check: dilation_y * KH <= MAX_KERNEL}} - %0 = "tosa.depthwise_conv2d"(%arg0, %arg1, %arg2) {dilation = array<i64: 4097, 1>, pad = array<i64: 0, 1, 0, 1>, stride = array<i64: 1, 1>} : + %0 = "tosa.depthwise_conv2d"(%arg0, %arg1, %arg2) {acc_type = f32, dilation = array<i64: 4097, 1>, pad = array<i64: 0, 1, 0, 1>, stride = array<i64: 1, 1>} : (tensor<1x32x32x8xf32>, tensor<2x2x8x8xf32>, tensor<64xf32>) -> tensor<1x32x32x64xf32> return %0 : tensor<1x32x32x64xf32> } @@ -415,7 +415,7 @@ func.func @test_depthwise_conv2d_dilation_y(%arg0: tensor<1x32x32x8xf32>, %arg1: func.func @test_depthwise_conv2d_dilation_x(%arg0: tensor<1x32x32x8xf32>, %arg1: tensor<2x2x8x8xf32>, %arg2: tensor<64xf32>) -> tensor<1x32x32x64xf32> { // expected-error@+1 {{'tosa.depthwise_conv2d' op failed level check: dilation_x * KW <= MAX_KERNEL}} - %0 = "tosa.depthwise_conv2d"(%arg0, %arg1, %arg2) {dilation = array<i64: 1, 4097>, pad = array<i64: 0, 1, 0, 1>, stride = array<i64: 1, 1>} : + %0 = "tosa.depthwise_conv2d"(%arg0, %arg1, %arg2) {acc_type = f32, dilation = array<i64: 1, 4097>, pad = array<i64: 0, 1, 0, 1>, stride = array<i64: 1, 1>} : (tensor<1x32x32x8xf32>, tensor<2x2x8x8xf32>, tensor<64xf32>) -> tensor<1x32x32x64xf32> return %0 : tensor<1x32x32x64xf32> } @@ -424,7 +424,7 @@ func.func @test_depthwise_conv2d_dilation_x(%arg0: tensor<1x32x32x8xf32>, %arg1: func.func @test_depthwise_conv2d_pad_top(%arg0: tensor<1x32x32x8xf32>, %arg1: tensor<2x2x8x8xf32>, %arg2: tensor<64xf32>) -> tensor<1x32x32x64xf32> { // expected-error@+1 {{'tosa.depthwise_conv2d' op failed level check: pad <= MAX_KERNEL}} - %0 = "tosa.depthwise_conv2d"(%arg0, %arg1, %arg2) {dilation = array<i64: 1, 1>, pad = array<i64: 8193, 1, 0, 1>, stride = array<i64: 1, 1>} : + %0 = "tosa.depthwise_conv2d"(%arg0, %arg1, %arg2) {acc_type = f32, dilation = array<i64: 1, 1>, pad = array<i64: 8193, 1, 0, 1>, stride = array<i64: 1, 1>} : (tensor<1x32x32x8xf32>, tensor<2x2x8x8xf32>, tensor<64xf32>) -> tensor<1x32x32x64xf32> return %0 : tensor<1x32x32x64xf32> } @@ -433,7 +433,7 @@ func.func @test_depthwise_conv2d_pad_top(%arg0: tensor<1x32x32x8xf32>, %arg1: te func.func @test_depthwise_conv2d_pad_bottom(%arg0: tensor<1x32x32x8xf32>, %arg1: tensor<2x2x8x8xf32>, %arg2: tensor<64xf32>) -> tensor<1x32x32x64xf32> { // expected-error@+1 {{'tosa.depthwise_conv2d' op failed level check: pad <= MAX_KERNEL}} - %0 = "tosa.depthwise_conv2d"(%arg0, %arg1, %arg2) {dilation = array<i64: 1, 1>, pad = array<i64: 0, 8193, 0, 1>, stride = array<i64: 1, 1>} : + %0 = "tosa.depthwise_conv2d"(%arg0, %arg1, %arg2) {acc_type = f32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 8193, 0, 1>, stride = array<i64: 1, 1>} : (tensor<1x32x32x8xf32>, tensor<2x2x8x8xf32>, tensor<64xf32>) -> tensor<1x32x32x64xf32> return %0 : tensor<1x32x32x64xf32> } @@ -442,7 +442,7 @@ func.func @test_depthwise_conv2d_pad_bottom(%arg0: tensor<1x32x32x8xf32>, %arg1: func.func @test_depthwise_conv2d_pad_left(%arg0: tensor<1x32x32x8xf32>, %arg1: tensor<2x2x8x8xf32>, %arg2: tensor<64xf32>) -> tensor<1x32x32x64xf32> { // expected-error@+1 {{'tosa.depthwise_conv2d' op failed level check: pad <= MAX_KERNEL}} - %0 = "tosa.depthwise_conv2d"(%arg0, %arg1, %arg2) {dilation = array<i64: 1, 1>, pad = array<i64: 0, 1, 8193, 1>, stride = array<i64: 1, 1>} : + %0 = "tosa.depthwise_conv2d"(%arg0, %arg1, %arg2) {acc_type = f32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 1, 8193, 1>, stride = array<i64: 1, 1>} : (tensor<1x32x32x8xf32>, tensor<2x2x8x8xf32>, tensor<64xf32>) -> tensor<1x32x32x64xf32> return %0 : tensor<1x32x32x64xf32> } @@ -451,7 +451,7 @@ func.func @test_depthwise_conv2d_pad_left(%arg0: tensor<1x32x32x8xf32>, %arg1: t func.func @test_depthwise_conv2d_pad_right(%arg0: tensor<1x32x32x8xf32>, %arg1: tensor<2x2x8x8xf32>, %arg2: tensor<64xf32>) -> tensor<1x32x32x64xf32> { // expected-error@+1 {{'tosa.depthwise_conv2d' op failed level check: pad <= MAX_KERNEL}} - %0 = "tosa.depthwise_conv2d"(%arg0, %arg1, %arg2) {dilation = array<i64: 1, 1>, pad = array<i64: 0, 1, 0, 8193>, stride = array<i64: 1, 1>} : + %0 = "tosa.depthwise_conv2d"(%arg0, %arg1, %arg2) {acc_type = f32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 1, 0, 8193>, stride = array<i64: 1, 1>} : (tensor<1x32x32x8xf32>, tensor<2x2x8x8xf32>, tensor<64xf32>) -> tensor<1x32x32x64xf32> return %0 : tensor<1x32x32x64xf32> } @@ -460,7 +460,7 @@ func.func @test_depthwise_conv2d_pad_right(%arg0: tensor<1x32x32x8xf32>, %arg1: func.func @test_depthwise_conv2d_stride_y(%arg0: tensor<1x32x32x8xf32>, %arg1: tensor<2x2x8x8xf32>, %arg2: tensor<64xf32>) -> tensor<1x32x32x64xf32> { // expected-error@+1 {{'tosa.depthwise_conv2d' op failed level check: stride <= MAX_STRIDE}} - %0 = "tosa.depthwise_conv2d"(%arg0, %arg1, %arg2) {dilation = array<i64: 1, 1>, pad = array<i64: 0, 1, 0, 1>, stride = array<i64: 8193, 1>} : + %0 = "tosa.depthwise_conv2d"(%arg0, %arg1, %arg2) {acc_type = f32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 1, 0, 1>, stride = array<i64: 8193, 1>} : (tensor<1x32x32x8xf32>, tensor<2x2x8x8xf32>, tensor<64xf32>) -> tensor<1x32x32x64xf32> return %0 : tensor<1x32x32x64xf32> } @@ -469,7 +469,7 @@ func.func @test_depthwise_conv2d_stride_y(%arg0: tensor<1x32x32x8xf32>, %arg1: t func.func @test_depthwise_conv2d_stride_x(%arg0: tensor<1x32x32x8xf32>, %arg1: tensor<2x2x8x8xf32>, %arg2: tensor<64xf32>) -> tensor<1x32x32x64xf32> { // expected-error@+1 {{'tosa.depthwise_conv2d' op failed level check: stride <= MAX_STRIDE}} - %0 = "tosa.depthwise_conv2d"(%arg0, %arg1, %arg2) {dilation = array<i64: 1, 1>, pad = array<i64: 0, 1, 0, 1>, stride = array<i64: 1, 8193>} : + %0 = "tosa.depthwise_conv2d"(%arg0, %arg1, %arg2) {acc_type = f32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 1, 0, 1>, stride = array<i64: 1, 8193>} : (tensor<1x32x32x8xf32>, tensor<2x2x8x8xf32>, tensor<64xf32>) -> tensor<1x32x32x64xf32> return %0 : tensor<1x32x32x64xf32> } @@ -603,7 +603,7 @@ func.func @test_rfft2d_input_w(%arg0: tensor<13x8x8193xf32>) -> (tensor<13x8x9xf func.func @test_transpose_conv2d_weight_h(%arg0: tensor<1x32x32x8xf32>, %arg1: tensor<16x8193x1x8xf32>, %arg2: tensor<16xf32>) -> tensor<1x32x32x16xf32> { // expected-error@+1 {{'tosa.transpose_conv2d' op failed level check: KH <= MAX_KERNEL}} - %0 = "tosa.transpose_conv2d"(%arg0, %arg1, %arg2) {out_pad = array<i64: 0, 0, 0, 0>, out_shape = array<i64: 1, 32, 32, 16>, stride = array<i64: 1, 1>} : + %0 = "tosa.transpose_conv2d"(%arg0, %arg1, %arg2) {acc_type = f32, out_pad = array<i64: 0, 0, 0, 0>, out_shape = array<i64: 1, 32, 32, 16>, stride = array<i64: 1, 1>} : (tensor<1x32x32x8xf32>, tensor<16x8193x1x8xf32>, tensor<16xf32>) -> tensor<1x32x32x16xf32> return %0 : tensor<1x32x32x16xf32> } @@ -612,7 +612,7 @@ func.func @test_transpose_conv2d_weight_h(%arg0: tensor<1x32x32x8xf32>, %arg1: t func.func @test_transpose_conv2d_weight_w(%arg0: tensor<1x32x32x8xf32>, %arg1: tensor<16x1x8193x8xf32>, %arg2: tensor<16xf32>) -> tensor<1x32x32x16xf32> { // expected-error@+1 {{'tosa.transpose_conv2d' op failed level check: KW <= MAX_KERNEL}} - %0 = "tosa.transpose_conv2d"(%arg0, %arg1, %arg2) {out_pad = array<i64: 0, 0, 0, 0>, out_shape = array<i64: 1, 32, 32, 16>, stride = array<i64: 1, 1>} : + %0 = "tosa.transpose_conv2d"(%arg0, %arg1, %arg2) {acc_type = f32, out_pad = array<i64: 0, 0, 0, 0>, out_shape = array<i64: 1, 32, 32, 16>, stride = array<i64: 1, 1>} : (tensor<1x32x32x8xf32>, tensor<16x1x8193x8xf32>, tensor<16xf32>) -> tensor<1x32x32x16xf32> return %0 : tensor<1x32x32x16xf32> } @@ -621,7 +621,7 @@ func.func @test_transpose_conv2d_weight_w(%arg0: tensor<1x32x32x8xf32>, %arg1: t func.func @test_transpose_conv2d_pad_top(%arg0: tensor<1x32x32x8xf32>, %arg1: tensor<16x1x1x8xf32>, %arg2: tensor<16xf32>) -> tensor<1x32x32x16xf32> { // expected-error@+1 {{'tosa.transpose_conv2d' op failed level check: pad <= MAX_KERNEL}} - %0 = "tosa.transpose_conv2d"(%arg0, %arg1, %arg2) {out_pad = array<i64: 8193, 0, 0, 0>, out_shape = array<i64: 1, 32, 32, 16>, stride = array<i64: 1, 1>} : + %0 = "tosa.transpose_conv2d"(%arg0, %arg1, %arg2) {acc_type = f32, out_pad = array<i64: 8193, 0, 0, 0>, out_shape = array<i64: 1, 32, 32, 16>, stride = array<i64: 1, 1>} : (tensor<1x32x32x8xf32>, tensor<16x1x1x8xf32>, tensor<16xf32>) -> tensor<1x32x32x16xf32> return %0 : tensor<1x32x32x16xf32> } @@ -630,7 +630,7 @@ func.func @test_transpose_conv2d_pad_top(%arg0: tensor<1x32x32x8xf32>, %arg1: te func.func @test_transpose_conv2d_pad_bottom(%arg0: tensor<1x32x32x8xf32>, %arg1: tensor<16x1x1x8xf32>, %arg2: tensor<16xf32>) -> tensor<1x32x32x16xf32> { // expected-error@+1 {{'tosa.transpose_conv2d' op failed level check: pad <= MAX_KERNEL}} - %0 = "tosa.transpose_conv2d"(%arg0, %arg1, %arg2) {out_pad = array<i64: 0, 8193, 0, 0>, out_shape = array<i64: 1, 32, 32, 16>, stride = array<i64: 1, 1>} : + %0 = "tosa.transpose_conv2d"(%arg0, %arg1, %arg2) {acc_type = f32, out_pad = array<i64: 0, 8193, 0, 0>, out_shape = array<i64: 1, 32, 32, 16>, stride = array<i64: 1, 1>} : (tensor<1x32x32x8xf32>, tensor<16x1x1x8xf32>, tensor<16xf32>) -> tensor<1x32x32x16xf32> return %0 : tensor<1x32x32x16xf32> } @@ -639,7 +639,7 @@ func.func @test_transpose_conv2d_pad_bottom(%arg0: tensor<1x32x32x8xf32>, %arg1: func.func @test_transpose_conv2d_pad_left(%arg0: tensor<1x32x32x8xf32>, %arg1: tensor<16x1x1x8xf32>, %arg2: tensor<16xf32>) -> tensor<1x32x32x16xf32> { // expected-error@+1 {{'tosa.transpose_conv2d' op failed level check: pad <= MAX_KERNEL}} - %0 = "tosa.transpose_conv2d"(%arg0, %arg1, %arg2) {out_pad = array<i64: 0, 0, 8193, 0>, out_shape = array<i64: 1, 32, 32, 16>, stride = array<i64: 1, 1>} : + %0 = "tosa.transpose_conv2d"(%arg0, %arg1, %arg2) {acc_type = f32, out_pad = array<i64: 0, 0, 8193, 0>, out_shape = array<i64: 1, 32, 32, 16>, stride = array<i64: 1, 1>} : (tensor<1x32x32x8xf32>, tensor<16x1x1x8xf32>, tensor<16xf32>) -> tensor<1x32x32x16xf32> return %0 : tensor<1x32x32x16xf32> } @@ -648,7 +648,7 @@ func.func @test_transpose_conv2d_pad_left(%arg0: tensor<1x32x32x8xf32>, %arg1: t func.func @test_transpose_conv2d_pad_right(%arg0: tensor<1x32x32x8xf32>, %arg1: tensor<16x1x1x8xf32>, %arg2: tensor<16xf32>) -> tensor<1x32x32x16xf32> { // expected-error@+1 {{'tosa.transpose_conv2d' op failed level check: pad <= MAX_KERNEL}} - %0 = "tosa.transpose_conv2d"(%arg0, %arg1, %arg2) {out_pad = array<i64: 0, 0, 0, 8193>, out_shape = array<i64: 1, 32, 32, 16>, stride = array<i64: 1, 1>} : + %0 = "tosa.transpose_conv2d"(%arg0, %arg1, %arg2) {acc_type = f32, out_pad = array<i64: 0, 0, 0, 8193>, out_shape = array<i64: 1, 32, 32, 16>, stride = array<i64: 1, 1>} : (tensor<1x32x32x8xf32>, tensor<16x1x1x8xf32>, tensor<16xf32>) -> tensor<1x32x32x16xf32> return %0 : tensor<1x32x32x16xf32> } @@ -657,7 +657,7 @@ func.func @test_transpose_conv2d_pad_right(%arg0: tensor<1x32x32x8xf32>, %arg1: func.func @test_transpose_conv2d_stride_y(%arg0: tensor<1x32x32x8xf32>, %arg1: tensor<16x1x1x8xf32>, %arg2: tensor<16xf32>) -> tensor<1x32x32x16xf32> { // expected-error@+1 {{'tosa.transpose_conv2d' op failed level check: stride <= MAX_STRIDE}} - %0 = "tosa.transpose_conv2d"(%arg0, %arg1, %arg2) {out_pad = array<i64: 0, 0, 0, 0>, out_shape = array<i64: 1, 32, 32, 16>, stride = array<i64: 8193, 1>} : + %0 = "tosa.transpose_conv2d"(%arg0, %arg1, %arg2) {acc_type = f32, out_pad = array<i64: 0, 0, 0, 0>, out_shape = array<i64: 1, 32, 32, 16>, stride = array<i64: 8193, 1>} : (tensor<1x32x32x8xf32>, tensor<16x1x1x8xf32>, tensor<16xf32>) -> tensor<1x32x32x16xf32> return %0 : tensor<1x32x32x16xf32> } @@ -666,7 +666,7 @@ func.func @test_transpose_conv2d_stride_y(%arg0: tensor<1x32x32x8xf32>, %arg1: t func.func @test_transpose_conv2d_stride_x(%arg0: tensor<1x32x32x8xf32>, %arg1: tensor<16x1x1x8xf32>, %arg2: tensor<16xf32>) -> tensor<1x32x32x16xf32> { // expected-error@+1 {{'tosa.transpose_conv2d' op failed level check: stride <= MAX_STRIDE}} - %0 = "tosa.transpose_conv2d"(%arg0, %arg1, %arg2) {out_pad = array<i64: 0, 0, 0, 0>, out_shape = array<i64: 1, 32, 32, 16>, stride = array<i64: 1, 8193>} : + %0 = "tosa.transpose_conv2d"(%arg0, %arg1, %arg2) {acc_type = f32, out_pad = array<i64: 0, 0, 0, 0>, out_shape = array<i64: 1, 32, 32, 16>, stride = array<i64: 1, 8193>} : (tensor<1x32x32x8xf32>, tensor<16x1x1x8xf32>, tensor<16xf32>) -> tensor<1x32x32x16xf32> return %0 : tensor<1x32x32x16xf32> } diff --git a/mlir/test/Dialect/Tosa/ops.mlir b/mlir/test/Dialect/Tosa/ops.mlir index 88fa94ae90db..f2e1cff72ab2 100644 --- a/mlir/test/Dialect/Tosa/ops.mlir +++ b/mlir/test/Dialect/Tosa/ops.mlir @@ -54,7 +54,7 @@ func.func @test_avg_pool2d_q8(%arg0: tensor<1x7x7x9x!quant.uniform<i8:f32, 0.01> // ----- // CHECK-LABEL: conv2d func.func @test_conv2d(%arg0: tensor<1x4x4x4xf32>, %arg1: tensor<8x1x1x4xf32>, %arg2: tensor<8xf32>) -> tensor<1x4x4x8xf32> { - %0 = tosa.conv2d %arg0, %arg1, %arg2 {dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, local_bound = true} : (tensor<1x4x4x4xf32>, tensor<8x1x1x4xf32>, tensor<8xf32>) -> tensor<1x4x4x8xf32> + %0 = tosa.conv2d %arg0, %arg1, %arg2 {acc_type = f32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, local_bound = true} : (tensor<1x4x4x4xf32>, tensor<8x1x1x4xf32>, tensor<8xf32>) -> tensor<1x4x4x8xf32> return %0 : tensor<1x4x4x8xf32> } @@ -63,7 +63,7 @@ func.func @test_conv2d(%arg0: tensor<1x4x4x4xf32>, %arg1: tensor<8x1x1x4xf32>, % func.func @test_conv2d_q8xi4(%arg0: tensor<1x11x11x3xi8>) -> tensor<1x1x1x3xi8> { %0 = "tosa.const"() {value = dense<0> : tensor<3x11x11x3xi4>} : () -> tensor<3x11x11x3xi4> %1 = "tosa.const"() {value = dense<[12, 23, 55]> : tensor<3xi32>} : () -> tensor<3xi32> - %2 = "tosa.conv2d"(%arg0, %0, %1) {dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, quantization_info = #tosa.conv_quant<input_zp = 0, weight_zp = 0>, stride = array<i64: 1, 1>} : (tensor<1x11x11x3xi8>, tensor<3x11x11x3xi4>, tensor<3xi32>) -> tensor<1x1x1x3xi32> + %2 = "tosa.conv2d"(%arg0, %0, %1) {acc_type = i32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, quantization_info = #tosa.conv_quant<input_zp = 0, weight_zp = 0>, stride = array<i64: 1, 1>} : (tensor<1x11x11x3xi8>, tensor<3x11x11x3xi4>, tensor<3xi32>) -> tensor<1x1x1x3xi32> %3 = "tosa.rescale"(%2) {double_round = true, input_zp = 0 : i32, multiplier = array<i32: 2026291432, 1079222024, 1693132724>, output_zp = 27 : i32, per_channel = true, scale32 = true, shift = array<i8: 37, 36, 37>} : (tensor<1x1x1x3xi32>) -> tensor<1x1x1x3xi8> return %3 : tensor<1x1x1x3xi8> } @@ -71,28 +71,28 @@ func.func @test_conv2d_q8xi4(%arg0: tensor<1x11x11x3xi8>) -> tensor<1x1x1x3xi8> // ----- // CHECK-LABEL: conv3d func.func @test_conv3d(%arg0: tensor<1x4x8x21x17xf32>, %arg1: tensor<34x1x1x1x17xf32>, %arg2: tensor<34xf32>) -> tensor<1x4x8x21x34xf32> { - %0 = tosa.conv3d %arg0, %arg1, %arg2 {dilation = array<i64: 1, 1, 1>, pad = array<i64: 0, 0, 0, 0, 0, 0>, stride = array<i64: 1, 1, 1>} : (tensor<1x4x8x21x17xf32>, tensor<34x1x1x1x17xf32>, tensor<34xf32>) -> tensor<1x4x8x21x34xf32> + %0 = tosa.conv3d %arg0, %arg1, %arg2 {acc_type = f32, dilation = array<i64: 1, 1, 1>, pad = array<i64: 0, 0, 0, 0, 0, 0>, stride = array<i64: 1, 1, 1>} : (tensor<1x4x8x21x17xf32>, tensor<34x1x1x1x17xf32>, tensor<34xf32>) -> tensor<1x4x8x21x34xf32> return %0 : tensor<1x4x8x21x34xf32> } // ----- // CHECK-LABEL: conv3d_with_local_bound func.func @test_conv3d_with_local_bound(%arg0: tensor<1x4x8x21x17xf32>, %arg1: tensor<34x1x1x1x17xf32>, %arg2: tensor<34xf32>) -> tensor<1x4x8x21x34xf32> { - %0 = tosa.conv3d %arg0, %arg1, %arg2 {dilation = array<i64: 1, 1, 1>, pad = array<i64: 0, 0, 0, 0, 0, 0>, stride = array<i64: 1, 1, 1>, local_bound = true} : (tensor<1x4x8x21x17xf32>, tensor<34x1x1x1x17xf32>, tensor<34xf32>) -> tensor<1x4x8x21x34xf32> + %0 = tosa.conv3d %arg0, %arg1, %arg2 {acc_type = f32, dilation = array<i64: 1, 1, 1>, pad = array<i64: 0, 0, 0, 0, 0, 0>, stride = array<i64: 1, 1, 1>, local_bound = true} : (tensor<1x4x8x21x17xf32>, tensor<34x1x1x1x17xf32>, tensor<34xf32>) -> tensor<1x4x8x21x34xf32> return %0 : tensor<1x4x8x21x34xf32> } // ----- // CHECK-LABEL: depthwise_conv2d func.func @test_depthwise_conv2d(%arg0: tensor<1x4x4x4xf32>, %arg1: tensor<1x1x4x2xf32>, %arg2: tensor<8xf32>) -> tensor<1x4x4x8xf32> { - %0 = tosa.depthwise_conv2d %arg0, %arg1, %arg2 {dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<1x4x4x4xf32>, tensor<1x1x4x2xf32>, tensor<8xf32>) -> tensor<1x4x4x8xf32> + %0 = tosa.depthwise_conv2d %arg0, %arg1, %arg2 {acc_type = f32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<1x4x4x4xf32>, tensor<1x1x4x2xf32>, tensor<8xf32>) -> tensor<1x4x4x8xf32> return %0 : tensor<1x4x4x8xf32> } // ----- // CHECK-LABEL: depthwise_conv2d_with_local_bound func.func @test_depthwise_conv2d_with_local_bound(%arg0: tensor<1x4x4x4xf32>, %arg1: tensor<1x1x4x2xf32>, %arg2: tensor<8xf32>) -> tensor<1x4x4x8xf32> { - %0 = tosa.depthwise_conv2d %arg0, %arg1, %arg2 {dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, local_bound = true} : (tensor<1x4x4x4xf32>, tensor<1x1x4x2xf32>, tensor<8xf32>) -> tensor<1x4x4x8xf32> + %0 = tosa.depthwise_conv2d %arg0, %arg1, %arg2 {acc_type = f32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, local_bound = true} : (tensor<1x4x4x4xf32>, tensor<1x1x4x2xf32>, tensor<8xf32>) -> tensor<1x4x4x8xf32> return %0 : tensor<1x4x4x8xf32> } @@ -162,14 +162,14 @@ func.func @test_rfft2d_with_local_bound(%arg0: tensor<13x8x16xf32>) -> (tensor<1 // ----- // CHECK-LABEL: transpose_conv2d func.func @test_transpose_conv2d(%arg0: tensor<1x32x32x8xf32>, %arg1: tensor<16x1x1x8xf32>, %arg2: tensor<16xf32>) -> tensor<1x32x32x16xf32> { - %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2 {out_pad = array<i64: 0, 0, 0, 0>, out_shape = array<i64: 1, 32, 32, 16>, stride = array<i64: 1, 1>} : (tensor<1x32x32x8xf32>, tensor<16x1x1x8xf32>, tensor<16xf32>) -> tensor<1x32x32x16xf32> + %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2 {acc_type = f32, out_pad = array<i64: 0, 0, 0, 0>, out_shape = array<i64: 1, 32, 32, 16>, stride = array<i64: 1, 1>} : (tensor<1x32x32x8xf32>, tensor<16x1x1x8xf32>, tensor<16xf32>) -> tensor<1x32x32x16xf32> return %0 : tensor<1x32x32x16xf32> } // ----- // CHECK-LABEL: transpose_conv2d_with_local_bound func.func @test_transpose_conv2d_with_local_bound(%arg0: tensor<1x32x32x8xf32>, %arg1: tensor<16x1x1x8xf32>, %arg2: tensor<16xf32>) -> tensor<1x32x32x16xf32> { - %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2 {out_pad = array<i64: 0, 0, 0, 0>, out_shape = array<i64: 1, 32, 32, 16>, stride = array<i64: 1, 1>, local_bound = false} : (tensor<1x32x32x8xf32>, tensor<16x1x1x8xf32>, tensor<16xf32>) -> tensor<1x32x32x16xf32> + %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2 {acc_type = f32, out_pad = array<i64: 0, 0, 0, 0>, out_shape = array<i64: 1, 32, 32, 16>, stride = array<i64: 1, 1>, local_bound = false} : (tensor<1x32x32x8xf32>, tensor<16x1x1x8xf32>, tensor<16xf32>) -> tensor<1x32x32x16xf32> return %0 : tensor<1x32x32x16xf32> } @@ -525,16 +525,16 @@ func.func @test_concat(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x21x3xf32>) - // ----- // CHECK-LABEL: pad -func.func @test_pad(%arg0: tensor<13x21x3xf32>, %arg1: tensor<3x2xi32>) -> tensor<13x21x3xf32> { - %0 = tosa.pad %arg0, %arg1 : (tensor<13x21x3xf32>, tensor<3x2xi32>) -> tensor<13x21x3xf32> +func.func @test_pad(%arg0: tensor<13x21x3xf32>, %arg1: tensor<6xi32>) -> tensor<13x21x3xf32> { + %0 = tosa.pad %arg0, %arg1 : (tensor<13x21x3xf32>, tensor<6xi32>) -> tensor<13x21x3xf32> return %0 : tensor<13x21x3xf32> } // ----- // CHECK-LABEL: pad_explicit_value -func.func @test_pad_explicit_value(%arg0: tensor<13x21x3xf32>, %arg1: tensor<3x2xi32>) -> tensor<13x21x3xf32> { +func.func @test_pad_explicit_value(%arg0: tensor<13x21x3xf32>, %arg1: tensor<6xi32>) -> tensor<13x21x3xf32> { %0 = "tosa.const"() {value = dense<3.14> : tensor<f32>} : () -> tensor<f32> - %1 = tosa.pad %arg0, %arg1, %0 : (tensor<13x21x3xf32>, tensor<3x2xi32>, tensor<f32>) -> tensor<13x21x3xf32> + %1 = tosa.pad %arg0, %arg1, %0 : (tensor<13x21x3xf32>, tensor<6xi32>, tensor<f32>) -> tensor<13x21x3xf32> return %1 : tensor<13x21x3xf32> } diff --git a/mlir/test/Dialect/Tosa/quant-test.mlir b/mlir/test/Dialect/Tosa/quant-test.mlir index 82a87dbc2494..6437f12e3ff8 100644 --- a/mlir/test/Dialect/Tosa/quant-test.mlir +++ b/mlir/test/Dialect/Tosa/quant-test.mlir @@ -10,9 +10,9 @@ func.func @test_build_qtype(%arg0 : tensor<16x1x1x8x!quant.uniform<u8<1:255>:f32 // ----- // CHECK-LABEL: test_build_mult_and_shift -func.func @test_build_mult_and_shift(%arg0: tensor<1x32x32x8x!quant.uniform<i8:f32, 0.015684768557548523>>, %arg1 : tensor<16x1x1x8x!quant.uniform<i8<-127:127>:f32, 0.015680249780416489>>, %arg2 : tensor<16xi32>) -> tensor<1x32x32x16x!quant.uniform<i8:f32, 0.078431375324726104>> { +func.func @test_build_mult_and_shift(%arg0: tensor<1x32x32x8x!quant.uniform<i8:f32, 0.015684768557548523>>, %arg1 : tensor<16x1x1x8x!quant.uniform<i8<-127:127>:f32, 0.015680249780416489>>, %arg2 : tensor<16xi32>) -> tensor<1x32x32x16x!quant.uniform<i32:f32, 0.078431375324726104>> { // CHECK: tosa.conv2d - %0 = "tosa.conv2d"(%arg0, %arg1, %arg2) {pad = array<i64: 1, 1, 2, 2>, dilation = array<i64: 2, 1>, stride = array<i64: 1, 1>, quantization_info = #tosa.conv_quant<input_zp = -1, weight_zp = 0>} : (tensor<1x32x32x8x!quant.uniform<i8:f32, 0.015684768557548523>>, tensor<16x1x1x8x!quant.uniform<i8<-127:127>:f32, 0.015680249780416489>>, tensor<16xi32>) -> tensor<1x32x32x16x!quant.uniform<i8:f32, 0.078431375324726104>> - return %0 : tensor<1x32x32x16x!quant.uniform<i8:f32, 0.078431375324726104>> + %0 = "tosa.conv2d"(%arg0, %arg1, %arg2) {acc_type = i32, pad = array<i64: 1, 1, 2, 2>, dilation = array<i64: 2, 1>, stride = array<i64: 1, 1>, quantization_info = #tosa.conv_quant<input_zp = -1, weight_zp = 0>} : (tensor<1x32x32x8x!quant.uniform<i8:f32, 0.015684768557548523>>, tensor<16x1x1x8x!quant.uniform<i8<-127:127>:f32, 0.015680249780416489>>, tensor<16xi32>) -> tensor<1x32x32x16x!quant.uniform<i32:f32, 0.078431375324726104>> + return %0 : tensor<1x32x32x16x!quant.uniform<i32:f32, 0.078431375324726104>> } diff --git a/mlir/test/Dialect/Tosa/tosa-decompose-conv2d.mlir b/mlir/test/Dialect/Tosa/tosa-decompose-conv2d.mlir index d876ccfb3b91..8df4630f9c17 100644 --- a/mlir/test/Dialect/Tosa/tosa-decompose-conv2d.mlir +++ b/mlir/test/Dialect/Tosa/tosa-decompose-conv2d.mlir @@ -14,7 +14,7 @@ func.func @conv2d_as_fully_connected(%arg0: tensor<4x10x10x2xf32>, %arg1: tensor // CHECK: %[[VAR3:.*]] = tosa.reshape %[[VAR2]] {new_shape = array<i64: 4, 10, 10, 3>} // CHECK-SAME: -> tensor<4x10x10x3xf32> // CHECK: return %[[VAR3]] - %0 = tosa.conv2d %arg0, %arg1, %arg2 {pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, dilation = array<i64: 1, 1>} : (tensor<4x10x10x2xf32>, tensor<3x1x1x2xf32>, tensor<3xf32>) -> tensor<4x10x10x3xf32> + %0 = tosa.conv2d %arg0, %arg1, %arg2 {acc_type = f32, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, dilation = array<i64: 1, 1>} : (tensor<4x10x10x2xf32>, tensor<3x1x1x2xf32>, tensor<3xf32>) -> tensor<4x10x10x3xf32> return %0 : tensor<4x10x10x3xf32> } @@ -33,7 +33,7 @@ func.func @conv2d_as_fully_connected_quant(%arg0: tensor<4x10x10x2xi8>, %arg1: t // CHECK: %[[VAR3:.*]] = tosa.reshape %[[VAR2]] {new_shape = array<i64: 4, 10, 10, 3>} // CHECK-SAME: -> tensor<4x10x10x3xi32> // CHECK: return %[[VAR3]] - %0 = tosa.conv2d %arg0, %arg1, %arg2 {pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, dilation = array<i64: 1, 1>, quantization_info = #tosa.conv_quant<input_zp = 42, weight_zp = 24>} : (tensor<4x10x10x2xi8>, tensor<3x1x1x2xi8>, tensor<3xi32>) -> tensor<4x10x10x3xi32> + %0 = tosa.conv2d %arg0, %arg1, %arg2 {acc_type = i32, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, dilation = array<i64: 1, 1>, quantization_info = #tosa.conv_quant<input_zp = 42, weight_zp = 24>} : (tensor<4x10x10x2xi8>, tensor<3x1x1x2xi8>, tensor<3xi32>) -> tensor<4x10x10x3xi32> return %0 : tensor<4x10x10x3xi32> } @@ -50,7 +50,7 @@ func.func @conv_with_dynamic_dim(%arg0: tensor<?x14x14x64xi8>, %arg1: tensor<384 // CHECK: %[[VAL_6:.*]] = tosa.reshape %[[VAL_5]] {new_shape = array<i64: -1, 14, 14, 384>} : (tensor<?x384xi32>) -> tensor<?x14x14x384xi32> // CHECK: return %[[VAL_6]] : tensor<?x14x14x384xi32> // CHECK: } - %0 = tosa.conv2d %arg0, %arg1, %arg2 {dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, quantization_info = #tosa.conv_quant<input_zp = -6, weight_zp = 11>, stride = array<i64: 1, 1>} : (tensor<?x14x14x64xi8>, tensor<384x1x1x64xi8>, tensor<384xi32>) -> tensor<?x14x14x384xi32> + %0 = tosa.conv2d %arg0, %arg1, %arg2 {acc_type = i32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, quantization_info = #tosa.conv_quant<input_zp = -6, weight_zp = 11>, stride = array<i64: 1, 1>} : (tensor<?x14x14x64xi8>, tensor<384x1x1x64xi8>, tensor<384xi32>) -> tensor<?x14x14x384xi32> return %0 : tensor<?x14x14x384xi32> } @@ -58,13 +58,13 @@ func.func @conv_with_dynamic_dim(%arg0: tensor<?x14x14x64xi8>, %arg1: tensor<384 // CHECK-LABEL: @conv2d_as_fully_connected_padded func.func @conv2d_as_fully_connected_padded(%arg0: tensor<4x10x10x2xi8>, %arg1: tensor<3x1x1x2xi8>, %arg2: tensor<3xi32>) -> tensor<4x12x12x3xi32> { - // CHECK-DAG: %[[PAD_SHAPE:.+]] = "tosa.const"() <{value = dense<{{\[\[}}0, 0], [1, 1], [1, 1], [0, 0]]> : tensor<4x2xi64>} + // CHECK-DAG: %[[PAD_SHAPE:.+]] = "tosa.const"() <{value = dense<{{\[}}0, 0, 1, 1, 1, 1, 0, 0]> : tensor<8xi64>} // CHECK-DAG: %[[PAD_VAL:.+]] = "tosa.const"() <{value = dense<42> : tensor<i8>} - // CHECK-DAG: %[[PAD:.+]] = tosa.pad %arg0, %[[PAD_SHAPE]], %[[PAD_VAL]] : (tensor<4x10x10x2xi8>, tensor<4x2xi64>, tensor<i8>) -> tensor<4x12x12x2xi8> + // CHECK-DAG: %[[PAD:.+]] = tosa.pad %arg0, %[[PAD_SHAPE]], %[[PAD_VAL]] : (tensor<4x10x10x2xi8>, tensor<8xi64>, tensor<i8>) -> tensor<4x12x12x2xi8> // CHECK-DAG: %[[RESHAPE_INPUT:.+]] = tosa.reshape %[[PAD]] {new_shape = array<i64: 576, 2>} // CHECK-DAG: %[[RESHAPE_FILTER:.+]] = tosa.reshape %arg1 {new_shape = array<i64: 3, 2>} // CHECK-DAG: %[[FULLY:.+]] = tosa.fully_connected %[[RESHAPE_INPUT]], %[[RESHAPE_FILTER]], %arg2 {quantization_info = #tosa.conv_quant<input_zp = 42, weight_zp = 24>} // CHECK: %[[RESHAPE:.+]] = tosa.reshape %[[FULLY]] {new_shape = array<i64: 4, 12, 12, 3>} - %0 = tosa.conv2d %arg0, %arg1, %arg2 {pad = array<i64: 1, 1, 1, 1>, stride = array<i64: 1, 1>, dilation = array<i64: 1, 1>, quantization_info = #tosa.conv_quant<input_zp = 42, weight_zp = 24>} : (tensor<4x10x10x2xi8>, tensor<3x1x1x2xi8>, tensor<3xi32>) -> tensor<4x12x12x3xi32> + %0 = tosa.conv2d %arg0, %arg1, %arg2 {acc_type = i32, pad = array<i64: 1, 1, 1, 1>, stride = array<i64: 1, 1>, dilation = array<i64: 1, 1>, quantization_info = #tosa.conv_quant<input_zp = 42, weight_zp = 24>} : (tensor<4x10x10x2xi8>, tensor<3x1x1x2xi8>, tensor<3xi32>) -> tensor<4x12x12x3xi32> return %0 : tensor<4x12x12x3xi32> } diff --git a/mlir/test/Dialect/Tosa/tosa-decompose-depthwise.mlir b/mlir/test/Dialect/Tosa/tosa-decompose-depthwise.mlir index 2224bf3f57b2..cfff6396ad48 100644 --- a/mlir/test/Dialect/Tosa/tosa-decompose-depthwise.mlir +++ b/mlir/test/Dialect/Tosa/tosa-decompose-depthwise.mlir @@ -18,7 +18,7 @@ func.func @depthwise_conv2d_as_mul(%arg0: tensor<4x10x10x2xf32>, %arg1: tensor<1 // CHECK: %[[VAR5:.*]] = tosa.add %[[VAR3]], %[[VAR4]] // CHECK-SAME: -> tensor<4x10x10x6xf32> // CHECK: return %[[VAR5]] - %0 = tosa.depthwise_conv2d %arg0, %arg1, %arg2 {pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, dilation = array<i64: 1, 1>} : (tensor<4x10x10x2xf32>, tensor<1x1x2x3xf32>, tensor<6xf32>) -> tensor<4x10x10x6xf32> + %0 = tosa.depthwise_conv2d %arg0, %arg1, %arg2 {acc_type = f32, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, dilation = array<i64: 1, 1>} : (tensor<4x10x10x2xf32>, tensor<1x1x2x3xf32>, tensor<6xf32>) -> tensor<4x10x10x6xf32> return %0 : tensor<4x10x10x6xf32> } @@ -38,7 +38,7 @@ func.func @depthwise_conv2d_as_mul_q(%arg0: tensor<4x10x10x2xi8>, %arg1: tensor< // CHECK: %[[reO:.+]] = tosa.reshape %[[mul]] {new_shape = array<i64: 4, 10, 10, 6>} // CHECK: %[[reArg2:.+]] = tosa.reshape %arg2 {new_shape = array<i64: 1, 1, 1, 6>} // CHECK: %[[add:.+]] = tosa.add %[[reO]], %[[reArg2]] - %0 = tosa.depthwise_conv2d %arg0, %arg1, %arg2 {pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, dilation = array<i64: 1, 1>, quantization_info = #tosa.conv_quant<input_zp = 7, weight_zp = 11>} : (tensor<4x10x10x2xi8>, tensor<1x1x2x3xi8>, tensor<6xi32>) -> tensor<4x10x10x6xi32> + %0 = tosa.depthwise_conv2d %arg0, %arg1, %arg2 {acc_type = i32, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, dilation = array<i64: 1, 1>, quantization_info = #tosa.conv_quant<input_zp = 7, weight_zp = 11>} : (tensor<4x10x10x2xi8>, tensor<1x1x2x3xi8>, tensor<6xi32>) -> tensor<4x10x10x6xi32> return %0 : tensor<4x10x10x6xi32> } @@ -46,15 +46,15 @@ func.func @depthwise_conv2d_as_mul_q(%arg0: tensor<4x10x10x2xi8>, %arg1: tensor< // CHECK-LABEL: @depthwise_conv2d_as_mul_padded func.func @depthwise_conv2d_as_mul_padded(%arg0: tensor<4x10x10x2xf32>, %arg1: tensor<1x1x2x3xf32>, %arg2: tensor<6xf32>) -> tensor<4x12x12x6xf32> { - // CHECK-DAG: %[[pad:.+]] = "tosa.const"() <{value = dense<{{\[\[}}0, 0], [1, 1], [1, 1], [0, 0], [0, 0]]> : tensor<5x2xi64>} + // CHECK-DAG: %[[pad:.+]] = "tosa.const"() <{value = dense<{{\[}}0, 0, 1, 1, 1, 1, 0, 0, 0, 0]> : tensor<10xi64>} // CHECK-DAG: %[[zero:.+]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<f32>} // CHECK: %[[reIn:.+]] = tosa.reshape %arg0 {new_shape = array<i64: 4, 10, 10, 2, 1>} - // CHECK: %[[padded:.+]] = tosa.pad %[[reIn]], %[[pad]], %[[zero]] : (tensor<4x10x10x2x1xf32>, tensor<5x2xi64>, tensor<f32>) -> tensor<4x12x12x2x1xf32> + // CHECK: %[[padded:.+]] = tosa.pad %[[reIn]], %[[pad]], %[[zero]] : (tensor<4x10x10x2x1xf32>, tensor<10xi64>, tensor<f32>) -> tensor<4x12x12x2x1xf32> // CHECK: %[[reArg1:.+]] = tosa.reshape %arg1 {new_shape = array<i64: 1, 1, 1, 2, 3>} // CHECK: %[[mul:.+]] = tosa.mul %3, %[[reArg1]] {shift = 0 : i8} // CHECK: %[[reOut:.+]] = tosa.reshape %[[mul]] {new_shape = array<i64: 4, 12, 12, 6>} // CHECK: %[[reArg2:.+]] = tosa.reshape %arg2 {new_shape = array<i64: 1, 1, 1, 6>} // CHECK: %[[add:.+]] = tosa.add %[[reOut]], %[[reArg2]] - %0 = tosa.depthwise_conv2d %arg0, %arg1, %arg2 {pad = array<i64: 1, 1, 1, 1>, stride = array<i64: 1, 1>, dilation = array<i64: 1, 1>} : (tensor<4x10x10x2xf32>, tensor<1x1x2x3xf32>, tensor<6xf32>) -> tensor<4x12x12x6xf32> + %0 = tosa.depthwise_conv2d %arg0, %arg1, %arg2 {acc_type = f32, pad = array<i64: 1, 1, 1, 1>, stride = array<i64: 1, 1>, dilation = array<i64: 1, 1>} : (tensor<4x10x10x2xf32>, tensor<1x1x2x3xf32>, tensor<6xf32>) -> tensor<4x12x12x6xf32> return %0 : tensor<4x12x12x6xf32> } diff --git a/mlir/test/Dialect/Tosa/tosa-decompose-transpose-conv.mlir b/mlir/test/Dialect/Tosa/tosa-decompose-transpose-conv.mlir index 1f2bb3fb9a36..c361c7c2899f 100644 --- a/mlir/test/Dialect/Tosa/tosa-decompose-transpose-conv.mlir +++ b/mlir/test/Dialect/Tosa/tosa-decompose-transpose-conv.mlir @@ -6,7 +6,7 @@ func.func @transpose_conv2d(%arg0: tensor<2x16x14x3xf32>, %arg1: tensor<5x3x6x3x // CHECK: %[[REV2:.+]] = tosa.reverse %[[REV1]] {axis = 2 : i32} // CHECK: tosa.conv2d %arg0, %[[REV2]], %arg2 // CHECK-SAME: dilation = array<i64: 1, 1>, pad = array<i64: 2, 2, 5, 5>, stride = array<i64: 1, 1> - %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2 {out_pad = array<i64: 0, 0, 0, 0>, out_shape = array<i64: -1, -1, -1, -1>, stride = array<i64: 1, 1>} : (tensor<2x16x14x3xf32>, tensor<5x3x6x3xf32>, tensor<5xf32>) -> tensor<2x18x19x5xf32> + %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2 {acc_type = f32, out_pad = array<i64: 0, 0, 0, 0>, out_shape = array<i64: -1, -1, -1, -1>, stride = array<i64: 1, 1>} : (tensor<2x16x14x3xf32>, tensor<5x3x6x3xf32>, tensor<5xf32>) -> tensor<2x18x19x5xf32> return %0 : tensor<2x18x19x5xf32> } @@ -17,8 +17,8 @@ func.func @transpose_conv2d(%arg0: tensor<2x16x14x3xf32>, %arg1: tensor<5x3x6x3x func.func @transpose_conv2d_quantized(%arg0: tensor<2x16x14x3xi8>, %arg1: tensor<5x3x6x3xi8>, %arg2: tensor<5xi32>) -> (tensor<2x18x19x5xi32>) { // CHECK: %[[REV1:.+]] = tosa.reverse %arg1 {axis = 1 : i32} // CHECK: %[[REV2:.+]] = tosa.reverse %[[REV1]] {axis = 2 : i32} - // CHECK: tosa.conv2d %arg0, %[[REV2]], %arg2 {dilation = array<i64: 1, 1>, pad = array<i64: 2, 2, 5, 5>, quantization_info = #tosa.conv_quant<input_zp = -22, weight_zp = 42>, stride = array<i64: 1, 1>} - %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2 {out_pad = array<i64: 0, 0, 0, 0>, quantization_info = #tosa.conv_quant<input_zp = -22, weight_zp = 42>, out_shape = array<i64: -1, -1, -1, -1>, stride = array<i64: 1, 1>} : (tensor<2x16x14x3xi8>, tensor<5x3x6x3xi8>, tensor<5xi32>) -> tensor<2x18x19x5xi32> + // CHECK: tosa.conv2d %arg0, %[[REV2]], %arg2 {acc_type = i32, dilation = array<i64: 1, 1>, pad = array<i64: 2, 2, 5, 5>, quantization_info = #tosa.conv_quant<input_zp = -22, weight_zp = 42>, stride = array<i64: 1, 1>} + %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2 {acc_type = i32, out_pad = array<i64: 0, 0, 0, 0>, quantization_info = #tosa.conv_quant<input_zp = -22, weight_zp = 42>, out_shape = array<i64: -1, -1, -1, -1>, stride = array<i64: 1, 1>} : (tensor<2x16x14x3xi8>, tensor<5x3x6x3xi8>, tensor<5xi32>) -> tensor<2x18x19x5xi32> return %0 : tensor<2x18x19x5xi32> } @@ -32,6 +32,7 @@ func.func @transpose_conv2d_quantized_padded(%arg0: tensor<2x16x14x3xi8>, %arg1: // CHECK-SAME: dilation = array<i64: 1, 1>, pad = array<i64: 3, 4, 8, 9>, // CHECK-SAME: quantization_info = #tosa.conv_quant<input_zp = -22, weight_zp = 42>, stride = array<i64: 1, 1>} %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2 { + acc_type = i32, out_pad = array<i64: 1, 2, 3, 4>, quantization_info = #tosa.conv_quant<input_zp = -22, weight_zp = 42>, out_shape = array<i64: -1, -1, -1, -1>, @@ -44,7 +45,7 @@ func.func @transpose_conv2d_quantized_padded(%arg0: tensor<2x16x14x3xi8>, %arg1: // CHECK-LABEL: @transpose_conv2d_strided func.func @transpose_conv2d_strided(%arg0: tensor<2x17x15x3xf32>, %arg1: tensor<5x3x5x3xf32>, %arg2: tensor<5xf32>) -> tensor<2x?x?x5xf32> { // Manipulate the weight matrix to handle striding. - // CHECK-DAG: %[[PADV:.+]] = "tosa.const"() <{value = dense<{{\[\[}}0, 0], [0, 1], [0, 1], [0, 0]]> : tensor<4x2xi32>} + // CHECK-DAG: %[[PADV:.+]] = "tosa.const"() <{value = dense<{{\[}}0, 0, 0, 1, 0, 1, 0, 0]> : tensor<8xi32>} // CHECK-DAG: %[[TRANSV:.+]] = "tosa.const"() <{value = dense<[2, 4, 0, 1, 3, 5]> : tensor<6xi32>} // CHECK-DAG: %[[PADW:.+]] = tosa.pad %arg1, %[[PADV]] // CHECK-DAG: %[[RESW1:.+]] = tosa.reshape %[[PADW]] {new_shape = array<i64: 5, 2, 2, 2, 3, 3>} @@ -54,20 +55,20 @@ func.func @transpose_conv2d_strided(%arg0: tensor<2x17x15x3xf32>, %arg1: tensor< // CHECK-DAG: %[[NEWWEIGHT:.+]] = tosa.reverse %[[REV1]] {axis = 2 : i32} // Pad out the input matrix to handle the transpose conv. - // CHECK-DAG: %[[PAD:.+]] = "tosa.const"() <{value = dense<{{\[\[}}0, 0], [1, 1], [1, 1], [0, 0]]> : tensor<4x2xi32>} + // CHECK-DAG: %[[PAD:.+]] = "tosa.const"() <{value = dense<{{\[}}0, 0, 1, 1, 1, 1, 0, 0]> : tensor<8xi32>} // CHECK-DAG: %[[TRANS2:.+]] = "tosa.const"() <{value = dense<[0, 1, 3, 2, 4, 5]> : tensor<6xi32>} // CHECK-DAG: %[[NEWINPUT:.+]] = tosa.pad %arg0, %[[PAD]] // Manipulate the final shape. // CHECK-DAG: %[[BIAS:.+]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<30xf32>} - // CHECK-DAG: %[[CONV:.+]] = tosa.conv2d %[[NEWINPUT]], %[[NEWWEIGHT]], %[[BIAS]] {dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} + // CHECK-DAG: %[[CONV:.+]] = tosa.conv2d %[[NEWINPUT]], %[[NEWWEIGHT]], %[[BIAS]] {acc_type = f32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} // CHECK-DAG: %[[RESHAPE_OUT_1:.+]] = tosa.reshape %[[CONV]] {new_shape = array<i64: 2, 18, 16, 2, 3, 5>} // CHECK-DAG: %[[TRANS_OUT:.+]] = tosa.transpose %[[RESHAPE_OUT_1]], %[[TRANS2]] // CHECK-DAG: %[[RESHAPE_OUT_2:.+]] = tosa.reshape %[[TRANS_OUT]] {new_shape = array<i64: 2, 36, 48, 5>} // CHECK-DAG: %[[SLICE:.+]] = tosa.slice %[[RESHAPE_OUT_2]] {size = array<i64: 2, 35, 47, 5>, start = array<i64: 0, 0, 0, 0>} // CHECK-DAG: %[[RESHAPE_ARG2:.+]] = tosa.reshape %arg2 {new_shape = array<i64: 1, 1, 1, 5>} // CHECK: %[[ADD:.+]] = tosa.add %[[SLICE]], %[[RESHAPE_ARG2]] - %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2 {out_pad = array<i64: 0, 0, 0, 0>, out_shape = array<i64: -1, -1, -1, -1>, stride = array<i64: 2, 3>} : (tensor<2x17x15x3xf32>, tensor<5x3x5x3xf32>, tensor<5xf32>) -> tensor<2x35x47x5xf32> + %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2 {acc_type = f32, out_pad = array<i64: 0, 0, 0, 0>, out_shape = array<i64: -1, -1, -1, -1>, stride = array<i64: 2, 3>} : (tensor<2x17x15x3xf32>, tensor<5x3x5x3xf32>, tensor<5xf32>) -> tensor<2x35x47x5xf32> %1 = tensor.cast %0 : tensor<2x35x47x5xf32> to tensor<2x?x?x5xf32> return %1 : tensor<2x?x?x5xf32> } @@ -77,7 +78,7 @@ func.func @transpose_conv2d_strided(%arg0: tensor<2x17x15x3xf32>, %arg1: tensor< // CHECK-LABEL: @transpose_conv2d_strided_quantized func.func @transpose_conv2d_strided_quantized(%arg0: tensor<2x17x15x3xi8>, %arg1: tensor<5x3x5x3xi8>, %arg2: tensor<5xi32>) -> (tensor<2x35x47x5xi32>) { // Manipulate the weight matrix to handle striding. - // CHECK-DAG: %[[PADV:.+]] = "tosa.const"() <{value = dense<{{\[\[}}0, 0], [0, 1], [0, 1], [0, 0]]> : tensor<4x2xi32>} + // CHECK-DAG: %[[PADV:.+]] = "tosa.const"() <{value = dense<{{\[}}0, 0, 0, 1, 0, 1, 0, 0]> : tensor<8xi32>} // CHECK-DAG: %[[TRANSV:.+]] = "tosa.const"() <{value = dense<[2, 4, 0, 1, 3, 5]> : tensor<6xi32>} // CHECK-DAG: %[[PADW:.+]] = tosa.pad %arg1, %[[PADV]] {quantization_info = #tosa.pad_quant<input_zp = 42>} // CHECK-DAG: %[[RESW1:.+]] = tosa.reshape %[[PADW]] {new_shape = array<i64: 5, 2, 2, 2, 3, 3>} @@ -87,20 +88,20 @@ func.func @transpose_conv2d_strided_quantized(%arg0: tensor<2x17x15x3xi8>, %arg1 // CHECK-DAG: %[[NEWWEIGHT:.+]] = tosa.reverse %[[REV1]] {axis = 2 : i32} // Pad out the input matrix to handle the transpose conv. - // CHECK-DAG: %[[PAD:.+]] = "tosa.const"() <{value = dense<{{\[\[}}0, 0], [1, 1], [1, 1], [0, 0]]> : tensor<4x2xi32>} + // CHECK-DAG: %[[PAD:.+]] = "tosa.const"() <{value = dense<{{\[}}0, 0, 1, 1, 1, 1, 0, 0]> : tensor<8xi32>} // CHECK-DAG: %[[TRANS2:.+]] = "tosa.const"() <{value = dense<[0, 1, 3, 2, 4, 5]> : tensor<6xi32>} // CHECK-DAG: %[[NEWINPUT:.+]] = tosa.pad %arg0, %[[PAD]] {quantization_info = #tosa.pad_quant<input_zp = -22>} // Manipulate the final shape. // CHECK-DAG: %[[BIAS:.+]] = "tosa.const"() <{value = dense<0> : tensor<30xi32>} - // CHECK-DAG: %[[CONV:.+]] = tosa.conv2d %[[NEWINPUT]], %[[NEWWEIGHT]], %[[BIAS]] {dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, quantization_info = #tosa.conv_quant<input_zp = -22, weight_zp = 42>, stride = array<i64: 1, 1>} + // CHECK-DAG: %[[CONV:.+]] = tosa.conv2d %[[NEWINPUT]], %[[NEWWEIGHT]], %[[BIAS]] {acc_type = i32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, quantization_info = #tosa.conv_quant<input_zp = -22, weight_zp = 42>, stride = array<i64: 1, 1>} // CHECK-DAG: %[[RESHAPE_OUT_1:.+]] = tosa.reshape %[[CONV]] {new_shape = array<i64: 2, 18, 16, 2, 3, 5>} // CHECK-DAG: %[[TRANS_OUT:.+]] = tosa.transpose %[[RESHAPE_OUT_1]], %[[TRANS2]] // CHECK-DAG: %[[RESHAPE_OUT_2:.+]] = tosa.reshape %[[TRANS_OUT]] {new_shape = array<i64: 2, 36, 48, 5>} // CHECK-DAG: %[[SLICE:.+]] = tosa.slice %[[RESHAPE_OUT_2]] {size = array<i64: 2, 35, 47, 5>, start = array<i64: 0, 0, 0, 0>} // CHECK-DAG: %[[RESHAPE_ARG2:.+]] = tosa.reshape %arg2 {new_shape = array<i64: 1, 1, 1, 5>} // CHECK: %[[ADD:.+]] = tosa.add %[[SLICE]], %[[RESHAPE_ARG2]] - %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2 {out_pad = array<i64: 0, 0, 0, 0>, quantization_info = #tosa.conv_quant<input_zp = -22, weight_zp = 42>, out_shape = array<i64: -1, -1, -1, -1>, stride = array<i64: 2, 3>} : (tensor<2x17x15x3xi8>, tensor<5x3x5x3xi8>, tensor<5xi32>) -> tensor<2x35x47x5xi32> + %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2 {acc_type = i32, out_pad = array<i64: 0, 0, 0, 0>, quantization_info = #tosa.conv_quant<input_zp = -22, weight_zp = 42>, out_shape = array<i64: -1, -1, -1, -1>, stride = array<i64: 2, 3>} : (tensor<2x17x15x3xi8>, tensor<5x3x5x3xi8>, tensor<5xi32>) -> tensor<2x35x47x5xi32> return %0 : tensor<2x35x47x5xi32> } @@ -108,12 +109,12 @@ func.func @transpose_conv2d_strided_quantized(%arg0: tensor<2x17x15x3xi8>, %arg1 // CHECK-LABEL: @transpose_conv2d_strided_overpad func.func @transpose_conv2d_strided_overpad(%arg0 : tensor<1x16x1x1xi8>, %arg1 : tensor<1x2x1x1xi8>, %arg2 : tensor<1xi32>) -> (tensor<1x19x2x1xi32>) { - // CHECK-DAG: %[[WEIGHT_PAD:.+]] = "tosa.const"() <{value = dense<{{\[}}[0, 0], [0, 0], [0, 1], [0, 0]]> : tensor<4x2xi32> + // CHECK-DAG: %[[WEIGHT_PAD:.+]] = "tosa.const"() <{value = dense<{{\[}}0, 0, 0, 0, 0, 1, 0, 0]> : tensor<8xi32> // CHECK-DAG: %[[WEIGHT_PERMS:.+]] = "tosa.const"() <{value = dense<[2, 4, 0, 1, 3, 5]> : tensor<6xi32>} - // CHECK-DAG: %[[INPUT_PAD:.+]] = "tosa.const"() <{value = dense<{{\[}}[0, 0], [1, 1], [0, 0], [0, 0]]> : tensor<4x2xi32>} + // CHECK-DAG: %[[INPUT_PAD:.+]] = "tosa.const"() <{value = dense<{{\[}}0, 0, 1, 1, 0, 0, 0, 0]> : tensor<8xi32>} // CHECK-DAG: %[[ZERO:.+]] = "tosa.const"() <{value = dense<0> : tensor<2xi32>} // CHECK-DAG: %[[RESULT_PERMS:.+]] = "tosa.const"() <{value = dense<[0, 1, 3, 2, 4, 5]> : tensor<6xi32>} - // CHECK-DAG: %[[RESULT_PAD:.+]] = "tosa.const"() <{value = dense<{{\[}}[0, 0], [2, 0], [0, 0], [0, 0]]> : tensor<4x2xi32>} + // CHECK-DAG: %[[RESULT_PAD:.+]] = "tosa.const"() <{value = dense<{{\[}}0, 0, 2, 0, 0, 0, 0, 0]> : tensor<8xi32>} // CHECK: %[[PAD_WEIGHT:.+]] = tosa.pad %arg1, %[[WEIGHT_PAD]] {quantization_info = #tosa.pad_quant<input_zp = 93>} // CHECK: %[[RESHAPE_WEIGHT_0:.+]] = tosa.reshape %[[PAD_WEIGHT]] {new_shape = array<i64: 1, 2, 1, 1, 2, 1>} // CHECK: %[[TRANSPOSE_WEIGHT:.+]] = tosa.transpose %[[RESHAPE_WEIGHT_0]], %[[WEIGHT_PERMS]] @@ -129,6 +130,7 @@ func.func @transpose_conv2d_strided_overpad(%arg0 : tensor<1x16x1x1xi8>, %arg1 : // CHECK: %[[RESHAPE_ARG2:.+]] = tosa.reshape %arg2 {new_shape = array<i64: 1, 1, 1, 1>} // CHECK: %[[ADD:.+]] = tosa.add %[[PAD_RESULT]], %[[RESHAPE_ARG2]] %2 = tosa.transpose_conv2d %arg0, %arg1, %arg2 { + acc_type = i32, out_pad = array<i64: 2, 0, 0, 1>, out_shape = array<i64: 1, -1, -1, 1>, stride = array<i64: 1, 2>, diff --git a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir index d46de740800e..82f3e22a3872 100644 --- a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir +++ b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir @@ -495,9 +495,9 @@ func.func @test_concat_axis_1(%arg0 : tensor<2x1xf32>, %arg1 : tensor<2x2xf32>) // ----- // CHECK-LABEL: @test_padding_no_const -func.func @test_padding_no_const(%arg0 : tensor<1x2xf32>, %arg1 : tensor<2x2xi32>) -> () { - // CHECK: tosa.pad %arg0, %arg1 : (tensor<1x2xf32>, tensor<2x2xi32>) -> tensor<?x?xf32> - %0 = tosa.pad %arg0, %arg1 : (tensor<1x2xf32>, tensor<2x2xi32>) -> tensor<?x?xf32> +func.func @test_padding_no_const(%arg0 : tensor<1x2xf32>, %arg1 : tensor<4xi32>) -> () { + // CHECK: tosa.pad %arg0, %arg1 : (tensor<1x2xf32>, tensor<4xi32>) -> tensor<?x?xf32> + %0 = tosa.pad %arg0, %arg1 : (tensor<1x2xf32>, tensor<4xi32>) -> tensor<?x?xf32> return } @@ -505,9 +505,9 @@ func.func @test_padding_no_const(%arg0 : tensor<1x2xf32>, %arg1 : tensor<2x2xi32 // CHECK-LABEL:@test_padding_dynamic_input func.func @test_padding_dynamic_input(%arg0 : tensor<1x?xf32>) -> () { - %0 = arith.constant dense<[[1, 2], [3, 4]]> : tensor<2x2xi32> - // CHECK: tosa.pad %arg0, %cst : (tensor<1x?xf32>, tensor<2x2xi32>) -> tensor<4x?xf32> - %1 = tosa.pad %arg0, %0 : (tensor<1x?xf32>, tensor<2x2xi32>) -> tensor<?x?xf32> + %0 = arith.constant dense<[1, 2, 3, 4]> : tensor<4xi32> + // CHECK: tosa.pad %arg0, %cst : (tensor<1x?xf32>, tensor<4xi32>) -> tensor<4x?xf32> + %1 = tosa.pad %arg0, %0 : (tensor<1x?xf32>, tensor<4xi32>) -> tensor<?x?xf32> return } @@ -515,9 +515,9 @@ func.func @test_padding_dynamic_input(%arg0 : tensor<1x?xf32>) -> () { // CHECK-LABEL: @test_padding_simple func.func @test_padding_simple(%arg0 : tensor<1x2xf32>) -> () { - %0 = arith.constant dense<[[1, 2], [3, 4]]> : tensor<2x2xi32> - // CHECK: tosa.pad %arg0, %cst : (tensor<1x2xf32>, tensor<2x2xi32>) -> tensor<4x9xf32> - %1 = tosa.pad %arg0, %0 : (tensor<1x2xf32>, tensor<2x2xi32>) -> tensor<?x?xf32> + %0 = arith.constant dense<[1, 2, 3, 4]> : tensor<4xi32> + // CHECK: tosa.pad %arg0, %cst : (tensor<1x2xf32>, tensor<4xi32>) -> tensor<4x9xf32> + %1 = tosa.pad %arg0, %0 : (tensor<1x2xf32>, tensor<4xi32>) -> tensor<?x?xf32> return } @@ -674,7 +674,7 @@ func.func @test_pool_static(%arg0: tensor<3x5x6x7xf32>) { // CHECK-LABEL: @conv2d_static func.func @conv2d_static(%input: tensor<2x8x9x3xf32>, %weights: tensor<5x3x6x3xf32>, %bias: tensor<5xf32>) -> () { // CHECK: -> tensor<2x6x4x5xf32> - %0 = tosa.conv2d %input, %weights, %bias {pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, dilation = array<i64: 1, 1>} : (tensor<2x8x9x3xf32>, tensor<5x3x6x3xf32>, tensor<5xf32>) -> tensor<?x?x?x?xf32> + %0 = tosa.conv2d %input, %weights, %bias {acc_type = f32, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, dilation = array<i64: 1, 1>} : (tensor<2x8x9x3xf32>, tensor<5x3x6x3xf32>, tensor<5xf32>) -> tensor<?x?x?x?xf32> return } @@ -683,7 +683,7 @@ func.func @conv2d_static(%input: tensor<2x8x9x3xf32>, %weights: tensor<5x3x6x3xf // CHECK-LABEL: @conv2d_dynamic_input func.func @conv2d_dynamic_input(%input: tensor<?x?x?x?xf32>, %weights: tensor<5x3x6x3xf32>, %bias: tensor<5xf32>) -> () { // CHECK: -> tensor<?x?x?x5xf32> - %0 = tosa.conv2d %input, %weights, %bias {pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, dilation = array<i64: 1, 1>} : (tensor<?x?x?x?xf32>, tensor<5x3x6x3xf32>, tensor<5xf32>) -> tensor<?x?x?x?xf32> + %0 = tosa.conv2d %input, %weights, %bias {acc_type = f32, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, dilation = array<i64: 1, 1>} : (tensor<?x?x?x?xf32>, tensor<5x3x6x3xf32>, tensor<5xf32>) -> tensor<?x?x?x?xf32> return } @@ -716,7 +716,7 @@ func.func @test_pool_padded(%arg0: tensor<3x5x6x7xf32>) { // CHECK-LABEL: @conv2d_dynamic_weight func.func @conv2d_dynamic_weight(%input: tensor<2x8x9x3xf32>, %weights: tensor<?x?x?x?xf32>, %bias: tensor<5xf32>) -> () { // CHECK: -> tensor<2x?x?x5xf32> - %0 = tosa.conv2d %input, %weights, %bias {pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, dilation = array<i64: 1, 1>} : (tensor<2x8x9x3xf32>, tensor<?x?x?x?xf32>, tensor<5xf32>) -> tensor<?x?x?x?xf32> + %0 = tosa.conv2d %input, %weights, %bias {acc_type = f32, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, dilation = array<i64: 1, 1>} : (tensor<2x8x9x3xf32>, tensor<?x?x?x?xf32>, tensor<5xf32>) -> tensor<?x?x?x?xf32> return } @@ -725,7 +725,7 @@ func.func @conv2d_dynamic_weight(%input: tensor<2x8x9x3xf32>, %weights: tensor<? // CHECK-LABEL: @conv2d_dynamic_bias func.func @conv2d_dynamic_bias(%input: tensor<2x8x9x3xf32>, %weights: tensor<5x3x6x3xf32>, %bias: tensor<?xf32>) -> () { // CHECK: -> tensor<2x6x4x5xf32> - %0 = tosa.conv2d %input, %weights, %bias {pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, dilation = array<i64: 1, 1>} : (tensor<2x8x9x3xf32>, tensor<5x3x6x3xf32>, tensor<?xf32>) -> tensor<?x?x?x?xf32> + %0 = tosa.conv2d %input, %weights, %bias {acc_type = f32, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, dilation = array<i64: 1, 1>} : (tensor<2x8x9x3xf32>, tensor<5x3x6x3xf32>, tensor<?xf32>) -> tensor<?x?x?x?xf32> return } @@ -746,7 +746,7 @@ func.func @test_pool_stride(%arg0: tensor<3x11x12x7xf32>) { // CHECK-LABEL: @conv2d_padded func.func @conv2d_padded(%input: tensor<2x8x9x3xf32>, %weights: tensor<5x3x6x3xf32>, %bias: tensor<5xf32>) -> () { // CHECK: -> tensor<2x9x11x5xf32> - %0 = tosa.conv2d %input, %weights, %bias {pad = array<i64: 1, 2, 3, 4>, stride = array<i64: 1, 1>, dilation = array<i64: 1, 1>} : (tensor<2x8x9x3xf32>, tensor<5x3x6x3xf32>, tensor<5xf32>) -> tensor<?x?x?x?xf32> + %0 = tosa.conv2d %input, %weights, %bias {acc_type = f32, pad = array<i64: 1, 2, 3, 4>, stride = array<i64: 1, 1>, dilation = array<i64: 1, 1>} : (tensor<2x8x9x3xf32>, tensor<5x3x6x3xf32>, tensor<5xf32>) -> tensor<?x?x?x?xf32> return } @@ -755,7 +755,7 @@ func.func @conv2d_padded(%input: tensor<2x8x9x3xf32>, %weights: tensor<5x3x6x3xf // CHECK-LABEL: @conv2d_dilated func.func @conv2d_dilated(%input: tensor<2x12x14x3xf32>, %weights: tensor<5x3x6x3xf32>, %bias: tensor<5xf32>) -> () { // CHECK: -> tensor<2x6x4x5xf32> - %0 = tosa.conv2d %input, %weights, %bias {pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, dilation = array<i64: 3, 2>} : (tensor<2x12x14x3xf32>, tensor<5x3x6x3xf32>, tensor<5xf32>) -> tensor<?x?x?x?xf32> + %0 = tosa.conv2d %input, %weights, %bias {acc_type = f32, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, dilation = array<i64: 3, 2>} : (tensor<2x12x14x3xf32>, tensor<5x3x6x3xf32>, tensor<5xf32>) -> tensor<?x?x?x?xf32> return } @@ -764,7 +764,7 @@ func.func @conv2d_dilated(%input: tensor<2x12x14x3xf32>, %weights: tensor<5x3x6x // CHECK-LABEL: @conv2d_strided func.func @conv2d_strided(%input: tensor<1x13x14x1xf32>, %weights: tensor<1x1x1x1xf32>, %bias: tensor<1xf32>) -> () { // CHECK: -> tensor<1x5x7x1xf32> - %0 = tosa.conv2d %input, %weights, %bias {pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 3, 2>, dilation = array<i64: 1, 1>} : (tensor<1x13x14x1xf32>, tensor<1x1x1x1xf32>, tensor<1xf32>) -> tensor<?x?x?x?xf32> + %0 = tosa.conv2d %input, %weights, %bias {acc_type = f32, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 3, 2>, dilation = array<i64: 1, 1>} : (tensor<1x13x14x1xf32>, tensor<1x1x1x1xf32>, tensor<1xf32>) -> tensor<?x?x?x?xf32> return } @@ -773,7 +773,7 @@ func.func @conv2d_strided(%input: tensor<1x13x14x1xf32>, %weights: tensor<1x1x1x // CHECK-LABEL: @conv3d_static func.func @conv3d_static(%input: tensor<2x8x9x10x3xf32>, %weights: tensor<5x3x6x4x3xf32>, %bias: tensor<5xf32>) -> () { // CHECK: -> tensor<2x6x4x7x5xf32> - %0 = tosa.conv3d %input, %weights, %bias {dilation = array<i64: 1, 1, 1>, pad = array<i64: 0, 0, 0, 0, 0, 0>, stride = array<i64: 1, 1, 1>} : (tensor<2x8x9x10x3xf32>, tensor<5x3x6x4x3xf32>, tensor<5xf32>) -> tensor<?x?x?x?x?xf32> + %0 = tosa.conv3d %input, %weights, %bias {acc_type = f32, dilation = array<i64: 1, 1, 1>, pad = array<i64: 0, 0, 0, 0, 0, 0>, stride = array<i64: 1, 1, 1>} : (tensor<2x8x9x10x3xf32>, tensor<5x3x6x4x3xf32>, tensor<5xf32>) -> tensor<?x?x?x?x?xf32> return } @@ -782,7 +782,7 @@ func.func @conv3d_static(%input: tensor<2x8x9x10x3xf32>, %weights: tensor<5x3x6x // CHECK-LABEL: @conv3d_dynamic_input func.func @conv3d_dynamic_input(%arg0: tensor<?x?x?x?x?xf32>, %arg1: tensor<5x3x6x4x3xf32>, %arg2: tensor<5xf32>) { // CHECK: -> tensor<?x?x?x?x5xf32> - %0 = tosa.conv3d %arg0, %arg1, %arg2 {dilation = array<i64: 1, 1, 1>, pad = array<i64: 0, 0, 0, 0, 0, 0>, stride = array<i64: 1, 1, 1>} : (tensor<?x?x?x?x?xf32>, tensor<5x3x6x4x3xf32>, tensor<5xf32>) -> tensor<?x?x?x?x?xf32> + %0 = tosa.conv3d %arg0, %arg1, %arg2 {acc_type = f32, dilation = array<i64: 1, 1, 1>, pad = array<i64: 0, 0, 0, 0, 0, 0>, stride = array<i64: 1, 1, 1>} : (tensor<?x?x?x?x?xf32>, tensor<5x3x6x4x3xf32>, tensor<5xf32>) -> tensor<?x?x?x?x?xf32> return } @@ -791,7 +791,7 @@ func.func @conv3d_dynamic_input(%arg0: tensor<?x?x?x?x?xf32>, %arg1: tensor<5x3x // CHECK-LABEL: @conv3d_dynamic_weight func.func @conv3d_dynamic_weight(%arg0: tensor<2x8x9x10x3xf32>, %arg1: tensor<?x?x?x?x?xf32>, %arg2: tensor<5xf32>) { // CHECK: -> tensor<2x?x?x?x5xf32> - %0 = tosa.conv3d %arg0, %arg1, %arg2 {dilation = array<i64: 1, 1, 1>, pad = array<i64: 0, 0, 0, 0, 0, 0>, stride = array<i64: 1, 1, 1>} : (tensor<2x8x9x10x3xf32>, tensor<?x?x?x?x?xf32>, tensor<5xf32>) -> tensor<?x?x?x?x?xf32> + %0 = tosa.conv3d %arg0, %arg1, %arg2 {acc_type = f32, dilation = array<i64: 1, 1, 1>, pad = array<i64: 0, 0, 0, 0, 0, 0>, stride = array<i64: 1, 1, 1>} : (tensor<2x8x9x10x3xf32>, tensor<?x?x?x?x?xf32>, tensor<5xf32>) -> tensor<?x?x?x?x?xf32> return } @@ -800,7 +800,7 @@ func.func @conv3d_dynamic_weight(%arg0: tensor<2x8x9x10x3xf32>, %arg1: tensor<?x // CHECK-LABEL: @conv3d_dynamic_bias func.func @conv3d_dynamic_bias(%arg0: tensor<2x8x9x10x3xf32>, %arg1: tensor<5x3x6x4x3xf32>, %arg2: tensor<?xf32>) { // CHECK: -> tensor<2x6x4x7x5xf32> - %0 = tosa.conv3d %arg0, %arg1, %arg2 {dilation = array<i64: 1, 1, 1>, pad = array<i64: 0, 0, 0, 0, 0, 0>, stride = array<i64: 1, 1, 1>} : (tensor<2x8x9x10x3xf32>, tensor<5x3x6x4x3xf32>, tensor<?xf32>) -> tensor<?x?x?x?x?xf32> + %0 = tosa.conv3d %arg0, %arg1, %arg2 {acc_type = f32, dilation = array<i64: 1, 1, 1>, pad = array<i64: 0, 0, 0, 0, 0, 0>, stride = array<i64: 1, 1, 1>} : (tensor<2x8x9x10x3xf32>, tensor<5x3x6x4x3xf32>, tensor<?xf32>) -> tensor<?x?x?x?x?xf32> return } @@ -809,7 +809,7 @@ func.func @conv3d_dynamic_bias(%arg0: tensor<2x8x9x10x3xf32>, %arg1: tensor<5x3x // CHECK-LABEL: @conv3d_padded func.func @conv3d_padded(%arg0: tensor<2x8x9x10x3xf32>, %arg1: tensor<5x3x6x4x3xf32>, %arg2: tensor<5xf32>) { // CHECK: -> tensor<2x9x11x18x5xf32> - %0 = tosa.conv3d %arg0, %arg1, %arg2 {dilation = array<i64: 1, 1, 1>, pad = array<i64: 1, 2, 3, 4, 5, 6>, stride = array<i64: 1, 1, 1>} : (tensor<2x8x9x10x3xf32>, tensor<5x3x6x4x3xf32>, tensor<5xf32>) -> tensor<?x?x?x?x?xf32> + %0 = tosa.conv3d %arg0, %arg1, %arg2 {acc_type = f32, dilation = array<i64: 1, 1, 1>, pad = array<i64: 1, 2, 3, 4, 5, 6>, stride = array<i64: 1, 1, 1>} : (tensor<2x8x9x10x3xf32>, tensor<5x3x6x4x3xf32>, tensor<5xf32>) -> tensor<?x?x?x?x?xf32> return } @@ -818,7 +818,7 @@ func.func @conv3d_padded(%arg0: tensor<2x8x9x10x3xf32>, %arg1: tensor<5x3x6x4x3x // CHECK-LABEL: @conv3d_dilated func.func @conv3d_dilated(%arg0: tensor<2x12x14x16x3xf32>, %arg1: tensor<5x3x6x2x3xf32>, %arg2: tensor<5xf32>) { // CHECK: -> tensor<2x6x4x12x5xf32> - %0 = tosa.conv3d %arg0, %arg1, %arg2 {dilation = array<i64: 3, 2, 4>, pad = array<i64: 0, 0, 0, 0, 0, 0>, stride = array<i64: 1, 1, 1>} : (tensor<2x12x14x16x3xf32>, tensor<5x3x6x2x3xf32>, tensor<5xf32>) -> tensor<?x?x?x?x?xf32> + %0 = tosa.conv3d %arg0, %arg1, %arg2 {acc_type = f32, dilation = array<i64: 3, 2, 4>, pad = array<i64: 0, 0, 0, 0, 0, 0>, stride = array<i64: 1, 1, 1>} : (tensor<2x12x14x16x3xf32>, tensor<5x3x6x2x3xf32>, tensor<5xf32>) -> tensor<?x?x?x?x?xf32> return } @@ -827,7 +827,7 @@ func.func @conv3d_dilated(%arg0: tensor<2x12x14x16x3xf32>, %arg1: tensor<5x3x6x2 // CHECK-LABEL: @conv3d_strided func.func @conv3d_strided(%arg0: tensor<1x13x14x15x1xf32>, %arg1: tensor<1x1x1x1x1xf32>, %arg2: tensor<1xf32>) { // CHECK: -> tensor<1x5x7x4x1xf32> - %0 = tosa.conv3d %arg0, %arg1, %arg2 {dilation = array<i64: 1, 1, 1>, pad = array<i64: 0, 0, 0, 0, 0, 0>, stride = array<i64: 3, 2, 4>} : (tensor<1x13x14x15x1xf32>, tensor<1x1x1x1x1xf32>, tensor<1xf32>) -> tensor<?x?x?x?x?xf32> + %0 = tosa.conv3d %arg0, %arg1, %arg2 {acc_type = f32, dilation = array<i64: 1, 1, 1>, pad = array<i64: 0, 0, 0, 0, 0, 0>, stride = array<i64: 3, 2, 4>} : (tensor<1x13x14x15x1xf32>, tensor<1x1x1x1x1xf32>, tensor<1xf32>) -> tensor<?x?x?x?x?xf32> return } @@ -836,7 +836,7 @@ func.func @conv3d_strided(%arg0: tensor<1x13x14x15x1xf32>, %arg1: tensor<1x1x1x1 // CHECK-LABEL: @depthwise_conv2d_static func.func @depthwise_conv2d_static(%arg0: tensor<2x8x9x3xf32>, %arg1: tensor<3x6x3x5xf32>, %arg2: tensor<15xf32>) { // CHECK: -> tensor<2x6x4x15xf32> - %0 = tosa.depthwise_conv2d %arg0, %arg1, %arg2 {dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<2x8x9x3xf32>, tensor<3x6x3x5xf32>, tensor<15xf32>) -> tensor<2x6x4x15xf32> + %0 = tosa.depthwise_conv2d %arg0, %arg1, %arg2 {acc_type = f32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<2x8x9x3xf32>, tensor<3x6x3x5xf32>, tensor<15xf32>) -> tensor<2x6x4x15xf32> return } @@ -845,7 +845,7 @@ func.func @depthwise_conv2d_static(%arg0: tensor<2x8x9x3xf32>, %arg1: tensor<3x6 // CHECK-LABEL: @depthwise_conv2d_dynamic_input func.func @depthwise_conv2d_dynamic_input(%arg0: tensor<?x?x?x?xf32>, %arg1: tensor<3x6x3x5xf32>, %arg2: tensor<15xf32>) { // CHECK: -> tensor<?x?x?x15xf32> - %0 = tosa.depthwise_conv2d %arg0, %arg1, %arg2 {dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<?x?x?x?xf32>, tensor<3x6x3x5xf32>, tensor<15xf32>) -> tensor<?x?x?x15xf32> + %0 = tosa.depthwise_conv2d %arg0, %arg1, %arg2 {acc_type = f32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<?x?x?x?xf32>, tensor<3x6x3x5xf32>, tensor<15xf32>) -> tensor<?x?x?x15xf32> return } @@ -854,7 +854,7 @@ func.func @depthwise_conv2d_dynamic_input(%arg0: tensor<?x?x?x?xf32>, %arg1: ten // CHECK-LABEL: @depthwise_conv2d_dynamic_weight func.func @depthwise_conv2d_dynamic_weight(%arg0: tensor<2x8x9x3xf32>, %arg1: tensor<?x?x?x?xf32>, %arg2: tensor<15xf32>) { // CHECK: -> tensor<2x?x?x15xf32> - %0 = tosa.depthwise_conv2d %arg0, %arg1, %arg2 {dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<2x8x9x3xf32>, tensor<?x?x?x?xf32>, tensor<15xf32>) -> tensor<2x?x?x15xf32> + %0 = tosa.depthwise_conv2d %arg0, %arg1, %arg2 {acc_type = f32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<2x8x9x3xf32>, tensor<?x?x?x?xf32>, tensor<15xf32>) -> tensor<2x?x?x15xf32> return } @@ -863,7 +863,7 @@ func.func @depthwise_conv2d_dynamic_weight(%arg0: tensor<2x8x9x3xf32>, %arg1: te // CHECK-LABEL: @depthwise_conv2d_dynamic_bias func.func @depthwise_conv2d_dynamic_bias(%arg0: tensor<2x8x9x3xf32>, %arg1: tensor<3x6x3x5xf32>, %arg2: tensor<?xf32>) { // CHECK: -> tensor<2x6x4x15xf32> - %0 = tosa.depthwise_conv2d %arg0, %arg1, %arg2 {dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<2x8x9x3xf32>, tensor<3x6x3x5xf32>, tensor<?xf32>) -> tensor<2x6x4x15xf32> + %0 = tosa.depthwise_conv2d %arg0, %arg1, %arg2 {acc_type = f32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<2x8x9x3xf32>, tensor<3x6x3x5xf32>, tensor<?xf32>) -> tensor<2x6x4x15xf32> return } @@ -872,7 +872,7 @@ func.func @depthwise_conv2d_dynamic_bias(%arg0: tensor<2x8x9x3xf32>, %arg1: tens // CHECK-LABEL: @depthwise_conv2d_padded func.func @depthwise_conv2d_padded(%arg0: tensor<2x8x9x3xf32>, %arg1: tensor<3x6x3x5xf32>, %arg2: tensor<15xf32>) { // CHECK: -> tensor<2x9x11x15xf32> - %0 = tosa.depthwise_conv2d %arg0, %arg1, %arg2 {dilation = array<i64: 1, 1>, pad = array<i64: 1, 2, 3, 4>, stride = array<i64: 1, 1>} : (tensor<2x8x9x3xf32>, tensor<3x6x3x5xf32>, tensor<15xf32>) -> tensor<2x9x11x15xf32> + %0 = tosa.depthwise_conv2d %arg0, %arg1, %arg2 {acc_type = f32, dilation = array<i64: 1, 1>, pad = array<i64: 1, 2, 3, 4>, stride = array<i64: 1, 1>} : (tensor<2x8x9x3xf32>, tensor<3x6x3x5xf32>, tensor<15xf32>) -> tensor<2x9x11x15xf32> return } @@ -881,7 +881,7 @@ func.func @depthwise_conv2d_padded(%arg0: tensor<2x8x9x3xf32>, %arg1: tensor<3x6 // CHECK-LABEL: @depthwise_conv2d_dilated func.func @depthwise_conv2d_dilated(%arg0: tensor<2x12x14x3xf32>, %arg1: tensor<3x6x3x5xf32>, %arg2: tensor<15xf32>) { // CHECK: -> tensor<2x6x4x15xf32> - %0 = tosa.depthwise_conv2d %arg0, %arg1, %arg2 {dilation = array<i64: 3, 2>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<2x12x14x3xf32>, tensor<3x6x3x5xf32>, tensor<15xf32>) -> tensor<2x6x4x15xf32> + %0 = tosa.depthwise_conv2d %arg0, %arg1, %arg2 {acc_type = f32, dilation = array<i64: 3, 2>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<2x12x14x3xf32>, tensor<3x6x3x5xf32>, tensor<15xf32>) -> tensor<2x6x4x15xf32> return } @@ -890,7 +890,7 @@ func.func @depthwise_conv2d_dilated(%arg0: tensor<2x12x14x3xf32>, %arg1: tensor< // CHECK-LABEL: @depthwise_conv2d_strided func.func @depthwise_conv2d_strided(%arg0: tensor<1x13x14x1xf32>, %arg1: tensor<1x1x1x1xf32>, %arg2: tensor<1xf32>) { // CHECK: -> tensor<1x5x7x1xf32> - %0 = tosa.depthwise_conv2d %arg0, %arg1, %arg2 {dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 3, 2>} : (tensor<1x13x14x1xf32>, tensor<1x1x1x1xf32>, tensor<1xf32>) -> tensor<1x5x7x1xf32> + %0 = tosa.depthwise_conv2d %arg0, %arg1, %arg2 {acc_type = f32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 3, 2>} : (tensor<1x13x14x1xf32>, tensor<1x1x1x1xf32>, tensor<1xf32>) -> tensor<1x5x7x1xf32> return } @@ -899,7 +899,7 @@ func.func @depthwise_conv2d_strided(%arg0: tensor<1x13x14x1xf32>, %arg1: tensor< // CHECK-LABEL: @transpose_conv2d_out_shape func.func @transpose_conv2d_out_shape(%arg0: tensor<2x?x?x3xf32>, %arg1: tensor<5x3x6x3xf32>, %arg2: tensor<5xf32>) { // CHECK: -> tensor<2x8x9x5xf32> - %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2 {out_pad = array<i64: 0, 0, 0, 0>, out_shape = array<i64: -1, 8, 9, -1>, stride = array<i64: 1, 1>} : (tensor<2x?x?x3xf32>, tensor<5x3x6x3xf32>, tensor<5xf32>) -> tensor<2x8x9x5xf32> + %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2 {acc_type = f32, out_pad = array<i64: 0, 0, 0, 0>, out_shape = array<i64: -1, 8, 9, -1>, stride = array<i64: 1, 1>} : (tensor<2x?x?x3xf32>, tensor<5x3x6x3xf32>, tensor<5xf32>) -> tensor<2x8x9x5xf32> return } @@ -908,7 +908,7 @@ func.func @transpose_conv2d_out_shape(%arg0: tensor<2x?x?x3xf32>, %arg1: tensor< // CHECK-LABEL: @transpose_conv2d_static func.func @transpose_conv2d_static(%arg0: tensor<2x16x14x3xf32>, %arg1: tensor<5x3x6x3xf32>, %arg2: tensor<5xf32>) { // CHECK: -> tensor<2x18x19x5xf32> - %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2 {out_pad = array<i64: 0, 0, 0, 0>, out_shape = array<i64: -1, -1, -1, -1>, stride = array<i64: 1, 1>} : (tensor<2x16x14x3xf32>, tensor<5x3x6x3xf32>, tensor<5xf32>) -> tensor<2x?x?x5xf32> + %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2 {acc_type = f32, out_pad = array<i64: 0, 0, 0, 0>, out_shape = array<i64: -1, -1, -1, -1>, stride = array<i64: 1, 1>} : (tensor<2x16x14x3xf32>, tensor<5x3x6x3xf32>, tensor<5xf32>) -> tensor<2x?x?x5xf32> return } @@ -917,7 +917,7 @@ func.func @transpose_conv2d_static(%arg0: tensor<2x16x14x3xf32>, %arg1: tensor<5 // CHECK-LABEL: @transpose_conv2d_static_strided func.func @transpose_conv2d_static_strided(%arg0: tensor<2x16x14x3xf32>, %arg1: tensor<5x3x6x3xf32>, %arg2: tensor<5xf32>) { // CHECK: -> tensor<2x33x45x5xf32> - %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2 {out_pad = array<i64: 0, 0, 0, 0>, out_shape = array<i64: -1, -1, -1, -1>, stride = array<i64: 2, 3>} : (tensor<2x16x14x3xf32>, tensor<5x3x6x3xf32>, tensor<5xf32>) -> tensor<2x?x?x5xf32> + %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2 {acc_type = f32, out_pad = array<i64: 0, 0, 0, 0>, out_shape = array<i64: -1, -1, -1, -1>, stride = array<i64: 2, 3>} : (tensor<2x16x14x3xf32>, tensor<5x3x6x3xf32>, tensor<5xf32>) -> tensor<2x?x?x5xf32> return } @@ -926,7 +926,7 @@ func.func @transpose_conv2d_static_strided(%arg0: tensor<2x16x14x3xf32>, %arg1: // CHECK-LABEL: @transpose_conv2d_dynamic_input func.func @transpose_conv2d_dynamic_input(%arg0: tensor<?x?x?x?xf32>, %arg1: tensor<5x3x6x3xf32>, %arg2: tensor<5xf32>) { // CHECK: -> tensor<?x?x?x5xf32> - %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2 {out_pad = array<i64: 0, 0, 0, 0>, out_shape = array<i64: -1, -1, -1, -1>, stride = array<i64: 1, 1>} : (tensor<?x?x?x?xf32>, tensor<5x3x6x3xf32>, tensor<5xf32>) -> tensor<?x?x?x5xf32> + %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2 {acc_type = f32, out_pad = array<i64: 0, 0, 0, 0>, out_shape = array<i64: -1, -1, -1, -1>, stride = array<i64: 1, 1>} : (tensor<?x?x?x?xf32>, tensor<5x3x6x3xf32>, tensor<5xf32>) -> tensor<?x?x?x5xf32> return } @@ -935,7 +935,7 @@ func.func @transpose_conv2d_dynamic_input(%arg0: tensor<?x?x?x?xf32>, %arg1: ten // CHECK-LABEL: @transpose_conv2d_dynamic_weights func.func @transpose_conv2d_dynamic_weights(%arg0: tensor<2x6x4x3xf32>, %arg1: tensor<?x?x?x?xf32>, %arg2: tensor<5xf32>) { // CHECK: -> tensor<2x?x?x5xf32> - %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2 {out_pad = array<i64: 0, 0, 0, 0>, out_shape = array<i64: -1, -1, -1, -1>, stride = array<i64: 1, 1>} : (tensor<2x6x4x3xf32>, tensor<?x?x?x?xf32>, tensor<5xf32>) -> tensor<2x?x?x5xf32> + %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2 {acc_type = f32, out_pad = array<i64: 0, 0, 0, 0>, out_shape = array<i64: -1, -1, -1, -1>, stride = array<i64: 1, 1>} : (tensor<2x6x4x3xf32>, tensor<?x?x?x?xf32>, tensor<5xf32>) -> tensor<2x?x?x5xf32> return } @@ -944,7 +944,7 @@ func.func @transpose_conv2d_dynamic_weights(%arg0: tensor<2x6x4x3xf32>, %arg1: t // CHECK-LABEL: @transpose_conv2d_dynamic_bias func.func @transpose_conv2d_dynamic_bias(%arg0: tensor<2x6x4x3xf32>, %arg1: tensor<5x3x6x3xf32>, %arg2: tensor<?xf32>) { // CHECK: -> tensor<2x8x9x5xf32> - %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2 {out_pad = array<i64: 0, 0, 0, 0>, out_shape = array<i64: -1, -1, -1, -1>, stride = array<i64: 1, 1>} : (tensor<2x6x4x3xf32>, tensor<5x3x6x3xf32>, tensor<?xf32>) -> tensor<2x8x9x5xf32> + %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2 {acc_type = f32, out_pad = array<i64: 0, 0, 0, 0>, out_shape = array<i64: -1, -1, -1, -1>, stride = array<i64: 1, 1>} : (tensor<2x6x4x3xf32>, tensor<5x3x6x3xf32>, tensor<?xf32>) -> tensor<2x8x9x5xf32> return } @@ -953,14 +953,14 @@ func.func @transpose_conv2d_dynamic_bias(%arg0: tensor<2x6x4x3xf32>, %arg1: tens // CHECK-LABEL: @transpose_conv2d_padded func.func @transpose_conv2d_padded(%arg0: tensor<2x9x11x3xf32>, %arg1: tensor<5x3x6x3xf32>, %arg2: tensor<5xf32>) { // CHECK: -> tensor<2x10x13x5xf32> - %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2 {out_pad = array<i64: 1, 0, 3, 0>, out_shape = array<i64: -1, -1, -1, -1>, stride = array<i64: 1, 1>} : (tensor<2x9x11x3xf32>, tensor<5x3x6x3xf32>, tensor<5xf32>) -> tensor<2x10x13x5xf32> + %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2 {acc_type = f32, out_pad = array<i64: 1, 0, 3, 0>, out_shape = array<i64: -1, -1, -1, -1>, stride = array<i64: 1, 1>} : (tensor<2x9x11x3xf32>, tensor<5x3x6x3xf32>, tensor<5xf32>) -> tensor<2x10x13x5xf32> return } // CHECK-LABEL: @transpose_conv2d_strided func.func @transpose_conv2d_strided(%arg0: tensor<1x5x7x1xf32>, %arg1: tensor<1x1x1x1xf32>, %arg2: tensor<1xf32>) { // CHECK: -> tensor<1x13x13x1xf32> - %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2 {out_pad = array<i64: 0, 0, 0, 0>, out_shape = array<i64: -1, -1, -1, -1>, stride = array<i64: 3, 2>} : (tensor<1x5x7x1xf32>, tensor<1x1x1x1xf32>, tensor<1xf32>) -> tensor<1x13x13x1xf32> + %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2 {acc_type = f32, out_pad = array<i64: 0, 0, 0, 0>, out_shape = array<i64: -1, -1, -1, -1>, stride = array<i64: 3, 2>} : (tensor<1x5x7x1xf32>, tensor<1x1x1x1xf32>, tensor<1xf32>) -> tensor<1x13x13x1xf32> return } @@ -1368,7 +1368,7 @@ func.func @test_non_tosa_consumer_still_propagates(%arg0: tensor<1x1x8xf32>, %ar func.func @test_tosa_use_def_chain(%arg0: tensor<1x32x32x3xf32>, %arg1: tensor<16x3x3x3xf32>, %arg2: tensor<16xf32>) -> tensor<?x16x16x16xf32> { // CHECK: [[CONV:%.+]] = tosa.conv2d %arg0, %arg1, %arg2 // CHECK: (tensor<1x32x32x3xf32>, tensor<16x3x3x3xf32>, tensor<16xf32>) -> tensor<1x32x32x16xf32> - %0 = tosa.conv2d %arg0, %arg1, %arg2 {dilation = array<i64: 1, 1>, pad = array<i64: 1, 1, 1, 1>, stride = array<i64: 1, 1>} : (tensor<1x32x32x3xf32>, tensor<16x3x3x3xf32>, tensor<16xf32>) -> tensor<?x32x32x16xf32> + %0 = tosa.conv2d %arg0, %arg1, %arg2 {acc_type = f32, dilation = array<i64: 1, 1>, pad = array<i64: 1, 1, 1, 1>, stride = array<i64: 1, 1>} : (tensor<1x32x32x3xf32>, tensor<16x3x3x3xf32>, tensor<16xf32>) -> tensor<?x32x32x16xf32> // CHECK: tosa.max_pool2d [[CONV]] // CHECK: (tensor<1x32x32x16xf32>) -> tensor<1x16x16x16xf32> %1 = tosa.max_pool2d %0 {kernel = array<i64: 2, 2>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 2, 2>} : (tensor<?x32x32x16xf32>) -> tensor<?x16x16x16xf32> diff --git a/mlir/test/Integration/GPU/CUDA/assert.mlir b/mlir/test/Integration/GPU/CUDA/assert.mlir new file mode 100644 index 000000000000..3d6527fe59b2 --- /dev/null +++ b/mlir/test/Integration/GPU/CUDA/assert.mlir @@ -0,0 +1,38 @@ +// RUN: mlir-opt %s -gpu-lower-to-nvvm-pipeline="cubin-format=%gpu_compilation_format" \ +// RUN: | mlir-cpu-runner \ +// RUN: --shared-libs=%mlir_cuda_runtime \ +// RUN: --shared-libs=%mlir_runner_utils \ +// RUN: --entry-point-result=void 2>&1 \ +// RUN: | FileCheck %s + +// CHECK-DAG: thread 0: print after passing assertion +// CHECK-DAG: thread 1: print after passing assertion +// CHECK-DAG: callee_file.cc:7: callee_func_name: block: [0,0,0], thread: [0,0,0] Assertion `failing assertion` failed. +// CHECK-DAG: callee_file.cc:7: callee_func_name: block: [0,0,0], thread: [1,0,0] Assertion `failing assertion` failed. +// CHECK-NOT: print after failing assertion + +module attributes {gpu.container_module} { +gpu.module @kernels { +gpu.func @test_assert(%c0: i1, %c1: i1) kernel { + %0 = gpu.thread_id x + cf.assert %c1, "passing assertion" + gpu.printf "thread %lld: print after passing assertion\n", %0 : index + // Test callsite(callsite(name)) location. + cf.assert %c0, "failing assertion" loc(callsite(callsite("callee_func_name"("callee_file.cc":7:9) at "caller_file.cc":10:8) at "caller2_file.cc":11:12)) + gpu.printf "thread %lld: print after failing assertion\n", %0 : index + gpu.return +} +} + +func.func @main() { + %c2 = arith.constant 2 : index + %c1 = arith.constant 1 : index + %c0_i1 = arith.constant 0 : i1 + %c1_i1 = arith.constant 1 : i1 + gpu.launch_func @kernels::@test_assert + blocks in (%c1, %c1, %c1) + threads in (%c2, %c1, %c1) + args(%c0_i1 : i1, %c1_i1 : i1) + return +} +} diff --git a/mlir/test/Integration/GPU/CUDA/printf.mlir b/mlir/test/Integration/GPU/CUDA/printf.mlir index 99ea1208e9c5..15b0bf02d911 100644 --- a/mlir/test/Integration/GPU/CUDA/printf.mlir +++ b/mlir/test/Integration/GPU/CUDA/printf.mlir @@ -14,7 +14,7 @@ module attributes {gpu.container_module} { %0 = gpu.thread_id x %csti8 = arith.constant 2 : i8 %cstf32 = arith.constant 3.0 : f32 - gpu.printf "Hello from %lld, %d, %f\n" %0, %csti8, %cstf32 : index, i8, f32 + gpu.printf "Hello from %lld, %d, %f\n", %0, %csti8, %cstf32 : index, i8, f32 gpu.return } } diff --git a/mlir/test/Integration/GPU/CUDA/sm90/cga_cluster.mlir b/mlir/test/Integration/GPU/CUDA/sm90/cga_cluster.mlir index c70c940564a2..a22a34b9393a 100644 --- a/mlir/test/Integration/GPU/CUDA/sm90/cga_cluster.mlir +++ b/mlir/test/Integration/GPU/CUDA/sm90/cga_cluster.mlir @@ -43,7 +43,7 @@ module attributes {gpu.container_module} { %cnd2 = arith.cmpi eq, %bidY, %c3 : index scf.if %cnd1 { scf.if %cnd2 { - gpu.printf "clusterIdx: (%d, %d, %d) in Cluster Dimension: (%d, %d, %d) blockIdx: (%d, %d, %d) \n" + gpu.printf "clusterIdx: (%d, %d, %d) in Cluster Dimension: (%d, %d, %d) blockIdx: (%d, %d, %d) \n", %cidX_i32, %cidY_i32, %cidZ_i32, diff --git a/mlir/test/Integration/GPU/CUDA/sm90/tma_load_128x64_swizzle128b.mlir b/mlir/test/Integration/GPU/CUDA/sm90/tma_load_128x64_swizzle128b.mlir index b50772f8249f..95bde40deb48 100644 --- a/mlir/test/Integration/GPU/CUDA/sm90/tma_load_128x64_swizzle128b.mlir +++ b/mlir/test/Integration/GPU/CUDA/sm90/tma_load_128x64_swizzle128b.mlir @@ -85,7 +85,7 @@ module @mymod { // Step 7. First thread does TMA load scf.if %10 { - gpu.printf "[GPU] TMA SIZE %d\0A" %c8192 : index + gpu.printf "[GPU] TMA SIZE %d\0A", %c8192 : index nvgpu.tma.async.load %3[%c0, %c0], %9[%c0] to %7 : !lhsTensorMap, !barrierType -> !shmemlhs nvgpu.mbarrier.arrive.expect_tx %9[%c0], %c8192 : !barrierType } else { @@ -98,16 +98,16 @@ module @mymod { // Step 9. Print loaded data in 128b swizzled scf.if %10 { - gpu.printf "===--- Matrix A ---=== %d \0A" %c-1_i32 : i32 + gpu.printf "===--- Matrix A ---=== %d \0A", %c-1_i32 : i32 scf.for %arg12 = %c0 to %c128 step %c1 { scf.for %arg13 = %c0 to %c64 step %c1 { %15 = memref.load %7[%arg12, %arg13] : !shmemlhs %16 = arith.extf %15 : f16 to f32 - gpu.printf "%.0f, " %16 : f32 + gpu.printf "%.0f, ", %16 : f32 } - gpu.printf "%d\0A" %c-1_i32 : i32 + gpu.printf "%d\0A", %c-1_i32 : i32 } - gpu.printf "===----------------=== %d \0A" %c-1_i32 : i32 + gpu.printf "===----------------=== %d \0A", %c-1_i32 : i32 } gpu.terminator } diff --git a/mlir/test/Integration/GPU/CUDA/sm90/tma_load_64x64_swizzle128b.mlir b/mlir/test/Integration/GPU/CUDA/sm90/tma_load_64x64_swizzle128b.mlir index 65e5fc0aff6a..e76fa03903b8 100644 --- a/mlir/test/Integration/GPU/CUDA/sm90/tma_load_64x64_swizzle128b.mlir +++ b/mlir/test/Integration/GPU/CUDA/sm90/tma_load_64x64_swizzle128b.mlir @@ -109,7 +109,7 @@ module @mymod { // Step 6. First thread does TMA load scf.if %10 { - gpu.printf "[GPU] TMA SIZE %d\0A" %c32768 : index + gpu.printf "[GPU] TMA SIZE %d\0A", %c32768 : index nvgpu.tma.async.load %d_lhsTensorMap[%c0, %c0], %9[%c0] to %lhsShmem : !lhsTensorMap, !barrierType -> !shmemlhs nvgpu.tma.async.load %d_rhsTensorMap[%c0, %c0], %9[%c0] to %rhsShmem1 : !rhsTensorMap, !barrierType -> memref<64x64xf16, strided<[128, 1]>, 3> nvgpu.tma.async.load %d_rhsTensorMap[%c64, %c0], %9[%c0] to %rhsShmem2 : !rhsTensorMap, !barrierType -> memref<64x64xf16, strided<[128, 1], offset: 4096>, 3> @@ -124,16 +124,16 @@ module @mymod { // Step 8. Print loaded data in 128b swizzled scf.if %10 { - gpu.printf "===--- Matrix B ---=== %d \n" %c-1_i32 : i32 + gpu.printf "===--- Matrix B ---=== %d \n", %c-1_i32 : i32 scf.for %ii = %c0 to %c64 step %c1 { scf.for %j = %c0 to %c128 step %c1 { %lhs0 = memref.load %rhsShmem[%ii, %j] : !shmemrhs %lhs032 = arith.extf %lhs0: f16 to f32 - gpu.printf "%.0f, " %lhs032 : f32 + gpu.printf "%.0f, ", %lhs032 : f32 } - gpu.printf "%d\n" %c-1_i32 : i32 + gpu.printf "%d\n", %c-1_i32 : i32 } - gpu.printf "===----------------=== %d \n" %c-1_i32 : i32 + gpu.printf "===----------------=== %d \n", %c-1_i32 : i32 } gpu.barrier gpu.terminator diff --git a/mlir/test/Integration/GPU/CUDA/sm90/tma_load_64x8_8x128_noswizzle.mlir b/mlir/test/Integration/GPU/CUDA/sm90/tma_load_64x8_8x128_noswizzle.mlir index 391fda82e1e1..acca9811f570 100644 --- a/mlir/test/Integration/GPU/CUDA/sm90/tma_load_64x8_8x128_noswizzle.mlir +++ b/mlir/test/Integration/GPU/CUDA/sm90/tma_load_64x8_8x128_noswizzle.mlir @@ -80,8 +80,8 @@ module @mymod { nvgpu.mbarrier.arrive.expect_tx %9[%c0], %c6144 : <memorySpace = #gpu.address_space<workgroup>> %11 = memref.load %7[%c0, %c0] : memref<64x8xf32, 3> %12 = memref.load %8[%c0, %c0] : memref<8x128xf32, 3> - gpu.printf "[GPU] TMA BEFORE lhs[45][7] %f\0A" %11 : f32 - gpu.printf "[GPU] TMA BEFORE rhs[7][0] %f\0A" %12 : f32 + gpu.printf "[GPU] TMA BEFORE lhs[45][7] %f\0A", %11 : f32 + gpu.printf "[GPU] TMA BEFORE rhs[7][0] %f\0A", %12 : f32 nvgpu.tma.async.load %3[%c0, %c0], %9[%c0] to %7 : <tensor = memref<64x8xf32, 3>, swizzle = none, l2promo = none, oob = zero, interleave = none>, <memorySpace = #gpu.address_space<workgroup>> -> memref<64x8xf32, 3> nvgpu.tma.async.load %4[%c0, %c0], %9[%c0] to %8 : <tensor = memref<8x128xf32, 3>, swizzle = none, l2promo = none, oob = zero, interleave = none>, <memorySpace = #gpu.address_space<workgroup>> -> memref<8x128xf32, 3> } else { @@ -92,8 +92,8 @@ module @mymod { scf.if %10 { %11 = memref.load %7[%c45, %c7] : memref<64x8xf32, 3> %12 = memref.load %8[%c7, %c0] : memref<8x128xf32, 3> - gpu.printf "[GPU] TMA LOADED lhs[45][7] %f\0A" %11 : f32 - gpu.printf "[GPU] TMA LOADED rhs[7][0] %f\0A" %12 : f32 + gpu.printf "[GPU] TMA LOADED lhs[45][7] %f\0A", %11 : f32 + gpu.printf "[GPU] TMA LOADED rhs[7][0] %f\0A", %12 : f32 } gpu.terminator } diff --git a/mlir/test/Integration/GPU/CUDA/sm90/transform-dialect/tma_load_64x8_8x128_noswizzle-transform.mlir b/mlir/test/Integration/GPU/CUDA/sm90/transform-dialect/tma_load_64x8_8x128_noswizzle-transform.mlir index f83f65bb2963..fe6c645357ec 100644 --- a/mlir/test/Integration/GPU/CUDA/sm90/transform-dialect/tma_load_64x8_8x128_noswizzle-transform.mlir +++ b/mlir/test/Integration/GPU/CUDA/sm90/transform-dialect/tma_load_64x8_8x128_noswizzle-transform.mlir @@ -96,8 +96,8 @@ func.func @main() { scf.if %10 { %11 = memref.load %out[%c45, %c7] : memref<64x8xf32, 3> %12 = memref.load %out_1[%c7, %c0] : memref<8x128xf32, 3> - gpu.printf "[GPU] TMA LOADED lhs[45][7] %f\0A" %11 : f32 - gpu.printf "[GPU] TMA LOADED rhs[7][0] %f\0A" %12 : f32 + gpu.printf "[GPU] TMA LOADED lhs[45][7] %f\0A", %11 : f32 + gpu.printf "[GPU] TMA LOADED rhs[7][0] %f\0A", %12 : f32 } gpu.terminator } diff --git a/mlir/test/Integration/GPU/ROCM/printf.mlir b/mlir/test/Integration/GPU/ROCM/printf.mlir index d5e6e3757540..4a0e4d34bfab 100644 --- a/mlir/test/Integration/GPU/ROCM/printf.mlir +++ b/mlir/test/Integration/GPU/ROCM/printf.mlir @@ -13,7 +13,7 @@ module attributes {gpu.container_module} { gpu.module @kernels { gpu.func @hello() kernel { %0 = gpu.thread_id x - gpu.printf "Hello from %d\n" %0 : index + gpu.printf "Hello from %d\n", %0 : index gpu.return } } diff --git a/mlir/test/Target/LLVMIR/Import/import-failure.ll b/mlir/test/Target/LLVMIR/Import/import-failure.ll index 6bde174642d5..b616cb81e0a8 100644 --- a/mlir/test/Target/LLVMIR/Import/import-failure.ll +++ b/mlir/test/Target/LLVMIR/Import/import-failure.ll @@ -13,15 +13,6 @@ bb2: ; // ----- ; CHECK: <unknown> -; CHECK-SAME: error: unhandled value: ptr asm "bswap $0", "=r,r" -define i32 @unhandled_value(i32 %arg1) { - %1 = call i32 asm "bswap $0", "=r,r"(i32 %arg1) - ret i32 %1 -} - -; // ----- - -; CHECK: <unknown> ; CHECK-SAME: unhandled constant: ptr blockaddress(@unhandled_constant, %bb1) since blockaddress(...) is unsupported ; CHECK: <unknown> ; CHECK-SAME: error: unhandled instruction: ret ptr blockaddress(@unhandled_constant, %bb1) diff --git a/mlir/test/Target/LLVMIR/Import/instructions.ll b/mlir/test/Target/LLVMIR/Import/instructions.ll index fff48bbc486b..7377e2584110 100644 --- a/mlir/test/Target/LLVMIR/Import/instructions.ll +++ b/mlir/test/Target/LLVMIR/Import/instructions.ll @@ -535,6 +535,17 @@ define void @indirect_vararg_call(ptr addrspace(42) %fn) { ; // ----- +; CHECK-LABEL: @inlineasm +; CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]] +define i32 @inlineasm(i32 %arg1) { + ; CHECK: %[[RES:.+]] = llvm.inline_asm has_side_effects "bswap $0", "=r,r" %[[ARG1]] : (i32) -> i32 + %1 = call i32 asm "bswap $0", "=r,r"(i32 %arg1) + ; CHECK: return %[[RES]] + ret i32 %1 +} + +; // ----- + ; CHECK-LABEL: @gep_static_idx ; CHECK-SAME: %[[PTR:[a-zA-Z0-9]+]] define void @gep_static_idx(ptr %ptr) { diff --git a/mlir/test/Target/LLVMIR/Import/metadata-alias-scopes.ll b/mlir/test/Target/LLVMIR/Import/metadata-alias-scopes.ll index f5128ff76bc5..bf4c85786216 100644 --- a/mlir/test/Target/LLVMIR/Import/metadata-alias-scopes.ll +++ b/mlir/test/Target/LLVMIR/Import/metadata-alias-scopes.ll @@ -92,3 +92,38 @@ declare void @foo(ptr %arg1) !0 = distinct !{!0, !"The domain"} !1 = !{!1, !0} !2 = !{!1} + +; // ----- + +; CHECK: #[[DOMAIN:.*]] = #llvm.alias_scope_domain<id = "domain1"> +; CHECK: #[[$SCOPE0:.*]] = #llvm.alias_scope<id = "scopeid1", domain = #[[DOMAIN]], description = "The first scope"> +; CHECK: #[[$SCOPE1:.*]] = #llvm.alias_scope<id = "scopeid2", domain = #[[DOMAIN]]> +; CHECK: #[[$SCOPE2:.*]] = #llvm.alias_scope<id = "scopeid3", domain = #[[DOMAIN]]> + +; CHECK-LABEL: llvm.func @alias_scope +define void @alias_scope(ptr %arg1) { + ; CHECK: llvm.load + ; CHECK-SAME: alias_scopes = [#[[$SCOPE0]]] + ; CHECK-SAME: noalias_scopes = [#[[$SCOPE1]], #[[$SCOPE2]]] + %1 = load i32, ptr %arg1, !alias.scope !4, !noalias !7 + ; CHECK: llvm.load + ; CHECK-SAME: alias_scopes = [#[[$SCOPE1]]] + ; CHECK-SAME: noalias_scopes = [#[[$SCOPE0]], #[[$SCOPE2]]] + %2 = load i32, ptr %arg1, !alias.scope !5, !noalias !8 + ; CHECK: llvm.load + ; CHECK-SAME: alias_scopes = [#[[$SCOPE2]]] + ; CHECK-SAME: noalias_scopes = [#[[$SCOPE0]], #[[$SCOPE1]]] + %3 = load i32, ptr %arg1, !alias.scope !6, !noalias !9 + ret void +} + +!0 = !{!"domain1"} +!1 = !{!"scopeid1", !0, !"The first scope"} +!2 = !{!"scopeid2", !0} +!3 = !{!"scopeid3", !0} +!4 = !{!1} +!5 = !{!2} +!6 = !{!3} +!7 = !{!2, !3} +!8 = !{!1, !3} +!9 = !{!1, !2} diff --git a/mlir/test/Target/LLVMIR/attribute-alias-scopes.mlir b/mlir/test/Target/LLVMIR/attribute-alias-scopes.mlir index fa3395533af2..fb71a51512ae 100644 --- a/mlir/test/Target/LLVMIR/attribute-alias-scopes.mlir +++ b/mlir/test/Target/LLVMIR/attribute-alias-scopes.mlir @@ -104,3 +104,54 @@ llvm.func @self_reference() { // CHECK-DAG: ![[SCOPES]] = !{![[SCOPE]]} // CHECK-DAG: = !DISubroutineType(types: ![[TYPES:[0-9]+]]) // CHECK-DAG: ![[TYPES]] = !{null} + +// ----- + +llvm.func @foo(%arg0: !llvm.ptr) + +#alias_scope_domain = #llvm.alias_scope_domain<id = "domain1", description = "The domain"> +#alias_scope1 = #llvm.alias_scope<id = "scope1", domain = #alias_scope_domain, description = "The first scope"> +#alias_scope2 = #llvm.alias_scope<id = "scope2", domain = #alias_scope_domain> +#alias_scope3 = #llvm.alias_scope<id = "scope3", domain = #alias_scope_domain> + +// CHECK-LABEL: @alias_scopes +llvm.func @alias_scopes(%arg1 : !llvm.ptr) { + %0 = llvm.mlir.constant(0 : i32) : i32 + // CHECK: call void @llvm.experimental.noalias.scope.decl(metadata ![[SCOPES1:[0-9]+]]) + llvm.intr.experimental.noalias.scope.decl #alias_scope1 + // CHECK: store {{.*}}, !alias.scope ![[SCOPES1]], !noalias ![[SCOPES23:[0-9]+]] + llvm.store %0, %arg1 {alias_scopes = [#alias_scope1], noalias_scopes = [#alias_scope2, #alias_scope3]} : i32, !llvm.ptr + // CHECK: load {{.*}}, !alias.scope ![[SCOPES2:[0-9]+]], !noalias ![[SCOPES13:[0-9]+]] + %1 = llvm.load %arg1 {alias_scopes = [#alias_scope2], noalias_scopes = [#alias_scope1, #alias_scope3]} : !llvm.ptr -> i32 + // CHECK: atomicrmw {{.*}}, !alias.scope ![[SCOPES3:[0-9]+]], !noalias ![[SCOPES12:[0-9]+]] + %2 = llvm.atomicrmw add %arg1, %0 monotonic {alias_scopes = [#alias_scope3], noalias_scopes = [#alias_scope1, #alias_scope2]} : !llvm.ptr, i32 + // CHECK: cmpxchg {{.*}}, !alias.scope ![[SCOPES3]] + %3 = llvm.cmpxchg %arg1, %1, %2 acq_rel monotonic {alias_scopes = [#alias_scope3]} : !llvm.ptr, i32 + %5 = llvm.mlir.constant(42 : i8) : i8 + // CHECK: llvm.memcpy{{.*}}, !alias.scope ![[SCOPES3]] + "llvm.intr.memcpy"(%arg1, %arg1, %0) <{isVolatile = false}> {alias_scopes = [#alias_scope3]} : (!llvm.ptr, !llvm.ptr, i32) -> () + // CHECK: llvm.memset{{.*}}, !noalias ![[SCOPES3]] + "llvm.intr.memset"(%arg1, %5, %0) <{isVolatile = false}> {noalias_scopes = [#alias_scope3]} : (!llvm.ptr, i8, i32) -> () + // CHECK: call void @foo({{.*}} !alias.scope ![[SCOPES3]] + llvm.call @foo(%arg1) {alias_scopes = [#alias_scope3]} : (!llvm.ptr) -> () + // CHECK: call void @foo({{.*}} !noalias ![[SCOPES3]] + llvm.call @foo(%arg1) {noalias_scopes = [#alias_scope3]} : (!llvm.ptr) -> () + llvm.return +} + +// Check the intrinsic declarations. +// CHECK-DAG: declare void @llvm.experimental.noalias.scope.decl(metadata) +// CHECK-DAG: declare void @llvm.memcpy.p0.p0.i32(ptr noalias nocapture writeonly, ptr noalias nocapture readonly, i32, i1 immarg) +// CHECK-DAG: declare void @llvm.memset.p0.i32(ptr nocapture writeonly, i8, i32, i1 immarg) + +// Check the translated metadata. +// CHECK-DAG: ![[DOMAIN:[0-9]+]] = !{!"domain1", !"The domain"} +// CHECK-DAG: ![[SCOPE1:[0-9]+]] = !{!"scope1", ![[DOMAIN]], !"The first scope"} +// CHECK-DAG: ![[SCOPE2:[0-9]+]] = !{!"scope2", ![[DOMAIN]]} +// CHECK-DAG: ![[SCOPE3:[0-9]+]] = !{!"scope3", ![[DOMAIN]]} +// CHECK-DAG: ![[SCOPES1]] = !{![[SCOPE1]]} +// CHECK-DAG: ![[SCOPES2]] = !{![[SCOPE2]]} +// CHECK-DAG: ![[SCOPES3]] = !{![[SCOPE3]]} +// CHECK-DAG: ![[SCOPES12]] = !{![[SCOPE1]], ![[SCOPE2]]} +// CHECK-DAG: ![[SCOPES13]] = !{![[SCOPE1]], ![[SCOPE3]]} +// CHECK-DAG: ![[SCOPES23]] = !{![[SCOPE2]], ![[SCOPE3]]} diff --git a/mlir/test/Target/LLVMIR/nvvmir.mlir b/mlir/test/Target/LLVMIR/nvvmir.mlir index b69d77496351..2d7710e7cbf2 100644 --- a/mlir/test/Target/LLVMIR/nvvmir.mlir +++ b/mlir/test/Target/LLVMIR/nvvmir.mlir @@ -556,9 +556,7 @@ llvm.func @kernel_func() attributes {nvvm.kernel} { llvm.return } -// CHECK: !nvvm.annotations = -// CHECK-NOT: {ptr @nvvm_special_regs, !"kernel", i32 1} -// CHECK: {ptr @kernel_func, !"kernel", i32 1} +// CHECK: ptx_kernel void @kernel_func // ----- @@ -566,9 +564,8 @@ llvm.func @kernel_func() attributes {nvvm.kernel, nvvm.maxntid = array<i32: 1, 2 llvm.return } +// CHECK: define ptx_kernel void @kernel_func // CHECK: !nvvm.annotations = -// CHECK-NOT: {ptr @nvvm_special_regs, !"kernel", i32 1} -// CHECK: {ptr @kernel_func, !"kernel", i32 1} // CHECK: {ptr @kernel_func, !"maxntidx", i32 1} // CHECK: {ptr @kernel_func, !"maxntidy", i32 23} // CHECK: {ptr @kernel_func, !"maxntidz", i32 32} @@ -578,9 +575,8 @@ llvm.func @kernel_func() attributes {nvvm.kernel, nvvm.reqntid = array<i32: 1, 2 llvm.return } +// CHECK: define ptx_kernel void @kernel_func // CHECK: !nvvm.annotations = -// CHECK-NOT: {ptr @nvvm_special_regs, !"kernel", i32 1} -// CHECK: {ptr @kernel_func, !"kernel", i32 1} // CHECK: {ptr @kernel_func, !"reqntidx", i32 1} // CHECK: {ptr @kernel_func, !"reqntidy", i32 23} // CHECK: {ptr @kernel_func, !"reqntidz", i32 32} @@ -590,31 +586,28 @@ llvm.func @kernel_func() attributes {nvvm.kernel, nvvm.cluster_dim = array<i32: llvm.return } +// CHECK: define ptx_kernel void @kernel_func // CHECK: !nvvm.annotations = -// CHECK-NOT: {ptr @nvvm_special_regs, !"kernel", i32 1} // CHECK: {ptr @kernel_func, !"cluster_dim_x", i32 3} // CHECK: {ptr @kernel_func, !"cluster_dim_y", i32 5} // CHECK: {ptr @kernel_func, !"cluster_dim_z", i32 7} -// CHECK: {ptr @kernel_func, !"kernel", i32 1} // ----- llvm.func @kernel_func() attributes {nvvm.kernel, nvvm.cluster_max_blocks = 8} { llvm.return } +// CHECK: define ptx_kernel void @kernel_func // CHECK: !nvvm.annotations = -// CHECK-NOT: {ptr @nvvm_special_regs, !"kernel", i32 1} // CHECK: {ptr @kernel_func, !"cluster_max_blocks", i32 8} -// CHECK: {ptr @kernel_func, !"kernel", i32 1} // ----- llvm.func @kernel_func() attributes {nvvm.kernel, nvvm.minctasm = 16} { llvm.return } +// CHECK: define ptx_kernel void @kernel_func // CHECK: !nvvm.annotations = -// CHECK-NOT: {ptr @nvvm_special_regs, !"kernel", i32 1} -// CHECK: {ptr @kernel_func, !"kernel", i32 1} // CHECK: {ptr @kernel_func, !"minctasm", i32 16} // ----- @@ -622,9 +615,8 @@ llvm.func @kernel_func() attributes {nvvm.kernel, nvvm.maxnreg = 16} { llvm.return } +// CHECK: define ptx_kernel void @kernel_func // CHECK: !nvvm.annotations = -// CHECK-NOT: {ptr @nvvm_special_regs, !"kernel", i32 1} -// CHECK: {ptr @kernel_func, !"kernel", i32 1} // CHECK: {ptr @kernel_func, !"maxnreg", i32 16} // ----- @@ -633,9 +625,8 @@ llvm.func @kernel_func() attributes {nvvm.kernel, nvvm.maxntid = array<i32: 1, 2 llvm.return } +// CHECK: define ptx_kernel void @kernel_func // CHECK: !nvvm.annotations = -// CHECK-NOT: {ptr @nvvm_special_regs, !"kernel", i32 1} -// CHECK: {ptr @kernel_func, !"kernel", i32 1} // CHECK: {ptr @kernel_func, !"maxnreg", i32 32} // CHECK: {ptr @kernel_func, !"maxntidx", i32 1} // CHECK: {ptr @kernel_func, !"maxntidy", i32 23} @@ -643,19 +634,19 @@ llvm.func @kernel_func() attributes {nvvm.kernel, nvvm.maxntid = array<i32: 1, 2 // CHECK: {ptr @kernel_func, !"minctasm", i32 16} // ----- +// CHECK: define ptx_kernel void @kernel_func // CHECK: !nvvm.annotations = // CHECK: !1 = !{ptr @kernel_func, !"grid_constant", !2} // CHECK: !2 = !{i32 1} -// CHECK: !3 = !{ptr @kernel_func, !"kernel", i32 1} llvm.func @kernel_func(%arg0: !llvm.ptr {llvm.byval = i32, nvvm.grid_constant}) attributes {nvvm.kernel} { llvm.return } // ----- +// CHECK: define ptx_kernel void @kernel_func // CHECK: !nvvm.annotations = // CHECK: !1 = !{ptr @kernel_func, !"grid_constant", !2} // CHECK: !2 = !{i32 1, i32 3} -// CHECK: !3 = !{ptr @kernel_func, !"kernel", i32 1} llvm.func @kernel_func(%arg0: !llvm.ptr {llvm.byval = i32, nvvm.grid_constant}, %arg1: f32, %arg2: !llvm.ptr {llvm.byval = f32, nvvm.grid_constant}) attributes {nvvm.kernel} { llvm.return } diff --git a/mlir/test/Target/LLVMIR/omptarget-threadprivate-device-lowering.mlir b/mlir/test/Target/LLVMIR/omptarget-threadprivate-device-lowering.mlir new file mode 100644 index 000000000000..279ecb3f8e99 --- /dev/null +++ b/mlir/test/Target/LLVMIR/omptarget-threadprivate-device-lowering.mlir @@ -0,0 +1,30 @@ +// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s + +// Not intended to be a functional example, the aim of this test is to verify +// omp.threadprivate does not crash on lowering during the OpenMP target device +// pass when used in conjunction with target code in the same module. + +module attributes {omp.is_target_device = true } { + llvm.func @func() attributes {omp.declare_target = #omp.declaretarget<device_type = (host), capture_clause = (to)>} { + %0 = llvm.mlir.addressof @_QFEpointer2 : !llvm.ptr + %1 = omp.threadprivate %0 : !llvm.ptr -> !llvm.ptr + %2 = omp.map.info var_ptr(%1 : !llvm.ptr, !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>)>) map_clauses(implicit, to) capture(ByRef) -> !llvm.ptr + omp.target map_entries(%2 -> %arg0 : !llvm.ptr) { + %3 = llvm.mlir.constant(1 : i32) : i32 + %4 = llvm.getelementptr %arg0[0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>)> + llvm.store %3, %4 : i32, !llvm.ptr + omp.terminator + } + llvm.return + } + llvm.mlir.global internal @_QFEpointer2() {addr_space = 0 : i32} : !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>)> { + %0 = llvm.mlir.undef : !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>)> + llvm.return %0 : !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>)> + } +} + +// CHECK: define weak_odr protected void @{{.*}}(ptr %{{.*}}, ptr %[[ARG1:.*]]) { +// CHECK: %[[ALLOCA:.*]] = alloca ptr, align 8 +// CHECK: store ptr %[[ARG1]], ptr %[[ALLOCA]], align 8 +// CHECK: %[[LOAD_ALLOCA:.*]] = load ptr, ptr %[[ALLOCA]], align 8 +// CHECK: store i32 1, ptr %[[LOAD_ALLOCA]], align 4 diff --git a/mlir/test/Target/LLVMIR/openmp-simd-aligned.mlir b/mlir/test/Target/LLVMIR/openmp-simd-aligned.mlir new file mode 100644 index 000000000000..234604e4b664 --- /dev/null +++ b/mlir/test/Target/LLVMIR/openmp-simd-aligned.mlir @@ -0,0 +1,60 @@ +// RUN: mlir-translate -mlir-to-llvmir -split-input-file %s | FileCheck %s + +//CHECK-LABEL: define void @_QPsimd_aligned_pointer() { +//CHECK: %[[A_PTR:.*]] = alloca { ptr, i64, i32, i8, i8, i8, i8 }, i64 1, align 8 +//CHECK: %[[A_VAL:.*]] = load ptr, ptr %[[A_PTR]], align 8 +//CHECK: call void @llvm.assume(i1 true) [ "align"(ptr %[[A_VAL]], i64 256) ] +llvm.func @_QPsimd_aligned_pointer() { + %1 = llvm.mlir.constant(1 : i64) : i64 + %2 = llvm.alloca %1 x !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8)> {bindc_name = "x"} : (i64) -> !llvm.ptr + %3 = llvm.alloca %1 x i32 {bindc_name = "i", pinned} : (i64) -> !llvm.ptr + %4 = llvm.mlir.constant(1 : i32) : i32 + %5 = llvm.mlir.constant(10 : i32) : i32 + %6 = llvm.mlir.constant(1 : i32) : i32 + omp.simd aligned(%2 : !llvm.ptr -> 256 : i64) { + omp.loop_nest (%arg0) : i32 = (%4) to (%5) inclusive step (%6) { + llvm.store %arg0, %3 : i32, !llvm.ptr + omp.yield + } + } + llvm.return +} + +//CHECK-LABEL: define void @_QPsimd_aligned_cptr() { +//CHECK: %[[A_CPTR:.*]] = alloca %_QM__fortran_builtinsT__builtin_c_ptr, i64 1, align 8 +//CHECK: %[[A_VAL:.*]] = load ptr, ptr %[[A_CPTR]], align 8 +//CHECK: call void @llvm.assume(i1 true) [ "align"(ptr %[[A_VAL]], i64 256) ] +llvm.func @_QPsimd_aligned_cptr() { + %0 = llvm.mlir.constant(1 : i64) : i64 + %1 = llvm.alloca %0 x !llvm.struct<"_QM__fortran_builtinsT__builtin_c_ptr", (i64)> {bindc_name = "a"} : (i64) -> !llvm.ptr + %2 = llvm.mlir.constant(1 : i64) : i64 + %3 = llvm.alloca %2 x i32 {bindc_name = "i", pinned} : (i64) -> !llvm.ptr + %4 = llvm.mlir.constant(1 : i32) : i32 + %5 = llvm.mlir.constant(10 : i32) : i32 + %6 = llvm.mlir.constant(1 : i32) : i32 + omp.simd aligned(%1 : !llvm.ptr -> 256 : i64) { + omp.loop_nest (%arg0) : i32 = (%4) to (%5) inclusive step (%6) { + llvm.store %arg0, %3 : i32, !llvm.ptr + omp.yield + } + } + llvm.return +} + +//CHECK-LABEL: define void @_QPsimd_aligned_allocatable() { +//CHECK: %[[A_ADDR:.*]] = alloca { ptr, i64, i32, i8, i8, i8, i8, [1 x [3 x i64]] }, i64 1, align 8 +//CHECK: %[[A_VAL:.*]] = load ptr, ptr %[[A_ADDR]], align 8 +//CHECK: call void @llvm.assume(i1 true) [ "align"(ptr %[[A_VAL]], i64 256) ] +llvm.func @_QPsimd_aligned_allocatable() { + %0 = llvm.mlir.constant(1 : i64) : i64 + %1 = llvm.alloca %0 x !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>)> {bindc_name = "a"} : (i64) -> !llvm.ptr + %2 = llvm.mlir.constant(1 : i32) : i32 + %3 = llvm.mlir.constant(10 : i32) : i32 + %4 = llvm.mlir.constant(1 : i32) : i32 + omp.simd aligned(%1 : !llvm.ptr -> 256 : i64) { + omp.loop_nest (%arg0) : i32 = (%2) to (%3) inclusive step (%4) { + omp.yield + } + } + llvm.return +} diff --git a/mlir/test/Target/LLVMIR/openmp-todo.mlir b/mlir/test/Target/LLVMIR/openmp-todo.mlir index 8f3e466cfbbe..83a0990d6316 100644 --- a/mlir/test/Target/LLVMIR/openmp-todo.mlir +++ b/mlir/test/Target/LLVMIR/openmp-todo.mlir @@ -127,18 +127,6 @@ llvm.func @sections_private(%x : !llvm.ptr) { llvm.return } -// ----- - -llvm.func @simd_aligned(%lb : i32, %ub : i32, %step : i32, %x : !llvm.ptr) { - // expected-error@below {{not yet implemented: Unhandled clause aligned in omp.simd operation}} - // expected-error@below {{LLVM Translation failed for operation: omp.simd}} - omp.simd aligned(%x : !llvm.ptr -> 32) { - omp.loop_nest (%iv) : i32 = (%lb) to (%ub) step (%step) { - omp.yield - } - } - llvm.return -} // ----- diff --git a/mlir/test/Transforms/location-snapshot.mlir b/mlir/test/Transforms/location-snapshot.mlir index 9f48cb6e3b3f..aeddfedd08ae 100644 --- a/mlir/test/Transforms/location-snapshot.mlir +++ b/mlir/test/Transforms/location-snapshot.mlir @@ -1,5 +1,6 @@ // RUN: mlir-opt -allow-unregistered-dialect -snapshot-op-locations='filename=%/t' -mlir-print-local-scope -mlir-print-debuginfo %s | FileCheck %s -DFILE=%/t // RUN: mlir-opt -allow-unregistered-dialect -snapshot-op-locations='filename=%/t tag='tagged'' -mlir-print-local-scope -mlir-print-debuginfo %s | FileCheck %s --check-prefix=TAG -DFILE=%/t +// RUN: mlir-opt -allow-unregistered-dialect -snapshot-op-locations='filename=%/t print-debuginfo' -mlir-print-local-scope -mlir-print-debuginfo %s | FileCheck %s --check-prefix=DBG -DFILE=%/t && cat %/t | FileCheck %s --check-prefix=DBGFILE // CHECK: func @function( // CHECK-NEXT: loc("[[FILE]]":{{[0-9]+}}:{{[0-9]+}}) @@ -15,3 +16,18 @@ func.func @function() -> i32 { %1 = "foo"() : () -> i32 loc("original") return %1 : i32 loc("original") } loc("original") + +// DBG: func @function2( +// DBG-NEXT: loc("[[FILE]]":{{[0-9]+}}:{{[0-9]+}}) +// DBG-NEXT: loc("[[FILE]]":{{[0-9]+}}:{{[0-9]+}}) +// DBG-NEXT: } loc("[[FILE]]":{{[0-9]+}}:{{[0-9]+}}) + +// DBGFILE: func @function2( +// DBGFILE-NEXT: loc("{{.*}}location-snapshot.mlir":{{[0-9]+}}:{{[0-9]+}}) +// DBGFILE-NEXT: loc("{{.*}}location-snapshot.mlir":{{[0-9]+}}:{{[0-9]+}}) +// DBGFILE-NEXT: } loc("{{.*}}location-snapshot.mlir":{{[0-9]+}}:{{[0-9]+}}) + +func.func @function2() -> i32 { + %1 = "foo"() : () -> i32 + return %1 : i32 +}
\ No newline at end of file diff --git a/mlir/test/Transforms/loop-invariant-code-motion.mlir b/mlir/test/Transforms/loop-invariant-code-motion.mlir index e4c423ce7052..5133c14414c9 100644 --- a/mlir/test/Transforms/loop-invariant-code-motion.mlir +++ b/mlir/test/Transforms/loop-invariant-code-motion.mlir @@ -124,6 +124,64 @@ func.func @invariant_affine_if() { // ----- +func.func @hoist_invariant_affine_if_success(%lb: index, %ub: index, %step: index) -> i32 { + %cst_0 = arith.constant 0 : i32 + %cst_42 = arith.constant 42 : i32 + %sum_result = affine.for %i = %lb to %ub iter_args(%acc = %cst_0) -> i32 { + %conditional_add = affine.if affine_set<() : ()> () -> (i32) { + %add = arith.addi %cst_42, %cst_42 : i32 + affine.yield %add : i32 + } else { + %poison = ub.poison : i32 + affine.yield %poison : i32 + } + %sum = arith.addi %acc, %conditional_add : i32 + affine.yield %sum : i32 + } + + // CHECK-LABEL: hoist_invariant_affine_if_success + // CHECK-NEXT: arith.constant 0 : i32 + // CHECK-NEXT: %[[CST:.*]] = arith.constant 42 : i32 + // CHECK-NEXT: %[[IF:.*]] = affine.if + // CHECK-NEXT: arith.addi %[[CST]], %[[CST]] : i32 + // CHECK: affine.for + // CHECK-NOT: affine.if + // CHECK-NEXT: arith.addi %{{.*}}, %[[IF]] + + return %sum_result : i32 +} + +// ----- + +func.func @hoist_variant_affine_if_failure(%lb: index, %ub: index, %step: index) -> i32 { + %cst_0 = arith.constant 0 : i32 + %cst_42 = arith.constant 42 : i32 + %ind_7 = arith.constant 7 : index + %sum_result = affine.for %i = %lb to %ub iter_args(%acc = %cst_0) -> i32 { + %conditional_add = affine.if affine_set<(d0, d1) : (d1 - d0 >= 0)> (%i, %ind_7) -> (i32) { + %add = arith.addi %cst_42, %cst_42 : i32 + affine.yield %add : i32 + } else { + %poison = ub.poison : i32 + affine.yield %poison : i32 + } + %sum = arith.addi %acc, %conditional_add : i32 + affine.yield %sum : i32 + } + + // CHECK-LABEL: hoist_variant_affine_if_failure + // CHECK-NEXT: arith.constant 0 : i32 + // CHECK-NEXT: %[[CST:.*]] = arith.constant 42 : i32 + // CHECK-NEXT: arith.constant 7 : index + // CHECK-NEXT: affine.for + // CHECK-NEXT: %[[IF:.*]] = affine.if + // CHECK: arith.addi %{{.*}}, %[[IF]] + + return %sum_result : i32 +} + +// ----- + func.func @hoist_affine_for_with_unknown_trip_count(%lb: index, %ub: index) { affine.for %arg0 = 0 to 10 { affine.for %arg1 = %lb to %ub { @@ -383,6 +441,69 @@ func.func @parallel_loop_with_invariant() { // ----- +func.func @hoist_invariant_scf_if_success(%lb: index, %ub: index, %step: index) -> i32 { + %cst_0 = arith.constant 0 : i32 + %cst_42 = arith.constant 42 : i32 + %true = arith.constant true + %sum_result = scf.for %i = %lb to %ub step %step iter_args(%acc = %cst_0) -> i32 { + %conditional_add = scf.if %true -> (i32) { + %add = arith.addi %cst_42, %cst_42 : i32 + scf.yield %add : i32 + } else { + %poison = ub.poison : i32 + scf.yield %poison : i32 + } + %sum = arith.addi %acc, %conditional_add : i32 + scf.yield %sum : i32 + } + + // CHECK-LABEL: hoist_invariant_scf_if_success + // CHECK-NEXT: arith.constant 0 : i32 + // CHECK-NEXT: %[[CST:.*]] = arith.constant 42 : i32 + // CHECK-NEXT: %[[TRUE:.*]] = arith.constant true + // CHECK-NEXT: %[[IF:.*]] = scf.if %[[TRUE]] + // CHECK-NEXT: arith.addi %[[CST]], %[[CST]] : i32 + // CHECK: scf.for + // CHECK-NOT: scf.if + // CHECK-NEXT: arith.addi %{{.*}}, %[[IF]] + + return %sum_result : i32 +} + +// ----- + +func.func @hoist_variant_scf_if_failure(%lb: index, %ub: index, %step: index) -> i32 { + %cst_0 = arith.constant 0 : i32 + %cst_42 = arith.constant 42 : i32 + %ind_7 = arith.constant 7 : index + %sum_result = scf.for %i = %lb to %ub step %step iter_args(%acc = %cst_0) -> i32 { + %cond = arith.cmpi ult, %i, %ind_7 : index + %conditional_add = scf.if %cond -> (i32) { + %add = arith.addi %cst_42, %cst_42 : i32 + scf.yield %add : i32 + } else { + %poison = ub.poison : i32 + scf.yield %poison : i32 + } + %sum = arith.addi %acc, %conditional_add : i32 + scf.yield %sum : i32 + } + + // CHECK-LABEL: hoist_variant_scf_if_failure + // CHECK-NEXT: arith.constant 0 : i32 + // CHECK-NEXT: %[[CST_42:.*]] = arith.constant 42 : i32 + // CHECK-NEXT: %[[CST_7:.*]] = arith.constant 7 : index + // CHECK-NEXT: scf.for %[[IV:.*]] = %{{.*}} to %{{.*}} + // CHECK-NEXT: %[[CMP:.*]] = arith.cmpi ult, %[[IV]], %[[CST_7]] + // CHECK-NEXT: %[[IF:.*]] = scf.if %[[CMP]] + // CHECK-NEXT: arith.addi %[[CST_42]], %[[CST_42]] : i32 + // CHECK: arith.addi %{{.*}}, %[[IF]] + + return %sum_result : i32 +} + +// ----- + func.func private @make_val() -> (index) // CHECK-LABEL: func @nested_uses_inside diff --git a/mlir/test/Transforms/test-legalizer.mlir b/mlir/test/Transforms/test-legalizer.mlir index 2ca5f4963752..ae7d344b7167 100644 --- a/mlir/test/Transforms/test-legalizer.mlir +++ b/mlir/test/Transforms/test-legalizer.mlir @@ -64,9 +64,6 @@ func.func @remap_call_1_to_1(%arg0: i64) { // Contents of the old block are moved to the new block. // CHECK-NEXT: notifyOperationInserted: test.return, was linked, exact position unknown -// The new block arguments are used in "test.return". -// CHECK-NEXT: notifyOperationModified: test.return - // The old block is erased. // CHECK-NEXT: notifyBlockErased @@ -390,8 +387,8 @@ func.func @caller() { // CHECK: %[[call:.*]]:2 = call @callee() : () -> (f16, f16) %0:2 = func.call @callee() : () -> (f32, i24) - // CHECK: %[[cast1:.*]] = "test.cast"() : () -> i24 - // CHECK: %[[cast0:.*]] = "test.cast"(%[[call]]#0, %[[call]]#1) : (f16, f16) -> f32 + // CHECK-DAG: %[[cast1:.*]] = "test.cast"() : () -> i24 + // CHECK-DAG: %[[cast0:.*]] = "test.cast"(%[[call]]#0, %[[call]]#1) : (f16, f16) -> f32 // CHECK: "test.some_user"(%[[cast0]], %[[cast1]]) : (f32, i24) -> () // expected-remark @below{{'test.some_user' is not legalizable}} "test.some_user"(%0#0, %0#1) : (f32, i24) -> () @@ -450,7 +447,7 @@ func.func @fold_legalization() -> i32 { // ----- // CHECK-LABEL: func @convert_detached_signature() -// CHECK: "test.legal_op_with_region"() ({ +// CHECK: "test.legal_op"() ({ // CHECK: ^bb0(%arg0: f64): // CHECK: "test.return"() : () -> () // CHECK: }) : () -> () @@ -483,3 +480,20 @@ func.func @test_1_to_n_block_signature_conversion() { "test.return"() : () -> () } +// ----- + +// CHECK: notifyOperationInserted: test.step_1 +// CHECK: notifyOperationReplaced: test.multiple_1_to_n_replacement +// CHECK: notifyOperationErased: test.multiple_1_to_n_replacement +// CHECK: notifyOperationInserted: test.legal_op +// CHECK: notifyOperationReplaced: test.step_1 +// CHECK: notifyOperationErased: test.step_1 + +// CHECK-LABEL: func @test_multiple_1_to_n_replacement() +// CHECK: %[[legal_op:.*]]:4 = "test.legal_op"() : () -> (f16, f16, f16, f16) +// CHECK: %[[cast:.*]] = "test.cast"(%[[legal_op]]#0, %[[legal_op]]#1, %[[legal_op]]#2, %[[legal_op]]#3) : (f16, f16, f16, f16) -> f16 +// CHECK: "test.valid"(%[[cast]]) : (f16) -> () +func.func @test_multiple_1_to_n_replacement() { + %0 = "test.multiple_1_to_n_replacement"() : () -> (f16) + "test.invalid"(%0) : (f16) -> () +} diff --git a/mlir/test/lib/Dialect/Func/TestDecomposeCallGraphTypes.cpp b/mlir/test/lib/Dialect/Func/TestDecomposeCallGraphTypes.cpp index 09c5b4b2a0ad..d0b62e71ab0c 100644 --- a/mlir/test/lib/Dialect/Func/TestDecomposeCallGraphTypes.cpp +++ b/mlir/test/lib/Dialect/Func/TestDecomposeCallGraphTypes.cpp @@ -139,7 +139,7 @@ struct TestDecomposeCallGraphTypes tupleType.getFlattenedTypes(types); return success(); }); - typeConverter.addArgumentMaterialization(buildMakeTupleOp); + typeConverter.addSourceMaterialization(buildMakeTupleOp); typeConverter.addTargetMaterialization(buildDecomposeTuple); populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>( diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp index a470497fdbb5..5b7c36c9b97b 100644 --- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp +++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp @@ -785,7 +785,7 @@ struct TestDetachedSignatureConversion : public ConversionPattern { ConversionPatternRewriter &rewriter) const final { if (op->getNumRegions() != 1) return failure(); - OperationState state(op->getLoc(), "test.legal_op_with_region", operands, + OperationState state(op->getLoc(), "test.legal_op", operands, op->getResultTypes(), {}, BlockRange()); Region *newRegion = state.addRegion(); rewriter.inlineRegionBefore(op->getRegion(0), *newRegion, @@ -1234,6 +1234,41 @@ public: } }; +/// A pattern that tests two back-to-back 1 -> 2 op replacements. +class TestMultiple1ToNReplacement : public ConversionPattern { +public: + TestMultiple1ToNReplacement(MLIRContext *ctx, const TypeConverter &converter) + : ConversionPattern(converter, "test.multiple_1_to_n_replacement", 1, + ctx) {} + LogicalResult + matchAndRewrite(Operation *op, ArrayRef<ValueRange> operands, + ConversionPatternRewriter &rewriter) const final { + // Helper function that replaces the given op with a new op of the given + // name and doubles each result (1 -> 2 replacement of each result). + auto replaceWithDoubleResults = [&](Operation *op, StringRef name) { + SmallVector<Type> types; + for (Type t : op->getResultTypes()) { + types.push_back(t); + types.push_back(t); + } + OperationState state(op->getLoc(), name, + /*operands=*/{}, types, op->getAttrs()); + auto *newOp = rewriter.create(state); + SmallVector<ValueRange> repls; + for (size_t i = 0, e = op->getNumResults(); i < e; ++i) + repls.push_back(newOp->getResults().slice(2 * i, 2)); + rewriter.replaceOpWithMultiple(op, repls); + return newOp; + }; + + // Replace test.multiple_1_to_n_replacement with test.step_1. + Operation *repl1 = replaceWithDoubleResults(op, "test.step_1"); + // Now replace test.step_1 with test.legal_op. + replaceWithDoubleResults(repl1, "test.legal_op"); + return success(); + } +}; + } // namespace namespace { @@ -1241,7 +1276,6 @@ struct TestTypeConverter : public TypeConverter { using TypeConverter::TypeConverter; TestTypeConverter() { addConversion(convertType); - addArgumentMaterialization(materializeCast); addSourceMaterialization(materializeCast); } @@ -1319,7 +1353,8 @@ struct TestLegalizePatternDriver TestUndoPropertiesModification, TestEraseOp, TestRepetitive1ToNConsumer>(&getContext()); patterns.add<TestDropOpSignatureConversion, TestDropAndReplaceInvalidOp, - TestPassthroughInvalidOp>(&getContext(), converter); + TestPassthroughInvalidOp, TestMultiple1ToNReplacement>( + &getContext(), converter); patterns.add<TestDuplicateBlockArgs>(converter, &getContext()); mlir::populateAnyFunctionOpInterfaceTypeConversionPattern(patterns, converter); @@ -1330,8 +1365,7 @@ struct TestLegalizePatternDriver target.addLegalOp<ModuleOp>(); target.addLegalOp<LegalOpA, LegalOpB, LegalOpC, TestCastOp, TestValidOp, TerminatorOp, OneRegionOp>(); - target.addLegalOp( - OperationName("test.legal_op_with_region", &getContext())); + target.addLegalOp(OperationName("test.legal_op", &getContext())); target .addIllegalOp<ILLegalOpF, TestRegionBuilderOp, TestOpWithRegionFold>(); target.addDynamicallyLegalOp<TestReturnOp>([](TestReturnOp op) { diff --git a/mlir/test/lib/Dialect/Tosa/TosaTestPasses.cpp b/mlir/test/lib/Dialect/Tosa/TosaTestPasses.cpp index ac904c3e01c9..83db1188861a 100644 --- a/mlir/test/lib/Dialect/Tosa/TosaTestPasses.cpp +++ b/mlir/test/lib/Dialect/Tosa/TosaTestPasses.cpp @@ -149,7 +149,7 @@ ConvertTosaConv2DOp::matchAndRewrite(Operation *op, op->getLoc(), newTosaConv2DOpType, tosaConv2DOp.getInput(), tosaConv2DOp.getWeight(), tosaConv2DOp.getBias(), tosaConv2DOp.getPadAttr(), tosaConv2DOp.getStrideAttr(), - tosaConv2DOp.getDilationAttr()); + tosaConv2DOp.getDilationAttr(), tosaConv2DOp.getAccTypeAttr()); // Create rescale to quantized type double inputScale = inputQType.getScale(); diff --git a/mlir/test/lib/Transforms/TestDialectConversion.cpp b/mlir/test/lib/Transforms/TestDialectConversion.cpp index 2cc1fb5d39d7..a03bf0a1023d 100644 --- a/mlir/test/lib/Transforms/TestDialectConversion.cpp +++ b/mlir/test/lib/Transforms/TestDialectConversion.cpp @@ -28,7 +28,6 @@ namespace { struct PDLLTypeConverter : public TypeConverter { PDLLTypeConverter() { addConversion(convertType); - addArgumentMaterialization(materializeCast); addSourceMaterialization(materializeCast); } diff --git a/mlir/test/python/dialects/transform_structured_ext.py b/mlir/test/python/dialects/transform_structured_ext.py index fb4c75b53379..8785d6d36007 100644 --- a/mlir/test/python/dialects/transform_structured_ext.py +++ b/mlir/test/python/dialects/transform_structured_ext.py @@ -103,6 +103,42 @@ def testFuseIntoContainingOpCompact(target): @run @create_sequence +def testFuseOpCompact(target): + structured.FuseOp( + target, tile_sizes=[4, 8], tile_interchange=[0, 1], apply_cleanup=True + ) + # CHECK-LABEL: TEST: testFuseOpCompact + # CHECK: transform.sequence + # CHECK: %{{.+}}, %{{.+}}:2 = transform.structured.fuse %{{.*}}[4, 8] + # CHECK-SAME: interchange [0, 1] apply_cleanup = true + # CHECK-SAME: (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op) + + +@run +@create_sequence +def testFuseOpNoArg(target): + structured.FuseOp(target) + # CHECK-LABEL: TEST: testFuseOpNoArg + # CHECK: transform.sequence + # CHECK: %{{.+}} = transform.structured.fuse %{{.*}} : + # CHECK-SAME: (!transform.any_op) -> !transform.any_op + + +@run +@create_sequence +def testFuseOpAttributes(target): + attr = DenseI64ArrayAttr.get([4, 8]) + ichange = DenseI64ArrayAttr.get([0, 1]) + structured.FuseOp(target, tile_sizes=attr, tile_interchange=ichange) + # CHECK-LABEL: TEST: testFuseOpAttributes + # CHECK: transform.sequence + # CHECK: %{{.+}}, %{{.+}}:2 = transform.structured.fuse %{{.*}}[4, 8] + # CHECK-SAME: interchange [0, 1] + # CHECK-SAME: (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op) + + +@run +@create_sequence def testGeneralize(target): structured.GeneralizeOp(target) # CHECK-LABEL: TEST: testGeneralize diff --git a/mlir/test/python/execution_engine.py b/mlir/test/python/execution_engine.py index 6d3a8db8c24b..0d12c35d96be 100644 --- a/mlir/test/python/execution_engine.py +++ b/mlir/test/python/execution_engine.py @@ -306,7 +306,7 @@ def testUnrankedMemRefWithOffsetCallback(): log(arr) with Context(): - # The module takes a subview of the argument memref, casts it to an unranked memref and + # The module takes a subview of the argument memref, casts it to an unranked memref and # calls the callback with it. module = Module.parse( r""" diff --git a/mlir/test/python/ir/dialects.py b/mlir/test/python/ir/dialects.py index d59c6a6bc424..5a2ed684d298 100644 --- a/mlir/test/python/ir/dialects.py +++ b/mlir/test/python/ir/dialects.py @@ -121,3 +121,39 @@ def testAppendPrefixSearchPath(): sys.path.append(".") _cext.globals.append_dialect_search_prefix("custom_dialect") assert _cext.globals._check_dialect_module_loaded("custom") + + +# CHECK-LABEL: TEST: testDialectLoadOnCreate +@run +def testDialectLoadOnCreate(): + with Context(load_on_create_dialects=[]) as ctx: + ctx.emit_error_diagnostics = True + ctx.allow_unregistered_dialects = True + + def callback(d): + # CHECK: DIAGNOSTIC + # CHECK-SAME: op created with unregistered dialect + print(f"DIAGNOSTIC={d.message}") + return True + + handler = ctx.attach_diagnostic_handler(callback) + loc = Location.unknown(ctx) + try: + op = Operation.create("arith.addi", loc=loc) + ctx.allow_unregistered_dialects = False + op.verify() + except MLIRError as e: + pass + + with Context(load_on_create_dialects=["func"]) as ctx: + loc = Location.unknown(ctx) + fn = Operation.create("func.func", loc=loc) + + # TODO: This may require an update if a site wide policy is set. + # CHECK: Load on create: [] + print(f"Load on create: {get_load_on_create_dialects()}") + append_load_on_create_dialect("func") + # CHECK: Load on create: + # CHECK-SAME: func + print(f"Load on create: {get_load_on_create_dialects()}") + print(get_load_on_create_dialects()) diff --git a/mlir/test/tblgen-lsp-server/templ-arg-check.test b/mlir/test/tblgen-lsp-server/templ-arg-check.test new file mode 100644 index 000000000000..cda9b79a1f46 --- /dev/null +++ b/mlir/test/tblgen-lsp-server/templ-arg-check.test @@ -0,0 +1,15 @@ +// RUN: tblgen-lsp-server -lit-test < %s | FileCheck -strict-whitespace %s +{"jsonrpc":"2.0","id":0,"method":"initialize","params":{"processId":123,"rootPath":"tablegen","capabilities":{},"trace":"off"}} +// ----- +{"jsonrpc":"2.0","method":"textDocument/didOpen","params":{"textDocument":{ + "uri":"test:///foo.td", + "languageId":"tablegen", + "version":1, + "text":"class Foo<int i>;\ndef : Foo<\"\">;" +}}} +// CHECK: "method": "textDocument/publishDiagnostics", +// CHECK: "message": "Value specified for template argument 'Foo:i' is of type string; expected type int: \"\"", +// ----- +{"jsonrpc":"2.0","id":3,"method":"shutdown"} +// ----- +{"jsonrpc":"2.0","method":"exit"} diff --git a/mlir/utils/pygments/README.md b/mlir/utils/pygments/README.md new file mode 100644 index 000000000000..838faceb01b0 --- /dev/null +++ b/mlir/utils/pygments/README.md @@ -0,0 +1,45 @@ +## Pygments Lexer for MLIR + +This file contains a simple Pygments lexer configuration for MLIR, derived from +the version used in the original CGO paper. Pygments allows for advanced +configurable syntax highlighting of any code. This lexer is known to be +incomplete and support mostly core IR with a subset of built-in types. +Additions and customizations are welcome. + +### Standalone Usage + +Install Pygments, e.g., by running `pip install Pygments` or a Python package +manager of your choosing. Use the standalone `pygmentize` command by +instructing it to load the custom lexer: + +``` +pygmentize -l /path/to/mlir_lexer.py:MlirLexer -x myfile.mlir +``` + +This will produce highlighted output in the terminal. Other output formats are +available, see Pygments [documentation](https://pygments.org/docs/) for more +information. + +### LaTeX Usage + +First, make sure your distribution includes the `minted` package and list in +the preamble. + +```latex +\usepackage{minted} +``` + +Place the `mlir_lexer.py` in a place where the `latex` binary can find it, +typically in the working directory next to the main `.tex` file. Note that you +will have to invoke `latex` with the `-shell-escape` flag. See the `minted` +package [documentation](https://ctan.org/pkg/minted?lang=en) for more +information. + +Leverage the custom lexer facility of `minted` to use this lexer in your +document as: + +```latex +\begin{minted}{mlir_lexer.py:MlirLexer -x} + ... your code here ... +\end{minted} +``` diff --git a/mlir/utils/pygments/mlir_lexer.py b/mlir/utils/pygments/mlir_lexer.py new file mode 100644 index 000000000000..179a058e9110 --- /dev/null +++ b/mlir/utils/pygments/mlir_lexer.py @@ -0,0 +1,38 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from pygments.lexer import RegexLexer +from pygments.token import * + + +class MlirLexer(RegexLexer): + name = "MLIR" + aliases = ["mlir"] + filenames = ["*.mlir"] + + tokens = { + "root": [ + (r"%[a-zA-Z0-9_]+", Name.Variable), + (r"@[a-zA-Z_][a-zA-Z0-9_]+", Name.Function), + (r"\^[a-zA-Z0-9_]+", Name.Label), + (r"#[a-zA-Z0-9_]+", Name.Constant), + (r"![a-zA-Z0-9_]+", Keyword.Type), + (r"[a-zA-Z_][a-zA-Z0-9_]*\.", Name.Entity), + (r"memref[^.]", Keyword.Type), + (r"index", Keyword.Type), + (r"i[0-9]+", Keyword.Type), + (r"f[0-9]+", Keyword.Type), + (r"[0-9]+", Number.Integer), + (r"[0-9]*\.[0-9]*", Number.Float), + (r'"[^"]*"', String.Double), + (r"affine_map", Keyword.Reserved), + # TODO: this should be within affine maps only + (r"\+-\*\/", Operator), + (r"floordiv", Operator.Word), + (r"ceildiv", Operator.Word), + (r"mod", Operator.Word), + (r"()\[\]<>,{}", Punctuation), + (r"\/\/.*\n", Comment.Single), + ] + } |
