summaryrefslogtreecommitdiff
path: root/mlir
diff options
context:
space:
mode:
authorNAKAMURA Takumi <geek4civic@gmail.com>2025-01-09 18:49:54 +0900
committerNAKAMURA Takumi <geek4civic@gmail.com>2025-01-09 18:49:54 +0900
commite2810c9a248f4c7fbfae84bb32b6f7e01027458b (patch)
treeae0b02a8491b969a1cee94ea16ffe42c559143c5 /mlir
parentfa04eb4af95c1ca7377279728cb004bcd2324d01 (diff)
parentbdcf47e4bcb92889665825654bb80a8bbe30379e (diff)
Merge branch 'users/chapuni/cov/single/base' into users/chapuni/cov/single/switchusers/chapuni/cov/single/switch
Diffstat (limited to 'mlir')
-rw-r--r--mlir/CMakeLists.txt8
-rw-r--r--mlir/cmake/modules/AddMLIRPython.cmake7
-rw-r--r--mlir/docs/Bindings/Python.md6
-rw-r--r--mlir/docs/DialectConversion.md35
-rw-r--r--mlir/docs/Tutorials/Toy/Ch-2.md2
-rw-r--r--mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp1
-rw-r--r--mlir/include/mlir-c/Dialect/LLVM.h7
-rw-r--r--mlir/include/mlir/Analysis/DataFlowFramework.h23
-rw-r--r--mlir/include/mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h4
-rw-r--r--mlir/include/mlir/Conversion/SCFToEmitC/SCFToEmitC.h4
-rw-r--r--mlir/include/mlir/Dialect/Affine/IR/AffineOps.td27
-rw-r--r--mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h6
-rw-r--r--mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h6
-rw-r--r--mlir/include/mlir/Dialect/Bufferization/Transforms/Transforms.h28
-rw-r--r--mlir/include/mlir/Dialect/EmitC/IR/EmitC.td4
-rw-r--r--mlir/include/mlir/Dialect/Func/IR/FuncOps.td2
-rw-r--r--mlir/include/mlir/Dialect/GPU/IR/GPUOps.td2
-rw-r--r--mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td8
-rw-r--r--mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td3
-rw-r--r--mlir/include/mlir/Dialect/SCF/IR/SCF.h6
-rw-r--r--mlir/include/mlir/Dialect/SCF/IR/SCFOps.td4
-rw-r--r--mlir/include/mlir/Dialect/SCF/Transforms/Patterns.h3
-rw-r--r--mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td5
-rw-r--r--mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td10
-rw-r--r--mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td20
-rw-r--r--mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td14
-rw-r--r--mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td6
-rw-r--r--mlir/include/mlir/IR/Dialect.h1
-rw-r--r--mlir/include/mlir/IR/Dominance.h49
-rw-r--r--mlir/include/mlir/IR/OperationSupport.h13
-rw-r--r--mlir/include/mlir/Target/LLVMIR/ModuleImport.h6
-rw-r--r--mlir/include/mlir/Transforms/DialectConversion.h18
-rw-r--r--mlir/include/mlir/Transforms/LocationSnapshot.h12
-rw-r--r--mlir/include/mlir/Transforms/OneToNTypeConversion.h11
-rw-r--r--mlir/include/mlir/Transforms/Passes.td10
-rw-r--r--mlir/lib/Bindings/Python/IRCore.cpp10
-rw-r--r--mlir/lib/CAPI/Dialect/LLVM.cpp10
-rw-r--r--mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp3
-rw-r--r--mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp122
-rw-r--r--mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h21
-rw-r--r--mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp101
-rw-r--r--mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp12
-rw-r--r--mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp16
-rw-r--r--mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp1
-rw-r--r--mlir/lib/Conversion/SCFToEmitC/CMakeLists.txt1
-rw-r--r--mlir/lib/Conversion/SCFToEmitC/SCFToEmitC.cpp206
-rw-r--r--mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp12
-rw-r--r--mlir/lib/Dialect/Affine/IR/AffineOps.cpp276
-rw-r--r--mlir/lib/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.cpp62
-rw-r--r--mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp3
-rw-r--r--mlir/lib/Dialect/Affine/Utils/Utils.cpp9
-rw-r--r--mlir/lib/Dialect/Arith/IR/ArithOps.cpp24
-rw-r--r--mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp94
-rw-r--r--mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp20
-rw-r--r--mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp27
-rw-r--r--mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp53
-rw-r--r--mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp46
-rw-r--r--mlir/lib/Dialect/EmitC/Transforms/TypeConversions.cpp1
-rw-r--r--mlir/lib/Dialect/LLVMIR/IR/LLVMAttrs.cpp40
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp2
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp1
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/EliminateEmptyTensors.cpp2
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp2
-rw-r--r--mlir/lib/Dialect/Quant/Transforms/StripFuncQuantTypes.cpp1
-rw-r--r--mlir/lib/Dialect/SCF/IR/SCF.cpp31
-rw-r--r--mlir/lib/Dialect/SCF/Utils/Utils.cpp3
-rw-r--r--mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorDescriptor.cpp3
-rw-r--r--mlir/lib/Dialect/Tensor/IR/TensorOps.cpp123
-rw-r--r--mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp4
-rw-r--r--mlir/lib/Dialect/Tosa/IR/TosaOps.cpp90
-rw-r--r--mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeConv2D.cpp2
-rw-r--r--mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp2
-rw-r--r--mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp17
-rw-r--r--mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp8
-rw-r--r--mlir/lib/Dialect/Transform/IR/TransformOps.cpp2
-rw-r--r--mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp1
-rw-r--r--mlir/lib/IR/AsmPrinter.cpp19
-rw-r--r--mlir/lib/IR/Dominance.cpp124
-rw-r--r--mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp10
-rw-r--r--mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp61
-rw-r--r--mlir/lib/Target/LLVMIR/ModuleImport.cpp137
-rw-r--r--mlir/lib/Target/LLVMIR/ModuleTranslation.cpp19
-rw-r--r--mlir/lib/Transforms/LocationSnapshot.cpp31
-rw-r--r--mlir/lib/Transforms/Utils/DialectConversion.cpp455
-rw-r--r--mlir/python/mlir/_mlir_libs/__init__.py42
-rw-r--r--mlir/python/mlir/dialects/transform/structured.py71
-rw-r--r--mlir/python/mlir/ir.py6
-rw-r--r--mlir/test/Conversion/AffineToStandard/lower-affine-to-vector.mlir2
-rw-r--r--mlir/test/Conversion/AffineToStandard/lower-affine.mlir26
-rw-r--r--mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir31
-rw-r--r--mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl-hip.mlir2
-rw-r--r--mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl-opencl.mlir2
-rw-r--r--mlir/test/Conversion/GPUToSPIRV/printf.mlir2
-rw-r--r--mlir/test/Conversion/MemRefToLLVM/expand-then-convert-to-llvm.mlir26
-rw-r--r--mlir/test/Conversion/SCFToEmitC/for.mlir89
-rw-r--r--mlir/test/Conversion/SCFToEmitC/switch.mlir9
-rw-r--r--mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir36
-rw-r--r--mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir52
-rw-r--r--mlir/test/Dialect/Affine/canonicalize.mlir245
-rw-r--r--mlir/test/Dialect/Affine/value-bounds-op-interface-impl.mlir81
-rw-r--r--mlir/test/Dialect/Arith/canonicalize.mlir64
-rw-r--r--mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-empty-tensor-elimination.mlir11
-rw-r--r--mlir/test/Dialect/GPU/indirect-device-func-call.mlir2
-rw-r--r--mlir/test/Dialect/GPU/ops.mlir17
-rw-r--r--mlir/test/Dialect/GPU/test-nvvm-pipeline.mlir2
-rw-r--r--mlir/test/Dialect/LLVM/lower-to-llvm-e2e-with-target-tag.mlir2
-rw-r--r--mlir/test/Dialect/LLVM/lower-to-llvm-e2e-with-top-level-named-sequence.mlir2
-rw-r--r--mlir/test/Dialect/LLVMIR/roundtrip.mlir10
-rw-r--r--mlir/test/Dialect/Linalg/decompose-tensor-unpack-tile.mlir5
-rw-r--r--mlir/test/Dialect/Linalg/decompose-tensor-unpack.mlir4
-rw-r--r--mlir/test/Dialect/Linalg/td/decompose-unpack.mlir12
-rw-r--r--mlir/test/Dialect/SCF/canonicalize.mlir22
-rw-r--r--mlir/test/Dialect/Tensor/canonicalize.mlir25
-rw-r--r--mlir/test/Dialect/Tensor/invalid.mlir2
-rw-r--r--mlir/test/Dialect/Tosa/canonicalize.mlir35
-rw-r--r--mlir/test/Dialect/Tosa/constant-op-fold.mlir9
-rw-r--r--mlir/test/Dialect/Tosa/invalid.mlir137
-rw-r--r--mlir/test/Dialect/Tosa/level_check.mlir72
-rw-r--r--mlir/test/Dialect/Tosa/ops.mlir24
-rw-r--r--mlir/test/Dialect/Tosa/quant-test.mlir6
-rw-r--r--mlir/test/Dialect/Tosa/tosa-decompose-conv2d.mlir12
-rw-r--r--mlir/test/Dialect/Tosa/tosa-decompose-depthwise.mlir10
-rw-r--r--mlir/test/Dialect/Tosa/tosa-decompose-transpose-conv.mlir30
-rw-r--r--mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir78
-rw-r--r--mlir/test/Integration/GPU/CUDA/assert.mlir38
-rw-r--r--mlir/test/Integration/GPU/CUDA/printf.mlir2
-rw-r--r--mlir/test/Integration/GPU/CUDA/sm90/cga_cluster.mlir2
-rw-r--r--mlir/test/Integration/GPU/CUDA/sm90/tma_load_128x64_swizzle128b.mlir10
-rw-r--r--mlir/test/Integration/GPU/CUDA/sm90/tma_load_64x64_swizzle128b.mlir10
-rw-r--r--mlir/test/Integration/GPU/CUDA/sm90/tma_load_64x8_8x128_noswizzle.mlir8
-rw-r--r--mlir/test/Integration/GPU/CUDA/sm90/transform-dialect/tma_load_64x8_8x128_noswizzle-transform.mlir4
-rw-r--r--mlir/test/Integration/GPU/ROCM/printf.mlir2
-rw-r--r--mlir/test/Target/LLVMIR/Import/import-failure.ll9
-rw-r--r--mlir/test/Target/LLVMIR/Import/instructions.ll11
-rw-r--r--mlir/test/Target/LLVMIR/Import/metadata-alias-scopes.ll35
-rw-r--r--mlir/test/Target/LLVMIR/attribute-alias-scopes.mlir51
-rw-r--r--mlir/test/Target/LLVMIR/nvvmir.mlir29
-rw-r--r--mlir/test/Target/LLVMIR/omptarget-threadprivate-device-lowering.mlir30
-rw-r--r--mlir/test/Target/LLVMIR/openmp-simd-aligned.mlir60
-rw-r--r--mlir/test/Target/LLVMIR/openmp-todo.mlir12
-rw-r--r--mlir/test/Transforms/location-snapshot.mlir16
-rw-r--r--mlir/test/Transforms/loop-invariant-code-motion.mlir121
-rw-r--r--mlir/test/Transforms/test-legalizer.mlir26
-rw-r--r--mlir/test/lib/Dialect/Func/TestDecomposeCallGraphTypes.cpp2
-rw-r--r--mlir/test/lib/Dialect/Test/TestPatterns.cpp44
-rw-r--r--mlir/test/lib/Dialect/Tosa/TosaTestPasses.cpp2
-rw-r--r--mlir/test/lib/Transforms/TestDialectConversion.cpp1
-rw-r--r--mlir/test/python/dialects/transform_structured_ext.py36
-rw-r--r--mlir/test/python/execution_engine.py2
-rw-r--r--mlir/test/python/ir/dialects.py36
-rw-r--r--mlir/test/tblgen-lsp-server/templ-arg-check.test15
-rw-r--r--mlir/utils/pygments/README.md45
-rw-r--r--mlir/utils/pygments/mlir_lexer.py38
153 files changed, 3408 insertions, 1295 deletions
diff --git a/mlir/CMakeLists.txt b/mlir/CMakeLists.txt
index 5ea49c0dbfa7..a888ac243b04 100644
--- a/mlir/CMakeLists.txt
+++ b/mlir/CMakeLists.txt
@@ -170,7 +170,7 @@ configure_file(
# The pybind11 library can be found (set with -DPYBIND_DIR=...)
# The python executable is correct (set with -DPython3_EXECUTABLE=...)
# By default, find_package and probing for installed pybind11 is performed.
-# Super projects can set MLIR_DISABLE_CONFIGURE_PYTHON_DEV_PACKAGES=ON to
+# Super projects can set MLIR_DISABLE_CONFIGURE_PYTHON_DEV_PACKAGES=ON to
# disable all package setup and control it themselves.
#-------------------------------------------------------------------------------
@@ -196,8 +196,10 @@ endif()
set(CMAKE_INCLUDE_CURRENT_DIR ON)
-include_directories( "include")
-include_directories( ${MLIR_INCLUDE_DIR})
+include_directories(BEFORE
+ "include"
+ ${MLIR_INCLUDE_DIR}
+ )
# Adding tools/mlir-tblgen here as calling add_tablegen sets some variables like
# MLIR_TABLEGEN_EXE in PARENT_SCOPE which gets lost if that folder is included
diff --git a/mlir/cmake/modules/AddMLIRPython.cmake b/mlir/cmake/modules/AddMLIRPython.cmake
index 9d4e06c7909c..717a503468a8 100644
--- a/mlir/cmake/modules/AddMLIRPython.cmake
+++ b/mlir/cmake/modules/AddMLIRPython.cmake
@@ -683,6 +683,13 @@ function(add_mlir_python_extension libname extname)
${eh_rtti_enable}
)
endif()
+
+ if(APPLE)
+ # NanobindAdaptors.h uses PyClassMethod_New to build `pure_subclass`es but nanobind
+ # doesn't declare this API as undefined in its linker flags. So we need to declare it as such
+ # for downstream users that do not do something like `-undefined dynamic_lookup`.
+ set(CMAKE_MODULE_LINKER_FLAGS "${CMAKE_MODULE_LINKER_FLAGS} -Wl,-U -Wl,_PyClassMethod_New")
+ endif()
endif()
target_compile_options(${libname} PRIVATE ${eh_rtti_enable})
diff --git a/mlir/docs/Bindings/Python.md b/mlir/docs/Bindings/Python.md
index a0bd1cac118b..32df3310d811 100644
--- a/mlir/docs/Bindings/Python.md
+++ b/mlir/docs/Bindings/Python.md
@@ -1035,7 +1035,7 @@ class ConstantOp(_ods_ir.OpView):
...
```
-expects `value` to be a `TypedAttr` (e.g., `IntegerAttr` or `FloatAttr`).
+expects `value` to be a `TypedAttr` (e.g., `IntegerAttr` or `FloatAttr`).
Thus, a natural extension is a builder that accepts a MLIR type and a Python value and instantiates the appropriate `TypedAttr`:
```python
@@ -1181,9 +1181,9 @@ make the passes available along with the dialect.
Dialect functionality other than IR objects or passes, such as helper functions,
can be exposed to Python similarly to attributes and types. C API is expected to
exist for this functionality, which can then be wrapped using pybind11 and
-`[include/mlir/Bindings/Python/PybindAdaptors.h](https://github.com/llvm/llvm-project/blob/main/mlir/include/mlir/Bindings/Python/PybindAdaptors.h)`,
+[`include/mlir/Bindings/Python/PybindAdaptors.h`](https://github.com/llvm/llvm-project/blob/main/mlir/include/mlir/Bindings/Python/PybindAdaptors.h),
or nanobind and
-`[include/mlir/Bindings/Python/NanobindAdaptors.h](https://github.com/llvm/llvm-project/blob/main/mlir/include/mlir/Bindings/Python/NanobindAdaptors.h)`
+[`include/mlir/Bindings/Python/NanobindAdaptors.h`](https://github.com/llvm/llvm-project/blob/main/mlir/include/mlir/Bindings/Python/NanobindAdaptors.h)
utilities to connect to the rest of Python API. The bindings can be located in a
separate module or in the same module as attributes and types, and
loaded along with the dialect.
diff --git a/mlir/docs/DialectConversion.md b/mlir/docs/DialectConversion.md
index 3168f5e13c75..abacd5a82c61 100644
--- a/mlir/docs/DialectConversion.md
+++ b/mlir/docs/DialectConversion.md
@@ -242,19 +242,6 @@ cannot. These materializations are used by the conversion framework to ensure
type safety during the conversion process. There are several types of
materializations depending on the situation.
-* Argument Materialization
-
- - An argument materialization is used when converting the type of a block
- argument during a [signature conversion](#region-signature-conversion).
- The new block argument types are specified in a `SignatureConversion`
- object. An original block argument can be converted into multiple
- block arguments, which is not supported everywhere in the dialect
- conversion. (E.g., adaptors support only a single replacement value for
- each original value.) Therefore, an argument materialization is used to
- convert potentially multiple new block arguments back into a single SSA
- value. An argument materialization is also used when replacing an op
- result with multiple values.
-
* Source Materialization
- A source materialization is used when a value was replaced with a value
@@ -344,17 +331,6 @@ class TypeConverter {
/// persist after the conversion has finished.
/// This method registers a materialization that will be called when
- /// converting (potentially multiple) block arguments that were the result of
- /// a signature conversion of a single block argument, to a single SSA value
- /// with the old argument type.
- template <typename FnT,
- typename T = typename llvm::function_traits<FnT>::template arg_t<1>>
- void addArgumentMaterialization(FnT &&callback) {
- argumentMaterializations.emplace_back(
- wrapMaterialization<T>(std::forward<FnT>(callback)));
- }
-
- /// This method registers a materialization that will be called when
/// converting a replacement value back to its original source type.
/// This is used when some uses of the original value persist beyond the main
/// conversion.
@@ -406,12 +382,11 @@ done explicitly via a conversion pattern.
To convert the types of block arguments within a Region, a custom hook on the
`ConversionPatternRewriter` must be invoked; `convertRegionTypes`. This hook
uses a provided type converter to apply type conversions to all blocks of a
-given region. As noted above, the conversions performed by this method use the
-argument materialization hook on the `TypeConverter`. This hook also takes an
-optional `TypeConverter::SignatureConversion` parameter that applies a custom
-conversion to the entry block of the region. The types of the entry block
-arguments are often tied semantically to the operation, e.g.,
-`func::FuncOp`, `AffineForOp`, etc.
+given region. This hook also takes an optional
+`TypeConverter::SignatureConversion` parameter that applies a custom conversion
+to the entry block of the region. The types of the entry block arguments are
+often tied semantically to the operation, e.g., `func::FuncOp`, `AffineForOp`,
+etc.
To convert the signature of just one given block, the
`applySignatureConversion` hook can be used.
diff --git a/mlir/docs/Tutorials/Toy/Ch-2.md b/mlir/docs/Tutorials/Toy/Ch-2.md
index b807ee3a2049..039417c9c9a1 100644
--- a/mlir/docs/Tutorials/Toy/Ch-2.md
+++ b/mlir/docs/Tutorials/Toy/Ch-2.md
@@ -262,7 +262,7 @@ class ConstantOp : public mlir::Op<
mlir::OpTrait::OneResult,
/// We also provide a utility `getType` accessor that
/// returns the TensorType of the single result.
- mlir::OpTraits::OneTypedResult<TensorType>::Impl> {
+ mlir::OpTrait::OneTypedResult<TensorType>::Impl> {
public:
/// Inherit the constructors from the base Op class.
diff --git a/mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp b/mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp
index 3ad70e727969..123d114ae163 100644
--- a/mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp
+++ b/mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp
@@ -220,6 +220,7 @@ void ToyToLLVMLoweringPass::runOnOperation() {
mlir::arith::populateArithToLLVMConversionPatterns(typeConverter, patterns);
populateFinalizeMemRefToLLVMConversionPatterns(typeConverter, patterns);
cf::populateControlFlowToLLVMConversionPatterns(typeConverter, patterns);
+ cf::populateAssertToLLVMConversionPattern(typeConverter, patterns);
populateFuncToLLVMConversionPatterns(typeConverter, patterns);
// The only remaining operation to lower from the `toy` dialect, is the
diff --git a/mlir/include/mlir-c/Dialect/LLVM.h b/mlir/include/mlir-c/Dialect/LLVM.h
index 0992285f997e..26c4140757c3 100644
--- a/mlir/include/mlir-c/Dialect/LLVM.h
+++ b/mlir/include/mlir-c/Dialect/LLVM.h
@@ -45,6 +45,13 @@ MLIR_CAPI_EXPORTED MlirType
mlirLLVMFunctionTypeGet(MlirType resultType, intptr_t nArgumentTypes,
MlirType const *argumentTypes, bool isVarArg);
+/// Returns the number of input types.
+MLIR_CAPI_EXPORTED intptr_t mlirLLVMFunctionTypeGetNumInputs(MlirType type);
+
+/// Returns the pos-th input type.
+MLIR_CAPI_EXPORTED MlirType mlirLLVMFunctionTypeGetInput(MlirType type,
+ intptr_t pos);
+
/// Returns `true` if the type is an LLVM dialect struct type.
MLIR_CAPI_EXPORTED bool mlirTypeIsALLVMStructType(MlirType type);
diff --git a/mlir/include/mlir/Analysis/DataFlowFramework.h b/mlir/include/mlir/Analysis/DataFlowFramework.h
index dfd358e7017a..b6d10ba0bea2 100644
--- a/mlir/include/mlir/Analysis/DataFlowFramework.h
+++ b/mlir/include/mlir/Analysis/DataFlowFramework.h
@@ -332,9 +332,11 @@ public:
/// does not exist.
template <typename StateT, typename AnchorT>
const StateT *lookupState(AnchorT anchor) const {
- auto it =
- analysisStates.find({LatticeAnchor(anchor), TypeID::get<StateT>()});
- if (it == analysisStates.end())
+ const auto &mapIt = analysisStates.find(LatticeAnchor(anchor));
+ if (mapIt == analysisStates.end())
+ return nullptr;
+ auto it = mapIt->second.find(TypeID::get<StateT>());
+ if (it == mapIt->second.end())
return nullptr;
return static_cast<const StateT *>(it->second.get());
}
@@ -343,11 +345,7 @@ public:
template <typename AnchorT>
void eraseState(AnchorT anchor) {
LatticeAnchor la(anchor);
-
- for (auto it = analysisStates.begin(); it != analysisStates.end(); ++it) {
- if (it->first.first == la)
- analysisStates.erase(it);
- }
+ analysisStates.erase(LatticeAnchor(anchor));
}
// Erase all analysis states
@@ -426,7 +424,8 @@ private:
/// A type-erased map of lattice anchors to associated analysis states for
/// first-class lattice anchors.
- DenseMap<std::pair<LatticeAnchor, TypeID>, std::unique_ptr<AnalysisState>>
+ DenseMap<LatticeAnchor, DenseMap<TypeID, std::unique_ptr<AnalysisState>>,
+ DenseMapInfo<LatticeAnchor::ParentTy>>
analysisStates;
/// Allow the base child analysis class to access the internals of the solver.
@@ -643,7 +642,7 @@ AnalysisT *DataFlowSolver::load(Args &&...args) {
template <typename StateT, typename AnchorT>
StateT *DataFlowSolver::getOrCreateState(AnchorT anchor) {
std::unique_ptr<AnalysisState> &state =
- analysisStates[{LatticeAnchor(anchor), TypeID::get<StateT>()}];
+ analysisStates[LatticeAnchor(anchor)][TypeID::get<StateT>()];
if (!state) {
state = std::unique_ptr<StateT>(new StateT(anchor));
#if LLVM_ENABLE_ABI_BREAKING_CHECKS
@@ -689,10 +688,6 @@ struct DenseMapInfo<mlir::ProgramPoint> {
}
};
-template <>
-struct DenseMapInfo<mlir::LatticeAnchor>
- : public DenseMapInfo<mlir::LatticeAnchor::ParentTy> {};
-
// Allow llvm::cast style functions.
template <typename To>
struct CastInfo<To, mlir::LatticeAnchor>
diff --git a/mlir/include/mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h b/mlir/include/mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h
index b88c1e8b20f3..88f18022da9b 100644
--- a/mlir/include/mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h
+++ b/mlir/include/mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h
@@ -29,6 +29,10 @@ namespace cf {
/// Collect the patterns to convert from the ControlFlow dialect to LLVM. The
/// conversion patterns capture the LLVMTypeConverter by reference meaning the
/// references have to remain alive during the entire pattern lifetime.
+///
+/// Note: This function does not populate the default cf.assert lowering. That
+/// is because some platforms have a custom cf.assert lowering. The default
+/// lowering can be populated with `populateAssertToLLVMConversionPattern`.
void populateControlFlowToLLVMConversionPatterns(
const LLVMTypeConverter &converter, RewritePatternSet &patterns);
diff --git a/mlir/include/mlir/Conversion/SCFToEmitC/SCFToEmitC.h b/mlir/include/mlir/Conversion/SCFToEmitC/SCFToEmitC.h
index 22df7f1c5dcf..acc39e6acf72 100644
--- a/mlir/include/mlir/Conversion/SCFToEmitC/SCFToEmitC.h
+++ b/mlir/include/mlir/Conversion/SCFToEmitC/SCFToEmitC.h
@@ -9,6 +9,7 @@
#ifndef MLIR_CONVERSION_SCFTOEMITC_SCFTOEMITC_H
#define MLIR_CONVERSION_SCFTOEMITC_SCFTOEMITC_H
+#include "mlir/Transforms/DialectConversion.h"
#include <memory>
namespace mlir {
@@ -19,7 +20,8 @@ class RewritePatternSet;
#include "mlir/Conversion/Passes.h.inc"
/// Collect a set of patterns to convert SCF operations to the EmitC dialect.
-void populateSCFToEmitCConversionPatterns(RewritePatternSet &patterns);
+void populateSCFToEmitCConversionPatterns(RewritePatternSet &patterns,
+ TypeConverter &typeConverter);
} // namespace mlir
#endif // MLIR_CONVERSION_SCFTOEMITC_SCFTOEMITC_H
diff --git a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
index f5ca24389065..e2eab1fb2178 100644
--- a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
+++ b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
@@ -1083,6 +1083,9 @@ def AffineDelinearizeIndexOp : Affine_Op<"delinearize_index", [Pure]> {
%indices_2 = affine.apply #map2()[%linear_index]
```
+ In other words, `%0:3 = affine.delinearize_index %x into (B, C)` produces
+ `%0 = {%x / (B * C), (%x mod (B * C)) / C, %x mod C}`.
+
The basis may either contain `N` or `N-1` elements, where `N` is the number of results.
If there are N basis elements, the first one will not be used during computations,
but may be used during analysis and canonicalization to eliminate terms from
@@ -1098,7 +1101,12 @@ def AffineDelinearizeIndexOp : Affine_Op<"delinearize_index", [Pure]> {
%0:3 = affine.delinearize_index %linear_index into (244, 244) : index, index
```
- Note that, due to the constraints of affine maps, all the basis elements must
+ Note that, for symmetry with `getPaddedBasis()`, if `hasOuterBound` is `true`
+ when one of the `OpFoldResult` builders is called but the first element of the
+ basis is `nullptr`, that first element is ignored and the builder proceeds as if
+ there was no outer bound.
+
+ Due to the constraints of affine maps, all the basis elements must
be strictly positive. A dynamic basis element being 0 or negative causes
undefined behavior.
}];
@@ -1136,6 +1144,11 @@ def AffineDelinearizeIndexOp : Affine_Op<"delinearize_index", [Pure]> {
/// Return a vector that contains the basis of the operation, removing
/// the outer bound if one is present.
SmallVector<OpFoldResult> getEffectiveBasis();
+
+ /// Return the vector with one basis element per result of the operation. If
+ /// there is no outer bound specified, the leading entry of this result will be
+ /// nullptr.
+ SmallVector<OpFoldResult> getPaddedBasis();
}];
let hasVerifier = 1;
@@ -1160,6 +1173,9 @@ def AffineLinearizeIndexOp : Affine_Op<"linearize_index",
sum(i = 0 to N-1) %idx_i * product(j = i + 1 to N-1) B_j
```
+ In other words, `%0 = affine.linearize_index [%z, %y, %x] by (Z, Y, X)`
+ gives `%0 = %x + %y * X + %z * X * Y`, or `%0 = %x + X * (%y + Y * (%z))`.
+
The basis may either have `N` or `N-1` elements, where `N` is the number of
inputs to linearize_index. If `N` inputs are provided, the first one is not used
in computation, but may be used during analysis or canonicalization as a bound
@@ -1168,6 +1184,10 @@ def AffineLinearizeIndexOp : Affine_Op<"linearize_index",
If all `N` basis elements are provided, the linearize_index operation is said to
"have an outer bound".
+ As a convenience, and for symmetry with `getPaddedBasis()`, ifg the first
+ element of a set of `OpFoldResult`s passed to the builders of this operation is
+ `nullptr`, that element is ignored.
+
If the `disjoint` property is present, this is an optimization hint that,
for all `i`, `0 <= %idx_i < B_i` - that is, no index affects any other index,
except that `%idx_0` may be negative to make the index as a whole negative.
@@ -1224,6 +1244,11 @@ def AffineLinearizeIndexOp : Affine_Op<"linearize_index",
/// Return a vector that contains the basis of the operation, removing
/// the outer bound if one is present.
SmallVector<OpFoldResult> getEffectiveBasis();
+
+ /// Return the vector with one basis element per index operand of the operation.
+ /// If there is no outer bound specified, the leading entry of this basis will be
+ /// nullptr.
+ SmallVector<OpFoldResult> getPaddedBasis();
}];
let hasVerifier = 1;
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
index 983f7a29cb22..d1a102e2a6e4 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
@@ -456,7 +456,7 @@ public:
/// read by themselves (e.g., ExtractSliceOp).
bool isValueRead(Value value) const;
- /// Starting from `value`, follow the use-def chain in reverse, always
+ /// Starting from `opOperand`, follow the use-def chain in reverse, always
/// selecting the aliasing OpOperands. Find and return Values for which
/// `condition` evaluates to true. OpOperands of such matching Values are not
/// traversed any further, the visited aliasing opOperands will be preserved
@@ -484,7 +484,7 @@ public:
/// Additional stopping conditions for the traversal can be specified in
/// `config`.
SetVector<Value> findValueInReverseUseDefChain(
- Value value, llvm::function_ref<bool(Value)> condition,
+ OpOperand *opOperand, llvm::function_ref<bool(Value)> condition,
TraversalConfig config = TraversalConfig(),
llvm::DenseSet<OpOperand *> *visitedOpOperands = nullptr) const;
@@ -520,7 +520,7 @@ public:
///
/// Note: OpResults of unknown ops are handled conservatively and assumed to
/// be definitions.
- SetVector<Value> findDefinitions(Value value) const;
+ SetVector<Value> findDefinitions(OpOperand *opOperand) const;
/// Return `true` if the given OpResult has been decided to bufferize inplace.
virtual bool isInPlace(OpOperand &opOperand) const;
diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h
index d50a3042aeea..bd23a19f7472 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h
@@ -127,9 +127,9 @@ public:
/// Return true if the buffer of the given tensor value is writable.
bool isWritable(Value value) const;
- /// Find the definitions of the given tensor value or retrieve them from the
- /// cache.
- const SetVector<Value> &findDefinitionsCached(Value value);
+ /// Find the definitions of the given operand's value or
+ /// retrieve them from the cache.
+ const SetVector<Value> &findDefinitionsCached(OpOperand *opOperand);
/// Reset cached data structures.
void resetCache() override;
diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/Transforms.h
index 892675954493..a4ee893ca534 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/Transforms.h
@@ -10,7 +10,9 @@
#define MLIR_DIALECT_BUFFERIZATION_TRANSFORMS_TRANSFORMS_H
#include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/Operation.h"
+#include "mlir/Interfaces/SubsetOpInterface.h"
namespace mlir {
namespace bufferization {
@@ -34,13 +36,35 @@ struct OneShotBufferizationOptions;
/// "tensor.empty" op.
LogicalResult eliminateEmptyTensors(RewriterBase &rewriter, Operation *op);
+/// A function type that defines a callback to control the construction
+/// of the subset extraction of the `SubsetInsertionOpInterface`.
+/// The subset extraction value can be used as a replacement for the
+/// `emptyTensorOp` value which is being consumed by `user`, failing
+/// of building such a value should be indicated with an empty value.
+/// This function should guarantee the legality of the replacement,
+/// i.e. the replacement should dominate the user of the `emptyTensorOp`
+/// being eliminated.
+using ControlBuildSubsetExtractionFn =
+ std::function<Value(RewriterBase &, SubsetInsertionOpInterface,
+ tensor::EmptyOp emptyTensorOp, Operation *user)>;
+
+/// This method builds and returns a subset extraction value for the
+/// destination tensor that the given `op` inserts into.
+/// It returns a value which should replace the `emptyTensorOp` use
+/// that is being consumed by `user`.
+/// If no such a value found it will return an empty Value.
+Value buildSubsetExtraction(RewriterBase &rewriter,
+ SubsetInsertionOpInterface op,
+ tensor::EmptyOp emptyTensorOp, Operation *user);
+
/// Try to eliminate "tensor.empty" ops inside `op`.
///
/// This function overload accepts an existing `OneShotAnalysisState`, which
/// contains in-place bufferization decisions. This overload is useful if an
/// existing analysis should be reused for empty tensor elimination.
-LogicalResult eliminateEmptyTensors(RewriterBase &rewriter, Operation *op,
- OneShotAnalysisState &state);
+LogicalResult eliminateEmptyTensors(
+ RewriterBase &rewriter, Operation *op, OneShotAnalysisState &state,
+ ControlBuildSubsetExtractionFn subsetsExtractionFn = buildSubsetExtraction);
/// Within the given operation, hoist buffers from loops where possible. See
/// "BufferLoopHoistingPass" for more information.
diff --git a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
index fc5a33541533..744a0dc4770e 100644
--- a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
+++ b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
@@ -727,7 +727,7 @@ def EmitC_ReturnOp : EmitC_Op<"return", [Pure, HasParent<"FuncOp">,
Example:
```mlir
- emitc.func @foo() : (i32) {
+ emitc.func @foo() -> (i32) {
...
emitc.return %0 : i32
}
@@ -1305,8 +1305,6 @@ def EmitC_IfOp : EmitC_Op<"if",
Block* body = getBody(1);
return OpBuilder::atBlockEnd(body, listener);
}
- Block* thenBlock();
- Block* elseBlock();
}];
let hasCustomAssemblyFormat = 1;
}
diff --git a/mlir/include/mlir/Dialect/Func/IR/FuncOps.td b/mlir/include/mlir/Dialect/Func/IR/FuncOps.td
index 237a825c1910..211201802b08 100644
--- a/mlir/include/mlir/Dialect/Func/IR/FuncOps.td
+++ b/mlir/include/mlir/Dialect/Func/IR/FuncOps.td
@@ -352,7 +352,7 @@ def ReturnOp : Func_Op<"return", [Pure, HasParent<"FuncOp">,
Example:
```mlir
- func.func @foo() : (i32, f8) {
+ func.func @foo() -> (i32, f8) {
...
return %0, %1 : i32, f8
}
diff --git a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
index 42a017db300a..3adfd5f4f2c4 100644
--- a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
+++ b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
@@ -1055,7 +1055,7 @@ def GPU_PrintfOp : GPU_Op<"printf", [MemoryEffects<[MemWrite]>]>,
imposed by one's target platform.
}];
let assemblyFormat = [{
- $format attr-dict ($args^ `:` type($args))?
+ $format attr-dict (`,` $args^ `:` type($args))?
}];
}
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td
index e8eeafd09a9c..267389774bd5 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td
@@ -825,7 +825,7 @@ def LLVM_MemoryEffectsAttr : LLVM_Attr<"MemoryEffects", "memory_effects"> {
def LLVM_AliasScopeDomainAttr : LLVM_Attr<"AliasScopeDomain",
"alias_scope_domain"> {
let parameters = (ins
- "DistinctAttr":$id,
+ "Attribute":$id,
OptionalParameter<"StringAttr">:$description
);
@@ -853,7 +853,7 @@ def LLVM_AliasScopeDomainAttr : LLVM_Attr<"AliasScopeDomain",
def LLVM_AliasScopeAttr : LLVM_Attr<"AliasScope", "alias_scope"> {
let parameters = (ins
- "DistinctAttr":$id,
+ "Attribute":$id,
"AliasScopeDomainAttr":$domain,
OptionalParameter<"StringAttr">:$description
);
@@ -891,6 +891,8 @@ def LLVM_AliasScopeAttr : LLVM_Attr<"AliasScope", "alias_scope"> {
}
```
+ The first attribute can either be a DistinctAttr or a StringAttr.
+
See the following link for more details:
https://llvm.org/docs/LangRef.html#noalias-and-alias-scope-metadata
}];
@@ -898,6 +900,8 @@ def LLVM_AliasScopeAttr : LLVM_Attr<"AliasScope", "alias_scope"> {
let summary = "LLVM dialect alias scope";
let assemblyFormat = "`<` struct(params) `>`";
+
+ let genVerifyDecl = 1;
}
def LLVM_AliasScopeArrayAttr
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index 37eec6e07963..fff4048ee125 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -472,9 +472,6 @@ def TransposeOp : LinalgStructuredBase_Op<"transpose", [
getRegionBuilder() {
return regionBuilder;
}
-
- static void createRegion(::mlir::OpBuilder &opBuilder,
- ::mlir::OperationState & odsState);
}];
let hasFolder = 1;
diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCF.h b/mlir/include/mlir/Dialect/SCF/IR/SCF.h
index b62c94179794..ba648181daec 100644
--- a/mlir/include/mlir/Dialect/SCF/IR/SCF.h
+++ b/mlir/include/mlir/Dialect/SCF/IR/SCF.h
@@ -40,12 +40,6 @@ void buildTerminatedBody(OpBuilder &builder, Location loc);
namespace mlir {
namespace scf {
-// Insert `loop.yield` at the end of the only region's only block if it
-// does not have a terminator already. If a new `loop.yield` is inserted,
-// the location is specified by `loc`. If the region is empty, insert a new
-// block first.
-void ensureLoopTerminator(Region &region, Builder &builder, Location loc);
-
/// Returns the loop parent of an induction variable. If the provided value is
/// not an induction variable, then return nullptr.
ForOp getForInductionVarOwner(Value val);
diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
index 23c597a1ca51..6f408b3c924d 100644
--- a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
+++ b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
@@ -302,7 +302,7 @@ def ForallOp : SCF_Op<"forall", [
AttrSizedOperandSegments,
AutomaticAllocationScope,
DeclareOpInterfaceMethods<LoopLikeOpInterface,
- ["getInitsMutable", "getRegionIterArgs", "getLoopInductionVars",
+ ["getInitsMutable", "getRegionIterArgs", "getLoopInductionVars",
"getLoopLowerBounds", "getLoopUpperBounds", "getLoopSteps",
"promoteIfSingleIteration", "yieldTiledValuesAndReplace"]>,
RecursiveMemoryEffects,
@@ -671,7 +671,7 @@ def IfOp : SCF_Op<"if", [DeclareOpInterfaceMethods<RegionBranchOpInterface, [
"getNumRegionInvocations", "getRegionInvocationBounds",
"getEntrySuccessorRegions"]>,
InferTypeOpAdaptor, SingleBlockImplicitTerminator<"scf::YieldOp">,
- RecursiveMemoryEffects, NoRegionArguments]> {
+ RecursiveMemoryEffects, RecursivelySpeculatable, NoRegionArguments]> {
let summary = "if-then-else operation";
let description = [{
The `scf.if` operation represents an if-then-else construct for
diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/Patterns.h b/mlir/include/mlir/Dialect/SCF/Transforms/Patterns.h
index b87407d302a8..18c9dfd205de 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/Patterns.h
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/Patterns.h
@@ -66,6 +66,9 @@ void populateSCFStructuralTypeConversionTarget(
/// Populates the provided pattern set with patterns that do 1:N type
/// conversions on (some) SCF ops. This is intended to be used with
/// applyPartialOneToNConversion.
+/// FIXME: The 1:N dialect conversion is deprecated and will be removed soon.
+/// 1:N support has been added to the regular dialect conversion driver.
+/// Use populateSCFStructuralTypeConversions() instead.
void populateSCFStructuralOneToNTypeConversions(
const TypeConverter &typeConverter, RewritePatternSet &patterns);
diff --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
index 08a0398e74b0..8bccba426ab1 100644
--- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
+++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
@@ -321,11 +321,6 @@ def Shape_DimOp : Shape_Op<"dim",
let assemblyFormat = "$value `,` $index attr-dict `:` type($value) `,`"
"type($index) `->` type($extent)";
- let builders = [
- // Builder that allows passing a constant dimension as a simple integer.
- OpBuilder<(ins "Value":$value, "int64_t":$index)>
- ];
-
let extraClassDeclaration = [{
/// Get the `index` value as integer if it is constant.
std::optional<int64_t> getConstantIndex();
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
index f5536927dc25..d3f12c34421b 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
@@ -126,11 +126,12 @@ def Tosa_ConvOpQuantInfoBuilder : OpBuilder<
(ins "::mlir::Type":$outputType, "::mlir::Value":$input,
"::mlir::Value":$weight, "::mlir::Value":$bias,
"::mlir::DenseI64ArrayAttr":$pad, "::mlir::DenseI64ArrayAttr":$stride,
- "::mlir::DenseI64ArrayAttr":$dilation),
+ "::mlir::DenseI64ArrayAttr":$dilation,
+ "::mlir::TypeAttr":$acc_type),
[{
buildConvOpWithQuantInfo($_builder, $_state, outputType,
input, weight, bias,
- pad, stride, dilation);
+ pad, stride, dilation, acc_type);
}]>;
// Handles tosa.transpose_conv2d which has an outpad and output shape attribute.
@@ -139,12 +140,13 @@ def Tosa_TransConvOpQuantInfoBuilder : OpBuilder<
"::mlir::Value":$weight, "mlir::Value":$bias,
"::mlir::DenseI64ArrayAttr":$outpad,
"::mlir::DenseI64ArrayAttr":$stride,
- "::mlir::DenseI64ArrayAttr":$outputShape),
+ "::mlir::DenseI64ArrayAttr":$outputShape,
+ "::mlir::TypeAttr":$acc_type),
[{
buildTransConvOpWithQuantInfo($_builder, $_state, outputType,
input, weight, bias,
outpad, stride,
- outputShape);
+ outputShape, acc_type);
}]>;
// The tosa.fully_connected op has its own builder as it does not have
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index e3c725801d16..6b43c9a259b1 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -57,7 +57,7 @@ def Tosa_ArgMaxOp : Tosa_InferShapedTypeOp<"argmax"> {
// Accumulator types.
//===----------------------------------------------------------------------===//
-def Tosa_AccType : AnyTypeOf<[I<32>, SI<32>, F16, F32]>;
+def Tosa_AccType : AnyTypeOf<[I<32>, I<48>, F16, F32]>;
//===----------------------------------------------------------------------===//
// Operator: avg_pool2d
@@ -106,6 +106,7 @@ def Tosa_Conv2DOp : Tosa_InferShapedTypeOp<"conv2d"> {
Tosa_IntArrayAttr4:$pad,
Tosa_IntArrayAttr2:$stride,
Tosa_IntArrayAttr2:$dilation,
+ TypeAttrOf<Tosa_AccType>:$acc_type,
OptionalAttr<Tosa_ConvOpQuantizationAttr>:$quantization_info,
DefaultValuedOptionalAttr<BoolAttr, "false">:$local_bound
);
@@ -135,6 +136,7 @@ def Tosa_Conv3DOp : Tosa_InferShapedTypeOp<"conv3d"> {
Tosa_IntArrayAttr6:$pad,
Tosa_IntArrayAttr3:$stride,
Tosa_IntArrayAttr3:$dilation,
+ TypeAttrOf<Tosa_AccType>:$acc_type,
OptionalAttr<Tosa_ConvOpQuantizationAttr>:$quantization_info,
DefaultValuedOptionalAttr<BoolAttr, "false">:$local_bound
);
@@ -165,6 +167,7 @@ def Tosa_DepthwiseConv2DOp : Tosa_InferShapedTypeOp<"depthwise_conv2d"> {
Tosa_IntArrayAttr4:$pad,
Tosa_IntArrayAttr2:$stride,
Tosa_IntArrayAttr2:$dilation,
+ TypeAttrOf<Tosa_AccType>:$acc_type,
OptionalAttr<Tosa_ConvOpQuantizationAttr>:$quantization_info,
DefaultValuedOptionalAttr<BoolAttr, "false">:$local_bound
);
@@ -348,6 +351,7 @@ def Tosa_TransposeConv2DOp : Tosa_InferShapedTypeOp<"transpose_conv2d"> {
Tosa_IntArrayAttr4:$out_pad,
Tosa_IntArrayAttr2:$stride,
Tosa_IntArrayAttr4:$out_shape,
+ TypeAttrOf<Tosa_AccType>:$acc_type,
OptionalAttr<Tosa_ConvOpQuantizationAttr>:$quantization_info,
DefaultValuedOptionalAttr<BoolAttr, "false">:$local_bound
);
@@ -357,6 +361,7 @@ def Tosa_TransposeConv2DOp : Tosa_InferShapedTypeOp<"transpose_conv2d"> {
);
let builders = [Tosa_TransConvOpQuantInfoBuilder];
+ let hasVerifier = 1;
}
//===----------------------------------------------------------------------===//
@@ -1552,21 +1557,21 @@ def Tosa_PadOp : Tosa_InferShapedTypeOp<"pad"> {
Example:
```mlir
- %0 = arith.constant dense<[[1, 2], [3, 4]]> : tensor<2x2xi32>
- tosa.pad %arg0, %0 : (tensor<1x2xf32>, tensor<2x2xi32>) -> (tensor<4x9xf32>)
+ %0 = arith.constant dense<[1, 2, 3, 4]> : tensor<4xi32>
+ tosa.pad %arg0, %0 : (tensor<1x2xf32>, tensor<4xi32>) -> (tensor<4x9xf32>)
```
Example 2:
```mlir
- %0 = arith.constant dense<[[-1, 2], [3, 4]]> : tensor<2x2xi32>
- tosa.pad %arg0, %0 : (tensor<1x2xf32>, tensor<2x2xi32>) -> (tensor<?x9xf32>)
+ %0 = arith.constant dense<[-1, 2, 3, 4]> : tensor<4xi32>
+ tosa.pad %arg0, %0 : (tensor<1x2xf32>, tensor<4xi32>) -> (tensor<?x9xf32>)
```
}];
let arguments = (ins
Tosa_RankedTensor:$input1,
- Tosa_Int32Or64Tensor:$padding,
+ TosaTensorRankOf<[Tosa_Int32Or64], [1]>:$padding,
Optional<Tosa_ScalarTensor>:$pad_const,
OptionalAttr<Tosa_PadOpQuantizationAttr>:$quantization_info
);
@@ -1698,7 +1703,8 @@ def Tosa_TileOp : Tosa_InferShapedTypeOp<"tile"> {
// Operator: transpose
//===----------------------------------------------------------------------===//
def Tosa_TransposeOp : Tosa_InferShapedTypeOp<"transpose",
- [DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>]> {
+ [DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>,
+ AllElementTypesMatch<["input1", "output"]>]> {
let summary = "Transpose operator";
let description = [{
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
index a6d3163d4446..d3cc6e92bac2 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
@@ -65,17 +65,17 @@ def Tosa_Int32Or64 : AnyTypeOf<[Tosa_Int32,
// int8 : symmetric per tensor/per channel, signed
// int16 : symmetric per tensor, signed
//===----------------------------------------------------------------------===//
-def Tosa_QuantizedInt : AnyTypeOf<[ Tosa_QuantizedType<"uint8", [8], 0>,
- Tosa_QuantizedType<"int4", [4, 0], 1>,
- Tosa_QuantizedType<"int8", [8, 0], 1>,
- Tosa_QuantizedType<"int16", [16, 0], 1>,
- Tosa_QuantizedType<"int32", [32, 0], 1>]>;
+def Tosa_QuantizedInt : AnyTypeOf<[Tosa_QuantizedType<"uint8", [8], 0>,
+ Tosa_QuantizedType<"int4", [4, 0], 1>,
+ Tosa_QuantizedType<"int8", [8, 0], 1>,
+ Tosa_QuantizedType<"int16", [16, 0], 1>,
+ Tosa_QuantizedType<"int32", [32, 0], 1>]>;
//===----------------------------------------------------------------------===//
// Multi-category types.
//===----------------------------------------------------------------------===//
def Tosa_AnyNumber : AnyTypeOf<[Tosa_Int, Tosa_QuantizedInt, AnyFloat],
- "number">;
+ "number">;
// For weight tensors from tosa::Conv2DOp, tosa::Conv3DOp,
// tosa::DepthwiseConv2DOp, tosa::TransposeConv2DOp, tosa::FullyConnectedOp
@@ -112,7 +112,7 @@ class TosaTensorRankOf<list<Type> allowedTypes, list<int> ranks>
def Tosa_I1Tensor : TosaTensorOf<[I1]>;
def Tosa_Int32Tensor : TosaTensorOf<[Tosa_Int32]>;
-def Tosa_Int32Or64Tensor :TosaTensorOf<[Tosa_Int32Or64]>;
+def Tosa_Int32Or64Tensor : TosaTensorOf<[Tosa_Int32Or64]>;
def Tosa_FloatTensor : TosaTensorOf<[AnyFloat]>;
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
index 2aaa7fd4221a..4841f94de75f 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
@@ -164,11 +164,9 @@ def XeGPU_SGMapAttr : XeGPUAttr<"SGMap", "sg_map"> {
}];
let parameters = (ins
ArrayRefParameter<"uint32_t">:$wi_layout,
- ArrayRefParameter<"uint32_t">:$wi_data);
+ ArrayRefParameter<"uint32_t">:$wi_data
+ );
- let builders = [
- AttrBuilder<(ins)>
- ];
let hasCustomAssemblyFormat = 1;
let genVerifyDecl = 1;
diff --git a/mlir/include/mlir/IR/Dialect.h b/mlir/include/mlir/IR/Dialect.h
index f3e5f6d88c53..fb24a6895dab 100644
--- a/mlir/include/mlir/IR/Dialect.h
+++ b/mlir/include/mlir/IR/Dialect.h
@@ -368,7 +368,6 @@ private:
DenseSet<std::pair<TypeID, TypeID>> unresolvedPromisedInterfaces;
friend class DialectRegistry;
- friend void registerDialect();
friend class MLIRContext;
};
diff --git a/mlir/include/mlir/IR/Dominance.h b/mlir/include/mlir/IR/Dominance.h
index 16d17b9c0f3d..9e1254c1dfe1 100644
--- a/mlir/include/mlir/IR/Dominance.h
+++ b/mlir/include/mlir/IR/Dominance.h
@@ -113,12 +113,12 @@ protected:
llvm::PointerIntPair<DomTree *, 1, bool>
getDominanceInfo(Region *region, bool needsDomTree) const;
- /// Return "true" if the specified block A properly (post)dominates block B.
- bool properlyDominatesImpl(Block *a, Block *b) const;
-
- /// Return "true" if the specified op A properly (post)dominates op B.
- bool properlyDominatesImpl(Operation *a, Operation *b,
- bool enclosingOpOk = true) const;
+ /// Return "true" if block iterator A properly (post)dominates block iterator
+ /// B. If `enclosingOk` is set, A is considered to (post)dominate B if A
+ /// encloses B.
+ bool properlyDominatesImpl(Block *aBlock, Block::iterator aIt, Block *bBlock,
+ Block::iterator bIt,
+ bool enclosingOk = true) const;
/// A mapping of regions to their base dominator tree and a cached
/// "hasSSADominance" bit. This map does not contain dominator trees for
@@ -151,9 +151,7 @@ public:
/// The `enclosingOpOk` flag says whether we should return true if the B op
/// is enclosed by a region on A.
bool properlyDominates(Operation *a, Operation *b,
- bool enclosingOpOk = true) const {
- return super::properlyDominatesImpl(a, b, enclosingOpOk);
- }
+ bool enclosingOpOk = true) const;
/// Return true if operation A dominates operation B, i.e. if A and B are the
/// same operation or A properly dominates B.
@@ -188,8 +186,17 @@ public:
/// Graph regions have only a single block. To be consistent with "proper
/// dominance" of ops, the single block is considered to properly dominate
/// itself in a graph region.
- bool properlyDominates(Block *a, Block *b) const {
- return super::properlyDominatesImpl(a, b);
+ bool properlyDominates(Block *a, Block *b) const;
+
+ bool properlyDominates(Block *aBlock, Block::iterator aIt, Block *bBlock,
+ Block::iterator bIt, bool enclosingOk = true) const {
+ return super::properlyDominatesImpl(aBlock, aIt, bBlock, bIt, enclosingOk);
+ }
+
+ bool dominates(Block *aBlock, Block::iterator aIt, Block *bBlock,
+ Block::iterator bIt, bool enclosingOk = true) const {
+ return (aBlock == bBlock && aIt == bIt) ||
+ super::properlyDominatesImpl(aBlock, aIt, bBlock, bIt, enclosingOk);
}
};
@@ -200,9 +207,7 @@ public:
/// Return true if operation A properly postdominates operation B.
bool properlyPostDominates(Operation *a, Operation *b,
- bool enclosingOpOk = true) const {
- return super::properlyDominatesImpl(a, b, enclosingOpOk);
- }
+ bool enclosingOpOk = true) const;
/// Return true if operation A postdominates operation B.
bool postDominates(Operation *a, Operation *b) const {
@@ -210,14 +215,24 @@ public:
}
/// Return true if the specified block A properly postdominates block B.
- bool properlyPostDominates(Block *a, Block *b) const {
- return super::properlyDominatesImpl(a, b);
- }
+ bool properlyPostDominates(Block *a, Block *b) const;
/// Return true if the specified block A postdominates block B.
bool postDominates(Block *a, Block *b) const {
return a == b || properlyPostDominates(a, b);
}
+
+ bool properlyPostDominates(Block *aBlock, Block::iterator aIt, Block *bBlock,
+ Block::iterator bIt,
+ bool enclosingOk = true) const {
+ return super::properlyDominatesImpl(aBlock, aIt, bBlock, bIt, enclosingOk);
+ }
+
+ bool postDominates(Block *aBlock, Block::iterator aIt, Block *bBlock,
+ Block::iterator bIt, bool enclosingOk = true) const {
+ return (aBlock == bBlock && aIt == bIt) ||
+ super::properlyDominatesImpl(aBlock, aIt, bBlock, bIt, enclosingOk);
+ }
};
} // namespace mlir
diff --git a/mlir/include/mlir/IR/OperationSupport.h b/mlir/include/mlir/IR/OperationSupport.h
index ef5b8b178fbc..5eb2d69134ea 100644
--- a/mlir/include/mlir/IR/OperationSupport.h
+++ b/mlir/include/mlir/IR/OperationSupport.h
@@ -693,9 +693,6 @@ public:
/// Return the dialect this operation is registered to.
Dialect &getDialect() const { return *getImpl()->getDialect(); }
- /// Use the specified object to parse this ops custom assembly format.
- ParseResult parseAssembly(OpAsmParser &parser, OperationState &result) const;
-
/// Represent the operation name as an opaque pointer. (Used to support
/// PointerLikeTypeTraits).
static RegisteredOperationName getFromOpaquePointer(const void *pointer) {
@@ -1169,16 +1166,20 @@ public:
OpPrintingFlags &skipRegions(bool skip = true);
/// Do not verify the operation when using custom operation printers.
- OpPrintingFlags &assumeVerified();
+ OpPrintingFlags &assumeVerified(bool enable = true);
/// Use local scope when printing the operation. This allows for using the
/// printer in a more localized and thread-safe setting, but may not
/// necessarily be identical to what the IR will look like when dumping
/// the full module.
- OpPrintingFlags &useLocalScope();
+ OpPrintingFlags &useLocalScope(bool enable = true);
/// Print users of values as comments.
- OpPrintingFlags &printValueUsers();
+ OpPrintingFlags &printValueUsers(bool enable = true);
+
+ /// Print unique SSA ID numbers for values, block arguments and naming
+ /// conflicts across all regions
+ OpPrintingFlags &printUniqueSSAIDs(bool enable = true);
/// Return if the given ElementsAttr should be elided.
bool shouldElideElementsAttr(ElementsAttr attr) const;
diff --git a/mlir/include/mlir/Target/LLVMIR/ModuleImport.h b/mlir/include/mlir/Target/LLVMIR/ModuleImport.h
index eea0647895b0..33c9af7c6335 100644
--- a/mlir/include/mlir/Target/LLVMIR/ModuleImport.h
+++ b/mlir/include/mlir/Target/LLVMIR/ModuleImport.h
@@ -319,9 +319,13 @@ private:
/// Appends the converted result type and operands of `callInst` to the
/// `types` and `operands` arrays. For indirect calls, the method additionally
/// inserts the called function at the beginning of the `operands` array.
+ /// If `allowInlineAsm` is set to false (the default), it will return failure
+ /// if the called operand is an inline asm which isn't convertible to MLIR as
+ /// a value.
LogicalResult convertCallTypeAndOperands(llvm::CallBase *callInst,
SmallVectorImpl<Type> &types,
- SmallVectorImpl<Value> &operands);
+ SmallVectorImpl<Value> &operands,
+ bool allowInlineAsm = false);
/// Converts the parameter attributes attached to `func` and adds them to the
/// `funcOp`.
void convertParameterAttributes(llvm::Function *func, LLVMFuncOp funcOp,
diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index 28150e886913..9a6975dcf8df 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -181,6 +181,10 @@ public:
/// converting (potentially multiple) block arguments that were the result of
/// a signature conversion of a single block argument, to a single SSA value
/// with the old block argument type.
+ ///
+ /// Note: Argument materializations are used only with the 1:N dialect
+ /// conversion driver. The 1:N dialect conversion driver will be removed soon
+ /// and so will be argument materializations.
template <typename FnT, typename T = typename llvm::function_traits<
std::decay_t<FnT>>::template arg_t<1>>
void addArgumentMaterialization(FnT &&callback) {
@@ -880,15 +884,7 @@ public:
void replaceOp(Operation *op, Operation *newOp) override;
/// Replace the given operation with the new value ranges. The number of op
- /// results and value ranges must match. If an original SSA value is replaced
- /// by multiple SSA values (i.e., a value range has more than 1 element), the
- /// conversion driver will insert an argument materialization to convert the
- /// N SSA values back into 1 SSA value of the original type. The given
- /// operation is erased.
- ///
- /// Note: The argument materialization is a workaround until we have full 1:N
- /// support in the dialect conversion. (It is going to disappear from both
- /// `replaceOpWithMultiple` and `applySignatureConversion`.)
+ /// results and value ranges must match. The given operation is erased.
void replaceOpWithMultiple(Operation *op, ArrayRef<ValueRange> newValues);
/// PatternRewriter hook for erasing a dead operation. The uses of this
@@ -1285,8 +1281,8 @@ struct ConversionConfig {
// represented at the moment.
RewriterBase::Listener *listener = nullptr;
- /// If set to "true", the dialect conversion attempts to build source/target/
- /// argument materializations through the type converter API in lieu of
+ /// If set to "true", the dialect conversion attempts to build source/target
+ /// materializations through the type converter API in lieu of
/// "builtin.unrealized_conversion_cast ops". The conversion process fails if
/// at least one materialization could not be built.
///
diff --git a/mlir/include/mlir/Transforms/LocationSnapshot.h b/mlir/include/mlir/Transforms/LocationSnapshot.h
index ccfdbac007ac..cefe005d2c4c 100644
--- a/mlir/include/mlir/Transforms/LocationSnapshot.h
+++ b/mlir/include/mlir/Transforms/LocationSnapshot.h
@@ -51,18 +51,6 @@ void generateLocationsFromIR(raw_ostream &os, StringRef fileName, StringRef tag,
LogicalResult generateLocationsFromIR(StringRef fileName, StringRef tag,
Operation *op, OpPrintingFlags flags);
-/// Create a pass to generate new locations by snapshotting the IR to the given
-/// file, and using the printed locations within that file. If `filename` is
-/// empty, a temporary file is generated instead. If a 'tag' is non-empty, the
-/// generated locations are represented as a NameLoc with the given tag as the
-/// name, and then fused with the existing locations. Otherwise, the existing
-/// locations are replaced.
-std::unique_ptr<Pass> createLocationSnapshotPass(OpPrintingFlags flags,
- StringRef fileName = "",
- StringRef tag = "");
-/// Overload utilizing pass options for initialization.
-std::unique_ptr<Pass> createLocationSnapshotPass();
-
} // namespace mlir
#endif // MLIR_TRANSFORMS_LOCATIONSNAPSHOT_H
diff --git a/mlir/include/mlir/Transforms/OneToNTypeConversion.h b/mlir/include/mlir/Transforms/OneToNTypeConversion.h
index 7b4dd65cbff7..37a326818d64 100644
--- a/mlir/include/mlir/Transforms/OneToNTypeConversion.h
+++ b/mlir/include/mlir/Transforms/OneToNTypeConversion.h
@@ -6,6 +6,9 @@
//
//===----------------------------------------------------------------------===//
//
+// Note: The 1:N dialect conversion is deprecated and will be removed soon.
+// 1:N support has been added to the regular dialect conversion driver.
+//
// This file provides utils for implementing (poor-man's) dialect conversion
// passes with 1:N type conversions.
//
@@ -119,6 +122,8 @@ public:
/// types must be the same as the result types of the op) and the new values
/// (i.e., the converted types must be the same as the types of the new
/// values).
+ /// FIXME: The 1:N dialect conversion is deprecated and will be removed soon.
+ /// Use replaceOpWithMultiple() instead.
void replaceOp(Operation *op, ValueRange newValues,
const OneToNTypeMapping &resultMapping);
using PatternRewriter::replaceOp;
@@ -251,6 +256,9 @@ public:
/// or illegal types; the function simply applies the given patterns and does
/// not fail if some ops or types remain unconverted (i.e., the conversion is
/// only "partial").
+/// FIXME: The 1:N dialect conversion is deprecated and will be removed soon.
+/// 1:N support has been added to the regular dialect conversion driver.
+/// Use applyPartialConversion() instead.
LogicalResult
applyPartialOneToNConversion(Operation *op, TypeConverter &typeConverter,
const FrozenRewritePatternSet &patterns);
@@ -259,6 +267,9 @@ applyPartialOneToNConversion(Operation *op, TypeConverter &typeConverter,
/// FunctionOpInterface op with the given type converter. This only supports
/// ops which use FunctionType to represent their type. This is intended to be
/// used with the 1:N dialect conversion.
+/// FIXME: The 1:N dialect conversion is deprecated and will be removed soon.
+/// 1:N support has been added to the regular dialect conversion driver.
+/// Use populateFunctionOpInterfaceTypeConversionPattern() instead.
void populateOneToNFunctionOpInterfaceTypeConversionPattern(
StringRef functionLikeOpName, const TypeConverter &converter,
RewritePatternSet &patterns);
diff --git a/mlir/include/mlir/Transforms/Passes.td b/mlir/include/mlir/Transforms/Passes.td
index 000d9f697618..c4a8e7a81fa4 100644
--- a/mlir/include/mlir/Transforms/Passes.td
+++ b/mlir/include/mlir/Transforms/Passes.td
@@ -331,13 +331,21 @@ def LocationSnapshot : Pass<"snapshot-op-locations"> {
... loc(fused["original_source.cpp":1:1, "snapshot"("snapshot_source.mlir":10:10)])
```
}];
- let constructor = "mlir::createLocationSnapshotPass()";
let options = [
Option<"fileName", "filename", "std::string", /*default=*/"",
"The filename to print the generated IR">,
Option<"tag", "tag", "std::string", /*default=*/"",
"A tag to use when fusing the new locations with the "
"original. If unset, the locations are replaced.">,
+ Option<"enableDebugInfo", "print-debuginfo", "bool", /*default=*/"false",
+ "Print debug info in MLIR output">,
+ Option<"printGenericOpForm", "print-op-generic", "bool", /*default=*/"false",
+ "Print the generic op form">,
+ Option<"useLocalScope", "print-local-scope", "bool", /*default=*/"false",
+ "Print with local scope and inline information (eliding "
+ "aliases for attributes, types, and locations">,
+ Option<"printPrettyDebugInfo", "pretty-debuginfo", "bool", /*default=*/"false",
+ "Print pretty debug info in MLIR output">,
];
}
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index 86afa956398a..453d4f7c7e8b 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -272,13 +272,13 @@ struct PyAttrBuilderMap {
static bool dunderContains(const std::string &attributeKind) {
return PyGlobals::get().lookupAttributeBuilder(attributeKind).has_value();
}
- static nb::callable dundeGetItemNamed(const std::string &attributeKind) {
+ static nb::callable dunderGetItemNamed(const std::string &attributeKind) {
auto builder = PyGlobals::get().lookupAttributeBuilder(attributeKind);
if (!builder)
throw nb::key_error(attributeKind.c_str());
return *builder;
}
- static void dundeSetItemNamed(const std::string &attributeKind,
+ static void dunderSetItemNamed(const std::string &attributeKind,
nb::callable func, bool replace) {
PyGlobals::get().registerAttributeBuilder(attributeKind, std::move(func),
replace);
@@ -287,8 +287,8 @@ struct PyAttrBuilderMap {
static void bind(nb::module_ &m) {
nb::class_<PyAttrBuilderMap>(m, "AttrBuilder")
.def_static("contains", &PyAttrBuilderMap::dunderContains)
- .def_static("get", &PyAttrBuilderMap::dundeGetItemNamed)
- .def_static("insert", &PyAttrBuilderMap::dundeSetItemNamed,
+ .def_static("get", &PyAttrBuilderMap::dunderGetItemNamed)
+ .def_static("insert", &PyAttrBuilderMap::dunderSetItemNamed,
"attribute_kind"_a, "attr_builder"_a, "replace"_a = false,
"Register an attribute builder for building MLIR "
"attributes from python values.");
@@ -2587,6 +2587,8 @@ private:
//------------------------------------------------------------------------------
void mlir::python::populateIRCore(nb::module_ &m) {
+ // disable leak warnings which tend to be false positives.
+ nb::set_leak_warnings(false);
//----------------------------------------------------------------------------
// Enums.
//----------------------------------------------------------------------------
diff --git a/mlir/lib/CAPI/Dialect/LLVM.cpp b/mlir/lib/CAPI/Dialect/LLVM.cpp
index 6ed82ba1a025..da450dd3fd8a 100644
--- a/mlir/lib/CAPI/Dialect/LLVM.cpp
+++ b/mlir/lib/CAPI/Dialect/LLVM.cpp
@@ -55,6 +55,16 @@ MlirType mlirLLVMFunctionTypeGet(MlirType resultType, intptr_t nArgumentTypes,
unwrapList(nArgumentTypes, argumentTypes, argumentStorage), isVarArg));
}
+intptr_t mlirLLVMFunctionTypeGetNumInputs(MlirType type) {
+ return llvm::cast<LLVM::LLVMFunctionType>(unwrap(type)).getNumParams();
+}
+
+MlirType mlirLLVMFunctionTypeGetInput(MlirType type, intptr_t pos) {
+ assert(pos >= 0 && "pos in array must be positive");
+ return wrap(llvm::cast<LLVM::LLVMFunctionType>(unwrap(type))
+ .getParamType(static_cast<unsigned>(pos)));
+}
+
bool mlirTypeIsALLVMStructType(MlirType type) {
return isa<LLVM::LLVMStructType>(unwrap(type));
}
diff --git a/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp b/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp
index 8672e7b849d9..d0ffb94f3f96 100644
--- a/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp
+++ b/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp
@@ -215,7 +215,6 @@ void mlir::cf::populateControlFlowToLLVMConversionPatterns(
const LLVMTypeConverter &converter, RewritePatternSet &patterns) {
// clang-format off
patterns.add<
- AssertOpLowering,
BranchOpLowering,
CondBranchOpLowering,
SwitchOpLowering>(converter);
@@ -258,6 +257,7 @@ struct ConvertControlFlowToLLVM
LLVMTypeConverter converter(ctx, options);
RewritePatternSet patterns(ctx);
mlir::cf::populateControlFlowToLLVMConversionPatterns(converter, patterns);
+ mlir::cf::populateAssertToLLVMConversionPattern(converter, patterns);
if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns))))
@@ -286,6 +286,7 @@ struct ControlFlowToLLVMDialectInterface
RewritePatternSet &patterns) const final {
mlir::cf::populateControlFlowToLLVMConversionPatterns(typeConverter,
patterns);
+ mlir::cf::populateAssertToLLVMConversionPattern(typeConverter, patterns);
}
};
} // namespace
diff --git a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
index b3c3fd4956d0..544fc57949e2 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
+++ b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
@@ -19,6 +19,59 @@
using namespace mlir;
+LLVM::LLVMFuncOp mlir::getOrDefineFunction(gpu::GPUModuleOp moduleOp,
+ Location loc, OpBuilder &b,
+ StringRef name,
+ LLVM::LLVMFunctionType type) {
+ LLVM::LLVMFuncOp ret;
+ if (!(ret = moduleOp.template lookupSymbol<LLVM::LLVMFuncOp>(name))) {
+ OpBuilder::InsertionGuard guard(b);
+ b.setInsertionPointToStart(moduleOp.getBody());
+ ret = b.create<LLVM::LLVMFuncOp>(loc, name, type, LLVM::Linkage::External);
+ }
+ return ret;
+}
+
+static SmallString<16> getUniqueSymbolName(gpu::GPUModuleOp moduleOp,
+ StringRef prefix) {
+ // Get a unique global name.
+ unsigned stringNumber = 0;
+ SmallString<16> stringConstName;
+ do {
+ stringConstName.clear();
+ (prefix + Twine(stringNumber++)).toStringRef(stringConstName);
+ } while (moduleOp.lookupSymbol(stringConstName));
+ return stringConstName;
+}
+
+LLVM::GlobalOp
+mlir::getOrCreateStringConstant(OpBuilder &b, Location loc,
+ gpu::GPUModuleOp moduleOp, Type llvmI8,
+ StringRef namePrefix, StringRef str,
+ uint64_t alignment, unsigned addrSpace) {
+ llvm::SmallString<20> nullTermStr(str);
+ nullTermStr.push_back('\0'); // Null terminate for C
+ auto globalType =
+ LLVM::LLVMArrayType::get(llvmI8, nullTermStr.size_in_bytes());
+ StringAttr attr = b.getStringAttr(nullTermStr);
+
+ // Try to find existing global.
+ for (auto globalOp : moduleOp.getOps<LLVM::GlobalOp>())
+ if (globalOp.getGlobalType() == globalType && globalOp.getConstant() &&
+ globalOp.getValueAttr() == attr &&
+ globalOp.getAlignment().value_or(0) == alignment &&
+ globalOp.getAddrSpace() == addrSpace)
+ return globalOp;
+
+ // Not found: create new global.
+ OpBuilder::InsertionGuard guard(b);
+ b.setInsertionPointToStart(moduleOp.getBody());
+ SmallString<16> name = getUniqueSymbolName(moduleOp, namePrefix);
+ return b.create<LLVM::GlobalOp>(loc, globalType,
+ /*isConstant=*/true, LLVM::Linkage::Internal,
+ name, attr, alignment, addrSpace);
+}
+
LogicalResult
GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
@@ -328,61 +381,6 @@ GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor,
return success();
}
-static SmallString<16> getUniqueFormatGlobalName(gpu::GPUModuleOp moduleOp) {
- const char formatStringPrefix[] = "printfFormat_";
- // Get a unique global name.
- unsigned stringNumber = 0;
- SmallString<16> stringConstName;
- do {
- stringConstName.clear();
- (formatStringPrefix + Twine(stringNumber++)).toStringRef(stringConstName);
- } while (moduleOp.lookupSymbol(stringConstName));
- return stringConstName;
-}
-
-/// Create an global that contains the given format string. If a global with
-/// the same format string exists already in the module, return that global.
-static LLVM::GlobalOp getOrCreateFormatStringConstant(
- OpBuilder &b, Location loc, gpu::GPUModuleOp moduleOp, Type llvmI8,
- StringRef str, uint64_t alignment = 0, unsigned addrSpace = 0) {
- llvm::SmallString<20> formatString(str);
- formatString.push_back('\0'); // Null terminate for C
- auto globalType =
- LLVM::LLVMArrayType::get(llvmI8, formatString.size_in_bytes());
- StringAttr attr = b.getStringAttr(formatString);
-
- // Try to find existing global.
- for (auto globalOp : moduleOp.getOps<LLVM::GlobalOp>())
- if (globalOp.getGlobalType() == globalType && globalOp.getConstant() &&
- globalOp.getValueAttr() == attr &&
- globalOp.getAlignment().value_or(0) == alignment &&
- globalOp.getAddrSpace() == addrSpace)
- return globalOp;
-
- // Not found: create new global.
- OpBuilder::InsertionGuard guard(b);
- b.setInsertionPointToStart(moduleOp.getBody());
- SmallString<16> name = getUniqueFormatGlobalName(moduleOp);
- return b.create<LLVM::GlobalOp>(loc, globalType,
- /*isConstant=*/true, LLVM::Linkage::Internal,
- name, attr, alignment, addrSpace);
-}
-
-template <typename T>
-static LLVM::LLVMFuncOp getOrDefineFunction(T &moduleOp, const Location loc,
- ConversionPatternRewriter &rewriter,
- StringRef name,
- LLVM::LLVMFunctionType type) {
- LLVM::LLVMFuncOp ret;
- if (!(ret = moduleOp.template lookupSymbol<LLVM::LLVMFuncOp>(name))) {
- ConversionPatternRewriter::InsertionGuard guard(rewriter);
- rewriter.setInsertionPointToStart(moduleOp.getBody());
- ret = rewriter.create<LLVM::LLVMFuncOp>(loc, name, type,
- LLVM::Linkage::External);
- }
- return ret;
-}
-
LogicalResult GPUPrintfOpToHIPLowering::matchAndRewrite(
gpu::PrintfOp gpuPrintfOp, gpu::PrintfOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
@@ -420,8 +418,8 @@ LogicalResult GPUPrintfOpToHIPLowering::matchAndRewrite(
Value printfDesc = printfBeginCall.getResult();
// Create the global op or find an existing one.
- LLVM::GlobalOp global = getOrCreateFormatStringConstant(
- rewriter, loc, moduleOp, llvmI8, adaptor.getFormat());
+ LLVM::GlobalOp global = getOrCreateStringConstant(
+ rewriter, loc, moduleOp, llvmI8, "printfFormat_", adaptor.getFormat());
// Get a pointer to the format string's first element and pass it to printf()
Value globalPtr = rewriter.create<LLVM::AddressOfOp>(
@@ -502,9 +500,9 @@ LogicalResult GPUPrintfOpToLLVMCallLowering::matchAndRewrite(
getOrDefineFunction(moduleOp, loc, rewriter, "printf", printfType);
// Create the global op or find an existing one.
- LLVM::GlobalOp global = getOrCreateFormatStringConstant(
- rewriter, loc, moduleOp, llvmI8, adaptor.getFormat(), /*alignment=*/0,
- addressSpace);
+ LLVM::GlobalOp global = getOrCreateStringConstant(
+ rewriter, loc, moduleOp, llvmI8, "printfFormat_", adaptor.getFormat(),
+ /*alignment=*/0, addressSpace);
// Get a pointer to the format string's first element
Value globalPtr = rewriter.create<LLVM::AddressOfOp>(
@@ -546,8 +544,8 @@ LogicalResult GPUPrintfOpToVPrintfLowering::matchAndRewrite(
getOrDefineFunction(moduleOp, loc, rewriter, "vprintf", vprintfType);
// Create the global op or find an existing one.
- LLVM::GlobalOp global = getOrCreateFormatStringConstant(
- rewriter, loc, moduleOp, llvmI8, adaptor.getFormat());
+ LLVM::GlobalOp global = getOrCreateStringConstant(
+ rewriter, loc, moduleOp, llvmI8, "printfFormat_", adaptor.getFormat());
// Get a pointer to the format string's first element
Value globalPtr = rewriter.create<LLVM::AddressOfOp>(loc, global);
diff --git a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h
index 444a07a93ca3..e73a74845d2b 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h
+++ b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h
@@ -14,6 +14,27 @@
namespace mlir {
+//===----------------------------------------------------------------------===//
+// Helper Functions
+//===----------------------------------------------------------------------===//
+
+/// Find or create an external function declaration in the given module.
+LLVM::LLVMFuncOp getOrDefineFunction(gpu::GPUModuleOp moduleOp, Location loc,
+ OpBuilder &b, StringRef name,
+ LLVM::LLVMFunctionType type);
+
+/// Create a global that contains the given string. If a global with the same
+/// string already exists in the module, return that global.
+LLVM::GlobalOp getOrCreateStringConstant(OpBuilder &b, Location loc,
+ gpu::GPUModuleOp moduleOp, Type llvmI8,
+ StringRef namePrefix, StringRef str,
+ uint64_t alignment = 0,
+ unsigned addrSpace = 0);
+
+//===----------------------------------------------------------------------===//
+// Lowering Patterns
+//===----------------------------------------------------------------------===//
+
/// Lowering for gpu.dynamic.shared.memory to LLVM dialect. The pattern first
/// create a 0-sized global array symbol similar as LLVM expects. It constructs
/// a memref descriptor with these values and return it.
diff --git a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
index e022d3ce6f63..2768929f460e 100644
--- a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
+++ b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
@@ -25,6 +25,7 @@
#include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
+#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/GPU/Transforms/Passes.h"
@@ -236,6 +237,103 @@ struct GPULaneIdOpToNVVM : ConvertOpToLLVMPattern<gpu::LaneIdOp> {
}
};
+/// Lowering of cf.assert into a conditional __assertfail.
+struct AssertOpToAssertfailLowering
+ : public ConvertOpToLLVMPattern<cf::AssertOp> {
+ using ConvertOpToLLVMPattern<cf::AssertOp>::ConvertOpToLLVMPattern;
+
+ LogicalResult
+ matchAndRewrite(cf::AssertOp assertOp, cf::AssertOpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ MLIRContext *ctx = rewriter.getContext();
+ Location loc = assertOp.getLoc();
+ Type i8Type = typeConverter->convertType(rewriter.getIntegerType(8));
+ Type i32Type = typeConverter->convertType(rewriter.getIntegerType(32));
+ Type i64Type = typeConverter->convertType(rewriter.getIntegerType(64));
+ Type ptrType = LLVM::LLVMPointerType::get(ctx);
+ Type voidType = LLVM::LLVMVoidType::get(ctx);
+
+ // Find or create __assertfail function declaration.
+ auto moduleOp = assertOp->getParentOfType<gpu::GPUModuleOp>();
+ auto assertfailType = LLVM::LLVMFunctionType::get(
+ voidType, {ptrType, ptrType, i32Type, ptrType, i64Type});
+ LLVM::LLVMFuncOp assertfailDecl = getOrDefineFunction(
+ moduleOp, loc, rewriter, "__assertfail", assertfailType);
+ assertfailDecl.setPassthroughAttr(
+ ArrayAttr::get(ctx, StringAttr::get(ctx, "noreturn")));
+
+ // Split blocks and insert conditional branch.
+ // ^before:
+ // ...
+ // cf.cond_br %condition, ^after, ^assert
+ // ^assert:
+ // cf.assert
+ // cf.br ^after
+ // ^after:
+ // ...
+ Block *beforeBlock = assertOp->getBlock();
+ Block *assertBlock =
+ rewriter.splitBlock(beforeBlock, assertOp->getIterator());
+ Block *afterBlock =
+ rewriter.splitBlock(assertBlock, ++assertOp->getIterator());
+ rewriter.setInsertionPointToEnd(beforeBlock);
+ rewriter.create<cf::CondBranchOp>(loc, adaptor.getArg(), afterBlock,
+ assertBlock);
+ rewriter.setInsertionPointToEnd(assertBlock);
+ rewriter.create<cf::BranchOp>(loc, afterBlock);
+
+ // Continue cf.assert lowering.
+ rewriter.setInsertionPoint(assertOp);
+
+ // Populate file name, file number and function name from the location of
+ // the AssertOp.
+ StringRef fileName = "(unknown)";
+ StringRef funcName = "(unknown)";
+ int32_t fileLine = 0;
+ while (auto callSiteLoc = dyn_cast<CallSiteLoc>(loc))
+ loc = callSiteLoc.getCallee();
+ if (auto fileLineColLoc = dyn_cast<FileLineColRange>(loc)) {
+ fileName = fileLineColLoc.getFilename().strref();
+ fileLine = fileLineColLoc.getStartLine();
+ } else if (auto nameLoc = dyn_cast<NameLoc>(loc)) {
+ funcName = nameLoc.getName().strref();
+ if (auto fileLineColLoc =
+ dyn_cast<FileLineColRange>(nameLoc.getChildLoc())) {
+ fileName = fileLineColLoc.getFilename().strref();
+ fileLine = fileLineColLoc.getStartLine();
+ }
+ }
+
+ // Create constants.
+ auto getGlobal = [&](LLVM::GlobalOp global) {
+ // Get a pointer to the format string's first element.
+ Value globalPtr = rewriter.create<LLVM::AddressOfOp>(
+ loc, LLVM::LLVMPointerType::get(ctx, global.getAddrSpace()),
+ global.getSymNameAttr());
+ Value start =
+ rewriter.create<LLVM::GEPOp>(loc, ptrType, global.getGlobalType(),
+ globalPtr, ArrayRef<LLVM::GEPArg>{0, 0});
+ return start;
+ };
+ Value assertMessage = getGlobal(getOrCreateStringConstant(
+ rewriter, loc, moduleOp, i8Type, "assert_message_", assertOp.getMsg()));
+ Value assertFile = getGlobal(getOrCreateStringConstant(
+ rewriter, loc, moduleOp, i8Type, "assert_file_", fileName));
+ Value assertFunc = getGlobal(getOrCreateStringConstant(
+ rewriter, loc, moduleOp, i8Type, "assert_func_", funcName));
+ Value assertLine =
+ rewriter.create<LLVM::ConstantOp>(loc, i32Type, fileLine);
+ Value c1 = rewriter.create<LLVM::ConstantOp>(loc, i64Type, 1);
+
+ // Insert function call to __assertfail.
+ SmallVector<Value> arguments{assertMessage, assertFile, assertLine,
+ assertFunc, c1};
+ rewriter.replaceOpWithNewOp<LLVM::CallOp>(assertOp, assertfailDecl,
+ arguments);
+ return success();
+ }
+};
+
/// Import the GPU Ops to NVVM Patterns.
#include "GPUToNVVM.cpp.inc"
@@ -358,7 +456,8 @@ void mlir::populateGpuToNVVMConversionPatterns(
using gpu::index_lowering::IndexKind;
using gpu::index_lowering::IntrType;
populateWithGenerated(patterns);
- patterns.add<GPUPrintfOpToVPrintfLowering>(converter);
+ patterns.add<GPUPrintfOpToVPrintfLowering, AssertOpToAssertfailLowering>(
+ converter);
patterns.add<
gpu::index_lowering::OpLowering<gpu::ThreadIdOp, NVVM::ThreadIdXOp,
NVVM::ThreadIdYOp, NVVM::ThreadIdZOp>>(
diff --git a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
index d52a86987b1c..afebded1c3ea 100644
--- a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
+++ b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
@@ -47,7 +47,6 @@
#include "../GPUCommon/GPUOpsLowering.h"
#include "../GPUCommon/IndexIntrinsicsOpLowering.h"
-#include "../GPUCommon/OpToFuncCallLowering.h"
namespace mlir {
#define GEN_PASS_DEF_CONVERTGPUOPSTOROCDLOPS
@@ -297,6 +296,7 @@ struct LowerGpuOpsToROCDLOpsPass
populateVectorToLLVMConversionPatterns(converter, llvmPatterns);
populateMathToLLVMConversionPatterns(converter, llvmPatterns);
cf::populateControlFlowToLLVMConversionPatterns(converter, llvmPatterns);
+ cf::populateAssertToLLVMConversionPattern(converter, llvmPatterns);
populateFuncToLLVMConversionPatterns(converter, llvmPatterns);
populateFinalizeMemRefToLLVMConversionPatterns(converter, llvmPatterns);
populateGpuToROCDLConversionPatterns(converter, llvmPatterns, runtime);
@@ -346,16 +346,6 @@ void mlir::configureGpuToROCDLConversionLegality(ConversionTarget &target) {
target.addLegalOp<gpu::YieldOp, gpu::GPUModuleOp>();
}
-template <typename OpTy>
-static void populateOpPatterns(const LLVMTypeConverter &converter,
- RewritePatternSet &patterns, StringRef f32Func,
- StringRef f64Func, StringRef f32ApproxFunc,
- StringRef f16Func) {
- patterns.add<ScalarizeVectorOpLowering<OpTy>>(converter);
- patterns.add<OpToFuncCallLowering<OpTy>>(converter, f32Func, f32ApproxFunc,
- f16Func);
-}
-
void mlir::populateGpuToROCDLConversionPatterns(
const LLVMTypeConverter &converter, RewritePatternSet &patterns,
mlir::gpu::amd::Runtime runtime) {
diff --git a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
index 49e2d9432866..72799e42cf3f 100644
--- a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
@@ -85,7 +85,7 @@ static Value unrankedMemRefMaterialization(OpBuilder &builder,
UnrankedMemRefType resultType,
ValueRange inputs, Location loc,
const LLVMTypeConverter &converter) {
- // An argument materialization must return a value of type
+ // A source materialization must return a value of type
// `resultType`, so insert a cast from the memref descriptor type
// (!llvm.struct) to the original memref type.
Value packed =
@@ -101,7 +101,7 @@ static Value rankedMemRefMaterialization(OpBuilder &builder,
MemRefType resultType,
ValueRange inputs, Location loc,
const LLVMTypeConverter &converter) {
- // An argument materialization must return a value of type `resultType`,
+ // A source materialization must return a value of type `resultType`,
// so insert a cast from the memref descriptor type (!llvm.struct) to the
// original memref type.
Value packed =
@@ -234,19 +234,9 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
.getResult(0);
});
- // Argument materializations convert from the new block argument types
+ // Source materializations convert from the new block argument types
// (multiple SSA values that make up a memref descriptor) back to the
// original block argument type.
- addArgumentMaterialization([&](OpBuilder &builder,
- UnrankedMemRefType resultType,
- ValueRange inputs, Location loc) {
- return unrankedMemRefMaterialization(builder, resultType, inputs, loc,
- *this);
- });
- addArgumentMaterialization([&](OpBuilder &builder, MemRefType resultType,
- ValueRange inputs, Location loc) {
- return rankedMemRefMaterialization(builder, resultType, inputs, loc, *this);
- });
addSourceMaterialization([&](OpBuilder &builder,
UnrankedMemRefType resultType, ValueRange inputs,
Location loc) {
diff --git a/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp b/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp
index 58fd3d565fce..5d0003911bca 100644
--- a/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp
+++ b/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp
@@ -304,6 +304,7 @@ void ConvertOpenMPToLLVMPass::runOnOperation() {
LLVMTypeConverter converter(&getContext());
arith::populateArithToLLVMConversionPatterns(converter, patterns);
cf::populateControlFlowToLLVMConversionPatterns(converter, patterns);
+ cf::populateAssertToLLVMConversionPattern(converter, patterns);
populateFinalizeMemRefToLLVMConversionPatterns(converter, patterns);
populateFuncToLLVMConversionPatterns(converter, patterns);
populateOpenMPToLLVMConversionPatterns(converter, patterns);
diff --git a/mlir/lib/Conversion/SCFToEmitC/CMakeLists.txt b/mlir/lib/Conversion/SCFToEmitC/CMakeLists.txt
index 79119d374f7a..af5493be8a4b 100644
--- a/mlir/lib/Conversion/SCFToEmitC/CMakeLists.txt
+++ b/mlir/lib/Conversion/SCFToEmitC/CMakeLists.txt
@@ -13,6 +13,7 @@ add_mlir_conversion_library(MLIRSCFToEmitC
LINK_LIBS PUBLIC
MLIRArithDialect
MLIREmitCDialect
+ MLIREmitCTransforms
MLIRSCFDialect
MLIRTransforms
)
diff --git a/mlir/lib/Conversion/SCFToEmitC/SCFToEmitC.cpp b/mlir/lib/Conversion/SCFToEmitC/SCFToEmitC.cpp
index 67a43c43d608..92523ca4f12b 100644
--- a/mlir/lib/Conversion/SCFToEmitC/SCFToEmitC.cpp
+++ b/mlir/lib/Conversion/SCFToEmitC/SCFToEmitC.cpp
@@ -14,6 +14,7 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/EmitC/IR/EmitC.h"
+#include "mlir/Dialect/EmitC/Transforms/TypeConversions.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
@@ -39,21 +40,22 @@ struct SCFToEmitCPass : public impl::SCFToEmitCBase<SCFToEmitCPass> {
// Lower scf::for to emitc::for, implementing result values using
// emitc::variable's updated within the loop body.
-struct ForLowering : public OpRewritePattern<ForOp> {
- using OpRewritePattern<ForOp>::OpRewritePattern;
+struct ForLowering : public OpConversionPattern<ForOp> {
+ using OpConversionPattern<ForOp>::OpConversionPattern;
- LogicalResult matchAndRewrite(ForOp forOp,
- PatternRewriter &rewriter) const override;
+ LogicalResult
+ matchAndRewrite(ForOp forOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override;
};
// Create an uninitialized emitc::variable op for each result of the given op.
template <typename T>
-static SmallVector<Value> createVariablesForResults(T op,
- PatternRewriter &rewriter) {
- SmallVector<Value> resultVariables;
-
+static LogicalResult
+createVariablesForResults(T op, const TypeConverter *typeConverter,
+ ConversionPatternRewriter &rewriter,
+ SmallVector<Value> &resultVariables) {
if (!op.getNumResults())
- return resultVariables;
+ return success();
Location loc = op->getLoc();
MLIRContext *context = op.getContext();
@@ -62,7 +64,9 @@ static SmallVector<Value> createVariablesForResults(T op,
rewriter.setInsertionPoint(op);
for (OpResult result : op.getResults()) {
- Type resultType = result.getType();
+ Type resultType = typeConverter->convertType(result.getType());
+ if (!resultType)
+ return rewriter.notifyMatchFailure(op, "result type conversion failed");
Type varType = emitc::LValueType::get(resultType);
emitc::OpaqueAttr noInit = emitc::OpaqueAttr::get(context, "");
emitc::VariableOp var =
@@ -70,13 +74,13 @@ static SmallVector<Value> createVariablesForResults(T op,
resultVariables.push_back(var);
}
- return resultVariables;
+ return success();
}
// Create a series of assign ops assigning given values to given variables at
// the current insertion point of given rewriter.
-static void assignValues(ValueRange values, SmallVector<Value> &variables,
- PatternRewriter &rewriter, Location loc) {
+static void assignValues(ValueRange values, ValueRange variables,
+ ConversionPatternRewriter &rewriter, Location loc) {
for (auto [value, var] : llvm::zip(values, variables))
rewriter.create<emitc::AssignOp>(loc, var, value);
}
@@ -89,18 +93,25 @@ SmallVector<Value> loadValues(const SmallVector<Value> &variables,
});
}
-static void lowerYield(SmallVector<Value> &resultVariables,
- PatternRewriter &rewriter, scf::YieldOp yield) {
+static LogicalResult lowerYield(Operation *op, ValueRange resultVariables,
+ ConversionPatternRewriter &rewriter,
+ scf::YieldOp yield) {
Location loc = yield.getLoc();
- ValueRange operands = yield.getOperands();
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPoint(yield);
- assignValues(operands, resultVariables, rewriter, loc);
+ SmallVector<Value> yieldOperands;
+ if (failed(rewriter.getRemappedValues(yield.getOperands(), yieldOperands))) {
+ return rewriter.notifyMatchFailure(op, "failed to lower yield operands");
+ }
+
+ assignValues(yieldOperands, resultVariables, rewriter, loc);
rewriter.create<emitc::YieldOp>(loc);
rewriter.eraseOp(yield);
+
+ return success();
}
// Lower the contents of an scf::if/scf::index_switch regions to an
@@ -108,27 +119,32 @@ static void lowerYield(SmallVector<Value> &resultVariables,
// moved into the respective lowered region, but the scf::yield is replaced not
// only with an emitc::yield, but also with a sequence of emitc::assign ops that
// set the yielded values into the result variables.
-static void lowerRegion(SmallVector<Value> &resultVariables,
- PatternRewriter &rewriter, Region &region,
- Region &loweredRegion) {
+static LogicalResult lowerRegion(Operation *op, ValueRange resultVariables,
+ ConversionPatternRewriter &rewriter,
+ Region &region, Region &loweredRegion) {
rewriter.inlineRegionBefore(region, loweredRegion, loweredRegion.end());
Operation *terminator = loweredRegion.back().getTerminator();
- lowerYield(resultVariables, rewriter, cast<scf::YieldOp>(terminator));
+ return lowerYield(op, resultVariables, rewriter,
+ cast<scf::YieldOp>(terminator));
}
-LogicalResult ForLowering::matchAndRewrite(ForOp forOp,
- PatternRewriter &rewriter) const {
+LogicalResult
+ForLowering::matchAndRewrite(ForOp forOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const {
Location loc = forOp.getLoc();
// Create an emitc::variable op for each result. These variables will be
// assigned to by emitc::assign ops within the loop body.
- SmallVector<Value> resultVariables =
- createVariablesForResults(forOp, rewriter);
+ SmallVector<Value> resultVariables;
+ if (failed(createVariablesForResults(forOp, getTypeConverter(), rewriter,
+ resultVariables)))
+ return rewriter.notifyMatchFailure(forOp,
+ "create variables for results failed");
- assignValues(forOp.getInits(), resultVariables, rewriter, loc);
+ assignValues(adaptor.getInitArgs(), resultVariables, rewriter, loc);
emitc::ForOp loweredFor = rewriter.create<emitc::ForOp>(
- loc, forOp.getLowerBound(), forOp.getUpperBound(), forOp.getStep());
+ loc, adaptor.getLowerBound(), adaptor.getUpperBound(), adaptor.getStep());
Block *loweredBody = loweredFor.getBody();
@@ -143,13 +159,27 @@ LogicalResult ForLowering::matchAndRewrite(ForOp forOp,
rewriter.restoreInsertionPoint(ip);
+ // Convert the original region types into the new types by adding unrealized
+ // casts in the beginning of the loop. This performs the conversion in place.
+ if (failed(rewriter.convertRegionTypes(&forOp.getRegion(),
+ *getTypeConverter(), nullptr))) {
+ return rewriter.notifyMatchFailure(forOp, "region types conversion failed");
+ }
+
+ // Register the replacements for the block arguments and inline the body of
+ // the scf.for loop into the body of the emitc::for loop.
+ Block *scfBody = &(forOp.getRegion().front());
SmallVector<Value> replacingValues;
replacingValues.push_back(loweredFor.getInductionVar());
replacingValues.append(iterArgsValues.begin(), iterArgsValues.end());
+ rewriter.mergeBlocks(scfBody, loweredBody, replacingValues);
- rewriter.mergeBlocks(forOp.getBody(), loweredBody, replacingValues);
- lowerYield(resultVariables, rewriter,
- cast<scf::YieldOp>(loweredBody->getTerminator()));
+ auto result = lowerYield(forOp, resultVariables, rewriter,
+ cast<scf::YieldOp>(loweredBody->getTerminator()));
+
+ if (failed(result)) {
+ return result;
+ }
// Load variables into SSA values after the for loop.
SmallVector<Value> resultValues = loadValues(resultVariables, rewriter, loc);
@@ -160,38 +190,66 @@ LogicalResult ForLowering::matchAndRewrite(ForOp forOp,
// Lower scf::if to emitc::if, implementing result values as emitc::variable's
// updated within the then and else regions.
-struct IfLowering : public OpRewritePattern<IfOp> {
- using OpRewritePattern<IfOp>::OpRewritePattern;
+struct IfLowering : public OpConversionPattern<IfOp> {
+ using OpConversionPattern<IfOp>::OpConversionPattern;
- LogicalResult matchAndRewrite(IfOp ifOp,
- PatternRewriter &rewriter) const override;
+ LogicalResult
+ matchAndRewrite(IfOp ifOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override;
};
} // namespace
-LogicalResult IfLowering::matchAndRewrite(IfOp ifOp,
- PatternRewriter &rewriter) const {
+LogicalResult
+IfLowering::matchAndRewrite(IfOp ifOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const {
Location loc = ifOp.getLoc();
// Create an emitc::variable op for each result. These variables will be
// assigned to by emitc::assign ops within the then & else regions.
- SmallVector<Value> resultVariables =
- createVariablesForResults(ifOp, rewriter);
-
- Region &thenRegion = ifOp.getThenRegion();
- Region &elseRegion = ifOp.getElseRegion();
+ SmallVector<Value> resultVariables;
+ if (failed(createVariablesForResults(ifOp, getTypeConverter(), rewriter,
+ resultVariables)))
+ return rewriter.notifyMatchFailure(ifOp,
+ "create variables for results failed");
+
+ // Utility function to lower the contents of an scf::if region to an emitc::if
+ // region. The contents of the scf::if regions is moved into the respective
+ // emitc::if regions, but the scf::yield is replaced not only with an
+ // emitc::yield, but also with a sequence of emitc::assign ops that set the
+ // yielded values into the result variables.
+ auto lowerRegion = [&resultVariables, &rewriter,
+ &ifOp](Region &region, Region &loweredRegion) {
+ rewriter.inlineRegionBefore(region, loweredRegion, loweredRegion.end());
+ Operation *terminator = loweredRegion.back().getTerminator();
+ auto result = lowerYield(ifOp, resultVariables, rewriter,
+ cast<scf::YieldOp>(terminator));
+ if (failed(result)) {
+ return result;
+ }
+ return success();
+ };
+
+ Region &thenRegion = adaptor.getThenRegion();
+ Region &elseRegion = adaptor.getElseRegion();
bool hasElseBlock = !elseRegion.empty();
auto loweredIf =
- rewriter.create<emitc::IfOp>(loc, ifOp.getCondition(), false, false);
+ rewriter.create<emitc::IfOp>(loc, adaptor.getCondition(), false, false);
Region &loweredThenRegion = loweredIf.getThenRegion();
- lowerRegion(resultVariables, rewriter, thenRegion, loweredThenRegion);
+ auto result = lowerRegion(thenRegion, loweredThenRegion);
+ if (failed(result)) {
+ return result;
+ }
if (hasElseBlock) {
Region &loweredElseRegion = loweredIf.getElseRegion();
- lowerRegion(resultVariables, rewriter, elseRegion, loweredElseRegion);
+ auto result = lowerRegion(elseRegion, loweredElseRegion);
+ if (failed(result)) {
+ return result;
+ }
}
rewriter.setInsertionPointAfter(ifOp);
@@ -203,37 +261,46 @@ LogicalResult IfLowering::matchAndRewrite(IfOp ifOp,
// Lower scf::index_switch to emitc::switch, implementing result values as
// emitc::variable's updated within the case and default regions.
-struct IndexSwitchOpLowering : public OpRewritePattern<IndexSwitchOp> {
- using OpRewritePattern<IndexSwitchOp>::OpRewritePattern;
+struct IndexSwitchOpLowering : public OpConversionPattern<IndexSwitchOp> {
+ using OpConversionPattern::OpConversionPattern;
- LogicalResult matchAndRewrite(IndexSwitchOp indexSwitchOp,
- PatternRewriter &rewriter) const override;
+ LogicalResult
+ matchAndRewrite(IndexSwitchOp indexSwitchOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override;
};
-LogicalResult
-IndexSwitchOpLowering::matchAndRewrite(IndexSwitchOp indexSwitchOp,
- PatternRewriter &rewriter) const {
+LogicalResult IndexSwitchOpLowering::matchAndRewrite(
+ IndexSwitchOp indexSwitchOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const {
Location loc = indexSwitchOp.getLoc();
// Create an emitc::variable op for each result. These variables will be
// assigned to by emitc::assign ops within the case and default regions.
- SmallVector<Value> resultVariables =
- createVariablesForResults(indexSwitchOp, rewriter);
+ SmallVector<Value> resultVariables;
+ if (failed(createVariablesForResults(indexSwitchOp, getTypeConverter(),
+ rewriter, resultVariables))) {
+ return rewriter.notifyMatchFailure(indexSwitchOp,
+ "create variables for results failed");
+ }
auto loweredSwitch = rewriter.create<emitc::SwitchOp>(
- loc, indexSwitchOp.getArg(), indexSwitchOp.getCases(),
- indexSwitchOp.getNumCases());
+ loc, adaptor.getArg(), adaptor.getCases(), indexSwitchOp.getNumCases());
// Lowering all case regions.
- for (auto pair : llvm::zip(indexSwitchOp.getCaseRegions(),
- loweredSwitch.getCaseRegions())) {
- lowerRegion(resultVariables, rewriter, std::get<0>(pair),
- std::get<1>(pair));
+ for (auto pair :
+ llvm::zip(adaptor.getCaseRegions(), loweredSwitch.getCaseRegions())) {
+ if (failed(lowerRegion(indexSwitchOp, resultVariables, rewriter,
+ *std::get<0>(pair), std::get<1>(pair)))) {
+ return failure();
+ }
}
// Lowering default region.
- lowerRegion(resultVariables, rewriter, indexSwitchOp.getDefaultRegion(),
- loweredSwitch.getDefaultRegion());
+ if (failed(lowerRegion(indexSwitchOp, resultVariables, rewriter,
+ adaptor.getDefaultRegion(),
+ loweredSwitch.getDefaultRegion()))) {
+ return failure();
+ }
rewriter.setInsertionPointAfter(indexSwitchOp);
SmallVector<Value> results = loadValues(resultVariables, rewriter, loc);
@@ -242,15 +309,22 @@ IndexSwitchOpLowering::matchAndRewrite(IndexSwitchOp indexSwitchOp,
return success();
}
-void mlir::populateSCFToEmitCConversionPatterns(RewritePatternSet &patterns) {
- patterns.add<ForLowering>(patterns.getContext());
- patterns.add<IfLowering>(patterns.getContext());
- patterns.add<IndexSwitchOpLowering>(patterns.getContext());
+void mlir::populateSCFToEmitCConversionPatterns(RewritePatternSet &patterns,
+ TypeConverter &typeConverter) {
+ patterns.add<ForLowering>(typeConverter, patterns.getContext());
+ patterns.add<IfLowering>(typeConverter, patterns.getContext());
+ patterns.add<IndexSwitchOpLowering>(typeConverter, patterns.getContext());
}
void SCFToEmitCPass::runOnOperation() {
RewritePatternSet patterns(&getContext());
- populateSCFToEmitCConversionPatterns(patterns);
+ TypeConverter typeConverter;
+ // Fallback converter
+ // See note https://mlir.llvm.org/docs/DialectConversion/#type-converter
+ // Type converters are called most to least recently inserted
+ typeConverter.addConversion([](Type t) { return t; });
+ populateEmitCSizeTTypeConversions(typeConverter);
+ populateSCFToEmitCConversionPatterns(patterns, typeConverter);
// Configure conversion to lower out SCF operations.
ConversionTarget target(getContext());
diff --git a/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp b/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp
index 6f085cb6ed06..b5a0da15e780 100644
--- a/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp
+++ b/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp
@@ -338,11 +338,6 @@ public:
padOp, "tosa.pad was unable to determine the pad constant value.");
}
- Value lowIndex =
- rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(0));
- Value highIndex =
- rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(1));
-
SmallVector<OpFoldResult, 3> lowValues;
SmallVector<OpFoldResult, 3> highValues;
@@ -350,11 +345,12 @@ public:
highValues.reserve(rank);
for (int i = 0; i < rank; i++) {
- Value inputIndex = rewriter.create<arith::ConstantIndexOp>(loc, i);
+ Value lowIndex = rewriter.create<arith::ConstantIndexOp>(loc, 2 * i);
+ Value highIndex = rewriter.create<arith::ConstantIndexOp>(loc, 2 * i + 1);
Value lowVal = rewriter.createOrFold<tensor::ExtractOp>(
- loc, padding, ValueRange({inputIndex, lowIndex}));
+ loc, padding, ValueRange({lowIndex}));
Value highVal = rewriter.createOrFold<tensor::ExtractOp>(
- loc, padding, ValueRange({inputIndex, highIndex}));
+ loc, padding, ValueRange({highIndex}));
lowVal = rewriter.createOrFold<arith::IndexCastOp>(
loc, rewriter.getIndexType(), lowVal);
diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
index dceebbfec586..b45829bcf6d2 100644
--- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
+++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
@@ -4520,6 +4520,10 @@ void AffineDelinearizeIndexOp::build(OpBuilder &odsBuilder,
OperationState &odsState,
Value linearIndex, ValueRange basis,
bool hasOuterBound) {
+ if (hasOuterBound && !basis.empty() && basis.front() == nullptr) {
+ hasOuterBound = false;
+ basis = basis.drop_front();
+ }
SmallVector<Value> dynamicBasis;
SmallVector<int64_t> staticBasis;
dispatchIndexOpFoldResults(getAsOpFoldResult(basis), dynamicBasis,
@@ -4533,6 +4537,10 @@ void AffineDelinearizeIndexOp::build(OpBuilder &odsBuilder,
Value linearIndex,
ArrayRef<OpFoldResult> basis,
bool hasOuterBound) {
+ if (hasOuterBound && !basis.empty() && basis.front() == OpFoldResult()) {
+ hasOuterBound = false;
+ basis = basis.drop_front();
+ }
SmallVector<Value> dynamicBasis;
SmallVector<int64_t> staticBasis;
dispatchIndexOpFoldResults(basis, dynamicBasis, staticBasis);
@@ -4654,6 +4662,13 @@ SmallVector<OpFoldResult> AffineDelinearizeIndexOp::getEffectiveBasis() {
return getMixedValues(getStaticBasis(), getDynamicBasis(), builder);
}
+SmallVector<OpFoldResult> AffineDelinearizeIndexOp::getPaddedBasis() {
+ SmallVector<OpFoldResult> ret = getMixedBasis();
+ if (!hasOuterBound())
+ ret.insert(ret.begin(), OpFoldResult());
+ return ret;
+}
+
namespace {
// Drops delinearization indices that correspond to unit-extent basis
@@ -4672,25 +4687,27 @@ struct DropUnitExtentBasis
return zero.value();
};
- bool hasOuterBound = delinearizeOp.hasOuterBound();
// Replace all indices corresponding to unit-extent basis with 0.
// Remaining basis can be used to get a new `affine.delinearize_index` op.
SmallVector<OpFoldResult> newBasis;
- for (auto [index, basis] : llvm::enumerate(delinearizeOp.getMixedBasis())) {
- std::optional<int64_t> basisVal = getConstantIntValue(basis);
+ for (auto [index, basis] :
+ llvm::enumerate(delinearizeOp.getPaddedBasis())) {
+ std::optional<int64_t> basisVal =
+ basis ? getConstantIntValue(basis) : std::nullopt;
if (basisVal && *basisVal == 1)
- replacements[index + (hasOuterBound ? 0 : 1)] = getZero();
+ replacements[index] = getZero();
else
newBasis.push_back(basis);
}
- if (newBasis.size() == delinearizeOp.getStaticBasis().size())
+ if (newBasis.size() == delinearizeOp.getNumResults())
return rewriter.notifyMatchFailure(delinearizeOp,
"no unit basis elements");
- if (!newBasis.empty() || !hasOuterBound) {
+ if (!newBasis.empty()) {
+ // Will drop the leading nullptr from `basis` if there was no outer bound.
auto newDelinearizeOp = rewriter.create<affine::AffineDelinearizeIndexOp>(
- loc, delinearizeOp.getLinearIndex(), newBasis, hasOuterBound);
+ loc, delinearizeOp.getLinearIndex(), newBasis);
int newIndex = 0;
// Map back the new delinearized indices to the values they replace.
for (auto &replacement : replacements) {
@@ -4871,6 +4888,8 @@ void AffineLinearizeIndexOp::build(OpBuilder &odsBuilder,
OperationState &odsState,
ValueRange multiIndex, ValueRange basis,
bool disjoint) {
+ if (!basis.empty() && basis.front() == Value())
+ basis = basis.drop_front();
SmallVector<Value> dynamicBasis;
SmallVector<int64_t> staticBasis;
dispatchIndexOpFoldResults(getAsOpFoldResult(basis), dynamicBasis,
@@ -4883,6 +4902,8 @@ void AffineLinearizeIndexOp::build(OpBuilder &odsBuilder,
ValueRange multiIndex,
ArrayRef<OpFoldResult> basis,
bool disjoint) {
+ if (!basis.empty() && basis.front() == OpFoldResult())
+ basis = basis.drop_front();
SmallVector<Value> dynamicBasis;
SmallVector<int64_t> staticBasis;
dispatchIndexOpFoldResults(basis, dynamicBasis, staticBasis);
@@ -4965,7 +4986,14 @@ SmallVector<OpFoldResult> AffineLinearizeIndexOp::getEffectiveBasis() {
builder);
}
- return ::mlir::getMixedValues(getStaticBasis(), getDynamicBasis(), builder);
+ return getMixedValues(getStaticBasis(), getDynamicBasis(), builder);
+}
+
+SmallVector<OpFoldResult> AffineLinearizeIndexOp::getPaddedBasis() {
+ SmallVector<OpFoldResult> ret = getMixedBasis();
+ if (!hasOuterBound())
+ ret.insert(ret.begin(), OpFoldResult());
+ return ret;
}
namespace {
@@ -5027,38 +5055,228 @@ struct DropLinearizeUnitComponentsIfDisjointOrZero final
}
};
-/// Cancel out linearize_index(delinearize_index(x, B), B).
+/// Return the product of `terms`, creating an `affine.apply` if any of them are
+/// non-constant values. If any of `terms` is `nullptr`, return `nullptr`.
+static OpFoldResult computeProduct(Location loc, OpBuilder &builder,
+ ArrayRef<OpFoldResult> terms) {
+ int64_t nDynamic = 0;
+ SmallVector<Value> dynamicPart;
+ AffineExpr result = builder.getAffineConstantExpr(1);
+ for (OpFoldResult term : terms) {
+ if (!term)
+ return term;
+ std::optional<int64_t> maybeConst = getConstantIntValue(term);
+ if (maybeConst) {
+ result = result * builder.getAffineConstantExpr(*maybeConst);
+ } else {
+ dynamicPart.push_back(term.get<Value>());
+ result = result * builder.getAffineSymbolExpr(nDynamic++);
+ }
+ }
+ if (auto constant = dyn_cast<AffineConstantExpr>(result))
+ return getAsIndexOpFoldResult(builder.getContext(), constant.getValue());
+ return builder.create<AffineApplyOp>(loc, result, dynamicPart).getResult();
+}
+
+/// If conseceutive outputs of a delinearize_index are linearized with the same
+/// bounds, canonicalize away the redundant arithmetic.
+///
+/// That is, if we have
+/// ```
+/// %s:N = affine.delinearize_index %x into (...a, B1, B2, ... BK, ...b)
+/// %t = affine.linearize_index [...c, %s#I, %s#(I + 1), ... %s#(I+K-1), ...d]
+/// by (...e, B1, B2, ..., BK, ...f)
+/// ```
///
-/// That is, rewrite
+/// We can rewrite this to
/// ```
-/// %0:N = affine.delinearize_index %x by (%b1, %b2, ... %bN)
-/// %y = affine.linearize_index [%0#0, %0#1, ... %0#(N-1)] by (%b1, %b2, ...
-/// %bN)
+/// B = B1 * B2 ... BK
+/// %sMerged:(N-K+1) affine.delinearize_index %x into (...a, B, ...b)
+/// %t = affine.linearize_index [...c, %s#I, ...d] by (...e, B, ...f)
/// ```
-/// to replacing `%y` with `%x`.
-struct CancelLinearizeOfDelinearizeExact final
+/// where we replace all results of %s unaffected by the change with results
+/// from %sMerged.
+///
+/// As a special case, if all results of the delinearize are merged in this way
+/// we can replace those usages with %x, thus cancelling the delinearization
+/// entirely, as in
+/// ```
+/// %s:3 = affine.delinearize_index %x into (2, 4, 8)
+/// %t = affine.linearize_index [%s#0, %s#1, %s#2, %c0] by (2, 4, 8, 16)
+/// ```
+/// becoming `%t = affine.linearize_index [%x, %c0] by (64, 16)`
+struct CancelLinearizeOfDelinearizePortion final
: OpRewritePattern<affine::AffineLinearizeIndexOp> {
using OpRewritePattern::OpRewritePattern;
+private:
+ // Struct representing a case where the cancellation pattern
+ // applies. A `Match` means that `length` inputs to the linearize operation
+ // starting at `linStart` can be cancelled with `length` outputs of
+ // `delinearize`, starting from `delinStart`.
+ struct Match {
+ AffineDelinearizeIndexOp delinearize;
+ unsigned linStart = 0;
+ unsigned delinStart = 0;
+ unsigned length = 0;
+ };
+
+public:
LogicalResult matchAndRewrite(affine::AffineLinearizeIndexOp linearizeOp,
PatternRewriter &rewriter) const override {
- auto delinearizeOp = linearizeOp.getMultiIndex()
- .front()
- .getDefiningOp<affine::AffineDelinearizeIndexOp>();
- if (!delinearizeOp)
- return rewriter.notifyMatchFailure(
- linearizeOp, "last entry doesn't come from a delinearize");
+ SmallVector<Match> matches;
+
+ const SmallVector<OpFoldResult> linBasis = linearizeOp.getPaddedBasis();
+ ArrayRef<OpFoldResult> linBasisRef = linBasis;
+
+ ValueRange multiIndex = linearizeOp.getMultiIndex();
+ unsigned numLinArgs = multiIndex.size();
+ unsigned linArgIdx = 0;
+ // We only want to replace one run from the same delinearize op per
+ // pattern invocation lest we run into invalidation issues.
+ llvm::SmallPtrSet<Operation *, 2> alreadyMatchedDelinearize;
+ while (linArgIdx < numLinArgs) {
+ auto asResult = dyn_cast<OpResult>(multiIndex[linArgIdx]);
+ if (!asResult) {
+ linArgIdx++;
+ continue;
+ }
- if (linearizeOp.getEffectiveBasis() != delinearizeOp.getEffectiveBasis())
- return rewriter.notifyMatchFailure(
- linearizeOp, "basis of linearize and delinearize don't match exactly "
- "(excluding outer bounds)");
+ auto delinearizeOp =
+ dyn_cast<AffineDelinearizeIndexOp>(asResult.getOwner());
+ if (!delinearizeOp) {
+ linArgIdx++;
+ continue;
+ }
+
+ /// Result 0 of the delinearize and argument 0 of the linearize can
+ /// leave their maximum value unspecified. However, even if this happens
+ /// we can still sometimes start the match process. Specifically, if
+ /// - The argument we're matching is result 0 and argument 0 (so the
+ /// bounds don't matter). For example,
+ ///
+ /// %0:2 = affine.delinearize_index %x into (8) : index, index
+ /// %1 = affine.linearize_index [%s#0, %s#1, ...] (8, ...)
+ /// allows cancellation
+ /// - The delinearization doesn't specify a bound, but the linearization
+ /// is `disjoint`, which asserts that the bound on the linearization is
+ /// correct.
+ unsigned delinArgIdx = asResult.getResultNumber();
+ SmallVector<OpFoldResult> delinBasis = delinearizeOp.getPaddedBasis();
+ OpFoldResult firstDelinBound = delinBasis[delinArgIdx];
+ OpFoldResult firstLinBound = linBasis[linArgIdx];
+ bool boundsMatch = firstDelinBound == firstLinBound;
+ bool bothAtFront = linArgIdx == 0 && delinArgIdx == 0;
+ bool knownByDisjoint =
+ linearizeOp.getDisjoint() && delinArgIdx == 0 && !firstDelinBound;
+ if (!boundsMatch && !bothAtFront && !knownByDisjoint) {
+ linArgIdx++;
+ continue;
+ }
+
+ unsigned j = 1;
+ unsigned numDelinOuts = delinearizeOp.getNumResults();
+ for (; j + linArgIdx < numLinArgs && j + delinArgIdx < numDelinOuts;
+ ++j) {
+ if (multiIndex[linArgIdx + j] !=
+ delinearizeOp.getResult(delinArgIdx + j))
+ break;
+ if (linBasis[linArgIdx + j] != delinBasis[delinArgIdx + j])
+ break;
+ }
+ // If there're multiple matches against the same delinearize_index,
+ // only rewrite the first one we find to prevent invalidations. The next
+ // ones will be taken care of by subsequent pattern invocations.
+ if (j <= 1 || !alreadyMatchedDelinearize.insert(delinearizeOp).second) {
+ linArgIdx++;
+ continue;
+ }
+ matches.push_back(Match{delinearizeOp, linArgIdx, delinArgIdx, j});
+ linArgIdx += j;
+ }
- if (delinearizeOp.getResults() != linearizeOp.getMultiIndex())
+ if (matches.empty())
return rewriter.notifyMatchFailure(
- linearizeOp, "not all indices come from delinearize");
+ linearizeOp, "no run of delinearize outputs to deal with");
+
+ // Record all the delinearize replacements so we can do them after creating
+ // the new linearization operation, since the new operation might use
+ // outputs of something we're replacing.
+ SmallVector<SmallVector<Value>> delinearizeReplacements;
+
+ SmallVector<Value> newIndex;
+ newIndex.reserve(numLinArgs);
+ SmallVector<OpFoldResult> newBasis;
+ newBasis.reserve(numLinArgs);
+ unsigned prevMatchEnd = 0;
+ for (Match m : matches) {
+ unsigned gap = m.linStart - prevMatchEnd;
+ llvm::append_range(newIndex, multiIndex.slice(prevMatchEnd, gap));
+ llvm::append_range(newBasis, linBasisRef.slice(prevMatchEnd, gap));
+ // Update here so we don't forget this during early continues
+ prevMatchEnd = m.linStart + m.length;
+
+ PatternRewriter::InsertionGuard g(rewriter);
+ rewriter.setInsertionPoint(m.delinearize);
+
+ ArrayRef<OpFoldResult> basisToMerge =
+ linBasisRef.slice(m.linStart, m.length);
+ // We use the slice from the linearize's basis above because of the
+ // "bounds inferred from `disjoint`" case above.
+ OpFoldResult newSize =
+ computeProduct(linearizeOp.getLoc(), rewriter, basisToMerge);
+
+ // Trivial case where we can just skip past the delinearize all together
+ if (m.length == m.delinearize.getNumResults()) {
+ newIndex.push_back(m.delinearize.getLinearIndex());
+ newBasis.push_back(newSize);
+ // Pad out set of replacements so we don't do anything with this one.
+ delinearizeReplacements.push_back(SmallVector<Value>());
+ continue;
+ }
+
+ SmallVector<Value> newDelinResults;
+ SmallVector<OpFoldResult> newDelinBasis = m.delinearize.getPaddedBasis();
+ newDelinBasis.erase(newDelinBasis.begin() + m.delinStart,
+ newDelinBasis.begin() + m.delinStart + m.length);
+ newDelinBasis.insert(newDelinBasis.begin() + m.delinStart, newSize);
+ auto newDelinearize = rewriter.create<AffineDelinearizeIndexOp>(
+ m.delinearize.getLoc(), m.delinearize.getLinearIndex(),
+ newDelinBasis);
+
+ // Since there may be other uses of the indices we just merged together,
+ // create a residual affine.delinearize_index that delinearizes the
+ // merged output into its component parts.
+ Value combinedElem = newDelinearize.getResult(m.delinStart);
+ auto residualDelinearize = rewriter.create<AffineDelinearizeIndexOp>(
+ m.delinearize.getLoc(), combinedElem, basisToMerge);
+
+ // Swap all the uses of the unaffected delinearize outputs to the new
+ // delinearization so that the old code can be removed if this
+ // linearize_index is the only user of the merged results.
+ llvm::append_range(newDelinResults,
+ newDelinearize.getResults().take_front(m.delinStart));
+ llvm::append_range(newDelinResults, residualDelinearize.getResults());
+ llvm::append_range(
+ newDelinResults,
+ newDelinearize.getResults().drop_front(m.delinStart + 1));
+
+ delinearizeReplacements.push_back(newDelinResults);
+ newIndex.push_back(combinedElem);
+ newBasis.push_back(newSize);
+ }
+ llvm::append_range(newIndex, multiIndex.drop_front(prevMatchEnd));
+ llvm::append_range(newBasis, linBasisRef.drop_front(prevMatchEnd));
+ rewriter.replaceOpWithNewOp<AffineLinearizeIndexOp>(
+ linearizeOp, newIndex, newBasis, linearizeOp.getDisjoint());
+
+ for (auto [m, newResults] :
+ llvm::zip_equal(matches, delinearizeReplacements)) {
+ if (newResults.empty())
+ continue;
+ rewriter.replaceOp(m.delinearize, newResults);
+ }
- rewriter.replaceOp(linearizeOp, delinearizeOp.getLinearIndex());
return success();
}
};
@@ -5096,7 +5314,7 @@ struct DropLinearizeLeadingZero final
void affine::AffineLinearizeIndexOp::getCanonicalizationPatterns(
RewritePatternSet &patterns, MLIRContext *context) {
- patterns.add<CancelLinearizeOfDelinearizeExact, DropLinearizeLeadingZero,
+ patterns.add<CancelLinearizeOfDelinearizePortion, DropLinearizeLeadingZero,
DropLinearizeUnitComponentsIfDisjointOrZero>(context);
}
diff --git a/mlir/lib/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.cpp b/mlir/lib/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.cpp
index 82a9fb0d4908..e93b99b4f498 100644
--- a/mlir/lib/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.cpp
@@ -91,6 +91,64 @@ struct AffineMaxOpInterface
};
};
+struct AffineDelinearizeIndexOpInterface
+ : public ValueBoundsOpInterface::ExternalModel<
+ AffineDelinearizeIndexOpInterface, AffineDelinearizeIndexOp> {
+ void populateBoundsForIndexValue(Operation *rawOp, Value value,
+ ValueBoundsConstraintSet &cstr) const {
+ auto op = cast<AffineDelinearizeIndexOp>(rawOp);
+ auto result = cast<OpResult>(value);
+ assert(result.getOwner() == rawOp &&
+ "bounded value isn't a result of this delinearize_index");
+ unsigned resIdx = result.getResultNumber();
+
+ AffineExpr linearIdx = cstr.getExpr(op.getLinearIndex());
+
+ SmallVector<OpFoldResult> basis = op.getPaddedBasis();
+ AffineExpr divisor = cstr.getExpr(1);
+ for (OpFoldResult basisElem : llvm::drop_begin(basis, resIdx + 1))
+ divisor = divisor * cstr.getExpr(basisElem);
+
+ if (resIdx == 0) {
+ cstr.bound(value) == linearIdx.floorDiv(divisor);
+ if (!basis.front().isNull())
+ cstr.bound(value) < cstr.getExpr(basis.front());
+ return;
+ }
+ AffineExpr thisBasis = cstr.getExpr(basis[resIdx]);
+ cstr.bound(value) == (linearIdx % (thisBasis * divisor)).floorDiv(divisor);
+ }
+};
+
+struct AffineLinearizeIndexOpInterface
+ : public ValueBoundsOpInterface::ExternalModel<
+ AffineLinearizeIndexOpInterface, AffineLinearizeIndexOp> {
+ void populateBoundsForIndexValue(Operation *rawOp, Value value,
+ ValueBoundsConstraintSet &cstr) const {
+ auto op = cast<AffineLinearizeIndexOp>(rawOp);
+ assert(value == op.getResult() &&
+ "value isn't the result of this linearize");
+
+ AffineExpr bound = cstr.getExpr(0);
+ AffineExpr stride = cstr.getExpr(1);
+ SmallVector<OpFoldResult> basis = op.getPaddedBasis();
+ OperandRange multiIndex = op.getMultiIndex();
+ unsigned numArgs = multiIndex.size();
+ for (auto [revArgNum, length] : llvm::enumerate(llvm::reverse(basis))) {
+ unsigned argNum = numArgs - (revArgNum + 1);
+ if (argNum == 0)
+ break;
+ OpFoldResult indexAsFoldRes = getAsOpFoldResult(multiIndex[argNum]);
+ bound = bound + cstr.getExpr(indexAsFoldRes) * stride;
+ stride = stride * cstr.getExpr(length);
+ }
+ bound = bound + cstr.getExpr(op.getMultiIndex().front()) * stride;
+ cstr.bound(value) == bound;
+ if (op.getDisjoint() && !basis.front().isNull()) {
+ cstr.bound(value) < stride *cstr.getExpr(basis.front());
+ }
+ }
+};
} // namespace
} // namespace mlir
@@ -100,6 +158,10 @@ void mlir::affine::registerValueBoundsOpInterfaceExternalModels(
AffineApplyOp::attachInterface<AffineApplyOpInterface>(*ctx);
AffineMaxOp::attachInterface<AffineMaxOpInterface>(*ctx);
AffineMinOp::attachInterface<AffineMinOpInterface>(*ctx);
+ AffineDelinearizeIndexOp::attachInterface<
+ AffineDelinearizeIndexOpInterface>(*ctx);
+ AffineLinearizeIndexOp::attachInterface<AffineLinearizeIndexOpInterface>(
+ *ctx);
});
}
diff --git a/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp b/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp
index 0f2c889d4f39..4e02559a0894 100644
--- a/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp
+++ b/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp
@@ -919,8 +919,9 @@ static void generateUnrolledLoop(
// 'forOp'.
auto builder = OpBuilder::atBlockTerminator(loopBodyBlock);
+ constexpr auto defaultAnnotateFn = [](unsigned, Operation *, OpBuilder) {};
if (!annotateFn)
- annotateFn = [](unsigned, Operation *, OpBuilder) {};
+ annotateFn = defaultAnnotateFn;
// Keep a pointer to the last non-terminator operation in the original block
// so that we know what to clone (since we are doing this in-place).
diff --git a/mlir/lib/Dialect/Affine/Utils/Utils.cpp b/mlir/lib/Dialect/Affine/Utils/Utils.cpp
index 4d3ead20fb5c..9e3257a62b12 100644
--- a/mlir/lib/Dialect/Affine/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Affine/Utils/Utils.cpp
@@ -51,12 +51,14 @@ public:
loc(loc) {}
template <typename OpTy>
- Value buildBinaryExpr(AffineBinaryOpExpr expr) {
+ Value buildBinaryExpr(AffineBinaryOpExpr expr,
+ arith::IntegerOverflowFlags overflowFlags =
+ arith::IntegerOverflowFlags::none) {
auto lhs = visit(expr.getLHS());
auto rhs = visit(expr.getRHS());
if (!lhs || !rhs)
return nullptr;
- auto op = builder.create<OpTy>(loc, lhs, rhs);
+ auto op = builder.create<OpTy>(loc, lhs, rhs, overflowFlags);
return op.getResult();
}
@@ -65,7 +67,8 @@ public:
}
Value visitMulExpr(AffineBinaryOpExpr expr) {
- return buildBinaryExpr<arith::MulIOp>(expr);
+ return buildBinaryExpr<arith::MulIOp>(expr,
+ arith::IntegerOverflowFlags::nsw);
}
/// Euclidean modulo operation: negative RHS is not allowed.
diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index d8b314a3fa43..e016a6e16e59 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -580,11 +580,31 @@ void arith::MulUIExtendedOp::getCanonicalizationPatterns(
// DivUIOp
//===----------------------------------------------------------------------===//
+/// Fold `(a * b) / b -> a`
+static Value foldDivMul(Value lhs, Value rhs,
+ arith::IntegerOverflowFlags ovfFlags) {
+ auto mul = lhs.getDefiningOp<mlir::arith::MulIOp>();
+ if (!mul || !bitEnumContainsAll(mul.getOverflowFlags(), ovfFlags))
+ return {};
+
+ if (mul.getLhs() == rhs)
+ return mul.getRhs();
+
+ if (mul.getRhs() == rhs)
+ return mul.getLhs();
+
+ return {};
+}
+
OpFoldResult arith::DivUIOp::fold(FoldAdaptor adaptor) {
// divui (x, 1) -> x.
if (matchPattern(adaptor.getRhs(), m_One()))
return getLhs();
+ // (a * b) / b -> a
+ if (Value val = foldDivMul(getLhs(), getRhs(), IntegerOverflowFlags::nuw))
+ return val;
+
// Don't fold if it would require a division by zero.
bool div0 = false;
auto result = constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
@@ -621,6 +641,10 @@ OpFoldResult arith::DivSIOp::fold(FoldAdaptor adaptor) {
if (matchPattern(adaptor.getRhs(), m_One()))
return getLhs();
+ // (a * b) / b -> a
+ if (Value val = foldDivMul(getLhs(), getRhs(), IntegerOverflowFlags::nsw))
+ return val;
+
// Don't fold if it would overflow or if it requires a division by zero.
bool overflowOrDiv0 = false;
auto result = constFoldBinaryOp<IntegerAttr>(
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp b/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
index 61767f3b21c9..12c65a72babc 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
+++ b/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
@@ -17,7 +17,7 @@
#include "mlir/Dialect/ArmSME/Transforms/Passes.h"
#include "mlir/Dialect/ArmSME/Utils/Utils.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
-#include "mlir/Dialect/Func/Transforms/OneToNFuncConversions.h"
+#include "mlir/Dialect/Func/Transforms/FuncConversions.h"
#include "mlir/Dialect/Index/IR/IndexDialect.h"
#include "mlir/Dialect/Index/IR/IndexOps.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
@@ -25,7 +25,8 @@
#include "mlir/Dialect/SCF/Transforms/Patterns.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
-#include "mlir/Transforms/OneToNTypeConversion.h"
+#include "mlir/Transforms/DialectConversion.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#define DEBUG_TYPE "arm-sme-vector-legalization"
@@ -172,12 +173,12 @@ int getNumberOfSMETilesForVectorType(VectorType type) {
/// Legalize `arith.constant dense<value>` splat operations to fit within SME
/// tiles by decomposing them into tile-sized operations.
struct LegalizeArithConstantOpsByDecomposition
- : public OneToNOpConversionPattern<arith::ConstantOp> {
- using OneToNOpConversionPattern::OneToNOpConversionPattern;
+ : public OpConversionPattern<arith::ConstantOp> {
+ using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(arith::ConstantOp constantOp, OpAdaptor adaptor,
- OneToNPatternRewriter &rewriter) const override {
+ ConversionPatternRewriter &rewriter) const override {
auto vectorType = dyn_cast<VectorType>(constantOp.getType());
auto denseAttr = dyn_cast<DenseElementsAttr>(constantOp.getValueAttr());
if (!vectorType || !denseAttr || !denseAttr.isSplat())
@@ -191,8 +192,8 @@ struct LegalizeArithConstantOpsByDecomposition
auto tileCount = getNumberOfSMETilesForVectorType(vectorType);
auto tileSplat = rewriter.create<arith::ConstantOp>(
constantOp.getLoc(), denseAttr.resizeSplat(smeTileType));
- rewriter.replaceOp(constantOp, SmallVector<Value>(tileCount, tileSplat),
- adaptor.getResultMapping());
+ SmallVector<Value> repl(tileCount, tileSplat);
+ rewriter.replaceOpWithMultiple(constantOp, {repl});
return success();
}
@@ -201,12 +202,13 @@ struct LegalizeArithConstantOpsByDecomposition
/// Legalize `vector.outerproduct` operations to fit within SME tiles by
/// decomposing them into tile-sized operations.
struct LegalizeVectorOuterProductOpsByDecomposition
- : public OneToNOpConversionPattern<vector::OuterProductOp> {
- using OneToNOpConversionPattern::OneToNOpConversionPattern;
+ : public OpConversionPattern<vector::OuterProductOp> {
+ using OpConversionPattern::OpConversionPattern;
LogicalResult
- matchAndRewrite(vector::OuterProductOp outerProductOp, OpAdaptor adaptor,
- OneToNPatternRewriter &rewriter) const override {
+ matchAndRewrite(vector::OuterProductOp outerProductOp,
+ OneToNOpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
auto vectorType = outerProductOp.getResultVectorType();
if (!isMultipleOfSMETileVectorType(vectorType))
return rewriter.notifyMatchFailure(outerProductOp,
@@ -219,6 +221,7 @@ struct LegalizeVectorOuterProductOpsByDecomposition
auto maskOp = outerProductOp.getMaskingOp();
mask = maskOp.getMask();
rootOp = maskOp;
+ rewriter.setInsertionPoint(rootOp);
}
if (!isSupportedMaskOp(mask))
@@ -248,7 +251,7 @@ struct LegalizeVectorOuterProductOpsByDecomposition
resultSMETiles.push_back(maskedOuterProduct->getResult(0));
}
- rewriter.replaceOp(rootOp, resultSMETiles, adaptor.getResultMapping());
+ rewriter.replaceOpWithMultiple(rootOp, {resultSMETiles});
return success();
}
};
@@ -259,12 +262,12 @@ struct LegalizeVectorOuterProductOpsByDecomposition
// (invalid). This pattern matches on `vector.mask` then calls into the
// `vector.outerproduct` pattern to work around this issue.
struct LegalizeMaskedVectorOuterProductOpsByDecomposition
- : public OneToNOpConversionPattern<vector::MaskOp> {
- using OneToNOpConversionPattern::OneToNOpConversionPattern;
+ : public OpConversionPattern<vector::MaskOp> {
+ using OpConversionPattern::OpConversionPattern;
LogicalResult
- matchAndRewrite(vector::MaskOp maskOp, OpAdaptor adaptor,
- OneToNPatternRewriter &rewriter) const override {
+ matchAndRewrite(vector::MaskOp maskOp, OneToNOpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
if (auto outerProductOp = llvm::dyn_cast_or_null<vector::OuterProductOp>(
maskOp.getMaskableOp())) {
LegalizeVectorOuterProductOpsByDecomposition pattern(*getTypeConverter(),
@@ -279,12 +282,12 @@ struct LegalizeMaskedVectorOuterProductOpsByDecomposition
/// Legalize `vector.transfer_read` operations to fit within SME tiles by
/// decomposing them into tile-sized operations.
struct LegalizeTransferReadOpsByDecomposition
- : public OneToNOpConversionPattern<vector::TransferReadOp> {
- using OneToNOpConversionPattern::OneToNOpConversionPattern;
+ : public OpConversionPattern<vector::TransferReadOp> {
+ using OpConversionPattern::OpConversionPattern;
LogicalResult
- matchAndRewrite(vector::TransferReadOp readOp, OpAdaptor adaptor,
- OneToNPatternRewriter &rewriter) const override {
+ matchAndRewrite(vector::TransferReadOp readOp, OneToNOpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
auto vectorType = readOp.getVectorType();
if (!isMultipleOfSMETileVectorType(vectorType))
return rewriter.notifyMatchFailure(readOp,
@@ -319,7 +322,7 @@ struct LegalizeTransferReadOpsByDecomposition
resultSMETiles.push_back(smeRead);
}
- rewriter.replaceOp(readOp, resultSMETiles, adaptor.getResultMapping());
+ rewriter.replaceOpWithMultiple(readOp, {resultSMETiles});
return success();
}
};
@@ -327,12 +330,12 @@ struct LegalizeTransferReadOpsByDecomposition
/// Legalize `vector.transfer_write` operations to fit within SME tiles by
/// decomposing them into tile-sized operations.
struct LegalizeTransferWriteOpsByDecomposition
- : public OneToNOpConversionPattern<vector::TransferWriteOp> {
- using OneToNOpConversionPattern::OneToNOpConversionPattern;
+ : public OpConversionPattern<vector::TransferWriteOp> {
+ using OpConversionPattern::OpConversionPattern;
LogicalResult
- matchAndRewrite(vector::TransferWriteOp writeOp, OpAdaptor adaptor,
- OneToNPatternRewriter &rewriter) const override {
+ matchAndRewrite(vector::TransferWriteOp writeOp, OneToNOpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
auto vectorType = writeOp.getVectorType();
if (!isMultipleOfSMETileVectorType(vectorType))
return rewriter.notifyMatchFailure(writeOp,
@@ -409,12 +412,12 @@ struct LegalizeTransferWriteOpsByDecomposition
/// }
/// ```
struct LegalizeMultiTileTransferWriteAsStoreLoop
- : public OneToNOpConversionPattern<vector::TransferWriteOp> {
- using OneToNOpConversionPattern::OneToNOpConversionPattern;
+ : public OpConversionPattern<vector::TransferWriteOp> {
+ using OpConversionPattern::OpConversionPattern;
LogicalResult
- matchAndRewrite(vector::TransferWriteOp writeOp, OpAdaptor adaptor,
- OneToNPatternRewriter &rewriter) const override {
+ matchAndRewrite(vector::TransferWriteOp writeOp, OneToNOpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
if (writeOp.hasPureTensorSemantics())
return rewriter.notifyMatchFailure(
writeOp, "TODO: tensor semantics are unsupported");
@@ -936,10 +939,16 @@ struct VectorLegalizationPass
return success();
});
- patterns.add<FoldExtractFromVectorOfSMELikeCreateMasks,
- LiftIllegalVectorTransposeToMemory,
- ConvertIllegalShapeCastOpsToTransposes,
- LowerIllegalTransposeStoreViaZA>(context);
+ // Apply preprocessing patterns.
+ RewritePatternSet rewritePatterns(context);
+ rewritePatterns.add<FoldExtractFromVectorOfSMELikeCreateMasks,
+ LiftIllegalVectorTransposeToMemory,
+ ConvertIllegalShapeCastOpsToTransposes,
+ LowerIllegalTransposeStoreViaZA>(context);
+ if (failed(
+ applyPatternsGreedily(getOperation(), std::move(rewritePatterns))))
+ return signalPassFailure();
+
// Note: These two patterns are added with a high benefit to ensure:
// - Masked outer products are handled before unmasked ones
// - Multi-tile writes are lowered as a store loop (if possible)
@@ -950,11 +959,20 @@ struct VectorLegalizationPass
LegalizeVectorOuterProductOpsByDecomposition,
LegalizeTransferReadOpsByDecomposition,
LegalizeTransferWriteOpsByDecomposition>(converter, context);
- populateFuncTypeConversionPatterns(converter, patterns);
- scf::populateSCFStructuralOneToNTypeConversions(converter, patterns);
-
- if (failed(applyPartialOneToNConversion(getOperation(), converter,
- std::move(patterns))))
+ populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(patterns,
+ converter);
+ populateCallOpTypeConversionPattern(patterns, converter);
+ populateReturnOpTypeConversionPattern(patterns, converter);
+ scf::populateSCFStructuralTypeConversions(converter, patterns);
+
+ ConversionTarget target(getContext());
+ target.markUnknownOpDynamicallyLegal(
+ [&](Operation *op) { return converter.isLegal(op); });
+ target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
+ return converter.isSignatureLegal(op.getFunctionType());
+ });
+ if (failed(applyPartialConversion(getOperation(), target,
+ std::move(patterns))))
return signalPassFailure();
}
};
diff --git a/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp
index 845a32c4d97b..2bdb640699d0 100644
--- a/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp
+++ b/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp
@@ -20,22 +20,6 @@
using namespace mlir;
using namespace mlir::arm_sve;
-template <typename OpTy>
-class ForwardOperands : public OpConversionPattern<OpTy> {
- using OpConversionPattern<OpTy>::OpConversionPattern;
-
- LogicalResult
- matchAndRewrite(OpTy op, typename OpTy::Adaptor adaptor,
- ConversionPatternRewriter &rewriter) const final {
- if (adaptor.getOperands().getTypes() == op->getOperands().getTypes())
- return rewriter.notifyMatchFailure(op, "operand types already match");
-
- rewriter.modifyOpInPlace(op,
- [&]() { op->setOperands(adaptor.getOperands()); });
- return success();
- }
-};
-
using SdotOpLowering = OneToOneConvertToLLVMPattern<SdotOp, SdotIntrOp>;
using SmmlaOpLowering = OneToOneConvertToLLVMPattern<SmmlaOp, SmmlaIntrOp>;
using UdotOpLowering = OneToOneConvertToLLVMPattern<UdotOp, UdotIntrOp>;
@@ -204,10 +188,6 @@ void mlir::populateArmSVELegalizeForLLVMExportPatterns(
// Populate conversion patterns
// clang-format off
- patterns.add<ForwardOperands<func::CallOp>,
- ForwardOperands<func::CallIndirectOp>,
- ForwardOperands<func::ReturnOp>>(converter,
- &converter.getContext());
patterns.add<SdotOpLowering,
SmmlaOpLowering,
UdotOpLowering,
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
index 349841f06959..1eb27e44810b 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
@@ -480,18 +480,21 @@ bool AnalysisState::isValueRead(Value value) const {
return false;
}
-// Starting from `value`, follow the use-def chain in reverse, always selecting
-// the aliasing OpOperands. Find and return Values for which `condition`
-// evaluates to true. OpOperands of such matching Values are not traversed any
-// further, the visited aliasing opOperands will be preserved through
-// `visitedOpOperands`.
+// Starting from `opOperand`, follow the use-def chain in reverse, always
+// selecting the aliasing OpOperands. Find and return Values for which
+// `condition` evaluates to true. Uses of such matching Values are not
+// traversed any further, the visited aliasing opOperands will be preserved
+// through `visitedOpOperands`.
llvm::SetVector<Value> AnalysisState::findValueInReverseUseDefChain(
- Value value, llvm::function_ref<bool(Value)> condition,
+ OpOperand *opOperand, llvm::function_ref<bool(Value)> condition,
TraversalConfig config,
llvm::DenseSet<OpOperand *> *visitedOpOperands) const {
llvm::DenseSet<Value> visited;
llvm::SetVector<Value> result, workingSet;
- workingSet.insert(value);
+ workingSet.insert(opOperand->get());
+
+ if (visitedOpOperands)
+ visitedOpOperands->insert(opOperand);
while (!workingSet.empty()) {
Value value = workingSet.pop_back_val();
@@ -563,12 +566,14 @@ llvm::SetVector<Value> AnalysisState::findValueInReverseUseDefChain(
return result;
}
-// Find the values that define the contents of the given value.
-llvm::SetVector<Value> AnalysisState::findDefinitions(Value value) const {
+// Find the values that define the contents of the given operand's value.
+llvm::SetVector<Value>
+AnalysisState::findDefinitions(OpOperand *opOperand) const {
TraversalConfig config;
config.alwaysIncludeLeaves = false;
return findValueInReverseUseDefChain(
- value, [&](Value v) { return this->bufferizesToMemoryWrite(v); }, config);
+ opOperand, [&](Value v) { return this->bufferizesToMemoryWrite(v); },
+ config);
}
AnalysisState::AnalysisState(const BufferizationOptions &options)
@@ -892,7 +897,7 @@ bool bufferization::detail::defaultResultBufferizesToMemoryWrite(
config.alwaysIncludeLeaves = false;
for (AliasingOpOperand alias : opOperands) {
if (!state
- .findValueInReverseUseDefChain(alias.opOperand->get(),
+ .findValueInReverseUseDefChain(alias.opOperand,
isMemoryWriteInsideOp, config)
.empty())
return true;
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp b/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp
index abc0635a2cdf..2c4e362101f8 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp
@@ -93,8 +93,31 @@ findValidInsertionPoint(Operation *emptyTensorOp, Operation *user,
return nullptr;
}
+Value mlir::bufferization::buildSubsetExtraction(RewriterBase &rewriter,
+ SubsetInsertionOpInterface op,
+ tensor::EmptyOp emptyTensorOp,
+ Operation *user) {
+
+ mlir::OpBuilder::InsertionGuard guard(rewriter);
+ // All values that are needed to create the replacement op.
+ SmallVector<Value> neededValues = op.getValuesNeededToBuildSubsetExtraction();
+ // Find a suitable insertion point. If no suitable insertion point
+ // for the replacement can be found, return an empty value to skip
+ // this replacement.
+ Operation *insertionPoint =
+ findValidInsertionPoint(emptyTensorOp, user, neededValues);
+ if (!insertionPoint)
+ return {};
+
+ rewriter.setInsertionPoint(insertionPoint);
+ Value replacement =
+ op.buildSubsetExtraction(rewriter, emptyTensorOp->getLoc());
+ return replacement;
+}
+
LogicalResult mlir::bufferization::eliminateEmptyTensors(
- RewriterBase &rewriter, Operation *op, OneShotAnalysisState &state) {
+ RewriterBase &rewriter, Operation *op, OneShotAnalysisState &state,
+ ControlBuildSubsetExtractionFn subsetsExtractionFn) {
OpBuilder::InsertionGuard g(rewriter);
llvm::DenseSet<OpOperand *> visitedOpOperands;
op->walk([&](SubsetInsertionOpInterface op) {
@@ -105,10 +128,6 @@ LogicalResult mlir::bufferization::eliminateEmptyTensors(
if (!state.isInPlace(source))
return WalkResult::skip();
- // All values that are needed to create the replacement op.
- SmallVector<Value> neededValues =
- op.getValuesNeededToBuildSubsetExtraction();
-
// Find tensor.empty ops on the reverse SSA use-def chain. Only follow
// equivalent tensors. I.e., stop when there are ops such as extract_slice
// on the path.
@@ -124,35 +143,23 @@ LogicalResult mlir::bufferization::eliminateEmptyTensors(
// %3 = tensor.insert_slice %2 into ...
config.followSameTypeOrCastsOnly = true;
SetVector<Value> emptyTensors = state.findValueInReverseUseDefChain(
- source.get(), /*condition=*/
+ &source, /*condition=*/
[&](Value val) { return val.getDefiningOp<tensor::EmptyOp>(); }, config,
&visitedOpOperands);
for (Value v : emptyTensors) {
- Operation *emptyTensorOp = v.getDefiningOp();
-
+ auto emptyTensorOp = v.getDefiningOp<tensor::EmptyOp>();
+ assert(emptyTensorOp && "expected tensor.empty op");
// Find the use to be replaced from the use-def chain.
auto iter = llvm::find_if(
visitedOpOperands, [&emptyTensorOp](OpOperand *opOperand) {
return llvm::count(emptyTensorOp->getUses(), *opOperand);
});
- // This could be achieved when a use of `emptyTensorOp` is being
- // consumed by `SubsetInsertionOpInterface`'s source directly.
- if (iter == visitedOpOperands.end())
- continue;
+
+ assert(iter != visitedOpOperands.end() && "could not find use");
OpOperand *useToBeReplaced = *iter;
Operation *user = useToBeReplaced->getOwner();
-
- // Find a suitable insertion point. If no suitable insertion point for
- // the replacement can be found, skip this replacement.
- Operation *insertionPoint =
- findValidInsertionPoint(emptyTensorOp, user, neededValues);
- if (!insertionPoint)
- continue;
-
- rewriter.setInsertionPoint(insertionPoint);
- Value replacement =
- op.buildSubsetExtraction(rewriter, emptyTensorOp->getLoc());
+ auto replacement = subsetsExtractionFn(rewriter, op, emptyTensorOp, user);
if (!replacement)
continue;
if (emptyTensorOp == replacement.getDefiningOp())
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp
index d1e6acef324f..fc1b221b4f03 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp
@@ -196,7 +196,12 @@ void OneShotAnalysisState::gatherUndefinedTensorUses(Operation *op) {
// If there is no preceding definition, the tensor contents are
// undefined.
- if (findDefinitionsCached(opResult).empty())
+ if (opResult.getUses().empty())
+ continue;
+ // It does not really matter which use to take to search about
+ // the value's definitions.
+ OpOperand *opOperand = &(*opResult.getUses().begin());
+ if (findDefinitionsCached(opOperand).empty())
for (OpOperand &use : opResult.getUses())
undefinedTensorUses.insert(&use);
}
@@ -464,7 +469,8 @@ static void annotateConflict(OpOperand *uRead, OpOperand *uConflictingWrite,
/// indexing. I.e., the tensor types do not change along the use-def chain,
/// apart from static <-> dynamic dim casts.
static bool hasEquivalentValueInReverseUseDefChain(AnalysisState &state,
- Value start, Value other) {
+ OpOperand *start,
+ Value other) {
TraversalConfig config;
config.followEquivalentOnly = true;
config.alwaysIncludeLeaves = false;
@@ -475,9 +481,10 @@ static bool hasEquivalentValueInReverseUseDefChain(AnalysisState &state,
.empty();
}
-/// Return "true" if `value` is originating from a subset that is equivalent to
-/// the subset that `subsetOp` inserts into.
-static bool matchesInsertDestination(const AnalysisState &state, Value value,
+/// Return "true" if the given operand's value is originating from a subset
+/// that is equivalent to the subset that `subsetOp` inserts into.
+static bool matchesInsertDestination(const AnalysisState &state,
+ OpOperand *opOperand,
SubsetInsertionOpInterface subsetOp) {
auto matchingSubset = [&](Value val) {
if (auto opResult = dyn_cast<OpResult>(val))
@@ -490,7 +497,7 @@ static bool matchesInsertDestination(const AnalysisState &state, Value value,
// There may be multiple leaves at which the reverse SSA use-def chain lookup
// terminates. All of them must be equivalent subsets.
SetVector<Value> backwardSlice =
- state.findValueInReverseUseDefChain(value, matchingSubset);
+ state.findValueInReverseUseDefChain(opOperand, matchingSubset);
return static_cast<bool>(llvm::all_of(backwardSlice, matchingSubset));
}
@@ -516,7 +523,7 @@ static bool areNonConflictingSubsets(OpOperand *uRead,
// {inplace= [true] }
if (uRead == &subsetOp.getDestinationOperand() &&
- matchesInsertDestination(state, uConflictingWrite->get(), subsetOp))
+ matchesInsertDestination(state, uConflictingWrite, subsetOp))
// Case 1: The main insight is that InsertSliceOp reads only part of
// the destination tensor. The overwritten area is not read. If
// uConflictingWrite writes into exactly the memory location that is
@@ -533,7 +540,7 @@ static bool areNonConflictingSubsets(OpOperand *uRead,
if (uRead == &subsetOp.getSourceOperand() &&
uConflictingWrite == &subsetOp.getDestinationOperand() &&
- matchesInsertDestination(state, uRead->get(), subsetOp))
+ matchesInsertDestination(state, uRead, subsetOp))
// Case 2: The read of the source tensor and the write to the dest
// tensor via an InsertSliceOp is not a conflict if the read is
// reading exactly that part of an equivalent tensor that the
@@ -567,8 +574,7 @@ static bool areNonConflictingSubsets(OpOperand *uRead,
if (uConflictingWrite == &subsetOp.getDestinationOperand() &&
state.areEquivalentBufferizedValues(
uRead->get(), subsetOp.getSourceOperand().get()) &&
- matchesInsertDestination(state, subsetOp.getSourceOperand().get(),
- subsetOp))
+ matchesInsertDestination(state, &subsetOp.getSourceOperand(), subsetOp))
return true;
return false;
@@ -600,9 +606,9 @@ hasReadAfterWriteInterference(const DenseSet<OpOperand *> &usesRead,
// even though that op just bufferizes to an allocation but does define
// the contents of the buffer.
SetVector<Value> definitionsOrLeaves =
- state.findValueInReverseUseDefChain(
- uConflictingWrite->get(),
- [&](Value v) { return state.bufferizesToMemoryWrite(v); });
+ state.findValueInReverseUseDefChain(uConflictingWrite, [&](Value v) {
+ return state.bufferizesToMemoryWrite(v);
+ });
assert(!definitionsOrLeaves.empty() &&
"expected at least one definition or leaf");
@@ -641,8 +647,7 @@ hasReadAfterWriteInterference(const DenseSet<OpOperand *> &usesRead,
// In the above example, if uRead is the OpOperand of reading_op, the
// definition is %0. Note that operations that create an alias but do not
// bufferize to a memory write (such as ExtractSliceOp) are skipped.
- const SetVector<Value> &definitions =
- state.findDefinitionsCached(uRead->get());
+ const SetVector<Value> &definitions = state.findDefinitionsCached(uRead);
if (definitions.empty()) {
// Fast path: No conflict if there are no definitions.
LLVM_DEBUG(llvm::dbgs()
@@ -714,9 +719,9 @@ hasReadAfterWriteInterference(const DenseSet<OpOperand *> &usesRead,
if (bufferizableOp.bufferizesToElementwiseAccess(
state, {uRead, uConflictingWrite})) {
if (hasEquivalentValueInReverseUseDefChain(
- state, uRead->get(), uConflictingWrite->get()) ||
+ state, uRead, uConflictingWrite->get()) ||
hasEquivalentValueInReverseUseDefChain(
- state, uConflictingWrite->get(), uRead->get())) {
+ state, uConflictingWrite, uRead->get())) {
LLVM_DEBUG(
llvm::dbgs()
<< " no conflict: op bufferizes to element-wise access\n");
@@ -965,11 +970,12 @@ wouldCreateWriteToNonWritableBuffer(OpOperand &operand,
// Bufferization analyses.
//===----------------------------------------------------------------------===//
-// Find the values that define the contents of the given value.
+// Find the values that define the contents of the given operand's value.
const llvm::SetVector<Value> &
-OneShotAnalysisState::findDefinitionsCached(Value value) {
+OneShotAnalysisState::findDefinitionsCached(OpOperand *opOperand) {
+ Value value = opOperand->get();
if (!cachedDefinitions.count(value))
- cachedDefinitions[value] = findDefinitions(value);
+ cachedDefinitions[value] = findDefinitions(opOperand);
return cachedDefinitions[value];
}
diff --git a/mlir/lib/Dialect/EmitC/Transforms/TypeConversions.cpp b/mlir/lib/Dialect/EmitC/Transforms/TypeConversions.cpp
index 0b3a494794f3..72c8fd0f3248 100644
--- a/mlir/lib/Dialect/EmitC/Transforms/TypeConversions.cpp
+++ b/mlir/lib/Dialect/EmitC/Transforms/TypeConversions.cpp
@@ -33,7 +33,6 @@ void mlir::populateEmitCSizeTTypeConversions(TypeConverter &converter) {
converter.addSourceMaterialization(materializeAsUnrealizedCast);
converter.addTargetMaterialization(materializeAsUnrealizedCast);
- converter.addArgumentMaterialization(materializeAsUnrealizedCast);
}
/// Get an unsigned integer or size data type corresponding to \p ty.
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMAttrs.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMAttrs.cpp
index c7ddc1b36f4d..ff1636bc121b 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMAttrs.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMAttrs.cpp
@@ -48,10 +48,28 @@ void LLVMDialect::registerAttributes() {
addAttributes<
#define GET_ATTRDEF_LIST
#include "mlir/Dialect/LLVMIR/LLVMOpsAttrDefs.cpp.inc"
+
>();
}
//===----------------------------------------------------------------------===//
+// AliasScopeAttr
+//===----------------------------------------------------------------------===//
+
+LogicalResult
+AliasScopeAttr::verify(function_ref<InFlightDiagnostic()> emitError,
+ Attribute id, AliasScopeDomainAttr domain,
+ StringAttr description) {
+ (void)domain;
+ (void)description;
+ if (!llvm::isa<StringAttr, DistinctAttr>(id))
+ return emitError()
+ << "id of an alias scope must be a StringAttr or a DistrinctAttr";
+
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
// DINodeAttr
//===----------------------------------------------------------------------===//
@@ -232,7 +250,7 @@ DIRecursiveTypeAttrInterface DISubprogramAttr::withRecId(DistinctAttr recId) {
DIRecursiveTypeAttrInterface DISubprogramAttr::getRecSelf(DistinctAttr recId) {
return DISubprogramAttr::get(recId.getContext(), recId, /*isRecSelf=*/true,
- {}, {}, {}, {}, {}, 0, 0, {}, {}, {}, {}, {});
+ {}, {}, {}, {}, {}, {}, 0, 0, {}, {}, {}, {});
}
//===----------------------------------------------------------------------===//
@@ -288,6 +306,16 @@ TargetFeaturesAttr TargetFeaturesAttr::get(MLIRContext *context,
}));
}
+TargetFeaturesAttr
+TargetFeaturesAttr::getChecked(function_ref<InFlightDiagnostic()> emitError,
+ MLIRContext *context,
+ llvm::ArrayRef<StringRef> features) {
+ return Base::getChecked(emitError, context,
+ llvm::map_to_vector(features, [&](StringRef feature) {
+ return StringAttr::get(context, feature);
+ }));
+}
+
TargetFeaturesAttr TargetFeaturesAttr::get(MLIRContext *context,
StringRef targetFeatures) {
SmallVector<StringRef> features;
@@ -296,6 +324,16 @@ TargetFeaturesAttr TargetFeaturesAttr::get(MLIRContext *context,
return get(context, features);
}
+TargetFeaturesAttr
+TargetFeaturesAttr::getChecked(function_ref<InFlightDiagnostic()> emitError,
+ MLIRContext *context, StringRef targetFeatures) {
+ SmallVector<StringRef> features;
+ targetFeatures.split(features, ',', /*MaxSplit=*/-1,
+ /*KeepEmpty=*/false);
+ ArrayRef featuresRef(features);
+ return getChecked(emitError, context, featuresRef);
+}
+
LogicalResult
TargetFeaturesAttr::verify(function_ref<InFlightDiagnostic()> emitError,
llvm::ArrayRef<StringAttr> features) {
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp b/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp
index 6801b68a8538..6c1087730ebb 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp
@@ -553,7 +553,7 @@ Value linalg::bufferizeToAllocation(
Value alloc = createAllocationForTensor(
rewriter, op->getLoc(), operand->get(), options, memorySpace);
allocs.push_back(alloc);
- if (!state.findDefinitions(operand->get()).empty()) {
+ if (!state.findDefinitions(operand).empty()) {
// Initialize buffer with a copy of the operand data. Not needed if the
// tensor is uninitialized.
createMemcpy(rewriter, op->getLoc(), operand->get(), alloc, options);
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp
index 0e651f4cee4c..fc6671ef8117 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp
@@ -154,7 +154,6 @@ public:
});
addSourceMaterialization(sourceMaterializationCallback);
- addArgumentMaterialization(sourceMaterializationCallback);
}
};
diff --git a/mlir/lib/Dialect/Linalg/Transforms/EliminateEmptyTensors.cpp b/mlir/lib/Dialect/Linalg/Transforms/EliminateEmptyTensors.cpp
index 4776883ed95c..b710bde87f9f 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/EliminateEmptyTensors.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/EliminateEmptyTensors.cpp
@@ -59,7 +59,7 @@ LogicalResult linalg::linalgOpAnchoredEmptyTensorEliminationStep(
config.followEquivalentOnly = true;
config.alwaysIncludeLeaves = false;
SetVector<Value> emptyTensors = state.findValueInReverseUseDefChain(
- in->get(), /*condition=*/
+ in, /*condition=*/
[&](Value val) {
return val.getDefiningOp<tensor::EmptyOp>() &&
val.getType() == in->get().getType();
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index 60cf897b00de..50593b08ad74 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -1656,8 +1656,8 @@ void linalg::populateDecomposeConvolutionPatterns(RewritePatternSet &patterns,
}
void linalg::populateDecomposePackUnpackPatterns(RewritePatternSet &patterns) {
- // TODO: Add and test patterns for tensor.unpack
patterns.add<DecomposeOuterUnitDimsPackOpPattern>(patterns.getContext());
+ patterns.add<DecomposeOuterUnitDimsUnPackOpPattern>(patterns.getContext());
}
void linalg::populateDecomposePadPatterns(RewritePatternSet &patterns) {
diff --git a/mlir/lib/Dialect/Quant/Transforms/StripFuncQuantTypes.cpp b/mlir/lib/Dialect/Quant/Transforms/StripFuncQuantTypes.cpp
index 619127226628..71b88d1be1b0 100644
--- a/mlir/lib/Dialect/Quant/Transforms/StripFuncQuantTypes.cpp
+++ b/mlir/lib/Dialect/Quant/Transforms/StripFuncQuantTypes.cpp
@@ -56,7 +56,6 @@ public:
addConversion(convertQuantizedType);
addConversion(convertTensorType);
- addArgumentMaterialization(materializeConversion);
addSourceMaterialization(materializeConversion);
addTargetMaterialization(materializeConversion);
}
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index eded1c394f12..83ae79ce4826 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -839,8 +839,7 @@ mlir::scf::replaceAndCastForOpIterArg(RewriterBase &rewriter, scf::ForOp forOp,
namespace {
// Fold away ForOp iter arguments when:
// 1) The op yields the iter arguments.
-// 2) The iter arguments have no use and the corresponding outer region
-// iterators (inputs) are yielded.
+// 2) The argument's corresponding outer region iterators (inputs) are yielded.
// 3) The iter arguments have no use and the corresponding (operation) results
// have no use.
//
@@ -872,30 +871,28 @@ struct ForOpIterArgsFolder : public OpRewritePattern<scf::ForOp> {
newIterArgs.reserve(forOp.getInitArgs().size());
newYieldValues.reserve(numResults);
newResultValues.reserve(numResults);
- for (auto it : llvm::zip(forOp.getInitArgs(), // iter from outside
- forOp.getRegionIterArgs(), // iter inside region
- forOp.getResults(), // op results
- forOp.getYieldedValues() // iter yield
- )) {
+ for (auto [init, arg, result, yielded] :
+ llvm::zip(forOp.getInitArgs(), // iter from outside
+ forOp.getRegionIterArgs(), // iter inside region
+ forOp.getResults(), // op results
+ forOp.getYieldedValues() // iter yield
+ )) {
// Forwarded is `true` when:
// 1) The region `iter` argument is yielded.
- // 2) The region `iter` argument has no use, and the corresponding iter
- // operand (input) is yielded.
+ // 2) The region `iter` argument the corresponding input is yielded.
// 3) The region `iter` argument has no use, and the corresponding op
// result has no use.
- bool forwarded = ((std::get<1>(it) == std::get<3>(it)) ||
- (std::get<1>(it).use_empty() &&
- (std::get<0>(it) == std::get<3>(it) ||
- std::get<2>(it).use_empty())));
+ bool forwarded = (arg == yielded) || (init == yielded) ||
+ (arg.use_empty() && result.use_empty());
keepMask.push_back(!forwarded);
canonicalize |= forwarded;
if (forwarded) {
- newBlockTransferArgs.push_back(std::get<0>(it));
- newResultValues.push_back(std::get<0>(it));
+ newBlockTransferArgs.push_back(init);
+ newResultValues.push_back(init);
continue;
}
- newIterArgs.push_back(std::get<0>(it));
- newYieldValues.push_back(std::get<3>(it));
+ newIterArgs.push_back(init);
+ newYieldValues.push_back(yielded);
newBlockTransferArgs.push_back(Value()); // placeholder with null value
newResultValues.push_back(Value()); // placeholder with null value
}
diff --git a/mlir/lib/Dialect/SCF/Utils/Utils.cpp b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
index 41410a0a56aa..6cda7100fe07 100644
--- a/mlir/lib/Dialect/SCF/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
@@ -329,8 +329,9 @@ static void generateUnrolledLoop(
// 'forOp'.
auto builder = OpBuilder::atBlockTerminator(loopBodyBlock);
+ constexpr auto defaultAnnotateFn = [](unsigned, Operation *, OpBuilder) {};
if (!annotateFn)
- annotateFn = [](unsigned, Operation *, OpBuilder) {};
+ annotateFn = defaultAnnotateFn;
// Keep a pointer to the last non-terminator operation in the original block
// so that we know what to clone (since we are doing this in-place).
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorDescriptor.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorDescriptor.cpp
index 834e3634cc13..8bbb2cac5efd 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorDescriptor.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorDescriptor.cpp
@@ -69,9 +69,6 @@ SparseTensorTypeToBufferConverter::SparseTensorTypeToBufferConverter() {
// Required by scf.for 1:N type conversion.
addSourceMaterialization(materializeTuple);
-
- // Required as a workaround until we have full 1:N support.
- addArgumentMaterialization(materializeTuple);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index f79c774ceb3e..24a1d5531531 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -4795,6 +4795,44 @@ static SmallVector<Value> getNewOperands(DestinationStyleOpInterface op,
return newOperands;
}
+// Given the (potentially) updated packed type, `newPackedTy`, generates an
+// updated mixed-tile-sizes attribute. A tile size is updated only
+// when:
+// * a dim from newPackedTy is static, and
+// * the corresponding size from mixedTiles is still dynamic.
+// Otherwise, the original tile size is preserved.
+// Note - packed-type-dim and mixed-tile-size should always match!
+static SmallVector<OpFoldResult>
+getNewMixedTileSizes(PatternRewriter &rewriter, Type newPackedTy,
+ SmallVector<OpFoldResult> mixedTiles) {
+ SmallVector<OpFoldResult> newMixedTileSizes;
+ for (auto it : llvm::zip(cast<ShapedType>(newPackedTy)
+ .getShape()
+ .take_back(mixedTiles.size()),
+ mixedTiles)) {
+ int64_t shape = std::get<0>(it);
+ if (shape == ShapedType::kDynamic) {
+ newMixedTileSizes.push_back(std::get<1>(it));
+ continue;
+ }
+
+ // If the current result dim is static, update the dynamic mixed-size
+ // (provided the original value is dynamic).
+ OpFoldResult tile = std::get<1>(it);
+ if (Attribute attr = llvm::dyn_cast_if_present<Attribute>(tile)) {
+ // Already a constant
+ newMixedTileSizes.push_back(tile);
+ } else {
+ assert(getConstantIntValue(tile).value() == shape &&
+ "tile size and dim size don't match!");
+ newMixedTileSizes.push_back(
+ (rewriter.getIntegerAttr(rewriter.getIndexType(), shape)));
+ }
+ }
+
+ return newMixedTileSizes;
+}
+
/// Folds a tensor.cast op into a consuming tensor::PackOp op if the
/// `tensor.cast` has source that is more static than the consuming op.
///
@@ -4821,31 +4859,13 @@ struct FoldTensorCastPackOp : public OpRewritePattern<PackOp> {
SmallVector<Value> newOperands = getNewOperands(op, newResultTypes);
// Get the updated mixed-tile-sizes attribute.
- SmallVector<OpFoldResult> newMixedTileSizes;
- for (auto it : llvm::zip(cast<ShapedType>(newResultTypes[0])
- .getShape()
- .take_back(op.getMixedTiles().size()),
- op.getMixedTiles())) {
- int64_t shape = std::get<0>(it);
- if (shape == ShapedType::kDynamic) {
- newMixedTileSizes.push_back(std::get<1>(it));
- continue;
- }
-
- if (Attribute attr =
- llvm::dyn_cast_if_present<Attribute>(std::get<1>(it))) {
- // Already a constant
- newMixedTileSizes.push_back(std::get<1>(it));
- } else {
- int64_t tileSize = getConstantIntValue(std::get<1>(it)).value();
- assert(tileSize == shape && "tile size and dim size don't match!");
- (void)tileSize;
- newMixedTileSizes.push_back(
- (rewriter.getIntegerAttr(rewriter.getIndexType(), shape)));
- }
- }
+ SmallVector<OpFoldResult> newMixedTileSizes =
+ getNewMixedTileSizes(rewriter, newResultTypes[0], op.getMixedTiles());
// Clone op.
+ // TODO: Strictly speaking, discardable attributes should be _discarded_ at
+ // this point. However, in practice, we use them for things that we'd like
+ // to preserve. Implement a better abstraction.
PackOp newOp = rewriter.create<PackOp>(
op.getLoc(), newOperands[0], newOperands[1], op.getInnerDimsPos(),
newMixedTileSizes, op.getPaddingValue(), op.getOuterDimsPerm());
@@ -4865,6 +4885,59 @@ struct FoldTensorCastPackOp : public OpRewritePattern<PackOp> {
}
};
+/// Folds a tensor.cast op into a consuming tensor::UnPackOp op if the
+/// `tensor.cast` has source that is more static than the consuming op.
+///
+/// Example:
+/// ```mlir
+/// %1 = tensor.cast %0 : tensor<1x1x8x1xi32> to tensor<1x1x?x1xi32>
+/// %2 = tensor.unpack %1 ... : tensor<1x1x?x1xi32> -> tensor<7x?xi32>
+/// ```
+///
+/// folds into:
+///
+/// ```mlir
+/// %2 = tensor.unpack %0 ... tensor<1x1x8x1xi32> -> tensor<7x?xi32>
+/// ```
+struct FoldTensorCastUnPackOp : public OpRewritePattern<UnPackOp> {
+ using OpRewritePattern<UnPackOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(UnPackOp op,
+ PatternRewriter &rewriter) const override {
+ if (!foldTensorCastPrecondition(op))
+ return failure();
+
+ SmallVector<Type> newResultTypes(op->getResultTypes());
+ SmallVector<Value> newOperands = getNewOperands(op, newResultTypes);
+ Value sourceTensor = newOperands[0];
+
+ // Get the updated mixed-tile-sizes attribute.
+ SmallVector<OpFoldResult> newMixedTileSizes = getNewMixedTileSizes(
+ rewriter, sourceTensor.getType(), op.getMixedTiles());
+
+ // Clone op.
+ // TODO: Strictly speaking, discardable attributes should be _discarded_ at
+ // this point. However, in practice, we use them for things that we'd like
+ // to preserve. Implement a better abstraction.
+ UnPackOp newOp = rewriter.create<UnPackOp>(
+ op.getLoc(), sourceTensor, newOperands[1], op.getInnerDimsPos(),
+ newMixedTileSizes, op.getOuterDimsPerm());
+ newOp->setDiscardableAttrs(op->getDiscardableAttrDictionary());
+
+ // Replace op.
+ Value oldResult = op.getResult();
+ Value newResult = newOp.getResult();
+ Value replacement = (newResult.getType() != oldResult.getType())
+ ? rewriter.create<tensor::CastOp>(
+ op->getLoc(), oldResult.getType(), newResult)
+ : newResult;
+
+ rewriter.replaceOp(op, {replacement});
+
+ return success();
+ }
+};
+
/// Folds a tensor.cast op into a consuming DestinationStyleOpInterface op if
/// the `tensor.cast` has source that is more static than the consuming op.
///
@@ -4890,7 +4963,8 @@ struct FoldTensorCastProducerOp
PatternRewriter &rewriter) const override {
// Reject tensor::PackOp - there's dedicated pattern for that instead.
- if (!foldTensorCastPrecondition(op) || dyn_cast<tensor::PackOp>(*op))
+ if (!foldTensorCastPrecondition(op) ||
+ isa<tensor::PackOp, tensor::UnPackOp>(*op))
return failure();
SmallVector<Type> newResultTypes(op->getResultTypes());
@@ -4923,6 +4997,7 @@ struct FoldTensorCastProducerOp
void TensorDialect::getCanonicalizationPatterns(
RewritePatternSet &results) const {
results.add<FoldTensorCastPackOp>(getContext());
+ results.add<FoldTensorCastUnPackOp>(getContext());
results.add<FoldTensorCastProducerOp>(getContext());
}
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index 39d0ee122b16..f51c3dbce6ee 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -1002,10 +1002,6 @@ OpFoldResult TransposeOp::fold(FoldAdaptor adaptor) {
return input.reshape(resultTy);
}
- // Transpose does not change the input type.
- if (getInput1().getType() != getType())
- return {};
-
// Transpose is not the identity transpose.
SmallVector<int32_t> perms;
if (getConstantPerms(perms).failed())
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 631d3c48f2df..764a5db48e07 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -210,7 +210,12 @@ template <typename T>
static LogicalResult verifyConvOp(T op) {
// All TOSA conv ops have an input() and weight().
auto inputType = llvm::dyn_cast<RankedTensorType>(op.getInput().getType());
- auto weightType = llvm::dyn_cast<RankedTensorType>(op.getWeight().getType());
+
+ RankedTensorType weightType;
+ if constexpr (std::is_same_v<T, tosa::TransposeConv2DOp>)
+ weightType = llvm::dyn_cast<RankedTensorType>(op.getFilter().getType());
+ else
+ weightType = llvm::dyn_cast<RankedTensorType>(op.getWeight().getType());
// Must be ranked tensor types
if (!inputType) {
@@ -218,7 +223,13 @@ static LogicalResult verifyConvOp(T op) {
return failure();
}
if (!weightType) {
- op.emitOpError("expect a ranked tensor for weight, got ") << op.getWeight();
+ if constexpr (std::is_same_v<T, tosa::TransposeConv2DOp>) {
+ op.emitOpError("expect a ranked tensor for filter, got ")
+ << op.getFilter();
+ } else {
+ op.emitOpError("expect a ranked tensor for weight, got ")
+ << op.getWeight();
+ }
return failure();
}
@@ -271,6 +282,38 @@ LogicalResult tosa::ConstOp::verify() {
return success();
}
+template <typename T>
+static LogicalResult verifyConvOpModes(T op) {
+ auto inputEType =
+ llvm::cast<ShapedType>(op.getInput().getType()).getElementType();
+
+ if (auto quantType =
+ llvm::dyn_cast<mlir::quant::UniformQuantizedType>(inputEType))
+ inputEType = quantType.getStorageType();
+
+ auto accType = op.getAccType();
+ if (inputEType.isInteger(8) && !accType.isInteger(32))
+ return op.emitOpError("accumulator type for i8 tensor is not i32");
+
+ if (inputEType.isInteger(16) && !accType.isInteger(48))
+ return op.emitOpError("accumulator type for i16 tensor is not i48");
+
+ if ((inputEType.isFloat8E5M2() || inputEType.isFloat8E4M3()) &&
+ !accType.isF16())
+ return op.emitOpError("accumulator type for f8 tensor is not f16");
+
+ if (inputEType.isF16() && !(accType.isF16() || accType.isF32()))
+ return op.emitOpError("accumulator type for f16 tensor is not f16/f32");
+
+ if (inputEType.isBF16() && !accType.isF32())
+ return op.emitOpError("accumulator type for bf16 tensor is not f32");
+
+ if (inputEType.isF32() && !accType.isF32())
+ return op.emitOpError("accumulator type for f32 tensor is not f32");
+
+ return success();
+}
+
LogicalResult tosa::ArgMaxOp::verify() {
// Ensure output is of 32-bit integer
const auto resultETy = llvm::cast<ShapedType>(getType()).getElementType();
@@ -368,12 +411,14 @@ static void buildConvOpWithQuantInfo(OpBuilder &builder, OperationState &result,
Type outputType, Value input, Value weight,
Value bias, DenseI64ArrayAttr pad,
DenseI64ArrayAttr stride,
- DenseI64ArrayAttr dilation) {
+ DenseI64ArrayAttr dilation,
+ TypeAttr accType) {
result.addOperands({input, weight, bias});
result.addAttribute("pad", pad);
result.addAttribute("stride", stride);
result.addAttribute("dilation", dilation);
+ result.addAttribute("acc_type", accType);
auto quantAttr = buildConvOpQuantizationAttr(builder, input, weight);
if (quantAttr) {
@@ -390,11 +435,12 @@ static void buildConvOpWithQuantInfo(OpBuilder &builder, OperationState &result,
static void buildTransConvOpWithQuantInfo(
OpBuilder &builder, OperationState &result, Type outputType, Value input,
Value weight, Value bias, DenseI64ArrayAttr outpad,
- DenseI64ArrayAttr stride, DenseI64ArrayAttr outputShape) {
+ DenseI64ArrayAttr stride, DenseI64ArrayAttr outputShape, TypeAttr accType) {
result.addOperands({input, weight, bias});
result.addAttribute("out_pad", outpad);
result.addAttribute("stride", stride);
result.addAttribute("out_shape", outputShape);
+ result.addAttribute("acc_type", accType);
auto quantAttr = ::buildConvOpQuantizationAttr(builder, input, weight);
if (quantAttr) {
@@ -787,7 +833,7 @@ LogicalResult tosa::PadOp::inferReturnTypeComponents(
return success();
}
- outputShape.resize(paddingShape.getDimSize(0), ShapedType::kDynamic);
+ outputShape.resize(paddingShape.getDimSize(0) / 2, ShapedType::kDynamic);
inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
return success();
}
@@ -823,13 +869,17 @@ LogicalResult tosa::PadOp::inferReturnTypeComponents(
LogicalResult tosa::PadOp::verify() {
RankedTensorType inputType = getInput1().getType();
RankedTensorType outputType = getOutput().getType();
- TensorType paddingType = getPadding().getType();
+ RankedTensorType paddingType = getPadding().getType();
if (inputType.getRank() != outputType.getRank())
return emitOpError() << "expect same input and output tensor rank.";
- if (paddingType.hasRank() && paddingType.getRank() != 2)
- return emitOpError() << "expect 'padding' tensor rank equal to 2.";
+ if (!paddingType.isDynamicDim(0) &&
+ paddingType.getDimSize(0) != inputType.getRank() * 2)
+ return emitOpError() << "expected padding tensor dim 0 to have size "
+ << inputType.getRank() * 2
+ << " (2*rank(shape1)) but got size "
+ << paddingType.getDimSize(0);
return success();
}
@@ -1595,7 +1645,11 @@ LogicalResult Conv2DOp::inferReturnTypeComponents(
return success();
}
-LogicalResult Conv2DOp::verify() { return verifyConvOp(*this); }
+LogicalResult Conv2DOp::verify() {
+ if (verifyConvOp(*this).failed() || verifyConvOpModes(*this).failed())
+ return failure();
+ return success();
+}
LogicalResult Conv3DOp::inferReturnTypeComponents(
MLIRContext *context, ::std::optional<Location> location,
@@ -1667,7 +1721,11 @@ LogicalResult Conv3DOp::inferReturnTypeComponents(
return success();
}
-LogicalResult Conv3DOp::verify() { return verifyConvOp(*this); }
+LogicalResult Conv3DOp::verify() {
+ if (verifyConvOp(*this).failed() || verifyConvOpModes(*this).failed())
+ return failure();
+ return success();
+}
LogicalResult AvgPool2dOp::inferReturnTypeComponents(
MLIRContext *context, ::std::optional<Location> location,
@@ -1762,7 +1820,11 @@ LogicalResult DepthwiseConv2DOp::inferReturnTypeComponents(
return success();
}
-LogicalResult DepthwiseConv2DOp::verify() { return verifyConvOp(*this); }
+LogicalResult DepthwiseConv2DOp::verify() {
+ if (verifyConvOp(*this).failed() || verifyConvOpModes(*this).failed())
+ return failure();
+ return success();
+}
LogicalResult TransposeConv2DOp::inferReturnTypeComponents(
MLIRContext *context, ::std::optional<Location> location,
@@ -1828,6 +1890,12 @@ LogicalResult TransposeConv2DOp::inferReturnTypeComponents(
return success();
}
+LogicalResult TransposeConv2DOp::verify() {
+ if (verifyConvOp(*this).failed() || verifyConvOpModes(*this).failed())
+ return failure();
+ return success();
+}
+
LogicalResult IfOp::inferReturnTypeComponents(
MLIRContext *context, ::std::optional<Location> location,
IfOp::Adaptor adaptor,
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeConv2D.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeConv2D.cpp
index 44f64f76e9b0..04a709c59677 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeConv2D.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeConv2D.cpp
@@ -81,7 +81,7 @@ struct Conv2DIsFullyConnected : public OpRewritePattern<tosa::Conv2DOp> {
}
}
- auto padSizeTy = RankedTensorType::get({4, 2}, rewriter.getI64Type());
+ auto padSizeTy = RankedTensorType::get({8}, rewriter.getI64Type());
auto padSize =
DenseIntElementsAttr::get(padSizeTy, ArrayRef<int64_t>(pad));
Value padSizeVal =
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp
index e6fba211dc37..14f392ab8c45 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp
@@ -108,7 +108,7 @@ struct DepthwiseConv2DIsMul : public OpRewritePattern<tosa::DepthwiseConv2DOp> {
}
}
- auto padSizeTy = RankedTensorType::get({5, 2}, rewriter.getI64Type());
+ auto padSizeTy = RankedTensorType::get({10}, rewriter.getI64Type());
auto padSize =
DenseIntElementsAttr::get(padSizeTy, ArrayRef<int64_t>(pad));
Value padSizeVal =
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp
index 0779cdb9667a..db1e219b601b 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp
@@ -75,13 +75,15 @@ public:
loc, resultTy, input, reverse2, bias,
rewriter.getDenseI64ArrayAttr(convPad),
rewriter.getDenseI64ArrayAttr(stride),
- rewriter.getDenseI64ArrayAttr({1, 1}), *op.getQuantizationInfo());
+ rewriter.getDenseI64ArrayAttr({1, 1}),
+ /* acc_type = */ op.getAccType(), *op.getQuantizationInfo());
} else {
conv2d = rewriter.create<tosa::Conv2DOp>(
loc, resultTy, input, reverse2, bias,
rewriter.getDenseI64ArrayAttr(convPad),
rewriter.getDenseI64ArrayAttr(stride),
- rewriter.getDenseI64ArrayAttr({1, 1}));
+ rewriter.getDenseI64ArrayAttr({1, 1}),
+ /* acc_type = */ op.getAccTypeAttr());
}
rewriter.replaceOp(op, conv2d);
@@ -139,7 +141,7 @@ public:
weightPadding[5] =
(weightWidth % stride[1]) ? (stride[1] - weightWidth % stride[1]) : 0;
DenseElementsAttr weightPaddingAttr = DenseIntElementsAttr::get(
- RankedTensorType::get({4, 2}, rewriter.getI32Type()), weightPadding);
+ RankedTensorType::get({8}, rewriter.getI32Type()), weightPadding);
Value weightPaddingVal = CreateOpAndInferShape<tosa::ConstOp>(
rewriter, loc, weightPaddingAttr.getType(), weightPaddingAttr);
@@ -202,7 +204,7 @@ public:
inputPadding[5] += restridedWeightTy.getDimSize(2) - 1;
DenseElementsAttr inputPaddingAttr = DenseIntElementsAttr::get(
- RankedTensorType::get({4, 2}, rewriter.getI32Type()), inputPadding);
+ RankedTensorType::get({8}, rewriter.getI32Type()), inputPadding);
Value inputPaddingVal = CreateOpAndInferShape<tosa::ConstOp>(
rewriter, loc, inputPaddingAttr.getType(), inputPaddingAttr);
@@ -238,7 +240,7 @@ public:
/*pad=*/rewriter.getDenseI64ArrayAttr({0, 0, 0, 0}),
/*stride=*/rewriter.getDenseI64ArrayAttr({1, 1}),
/*dilation=*/rewriter.getDenseI64ArrayAttr({1, 1}),
- *op.getQuantizationInfo())
+ /* acc_type = */ op.getAccType(), *op.getQuantizationInfo())
.getResult();
} else {
conv2d = CreateOpAndInferShape<tosa::Conv2DOp>(
@@ -246,7 +248,8 @@ public:
weight, zeroBias,
/*pad=*/rewriter.getDenseI64ArrayAttr({0, 0, 0, 0}),
/*stride=*/rewriter.getDenseI64ArrayAttr({1, 1}),
- /*dilation=*/rewriter.getDenseI64ArrayAttr({1, 1}))
+ /*dilation=*/rewriter.getDenseI64ArrayAttr({1, 1}),
+ /* acc_type = */ op.getAccTypeAttr())
.getResult();
}
@@ -314,7 +317,7 @@ public:
resultPadding[5] = resultTy.getDimSize(2) - resultPadLeft - sliceSize[2];
DenseElementsAttr resultPaddingAttr = DenseIntElementsAttr::get(
- RankedTensorType::get({4, 2}, rewriter.getI32Type()), resultPadding);
+ RankedTensorType::get({8}, rewriter.getI32Type()), resultPadding);
Value resultPaddingVal = CreateOpAndInferShape<tosa::ConstOp>(
rewriter, loc, resultPaddingAttr.getType(), resultPaddingAttr);
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
index 6fd671051362..8588c878bfe4 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
@@ -542,9 +542,13 @@ bool TosaValidation::isValidElementType(Type type) {
void TosaValidation::runOnOperation() {
configLevelAndProfile();
+
+ TosaDialect *tosaDialect = getContext().getLoadedDialect<TosaDialect>();
+ if (!tosaDialect)
+ return;
+
getOperation().walk([&](Operation *op) {
- if (!op->getDialect() ||
- op->getDialect()->getNamespace() != TosaDialect::getDialectNamespace())
+ if (op->getDialect() != tosaDialect)
return;
for (Value operand : op->getOperands()) {
diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
index 106a79473509..798853a75441 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
@@ -2840,6 +2840,7 @@ transform::PrintOp::apply(transform::TransformRewriter &rewriter,
llvm::outs() << "top-level ]]]\n";
state.getTopLevel()->print(llvm::outs(), printFlags);
llvm::outs() << "\n";
+ llvm::outs().flush();
return DiagnosedSilenceableFailure::success();
}
@@ -2849,6 +2850,7 @@ transform::PrintOp::apply(transform::TransformRewriter &rewriter,
llvm::outs() << "\n";
}
+ llvm::outs().flush();
return DiagnosedSilenceableFailure::success();
}
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
index 757631944f22..68535ae5a7a5 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
@@ -481,7 +481,6 @@ void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality(
return builder.create<vector::ShapeCastOp>(loc, type, inputs.front());
};
- typeConverter.addArgumentMaterialization(materializeCast);
typeConverter.addSourceMaterialization(materializeCast);
typeConverter.addTargetMaterialization(materializeCast);
target.markUnknownOpDynamicallyLegal(
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index 6fe96504ae10..c603db450cbd 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -284,22 +284,29 @@ OpPrintingFlags &OpPrintingFlags::skipRegions(bool skip) {
}
/// Do not verify the operation when using custom operation printers.
-OpPrintingFlags &OpPrintingFlags::assumeVerified() {
- assumeVerifiedFlag = true;
+OpPrintingFlags &OpPrintingFlags::assumeVerified(bool enable) {
+ assumeVerifiedFlag = enable;
return *this;
}
/// Use local scope when printing the operation. This allows for using the
/// printer in a more localized and thread-safe setting, but may not necessarily
/// be identical of what the IR will look like when dumping the full module.
-OpPrintingFlags &OpPrintingFlags::useLocalScope() {
- printLocalScope = true;
+OpPrintingFlags &OpPrintingFlags::useLocalScope(bool enable) {
+ printLocalScope = enable;
return *this;
}
/// Print users of values as comments.
-OpPrintingFlags &OpPrintingFlags::printValueUsers() {
- printValueUsersFlag = true;
+OpPrintingFlags &OpPrintingFlags::printValueUsers(bool enable) {
+ printValueUsersFlag = enable;
+ return *this;
+}
+
+/// Print unique SSA ID numbers for values, block arguments and naming conflicts
+/// across all regions
+OpPrintingFlags &OpPrintingFlags::printUniqueSSAIDs(bool enable) {
+ printUniqueSSAIDsFlag = enable;
return *this;
}
diff --git a/mlir/lib/IR/Dominance.cpp b/mlir/lib/IR/Dominance.cpp
index 406e0f2d62d6..1c54e09d29b9 100644
--- a/mlir/lib/IR/Dominance.cpp
+++ b/mlir/lib/IR/Dominance.cpp
@@ -213,61 +213,73 @@ DominanceInfoBase<IsPostDom>::findNearestCommonDominator(Block *a,
return getDomTree(a->getParent()).findNearestCommonDominator(a, b);
}
-/// Return true if the specified block A properly dominates block B.
-template <bool IsPostDom>
-bool DominanceInfoBase<IsPostDom>::properlyDominatesImpl(Block *a,
- Block *b) const {
- assert(a && b && "null blocks not allowed");
+/// Returns the given block iterator if it lies within the region region.
+/// Otherwise, otherwise finds the ancestor of the given block iterator that
+/// lies within the given region. Returns and "empty" iterator if the latter
+/// fails.
+///
+/// Note: This is a variant of Region::findAncestorOpInRegion that operates on
+/// block iterators instead of ops.
+static std::pair<Block *, Block::iterator>
+findAncestorIteratorInRegion(Region *r, Block *b, Block::iterator it) {
+ // Case 1: The iterator lies within the region region.
+ if (b->getParent() == r)
+ return std::make_pair(b, it);
+
+ // Otherwise: Find ancestor iterator. Bail if we run out of parent ops.
+ Operation *parentOp = b->getParentOp();
+ if (!parentOp)
+ return std::make_pair(static_cast<Block *>(nullptr), Block::iterator());
+ Operation *op = r->findAncestorOpInRegion(*parentOp);
+ if (!op)
+ return std::make_pair(static_cast<Block *>(nullptr), Block::iterator());
+ return std::make_pair(op->getBlock(), op->getIterator());
+}
- // A block dominates, but does not properly dominate, itself unless this
- // is a graph region.
+/// Given two iterators into the same block, return "true" if `a` is before `b.
+/// Note: This is a variant of Operation::isBeforeInBlock that operates on
+/// block iterators instead of ops.
+static bool isBeforeInBlock(Block *block, Block::iterator a,
+ Block::iterator b) {
if (a == b)
- return !hasSSADominance(a);
-
- // If both blocks are not in the same region, `a` properly dominates `b` if
- // `b` is defined in an operation region that (recursively) ends up being
- // dominated by `a`. Walk up the list of containers enclosing B.
- Region *regionA = a->getParent();
- if (regionA != b->getParent()) {
- b = regionA ? regionA->findAncestorBlockInRegion(*b) : nullptr;
- // If we could not find a valid block b then it is a not a dominator.
- if (!b)
- return false;
-
- // Check to see if the ancestor of `b` is the same block as `a`. A properly
- // dominates B if it contains an op that contains the B block.
- if (a == b)
- return true;
- }
-
- // Otherwise, they are two different blocks in the same region, use DomTree.
- return getDomTree(regionA).properlyDominates(a, b);
+ return false;
+ if (a == block->end())
+ return false;
+ if (b == block->end())
+ return true;
+ return a->isBeforeInBlock(&*b);
}
template <bool IsPostDom>
bool DominanceInfoBase<IsPostDom>::properlyDominatesImpl(
- Operation *a, Operation *b, bool enclosingOpOk) const {
- Block *aBlock = a->getBlock(), *bBlock = b->getBlock();
- assert(aBlock && bBlock && "operations must be in a block");
+ Block *aBlock, Block::iterator aIt, Block *bBlock, Block::iterator bIt,
+ bool enclosingOk) const {
+ assert(aBlock && bBlock && "expected non-null blocks");
- // An operation (pos)dominates, but does not properly (pos)dominate, itself
- // unless this is a graph region.
- if (a == b)
+ // A block iterator (post)dominates, but does not properly (post)dominate,
+ // itself unless this is a graph region.
+ if (aBlock == bBlock && aIt == bIt)
return !hasSSADominance(aBlock);
- // If these ops are in different regions, then normalize one into the other.
+ // If the iterators are in different regions, then normalize one into the
+ // other.
Region *aRegion = aBlock->getParent();
if (aRegion != bBlock->getParent()) {
- // Scoot up b's region tree until we find an operation in A's region that
+ // Scoot up b's region tree until we find a location in A's region that
// encloses it. If this fails, then we know there is no (post)dom relation.
- b = aRegion ? aRegion->findAncestorOpInRegion(*b) : nullptr;
- if (!b)
+ if (!aRegion) {
+ bBlock = nullptr;
+ bIt = Block::iterator();
+ } else {
+ std::tie(bBlock, bIt) =
+ findAncestorIteratorInRegion(aRegion, bBlock, bIt);
+ }
+ if (!bBlock)
return false;
- bBlock = b->getBlock();
- assert(bBlock->getParent() == aRegion);
+ assert(bBlock->getParent() == aRegion && "expected block in regionA");
// If 'a' encloses 'b', then we consider it to (post)dominate.
- if (a == b && enclosingOpOk)
+ if (aBlock == bBlock && aIt == bIt && enclosingOk)
return true;
}
@@ -279,9 +291,9 @@ bool DominanceInfoBase<IsPostDom>::properlyDominatesImpl(
if (!hasSSADominance(aBlock))
return true;
if constexpr (IsPostDom) {
- return b->isBeforeInBlock(a);
+ return isBeforeInBlock(aBlock, bIt, aIt);
} else {
- return a->isBeforeInBlock(b);
+ return isBeforeInBlock(aBlock, aIt, bIt);
}
}
@@ -309,6 +321,18 @@ template class detail::DominanceInfoBase</*IsPostDom=*/false>;
// DominanceInfo
//===----------------------------------------------------------------------===//
+bool DominanceInfo::properlyDominates(Operation *a, Operation *b,
+ bool enclosingOpOk) const {
+ return super::properlyDominatesImpl(a->getBlock(), a->getIterator(),
+ b->getBlock(), b->getIterator(),
+ enclosingOpOk);
+}
+
+bool DominanceInfo::properlyDominates(Block *a, Block *b) const {
+ return super::properlyDominatesImpl(a, a->begin(), b, b->begin(),
+ /*enclosingOk=*/true);
+}
+
/// Return true if the `a` value properly dominates operation `b`, i.e if the
/// operation that defines `a` properlyDominates `b` and the operation that
/// defines `a` does not contain `b`.
@@ -322,3 +346,19 @@ bool DominanceInfo::properlyDominates(Value a, Operation *b) const {
// `b`, but `a` does not itself enclose `b` in one of its regions.
return properlyDominates(a.getDefiningOp(), b, /*enclosingOpOk=*/false);
}
+
+//===----------------------------------------------------------------------===//
+// PostDominanceInfo
+//===----------------------------------------------------------------------===//
+
+bool PostDominanceInfo::properlyPostDominates(Operation *a, Operation *b,
+ bool enclosingOpOk) const {
+ return super::properlyDominatesImpl(a->getBlock(), a->getIterator(),
+ b->getBlock(), b->getIterator(),
+ enclosingOpOk);
+}
+
+bool PostDominanceInfo::properlyPostDominates(Block *a, Block *b) const {
+ return super::properlyDominatesImpl(a, a->end(), b, b->end(),
+ /*enclosingOk=*/true);
+}
diff --git a/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp
index cf58bc5d8f47..659ab1227f11 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp
@@ -237,15 +237,7 @@ public:
generateMetadata(value.getInt(), "maxnreg");
} else if (attribute.getName() ==
NVVM::NVVMDialect::getKernelFuncAttrName()) {
- llvm::Metadata *llvmMetadataKernel[] = {
- llvm::ValueAsMetadata::get(llvmFunc),
- llvm::MDString::get(llvmContext, "kernel"),
- llvm::ValueAsMetadata::get(
- llvm::ConstantInt::get(llvm::Type::getInt32Ty(llvmContext), 1))};
- llvm::MDNode *llvmMetadataNode =
- llvm::MDNode::get(llvmContext, llvmMetadataKernel);
- moduleTranslation.getOrInsertNamedModuleMetadata("nvvm.annotations")
- ->addOperand(llvmMetadataNode);
+ llvmFunc->setCallingConv(llvm::CallingConv::PTX_Kernel);
}
return success();
}
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index 9a30266103b1..87cb7f03fec6 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -150,10 +150,6 @@ static LogicalResult checkImplementationStatus(Operation &op) {
<< " operation";
};
- auto checkAligned = [&todo](auto op, LogicalResult &result) {
- if (!op.getAlignedVars().empty() || op.getAlignments())
- result = todo("aligned");
- };
auto checkAllocate = [&todo](auto op, LogicalResult &result) {
if (!op.getAllocateVars().empty() || !op.getAllocatorVars().empty())
result = todo("allocate");
@@ -275,7 +271,6 @@ static LogicalResult checkImplementationStatus(Operation &op) {
})
.Case([&](omp::ParallelOp op) { checkAllocate(op, result); })
.Case([&](omp::SimdOp op) {
- checkAligned(op, result);
checkLinear(op, result);
checkNontemporal(op, result);
checkPrivate(op, result);
@@ -2302,6 +2297,24 @@ convertOmpSimd(Operation &opInst, llvm::IRBuilderBase &builder,
llvm::MapVector<llvm::Value *, llvm::Value *> alignedVars;
llvm::omp::OrderKind order = convertOrderKind(simdOp.getOrder());
+ llvm::BasicBlock *sourceBlock = builder.GetInsertBlock();
+ std::optional<ArrayAttr> alignmentValues = simdOp.getAlignments();
+ mlir::OperandRange operands = simdOp.getAlignedVars();
+ for (size_t i = 0; i < operands.size(); ++i) {
+ llvm::Value *alignment = nullptr;
+ llvm::Value *llvmVal = moduleTranslation.lookupValue(operands[i]);
+ llvm::Type *ty = llvmVal->getType();
+ if (auto intAttr = llvm::dyn_cast<IntegerAttr>((*alignmentValues)[i])) {
+ alignment = builder.getInt64(intAttr.getInt());
+ assert(ty->isPointerTy() && "Invalid type for aligned variable");
+ assert(alignment && "Invalid alignment value");
+ auto curInsert = builder.saveIP();
+ builder.SetInsertPoint(sourceBlock->getTerminator());
+ llvmVal = builder.CreateLoad(ty, llvmVal);
+ builder.restoreIP(curInsert);
+ alignedVars[llvmVal] = alignment;
+ }
+ }
ompBuilder->applySimd(loopInfo, alignedVars,
simdOp.getIfExpr()
? moduleTranslation.lookupValue(simdOp.getIfExpr())
@@ -2575,6 +2588,7 @@ static LogicalResult
convertOmpThreadprivate(Operation &opInst, llvm::IRBuilderBase &builder,
LLVM::ModuleTranslation &moduleTranslation) {
llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
+ llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
auto threadprivateOp = cast<omp::ThreadprivateOp>(opInst);
if (failed(checkImplementationStatus(opInst)))
@@ -2582,6 +2596,10 @@ convertOmpThreadprivate(Operation &opInst, llvm::IRBuilderBase &builder,
Value symAddr = threadprivateOp.getSymAddr();
auto *symOp = symAddr.getDefiningOp();
+
+ if (auto asCast = dyn_cast<LLVM::AddrSpaceCastOp>(symOp))
+ symOp = asCast.getOperand().getDefiningOp();
+
if (!isa<LLVM::AddressOfOp>(symOp))
return opInst.emitError("Addressing symbol not found");
LLVM::AddressOfOp addressOfOp = dyn_cast<LLVM::AddressOfOp>(symOp);
@@ -2589,17 +2607,20 @@ convertOmpThreadprivate(Operation &opInst, llvm::IRBuilderBase &builder,
LLVM::GlobalOp global =
addressOfOp.getGlobal(moduleTranslation.symbolTable());
llvm::GlobalValue *globalValue = moduleTranslation.lookupGlobal(global);
- llvm::Type *type = globalValue->getValueType();
- llvm::TypeSize typeSize =
- builder.GetInsertBlock()->getModule()->getDataLayout().getTypeStoreSize(
- type);
- llvm::ConstantInt *size = builder.getInt64(typeSize.getFixedValue());
- llvm::StringRef suffix = llvm::StringRef(".cache", 6);
- std::string cacheName = (Twine(global.getSymName()).concat(suffix)).str();
- llvm::Value *callInst =
- moduleTranslation.getOpenMPBuilder()->createCachedThreadPrivate(
- ompLoc, globalValue, size, cacheName);
- moduleTranslation.mapValue(opInst.getResult(0), callInst);
+
+ if (!ompBuilder->Config.isTargetDevice()) {
+ llvm::Type *type = globalValue->getValueType();
+ llvm::TypeSize typeSize =
+ builder.GetInsertBlock()->getModule()->getDataLayout().getTypeStoreSize(
+ type);
+ llvm::ConstantInt *size = builder.getInt64(typeSize.getFixedValue());
+ llvm::Value *callInst = ompBuilder->createCachedThreadPrivate(
+ ompLoc, globalValue, size, global.getSymName() + ".cache");
+ moduleTranslation.mapValue(opInst.getResult(0), callInst);
+ } else {
+ moduleTranslation.mapValue(opInst.getResult(0), globalValue);
+ }
+
return success();
}
@@ -4199,6 +4220,14 @@ static bool isTargetDeviceOp(Operation *op) {
if (op->getParentOfType<omp::TargetOp>())
return true;
+ // Certain operations return results, and whether utilised in host or
+ // target there is a chance an LLVM Dialect operation depends on it
+ // by taking it in as an operand, so we must always lower these in
+ // some manner or result in an ICE (whether they end up in a no-op
+ // or otherwise).
+ if (mlir::isa<omp::ThreadprivateOp>(op))
+ return true;
+
if (auto parentFn = op->getParentOfType<LLVM::LLVMFuncOp>())
if (auto declareTargetIface =
llvm::dyn_cast<mlir::omp::DeclareTargetInterface>(
diff --git a/mlir/lib/Target/LLVMIR/ModuleImport.cpp b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
index b0d5e635248d..2d8d7745eca9 100644
--- a/mlir/lib/Target/LLVMIR/ModuleImport.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
@@ -427,19 +427,33 @@ ModuleImport::processAliasScopeMetadata(const llvm::MDNode *node) {
return node->getNumOperands() != 0 &&
node == dyn_cast<llvm::MDNode>(node->getOperand(0));
};
+ auto verifySelfRefOrString = [](const llvm::MDNode *node) {
+ return node->getNumOperands() != 0 &&
+ (node == dyn_cast<llvm::MDNode>(node->getOperand(0)) ||
+ isa<llvm::MDString>(node->getOperand(0)));
+ };
// Helper that verifies the given operand is a string or does not exist.
auto verifyDescription = [](const llvm::MDNode *node, unsigned idx) {
return idx >= node->getNumOperands() ||
isa<llvm::MDString>(node->getOperand(idx));
};
+
+ auto getIdAttr = [&](const llvm::MDNode *node) -> Attribute {
+ if (verifySelfRef(node))
+ return DistinctAttr::create(builder.getUnitAttr());
+
+ auto name = cast<llvm::MDString>(node->getOperand(0));
+ return builder.getStringAttr(name->getString());
+ };
+
// Helper that creates an alias scope domain attribute.
auto createAliasScopeDomainOp = [&](const llvm::MDNode *aliasDomain) {
StringAttr description = nullptr;
if (aliasDomain->getNumOperands() >= 2)
if (auto *operand = dyn_cast<llvm::MDString>(aliasDomain->getOperand(1)))
description = builder.getStringAttr(operand->getString());
- return builder.getAttr<AliasScopeDomainAttr>(
- DistinctAttr::create(builder.getUnitAttr()), description);
+ Attribute idAttr = getIdAttr(aliasDomain);
+ return builder.getAttr<AliasScopeDomainAttr>(idAttr, description);
};
// Collect the alias scopes and domains to translate them.
@@ -452,10 +466,11 @@ ModuleImport::processAliasScopeMetadata(const llvm::MDNode *node) {
// verifying its domain. Perform the verification before looking it up in
// the alias scope mapping since it could have been inserted as a domain
// node before.
- if (!verifySelfRef(scope) || !domain || !verifyDescription(scope, 2))
+ if (!verifySelfRefOrString(scope) || !domain ||
+ !verifyDescription(scope, 2))
return emitError(loc) << "unsupported alias scope node: "
<< diagMD(scope, llvmModule.get());
- if (!verifySelfRef(domain) || !verifyDescription(domain, 1))
+ if (!verifySelfRefOrString(domain) || !verifyDescription(domain, 1))
return emitError(loc) << "unsupported alias domain node: "
<< diagMD(domain, llvmModule.get());
@@ -473,9 +488,10 @@ ModuleImport::processAliasScopeMetadata(const llvm::MDNode *node) {
StringAttr description = nullptr;
if (!aliasScope.getName().empty())
description = builder.getStringAttr(aliasScope.getName());
+ Attribute idAttr = getIdAttr(scope);
auto aliasScopeOp = builder.getAttr<AliasScopeAttr>(
- DistinctAttr::create(builder.getUnitAttr()),
- cast<AliasScopeDomainAttr>(it->second), description);
+ idAttr, cast<AliasScopeDomainAttr>(it->second), description);
+
aliasScopeMapping.try_emplace(aliasScope.getNode(), aliasScopeOp);
}
}
@@ -1473,18 +1489,20 @@ ModuleImport::convertBranchArgs(llvm::Instruction *branch,
return success();
}
-LogicalResult
-ModuleImport::convertCallTypeAndOperands(llvm::CallBase *callInst,
- SmallVectorImpl<Type> &types,
- SmallVectorImpl<Value> &operands) {
+LogicalResult ModuleImport::convertCallTypeAndOperands(
+ llvm::CallBase *callInst, SmallVectorImpl<Type> &types,
+ SmallVectorImpl<Value> &operands, bool allowInlineAsm) {
if (!callInst->getType()->isVoidTy())
types.push_back(convertType(callInst->getType()));
if (!callInst->getCalledFunction()) {
- FailureOr<Value> called = convertValue(callInst->getCalledOperand());
- if (failed(called))
- return failure();
- operands.push_back(*called);
+ if (!allowInlineAsm ||
+ !isa<llvm::InlineAsm>(callInst->getCalledOperand())) {
+ FailureOr<Value> called = convertValue(callInst->getCalledOperand());
+ if (failed(called))
+ return failure();
+ operands.push_back(*called);
+ }
}
SmallVector<llvm::Value *> args(callInst->args());
FailureOr<SmallVector<Value>> arguments = convertValues(args);
@@ -1579,7 +1597,8 @@ LogicalResult ModuleImport::convertInstruction(llvm::Instruction *inst) {
SmallVector<Type> types;
SmallVector<Value> operands;
- if (failed(convertCallTypeAndOperands(callInst, types, operands)))
+ if (failed(convertCallTypeAndOperands(callInst, types, operands,
+ /*allowInlineAsm=*/true)))
return failure();
auto funcTy =
@@ -1587,45 +1606,59 @@ LogicalResult ModuleImport::convertInstruction(llvm::Instruction *inst) {
if (!funcTy)
return failure();
- CallOp callOp;
-
- if (llvm::Function *callee = callInst->getCalledFunction()) {
- callOp = builder.create<CallOp>(
- loc, funcTy, SymbolRefAttr::get(context, callee->getName()),
- operands);
+ if (auto asmI = dyn_cast<llvm::InlineAsm>(callInst->getCalledOperand())) {
+ auto callOp = builder.create<InlineAsmOp>(
+ loc, funcTy.getReturnType(), operands,
+ builder.getStringAttr(asmI->getAsmString()),
+ builder.getStringAttr(asmI->getConstraintString()),
+ /*has_side_effects=*/true,
+ /*is_align_stack=*/false, /*asm_dialect=*/nullptr,
+ /*operand_attrs=*/nullptr);
+ if (!callInst->getType()->isVoidTy())
+ mapValue(inst, callOp.getResult(0));
+ else
+ mapNoResultOp(inst, callOp);
} else {
- callOp = builder.create<CallOp>(loc, funcTy, operands);
+ CallOp callOp;
+
+ if (llvm::Function *callee = callInst->getCalledFunction()) {
+ callOp = builder.create<CallOp>(
+ loc, funcTy, SymbolRefAttr::get(context, callee->getName()),
+ operands);
+ } else {
+ callOp = builder.create<CallOp>(loc, funcTy, operands);
+ }
+ callOp.setCConv(convertCConvFromLLVM(callInst->getCallingConv()));
+ callOp.setTailCallKind(
+ convertTailCallKindFromLLVM(callInst->getTailCallKind()));
+ setFastmathFlagsAttr(inst, callOp);
+
+ // Handle function attributes.
+ if (callInst->hasFnAttr(llvm::Attribute::Convergent))
+ callOp.setConvergent(true);
+ if (callInst->hasFnAttr(llvm::Attribute::NoUnwind))
+ callOp.setNoUnwind(true);
+ if (callInst->hasFnAttr(llvm::Attribute::WillReturn))
+ callOp.setWillReturn(true);
+
+ llvm::MemoryEffects memEffects = callInst->getMemoryEffects();
+ ModRefInfo othermem = convertModRefInfoFromLLVM(
+ memEffects.getModRef(llvm::MemoryEffects::Location::Other));
+ ModRefInfo argMem = convertModRefInfoFromLLVM(
+ memEffects.getModRef(llvm::MemoryEffects::Location::ArgMem));
+ ModRefInfo inaccessibleMem = convertModRefInfoFromLLVM(
+ memEffects.getModRef(llvm::MemoryEffects::Location::InaccessibleMem));
+ auto memAttr = MemoryEffectsAttr::get(callOp.getContext(), othermem,
+ argMem, inaccessibleMem);
+ // Only set the attribute when it does not match the default value.
+ if (!memAttr.isReadWrite())
+ callOp.setMemoryEffectsAttr(memAttr);
+
+ if (!callInst->getType()->isVoidTy())
+ mapValue(inst, callOp.getResult());
+ else
+ mapNoResultOp(inst, callOp);
}
- callOp.setCConv(convertCConvFromLLVM(callInst->getCallingConv()));
- callOp.setTailCallKind(
- convertTailCallKindFromLLVM(callInst->getTailCallKind()));
- setFastmathFlagsAttr(inst, callOp);
-
- // Handle function attributes.
- if (callInst->hasFnAttr(llvm::Attribute::Convergent))
- callOp.setConvergent(true);
- if (callInst->hasFnAttr(llvm::Attribute::NoUnwind))
- callOp.setNoUnwind(true);
- if (callInst->hasFnAttr(llvm::Attribute::WillReturn))
- callOp.setWillReturn(true);
-
- llvm::MemoryEffects memEffects = callInst->getMemoryEffects();
- ModRefInfo othermem = convertModRefInfoFromLLVM(
- memEffects.getModRef(llvm::MemoryEffects::Location::Other));
- ModRefInfo argMem = convertModRefInfoFromLLVM(
- memEffects.getModRef(llvm::MemoryEffects::Location::ArgMem));
- ModRefInfo inaccessibleMem = convertModRefInfoFromLLVM(
- memEffects.getModRef(llvm::MemoryEffects::Location::InaccessibleMem));
- auto memAttr = MemoryEffectsAttr::get(callOp.getContext(), othermem, argMem,
- inaccessibleMem);
- // Only set the attribute when it does not match the default value.
- if (!memAttr.isReadWrite())
- callOp.setMemoryEffectsAttr(memAttr);
-
- if (!callInst->getType()->isVoidTy())
- mapValue(inst, callOp.getResult());
- else
- mapNoResultOp(inst, callOp);
return success();
}
if (inst->getOpcode() == llvm::Instruction::LandingPad) {
diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
index ad62ae0cef57..4367100e3aca 100644
--- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
@@ -1724,25 +1724,36 @@ ModuleTranslation::getOrCreateAliasScope(AliasScopeAttr aliasScopeAttr) {
aliasScopeAttr.getDomain(), nullptr);
if (insertedDomain) {
llvm::SmallVector<llvm::Metadata *, 2> operands;
- // Placeholder for self-reference.
+ // Placeholder for potential self-reference.
operands.push_back(dummy.get());
if (StringAttr description = aliasScopeAttr.getDomain().getDescription())
operands.push_back(llvm::MDString::get(ctx, description));
domainIt->second = llvm::MDNode::get(ctx, operands);
// Self-reference for uniqueness.
- domainIt->second->replaceOperandWith(0, domainIt->second);
+ llvm::Metadata *replacement;
+ if (auto stringAttr =
+ dyn_cast<StringAttr>(aliasScopeAttr.getDomain().getId()))
+ replacement = llvm::MDString::get(ctx, stringAttr.getValue());
+ else
+ replacement = domainIt->second;
+ domainIt->second->replaceOperandWith(0, replacement);
}
// Convert the scope metadata node.
assert(domainIt->second && "Scope's domain should already be valid");
llvm::SmallVector<llvm::Metadata *, 3> operands;
- // Placeholder for self-reference.
+ // Placeholder for potential self-reference.
operands.push_back(dummy.get());
operands.push_back(domainIt->second);
if (StringAttr description = aliasScopeAttr.getDescription())
operands.push_back(llvm::MDString::get(ctx, description));
scopeIt->second = llvm::MDNode::get(ctx, operands);
// Self-reference for uniqueness.
- scopeIt->second->replaceOperandWith(0, scopeIt->second);
+ llvm::Metadata *replacement;
+ if (auto stringAttr = dyn_cast<StringAttr>(aliasScopeAttr.getId()))
+ replacement = llvm::MDString::get(ctx, stringAttr.getValue());
+ else
+ replacement = scopeIt->second;
+ scopeIt->second->replaceOperandWith(0, replacement);
return scopeIt->second;
}
diff --git a/mlir/lib/Transforms/LocationSnapshot.cpp b/mlir/lib/Transforms/LocationSnapshot.cpp
index b85850acda91..f701c8b4f0a9 100644
--- a/mlir/lib/Transforms/LocationSnapshot.cpp
+++ b/mlir/lib/Transforms/LocationSnapshot.cpp
@@ -10,6 +10,7 @@
#include "mlir/IR/AsmState.h"
#include "mlir/IR/Builders.h"
+#include "mlir/IR/OperationSupport.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Support/FileUtilities.h"
#include "llvm/Support/FileSystem.h"
@@ -131,29 +132,23 @@ LogicalResult mlir::generateLocationsFromIR(StringRef fileName, StringRef tag,
namespace {
struct LocationSnapshotPass
: public impl::LocationSnapshotBase<LocationSnapshotPass> {
- LocationSnapshotPass() = default;
- LocationSnapshotPass(OpPrintingFlags flags, StringRef fileName, StringRef tag)
- : flags(flags) {
- this->fileName = fileName.str();
- this->tag = tag.str();
- }
+ using impl::LocationSnapshotBase<LocationSnapshotPass>::LocationSnapshotBase;
void runOnOperation() override {
Operation *op = getOperation();
- if (failed(generateLocationsFromIR(fileName, op, OpPrintingFlags(), tag)))
+ if (failed(generateLocationsFromIR(fileName, op, getFlags(), tag)))
return signalPassFailure();
}
- /// The printing flags to use when creating the snapshot.
- OpPrintingFlags flags;
+private:
+ /// build the flags from the command line arguments to the pass
+ OpPrintingFlags getFlags() {
+ OpPrintingFlags flags;
+ flags.enableDebugInfo(enableDebugInfo, printPrettyDebugInfo);
+ flags.printGenericOpForm(printGenericOpForm);
+ if (useLocalScope)
+ flags.useLocalScope();
+ return flags;
+ }
};
} // namespace
-
-std::unique_ptr<Pass> mlir::createLocationSnapshotPass(OpPrintingFlags flags,
- StringRef fileName,
- StringRef tag) {
- return std::make_unique<LocationSnapshotPass>(flags, fileName, tag);
-}
-std::unique_ptr<Pass> mlir::createLocationSnapshotPass() {
- return std::make_unique<LocationSnapshotPass>();
-}
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 255b0ba2559e..403321d40d53 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -11,6 +11,7 @@
#include "mlir/IR/Block.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
+#include "mlir/IR/Dominance.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/Iterators.h"
#include "mlir/Interfaces/FunctionInterfaces.h"
@@ -63,11 +64,55 @@ static OpBuilder::InsertPoint computeInsertPoint(Value value) {
return OpBuilder::InsertPoint(insertBlock, insertPt);
}
+/// Helper function that computes an insertion point where the given values are
+/// defined and can be used without a dominance violation.
+static OpBuilder::InsertPoint computeInsertPoint(ArrayRef<Value> vals) {
+ assert(!vals.empty() && "expected at least one value");
+ DominanceInfo domInfo;
+ OpBuilder::InsertPoint pt = computeInsertPoint(vals.front());
+ for (Value v : vals.drop_front()) {
+ // Choose the "later" insertion point.
+ OpBuilder::InsertPoint nextPt = computeInsertPoint(v);
+ if (domInfo.dominates(pt.getBlock(), pt.getPoint(), nextPt.getBlock(),
+ nextPt.getPoint())) {
+ // pt is before nextPt => choose nextPt.
+ pt = nextPt;
+ } else {
+#ifndef NDEBUG
+ // nextPt should be before pt => choose pt.
+ // If pt, nextPt are no dominance relationship, then there is no valid
+ // insertion point at which all given values are defined.
+ bool dom = domInfo.dominates(nextPt.getBlock(), nextPt.getPoint(),
+ pt.getBlock(), pt.getPoint());
+ assert(dom && "unable to find valid insertion point");
+#endif // NDEBUG
+ }
+ }
+ return pt;
+}
+
//===----------------------------------------------------------------------===//
// ConversionValueMapping
//===----------------------------------------------------------------------===//
+/// A vector of SSA values, optimized for the most common case of a single
+/// value.
+using ValueVector = SmallVector<Value, 1>;
+
namespace {
+
+/// Helper class to make it possible to use `ValueVector` as a key in DenseMap.
+struct ValueVectorMapInfo {
+ static ValueVector getEmptyKey() { return ValueVector{Value()}; }
+ static ValueVector getTombstoneKey() { return ValueVector{Value(), Value()}; }
+ static ::llvm::hash_code getHashValue(const ValueVector &val) {
+ return ::llvm::hash_combine_range(val.begin(), val.end());
+ }
+ static bool isEqual(const ValueVector &LHS, const ValueVector &RHS) {
+ return LHS == RHS;
+ }
+};
+
/// This class wraps a IRMapping to provide recursive lookup
/// functionality, i.e. we will traverse if the mapped value also has a mapping.
struct ConversionValueMapping {
@@ -75,68 +120,128 @@ struct ConversionValueMapping {
/// false positives.
bool isMappedTo(Value value) const { return mappedTo.contains(value); }
- /// Lookup the most recently mapped value with the desired type in the
+ /// Lookup the most recently mapped values with the desired types in the
/// mapping.
///
/// Special cases:
- /// - If the desired type is "null", simply return the most recently mapped
- /// value.
- /// - If there is no mapping to the desired type, also return the most
- /// recently mapped value.
- /// - If there is no mapping for the given value at all, return the given
+ /// - If the desired type range is empty, simply return the most recently
+ /// mapped values.
+ /// - If there is no mapping to the desired types, also return the most
+ /// recently mapped values.
+ /// - If there is no mapping for the given values at all, return the given
/// value.
- Value lookupOrDefault(Value from, Type desiredType = nullptr) const;
+ ValueVector lookupOrDefault(Value from, TypeRange desiredTypes = {}) const;
- /// Lookup a mapped value within the map, or return null if a mapping does not
- /// exist. If a mapping exists, this follows the same behavior of
- /// `lookupOrDefault`.
- Value lookupOrNull(Value from, Type desiredType = nullptr) const;
+ /// Lookup the given value within the map, or return an empty vector if the
+ /// value is not mapped. If it is mapped, this follows the same behavior
+ /// as `lookupOrDefault`.
+ ValueVector lookupOrNull(Value from, TypeRange desiredTypes = {}) const;
- /// Map a value to the one provided.
- void map(Value oldVal, Value newVal) {
+ template <typename T>
+ struct IsValueVector : std::is_same<std::decay_t<T>, ValueVector> {};
+
+ /// Map a value vector to the one provided.
+ template <typename OldVal, typename NewVal>
+ std::enable_if_t<IsValueVector<OldVal>::value && IsValueVector<NewVal>::value>
+ map(OldVal &&oldVal, NewVal &&newVal) {
LLVM_DEBUG({
- for (Value it = newVal; it; it = mapping.lookupOrNull(it))
- assert(it != oldVal && "inserting cyclic mapping");
+ ValueVector next(newVal);
+ while (true) {
+ assert(next != oldVal && "inserting cyclic mapping");
+ auto it = mapping.find(next);
+ if (it == mapping.end())
+ break;
+ next = it->second;
+ }
});
- mapping.map(oldVal, newVal);
- mappedTo.insert(newVal);
+ for (Value v : newVal)
+ mappedTo.insert(v);
+
+ mapping[std::forward<OldVal>(oldVal)] = std::forward<NewVal>(newVal);
}
- /// Drop the last mapping for the given value.
- void erase(Value value) { mapping.erase(value); }
+ /// Map a value vector or single value to the one provided.
+ template <typename OldVal, typename NewVal>
+ std::enable_if_t<!IsValueVector<OldVal>::value ||
+ !IsValueVector<NewVal>::value>
+ map(OldVal &&oldVal, NewVal &&newVal) {
+ if constexpr (IsValueVector<OldVal>{}) {
+ map(std::forward<OldVal>(oldVal), ValueVector{newVal});
+ } else if constexpr (IsValueVector<NewVal>{}) {
+ map(ValueVector{oldVal}, std::forward<NewVal>(newVal));
+ } else {
+ map(ValueVector{oldVal}, ValueVector{newVal});
+ }
+ }
+
+ /// Drop the last mapping for the given values.
+ void erase(const ValueVector &value) { mapping.erase(value); }
private:
/// Current value mappings.
- IRMapping mapping;
+ DenseMap<ValueVector, ValueVector, ValueVectorMapInfo> mapping;
/// All SSA values that are mapped to. May contain false positives.
DenseSet<Value> mappedTo;
};
} // namespace
-Value ConversionValueMapping::lookupOrDefault(Value from,
- Type desiredType) const {
- // Try to find the deepest value that has the desired type. If there is no
- // such value, simply return the deepest value.
- Value desiredValue;
+ValueVector
+ConversionValueMapping::lookupOrDefault(Value from,
+ TypeRange desiredTypes) const {
+ // Try to find the deepest values that have the desired types. If there is no
+ // such mapping, simply return the deepest values.
+ ValueVector desiredValue;
+ ValueVector current{from};
do {
- if (!desiredType || from.getType() == desiredType)
- desiredValue = from;
+ // Store the current value if the types match.
+ if (TypeRange(ValueRange(current)) == desiredTypes)
+ desiredValue = current;
+
+ // If possible, Replace each value with (one or multiple) mapped values.
+ ValueVector next;
+ for (Value v : current) {
+ auto it = mapping.find({v});
+ if (it != mapping.end()) {
+ llvm::append_range(next, it->second);
+ } else {
+ next.push_back(v);
+ }
+ }
+ if (next != current) {
+ // If at least one value was replaced, continue the lookup from there.
+ current = std::move(next);
+ continue;
+ }
- Value mappedValue = mapping.lookupOrNull(from);
- if (!mappedValue)
+ // Otherwise: Check if there is a mapping for the entire vector. Such
+ // mappings are materializations. (N:M mapping are not supported for value
+ // replacements.)
+ //
+ // Note: From a correctness point of view, materializations do not have to
+ // be stored (and looked up) in the mapping. But for performance reasons,
+ // we choose to reuse existing IR (when possible) instead of creating it
+ // multiple times.
+ auto it = mapping.find(current);
+ if (it == mapping.end()) {
+ // No mapping found: The lookup stops here.
break;
- from = mappedValue;
+ }
+ current = it->second;
} while (true);
- // If the desired value was found use it, otherwise default to the leaf value.
- return desiredValue ? desiredValue : from;
+ // If the desired values were found use them, otherwise default to the leaf
+ // values.
+ // Note: If `desiredTypes` is empty, this function always returns `current`.
+ return !desiredValue.empty() ? std::move(desiredValue) : std::move(current);
}
-Value ConversionValueMapping::lookupOrNull(Value from, Type desiredType) const {
- Value result = lookupOrDefault(from, desiredType);
- if (result == from || (desiredType && result.getType() != desiredType))
- return nullptr;
+ValueVector ConversionValueMapping::lookupOrNull(Value from,
+ TypeRange desiredTypes) const {
+ ValueVector result = lookupOrDefault(from, desiredTypes);
+ if (result == ValueVector{from} ||
+ (!desiredTypes.empty() && TypeRange(ValueRange(result)) != desiredTypes))
+ return {};
return result;
}
@@ -651,10 +756,6 @@ public:
/// The type of materialization.
enum MaterializationKind {
- /// This materialization materializes a conversion for an illegal block
- /// argument type, to the original one.
- Argument,
-
/// This materialization materializes a conversion from an illegal type to a
/// legal one.
Target,
@@ -673,7 +774,7 @@ public:
UnrealizedConversionCastOp op,
const TypeConverter *converter,
MaterializationKind kind, Type originalType,
- Value mappedValue);
+ ValueVector mappedValues);
static bool classof(const IRRewrite *rewrite) {
return rewrite->getKind() == Kind::UnresolvedMaterialization;
@@ -708,9 +809,9 @@ private:
/// materializations.
Type originalType;
- /// The value in the conversion value mapping that is being replaced by the
+ /// The values in the conversion value mapping that are being replaced by the
/// results of this unresolved materialization.
- Value mappedValue;
+ ValueVector mappedValues;
};
} // namespace
@@ -779,7 +880,7 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
LogicalResult remapValues(StringRef valueDiagTag,
std::optional<Location> inputLoc,
PatternRewriter &rewriter, ValueRange values,
- SmallVector<SmallVector<Value>> &remapped);
+ SmallVector<ValueVector> &remapped);
/// Return "true" if the given operation is ignored, and does not need to be
/// converted.
@@ -820,39 +921,14 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
/// If a cast op was built, it can optionally be returned with the `castOp`
/// output argument.
///
- /// If `valueToMap` is set to a non-null Value, then that value is mapped to
+ /// If `valuesToMap` is set to a non-null Value, then that value is mapped to
/// the results of the unresolved materialization in the conversion value
/// mapping.
ValueRange buildUnresolvedMaterialization(
MaterializationKind kind, OpBuilder::InsertPoint ip, Location loc,
- Value valueToMap, ValueRange inputs, TypeRange outputTypes,
+ ValueVector valuesToMap, ValueRange inputs, TypeRange outputTypes,
Type originalType, const TypeConverter *converter,
UnrealizedConversionCastOp *castOp = nullptr);
- Value buildUnresolvedMaterialization(
- MaterializationKind kind, OpBuilder::InsertPoint ip, Location loc,
- Value valueToMap, ValueRange inputs, Type outputType, Type originalType,
- const TypeConverter *converter,
- UnrealizedConversionCastOp *castOp = nullptr) {
- return buildUnresolvedMaterialization(kind, ip, loc, valueToMap, inputs,
- TypeRange(outputType), originalType,
- converter, castOp)
- .front();
- }
-
- /// Build an N:1 materialization for the given original value that was
- /// replaced with the given replacement values.
- ///
- /// This is a workaround around incomplete 1:N support in the dialect
- /// conversion driver. The conversion mapping can store only 1:1 replacements
- /// and the conversion patterns only support single Value replacements in the
- /// adaptor, so N values must be converted back to a single value. This
- /// function will be deleted when full 1:N support has been added.
- ///
- /// This function inserts an argument materialization back to the original
- /// type.
- void insertNTo1Materialization(OpBuilder::InsertPoint ip, Location loc,
- ValueRange replacements, Value originalValue,
- const TypeConverter *converter);
/// Find a replacement value for the given SSA value in the conversion value
/// mapping. The replacement value must have the same type as the given SSA
@@ -862,16 +938,6 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
Value findOrBuildReplacementValue(Value value,
const TypeConverter *converter);
- /// Unpack an N:1 materialization and return the inputs of the
- /// materialization. This function unpacks only those materializations that
- /// were built with `insertNTo1Materialization`.
- ///
- /// This is a workaround around incomplete 1:N support in the dialect
- /// conversion driver. It allows us to write 1:N conversion patterns while
- /// 1:N support is still missing in the conversion value mapping. This
- /// function will be deleted when full 1:N support has been added.
- SmallVector<Value> unpackNTo1Materialization(Value value);
-
//===--------------------------------------------------------------------===//
// Rewriter Notification Hooks
//===--------------------------------------------------------------------===//
@@ -974,10 +1040,6 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
DenseMap<UnrealizedConversionCastOp, UnresolvedMaterializationRewrite *>
unresolvedMaterializations;
- /// A set of all N:1 materializations that were added to work around
- /// incomplete 1:N support in the dialect conversion driver.
- DenseSet<UnrealizedConversionCastOp> nTo1TempMaterializations;
-
/// The current type converter, or nullptr if no type converter is currently
/// active.
const TypeConverter *currentTypeConverter = nullptr;
@@ -1041,7 +1103,7 @@ void ReplaceBlockArgRewrite::commit(RewriterBase &rewriter) {
});
}
-void ReplaceBlockArgRewrite::rollback() { rewriterImpl.mapping.erase(arg); }
+void ReplaceBlockArgRewrite::rollback() { rewriterImpl.mapping.erase({arg}); }
void ReplaceOperationRewrite::commit(RewriterBase &rewriter) {
auto *listener =
@@ -1082,7 +1144,7 @@ void ReplaceOperationRewrite::commit(RewriterBase &rewriter) {
void ReplaceOperationRewrite::rollback() {
for (auto result : op->getResults())
- rewriterImpl.mapping.erase(result);
+ rewriterImpl.mapping.erase({result});
}
void ReplaceOperationRewrite::cleanup(RewriterBase &rewriter) {
@@ -1101,20 +1163,19 @@ void CreateOperationRewrite::rollback() {
UnresolvedMaterializationRewrite::UnresolvedMaterializationRewrite(
ConversionPatternRewriterImpl &rewriterImpl, UnrealizedConversionCastOp op,
const TypeConverter *converter, MaterializationKind kind, Type originalType,
- Value mappedValue)
+ ValueVector mappedValues)
: OperationRewrite(Kind::UnresolvedMaterialization, rewriterImpl, op),
converterAndKind(converter, kind), originalType(originalType),
- mappedValue(mappedValue) {
+ mappedValues(std::move(mappedValues)) {
assert((!originalType || kind == MaterializationKind::Target) &&
"original type is valid only for target materializations");
rewriterImpl.unresolvedMaterializations[op] = this;
}
void UnresolvedMaterializationRewrite::rollback() {
- if (mappedValue)
- rewriterImpl.mapping.erase(mappedValue);
+ if (!mappedValues.empty())
+ rewriterImpl.mapping.erase(mappedValues);
rewriterImpl.unresolvedMaterializations.erase(getOperation());
- rewriterImpl.nTo1TempMaterializations.erase(getOperation());
op->erase();
}
@@ -1160,7 +1221,7 @@ void ConversionPatternRewriterImpl::undoRewrites(unsigned numRewritesToKeep) {
LogicalResult ConversionPatternRewriterImpl::remapValues(
StringRef valueDiagTag, std::optional<Location> inputLoc,
PatternRewriter &rewriter, ValueRange values,
- SmallVector<SmallVector<Value>> &remapped) {
+ SmallVector<ValueVector> &remapped) {
remapped.reserve(llvm::size(values));
for (const auto &it : llvm::enumerate(values)) {
@@ -1168,18 +1229,11 @@ LogicalResult ConversionPatternRewriterImpl::remapValues(
Type origType = operand.getType();
Location operandLoc = inputLoc ? *inputLoc : operand.getLoc();
- // Find the most recently mapped value. Unpack all temporary N:1
- // materializations. Such conversions are a workaround around missing
- // 1:N support in the ConversionValueMapping. (The conversion patterns
- // already support 1:N replacements.)
- Value repl = mapping.lookupOrDefault(operand);
- SmallVector<Value> unpacked = unpackNTo1Materialization(repl);
-
if (!currentTypeConverter) {
// The current pattern does not have a type converter. I.e., it does not
// distinguish between legal and illegal types. For each operand, simply
- // pass through the most recently mapped value.
- remapped.push_back(std::move(unpacked));
+ // pass through the most recently mapped values.
+ remapped.push_back(mapping.lookupOrDefault(operand));
continue;
}
@@ -1192,51 +1246,28 @@ LogicalResult ConversionPatternRewriterImpl::remapValues(
});
return failure();
}
-
// If a type is converted to 0 types, there is nothing to do.
if (legalTypes.empty()) {
remapped.push_back({});
continue;
}
- if (legalTypes.size() != 1) {
- // TODO: This is a 1:N conversion. The conversion value mapping does not
- // store such materializations yet. If the types of the most recently
- // mapped values do not match, build a target materialization.
- ValueRange unpackedRange(unpacked);
- if (TypeRange(unpackedRange) == legalTypes) {
- remapped.push_back(std::move(unpacked));
- continue;
- }
-
- // Insert a target materialization if the current pattern expects
- // different legalized types.
- ValueRange targetMat = buildUnresolvedMaterialization(
- MaterializationKind::Target, computeInsertPoint(repl), operandLoc,
- /*valueToMap=*/Value(), /*inputs=*/unpacked,
- /*outputType=*/legalTypes, /*originalType=*/origType,
- currentTypeConverter);
- remapped.push_back(targetMat);
+ ValueVector repl = mapping.lookupOrDefault(operand, legalTypes);
+ if (!repl.empty() && TypeRange(ValueRange(repl)) == legalTypes) {
+ // Mapped values have the correct type or there is an existing
+ // materialization. Or the operand is not mapped at all and has the
+ // correct type.
+ remapped.push_back(std::move(repl));
continue;
}
- // Handle 1->1 type conversions.
- Type desiredType = legalTypes.front();
- // Try to find a mapped value with the desired type. (Or the operand itself
- // if the value is not mapped at all.)
- Value newOperand = mapping.lookupOrDefault(operand, desiredType);
- if (newOperand.getType() != desiredType) {
- // If the looked up value's type does not have the desired type, it means
- // that the value was replaced with a value of different type and no
- // target materialization was created yet.
- Value castValue = buildUnresolvedMaterialization(
- MaterializationKind::Target, computeInsertPoint(newOperand),
- operandLoc, /*valueToMap=*/newOperand, /*inputs=*/unpacked,
- /*outputType=*/desiredType, /*originalType=*/origType,
- currentTypeConverter);
- newOperand = castValue;
- }
- remapped.push_back({newOperand});
+ // Create a materialization for the most recently mapped values.
+ repl = mapping.lookupOrDefault(operand);
+ ValueRange castValues = buildUnresolvedMaterialization(
+ MaterializationKind::Target, computeInsertPoint(repl), operandLoc,
+ /*valuesToMap=*/repl, /*inputs=*/repl, /*outputTypes=*/legalTypes,
+ /*originalType=*/origType, currentTypeConverter);
+ remapped.push_back(castValues);
}
return success();
}
@@ -1353,7 +1384,7 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
buildUnresolvedMaterialization(
MaterializationKind::Source,
OpBuilder::InsertPoint(newBlock, newBlock->begin()), origArg.getLoc(),
- /*valueToMap=*/origArg, /*inputs=*/ValueRange(),
+ /*valuesToMap=*/{origArg}, /*inputs=*/ValueRange(),
/*outputType=*/origArgType, /*originalType=*/Type(), converter);
appendRewrite<ReplaceBlockArgRewrite>(block, origArg, converter);
continue;
@@ -1369,19 +1400,11 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
continue;
}
- // This is a 1->1+ mapping. 1->N mappings are not fully supported in the
- // dialect conversion. Therefore, we need an argument materialization to
- // turn the replacement block arguments into a single SSA value that can be
- // used as a replacement.
+ // This is a 1->1+ mapping.
auto replArgs =
newBlock->getArguments().slice(inputMap->inputNo, inputMap->size);
- if (replArgs.size() == 1) {
- mapping.map(origArg, replArgs.front());
- } else {
- insertNTo1Materialization(
- OpBuilder::InsertPoint(newBlock, newBlock->begin()), origArg.getLoc(),
- /*replacements=*/replArgs, /*outputValue=*/origArg, converter);
- }
+ ValueVector replArgVals = llvm::to_vector_of<Value, 1>(replArgs);
+ mapping.map(origArg, std::move(replArgVals));
appendRewrite<ReplaceBlockArgRewrite>(block, origArg, converter);
}
@@ -1402,20 +1425,13 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
/// of input operands.
ValueRange ConversionPatternRewriterImpl::buildUnresolvedMaterialization(
MaterializationKind kind, OpBuilder::InsertPoint ip, Location loc,
- Value valueToMap, ValueRange inputs, TypeRange outputTypes,
+ ValueVector valuesToMap, ValueRange inputs, TypeRange outputTypes,
Type originalType, const TypeConverter *converter,
UnrealizedConversionCastOp *castOp) {
assert((!originalType || kind == MaterializationKind::Target) &&
"original type is valid only for target materializations");
-
- // Avoid materializing an unnecessary cast.
- if (TypeRange(inputs) == outputTypes) {
- if (valueToMap) {
- assert(inputs.size() == 1 && "1:N mapping is not supported");
- mapping.map(valueToMap, inputs.front());
- }
- return inputs;
- }
+ assert(TypeRange(inputs) != outputTypes &&
+ "materialization is not necessary");
// Create an unresolved materialization. We use a new OpBuilder to avoid
// tracking the materialization like we do for other operations.
@@ -1423,37 +1439,23 @@ ValueRange ConversionPatternRewriterImpl::buildUnresolvedMaterialization(
builder.setInsertionPoint(ip.getBlock(), ip.getPoint());
auto convertOp =
builder.create<UnrealizedConversionCastOp>(loc, outputTypes, inputs);
- if (valueToMap) {
- assert(outputTypes.size() == 1 && "1:N mapping is not supported");
- mapping.map(valueToMap, convertOp.getResult(0));
- }
+ if (!valuesToMap.empty())
+ mapping.map(valuesToMap, convertOp.getResults());
if (castOp)
*castOp = convertOp;
- appendRewrite<UnresolvedMaterializationRewrite>(convertOp, converter, kind,
- originalType, valueToMap);
+ appendRewrite<UnresolvedMaterializationRewrite>(
+ convertOp, converter, kind, originalType, std::move(valuesToMap));
return convertOp.getResults();
}
-void ConversionPatternRewriterImpl::insertNTo1Materialization(
- OpBuilder::InsertPoint ip, Location loc, ValueRange replacements,
- Value originalValue, const TypeConverter *converter) {
- // Insert argument materialization back to the original type.
- Type originalType = originalValue.getType();
- UnrealizedConversionCastOp argCastOp;
- buildUnresolvedMaterialization(
- MaterializationKind::Argument, ip, loc, /*valueToMap=*/originalValue,
- /*inputs=*/replacements, originalType,
- /*originalType=*/Type(), converter, &argCastOp);
- if (argCastOp)
- nTo1TempMaterializations.insert(argCastOp);
-}
-
Value ConversionPatternRewriterImpl::findOrBuildReplacementValue(
Value value, const TypeConverter *converter) {
- // Find a replacement value with the same type.
- Value repl = mapping.lookupOrNull(value, value.getType());
- if (repl)
- return repl;
+ // Try to find a replacement value with the same type in the conversion value
+ // mapping. This includes cached materializations. We try to reuse those
+ // instead of generating duplicate IR.
+ ValueVector repl = mapping.lookupOrNull(value, value.getType());
+ if (!repl.empty())
+ return repl.front();
// Check if the value is dead. No replacement value is needed in that case.
// This is an approximate check that may have false negatives but does not
@@ -1468,7 +1470,7 @@ Value ConversionPatternRewriterImpl::findOrBuildReplacementValue(
// (regardless of the type) and build a source materialization to the
// original type.
repl = mapping.lookupOrNull(value);
- if (!repl) {
+ if (repl.empty()) {
// No replacement value is registered in the mapping. This means that the
// value is dropped and no longer needed. (If the value were still needed,
// a source materialization producing a replacement value "out of thin air"
@@ -1476,34 +1478,22 @@ Value ConversionPatternRewriterImpl::findOrBuildReplacementValue(
// `applySignatureConversion`.)
return Value();
}
- Value castValue = buildUnresolvedMaterialization(
- MaterializationKind::Source, computeInsertPoint(repl), value.getLoc(),
- /*valueToMap=*/value, /*inputs=*/repl, /*outputType=*/value.getType(),
- /*originalType=*/Type(), converter);
- mapping.map(value, castValue);
- return castValue;
-}
-SmallVector<Value>
-ConversionPatternRewriterImpl::unpackNTo1Materialization(Value value) {
- // Unpack unrealized_conversion_cast ops that were inserted as a N:1
- // workaround.
- auto castOp = value.getDefiningOp<UnrealizedConversionCastOp>();
- if (!castOp)
- return {value};
- if (!nTo1TempMaterializations.contains(castOp))
- return {value};
- assert(castOp->getNumResults() == 1 && "expected single result");
-
- SmallVector<Value> result;
- for (Value v : castOp.getOperands()) {
- // Keep unpacking if possible. This is needed because during block
- // signature conversions and 1:N op replacements, the driver may have
- // inserted two materializations back-to-back: first an argument
- // materialization, then a target materialization.
- llvm::append_range(result, unpackNTo1Materialization(v));
- }
- return result;
+ // Note: `computeInsertPoint` computes the "earliest" insertion point at
+ // which all values in `repl` are defined. It is important to emit the
+ // materialization at that location because the same materialization may be
+ // reused in a different context. (That's because materializations are cached
+ // in the conversion value mapping.) The insertion point of the
+ // materialization must be valid for all future users that may be created
+ // later in the conversion process.
+ Value castValue =
+ buildUnresolvedMaterialization(MaterializationKind::Source,
+ computeInsertPoint(repl), value.getLoc(),
+ /*valuesToMap=*/repl, /*inputs=*/repl,
+ /*outputType=*/value.getType(),
+ /*originalType=*/Type(), converter)
+ .front();
+ return castValue;
}
//===----------------------------------------------------------------------===//
@@ -1554,7 +1544,7 @@ void ConversionPatternRewriterImpl::notifyOpReplaced(
// Materialize a replacement value "out of thin air".
buildUnresolvedMaterialization(
MaterializationKind::Source, computeInsertPoint(result),
- result.getLoc(), /*valueToMap=*/result, /*inputs=*/ValueRange(),
+ result.getLoc(), /*valuesToMap=*/{result}, /*inputs=*/ValueRange(),
/*outputType=*/result.getType(), /*originalType=*/Type(),
currentTypeConverter);
continue;
@@ -1572,16 +1562,7 @@ void ConversionPatternRewriterImpl::notifyOpReplaced(
// Remap result to replacement value.
if (repl.empty())
continue;
-
- if (repl.size() == 1) {
- // Single replacement value: replace directly.
- mapping.map(result, repl.front());
- } else {
- // Multiple replacement values: insert N:1 materialization.
- insertNTo1Materialization(computeInsertPoint(result), result.getLoc(),
- /*replacements=*/repl, /*outputValue=*/result,
- currentTypeConverter);
- }
+ mapping.map(result, repl);
}
appendRewrite<ReplaceOperationRewrite>(op, currentTypeConverter);
@@ -1660,8 +1641,13 @@ void ConversionPatternRewriter::replaceOp(Operation *op, ValueRange newValues) {
<< "** Replace : '" << op->getName() << "'(" << op << ")\n";
});
SmallVector<ValueRange> newVals;
- for (size_t i = 0; i < newValues.size(); ++i)
- newVals.push_back(newValues.slice(i, 1));
+ for (size_t i = 0; i < newValues.size(); ++i) {
+ if (newValues[i]) {
+ newVals.push_back(newValues.slice(i, 1));
+ } else {
+ newVals.push_back(ValueRange());
+ }
+ }
impl->notifyOpReplaced(op, newVals);
}
@@ -1733,7 +1719,7 @@ void ConversionPatternRewriter::replaceUsesOfBlockArgument(BlockArgument from,
}
Value ConversionPatternRewriter::getRemappedValue(Value key) {
- SmallVector<SmallVector<Value>> remappedValues;
+ SmallVector<ValueVector> remappedValues;
if (failed(impl->remapValues("value", /*inputLoc=*/std::nullopt, *this, key,
remappedValues)))
return nullptr;
@@ -1746,7 +1732,7 @@ ConversionPatternRewriter::getRemappedValues(ValueRange keys,
SmallVectorImpl<Value> &results) {
if (keys.empty())
return success();
- SmallVector<SmallVector<Value>> remapped;
+ SmallVector<ValueVector> remapped;
if (failed(impl->remapValues("value", /*inputLoc=*/std::nullopt, *this, keys,
remapped)))
return failure();
@@ -1872,7 +1858,7 @@ ConversionPattern::matchAndRewrite(Operation *op,
getTypeConverter());
// Remap the operands of the operation.
- SmallVector<SmallVector<Value>> remapped;
+ SmallVector<ValueVector> remapped;
if (failed(rewriterImpl.remapValues("operand", op->getLoc(), rewriter,
op->getOperands(), remapped))) {
return failure();
@@ -2625,19 +2611,6 @@ legalizeUnresolvedMaterialization(RewriterBase &rewriter,
rewriter.setInsertionPoint(op);
SmallVector<Value> newMaterialization;
switch (rewrite->getMaterializationKind()) {
- case MaterializationKind::Argument: {
- // Try to materialize an argument conversion.
- assert(op->getNumResults() == 1 && "expected single result");
- Value argMat = converter->materializeArgumentConversion(
- rewriter, op->getLoc(), op.getResultTypes().front(), inputOperands);
- if (argMat) {
- newMaterialization.push_back(argMat);
- break;
- }
- }
- // If an argument materialization failed, fallback to trying a target
- // materialization.
- [[fallthrough]];
case MaterializationKind::Target:
newMaterialization = converter->materializeTargetConversion(
rewriter, op->getLoc(), op.getResultTypes(), inputOperands,
diff --git a/mlir/python/mlir/_mlir_libs/__init__.py b/mlir/python/mlir/_mlir_libs/__init__.py
index c5cb22c6dccb..d021dde05dd8 100644
--- a/mlir/python/mlir/_mlir_libs/__init__.py
+++ b/mlir/python/mlir/_mlir_libs/__init__.py
@@ -58,6 +58,7 @@ def get_include_dirs() -> Sequence[str]:
# needs.
_dialect_registry = None
+_load_on_create_dialects = None
def get_dialect_registry():
@@ -71,6 +72,21 @@ def get_dialect_registry():
return _dialect_registry
+def append_load_on_create_dialect(dialect: str):
+ global _load_on_create_dialects
+ if _load_on_create_dialects is None:
+ _load_on_create_dialects = [dialect]
+ else:
+ _load_on_create_dialects.append(dialect)
+
+
+def get_load_on_create_dialects():
+ global _load_on_create_dialects
+ if _load_on_create_dialects is None:
+ _load_on_create_dialects = []
+ return _load_on_create_dialects
+
+
def _site_initialize():
import importlib
import itertools
@@ -132,15 +148,35 @@ def _site_initialize():
break
class Context(ir._BaseContext):
- def __init__(self, *args, **kwargs):
+ def __init__(self, load_on_create_dialects=None, *args, **kwargs):
super().__init__(*args, **kwargs)
self.append_dialect_registry(get_dialect_registry())
for hook in post_init_hooks:
hook(self)
if not disable_multithreading:
self.enable_multithreading(True)
- if not disable_load_all_available_dialects:
- self.load_all_available_dialects()
+ if load_on_create_dialects is not None:
+ logger.debug(
+ "Loading all dialects from load_on_create_dialects arg %r",
+ load_on_create_dialects,
+ )
+ for dialect in load_on_create_dialects:
+ # This triggers loading the dialect into the context.
+ _ = self.dialects[dialect]
+ else:
+ if disable_load_all_available_dialects:
+ dialects = get_load_on_create_dialects()
+ if dialects:
+ logger.debug(
+ "Loading all dialects from global load_on_create_dialects %r",
+ dialects,
+ )
+ for dialect in dialects:
+ # This triggers loading the dialect into the context.
+ _ = self.dialects[dialect]
+ else:
+ logger.debug("Loading all available dialects")
+ self.load_all_available_dialects()
if init_module:
logger.debug(
"Registering translations from initializer %r", init_module
diff --git a/mlir/python/mlir/dialects/transform/structured.py b/mlir/python/mlir/dialects/transform/structured.py
index 9121aa8e4023..bf40cc532065 100644
--- a/mlir/python/mlir/dialects/transform/structured.py
+++ b/mlir/python/mlir/dialects/transform/structured.py
@@ -141,6 +141,77 @@ class FuseIntoContainingOp(FuseIntoContainingOp):
@_ods_cext.register_operation(_Dialect, replace=True)
+class FuseOp(FuseOp):
+ """Specialization for FuseOp class."""
+
+ @overload
+ def __init__(
+ self,
+ loop_types: Union[Type, Sequence[Type]],
+ target: Union[Operation, Value, OpView],
+ *,
+ tile_sizes: Optional[Union[DynamicIndexList, ArrayAttr]] = None,
+ tile_interchange: OptionalIntList = None,
+ apply_cleanup: Optional[bool] = False,
+ loc=None,
+ ip=None,
+ ):
+ ...
+
+ @overload
+ def __init__(
+ self,
+ target: Union[Operation, Value, OpView],
+ *,
+ tile_sizes: Optional[Union[DynamicIndexList, ArrayAttr]] = None,
+ tile_interchange: OptionalIntList = None,
+ apply_cleanup: Optional[bool] = False,
+ loc=None,
+ ip=None,
+ ):
+ ...
+
+ def __init__(
+ self,
+ loop_types_or_target: Union[Type, Sequence[Type], Operation, OpView, Value],
+ target_or_none: Optional[Union[Operation, Value, OpView]] = None,
+ *,
+ tile_sizes: Optional[Union[DynamicIndexList, ArrayAttr]] = None,
+ tile_interchange: OptionalIntList = None,
+ apply_cleanup: Optional[bool] = False,
+ loc=None,
+ ip=None,
+ ):
+ tile_sizes = tile_sizes if tile_sizes else []
+ tile_interchange = tile_interchange if tile_interchange else []
+ _, tile_sizes, _ = _dispatch_dynamic_index_list(tile_sizes)
+ _, tile_interchange, _ = _dispatch_dynamic_index_list(tile_interchange)
+ num_loops = sum(0 if v == 0 else 1 for v in tile_sizes)
+
+ if isinstance(loop_types_or_target, (Operation, Value, OpView)):
+ loop_types = [transform.AnyOpType.get()] * num_loops
+ target = loop_types_or_target
+ assert target_or_none is None, "Cannot construct FuseOp with two targets."
+ else:
+ loop_types = (
+ ([loop_types_or_target] * num_loops)
+ if isinstance(loop_types_or_target, Type)
+ else loop_types_or_target
+ )
+ target = target_or_none
+ super().__init__(
+ target.type,
+ loop_types,
+ target,
+ tile_sizes=tile_sizes,
+ tile_interchange=tile_interchange,
+ apply_cleanup=apply_cleanup,
+ loc=loc,
+ ip=ip,
+ )
+
+
+@_ods_cext.register_operation(_Dialect, replace=True)
class GeneralizeOp(GeneralizeOp):
"""Specialization for GeneralizeOp class."""
diff --git a/mlir/python/mlir/ir.py b/mlir/python/mlir/ir.py
index 9a6ce462047a..6f37266d5bf3 100644
--- a/mlir/python/mlir/ir.py
+++ b/mlir/python/mlir/ir.py
@@ -5,7 +5,11 @@
from ._mlir_libs._mlir.ir import *
from ._mlir_libs._mlir.ir import _GlobalDebug
from ._mlir_libs._mlir import register_type_caster, register_value_caster
-from ._mlir_libs import get_dialect_registry
+from ._mlir_libs import (
+ get_dialect_registry,
+ append_load_on_create_dialect,
+ get_load_on_create_dialects,
+)
# Convenience decorator for registering user-friendly Attribute builders.
diff --git a/mlir/test/Conversion/AffineToStandard/lower-affine-to-vector.mlir b/mlir/test/Conversion/AffineToStandard/lower-affine-to-vector.mlir
index 58580a194df0..f2e0306073f2 100644
--- a/mlir/test/Conversion/AffineToStandard/lower-affine-to-vector.mlir
+++ b/mlir/test/Conversion/AffineToStandard/lower-affine-to-vector.mlir
@@ -26,7 +26,7 @@ func.func @affine_vector_store(%arg0 : index) {
// CHECK: %[[buf:.*]] = memref.alloc
// CHECK: %[[val:.*]] = arith.constant dense
// CHECK: %[[c_1:.*]] = arith.constant -1 : index
-// CHECK-NEXT: %[[a:.*]] = arith.muli %arg0, %[[c_1]] : index
+// CHECK-NEXT: %[[a:.*]] = arith.muli %arg0, %[[c_1]] overflow<nsw> : index
// CHECK-NEXT: %[[b:.*]] = arith.addi %{{.*}}, %[[a]] : index
// CHECK-NEXT: %[[c7:.*]] = arith.constant 7 : index
// CHECK-NEXT: %[[c:.*]] = arith.addi %[[b]], %[[c7]] : index
diff --git a/mlir/test/Conversion/AffineToStandard/lower-affine.mlir b/mlir/test/Conversion/AffineToStandard/lower-affine.mlir
index 00d7b6b8d65f..550ea71882e1 100644
--- a/mlir/test/Conversion/AffineToStandard/lower-affine.mlir
+++ b/mlir/test/Conversion/AffineToStandard/lower-affine.mlir
@@ -156,7 +156,7 @@ func.func private @get_idx() -> (index)
// CHECK-NEXT: %[[v0:.*]] = call @get_idx() : () -> index
// CHECK-NEXT: %[[c0:.*]] = arith.constant 0 : index
// CHECK-NEXT: %[[cm1:.*]] = arith.constant -1 : index
-// CHECK-NEXT: %[[v1:.*]] = arith.muli %[[v0]], %[[cm1]] : index
+// CHECK-NEXT: %[[v1:.*]] = arith.muli %[[v0]], %[[cm1]] overflow<nsw> : index
// CHECK-NEXT: %[[c20:.*]] = arith.constant 20 : index
// CHECK-NEXT: %[[v2:.*]] = arith.addi %[[v1]], %[[c20]] : index
// CHECK-NEXT: %[[v3:.*]] = arith.cmpi sge, %[[v2]], %[[c0]] : index
@@ -177,7 +177,7 @@ func.func @if_only() {
// CHECK-NEXT: %[[v0:.*]] = call @get_idx() : () -> index
// CHECK-NEXT: %[[c0:.*]] = arith.constant 0 : index
// CHECK-NEXT: %[[cm1:.*]] = arith.constant -1 : index
-// CHECK-NEXT: %[[v1:.*]] = arith.muli %[[v0]], %[[cm1]] : index
+// CHECK-NEXT: %[[v1:.*]] = arith.muli %[[v0]], %[[cm1]] overflow<nsw> : index
// CHECK-NEXT: %[[c20:.*]] = arith.constant 20 : index
// CHECK-NEXT: %[[v2:.*]] = arith.addi %[[v1]], %[[c20]] : index
// CHECK-NEXT: %[[v3:.*]] = arith.cmpi sge, %[[v2]], %[[c0]] : index
@@ -202,7 +202,7 @@ func.func @if_else() {
// CHECK-NEXT: %[[v0:.*]] = call @get_idx() : () -> index
// CHECK-NEXT: %[[c0:.*]] = arith.constant 0 : index
// CHECK-NEXT: %[[cm1:.*]] = arith.constant -1 : index
-// CHECK-NEXT: %[[v1:.*]] = arith.muli %[[v0]], %[[cm1]] : index
+// CHECK-NEXT: %[[v1:.*]] = arith.muli %[[v0]], %[[cm1]] overflow<nsw> : index
// CHECK-NEXT: %[[c20:.*]] = arith.constant 20 : index
// CHECK-NEXT: %[[v2:.*]] = arith.addi %[[v1]], %[[c20]] : index
// CHECK-NEXT: %[[v3:.*]] = arith.cmpi sge, %[[v2]], %[[c0]] : index
@@ -272,7 +272,7 @@ func.func @if_with_yield() -> (i64) {
// CHECK-NEXT: %[[v0:.*]] = call @get_idx() : () -> index
// CHECK-NEXT: %[[c0:.*]] = arith.constant 0 : index
// CHECK-NEXT: %[[cm1:.*]] = arith.constant -1 : index
-// CHECK-NEXT: %[[v1:.*]] = arith.muli %[[v0]], %[[cm1]] : index
+// CHECK-NEXT: %[[v1:.*]] = arith.muli %[[v0]], %[[cm1]] overflow<nsw> : index
// CHECK-NEXT: %[[v2:.*]] = arith.addi %[[v1]], %{{.*}} : index
// CHECK-NEXT: %[[c1:.*]] = arith.constant 1 : index
// CHECK-NEXT: %[[v3:.*]] = arith.addi %[[v2]], %[[c1]] : index
@@ -316,7 +316,7 @@ func.func @if_for() {
%i = call @get_idx() : () -> (index)
// CHECK-NEXT: %[[c0:.*]] = arith.constant 0 : index
// CHECK-NEXT: %[[cm1:.*]] = arith.constant -1 : index
-// CHECK-NEXT: %[[v1:.*]] = arith.muli %[[v0]], %[[cm1]] : index
+// CHECK-NEXT: %[[v1:.*]] = arith.muli %[[v0]], %[[cm1]] overflow<nsw> : index
// CHECK-NEXT: %[[c20:.*]] = arith.constant 20 : index
// CHECK-NEXT: %[[v2:.*]] = arith.addi %[[v1]], %[[c20]] : index
// CHECK-NEXT: %[[v3:.*]] = arith.cmpi sge, %[[v2]], %[[c0]] : index
@@ -371,7 +371,7 @@ func.func @if_for() {
// CHECK-NEXT: %[[c1:.*]] = arith.constant 1 : index
// CHECK-NEXT: for %{{.*}} = %[[c0]] to %[[c42]] step %[[c1]] {
// CHECK-NEXT: %[[cm1:.*]] = arith.constant -1 : index
-// CHECK-NEXT: %[[mul0:.*]] = arith.muli %{{.*}}, %[[cm1]] : index
+// CHECK-NEXT: %[[mul0:.*]] = arith.muli %{{.*}}, %[[cm1]] overflow<nsw> : index
// CHECK-NEXT: %[[add0:.*]] = arith.addi %[[mul0]], %{{.*}} : index
// CHECK-NEXT: %[[max:.*]] = arith.maxsi %{{.*}}, %[[add0]] : index
// CHECK-NEXT: %[[c10:.*]] = arith.constant 10 : index
@@ -448,22 +448,22 @@ func.func @affine_applies(%arg0 : index) {
%one = affine.apply #map3(%symbZero)[%zero]
// CHECK-NEXT: %[[c2:.*]] = arith.constant 2 : index
-// CHECK-NEXT: %[[v2:.*]] = arith.muli %arg0, %[[c2]] : index
+// CHECK-NEXT: %[[v2:.*]] = arith.muli %arg0, %[[c2]] overflow<nsw> : index
// CHECK-NEXT: %[[v3:.*]] = arith.addi %arg0, %[[v2]] : index
// CHECK-NEXT: %[[c3:.*]] = arith.constant 3 : index
-// CHECK-NEXT: %[[v4:.*]] = arith.muli %arg0, %[[c3]] : index
+// CHECK-NEXT: %[[v4:.*]] = arith.muli %arg0, %[[c3]] overflow<nsw> : index
// CHECK-NEXT: %[[v5:.*]] = arith.addi %[[v3]], %[[v4]] : index
// CHECK-NEXT: %[[c4:.*]] = arith.constant 4 : index
-// CHECK-NEXT: %[[v6:.*]] = arith.muli %arg0, %[[c4]] : index
+// CHECK-NEXT: %[[v6:.*]] = arith.muli %arg0, %[[c4]] overflow<nsw> : index
// CHECK-NEXT: %[[v7:.*]] = arith.addi %[[v5]], %[[v6]] : index
// CHECK-NEXT: %[[c5:.*]] = arith.constant 5 : index
-// CHECK-NEXT: %[[v8:.*]] = arith.muli %arg0, %[[c5]] : index
+// CHECK-NEXT: %[[v8:.*]] = arith.muli %arg0, %[[c5]] overflow<nsw> : index
// CHECK-NEXT: %[[v9:.*]] = arith.addi %[[v7]], %[[v8]] : index
// CHECK-NEXT: %[[c6:.*]] = arith.constant 6 : index
-// CHECK-NEXT: %[[v10:.*]] = arith.muli %arg0, %[[c6]] : index
+// CHECK-NEXT: %[[v10:.*]] = arith.muli %arg0, %[[c6]] overflow<nsw> : index
// CHECK-NEXT: %[[v11:.*]] = arith.addi %[[v9]], %[[v10]] : index
// CHECK-NEXT: %[[c7:.*]] = arith.constant 7 : index
-// CHECK-NEXT: %[[v12:.*]] = arith.muli %arg0, %[[c7]] : index
+// CHECK-NEXT: %[[v12:.*]] = arith.muli %arg0, %[[c7]] overflow<nsw> : index
// CHECK-NEXT: %[[v13:.*]] = arith.addi %[[v11]], %[[v12]] : index
%four = affine.apply #map4(%arg0, %arg0, %arg0, %arg0)[%arg0, %arg0, %arg0]
return
@@ -610,7 +610,7 @@ func.func @affine_store(%arg0 : index) {
affine.store %1, %0[%i0 - symbol(%arg0) + 7] : memref<10xf32>
}
// CHECK: %[[cm1:.*]] = arith.constant -1 : index
-// CHECK-NEXT: %[[a:.*]] = arith.muli %{{.*}}, %[[cm1]] : index
+// CHECK-NEXT: %[[a:.*]] = arith.muli %{{.*}}, %[[cm1]] overflow<nsw> : index
// CHECK-NEXT: %[[b:.*]] = arith.addi %{{.*}}, %[[a]] : index
// CHECK-NEXT: %[[c7:.*]] = arith.constant 7 : index
// CHECK-NEXT: %[[c:.*]] = arith.addi %[[b]], %[[c7]] : index
diff --git a/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir b/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir
index 748dfe8c68fc..f52dd6c0d0ce 100644
--- a/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir
+++ b/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir
@@ -633,7 +633,7 @@ gpu.module @test_module_29 {
// CHECK-NEXT: %[[EL1:.*]] = llvm.getelementptr %[[ALLOC]][0, 1] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<(i32, f64)>
// CHECK-NEXT: llvm.store %[[EXT]], %[[EL1]] : f64, !llvm.ptr
// CHECK-NEXT: llvm.call @vprintf(%[[FORMATSTART]], %[[ALLOC]]) : (!llvm.ptr, !llvm.ptr) -> i32
- gpu.printf "Hello: %d\n" %arg0, %arg1 : i32, f32
+ gpu.printf "Hello: %d\n", %arg0, %arg1 : i32, f32
gpu.return
}
}
@@ -969,6 +969,35 @@ gpu.module @test_module_50 {
}
}
+// CHECK-LABEL: gpu.module @test_module_51
+// CHECK: llvm.mlir.global internal constant @[[func_name:.*]]("(unknown)\00") {addr_space = 0 : i32}
+// CHECK: llvm.mlir.global internal constant @[[file_name:.*]]("{{.*}}gpu-to-nvvm.mlir{{.*}}") {addr_space = 0 : i32}
+// CHECK: llvm.mlir.global internal constant @[[message:.*]]("assert message\00") {addr_space = 0 : i32}
+// CHECK: llvm.func @__assertfail(!llvm.ptr, !llvm.ptr, i32, !llvm.ptr, i64) attributes {passthrough = ["noreturn"]}
+// CHECK: llvm.func @test_assert(%[[cond:.*]]: i1) attributes {gpu.kernel, nvvm.kernel} {
+// CHECK: llvm.cond_br %[[cond]], ^[[after_block:.*]], ^[[assert_block:.*]]
+// CHECK: ^[[assert_block]]:
+// CHECK: %[[message_ptr:.*]] = llvm.mlir.addressof @[[message]] : !llvm.ptr
+// CHECK: %[[message_start:.*]] = llvm.getelementptr %[[message_ptr]][0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.array<15 x i8>
+// CHECK: %[[file_ptr:.*]] = llvm.mlir.addressof @[[file_name]] : !llvm.ptr
+// CHECK: %[[file_start:.*]] = llvm.getelementptr %[[file_ptr]][0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.array<{{.*}} x i8>
+// CHECK: %[[func_ptr:.*]] = llvm.mlir.addressof @[[func_name]] : !llvm.ptr
+// CHECK: %[[func_start:.*]] = llvm.getelementptr %[[func_ptr]][0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.array<{{.*}} x i8>
+// CHECK: %[[line_num:.*]] = llvm.mlir.constant({{.*}} : i32) : i32
+// CHECK: %[[ptr:.*]] = llvm.mlir.constant(1 : i64) : i64
+// CHECK: llvm.call @__assertfail(%[[message_start]], %[[file_start]], %[[line_num]], %[[func_start]], %[[ptr]]) : (!llvm.ptr, !llvm.ptr, i32, !llvm.ptr, i64) -> ()
+// CHECK: llvm.br ^[[after_block]]
+// CHECK: ^[[after_block]]:
+// CHECK: llvm.return
+// CHECK: }
+
+gpu.module @test_module_51 {
+ gpu.func @test_assert(%arg0: i1) kernel {
+ cf.assert %arg0, "assert message"
+ gpu.return
+ }
+}
+
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%toplevel_module: !transform.any_op {transform.readonly}) {
%gpu_module = transform.structured.match ops{["gpu.module"]} in %toplevel_module
diff --git a/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl-hip.mlir b/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl-hip.mlir
index 1b904fa142ba..2dc6a5ab2a86 100644
--- a/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl-hip.mlir
+++ b/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl-hip.mlir
@@ -36,7 +36,7 @@ gpu.module @test_module {
// CHECK-NEXT: %[[NARGS1:.*]] = llvm.mlir.constant(1 : i32) : i32
// CHECK-NEXT: %[[ARG0_64:.*]] = llvm.zext %[[ARG0]] : i32 to i64
// CHECK-NEXT: %{{.*}} = llvm.call @__ockl_printf_append_args(%[[DESC1]], %[[NARGS1]], %[[ARG0_64]], %[[CST0]], %[[CST0]], %[[CST0]], %[[CST0]], %[[CST0]], %[[CST0]], %[[ISLAST]]) : (i64, i32, i64, i64, i64, i64, i64, i64, i64, i32) -> i64
- gpu.printf "Hello: %d\n" %arg0 : i32
+ gpu.printf "Hello: %d\n", %arg0 : i32
gpu.return
}
}
diff --git a/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl-opencl.mlir b/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl-opencl.mlir
index 870f5c5016ec..00d1d7d85268 100644
--- a/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl-opencl.mlir
+++ b/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl-opencl.mlir
@@ -9,7 +9,7 @@ gpu.module @test_module {
// CHECK: %[[IMM0:.*]] = llvm.mlir.addressof @[[$PRINT_GLOBAL]] : !llvm.ptr<4>
// CHECK-NEXT: %[[IMM2:.*]] = llvm.getelementptr %[[IMM0]][0, 0] : (!llvm.ptr<4>) -> !llvm.ptr<4>, !llvm.array<11 x i8>
// CHECK-NEXT: %{{.*}} = llvm.call @printf(%[[IMM2]], %[[ARG0]]) vararg(!llvm.func<i32 (ptr<4>, ...)>) : (!llvm.ptr<4>, i32) -> i32
- gpu.printf "Hello: %d\n" %arg0 : i32
+ gpu.printf "Hello: %d\n", %arg0 : i32
gpu.return
}
}
diff --git a/mlir/test/Conversion/GPUToSPIRV/printf.mlir b/mlir/test/Conversion/GPUToSPIRV/printf.mlir
index bc091124ea4c..7fe9752b088d 100644
--- a/mlir/test/Conversion/GPUToSPIRV/printf.mlir
+++ b/mlir/test/Conversion/GPUToSPIRV/printf.mlir
@@ -62,7 +62,7 @@ module attributes {
// CHECK: [[FMTSTR_ADDR:%.*]] = spirv.mlir.addressof [[PRINTMSG]] : !spirv.ptr<!spirv.array<[[ARRAYSIZE]] x i8>, UniformConstant>
// CHECK-NEXT: [[FMTSTR_PTR1:%.*]] = spirv.Bitcast [[FMTSTR_ADDR]] : !spirv.ptr<!spirv.array<[[ARRAYSIZE]] x i8>, UniformConstant> to !spirv.ptr<i8, UniformConstant>
// CHECK-NEXT: {{%.*}} = spirv.CL.printf [[FMTSTR_PTR1]] {{%.*}}, {{%.*}}, {{%.*}} : !spirv.ptr<i8, UniformConstant>, i32, f32, i32 -> i32
- gpu.printf "\nHello, world : %d %f \n Thread id: %d\n" %arg0, %arg1, %2: i32, f32, index
+ gpu.printf "\nHello, world : %d %f \n Thread id: %d\n", %arg0, %arg1, %2: i32, f32, index
// CHECK: spirv.Return
gpu.return
diff --git a/mlir/test/Conversion/MemRefToLLVM/expand-then-convert-to-llvm.mlir b/mlir/test/Conversion/MemRefToLLVM/expand-then-convert-to-llvm.mlir
index a78db9733b7e..1fe4217cde98 100644
--- a/mlir/test/Conversion/MemRefToLLVM/expand-then-convert-to-llvm.mlir
+++ b/mlir/test/Conversion/MemRefToLLVM/expand-then-convert-to-llvm.mlir
@@ -59,7 +59,7 @@ func.func @subview(%0 : memref<64x4xf32, strided<[4, 1], offset: 0>>, %arg0 : in
// CHECK: %[[BASE:.*]] = llvm.extractvalue %[[MEMREF]][0] : !llvm.struct<(ptr, ptr, i64
// CHECK: %[[BASE_ALIGNED:.*]] = llvm.extractvalue %[[MEMREF]][1] : !llvm.struct<(ptr, ptr, i64
// CHECK: %[[STRIDE0:.*]] = llvm.mlir.constant(4 : index) : i64
- // CHECK: %[[DESCSTRIDE0:.*]] = llvm.mul %[[ARG0]], %[[STRIDE0]] : i64
+ // CHECK: %[[DESCSTRIDE0:.*]] = llvm.mul %[[ARG0]], %[[STRIDE0]] overflow<nsw> : i64
// CHECK: %[[TMP:.*]] = builtin.unrealized_conversion_cast %[[DESCSTRIDE0]] : i64 to index
// CHECK: %[[DESCSTRIDE0_V2:.*]] = builtin.unrealized_conversion_cast %[[TMP]] : index to i64
// CHECK: %[[OFF2:.*]] = llvm.add %[[DESCSTRIDE0]], %[[ARG1]] : i64
@@ -95,10 +95,10 @@ func.func @subview_non_zero_addrspace(%0 : memref<64x4xf32, strided<[4, 1], offs
// CHECK: %[[BASE:.*]] = llvm.extractvalue %[[MEMREF]][0] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: %[[BASE_ALIGNED:.*]] = llvm.extractvalue %[[MEMREF]][1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: %[[STRIDE0:.*]] = llvm.mlir.constant(4 : index) : i64
- // CHECK: %[[DESCSTRIDE0:.*]] = llvm.mul %[[ARG0]], %[[STRIDE0]] : i64
+ // CHECK: %[[DESCSTRIDE0:.*]] = llvm.mul %[[ARG0]], %[[STRIDE0]] overflow<nsw> : i64
// CHECK: %[[TMP:.*]] = builtin.unrealized_conversion_cast %[[DESCSTRIDE0]] : i64 to index
// CHECK: %[[DESCSTRIDE0_V2:.*]] = builtin.unrealized_conversion_cast %[[TMP]] : index to i64
- // CHECK: %[[OFF2:.*]] = llvm.add %[[DESCSTRIDE0]], %[[ARG1]] : i64
+ // CHECK: %[[OFF2:.*]] = llvm.add %[[DESCSTRIDE0]], %[[ARG1]] : i64
// CHECK: %[[TMP:.*]] = builtin.unrealized_conversion_cast %[[OFF2]] : i64 to index
// CHECK: %[[OFF2:.*]] = builtin.unrealized_conversion_cast %[[TMP]] : index to i64
// CHECK: %[[DESC:.*]] = llvm.mlir.undef : !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)>
@@ -131,10 +131,10 @@ func.func @subview_const_size(%0 : memref<64x4xf32, strided<[4, 1], offset: 0>>,
// CHECK: %[[BASE:.*]] = llvm.extractvalue %[[MEMREF]][0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: %[[BASE_ALIGNED:.*]] = llvm.extractvalue %[[MEMREF]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: %[[C4:.*]] = llvm.mlir.constant(4 : index) : i64
- // CHECK: %[[DESCSTRIDE0:.*]] = llvm.mul %[[ARG0]], %[[C4]] : i64
+ // CHECK: %[[DESCSTRIDE0:.*]] = llvm.mul %[[ARG0]], %[[C4]] overflow<nsw> : i64
// CHECK: %[[TMP:.*]] = builtin.unrealized_conversion_cast %[[DESCSTRIDE0]] : i64 to index
// CHECK: %[[DESCSTRIDE0_V2:.*]] = builtin.unrealized_conversion_cast %[[TMP]] : index to i64
- // CHECK: %[[OFF2:.*]] = llvm.add %[[DESCSTRIDE0]], %[[ARG1]] : i64
+ // CHECK: %[[OFF2:.*]] = llvm.add %[[DESCSTRIDE0]], %[[ARG1]] : i64
// CHECK: %[[TMP:.*]] = builtin.unrealized_conversion_cast %[[OFF2]] : i64 to index
// CHECK: %[[OFF2:.*]] = builtin.unrealized_conversion_cast %[[TMP]] : index to i64
// CHECK: %[[DESC:.*]] = llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
@@ -168,8 +168,8 @@ func.func @subview_const_stride(%0 : memref<64x4xf32, strided<[4, 1], offset: 0>
// CHECK: %[[BASE:.*]] = llvm.extractvalue %[[MEMREF]][0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: %[[BASE_ALIGNED:.*]] = llvm.extractvalue %[[MEMREF]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: %[[C4:.*]] = llvm.mlir.constant(4 : index) : i64
- // CHECK: %[[OFF0:.*]] = llvm.mul %[[ARG0]], %[[C4]] : i64
- // CHECK: %[[OFF2:.*]] = llvm.add %[[OFF0]], %[[ARG1]] : i64
+ // CHECK: %[[OFF0:.*]] = llvm.mul %[[ARG0]], %[[C4]] overflow<nsw> : i64
+ // CHECK: %[[OFF2:.*]] = llvm.add %[[OFF0]], %[[ARG1]] : i64
// CHECK: %[[TMP:.*]] = builtin.unrealized_conversion_cast %[[OFF2]] : i64 to index
// CHECK: %[[OFF2:.*]] = builtin.unrealized_conversion_cast %[[TMP]] : index to i64
// CHECK: %[[DESC:.*]] = llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
@@ -234,12 +234,12 @@ func.func @subview_mixed_static_dynamic(%0 : memref<64x4xf32, strided<[4, 1], of
// CHECK: %[[BASE:.*]] = llvm.extractvalue %[[MEMREF]][0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: %[[BASE_ALIGNED:.*]] = llvm.extractvalue %[[MEMREF]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: %[[STRIDE0:.*]] = llvm.mlir.constant(4 : index) : i64
- // CHECK: %[[DESCSTRIDE0:.*]] = llvm.mul %[[ARG0]], %[[STRIDE0]] : i64
+ // CHECK: %[[DESCSTRIDE0:.*]] = llvm.mul %[[ARG0]], %[[STRIDE0]] overflow<nsw> : i64
// CHECK: %[[TMP:.*]] = builtin.unrealized_conversion_cast %[[DESCSTRIDE0]] : i64 to index
// CHECK: %[[DESCSTRIDE0_V2:.*]] = builtin.unrealized_conversion_cast %[[TMP]] : index to i64
- // CHECK: %[[OFF0:.*]] = llvm.mul %[[ARG1]], %[[STRIDE0]] : i64
+ // CHECK: %[[OFF0:.*]] = llvm.mul %[[ARG1]], %[[STRIDE0]] overflow<nsw> : i64
// CHECK: %[[BASE_OFF:.*]] = llvm.mlir.constant(8 : index) : i64
- // CHECK: %[[OFF2:.*]] = llvm.add %[[OFF0]], %[[BASE_OFF]] : i64
+ // CHECK: %[[OFF2:.*]] = llvm.add %[[OFF0]], %[[BASE_OFF]] : i64
// CHECK: %[[TMP:.*]] = builtin.unrealized_conversion_cast %[[OFF2]] : i64 to index
// CHECK: %[[OFF2:.*]] = builtin.unrealized_conversion_cast %[[TMP]] : index to i64
// CHECK: %[[DESC:.*]] = llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
@@ -301,7 +301,7 @@ func.func @subview_leading_operands_dynamic(%0 : memref<5x?xf32>) -> memref<3x?x
// CHECK: %[[STRIDE0:.*]] = llvm.extractvalue %[[MEMREF]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
// Compute and insert offset from 2 + dynamic value.
// CHECK: %[[CST_OFF0:.*]] = llvm.mlir.constant(2 : index) : i64
- // CHECK: %[[OFF0:.*]] = llvm.mul %[[STRIDE0]], %[[CST_OFF0]] : i64
+ // CHECK: %[[OFF0:.*]] = llvm.mul %[[STRIDE0]], %[[CST_OFF0]] overflow<nsw> : i64
// CHECK: %[[TMP:.*]] = builtin.unrealized_conversion_cast %[[OFF0]] : i64 to index
// CHECK: %[[OFF0:.*]] = builtin.unrealized_conversion_cast %[[TMP]] : index to i64
// CHECK: %[[DESC:.*]] = llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
@@ -425,7 +425,7 @@ func.func @collapse_shape_dynamic_with_non_identity_layout(
// CHECK: %[[SIZE1:.*]] = llvm.extractvalue %[[MEM]][3, 1] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
// CHECK: %[[SIZE2:.*]] = llvm.extractvalue %[[MEM]][3, 2] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
// CHECK: %[[STRIDE0:.*]] = llvm.extractvalue %[[MEM]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
-// CHECK: %[[FINAL_SIZE1:.*]] = llvm.mul %[[SIZE1]], %[[SIZE2]] : i64
+// CHECK: %[[FINAL_SIZE1:.*]] = llvm.mul %[[SIZE1]], %[[SIZE2]] overflow<nsw> : i64
// CHECK: %[[SIZE1_TO_IDX:.*]] = builtin.unrealized_conversion_cast %[[FINAL_SIZE1]] : i64 to index
// CHECK: %[[FINAL_SIZE1:.*]] = builtin.unrealized_conversion_cast %[[SIZE1_TO_IDX]] : index to i64
// CHECK: %[[DESC:.*]] = llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
@@ -547,7 +547,7 @@ func.func @collapse_shape_dynamic(%arg0 : memref<1x2x?xf32>) -> memref<1x?xf32>
// CHECK: %[[SIZE2:.*]] = llvm.extractvalue %[[MEM]][3, 2] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
// CHECK: %[[STRIDE0:.*]] = llvm.extractvalue %[[MEM]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
// CHECK: %[[C2:.*]] = llvm.mlir.constant(2 : index) : i64
-// CHECK: %[[FINAL_SIZE1:.*]] = llvm.mul %[[SIZE2]], %[[C2]] : i64
+// CHECK: %[[FINAL_SIZE1:.*]] = llvm.mul %[[SIZE2]], %[[C2]] overflow<nsw> : i64
// CHECK: %[[SIZE1_TO_IDX:.*]] = builtin.unrealized_conversion_cast %[[FINAL_SIZE1]] : i64 to index
// CHECK: %[[FINAL_SIZE1:.*]] = builtin.unrealized_conversion_cast %[[SIZE1_TO_IDX]] : index to i64
// CHECK: %[[DESC:.*]] = llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
diff --git a/mlir/test/Conversion/SCFToEmitC/for.mlir b/mlir/test/Conversion/SCFToEmitC/for.mlir
index 83592187a9b6..7f41e636936b 100644
--- a/mlir/test/Conversion/SCFToEmitC/for.mlir
+++ b/mlir/test/Conversion/SCFToEmitC/for.mlir
@@ -7,8 +7,11 @@ func.func @simple_std_for_loop(%arg0 : index, %arg1 : index, %arg2 : index) {
return
}
// CHECK-LABEL: func.func @simple_std_for_loop(
-// CHECK-SAME: %[[VAL_0:.*]]: index, %[[VAL_1:.*]]: index, %[[VAL_2:.*]]: index) {
-// CHECK-NEXT: emitc.for %[[VAL_3:.*]] = %[[VAL_0]] to %[[VAL_1]] step %[[VAL_2]] {
+// CHECK-SAME: %[[ARG_0:.*]]: index, %[[ARG_1:.*]]: index, %[[ARG_2:.*]]: index) {
+// CHECK-NEXT: %[[VAL_2:.*]] = builtin.unrealized_conversion_cast %[[ARG_2]] : index to !emitc.size_t
+// CHECK-NEXT: %[[VAL_1:.*]] = builtin.unrealized_conversion_cast %[[ARG_1]] : index to !emitc.size_t
+// CHECK-NEXT: %[[VAL_0:.*]] = builtin.unrealized_conversion_cast %[[ARG_0]] : index to !emitc.size_t
+// CHECK-NEXT: emitc.for %[[VAL_3:.*]] = %[[VAL_0]] to %[[VAL_1]] step %[[VAL_2]] : !emitc.size_t {
// CHECK-NEXT: %[[VAL_4:.*]] = arith.constant 1 : index
// CHECK-NEXT: }
// CHECK-NEXT: return
@@ -24,10 +27,13 @@ func.func @simple_std_2_for_loops(%arg0 : index, %arg1 : index, %arg2 : index) {
return
}
// CHECK-LABEL: func.func @simple_std_2_for_loops(
-// CHECK-SAME: %[[VAL_0:.*]]: index, %[[VAL_1:.*]]: index, %[[VAL_2:.*]]: index) {
-// CHECK-NEXT: emitc.for %[[VAL_3:.*]] = %[[VAL_0]] to %[[VAL_1]] step %[[VAL_2]] {
+// CHECK-SAME: %[[ARG_0:.*]]: index, %[[ARG_1:.*]]: index, %[[ARG_2:.*]]: index) {
+// CHECK-NEXT: %[[VAL_2:.*]] = builtin.unrealized_conversion_cast %[[ARG_2]] : index to !emitc.size_t
+// CHECK-NEXT: %[[VAL_1:.*]] = builtin.unrealized_conversion_cast %[[ARG_1]] : index to !emitc.size_t
+// CHECK-NEXT: %[[VAL_0:.*]] = builtin.unrealized_conversion_cast %[[ARG_0]] : index to !emitc.size_t
+// CHECK-NEXT: emitc.for %[[VAL_3:.*]] = %[[VAL_0]] to %[[VAL_1]] step %[[VAL_2]] : !emitc.size_t {
// CHECK-NEXT: %[[VAL_4:.*]] = arith.constant 1 : index
-// CHECK-NEXT: emitc.for %[[VAL_5:.*]] = %[[VAL_0]] to %[[VAL_1]] step %[[VAL_2]] {
+// CHECK-NEXT: emitc.for %[[VAL_5:.*]] = %[[VAL_0]] to %[[VAL_1]] step %[[VAL_2]] : !emitc.size_t {
// CHECK-NEXT: %[[VAL_6:.*]] = arith.constant 1 : index
// CHECK-NEXT: }
// CHECK-NEXT: }
@@ -44,14 +50,17 @@ func.func @for_yield(%arg0 : index, %arg1 : index, %arg2 : index) -> (f32, f32)
return %result#0, %result#1 : f32, f32
}
// CHECK-LABEL: func.func @for_yield(
-// CHECK-SAME: %[[VAL_0:.*]]: index, %[[VAL_1:.*]]: index, %[[VAL_2:.*]]: index) -> (f32, f32) {
+// CHECK-SAME: %[[ARG_0:.*]]: index, %[[ARG_1:.*]]: index, %[[ARG_2:.*]]: index) -> (f32, f32) {
+// CHECK-NEXT: %[[VAL_2:.*]] = builtin.unrealized_conversion_cast %[[ARG_2]] : index to !emitc.size_t
+// CHECK-NEXT: %[[VAL_1:.*]] = builtin.unrealized_conversion_cast %[[ARG_1]] : index to !emitc.size_t
+// CHECK-NEXT: %[[VAL_0:.*]] = builtin.unrealized_conversion_cast %[[ARG_0]] : index to !emitc.size_t
// CHECK-NEXT: %[[VAL_3:.*]] = arith.constant 0.000000e+00 : f32
// CHECK-NEXT: %[[VAL_4:.*]] = arith.constant 1.000000e+00 : f32
// CHECK-NEXT: %[[VAL_5:.*]] = "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> !emitc.lvalue<f32>
// CHECK-NEXT: %[[VAL_6:.*]] = "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> !emitc.lvalue<f32>
// CHECK-NEXT: emitc.assign %[[VAL_3]] : f32 to %[[VAL_5]] : <f32>
// CHECK-NEXT: emitc.assign %[[VAL_4]] : f32 to %[[VAL_6]] : <f32>
-// CHECK-NEXT: emitc.for %[[VAL_7:.*]] = %[[VAL_0]] to %[[VAL_1]] step %[[VAL_2]] {
+// CHECK-NEXT: emitc.for %[[VAL_7:.*]] = %[[VAL_0]] to %[[VAL_1]] step %[[VAL_2]] : !emitc.size_t {
// CHECK-NEXT: %[[VAL_8:.*]] = emitc.load %[[VAL_5]] : <f32>
// CHECK-NEXT: %[[VAL_9:.*]] = emitc.load %[[VAL_6]] : <f32>
// CHECK-NEXT: %[[VAL_10:.*]] = arith.addf %[[VAL_8]], %[[VAL_9]] : f32
@@ -75,15 +84,18 @@ func.func @nested_for_yield(%arg0 : index, %arg1 : index, %arg2 : index) -> f32
return %r : f32
}
// CHECK-LABEL: func.func @nested_for_yield(
-// CHECK-SAME: %[[VAL_0:.*]]: index, %[[VAL_1:.*]]: index, %[[VAL_2:.*]]: index) -> f32 {
+// CHECK-SAME: %[[ARG_0:.*]]: index, %[[ARG_1:.*]]: index, %[[ARG_2:.*]]: index) -> f32 {
+// CHECK-NEXT: %[[VAL_2:.*]] = builtin.unrealized_conversion_cast %[[ARG_2]] : index to !emitc.size_t
+// CHECK-NEXT: %[[VAL_1:.*]] = builtin.unrealized_conversion_cast %[[ARG_1]] : index to !emitc.size_t
+// CHECK-NEXT: %[[VAL_0:.*]] = builtin.unrealized_conversion_cast %[[ARG_0]] : index to !emitc.size_t
// CHECK-NEXT: %[[VAL_3:.*]] = arith.constant 1.000000e+00 : f32
// CHECK-NEXT: %[[VAL_4:.*]] = "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> !emitc.lvalue<f32>
// CHECK-NEXT: emitc.assign %[[VAL_3]] : f32 to %[[VAL_4]] : <f32>
-// CHECK-NEXT: emitc.for %[[VAL_5:.*]] = %[[VAL_0]] to %[[VAL_1]] step %[[VAL_2]] {
+// CHECK-NEXT: emitc.for %[[VAL_5:.*]] = %[[VAL_0]] to %[[VAL_1]] step %[[VAL_2]] : !emitc.size_t {
// CHECK-NEXT: %[[VAL_6:.*]] = emitc.load %[[VAL_4]] : <f32>
// CHECK-NEXT: %[[VAL_7:.*]] = "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> !emitc.lvalue<f32>
// CHECK-NEXT: emitc.assign %[[VAL_6]] : f32 to %[[VAL_7]] : <f32>
-// CHECK-NEXT: emitc.for %[[VAL_8:.*]] = %[[VAL_0]] to %[[VAL_1]] step %[[VAL_2]] {
+// CHECK-NEXT: emitc.for %[[VAL_8:.*]] = %[[VAL_0]] to %[[VAL_1]] step %[[VAL_2]] : !emitc.size_t {
// CHECK-NEXT: %[[VAL_9:.*]] = emitc.load %[[VAL_7]] : <f32>
// CHECK-NEXT: %[[VAL_10:.*]] = arith.addf %[[VAL_9]], %[[VAL_9]] : f32
// CHECK-NEXT: emitc.assign %[[VAL_10]] : f32 to %[[VAL_7]] : <f32>
@@ -94,3 +106,60 @@ func.func @nested_for_yield(%arg0 : index, %arg1 : index, %arg2 : index) -> f32
// CHECK-NEXT: %[[VAL_12:.*]] = emitc.load %[[VAL_4]] : <f32>
// CHECK-NEXT: return %[[VAL_12]] : f32
// CHECK-NEXT: }
+
+func.func @for_yield_index(%arg0 : index, %arg1 : index, %arg2 : index) -> index {
+ %zero = arith.constant 0 : index
+ %r = scf.for %i0 = %arg0 to %arg1 step %arg2 iter_args(%acc = %zero) -> index {
+ scf.yield %acc : index
+ }
+ return %r : index
+}
+
+// CHECK-LABEL: func.func @for_yield_index(
+// CHECK-SAME: %[[ARG_0:.*]]: index, %[[ARG_1:.*]]: index, %[[ARG_2:.*]]: index) -> index {
+// CHECK: %[[VAL_0:.*]] = builtin.unrealized_conversion_cast %[[ARG_2]] : index to !emitc.size_t
+// CHECK: %[[VAL_1:.*]] = builtin.unrealized_conversion_cast %[[ARG_1]] : index to !emitc.size_t
+// CHECK: %[[VAL_2:.*]] = builtin.unrealized_conversion_cast %[[ARG_0]] : index to !emitc.size_t
+// CHECK: %[[C0:.*]] = arith.constant 0 : index
+// CHECK: %[[VAL_3:.*]] = builtin.unrealized_conversion_cast %[[C0]] : index to !emitc.size_t
+// CHECK: %[[VAL_4:.*]] = "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> !emitc.lvalue<!emitc.size_t>
+// CHECK: emitc.assign %[[VAL_3]] : !emitc.size_t to %[[VAL_4]] : <!emitc.size_t>
+// CHECK: emitc.for %[[VAL_5:.*]] = %[[VAL_2]] to %[[VAL_1]] step %[[VAL_0]] : !emitc.size_t {
+// CHECK: %[[V:.*]] = emitc.load %[[VAL_4]] : <!emitc.size_t>
+// CHECK: emitc.assign %[[V]] : !emitc.size_t to %[[VAL_4]] : <!emitc.size_t>
+// CHECK: }
+// CHECK: %[[V2:.*]] = emitc.load %[[VAL_4]] : <!emitc.size_t>
+// CHECK: %[[VAL_8:.*]] = builtin.unrealized_conversion_cast %[[V2]] : !emitc.size_t to index
+// CHECK: return %[[VAL_8]] : index
+// CHECK: }
+
+
+func.func @for_yield_update_loop_carried_var(%arg0 : index, %arg1 : index, %arg2 : index) -> index {
+ %zero = arith.constant 0 : index
+ %r = scf.for %i0 = %arg0 to %arg1 step %arg2 iter_args(%acc = %zero) -> index {
+ %sn = arith.addi %acc, %acc : index
+ scf.yield %sn: index
+ }
+ return %r : index
+ }
+
+// CHECK-LABEL: func.func @for_yield_update_loop_carried_var(
+// CHECK-SAME: %[[ARG_0:.*]]: index, %[[ARG_1:.*]]: index, %[[ARG_2:.*]]: index) -> index {
+// CHECK: %[[VAL_0:.*]] = builtin.unrealized_conversion_cast %[[ARG_2]] : index to !emitc.size_t
+// CHECK: %[[VAL_1:.*]] = builtin.unrealized_conversion_cast %[[ARG_1]] : index to !emitc.size_t
+// CHECK: %[[VAL_2:.*]] = builtin.unrealized_conversion_cast %[[ARG_0]] : index to !emitc.size_t
+// CHECK: %[[C0:.*]] = arith.constant 0 : index
+// CHECK: %[[VAL_3:.*]] = builtin.unrealized_conversion_cast %[[C0]] : index to !emitc.size_t
+// CHECK: %[[VAL_4:.*]] = "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> !emitc.lvalue<!emitc.size_t>
+// CHECK: emitc.assign %[[VAL_3]] : !emitc.size_t to %[[VAL_4]] : <!emitc.size_t>
+// CHECK: emitc.for %[[ARG_3:.*]] = %[[VAL_2]] to %[[VAL_1]] step %[[VAL_0]] : !emitc.size_t {
+// CHECK: %[[V:.*]] = emitc.load %[[VAL_4]] : <!emitc.size_t>
+// CHECK: %[[VAL_5:.*]] = builtin.unrealized_conversion_cast %[[V]] : !emitc.size_t to index
+// CHECK: %[[VAL_6:.*]] = arith.addi %[[VAL_5]], %[[VAL_5]] : index
+// CHECK: %[[VAL_8:.*]] = builtin.unrealized_conversion_cast %[[VAL_6]] : index to !emitc.size_t
+// CHECK: emitc.assign %[[VAL_8]] : !emitc.size_t to %[[VAL_4]] : <!emitc.size_t>
+// CHECK: }
+// CHECK: %[[V2:.*]] = emitc.load %[[VAL_4]] : <!emitc.size_t>
+// CHECK: %[[VAL_9:.*]] = builtin.unrealized_conversion_cast %[[V2]] : !emitc.size_t to index
+// CHECK: return %[[VAL_9]] : index
+// CHECK: }
diff --git a/mlir/test/Conversion/SCFToEmitC/switch.mlir b/mlir/test/Conversion/SCFToEmitC/switch.mlir
index 86d96ed21f1b..61015b0ae483 100644
--- a/mlir/test/Conversion/SCFToEmitC/switch.mlir
+++ b/mlir/test/Conversion/SCFToEmitC/switch.mlir
@@ -1,7 +1,8 @@
// RUN: mlir-opt -allow-unregistered-dialect -convert-scf-to-emitc %s | FileCheck %s
// CHECK-LABEL: func.func @switch_no_result(
-// CHECK-SAME: %[[VAL_0:.*]]: index) {
+// CHECK-SAME: %[[ARG_0:.*]]: index) {
+// CHECK: %[[VAL_0:.*]] = builtin.unrealized_conversion_cast %[[ARG_0]] : index to !emitc.size_t
// CHECK: emitc.switch %[[VAL_0]]
// CHECK: case 2 {
// CHECK: %[[VAL_1:.*]] = arith.constant 10 : i32
@@ -33,7 +34,8 @@ func.func @switch_no_result(%arg0 : index) {
}
// CHECK-LABEL: func.func @switch_one_result(
-// CHECK-SAME: %[[VAL_0:.*]]: index) {
+// CHECK-SAME: %[[ARG_0:.*]]: index) {
+// CHECK: %[[VAL_0:.*]] = builtin.unrealized_conversion_cast %[[ARG_0]] : index to !emitc.size_t
// CHECK: %[[VAL_1:.*]] = "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> !emitc.lvalue<i32>
// CHECK: emitc.switch %[[VAL_0]]
// CHECK: case 2 {
@@ -70,7 +72,8 @@ func.func @switch_one_result(%arg0 : index) {
}
// CHECK-LABEL: func.func @switch_two_results(
-// CHECK-SAME: %[[VAL_0:.*]]: index) -> (i32, f32) {
+// CHECK-SAME: %[[ARG_0:.*]]: index) -> (i32, f32) {
+// CHECK: %[[VAL_0:.*]] = builtin.unrealized_conversion_cast %[[ARG_0]] : index to !emitc.size_t
// CHECK: %[[VAL_1:.*]] = "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> !emitc.lvalue<i32>
// CHECK: %[[VAL_2:.*]] = "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> !emitc.lvalue<f32>
// CHECK: emitc.switch %[[VAL_0]]
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
index bfdc72ee07e9..453a8610e716 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
@@ -510,7 +510,7 @@ func.func @avg_pool_dyn(%arg0: tensor<?x6x34x62xf32>) -> (tensor<?x5x33x62xf32>)
func.func @conv2d_scalar_bias_f32(%input: tensor<1x49x42x27xf32>, %weights: tensor<28x3x3x27xf32>, %bias: tensor<1xf32>) -> () {
// CHECK: %[[INIT:.+]] = tensor.empty() : tensor<1x45x40x28xf32>
// CHECK: %[[BROADCAST:.+]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP2]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2 : tensor<1xf32>) outs(%[[INIT]] : tensor<1x45x40x28xf32>) {
- %0 = tosa.conv2d %input, %weights, %bias {pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, dilation = array<i64: 2, 1>} : (tensor<1x49x42x27xf32>, tensor<28x3x3x27xf32>, tensor<1xf32>) -> tensor<1x45x40x28xf32>
+ %0 = tosa.conv2d %input, %weights, %bias {acc_type = f32, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, dilation = array<i64: 2, 1>} : (tensor<1x49x42x27xf32>, tensor<28x3x3x27xf32>, tensor<1xf32>) -> tensor<1x45x40x28xf32>
return
}
@@ -531,7 +531,7 @@ func.func @conv2d_i8(%input: tensor<1x49x42x27xi8>, %weights: tensor<28x1x1x27xi
// CHECK: linalg.conv_2d_nhwc_fhwc_q {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1, %c0_i32, %c0_i32_0 : tensor<1x49x42x27xi8>, tensor<28x1x1x27xi8>, i32, i32) outs(%[[BROADCAST]] : tensor<1x45x40x28xi32>) -> tensor<1x45x40x28xi32>
// HWCF: linalg.conv_2d_nhwc_hwcf_q {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %[[TRANSPOSE]], %c0_i32, %c0_i32_0 : tensor<1x49x42x27xi8>, tensor<1x1x27x28xi8>, i32, i32) outs(%{{[a-zA-Z0-9_]*}} : tensor<1x45x40x28xi32>) -> tensor<1x45x40x28xi32>
- %0 = tosa.conv2d %input, %weights, %bias {dilation = array<i64: 2, 1>, pad = array<i64: 0, 0, 0, 0>, quantization_info = #tosa.conv_quant<input_zp = 0, weight_zp = 0>, stride = array<i64: 1, 1>} : (tensor<1x49x42x27xi8>, tensor<28x1x1x27xi8>, tensor<28xi8>) -> tensor<1x45x40x28xi32>
+ %0 = tosa.conv2d %input, %weights, %bias {acc_type = i32, dilation = array<i64: 2, 1>, pad = array<i64: 0, 0, 0, 0>, quantization_info = #tosa.conv_quant<input_zp = 0, weight_zp = 0>, stride = array<i64: 1, 1>} : (tensor<1x49x42x27xi8>, tensor<28x1x1x27xi8>, tensor<28xi8>) -> tensor<1x45x40x28xi32>
return
}
@@ -552,7 +552,7 @@ func.func @conv2d_f32(%input: tensor<1x49x42x27xf32>, %weights: tensor<28x3x3x27
// CHECK: linalg.conv_2d_nhwc_fhwc {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<1x49x42x27xf32>, tensor<28x3x3x27xf32>) outs(%1 : tensor<1x45x40x28xf32>) -> tensor<1x45x40x28xf32>
// HWCF: linalg.conv_2d_nhwc_hwcf {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %[[TRANSPOSE]] : tensor<1x49x42x27xf32>, tensor<3x3x27x28xf32>) outs(%{{[a-zA-Z0-9_]*}} : tensor<1x45x40x28xf32>
- %0 = tosa.conv2d %input, %weights, %bias {pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, dilation = array<i64: 2, 1>} : (tensor<1x49x42x27xf32>, tensor<28x3x3x27xf32>, tensor<28xf32>) -> tensor<1x45x40x28xf32>
+ %0 = tosa.conv2d %input, %weights, %bias {acc_type = f32, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, dilation = array<i64: 2, 1>} : (tensor<1x49x42x27xf32>, tensor<28x3x3x27xf32>, tensor<28xf32>) -> tensor<1x45x40x28xf32>
return
}
@@ -571,7 +571,7 @@ func.func @conv2d_dyn(%input: tensor<?x49x42x27xf32>, %weights: tensor<28x3x3x27
// CHECK: linalg.yield %[[IN]] : f32
// CHECK: } -> tensor<?x45x40x28xf32>
// CHECK: %2 = linalg.conv_2d_nhwc_fhwc {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<?x49x42x27xf32>, tensor<28x3x3x27xf32>) outs(%[[BROADCAST]] : tensor<?x45x40x28xf32>) -> tensor<?x45x40x28xf32>
- %0 = tosa.conv2d %input, %weights, %bias {pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, dilation = array<i64: 2, 1>} : (tensor<?x49x42x27xf32>, tensor<28x3x3x27xf32>, tensor<28xf32>) -> tensor<?x45x40x28xf32>
+ %0 = tosa.conv2d %input, %weights, %bias {acc_type = f32, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, dilation = array<i64: 2, 1>} : (tensor<?x49x42x27xf32>, tensor<28x3x3x27xf32>, tensor<28xf32>) -> tensor<?x45x40x28xf32>
return
}
@@ -627,7 +627,7 @@ func.func @conv2d_dyn_w_h(%input: tensor<1x?x?x27xf32>, %weights: tensor<28x3x3x
// CHECK: } -> tensor<1x?x?x28xf32>
// CHECK: linalg.conv_2d_nhwc_fhwc {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<1x?x?x27xf32>, tensor<28x3x3x27xf32>) outs(%17 : tensor<1x?x?x28xf32>) -> tensor<1x?x?x28xf32>
- %0 = tosa.conv2d %input, %weights, %bias {pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, dilation = array<i64: 2, 1>} : (tensor<1x?x?x27xf32>, tensor<28x3x3x27xf32>, tensor<28xf32>) -> tensor<1x?x?x28xf32>
+ %0 = tosa.conv2d %input, %weights, %bias {acc_type = f32, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, dilation = array<i64: 2, 1>} : (tensor<1x?x?x27xf32>, tensor<28x3x3x27xf32>, tensor<28xf32>) -> tensor<1x?x?x28xf32>
return
}
@@ -650,7 +650,7 @@ func.func @conv2d_dyn_output(%input: tensor<2x6x5x4xf32>, %weights: tensor<4x3x3
// linalg.yield %[[ADD]] : f32
// } -> tensor<?x4x3x4xf32>
- %0 = tosa.conv2d %input, %weights, %bias {dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<2x6x5x4xf32 >, tensor<4x3x3x4xf32>, tensor<4xf32>) -> tensor<?x4x3x4xf32>
+ %0 = tosa.conv2d %input, %weights, %bias {acc_type = f32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<2x6x5x4xf32 >, tensor<4x3x3x4xf32>, tensor<4xf32>) -> tensor<?x4x3x4xf32>
return
}
@@ -662,7 +662,7 @@ func.func @conv2d_padded_f32(%input: tensor<1x47x40x28xf32>, %weights: tensor<28
// CHECK: tensor.pad %arg0 low[0, 1, 1, 0] high[0, 1, 1, 0]
// CHECK: tensor.yield %[[C0]]
// CHECK: linalg.conv_2d_nhwc_fhwc
- %0 = tosa.conv2d %input, %weights, %bias {pad = array<i64: 1, 1, 1, 1>, stride = array<i64: 1, 1>, dilation = array<i64: 2, 1>} : (tensor<1x47x40x28xf32>, tensor<28x3x3x28xf32>, tensor<28xf32>) -> tensor<1x45x40x28xf32>
+ %0 = tosa.conv2d %input, %weights, %bias {acc_type = f32, pad = array<i64: 1, 1, 1, 1>, stride = array<i64: 1, 1>, dilation = array<i64: 2, 1>} : (tensor<1x47x40x28xf32>, tensor<28x3x3x28xf32>, tensor<28xf32>) -> tensor<1x45x40x28xf32>
return
}
@@ -674,7 +674,7 @@ func.func @conv2d_quant(%arg0 : tensor<1x12x12x1xi8>, %arg1 : tensor<1024x3x3x1x
// CHECK: tensor.pad %arg0 low[0, 1, 1, 0] high[0, 1, 1, 0]
// CHECK: tensor.yield %[[C22]]
// CHECK: linalg.conv_2d_nhwc_fhwc_q
- %0 = tosa.conv2d %arg0, %arg1, %arg2 {dilation = array<i64: 1, 1>, pad = array<i64: 1, 1, 1, 1>, quantization_info = #tosa.conv_quant<input_zp = -22, weight_zp = 42>, stride = array<i64: 1, 1>} : (tensor<1x12x12x1xi8>, tensor<1024x3x3x1xi8>, tensor<1024xi32>) -> tensor<1x12x12x1024xi32>
+ %0 = tosa.conv2d %arg0, %arg1, %arg2 {acc_type = i32, dilation = array<i64: 1, 1>, pad = array<i64: 1, 1, 1, 1>, quantization_info = #tosa.conv_quant<input_zp = -22, weight_zp = 42>, stride = array<i64: 1, 1>} : (tensor<1x12x12x1xi8>, tensor<1024x3x3x1xi8>, tensor<1024xi32>) -> tensor<1x12x12x1024xi32>
return
}
@@ -696,7 +696,7 @@ func.func @depthwise_conv(%arg0 : tensor<1x7x5x3xf32>, %arg1 : tensor<3x1x3x11xf
// CHECK: [[ADD:%.+]] = arith.addf %[[ARG3]], %[[ARG4]] : f32
// CHECK: linalg.yield [[ADD]] : f32
// CHECK: } -> tensor<1x5x5x33xf32>
- %2 = tosa.depthwise_conv2d %arg0, %arg1, %arg2 { pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, dilation = array<i64: 1, 1> } : (tensor<1x7x5x3xf32>, tensor<3x1x3x11xf32>, tensor<33xf32>) -> tensor<1x5x5x33xf32>
+ %2 = tosa.depthwise_conv2d %arg0, %arg1, %arg2 { acc_type = f32, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, dilation = array<i64: 1, 1> } : (tensor<1x7x5x3xf32>, tensor<3x1x3x11xf32>, tensor<33xf32>) -> tensor<1x5x5x33xf32>
return
}
@@ -712,7 +712,7 @@ func.func @depthwise_conv_scalar_bias(%arg0 : tensor<1x7x5x3xf32>, %arg1 : tenso
// CHECK: [[ADD:%.+]] = arith.addf %[[ARG3]], %[[ARG4]] : f32
// CHECK: linalg.yield [[ADD]] : f32
// CHECK: } -> tensor<1x5x5x33xf32>
- %2 = tosa.depthwise_conv2d %arg0, %arg1, %arg2 { pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, dilation = array<i64: 1, 1> } : (tensor<1x7x5x3xf32>, tensor<3x1x3x11xf32>, tensor<1xf32>) -> tensor<1x5x5x33xf32>
+ %2 = tosa.depthwise_conv2d %arg0, %arg1, %arg2 {acc_type = f32, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, dilation = array<i64: 1, 1> } : (tensor<1x7x5x3xf32>, tensor<3x1x3x11xf32>, tensor<1xf32>) -> tensor<1x5x5x33xf32>
return
}
@@ -736,7 +736,7 @@ func.func @depthwise_conv_dyn(%arg0 : tensor<?x7x5x3xf32>, %arg1 : tensor<3x1x3x
// CHECK: %[[ADD:.+]] = arith.addf %[[ARG3]], %[[ARG4]] : f32
// CHECK: linalg.yield %[[ADD]] : f32
// CHECK: } -> tensor<?x5x5x33xf32>
- %2 = tosa.depthwise_conv2d %arg0, %arg1, %arg2 { pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, dilation = array<i64: 1, 1> } : (tensor<?x7x5x3xf32>, tensor<3x1x3x11xf32>, tensor<33xf32>) -> tensor<?x5x5x33xf32>
+ %2 = tosa.depthwise_conv2d %arg0, %arg1, %arg2 { acc_type = f32, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, dilation = array<i64: 1, 1> } : (tensor<?x7x5x3xf32>, tensor<3x1x3x11xf32>, tensor<33xf32>) -> tensor<?x5x5x33xf32>
return
}
@@ -758,7 +758,7 @@ func.func @depthwise_conv_strides(%arg0 : tensor<1x11x9x3xf32>, %arg1 : tensor<3
// CHECK: [[ADD:%.+]] = arith.addf %[[ARG3]], %[[ARG4]] : f32
// CHECK: linalg.yield [[ADD]] : f32
// CHECK: } -> tensor<1x5x5x33xf32>
- %2 = tosa.depthwise_conv2d %arg0, %arg1, %arg2 { pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 2, 2>, dilation = array<i64: 1, 1> } : (tensor<1x11x9x3xf32>, tensor<3x1x3x11xf32>, tensor<33xf32>) -> tensor<1x5x5x33xf32>
+ %2 = tosa.depthwise_conv2d %arg0, %arg1, %arg2 { acc_type = f32, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 2, 2>, dilation = array<i64: 1, 1> } : (tensor<1x11x9x3xf32>, tensor<3x1x3x11xf32>, tensor<33xf32>) -> tensor<1x5x5x33xf32>
return
}
@@ -786,7 +786,7 @@ func.func @depthwise_conv_quant(%arg0 : tensor<1x12x12x4xi8>, %arg1 : tensor<3x3
// CHECK: [[ADD:%.+]] = arith.addi %[[ARG3]], %[[ARG4]] : i32
// CHECK: linalg.yield [[ADD]] : i32
// CHECK: } -> tensor<1x12x12x512xi32>
- %0 = tosa.depthwise_conv2d %arg0, %arg1, %arg2 {pad = array<i64: 1, 1, 1, 1>, quantization_info = #tosa.conv_quant<input_zp = -128, weight_zp = 42>, stride = array<i64: 1, 1>, dilation = array<i64: 1, 1> } : (tensor<1x12x12x4xi8>, tensor<3x3x4x128xi8>, tensor<512xi32>) -> tensor<1x12x12x512xi32>
+ %0 = tosa.depthwise_conv2d %arg0, %arg1, %arg2 {acc_type = i32, pad = array<i64: 1, 1, 1, 1>, quantization_info = #tosa.conv_quant<input_zp = -128, weight_zp = 42>, stride = array<i64: 1, 1>, dilation = array<i64: 1, 1> } : (tensor<1x12x12x4xi8>, tensor<3x3x4x128xi8>, tensor<512xi32>) -> tensor<1x12x12x512xi32>
return
}
@@ -810,7 +810,7 @@ func.func @depthwise_conv_quant_dilations(%arg0 : tensor<1x14x14x4xi8>, %arg1 :
// CHECK: [[ADD:%.+]] = arith.addi %[[ARG3]], %[[ARG4]] : i32
// CHECK: linalg.yield [[ADD]] : i32
// CHECK: } -> tensor<1x10x10x512xi32>
- %0 = tosa.depthwise_conv2d %arg0, %arg1, %arg2 {pad = array<i64: 0, 0, 0, 0>, quantization_info = #tosa.conv_quant<input_zp = -128, weight_zp = 42>, stride = array<i64: 1, 1>, dilation = array<i64: 2, 2> } : (tensor<1x14x14x4xi8>, tensor<3x3x4x128xi8>, tensor<512xi32>) -> tensor<1x10x10x512xi32>
+ %0 = tosa.depthwise_conv2d %arg0, %arg1, %arg2 {acc_type = i32, pad = array<i64: 0, 0, 0, 0>, quantization_info = #tosa.conv_quant<input_zp = -128, weight_zp = 42>, stride = array<i64: 1, 1>, dilation = array<i64: 2, 2> } : (tensor<1x14x14x4xi8>, tensor<3x3x4x128xi8>, tensor<512xi32>) -> tensor<1x10x10x512xi32>
return
}
@@ -826,7 +826,7 @@ func.func @depthwise_conv2d_dyn_w_h(%arg0: tensor<2x?x?x3xf32>, %arg1: tensor<3x
// CHECK: } : tensor<2x?x?x3xf32> to tensor<2x?x?x3xf32>
// CHECK: %[[CONV:.+]] = linalg.depthwise_conv_2d_nhwc_hwcm {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<[1, 2]> : tensor<2xi64>} ins(%[[PADDED]], %arg1 : tensor<2x?x?x3xf32>, tensor<3x6x3x5xf32>) outs(%{{.*}} : tensor<2x?x?x3x5xf32>) -> tensor<2x?x?x3x5xf32>
// CHECK: %[[COLLAPSED:.+]] = tensor.collapse_shape %[[CONV]] {{\[}}[0], [1], [2], [3, 4]]
- %0 = tosa.depthwise_conv2d %arg0, %arg1, %arg2 {pad = array<i64: 1, 2, 3, 4>, dilation = array<i64: 2, 1>, stride = array<i64: 1, 2>} : (tensor<2x?x?x3xf32>, tensor<3x6x3x5xf32>, tensor<15xf32>) -> tensor<2x?x?x15xf32>
+ %0 = tosa.depthwise_conv2d %arg0, %arg1, %arg2 {acc_type = f32, pad = array<i64: 1, 2, 3, 4>, dilation = array<i64: 2, 1>, stride = array<i64: 1, 2>} : (tensor<2x?x?x3xf32>, tensor<3x6x3x5xf32>, tensor<15xf32>) -> tensor<2x?x?x15xf32>
return
}
@@ -850,7 +850,7 @@ func.func @conv3d_f32(%input: tensor<1x49x48x47x27xf32>, %weights: tensor<28x3x4
// CHECK-SAME: {dilations = dense<1> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>}
// CHECK-SAME: ins(%arg0, %[[TRANSPOSE]] : tensor<1x49x48x47x27xf32>, tensor<3x4x5x27x28xf32>)
// CHECK-SAME: outs(%[[BROADCAST]] : tensor<1x47x45x43x28xf32>) -> tensor<1x47x45x43x28xf32>
- %0 = tosa.conv3d %input, %weights, %bias {pad = array<i64: 0, 0, 0, 0, 0, 0>, stride = array<i64: 1, 1, 1>, dilation = array<i64: 1, 1, 1>} : (tensor<1x49x48x47x27xf32>, tensor<28x3x4x5x27xf32>, tensor<28xf32>) -> tensor<1x47x45x43x28xf32>
+ %0 = tosa.conv3d %input, %weights, %bias {acc_type = f32, pad = array<i64: 0, 0, 0, 0, 0, 0>, stride = array<i64: 1, 1, 1>, dilation = array<i64: 1, 1, 1>} : (tensor<1x49x48x47x27xf32>, tensor<28x3x4x5x27xf32>, tensor<28xf32>) -> tensor<1x47x45x43x28xf32>
return
}
@@ -864,7 +864,7 @@ func.func @conv3d_scalar_bias_f32(%input: tensor<1x49x48x47x27xf32>, %weights: t
// CHECK: %[[INIT:.+]] = tensor.empty() : tensor<1x47x45x43x28xf32>
// CHECK: %[[BROADCAST:.+]] = linalg.generic
// CHECK-SAME: {indexing_maps = [#[[$MAP1]], #[[$MAP2]]], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]}
- %0 = tosa.conv3d %input, %weights, %bias {pad = array<i64: 0, 0, 0, 0, 0, 0>, stride = array<i64: 1, 1, 1>, dilation = array<i64: 1, 1, 1>} : (tensor<1x49x48x47x27xf32>, tensor<28x3x4x5x27xf32>, tensor<1xf32>) -> tensor<1x47x45x43x28xf32>
+ %0 = tosa.conv3d %input, %weights, %bias {acc_type = f32, pad = array<i64: 0, 0, 0, 0, 0, 0>, stride = array<i64: 1, 1, 1>, dilation = array<i64: 1, 1, 1>} : (tensor<1x49x48x47x27xf32>, tensor<28x3x4x5x27xf32>, tensor<1xf32>) -> tensor<1x47x45x43x28xf32>
return
}
@@ -892,7 +892,7 @@ func.func @conv3d_i8(%input: tensor<1x49x48x47x27xi8>, %weights: tensor<28x3x4x5
// CHECK-SAME: ins(%arg0, %[[TRANSPOSE]], %[[IZP]], %[[FZP]] : tensor<1x49x48x47x27xi8>, tensor<3x4x5x27x28xi8>, i32, i32)
// CHECK-SAME: outs(%[[BROADCAST]] : tensor<1x47x45x43x28xi32>) -> tensor<1x47x45x43x28xi32>
- %0 = tosa.conv3d %input, %weights, %bias {pad = array<i64: 0, 0, 0, 0, 0, 0>, quantization_info = #tosa.conv_quant<input_zp = -128, weight_zp = 42>, stride = array<i64: 1, 1, 1>, dilation = array<i64: 1, 1, 1>} : (tensor<1x49x48x47x27xi8>, tensor<28x3x4x5x27xi8>, tensor<28xi32>) -> tensor<1x47x45x43x28xi32>
+ %0 = tosa.conv3d %input, %weights, %bias {acc_type = i32, pad = array<i64: 0, 0, 0, 0, 0, 0>, quantization_info = #tosa.conv_quant<input_zp = -128, weight_zp = 42>, stride = array<i64: 1, 1, 1>, dilation = array<i64: 1, 1, 1>} : (tensor<1x49x48x47x27xi8>, tensor<28x3x4x5x27xi8>, tensor<28xi32>) -> tensor<1x47x45x43x28xi32>
return
}
diff --git a/mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir b/mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir
index 1e62e25176a0..0b9a64494bc0 100644
--- a/mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir
+++ b/mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir
@@ -459,85 +459,65 @@ func.func @slice_dyn(%arg0: tensor<?xf32>) -> (tensor<?xf32>) {
// CHECK-LABEL: @pad_float
// CHECK-SAME: (%[[ARG0:[0-9a-zA-Z_]*]]:
func.func @pad_float(%arg0 : tensor<1x2xf32>) -> (tensor<4x9xf32>) {
- %0 = arith.constant dense<[[1, 2], [3, 4]]> : tensor<2x2xi32>
- // TODO: Output contains multiple "arith.constant 1 : index".
- // CHECK-DAG: [[INDEX1:%.+]] = arith.constant 1 : index
- // CHECK-DAG: [[INDEX2:%.+]] = arith.constant 2 : index
- // CHECK-DAG: [[INDEX3:%.+]] = arith.constant 3 : index
- // CHECK-DAG: [[INDEX4:%.+]] = arith.constant 4 : index
+ %0 = arith.constant dense<[1, 2, 3, 4]> : tensor<4xi32>
// CHECK-DAG: [[CST:%.+]] = arith.constant 0.000000e+00 : f32
- // CHECK: tensor.pad %[[ARG0]] low{{\[}}%{{.*}}, [[INDEX3]]] high{{\[}}[[INDEX2]], [[INDEX4]]] {
+ // CHECK: tensor.pad %[[ARG0]] low{{\[}}%{{.*}}, %{{.*}}] high{{\[}}%{{.*}}, %{{.*}}] {
// CHECK: tensor.yield [[CST]]
// CHECK: } : tensor<1x2xf32> to tensor<4x9xf32>
- %1 = "tosa.pad"(%arg0, %0) : (tensor<1x2xf32>, tensor<2x2xi32>) -> (tensor<4x9xf32>)
+ %1 = "tosa.pad"(%arg0, %0) : (tensor<1x2xf32>, tensor<4xi32>) -> (tensor<4x9xf32>)
return %1 : tensor<4x9xf32>
}
func.func @pad_int(%arg0 : tensor<1x2xi32>) -> (tensor<4x9xi32>) {
- %0 = arith.constant dense<[[1, 2], [3, 4]]> : tensor<2x2xi32>
+ %0 = arith.constant dense<[1, 2, 3, 4]> : tensor<4xi32>
// CHECK: [[CST:%.+]] = arith.constant 0 : i32
// CHECK: tensor.pad
// CHECK: tensor.yield [[CST]]
- %1 = "tosa.pad"(%arg0, %0) : (tensor<1x2xi32>, tensor<2x2xi32>) -> (tensor<4x9xi32>)
+ %1 = "tosa.pad"(%arg0, %0) : (tensor<1x2xi32>, tensor<4xi32>) -> (tensor<4x9xi32>)
return %1 : tensor<4x9xi32>
}
func.func @pad_quant(%arg0 : tensor<1x2xi32>) -> (tensor<4x9xi32>) {
- %0 = arith.constant dense<[[1, 2], [3, 4]]> : tensor<2x2xi32>
+ %0 = arith.constant dense<[1, 2, 3, 4]> : tensor<4xi32>
// CHECK: [[CST:%.+]] = arith.constant 42 : i32
// CHECK: tensor.pad
// CHECK: tensor.yield [[CST]]
- %1 = "tosa.pad"(%arg0, %0) {quantization_info = #tosa.pad_quant<input_zp = 42>} : (tensor<1x2xi32>, tensor<2x2xi32>) -> (tensor<4x9xi32>)
+ %1 = "tosa.pad"(%arg0, %0) {quantization_info = #tosa.pad_quant<input_zp = 42>} : (tensor<1x2xi32>, tensor<4xi32>) -> (tensor<4x9xi32>)
return %1 : tensor<4x9xi32>
}
// -----
func.func @pad_float_explicit(%arg0 : tensor<1x2xf32>) -> (tensor<4x9xf32>) {
- %0 = arith.constant dense<[[1, 2], [3, 4]]> : tensor<2x2xi32>
- // TODO: Output contains multiple "arith.constant 1 : index".
- // CHECK-DAG: [[INDEX1:%.+]] = arith.constant 1 : index
- // CHECK-DAG: [[INDEX2:%.+]] = arith.constant 2 : index
- // CHECK-DAG: [[INDEX3:%.+]] = arith.constant 3 : index
- // CHECK-DAG: [[INDEX4:%.+]] = arith.constant 4 : index
+ %0 = arith.constant dense<[1, 2, 3, 4]> : tensor<4xi32>
// CHECK-DAG: [[CST:%.+]] = arith.constant 4.200000e+01 : f32
- // CHECK: tensor.pad %[[ARG0]] low{{\[}}%{{.*}}, [[INDEX3]]] high{{\[}}[[INDEX2]], [[INDEX4]]] {
+ // CHECK: tensor.pad %[[ARG0]] low{{\[}}%{{.*}}, %{{.*}}] high{{\[}}%{{.*}}, %{{.*}}] {
// CHECK: tensor.yield [[CST]]
// CHECK: } : tensor<1x2xf32> to tensor<4x9xf32>
%1 = arith.constant dense<42.0> : tensor<f32>
- %2 = "tosa.pad"(%arg0, %0, %1) : (tensor<1x2xf32>, tensor<2x2xi32>, tensor<f32>) -> (tensor<4x9xf32>)
+ %2 = "tosa.pad"(%arg0, %0, %1) : (tensor<1x2xf32>, tensor<4xi32>, tensor<f32>) -> (tensor<4x9xf32>)
return %2 : tensor<4x9xf32>
}
// -----
func.func @pad_dyn_input(%arg0 : tensor<?x2xf32>) -> (tensor<?x9xf32>) {
- %0 = arith.constant dense<[[1, 2], [3, 4]]> : tensor<2x2xi32>
- // TODO: Output contains multiple "arith.constant 1 : index".
- // CHECK-DAG: [[INDEX1:%.+]] = arith.constant 1 : index
- // CHECK-DAG: [[INDEX2:%.+]] = arith.constant 2 : index
- // CHECK-DAG: [[INDEX3:%.+]] = arith.constant 3 : index
- // CHECK-DAG: [[INDEX4:%.+]] = arith.constant 4 : index
+ %0 = arith.constant dense<[1, 2, 3, 4]> : tensor<4xi32>
// CHECK-DAG: [[CST:%.+]] = arith.constant 0.000000e+00 : f32
- // CHECK: tensor.pad %[[ARG0]] low{{\[}}%{{.*}}, [[INDEX3]]] high{{\[}}[[INDEX2]], [[INDEX4]]] {
+ // CHECK: tensor.pad %[[ARG0]] low{{\[}}%{{.*}}, %{{.*}}] high{{\[}}%{{.*}}, %{{.*}}] {
// CHECK: tensor.yield [[CST]]
// CHECK: } : tensor<?x2xf32> to tensor<?x9xf32>
- %1 = "tosa.pad"(%arg0, %0) : (tensor<?x2xf32>, tensor<2x2xi32>) -> (tensor<?x9xf32>)
+ %1 = "tosa.pad"(%arg0, %0) : (tensor<?x2xf32>, tensor<4xi32>) -> (tensor<?x9xf32>)
return %1 : tensor<?x9xf32>
}
func.func @pad_dyn_padding(%arg0 : tensor<1x2xf32>) -> (tensor<?x9xf32>) {
- %0 = arith.constant dense<[[-1, 2], [3, 4]]> : tensor<2x2xi32>
- // TODO: Output contains multiple "arith.constant 1 : index".
- // CHECK-DAG: [[INDEX1:%.+]] = arith.constant 1 : index
- // CHECK-DAG: [[INDEX2:%.+]] = arith.constant 2 : index
- // CHECK-DAG: [[INDEX3:%.+]] = arith.constant 3 : index
- // CHECK-DAG: [[INDEX4:%.+]] = arith.constant 4 : index
+ %0 = arith.constant dense<[-1, 2, 3, 4]> : tensor<4xi32>
// CHECK-DAG: [[CST:%.+]] = arith.constant 0.000000e+00 : f32
- // CHECK: tensor.pad %[[ARG0]] low{{\[}}%{{.*}}, [[INDEX3]]] high{{\[}}[[INDEX2]], [[INDEX4]]] {
+ // CHECK: tensor.pad %[[ARG0]] low{{\[}}%{{.*}}, %{{.*}}] high{{\[}}%{{.*}}, %{{.*}}] {
// CHECK: tensor.yield [[CST]]
// CHECK: } : tensor<1x2xf32> to tensor<?x9xf32>
- %1 = "tosa.pad"(%arg0, %0) : (tensor<1x2xf32>, tensor<2x2xi32>) -> (tensor<?x9xf32>)
+ %1 = "tosa.pad"(%arg0, %0) : (tensor<1x2xf32>, tensor<4xi32>) -> (tensor<?x9xf32>)
return %1 : tensor<?x9xf32>
}
diff --git a/mlir/test/Dialect/Affine/canonicalize.mlir b/mlir/test/Dialect/Affine/canonicalize.mlir
index 717004eb50c0..a9ac13ad7162 100644
--- a/mlir/test/Dialect/Affine/canonicalize.mlir
+++ b/mlir/test/Dialect/Affine/canonicalize.mlir
@@ -1917,12 +1917,12 @@ func.func @linearize_one_element_basis(%arg0: index, %arg1: index) -> index {
// -----
-// CHECK-LABEL: func @cancel_linearize_denearize_exact(
+// CHECK-LABEL: func @cancel_linearize_delinearize_exact(
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: index,
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: index,
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: index)
// CHECK: return %[[ARG0]]
-func.func @cancel_linearize_denearize_exact(%arg0: index, %arg1: index, %arg2: index) -> index {
+func.func @cancel_linearize_delinearize_exact(%arg0: index, %arg1: index, %arg2: index) -> index {
%0:3 = affine.delinearize_index %arg0 into (%arg1, 4, %arg2) : index, index, index
%1 = affine.linearize_index [%0#0, %0#1, %0#2] by (%arg1, 4, %arg2) : index
return %1 : index
@@ -1930,12 +1930,12 @@ func.func @cancel_linearize_denearize_exact(%arg0: index, %arg1: index, %arg2: i
// -----
-// CHECK-LABEL: func @cancel_linearize_denearize_linearize_extra_bound(
+// CHECK-LABEL: func @cancel_linearize_delinearize_linearize_extra_bound(
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: index,
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: index,
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: index)
// CHECK: return %[[ARG0]]
-func.func @cancel_linearize_denearize_linearize_extra_bound(%arg0: index, %arg1: index, %arg2: index) -> index {
+func.func @cancel_linearize_delinearize_linearize_extra_bound(%arg0: index, %arg1: index, %arg2: index) -> index {
%0:3 = affine.delinearize_index %arg0 into (4, %arg2) : index, index, index
%1 = affine.linearize_index [%0#0, %0#1, %0#2] by (%arg1, 4, %arg2) : index
return %1 : index
@@ -1943,12 +1943,12 @@ func.func @cancel_linearize_denearize_linearize_extra_bound(%arg0: index, %arg1:
// -----
-// CHECK-LABEL: func @cancel_linearize_denearize_delinearize_extra_bound(
+// CHECK-LABEL: func @cancel_linearize_delinearize_delinearize_extra_bound(
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: index,
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: index,
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: index)
// CHECK: return %[[ARG0]]
-func.func @cancel_linearize_denearize_delinearize_extra_bound(%arg0: index, %arg1: index, %arg2: index) -> index {
+func.func @cancel_linearize_delinearize_delinearize_extra_bound(%arg0: index, %arg1: index, %arg2: index) -> index {
%0:3 = affine.delinearize_index %arg0 into (%arg1, 4, %arg2) : index, index, index
%1 = affine.linearize_index [%0#0, %0#1, %0#2] by (4, %arg2) : index
return %1 : index
@@ -1956,31 +1956,252 @@ func.func @cancel_linearize_denearize_delinearize_extra_bound(%arg0: index, %arg
// -----
+// CHECK-LABEL: func @cancel_linearize_delinearize_head(
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: index)
+// CHECK: %[[DELIN:.+]]:2 = affine.delinearize_index %[[ARG0]] into (12, 8)
+// CHECK: %[[LIN:.+]] = affine.linearize_index [%[[DELIN]]#0, %[[ARG1]]] by (12, 16)
+// CHECK: return %[[LIN]]
+func.func @cancel_linearize_delinearize_head(%arg0: index, %arg1: index) -> index {
+ %0:3 = affine.delinearize_index %arg0 into (3, 4, 8) : index, index, index
+ %1 = affine.linearize_index [%0#0, %0#1, %arg1] by (3, 4, 16) : index
+ return %1 : index
+}
+
+// -----
+
+// CHECK-LABEL: func @cancel_linearize_delinearize_head_delinearize_unbounded(
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: index)
+// CHECK: %[[DELIN:.+]]:2 = affine.delinearize_index %[[ARG0]] into (12, 8)
+// CHECK: %[[LIN:.+]] = affine.linearize_index [%[[DELIN]]#0, %[[ARG1]]] by (12, 16)
+// CHECK: return %[[LIN]]
+func.func @cancel_linearize_delinearize_head_delinearize_unbounded(%arg0: index, %arg1: index) -> index {
+ %0:3 = affine.delinearize_index %arg0 into (4, 8) : index, index, index
+ %1 = affine.linearize_index [%0#0, %0#1, %arg1] by (3, 4, 16) : index
+ return %1 : index
+}
+
+// -----
+
+// CHECK-LABEL: func @cancel_linearize_delinearize_head_linearize_unbounded(
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: index)
+// CHECK: %[[DELIN:.+]]:2 = affine.delinearize_index %[[ARG0]] into (8)
+// CHECK: %[[LIN:.+]] = affine.linearize_index [%[[DELIN]]#0, %[[ARG1]]] by (16)
+// CHECK: return %[[LIN]]
+func.func @cancel_linearize_delinearize_head_linearize_unbounded(%arg0: index, %arg1: index) -> index {
+ %0:3 = affine.delinearize_index %arg0 into (3, 4, 8) : index, index, index
+ %1 = affine.linearize_index [%0#0, %0#1, %arg1] by (4, 16) : index
+ return %1 : index
+}
+
+// -----
+
+// CHECK-LABEL: func @cancel_linearize_delinearize_head_both_unbounded(
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: index)
+// CHECK: %[[DELIN:.+]]:2 = affine.delinearize_index %[[ARG0]] into (8)
+// CHECK: %[[LIN:.+]] = affine.linearize_index [%[[DELIN]]#0, %[[ARG1]]] by (16)
+// CHECK: return %[[LIN]]
+func.func @cancel_linearize_delinearize_head_both_unbounded(%arg0: index, %arg1: index) -> index {
+ %0:3 = affine.delinearize_index %arg0 into (4, 8) : index, index, index
+ %1 = affine.linearize_index [%0#0, %0#1, %arg1] by (4, 16) : index
+ return %1 : index
+}
+
+// -----
+
+// CHECK-LABEL: func @cancel_linearize_delinearize_tail(
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: index)
+// CHECK: %[[DELIN:.+]]:2 = affine.delinearize_index %[[ARG0]] into (3, 32)
+// CHECK: %[[LIN:.+]] = affine.linearize_index [%[[ARG1]], %[[DELIN]]#1] by (5, 32)
+// CHECK: return %[[LIN]]
+func.func @cancel_linearize_delinearize_tail(%arg0: index, %arg1: index) -> index {
+ %0:3 = affine.delinearize_index %arg0 into (3, 4, 8) : index, index, index
+ %1 = affine.linearize_index [%arg1, %0#1, %0#2] by (5, 4, 8) : index
+ return %1 : index
+}
+
+// -----
+
+// CHECK-LABEL: func @cancel_linearize_delinearize_middle_exact(
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME: %[[ARG2:[a-zA-z0-9]+]]: index)
+// CHECK: %[[LIN:.+]] = affine.linearize_index [%[[ARG1]], %[[ARG0]], %[[ARG2]]] by (9, 30, 7)
+// CHECK: return %[[LIN]]
+func.func @cancel_linearize_delinearize_middle_exact(%arg0: index, %arg1: index, %arg2: index) -> index {
+ %0:3 = affine.delinearize_index %arg0 into (2, 3, 5) : index, index, index
+ %1 = affine.linearize_index [%arg1, %0#0, %0#1, %0#2, %arg2] by (9, 2, 3, 5, 7) : index
+ return %1 : index
+}
+
+// -----
+
+// CHECK: #[[$MAP:.+]] = affine_map<()[s0, s1] -> ((s0 * s1) * 16)>
+
+// CHECK-LABEL: func @cancel_linearize_delinearize_middle_exact_dynamic_basis(
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME: %[[ARG2:[a-zA-z0-9]+]]: index)
+// CHECK: %[[C1:.+]] = arith.constant 1 : index
+// CHECK: %[[SIZEPROD:.+]] = affine.apply #[[$MAP]]()[%[[ARG1]], %[[ARG2]]]
+// CHECK: %[[LIN:.+]] = affine.linearize_index [%[[C1]], %[[ARG0]], %[[C1]]] by (3, %[[SIZEPROD]], 4)
+// CHECK: return %[[LIN]]
+func.func @cancel_linearize_delinearize_middle_exact_dynamic_basis(%arg0: index, %arg1: index, %arg2: index) -> index {
+ %c1 = arith.constant 1 : index
+ %0:4 = affine.delinearize_index %arg0 into (2, %arg1, %arg2, 8) : index, index, index, index
+ %1 = affine.linearize_index [%c1, %0#0, %0#1, %0#2, %0#3, %c1] by (3, 2, %arg1, %arg2, 8, 4) : index
+ return %1 : index
+}
+
+// -----
+
+// CHECK-LABEL: func @cancel_linearize_delinearize_middle_exact_delinearize_unbounded_disjoint(
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME: %[[ARG2:[a-zA-z0-9]+]]: index)
+// CHECK: %[[LIN:.+]] = affine.linearize_index disjoint [%[[ARG1]], %[[ARG0]], %[[ARG2]]] by (9, 30, 7)
+// CHECK: return %[[LIN]]
+func.func @cancel_linearize_delinearize_middle_exact_delinearize_unbounded_disjoint(%arg0: index, %arg1: index, %arg2: index) -> index {
+ %0:3 = affine.delinearize_index %arg0 into (3, 5) : index, index, index
+ %1 = affine.linearize_index disjoint [%arg1, %0#0, %0#1, %0#2, %arg2] by (9, 2, 3, 5, 7) : index
+ return %1 : index
+}
+
+// -----
+
+// Unlike in the test above, the linerize indices aren't asserted to be disjoint, so
+// we can't know if the `2` from the basis is a correct bound.
+// CHECK-LABEL: func @dont_cancel_linearize_delinearize_middle_exact_delinearize_unbounded(
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME: %[[ARG2:[a-zA-z0-9]+]]: index)
+// CHECK: %[[DELIN:.+]]:2 = affine.delinearize_index %[[ARG0]] into (3)
+// CHECK: %[[LIN:.+]] = affine.linearize_index [%[[ARG1]], %[[DELIN]]#0, %[[DELIN]]#1, %[[ARG2]]] by (9, 2, 3, 7)
+// CHECK: return %[[LIN]]
+
+func.func @dont_cancel_linearize_delinearize_middle_exact_delinearize_unbounded(%arg0: index, %arg1: index, %arg2: index) -> index {
+ %0:2 = affine.delinearize_index %arg0 into (3) : index, index
+ %1 = affine.linearize_index [%arg1, %0#0, %0#1, %arg2] by (9, 2, 3, 7) : index
+ return %1 : index
+}
+
+// -----
+
+// The presence of a `disjoint` here tells us that the "unbounded" term on the
+// delinearization can't have been above 2.
+// CHECK-LABEL: func @cancel_linearize_delinearize_middle_delinearize_unbounded_disjoint_implied_bound(
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME: %[[ARG2:[a-zA-z0-9]+]]: index)
+// CHECK: %[[DELIN:.+]]:2 = affine.delinearize_index %[[ARG0]] into (6, 5)
+// CHECK: %[[LIN:.+]] = affine.linearize_index disjoint [%[[ARG1]], %[[DELIN]]#0, %[[ARG2]]] by (9, 6, 7)
+// CHECK: return %[[LIN]]
+func.func @cancel_linearize_delinearize_middle_delinearize_unbounded_disjoint_implied_bound(%arg0: index, %arg1: index, %arg2: index) -> index {
+ %0:3 = affine.delinearize_index %arg0 into (3, 5) : index, index, index
+ %1 = affine.linearize_index disjoint [%arg1, %0#0, %0#1, %arg2] by (9, 2, 3, 7) : index
+ return %1 : index
+}
+
+// -----
+
+// CHECK-LABEL: func @cancel_linearize_delinearize_multiple_matches(
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: index)
+// CHECK: %[[C0:.+]] = arith.constant 0
+// CHECK: %[[DELIN:.+]]:4 = affine.delinearize_index %[[ARG0]] into (4, 16, 4, 64)
+// CHECK: %[[LIN:.+]] = affine.linearize_index [%[[ARG1]], %[[DELIN]]#1, %[[C0]], %[[DELIN]]#3] by (4, 16, 4, 64)
+// CHECK: return %[[LIN]]
+func.func @cancel_linearize_delinearize_multiple_matches(%arg0: index, %arg1: index) -> index {
+ %c0 = arith.constant 0 : index
+ %0:7 = affine.delinearize_index %arg0 into (4, 4, 4, 4, 4, 4, 4) : index, index, index, index, index, index, index
+ %1 = affine.linearize_index [%arg1, %0#1, %0#2, %c0, %0#4, %0#5, %0#6] by (4, 4, 4, 4, 4, 4, 4) : index
+ return %1 : index
+}
+
+// -----
+
+// CHECK-LABEL: func @cancel_linearize_delinearize_multiple_delinearizes(
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: index)
+// CHECK: %[[LIN:.+]] = affine.linearize_index [%[[ARG0]], %[[ARG1]]] by (32, 32)
+// CHECK: return %[[LIN]]
+func.func @cancel_linearize_delinearize_multiple_delinearizes(%arg0: index, %arg1: index) -> index {
+ %0:2 = affine.delinearize_index %arg0 into (4, 8) : index, index
+ %1:2 = affine.delinearize_index %arg1 into (2, 16) : index, index
+ %2 = affine.linearize_index [%0#0, %0#1, %1#0, %1#1] by (4, 8, 2, 16) : index
+ return %2 : index
+}
+
+// -----
+
// Don't cancel because the values from the delinearize aren't used in order
-// CHECK-LABEL: func @no_cancel_linearize_denearize_permuted(
+// CHECK-LABEL: func @no_cancel_linearize_delinearize_permuted(
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: index,
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: index,
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: index)
// CHECK: %[[DELIN:.+]]:3 = affine.delinearize_index %[[ARG0]] into (%[[ARG1]], 4, %[[ARG2]])
-// CHECK: %[[LIN:.+]] = affine.linearize_index [%[[DELIN]]#0, %[[DELIN]]#2, %[[DELIN]]#1] by (%[[ARG1]], 4, %[[ARG2]])
+// CHECK: %[[LIN:.+]] = affine.linearize_index [%[[DELIN]]#0, %[[DELIN]]#2, %[[DELIN]]#1] by (%[[ARG1]], %[[ARG2]], 4)
// CHECK: return %[[LIN]]
-func.func @no_cancel_linearize_denearize_permuted(%arg0: index, %arg1: index, %arg2: index) -> index {
+func.func @no_cancel_linearize_delinearize_permuted(%arg0: index, %arg1: index, %arg2: index) -> index {
%0:3 = affine.delinearize_index %arg0 into (%arg1, 4, %arg2) : index, index, index
- %1 = affine.linearize_index [%0#0, %0#2, %0#1] by (%arg1, 4, %arg2) : index
+ %1 = affine.linearize_index [%0#0, %0#2, %0#1] by (%arg1, %arg2, 4) : index
+ return %1 : index
+}
+
+// -----
+
+// CHECK: #[[$MAP:.+]] = affine_map<()[s0] -> (s0 * 3)>
+// But these cancel because they're a contiguous segment
+// CHECK-LABEL: func @partial_cancel_linearize_delinearize_not_fully_permuted(
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: index)
+// CHECK: %[[SIZEPROD:.+]] = affine.apply #[[$MAP]]()[%[[ARG2]]]
+// CHECK: %[[DELIN:.+]]:3 = affine.delinearize_index %[[ARG0]] into (%[[ARG1]], 4, %[[SIZEPROD]])
+// CHECK: %[[LIN:.+]] = affine.linearize_index [%[[DELIN]]#0, %[[DELIN]]#2, %[[DELIN]]#1] by (%[[ARG1]], %[[SIZEPROD]], 4)
+// CHECK: return %[[LIN]]
+func.func @partial_cancel_linearize_delinearize_not_fully_permuted(%arg0: index, %arg1: index, %arg2: index) -> index {
+ %0:4 = affine.delinearize_index %arg0 into (%arg1, 4, %arg2, 3) : index, index, index, index
+ %1 = affine.linearize_index [%0#0, %0#2, %0#3, %0#1] by (%arg1, %arg2, 3, 4) : index
return %1 : index
}
// -----
+// Ensure we don't get SSA errors when creating new `affine.delinearize` operations.
+// CHECK-LABEL: func @cancel_linearize_delinearize_placement
+// CHECK-SAME: (%[[ARG0:.+]]: index)
+// CHECK: %[[C0:.+]] = arith.constant 0 : index
+// CHECK: %[[NEW_DELIN:.+]]:2 = affine.delinearize_index %[[ARG0]] into (8, 32) : index, index
+// CHECK-NEXT: %[[DELIN_PART:.+]]:2 = affine.delinearize_index %[[NEW_DELIN]]#1 into (8, 4) : index, index
+// CHECK-NEXT: %[[L1:.+]] = affine.linearize_index disjoint [%[[DELIN_PART]]#1, %[[NEW_DELIN]]#0, %[[C0]], %[[C0]]] by (4, 8, 4, 8)
+// CHECK-NEXT: %[[L2:.+]] = affine.linearize_index disjoint [%[[NEW_DELIN]]#1, %[[C0]], %[[C0]]] by (32, 8, 4)
+// CHECK-NEXT: %[[L3:.+]] = affine.linearize_index disjoint [%[[DELIN_PART]]#0, %[[NEW_DELIN]]#0, %[[C0]], %[[C0]]] by (8, 8, 4, 4)
+// CHECK-NEXT: return %[[L1]], %[[L2]], %[[L3]]
+func.func @cancel_linearize_delinearize_placement(%arg0: index) -> (index, index, index) {
+ %c0 = arith.constant 0 : index
+ %0:3 = affine.delinearize_index %arg0 into (8, 8, 4) : index, index, index
+ %1 = affine.linearize_index disjoint [%0#2, %0#0, %c0, %c0] by (4, 8, 4, 8) : index
+ %2 = affine.linearize_index disjoint [%0#1, %0#2, %c0, %c0] by (8, 4, 8, 4) : index
+ %3 = affine.linearize_index disjoint [%0#1, %0#0, %c0, %c0] by (8, 8, 4, 4) : index
+ return %1, %2, %3 : index, index, index
+}
+
+// -----
+
// Won't cancel because the linearize and delinearize are using a different basis
-// CHECK-LABEL: func @no_cancel_linearize_denearize_different_basis(
+// CHECK-LABEL: func @no_cancel_linearize_delinearize_different_basis(
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: index,
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: index,
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: index)
// CHECK: %[[DELIN:.+]]:3 = affine.delinearize_index %[[ARG0]] into (%[[ARG1]], 4, %[[ARG2]])
// CHECK: %[[LIN:.+]] = affine.linearize_index [%[[DELIN]]#0, %[[DELIN]]#1, %[[DELIN]]#2] by (%[[ARG1]], 8, %[[ARG2]])
// CHECK: return %[[LIN]]
-func.func @no_cancel_linearize_denearize_different_basis(%arg0: index, %arg1: index, %arg2: index) -> index {
+func.func @no_cancel_linearize_delinearize_different_basis(%arg0: index, %arg1: index, %arg2: index) -> index {
%0:3 = affine.delinearize_index %arg0 into (%arg1, 4, %arg2) : index, index, index
%1 = affine.linearize_index [%0#0, %0#1, %0#2] by (%arg1, 8, %arg2) : index
return %1 : index
diff --git a/mlir/test/Dialect/Affine/value-bounds-op-interface-impl.mlir b/mlir/test/Dialect/Affine/value-bounds-op-interface-impl.mlir
index 935c08aceff5..5354eb38d7b0 100644
--- a/mlir/test/Dialect/Affine/value-bounds-op-interface-impl.mlir
+++ b/mlir/test/Dialect/Affine/value-bounds-op-interface-impl.mlir
@@ -155,3 +155,84 @@ func.func @compare_maps(%a: index, %b: index) {
: (index, index, index, index) -> ()
return
}
+
+// -----
+
+// CHECK-DAG: #[[$map1:.+]] = affine_map<()[s0] -> (s0 floordiv 15)>
+// CHECK-DAG: #[[$map2:.+]] = affine_map<()[s0] -> ((s0 mod 15) floordiv 5)>
+// CHECK-DAG: #[[$map3:.+]] = affine_map<()[s0] -> (s0 mod 5)>
+// CHECK-LABEL: func.func @delinearize_static
+// CHECK-SAME: (%[[arg0:.+]]: index)
+// CHECK-DAG: %[[v1:.+]] = affine.apply #[[$map1]]()[%[[arg0]]]
+// CHECK-DAG: %[[v2:.+]] = affine.apply #[[$map2]]()[%[[arg0]]]
+// CHECK-DAG: %[[v3:.+]] = affine.apply #[[$map3]]()[%[[arg0]]]
+// CHECK: return %[[v1]], %[[v2]], %[[v3]]
+func.func @delinearize_static(%arg0: index) -> (index, index, index) {
+ %c2 = arith.constant 2 : index
+ %c3 = arith.constant 3 : index
+ %0:3 = affine.delinearize_index %arg0 into (2, 3, 5) : index, index, index
+ %1 = "test.reify_bound"(%0#0) {type = "EQ"} : (index) -> (index)
+ %2 = "test.reify_bound"(%0#1) {type = "EQ"} : (index) -> (index)
+ %3 = "test.reify_bound"(%0#2) {type = "EQ"} : (index) -> (index)
+ // expected-remark @below{{true}}
+ "test.compare"(%0#0, %c2) {cmp = "LT"} : (index, index) -> ()
+ // expected-remark @below{{true}}
+ "test.compare"(%0#1, %c3) {cmp = "LT"} : (index, index) -> ()
+ return %1, %2, %3 : index, index, index
+}
+
+// -----
+
+// CHECK-DAG: #[[$map1:.+]] = affine_map<()[s0] -> (s0 floordiv 15)>
+// CHECK-DAG: #[[$map2:.+]] = affine_map<()[s0] -> ((s0 mod 15) floordiv 5)>
+// CHECK-DAG: #[[$map3:.+]] = affine_map<()[s0] -> (s0 mod 5)>
+// CHECK-LABEL: func.func @delinearize_static_no_outer_bound
+// CHECK-SAME: (%[[arg0:.+]]: index)
+// CHECK-DAG: %[[v1:.+]] = affine.apply #[[$map1]]()[%[[arg0]]]
+// CHECK-DAG: %[[v2:.+]] = affine.apply #[[$map2]]()[%[[arg0]]]
+// CHECK-DAG: %[[v3:.+]] = affine.apply #[[$map3]]()[%[[arg0]]]
+// CHECK: return %[[v1]], %[[v2]], %[[v3]]
+func.func @delinearize_static_no_outer_bound(%arg0: index) -> (index, index, index) {
+ %c2 = arith.constant 2 : index
+ %c3 = arith.constant 3 : index
+ %0:3 = affine.delinearize_index %arg0 into (3, 5) : index, index, index
+ %1 = "test.reify_bound"(%0#0) {type = "EQ"} : (index) -> (index)
+ %2 = "test.reify_bound"(%0#1) {type = "EQ"} : (index) -> (index)
+ %3 = "test.reify_bound"(%0#2) {type = "EQ"} : (index) -> (index)
+ "test.compaare"(%0#0, %c2) {cmp = "LT"} : (index, index) -> ()
+ // expected-remark @below{{true}}
+ "test.compare"(%0#1, %c3) {cmp = "LT"} : (index, index) -> ()
+ return %1, %2, %3 : index, index, index
+}
+
+// -----
+
+// CHECK: #[[$map:.+]] = affine_map<()[s0, s1] -> (s0 + s1 * 3)>
+// CHECK-LABEL: func.func @linearize_static
+// CHECK-SAME: (%[[arg0:.+]]: index, %[[arg1:.+]]: index)
+// CHECK: %[[v1:.+]] = affine.apply #[[$map]]()[%[[arg1]], %[[arg0]]]
+// CHECK: return %[[v1]]
+func.func @linearize_static(%arg0: index, %arg1: index) -> index {
+ %c6 = arith.constant 6 : index
+ %0 = affine.linearize_index disjoint [%arg0, %arg1] by (2, 3) : index
+ %1 = "test.reify_bound"(%0) {type = "EQ"} : (index) -> (index)
+ // expected-remark @below{{true}}
+ "test.compare"(%0, %c6) {cmp = "LT"} : (index, index) -> ()
+ return %1 : index
+}
+
+// -----
+
+// CHECK: #[[$map:.+]] = affine_map<()[s0, s1] -> (s0 + s1 * 3)>
+// CHECK-LABEL: func.func @linearize_static_no_outer_bound
+// CHECK-SAME: (%[[arg0:.+]]: index, %[[arg1:.+]]: index)
+// CHECK: %[[v1:.+]] = affine.apply #[[$map]]()[%[[arg1]], %[[arg0]]]
+// CHECK: return %[[v1]]
+func.func @linearize_static_no_outer_bound(%arg0: index, %arg1: index) -> index {
+ %c6 = arith.constant 6 : index
+ %0 = affine.linearize_index disjoint [%arg0, %arg1] by (3) : index
+ %1 = "test.reify_bound"(%0) {type = "EQ"} : (index) -> (index)
+ // expected-error @below{{unknown}}
+ "test.compare"(%0, %c6) {cmp = "LT"} : (index, index) -> ()
+ return %1 : index
+}
diff --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir
index 6a186a0c6cec..522711b08f28 100644
--- a/mlir/test/Dialect/Arith/canonicalize.mlir
+++ b/mlir/test/Dialect/Arith/canonicalize.mlir
@@ -2060,6 +2060,70 @@ func.func @test_divf1(%arg0 : f32, %arg1 : f32) -> (f32) {
// -----
+func.func @fold_divui_of_muli_0(%arg0 : index, %arg1 : index) -> index {
+ %0 = arith.muli %arg0, %arg1 overflow<nuw> : index
+ %1 = arith.divui %0, %arg0 : index
+ return %1 : index
+}
+// CHECK-LABEL: func @fold_divui_of_muli_0(
+// CHECK-SAME: %[[ARG0:.+]]: index,
+// CHECK-SAME: %[[ARG1:.+]]: index)
+// CHECK: return %[[ARG1]]
+
+func.func @fold_divui_of_muli_1(%arg0 : index, %arg1 : index) -> index {
+ %0 = arith.muli %arg0, %arg1 overflow<nuw> : index
+ %1 = arith.divui %0, %arg1 : index
+ return %1 : index
+}
+// CHECK-LABEL: func @fold_divui_of_muli_1(
+// CHECK-SAME: %[[ARG0:.+]]: index,
+// CHECK-SAME: %[[ARG1:.+]]: index)
+// CHECK: return %[[ARG0]]
+
+func.func @fold_divsi_of_muli_0(%arg0 : index, %arg1 : index) -> index {
+ %0 = arith.muli %arg0, %arg1 overflow<nsw> : index
+ %1 = arith.divsi %0, %arg0 : index
+ return %1 : index
+}
+// CHECK-LABEL: func @fold_divsi_of_muli_0(
+// CHECK-SAME: %[[ARG0:.+]]: index,
+// CHECK-SAME: %[[ARG1:.+]]: index)
+// CHECK: return %[[ARG1]]
+
+func.func @fold_divsi_of_muli_1(%arg0 : index, %arg1 : index) -> index {
+ %0 = arith.muli %arg0, %arg1 overflow<nsw> : index
+ %1 = arith.divsi %0, %arg1 : index
+ return %1 : index
+}
+// CHECK-LABEL: func @fold_divsi_of_muli_1(
+// CHECK-SAME: %[[ARG0:.+]]: index,
+// CHECK-SAME: %[[ARG1:.+]]: index)
+// CHECK: return %[[ARG0]]
+
+// Do not fold divui(mul(a, v), v) -> a with nuw attribute.
+func.func @no_fold_divui_of_muli(%arg0 : index, %arg1 : index) -> index {
+ %0 = arith.muli %arg0, %arg1 : index
+ %1 = arith.divui %0, %arg0 : index
+ return %1 : index
+}
+// CHECK-LABEL: func @no_fold_divui_of_muli
+// CHECK: %[[T0:.+]] = arith.muli
+// CHECK: %[[T1:.+]] = arith.divui %[[T0]],
+// CHECK: return %[[T1]]
+
+// Do not fold divsi(mul(a, v), v) -> a with nuw attribute.
+func.func @no_fold_divsi_of_muli(%arg0 : index, %arg1 : index) -> index {
+ %0 = arith.muli %arg0, %arg1 : index
+ %1 = arith.divsi %0, %arg0 : index
+ return %1 : index
+}
+// CHECK-LABEL: func @no_fold_divsi_of_muli
+// CHECK: %[[T0:.+]] = arith.muli
+// CHECK: %[[T1:.+]] = arith.divsi %[[T0]],
+// CHECK: return %[[T1]]
+
+// -----
+
// CHECK-LABEL: @test_cmpf(
func.func @test_cmpf(%arg0 : f32) -> (i1, i1, i1, i1) {
// CHECK-DAG: %[[T:.*]] = arith.constant true
diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-empty-tensor-elimination.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-empty-tensor-elimination.mlir
index 26434774730e..820fb3dfa5e5 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-empty-tensor-elimination.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-empty-tensor-elimination.mlir
@@ -465,3 +465,14 @@ func.func @mutli_use_of_the_same_tensor_empty_creates_non_existent_read(%arg1: t
: tensor<5x6x64xf32> into tensor<5x6x128xf32>
return %inserted_slice_1, %res_2 : tensor<5x6x128xf32>, tensor<5x6x64xf32>
}
+
+// -----
+
+// CHECK-LABEL: func.func @direct_use_of_tensor_empty
+func.func @direct_use_of_tensor_empty(%arg0: tensor<5x6x128xf32>) -> tensor<5x6x128xf32> {
+ // CHECK-NOT: memref.alloc
+ %empty_1 = tensor.empty() : tensor<5x6x64xf32>
+ %inserted_slice_1 = tensor.insert_slice %empty_1 into %arg0[0, 0, 0][5, 6, 64][1, 1, 1]
+ : tensor<5x6x64xf32> into tensor<5x6x128xf32>
+ return %inserted_slice_1 : tensor<5x6x128xf32>
+}
diff --git a/mlir/test/Dialect/GPU/indirect-device-func-call.mlir b/mlir/test/Dialect/GPU/indirect-device-func-call.mlir
index 91d7f1cd6c67..85805da3ac10 100644
--- a/mlir/test/Dialect/GPU/indirect-device-func-call.mlir
+++ b/mlir/test/Dialect/GPU/indirect-device-func-call.mlir
@@ -6,7 +6,7 @@ gpu.module @kernels {
func.func @hello(%arg0 : f32) {
%tid_x = gpu.thread_id x
%csti8 = arith.constant 2 : i8
- gpu.printf "Hello from %lld, %d, %f\n" %tid_x, %csti8, %arg0 : index, i8, f32
+ gpu.printf "Hello from %lld, %d, %f\n", %tid_x, %csti8, %arg0 : index, i8, f32
return
}
// CHECK-LABEL: @hello_indirect
diff --git a/mlir/test/Dialect/GPU/ops.mlir b/mlir/test/Dialect/GPU/ops.mlir
index c0ff2044b76c..99915c493ea4 100644
--- a/mlir/test/Dialect/GPU/ops.mlir
+++ b/mlir/test/Dialect/GPU/ops.mlir
@@ -229,9 +229,22 @@ module attributes {gpu.container_module} {
// CHECK-LABEL: gpu.func @printf_test
// CHECK: (%[[ARG0:.*]]: i32)
- // CHECK: gpu.printf "Value: %d" %[[ARG0]] : i32
+ // CHECK: gpu.printf "Value: %d", %[[ARG0]] : i32
gpu.func @printf_test(%arg0 : i32) {
- gpu.printf "Value: %d" %arg0 : i32
+ gpu.printf "Value: %d", %arg0 : i32
+ gpu.return
+ }
+
+ // CHECK-LABEL: gpu.func @printf_empty
+ // CHECK: gpu.printf "]"
+ // CHECK: scf.if
+ // CHECK: gpu.printf ", "
+ gpu.func @printf_empty(%arg0 : i32) {
+ gpu.printf "]"
+ %1 = arith.cmpi slt, %arg0, %arg0 : i32
+ scf.if %1 {
+ gpu.printf ", "
+ }
gpu.return
}
diff --git a/mlir/test/Dialect/GPU/test-nvvm-pipeline.mlir b/mlir/test/Dialect/GPU/test-nvvm-pipeline.mlir
index 732f40c4333d..f02b26dba97d 100644
--- a/mlir/test/Dialect/GPU/test-nvvm-pipeline.mlir
+++ b/mlir/test/Dialect/GPU/test-nvvm-pipeline.mlir
@@ -23,7 +23,7 @@ func.func @test_math(%arg0 : f32) {
threads(%6, %7, %8) in (%9 = %c2, %10 = %c1, %11 = %c1) {
// CHECK-NVVM: __nv_expf
%s1 = math.exp %arg0 : f32
- gpu.printf "%f" %s1 : f32
+ gpu.printf "%f", %s1 : f32
gpu.terminator
}
return
diff --git a/mlir/test/Dialect/LLVM/lower-to-llvm-e2e-with-target-tag.mlir b/mlir/test/Dialect/LLVM/lower-to-llvm-e2e-with-target-tag.mlir
index f6d3387d99b3..2785b5088612 100644
--- a/mlir/test/Dialect/LLVM/lower-to-llvm-e2e-with-target-tag.mlir
+++ b/mlir/test/Dialect/LLVM/lower-to-llvm-e2e-with-target-tag.mlir
@@ -28,7 +28,7 @@ func.func @subview(%0 : memref<64x4xf32, strided<[4, 1], offset: 0>>, %arg0 : in
// CHECK-SAME: -> !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>
// CHECK-DAG: %[[STRIDE0:.*]] = llvm.mlir.constant(4 : index) : i64
- // CHECK-DAG: %[[DESCSTRIDE0:.*]] = llvm.mul %[[ARG0]], %[[STRIDE0]] : i64
+ // CHECK-DAG: %[[DESCSTRIDE0:.*]] = llvm.mul %[[ARG0]], %[[STRIDE0]] overflow<nsw> : i64
// CHECK-DAG: %[[OFF2:.*]] = llvm.add %[[DESCSTRIDE0]], %[[ARG1]] : i64
// CHECK-DAG: %[[DESC:.*]] = llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
diff --git a/mlir/test/Dialect/LLVM/lower-to-llvm-e2e-with-top-level-named-sequence.mlir b/mlir/test/Dialect/LLVM/lower-to-llvm-e2e-with-top-level-named-sequence.mlir
index a74553cc2268..c1f30c7eaf64 100644
--- a/mlir/test/Dialect/LLVM/lower-to-llvm-e2e-with-top-level-named-sequence.mlir
+++ b/mlir/test/Dialect/LLVM/lower-to-llvm-e2e-with-top-level-named-sequence.mlir
@@ -27,7 +27,7 @@ func.func @subview(%0 : memref<64x4xf32, strided<[4, 1], offset: 0>>, %arg0 : in
// CHECK-SAME: -> !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>
// CHECK-DAG: %[[STRIDE0:.*]] = llvm.mlir.constant(4 : index) : i64
- // CHECK-DAG: %[[DESCSTRIDE0:.*]] = llvm.mul %[[ARG0]], %[[STRIDE0]] : i64
+ // CHECK-DAG: %[[DESCSTRIDE0:.*]] = llvm.mul %[[ARG0]], %[[STRIDE0]] overflow<nsw> : i64
// CHECK-DAG: %[[OFF2:.*]] = llvm.add %[[DESCSTRIDE0]], %[[ARG1]] : i64
// CHECK-DAG: %[[DESC:.*]] = llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
diff --git a/mlir/test/Dialect/LLVMIR/roundtrip.mlir b/mlir/test/Dialect/LLVMIR/roundtrip.mlir
index aebfd7492093..88660ce598f3 100644
--- a/mlir/test/Dialect/LLVMIR/roundtrip.mlir
+++ b/mlir/test/Dialect/LLVMIR/roundtrip.mlir
@@ -750,6 +750,16 @@ llvm.func @experimental_noalias_scope_decl() {
llvm.return
}
+#alias_scope_domain2 = #llvm.alias_scope_domain<id = "domainid", description = "The domain">
+#alias_scope2 = #llvm.alias_scope<id = "stringid", domain = #alias_scope_domain2, description = "The domain">
+
+// CHECK-LABEL: @experimental_noalias_scope_with_string_id
+llvm.func @experimental_noalias_scope_with_string_id() {
+ // CHECK: llvm.intr.experimental.noalias.scope.decl #{{.*}}
+ llvm.intr.experimental.noalias.scope.decl #alias_scope2
+ llvm.return
+}
+
// CHECK-LABEL: @experimental_constrained_fptrunc
llvm.func @experimental_constrained_fptrunc(%in: f64) {
// CHECK: llvm.intr.experimental.constrained.fptrunc %{{.*}} towardzero ignore : f64 to f32
diff --git a/mlir/test/Dialect/Linalg/decompose-tensor-unpack-tile.mlir b/mlir/test/Dialect/Linalg/decompose-tensor-unpack-tile.mlir
index 6d9709caf709..0dbdf470bbfc 100644
--- a/mlir/test/Dialect/Linalg/decompose-tensor-unpack-tile.mlir
+++ b/mlir/test/Dialect/Linalg/decompose-tensor-unpack-tile.mlir
@@ -1,4 +1,7 @@
-// RUN: mlir-opt -split-input-file --transform-interpreter --canonicalize --test-linalg-transform-patterns="test-decompose-tensor-unpack" %s | FileCheck %s
+// RUN: mlir-opt -split-input-file -transform-interpreter --canonicalize \
+// RUN: -transform-preload-library='transform-library-paths=%p/td/decompose-unpack.mlir' \
+// RUN: -transform-interpreter=entry-point=decompose_unpack \
+// RUN: -transform-interpreter %s | FileCheck %s
func.func @KCRSsr_to_KCRS(%arg0: tensor<1x1x4x8x8x32xf32>, %arg1: tensor<1x1x128x64xf32>) -> tensor<1x1x128x64xf32> {
%0 = tensor.unpack %arg0 inner_dims_pos = [3, 2] inner_tiles = [8, 32] into %arg1 : tensor<1x1x4x8x8x32xf32> -> tensor<1x1x128x64xf32>
diff --git a/mlir/test/Dialect/Linalg/decompose-tensor-unpack.mlir b/mlir/test/Dialect/Linalg/decompose-tensor-unpack.mlir
index bd60504f5334..ba1f21495256 100644
--- a/mlir/test/Dialect/Linalg/decompose-tensor-unpack.mlir
+++ b/mlir/test/Dialect/Linalg/decompose-tensor-unpack.mlir
@@ -1,4 +1,6 @@
-// RUN: mlir-opt -split-input-file --test-linalg-transform-patterns="test-decompose-tensor-unpack" %s | FileCheck %s
+// RUN: mlir-opt -split-input-file \
+// RUN: -transform-preload-library='transform-library-paths=%p/td/decompose-unpack.mlir' \
+// RUN: -transform-interpreter=entry-point=decompose_unpack %s | FileCheck %s
func.func @simple_KCRSsr_to_KCRS(%arg0: tensor<1x1x1x1x8x32xf32>, %arg1: tensor<1x1x32x8xf32>) -> tensor<1x1x32x8xf32> {
%0 = tensor.unpack %arg0 inner_dims_pos = [3, 2] inner_tiles = [8, 32] into %arg1 : tensor<1x1x1x1x8x32xf32> -> tensor<1x1x32x8xf32>
diff --git a/mlir/test/Dialect/Linalg/td/decompose-unpack.mlir b/mlir/test/Dialect/Linalg/td/decompose-unpack.mlir
new file mode 100644
index 000000000000..11243634262e
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/td/decompose-unpack.mlir
@@ -0,0 +1,12 @@
+module @transforms attributes { transform.with_named_sequence } {
+ transform.named_sequence @decompose_unpack(%module: !transform.any_op {transform.readonly}) {
+ %pack = transform.structured.match ops{["tensor.unpack"]} in %module : (!transform.any_op) -> !transform.any_op
+
+ %1 = transform.get_parent_op %pack {isolated_from_above} : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %1 {
+ transform.apply_patterns.linalg.decompose_pack_unpack
+ } : !transform.any_op
+
+ transform.yield
+ }
+}
diff --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir
index 8c4e7a41ee6b..828758df6d31 100644
--- a/mlir/test/Dialect/SCF/canonicalize.mlir
+++ b/mlir/test/Dialect/SCF/canonicalize.mlir
@@ -408,6 +408,20 @@ func.func @for_yields_4() -> i32 {
// -----
+// CHECK-LABEL: @constant_iter_arg
+func.func @constant_iter_arg(%arg0: index, %arg1: index, %arg2: index) {
+ %c0_i32 = arith.constant 0 : i32
+ // CHECK: scf.for %arg3 = %arg0 to %arg1 step %arg2 {
+ %0 = scf.for %i = %arg0 to %arg1 step %arg2 iter_args(%arg3 = %c0_i32) -> i32 {
+ // CHECK-NEXT: "test.use"(%c0_i32)
+ "test.use"(%arg3) : (i32) -> ()
+ scf.yield %c0_i32 : i32
+ }
+ return
+}
+
+// -----
+
// CHECK-LABEL: @replace_true_if
func.func @replace_true_if() {
%true = arith.constant true
@@ -1789,7 +1803,7 @@ module {
}
// CHECK-LABEL: @fold_iter_args_not_being_modified_within_scfforall
// CHECK-SAME: (%{{.*}}: index, %[[ARG1:.*]]: tensor<?xf32>, %[[ARG2:.*]]: tensor<?xf32>) -> (tensor<?xf32>, tensor<?xf32>) {
-// CHECK: %[[RESULT:.*]] = scf.forall
+// CHECK: %[[RESULT:.*]] = scf.forall
// CHECK-SAME: shared_outs(%[[ITER_ARG_5:.*]] = %[[ARG2]]) -> (tensor<?xf32>) {
// CHECK: %[[OPERAND0:.*]] = tensor.extract_slice %[[ARG1]]
// CHECK: %[[OPERAND1:.*]] = tensor.extract_slice %[[ITER_ARG_5]]
@@ -1832,7 +1846,7 @@ module {
}
// CHECK-LABEL: @fold_iter_args_with_no_use_of_result_scfforall
// CHECK-SAME: (%{{.*}}: index, %[[ARG1:.*]]: tensor<?xf32>, %[[ARG2:.*]]: tensor<?xf32>, %[[ARG3:.*]]: tensor<?xf32>) -> tensor<?xf32> {
-// CHECK: %[[RESULT:.*]] = scf.forall
+// CHECK: %[[RESULT:.*]] = scf.forall
// CHECK-SAME: shared_outs(%[[ITER_ARG_6:.*]] = %[[ARG2]]) -> (tensor<?xf32>) {
// CHECK: %[[OPERAND0:.*]] = tensor.extract_slice %[[ARG1]]
// CHECK: %[[OPERAND1:.*]] = tensor.extract_slice %[[ARG3]]
@@ -1856,7 +1870,7 @@ func.func @index_switch_fold() -> (f32, f32) {
%y = arith.constant 42.0 : f32
scf.yield %y : f32
}
-
+
%switch_cst_2 = arith.constant 2: index
%1 = scf.index_switch %switch_cst_2 -> f32
case 0 {
@@ -1867,7 +1881,7 @@ func.func @index_switch_fold() -> (f32, f32) {
%y = arith.constant 42.0 : f32
scf.yield %y : f32
}
-
+
return %0, %1 : f32, f32
}
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index e8fc4ce834e1..01d14871072c 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -2786,6 +2786,7 @@ func.func @fold_cast_multiple_results(%arg0: tensor<2x2xf32>, %arg1: tensor<2x2x
%0:2 = test.destination_style_op ins(%cast : tensor<?x2xf32>) outs(%cast_0 : tensor<?x2xf32>) -> tensor<?x2xf32>, index
return %0#1 : index
}
+
// -----
// CHECK-LABEL: func.func @fold_cast_pack_dynamic_tile_size
@@ -2794,7 +2795,7 @@ func.func @fold_cast_multiple_results(%arg0: tensor<2x2xf32>, %arg1: tensor<2x2x
// CHECK-SAME: %[[PAD:.*]]: i32) -> tensor<1x1x8x1xi32> {
// CHECK: %[[PACK:.*]] = tensor.pack %[[SRC]] padding_value(%[[PAD]] : i32)
// CHECK-SAME: inner_dims_pos = [0, 1] inner_tiles = [8, 1] into %[[DEST]]
-// CHECK-SAME: some_attr
+// CHECK-SAME: test_attr
// CHECK-SAME: : tensor<7x?xi32> -> tensor<1x1x8x1xi32>
// CHECK: return %[[PACK]] : tensor<1x1x8x1xi32>
func.func @fold_cast_pack_dynamic_tile_size(
@@ -2807,13 +2808,33 @@ func.func @fold_cast_pack_dynamic_tile_size(
%pack = tensor.pack %src padding_value(%pad : i32)
inner_dims_pos = [0, 1]
inner_tiles = [%c8, 1]
- into %cast {some_attr} : tensor<7x?xi32> -> tensor<1x1x?x1xi32>
+ into %cast {test_attr} : tensor<7x?xi32> -> tensor<1x1x?x1xi32>
%res = tensor.cast %pack : tensor<1x1x?x1xi32> to tensor<1x1x8x1xi32>
return %res : tensor<1x1x8x1xi32>
}
// -----
+// CHECK-LABEL: func.func @fold_cast_unpack_dynamic_tile_size(
+// CHECK-SAME: %[[SRC:.*]]: tensor<1x1x8x1xi32>,
+// CHECK-SAME: %[[DEST:.*]]: tensor<7x?xi32>) -> tensor<7x?xi32> {
+// CHECK: %[[RES:.*]] = tensor.unpack %[[SRC]] inner_dims_pos = [0, 1] inner_tiles = [8, 1] into %[[DEST]] {test_attr} : tensor<1x1x8x1xi32> -> tensor<7x?xi32>
+// CHECK: return %[[RES]] : tensor<7x?xi32>
+func.func @fold_cast_unpack_dynamic_tile_size(
+ %src: tensor<1x1x8x1xi32>,
+ %res: tensor<7x?xi32>) -> tensor<7x?xi32> {
+
+ %cast = tensor.cast %src : tensor<1x1x8x1xi32> to tensor<1x1x?x1xi32>
+ %c8 = arith.constant 8 : index
+ %unpack = tensor.unpack %cast
+ inner_dims_pos = [0, 1]
+ inner_tiles = [%c8, 1]
+ into %res {test_attr} : tensor<1x1x?x1xi32> -> tensor<7x?xi32>
+ return %unpack : tensor<7x?xi32>
+}
+
+// -----
+
// CHECK-LABEL: func.func @pack_dont_drop_attributes(
// CHECK: tensor.pack {{.*}} {test_attr}
func.func @pack_dont_drop_attributes(%arg0: tensor<?x?x?xf16>, %arg1: tensor<128x?x100x16x1xf16>) -> tensor<128x?x100x16x1xf16> {
diff --git a/mlir/test/Dialect/Tensor/invalid.mlir b/mlir/test/Dialect/Tensor/invalid.mlir
index 83cb4b9d4ab2..1de3e281bc46 100644
--- a/mlir/test/Dialect/Tensor/invalid.mlir
+++ b/mlir/test/Dialect/Tensor/invalid.mlir
@@ -699,7 +699,7 @@ func.func @pack_invalid_output_rank(%input: tensor<256x128xf32>, %output: tensor
// -----
-func.func @pack_invalid_output_rank(%input: tensor<256x128xf32>, %output: tensor<64x32x16xf32>) -> tensor<256x128xf32> {
+func.func @unpack_invalid_output_rank(%input: tensor<256x128xf32>, %output: tensor<64x32x16xf32>) -> tensor<256x128xf32> {
// expected-error@+1 {{packed rank != (unpacked rank + num tiling factors), got 3 != 4}}
%0 = tensor.unpack %output inner_dims_pos = [0, 1] inner_tiles = [32, 16] into %input : tensor<64x32x16xf32> -> tensor<256x128xf32>
return %0 : tensor<256x128xf32>
diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir
index 67cd01f62f0b..60121bb0ea2f 100644
--- a/mlir/test/Dialect/Tosa/canonicalize.mlir
+++ b/mlir/test/Dialect/Tosa/canonicalize.mlir
@@ -162,7 +162,7 @@ func.func @conv2d_stride_2(%arg0: tensor<4x10x10x2xf32>) -> tensor<4x10x10x3xf32
// CHECK: tosa.conv2d
%weight = "tosa.const"() {value = dense<[[[[1.0, 1.0]]], [[[1.0, 1.0]]], [[[1.0, 1.0]]]]> : tensor<3x1x1x2xf32>} : ()-> tensor<3x1x1x2xf32>
%bias = "tosa.const"() {value = dense<0.0> : tensor<3xf32>} : ()-> tensor<3xf32>
- %0 = tosa.conv2d %arg0, %weight, %bias {pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 2, 2>, dilation = array<i64: 1, 1>} : (tensor<4x10x10x2xf32>, tensor<3x1x1x2xf32>, tensor<3xf32>) -> tensor<4x10x10x3xf32>
+ %0 = tosa.conv2d %arg0, %weight, %bias {acc_type = f32, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 2, 2>, dilation = array<i64: 1, 1>} : (tensor<4x10x10x2xf32>, tensor<3x1x1x2xf32>, tensor<3xf32>) -> tensor<4x10x10x3xf32>
return %0 : tensor<4x10x10x3xf32>
}
@@ -173,7 +173,7 @@ func.func @conv2d_weight_2x2(%arg0: tensor<4x10x10x1xf32>) -> tensor<4x10x10x1xf
// CHECK: tosa.conv2d
%weight = "tosa.const"() {value = dense<[[[[1.0], [1.0]], [[1.0], [1.0]]]]> : tensor<1x2x2x1xf32>} : ()-> tensor<1x2x2x1xf32>
%bias = "tosa.const"() {value = dense<0.0> : tensor<1xf32>} : ()-> tensor<1xf32>
- %0 = tosa.conv2d %arg0, %weight, %bias {pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, dilation = array<i64: 1, 1>} : (tensor<4x10x10x1xf32>, tensor<1x2x2x1xf32>, tensor<1xf32>) -> tensor<4x10x10x1xf32>
+ %0 = tosa.conv2d %arg0, %weight, %bias {acc_type = f32, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, dilation = array<i64: 1, 1>} : (tensor<4x10x10x1xf32>, tensor<1x2x2x1xf32>, tensor<1xf32>) -> tensor<4x10x10x1xf32>
return %0 : tensor<4x10x10x1xf32>
}
@@ -182,7 +182,7 @@ func.func @conv2d_weight_2x2(%arg0: tensor<4x10x10x1xf32>) -> tensor<4x10x10x1xf
// CHECK-LABEL: @depthwise_conv2d_stride_2
func.func @depthwise_conv2d_stride_2(%arg0: tensor<4x10x10x2xf32>, %arg1: tensor<1x1x2x3xf32>, %arg2: tensor<6xf32>) -> tensor<4x10x10x6xf32> {
// CHECK: tosa.depthwise_conv2d
- %0 = tosa.depthwise_conv2d %arg0, %arg1, %arg2 {pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 2, 2>, dilation = array<i64: 1, 1>} : (tensor<4x10x10x2xf32>, tensor<1x1x2x3xf32>, tensor<6xf32>) -> tensor<4x10x10x6xf32>
+ %0 = tosa.depthwise_conv2d %arg0, %arg1, %arg2 {acc_type = f32, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 2, 2>, dilation = array<i64: 1, 1>} : (tensor<4x10x10x2xf32>, tensor<1x1x2x3xf32>, tensor<6xf32>) -> tensor<4x10x10x6xf32>
return %0 : tensor<4x10x10x6xf32>
}
@@ -191,7 +191,7 @@ func.func @depthwise_conv2d_stride_2(%arg0: tensor<4x10x10x2xf32>, %arg1: tensor
// CHECK-LABEL: @depthwise_conv2d_weight_2x2
func.func @depthwise_conv2d_weight_2x2(%arg0: tensor<4x10x10x2xf32>, %arg1: tensor<2x2x2x3xf32>, %arg2: tensor<6xf32>) -> tensor<4x10x10x6xf32> {
// CHECK: tosa.depthwise_conv2d
- %0 = tosa.depthwise_conv2d %arg0, %arg1, %arg2 {pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, dilation = array<i64: 1, 1>} : (tensor<4x10x10x2xf32>, tensor<2x2x2x3xf32>, tensor<6xf32>) -> tensor<4x10x10x6xf32>
+ %0 = tosa.depthwise_conv2d %arg0, %arg1, %arg2 {acc_type = f32, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, dilation = array<i64: 1, 1>} : (tensor<4x10x10x2xf32>, tensor<2x2x2x3xf32>, tensor<6xf32>) -> tensor<4x10x10x6xf32>
return %0 : tensor<4x10x10x6xf32>
}
@@ -210,8 +210,8 @@ func.func @max_pool2d_is_noop(%arg0: tensor<10x1x1x3xf32>) -> tensor<10x1x1x3xf3
// CHECK-LABEL: @pad_noop
func.func @pad_noop(%arg0: tensor<?x?xf32>) -> tensor<?x?xf32> {
// CHECK: return %arg0
- %0 = "tosa.const"() { value = dense<0> : tensor<2x2xi32>} : () -> tensor<2x2xi32>
- %1 = tosa.pad %arg0, %0 : (tensor<?x?xf32>, tensor<2x2xi32>) -> tensor<?x?xf32>
+ %0 = "tosa.const"() { value = dense<0> : tensor<4xi32>} : () -> tensor<4xi32>
+ %1 = tosa.pad %arg0, %0 : (tensor<?x?xf32>, tensor<4xi32>) -> tensor<?x?xf32>
return %1 : tensor<?x?xf32>
}
@@ -221,8 +221,8 @@ func.func @pad_noop(%arg0: tensor<?x?xf32>) -> tensor<?x?xf32> {
func.func @pad_noop_padding_mismatch_nofold(%arg0: tensor<?x?xf32>) -> tensor<?x?xf32> {
// CHECK: %[[PAD:.+]] = tosa.pad
// CHECK: return %[[PAD]]
- %0 = "tosa.const"() { value = dense_resource<__elided__> : tensor<2x2xi32>} : () -> tensor<2x2xi32>
- %1 = tosa.pad %arg0, %0 : (tensor<?x?xf32>, tensor<2x2xi32>) -> tensor<?x?xf32>
+ %0 = "tosa.const"() { value = dense_resource<__elided__> : tensor<4xi32>} : () -> tensor<4xi32>
+ %1 = tosa.pad %arg0, %0 : (tensor<?x?xf32>, tensor<4xi32>) -> tensor<?x?xf32>
return %1 : tensor<?x?xf32>
}
@@ -234,42 +234,39 @@ func.func @pad_noop_type_mismatch_nofold(%arg0: tensor<10xf32>) -> tensor<?xf32>
// CHECK: return %[[PAD]]
%c0_i32 = arith.constant 0 : i32
- %shape = tensor.from_elements %c0_i32, %c0_i32 : tensor<1x2xi32>
+ %shape = tensor.from_elements %c0_i32, %c0_i32 : tensor<2xi32>
- %0 = tosa.pad %arg0, %shape : (tensor<10xf32>, tensor<1x2xi32>) -> tensor<?xf32>
+ %0 = tosa.pad %arg0, %shape : (tensor<10xf32>, tensor<2xi32>) -> tensor<?xf32>
return %0 : tensor<?xf32>
}
// -----
// CHECK-LABEL: @pad_determine_val_i32
-func.func @pad_determine_val_i32(%arg0: tensor<?x?xi32>, %arg1 : tensor<2x2xi32>) -> tensor<?x?xi32> {
+func.func @pad_determine_val_i32(%arg0: tensor<?x?xi32>, %arg1 : tensor<4xi32>) -> tensor<?x?xi32> {
// CHECK: %[[ZERO:.+]] = "tosa.const"() <{value = dense<0> : tensor<i32>}
// CHECK: tosa.pad %arg0, %arg1, %[[ZERO]]
- %0 = "tosa.const"() { value = dense<[[1, 0], [0, 1]]> : tensor<2x2xi32>} : () -> tensor<2x2xi32>
- %1 = tosa.pad %arg0, %arg1 : (tensor<?x?xi32>, tensor<2x2xi32>) -> tensor<?x?xi32>
+ %1 = tosa.pad %arg0, %arg1 : (tensor<?x?xi32>, tensor<4xi32>) -> tensor<?x?xi32>
return %1 : tensor<?x?xi32>
}
// -----
// CHECK-LABEL: @pad_determine_val_f32
-func.func @pad_determine_val_f32(%arg0: tensor<?x?xf32>, %arg1 : tensor<2x2xi32>) -> tensor<?x?xf32> {
+func.func @pad_determine_val_f32(%arg0: tensor<?x?xf32>, %arg1 : tensor<4xi32>) -> tensor<?x?xf32> {
// CHECK: %[[ZERO:.+]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<f32>}
// CHECK: tosa.pad %arg0, %arg1, %[[ZERO]]
- %0 = "tosa.const"() { value = dense<[[1, 0], [0, 1]]> : tensor<2x2xi32>} : () -> tensor<2x2xi32>
- %1 = tosa.pad %arg0, %arg1 : (tensor<?x?xf32>, tensor<2x2xi32>) -> tensor<?x?xf32>
+ %1 = tosa.pad %arg0, %arg1 : (tensor<?x?xf32>, tensor<4xi32>) -> tensor<?x?xf32>
return %1 : tensor<?x?xf32>
}
// -----
// CHECK-LABEL: @pad_determine_val_quant
-func.func @pad_determine_val_quant(%arg0: tensor<?x?xi32>, %arg1 : tensor<2x2xi32>) -> tensor<?x?xi32> {
+func.func @pad_determine_val_quant(%arg0: tensor<?x?xi32>, %arg1 : tensor<4xi32>) -> tensor<?x?xi32> {
// CHECK: %[[ZERO:.+]] = "tosa.const"() <{value = dense<42> : tensor<i32>}
// CHECK: tosa.pad %arg0, %arg1, %[[ZERO]]
- %0 = "tosa.const"() { value = dense<[[1, 0], [0, 1]]> : tensor<2x2xi32>} : () -> tensor<2x2xi32>
- %1 = tosa.pad %arg0, %arg1 {quantization_info = #tosa.pad_quant<input_zp = 42>} : (tensor<?x?xi32>, tensor<2x2xi32>) -> tensor<?x?xi32>
+ %1 = tosa.pad %arg0, %arg1 {quantization_info = #tosa.pad_quant<input_zp = 42>} : (tensor<?x?xi32>, tensor<4xi32>) -> tensor<?x?xi32>
return %1 : tensor<?x?xi32>
}
diff --git a/mlir/test/Dialect/Tosa/constant-op-fold.mlir b/mlir/test/Dialect/Tosa/constant-op-fold.mlir
index 2902c4a62009..8198903b78ac 100644
--- a/mlir/test/Dialect/Tosa/constant-op-fold.mlir
+++ b/mlir/test/Dialect/Tosa/constant-op-fold.mlir
@@ -117,15 +117,6 @@ func.func @transpose_nofold_multi_users() -> (tensor<3x2xf32>, tensor<2x3xf32>)
return %1, %input : tensor<3x2xf32>, tensor<2x3xf32>
}
-// CHECK-LABEL: @transpose_nofold_quantized_types
-func.func @transpose_nofold_quantized_types() -> tensor<1x1x2x2x!quant.uniform<i8<-127:127>:f32:3, {1.000000e-01,1.000000e-01}>> {
- %perms = "tosa.const"() {value = dense<[1, 2, 3, 0]> : tensor<4xi32>} : () -> tensor<4xi32>
- %input = "tosa.const"() {value = dense<-127> : tensor<2x1x1x2xi8>} : () -> tensor<2x1x1x2xi8>
- // CHECK: tosa.transpose
- %0 = tosa.transpose %input, %perms : (tensor<2x1x1x2xi8>, tensor<4xi32>) -> tensor<1x1x2x2x!quant.uniform<i8<-127:127>:f32:3, {1.000000e-01,1.000000e-01}>>
- return %0: tensor<1x1x2x2x!quant.uniform<i8<-127:127>:f32:3, {1.000000e-01,1.000000e-01}>>
-}
-
// CHECK-LABEL: @transpose_nofold_dense_resource
func.func @transpose_nofold_dense_resource() -> tensor<2x2xf32> {
%0 = "tosa.const"() <{value = dense_resource<resource> : tensor<2x2xf32>}> : () -> tensor<2x2xf32>
diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir
index cca50b25d14d..a6d57f8a2f61 100644
--- a/mlir/test/Dialect/Tosa/invalid.mlir
+++ b/mlir/test/Dialect/Tosa/invalid.mlir
@@ -25,7 +25,7 @@ func.func @test_const_non_tensor_attr() {
func.func @test_conv2d(%arg0: tensor<1x29x29x4xf32>, %arg1: tensor<16x3x3x4xi8>, %arg2: tensor<16xi8>) -> tensor<1x27x27x16xi8> {
// expected-error@+1 {{expect both input and weight to be float or not together, got 'f32' and 'i8'}}
- %0 = tosa.conv2d %arg0, %arg1, %arg2 {dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>}
+ %0 = tosa.conv2d %arg0, %arg1, %arg2 {acc_type = i32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>}
: (tensor<1x29x29x4xf32>, tensor<16x3x3x4xi8>, tensor<16xi8>) -> tensor<1x27x27x16xi8>
return %0 : tensor<1x27x27x16xi8>
}
@@ -34,7 +34,7 @@ func.func @test_conv2d(%arg0: tensor<1x29x29x4xf32>, %arg1: tensor<16x3x3x4xi8>,
func.func @test_conv2d(%arg0: tensor<*xi8>, %arg1: tensor<16x3x3x4xi8>, %arg2: tensor<16xi8>) -> tensor<1x27x27x16xi8> {
// expected-error@+1 {{expect a ranked tensor for input, got <block argument> of type 'tensor<*xi8>' at index: 0}}
- %0 = tosa.conv2d %arg0, %arg1, %arg2 {dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>}
+ %0 = tosa.conv2d %arg0, %arg1, %arg2 {acc_type = i32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>}
: (tensor<*xi8>, tensor<16x3x3x4xi8>, tensor<16xi8>) -> tensor<1x27x27x16xi8>
return %0 : tensor<1x27x27x16xi8>
}
@@ -43,7 +43,7 @@ func.func @test_conv2d(%arg0: tensor<*xi8>, %arg1: tensor<16x3x3x4xi8>, %arg2: t
func.func @test_conv2d(%arg0: tensor<1x29x29x4xi8>, %arg1: tensor<*xi8>, %arg2: tensor<16xi8>) -> tensor<1x27x27x16xi8> {
// expected-error@+1 {{'tosa.conv2d' op operand #1 must be 4D tensor of 4-bit signless integer or 8-bit signless integer or Quint8 type or Qint4 type or Qint8 type or Qint16 type or Qint32 type or floating-point values, but got 'tensor<*xi8>'}}
- %0 = tosa.conv2d %arg0, %arg1, %arg2 {dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>}
+ %0 = tosa.conv2d %arg0, %arg1, %arg2 {acc_type = i32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>}
: (tensor<1x29x29x4xi8>, tensor<*xi8>, tensor<16xi8>) -> tensor<1x27x27x16xi8>
return %0 : tensor<1x27x27x16xi8>
}
@@ -52,13 +52,101 @@ func.func @test_conv2d(%arg0: tensor<1x29x29x4xi8>, %arg1: tensor<*xi8>, %arg2:
func.func @test_conv2d(%arg0: tensor<1x29x29x4xi8>, %arg1: tensor<16x3x3x4xi8>, %arg2: tensor<16xi8>) -> tensor<1x27x27x16xi8> {
// expected-error@+1 {{'tosa.conv2d' op quantizationattr is required for quantized type, and not allowed for float type}}
- %0 = tosa.conv2d %arg0, %arg1, %arg2 {dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>}
+ %0 = tosa.conv2d %arg0, %arg1, %arg2 {acc_type = f16, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>}
: (tensor<1x29x29x4xi8>, tensor<16x3x3x4xi8>, tensor<16xi8>) -> tensor<1x27x27x16xi8>
return %0 : tensor<1x27x27x16xi8>
}
// -----
+func.func @test_conv2d_acc_type(%arg0: tensor<1x29x29x4xi8>, %arg1: tensor<16x3x3x4xi8>, %arg2: tensor<16xi8>) -> tensor<1x27x27x16xi8> {
+ // expected-error@+1 {{'tosa.conv2d' op accumulator type for i8 tensor is not i32}}
+ %0 = tosa.conv2d %arg0, %arg1, %arg2 {acc_type = f16, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, quantization_info = #tosa.conv_quant<input_zp = 0, weight_zp = 0>}
+ : (tensor<1x29x29x4xi8>, tensor<16x3x3x4xi8>, tensor<16xi8>) -> tensor<1x27x27x16xi8>
+ return %0 : tensor<1x27x27x16xi8>
+}
+
+// -----
+
+func.func @test_conv2d_acc_type(%arg0: tensor<1x29x29x4xi16>, %arg1: tensor<16x3x3x4xi8>, %arg2: tensor<16xi16>) -> tensor<1x27x27x16xi16> {
+ // expected-error@+1 {{'tosa.conv2d' op accumulator type for i16 tensor is not i48}}
+ %0 = tosa.conv2d %arg0, %arg1, %arg2 {acc_type = f16, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, quantization_info = #tosa.conv_quant<input_zp = 0, weight_zp = 0>}
+ : (tensor<1x29x29x4xi16>, tensor<16x3x3x4xi8>, tensor<16xi16>) -> tensor<1x27x27x16xi16>
+ return %0 : tensor<1x27x27x16xi16>
+}
+
+// -----
+
+func.func @test_conv2d_acc_type(%arg0: tensor<1x29x29x4xf8E5M2>, %arg1: tensor<16x3x3x4xf8E5M2>, %arg2: tensor<16xf16>) -> tensor<1x27x27x16xf16> {
+ // expected-error@+1 {{'tosa.conv2d' op accumulator type for f8 tensor is not f16}}
+ %0 = tosa.conv2d %arg0, %arg1, %arg2 {acc_type = i32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>}
+ : (tensor<1x29x29x4xf8E5M2>, tensor<16x3x3x4xf8E5M2>, tensor<16xf16>) -> tensor<1x27x27x16xf16>
+ return %0 : tensor<1x27x27x16xf16>
+}
+
+// -----
+
+func.func @test_conv2d_acc_type(%arg0: tensor<1x29x29x4xf8E4M3>, %arg1: tensor<16x3x3x4xf8E4M3>, %arg2: tensor<16xf16>) -> tensor<1x27x27x16xf16> {
+ // expected-error@+1 {{'tosa.conv2d' op accumulator type for f8 tensor is not f16}}
+ %0 = tosa.conv2d %arg0, %arg1, %arg2 {acc_type = i32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>}
+ : (tensor<1x29x29x4xf8E4M3>, tensor<16x3x3x4xf8E4M3>, tensor<16xf16>) -> tensor<1x27x27x16xf16>
+ return %0 : tensor<1x27x27x16xf16>
+}
+
+// -----
+
+func.func @test_conv2d_acc_type(%arg0: tensor<1x29x29x4xf16>, %arg1: tensor<16x3x3x4xf16>, %arg2: tensor<16xf16>) -> tensor<1x27x27x16xf16> {
+ // expected-error@+1 {{'tosa.conv2d' op accumulator type for f16 tensor is not f16/f32}}
+ %0 = tosa.conv2d %arg0, %arg1, %arg2 {acc_type = i32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>}
+ : (tensor<1x29x29x4xf16>, tensor<16x3x3x4xf16>, tensor<16xf16>) -> tensor<1x27x27x16xf16>
+ return %0 : tensor<1x27x27x16xf16>
+}
+
+// -----
+
+func.func @test_conv2d_acc_type(%arg0: tensor<1x29x29x4xbf16>, %arg1: tensor<16x3x3x4xbf16>, %arg2: tensor<16xbf16>) -> tensor<1x27x27x16xbf16> {
+ // expected-error@+1 {{'tosa.conv2d' op accumulator type for bf16 tensor is not f32}}
+ %0 = tosa.conv2d %arg0, %arg1, %arg2 {acc_type = i32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>}
+ : (tensor<1x29x29x4xbf16>, tensor<16x3x3x4xbf16>, tensor<16xbf16>) -> tensor<1x27x27x16xbf16>
+ return %0 : tensor<1x27x27x16xbf16>
+}
+
+// -----
+
+func.func @test_conv2d_acc_type(%arg0: tensor<1x29x29x4xf32>, %arg1: tensor<16x3x3x4xf32>, %arg2: tensor<16xf32>) -> tensor<1x27x27x16xf32> {
+ // expected-error@+1 {{'tosa.conv2d' op accumulator type for f32 tensor is not f32}}
+ %0 = tosa.conv2d %arg0, %arg1, %arg2 {acc_type = i32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>}
+ : (tensor<1x29x29x4xf32>, tensor<16x3x3x4xf32>, tensor<16xf32>) -> tensor<1x27x27x16xf32>
+ return %0 : tensor<1x27x27x16xf32>
+}
+
+// -----
+
+func.func @test_conv3d_acc_type(%arg0: tensor<1x4x8x21x17xi8>, %arg1: tensor<34x1x1x1x17xi8>, %arg2: tensor<34xi8>) -> tensor<1x4x8x21x34xi8> {
+ // expected-error@+1 {{'tosa.conv3d' op accumulator type for i8 tensor is not i32}}
+ %0 = tosa.conv3d %arg0, %arg1, %arg2 {acc_type = f16, dilation = array<i64: 1, 1, 1>, pad = array<i64: 0, 0, 0, 0, 0, 0>, stride = array<i64: 1, 1, 1>, quantization_info = #tosa.conv_quant<input_zp = 0, weight_zp = 0>}
+ : (tensor<1x4x8x21x17xi8>, tensor<34x1x1x1x17xi8>, tensor<34xi8>) -> tensor<1x4x8x21x34xi8>
+ return %0 : tensor<1x4x8x21x34xi8>
+}
+
+// -----
+
+func.func @test_depthwise_conv2d_acc_type(%arg0: tensor<1x4x4x4xi8>, %arg1: tensor<1x1x4x2xi8>, %arg2: tensor<8xi8>) -> tensor<1x4x4x8xi8> {
+ // expected-error@+1 {{'tosa.depthwise_conv2d' op accumulator type for i8 tensor is not i32}}
+ %0 = tosa.depthwise_conv2d %arg0, %arg1, %arg2 {acc_type = f16, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, quantization_info = #tosa.conv_quant<input_zp = 0, weight_zp = 0>} : (tensor<1x4x4x4xi8>, tensor<1x1x4x2xi8>, tensor<8xi8>) -> tensor<1x4x4x8xi8>
+ return %0 : tensor<1x4x4x8xi8>
+}
+
+// -----
+
+func.func @test_transpose_conv2d(%arg0: tensor<1x32x32x8xi8>, %arg1: tensor<16x1x1x8xi8>, %arg2: tensor<16xi8>) -> tensor<1x32x32x16xi8> {
+ // expected-error@+1 {{'tosa.transpose_conv2d' op accumulator type for i8 tensor is not i32}}
+ %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2 {acc_type = f16, out_pad = array<i64: 0, 0, 0, 0>, out_shape = array<i64: 1, 32, 32, 16>, stride = array<i64: 1, 1>, quantization_info = #tosa.conv_quant<input_zp = 0, weight_zp = 0>} : (tensor<1x32x32x8xi8>, tensor<16x1x1x8xi8>, tensor<16xi8>) -> tensor<1x32x32x16xi8>
+ return %0 : tensor<1x32x32x16xi8>
+}
+
+// -----
+
func.func @test_concat(%arg0 : tensor<2x1xf32>, %arg1 : tensor<2x2xf32>) -> tensor<?x?xf32> {
// expected-error@+2 {{failed to infer returned types}}
// expected-error@+1 {{Cannot concat tensors with different sizes on the non-axis dimension 1}}
@@ -77,48 +165,56 @@ func.func @test_concat_element_type_mismatch(%arg0 : tensor<1x2xf32>, %arg1 : te
// -----
-func.func @test_pad_non_const(%arg0: tensor<13x21x3xf32>, %arg1: tensor<3x2xi32>) -> tensor<13x21x3xf32> {
+func.func @test_pad_non_const(%arg0: tensor<13x21x3xf32>, %arg1: tensor<6xi32>) -> tensor<13x21x3xf32> {
// expected-error@+1 {{'tosa.pad' op padding of pad is not constant}}
- %0 = tosa.pad %arg0, %arg1 : (tensor<13x21x3xf32>, tensor<3x2xi32>) -> tensor<13x21x3xf32>
+ %0 = tosa.pad %arg0, %arg1 : (tensor<13x21x3xf32>, tensor<6xi32>) -> tensor<13x21x3xf32>
return %0 : tensor<13x21x3xf32>
}
// -----
func.func @test_pad_non_const(%arg0: tensor<13x21x3xi8>, %arg1: tensor<i8>) -> tensor<13x21x3xi8> {
- %0 = "tosa.const"() {value = dense<[[0, 0], [0, 1], [0, 1]]> : tensor<3x2xi32>} : () -> tensor<3x2xi32>
+ %0 = "tosa.const"() {value = dense<[0, 0, 0, 1, 0, 1]> : tensor<6xi32>} : () -> tensor<6xi32>
// expected-error@+1 {{'tosa.pad' op pad_const of pad is not constant}}
- %1 = tosa.pad %arg0, %0, %arg1 : (tensor<13x21x3xi8>, tensor<3x2xi32>, tensor<i8>) -> tensor<13x21x3xi8>
+ %1 = tosa.pad %arg0, %0, %arg1 : (tensor<13x21x3xi8>, tensor<6xi32>, tensor<i8>) -> tensor<13x21x3xi8>
return %1 : tensor<13x21x3xi8>
}
// -----
-func.func @test_pad_io_rank_mismatch(%arg0: tensor<13x21xf32>, %arg1: tensor<2x2xi32>) {
+func.func @test_pad_io_rank_mismatch(%arg0: tensor<13x21xf32>, %arg1: tensor<4xi32>) {
// expected-error@+1 {{'tosa.pad' op expect same input and output tensor rank.}}
- %1 = tosa.pad %arg0, %arg1 : (tensor<13x21xf32>, tensor<2x2xi32>) -> tensor<13x21x3xf32>
+ %1 = tosa.pad %arg0, %arg1 : (tensor<13x21xf32>, tensor<4xi32>) -> tensor<13x21x3xf32>
return
}
// -----
-func.func @test_pad_invalid_padding_rank(%arg0: tensor<13x21xf32>, %arg1: tensor<2xi32>) {
- // expected-error@+1 {{'tosa.pad' op expect 'padding' tensor rank equal to 2.}}
- %1 = tosa.pad %arg0, %arg1 : (tensor<13x21xf32>, tensor<2xi32>) -> tensor<13x21xf32>
+func.func @test_pad_invalid_padding_rank(%arg0: tensor<13x21xf32>, %arg1: tensor<2x2xi32>) {
+ // expected-error@+1 {{'tosa.pad' op operand #1 must be 1D tensor of 32-bit signless integer or 64-bit signless integer values, but got 'tensor<2x2xi32>'}}
+ %1 = tosa.pad %arg0, %arg1 : (tensor<13x21xf32>, tensor<2x2xi32>) -> tensor<13x21xf32>
return
}
// -----
-func.func @test_pad_invalid_padConst_rank(%arg0: tensor<13x21xf32>, %arg1: tensor<2x2xi32>) {
+func.func @test_pad_invalid_padConst_rank(%arg0: tensor<13x21xf32>, %arg1: tensor<4xi32>) {
%0 = "tosa.const"() {value = dense<3.14> : tensor<1xf32>} : () -> tensor<1xf32>
// expected-error@+1 {{'tosa.pad' op operand #2 must be 0D tensor of number values, but got 'tensor<1xf32>'}}
- %1 = tosa.pad %arg0, %arg1, %0 : (tensor<13x21xf32>, tensor<2x2xi32>, tensor<1xf32>) -> tensor<13x21xf32>
+ %1 = tosa.pad %arg0, %arg1, %0 : (tensor<13x21xf32>, tensor<4xi32>, tensor<1xf32>) -> tensor<13x21xf32>
return
}
// -----
+func.func @test_pad_padding_shape_mismatch(%arg0: tensor<13x21x3xf32>, %arg1: tensor<4xi32>) -> tensor<13x21x3xf32> {
+ // expected-error@+1 {{'tosa.pad' op expected padding tensor dim 0 to have size 6 (2*rank(shape1)) but got size 4}}
+ %0 = tosa.pad %arg0, %arg1 : (tensor<13x21x3xf32>, tensor<4xi32>) -> tensor<13x21x3xf32>
+ return %0 : tensor<13x21x3xf32>
+}
+
+// -----
+
func.func @test_transpose_non_const(%arg0: tensor<13x21x3xf32>, %arg1: tensor<3xi32>) -> tensor<3x13x21xf32> {
// expected-error@+1 {{'tosa.transpose' op perms of transpose is not constant}}
%0 = tosa.transpose %arg0, %arg1 : (tensor<13x21x3xf32>, tensor<3xi32>) -> tensor<3x13x21xf32>
@@ -206,6 +302,15 @@ func.func @test_transpose_invalid_permutation_types_dynamic_dim_ok(%arg0: tensor
// -----
+func.func @test_transpose_element_type_mismatch(%arg0: tensor<2x3xi32>) -> tensor<3x2xf32> {
+ %perms = "tosa.const"() {value = dense<[1, 0]> : tensor<2xi32>} : () -> tensor<2xi32>
+ // expected-error@+1 {{'tosa.transpose' op failed to verify that all of {input1, output} have same element type}}
+ %1 = tosa.transpose %arg0, %perms : (tensor<2x3xi32>, tensor<2xi32>) -> tensor<3x2xf32>
+ return %1 : tensor<3x2xf32>
+}
+
+// -----
+
func.func @test_fully_connected_non_const(%arg0: tensor<13x21x3xf32>, %arg1: tensor<2x3xf32>) -> tensor<273x2xf32> {
%0 = "tosa.const"() {value = dense<0.000000e+00> : tensor<2xf32>} : () -> tensor<2xf32>
%1 = tosa.reshape %arg0 {new_shape = array<i64: 273, 3>} : (tensor<13x21x3xf32>) -> tensor<273x3xf32>
@@ -416,7 +521,7 @@ func.func @test_const_attribute_type_mismatch() -> tensor<100x100xf32> {
func.func @test_conv2d_static_zero_dim_input(%arg0: tensor<1x29x0x4xf32>, %arg1: tensor<16x3x3x4xf32>, %arg2: tensor<16xf32>) -> tensor<1x27x27x16xf32> {
// expected-error@+1 {{'tosa.conv2d' op operand #0 must be 4-d tosa-conformant tensor, but got 'tensor<1x29x0x4xf32>'}}
- %0 = "tosa.conv2d"(%arg0, %arg1, %arg2) {dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>}
+ %0 = "tosa.conv2d"(%arg0, %arg1, %arg2) {acc_type = f32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>}
: (tensor<1x29x0x4xf32>, tensor<16x3x3x4xf32>, tensor<16xf32>) -> tensor<1x27x27x16xf32>
return %0 : tensor<1x27x27x16xf32>
}
diff --git a/mlir/test/Dialect/Tosa/level_check.mlir b/mlir/test/Dialect/Tosa/level_check.mlir
index 529a16ca48c7..ba8ed8a1e5f5 100644
--- a/mlir/test/Dialect/Tosa/level_check.mlir
+++ b/mlir/test/Dialect/Tosa/level_check.mlir
@@ -226,7 +226,7 @@ func.func @test_avgpool2d_pad_right(%arg0: tensor<1x32x32x8xf32>) -> tensor<1x32
func.func @test_conv2d_dilation_y(%arg0: tensor<1x32x32x8xf32>, %arg1: tensor<16x2x2x8xf32>, %arg2: tensor<16xf32>) -> tensor<1x32x32x16xf32> {
// expected-error@+1 {{'tosa.conv2d' op failed level check: dilation_y * KH <= MAX_KERNEL}}
- %0 = "tosa.conv2d"(%arg0, %arg1, %arg2) {dilation = array<i64: 4097, 1>, pad = array<i64: 0, 1, 0, 1>, stride = array<i64: 1, 1>} :
+ %0 = "tosa.conv2d"(%arg0, %arg1, %arg2) {acc_type = f32, dilation = array<i64: 4097, 1>, pad = array<i64: 0, 1, 0, 1>, stride = array<i64: 1, 1>} :
(tensor<1x32x32x8xf32>, tensor<16x2x2x8xf32>, tensor<16xf32>) -> tensor<1x32x32x16xf32>
return %0 : tensor<1x32x32x16xf32>
}
@@ -235,7 +235,7 @@ func.func @test_conv2d_dilation_y(%arg0: tensor<1x32x32x8xf32>, %arg1: tensor<16
func.func @test_conv2d_dilation_x(%arg0: tensor<1x32x32x8xf32>, %arg1: tensor<16x2x2x8xf32>, %arg2: tensor<16xf32>) -> tensor<1x32x32x16xf32> {
// expected-error@+1 {{'tosa.conv2d' op failed level check: dilation_x * KW <= MAX_KERNEL}}
- %0 = "tosa.conv2d"(%arg0, %arg1, %arg2) {dilation = array<i64: 1, 4097>, pad = array<i64: 0, 1, 0, 1>, stride = array<i64: 1, 1>} :
+ %0 = "tosa.conv2d"(%arg0, %arg1, %arg2) {acc_type = f32, dilation = array<i64: 1, 4097>, pad = array<i64: 0, 1, 0, 1>, stride = array<i64: 1, 1>} :
(tensor<1x32x32x8xf32>, tensor<16x2x2x8xf32>, tensor<16xf32>) -> tensor<1x32x32x16xf32>
return %0 : tensor<1x32x32x16xf32>
}
@@ -244,7 +244,7 @@ func.func @test_conv2d_dilation_x(%arg0: tensor<1x32x32x8xf32>, %arg1: tensor<16
func.func @test_conv2d_pad_top(%arg0: tensor<1x32x32x8xf32>, %arg1: tensor<16x2x2x8xf32>, %arg2: tensor<16xf32>) -> tensor<1x32x32x16xf32> {
// expected-error@+1 {{'tosa.conv2d' op failed level check: pad <= MAX_KERNEL}}
- %0 = "tosa.conv2d"(%arg0, %arg1, %arg2) {dilation = array<i64: 1, 1>, pad = array<i64: 8193, 1, 0, 1>, stride = array<i64: 1, 1>} :
+ %0 = "tosa.conv2d"(%arg0, %arg1, %arg2) {acc_type = f32, dilation = array<i64: 1, 1>, pad = array<i64: 8193, 1, 0, 1>, stride = array<i64: 1, 1>} :
(tensor<1x32x32x8xf32>, tensor<16x2x2x8xf32>, tensor<16xf32>) -> tensor<1x32x32x16xf32>
return %0 : tensor<1x32x32x16xf32>
}
@@ -253,7 +253,7 @@ func.func @test_conv2d_pad_top(%arg0: tensor<1x32x32x8xf32>, %arg1: tensor<16x2x
func.func @test_conv2d_pad_bottom(%arg0: tensor<1x32x32x8xf32>, %arg1: tensor<16x2x2x8xf32>, %arg2: tensor<16xf32>) -> tensor<1x32x32x16xf32> {
// expected-error@+1 {{'tosa.conv2d' op failed level check: pad <= MAX_KERNEL}}
- %0 = "tosa.conv2d"(%arg0, %arg1, %arg2) {dilation = array<i64: 1, 1>, pad = array<i64: 0, 8193, 0, 1>, stride = array<i64: 1, 1>} :
+ %0 = "tosa.conv2d"(%arg0, %arg1, %arg2) {acc_type = f32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 8193, 0, 1>, stride = array<i64: 1, 1>} :
(tensor<1x32x32x8xf32>, tensor<16x2x2x8xf32>, tensor<16xf32>) -> tensor<1x32x32x16xf32>
return %0 : tensor<1x32x32x16xf32>
}
@@ -262,7 +262,7 @@ func.func @test_conv2d_pad_bottom(%arg0: tensor<1x32x32x8xf32>, %arg1: tensor<16
func.func @test_conv2d_pad_left(%arg0: tensor<1x32x32x8xf32>, %arg1: tensor<16x2x2x8xf32>, %arg2: tensor<16xf32>) -> tensor<1x32x32x16xf32> {
// expected-error@+1 {{'tosa.conv2d' op failed level check: pad <= MAX_KERNEL}}
- %0 = "tosa.conv2d"(%arg0, %arg1, %arg2) {dilation = array<i64: 1, 1>, pad = array<i64: 0, 1, 8193, 1>, stride = array<i64: 1, 1>} :
+ %0 = "tosa.conv2d"(%arg0, %arg1, %arg2) {acc_type = f32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 1, 8193, 1>, stride = array<i64: 1, 1>} :
(tensor<1x32x32x8xf32>, tensor<16x2x2x8xf32>, tensor<16xf32>) -> tensor<1x32x32x16xf32>
return %0 : tensor<1x32x32x16xf32>
}
@@ -271,7 +271,7 @@ func.func @test_conv2d_pad_left(%arg0: tensor<1x32x32x8xf32>, %arg1: tensor<16x2
func.func @test_conv2d_pad_right(%arg0: tensor<1x32x32x8xf32>, %arg1: tensor<16x2x2x8xf32>, %arg2: tensor<16xf32>) -> tensor<1x32x32x16xf32> {
// expected-error@+1 {{'tosa.conv2d' op failed level check: pad <= MAX_KERNEL}}
- %0 = "tosa.conv2d"(%arg0, %arg1, %arg2) {dilation = array<i64: 1, 1>, pad = array<i64: 0, 1, 0, 8193>, stride = array<i64: 1, 1>} :
+ %0 = "tosa.conv2d"(%arg0, %arg1, %arg2) {acc_type = f32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 1, 0, 8193>, stride = array<i64: 1, 1>} :
(tensor<1x32x32x8xf32>, tensor<16x2x2x8xf32>, tensor<16xf32>) -> tensor<1x32x32x16xf32>
return %0 : tensor<1x32x32x16xf32>
}
@@ -280,7 +280,7 @@ func.func @test_conv2d_pad_right(%arg0: tensor<1x32x32x8xf32>, %arg1: tensor<16x
func.func @test_conv2d_stride_y(%arg0: tensor<1x32x32x8xf32>, %arg1: tensor<16x2x2x8xf32>, %arg2: tensor<16xf32>) -> tensor<1x32x32x16xf32> {
// expected-error@+1 {{'tosa.conv2d' op failed level check: stride <= MAX_STRIDE}}
- %0 = "tosa.conv2d"(%arg0, %arg1, %arg2) {dilation = array<i64: 1, 1>, pad = array<i64: 0, 1, 0, 1>, stride = array<i64: 8193, 1>} :
+ %0 = "tosa.conv2d"(%arg0, %arg1, %arg2) {acc_type = f32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 1, 0, 1>, stride = array<i64: 8193, 1>} :
(tensor<1x32x32x8xf32>, tensor<16x2x2x8xf32>, tensor<16xf32>) -> tensor<1x32x32x16xf32>
return %0 : tensor<1x32x32x16xf32>
}
@@ -289,7 +289,7 @@ func.func @test_conv2d_stride_y(%arg0: tensor<1x32x32x8xf32>, %arg1: tensor<16x2
func.func @test_conv2d_stride_x(%arg0: tensor<1x32x32x8xf32>, %arg1: tensor<16x2x2x8xf32>, %arg2: tensor<16xf32>) -> tensor<1x32x32x16xf32> {
// expected-error@+1 {{'tosa.conv2d' op failed level check: stride <= MAX_STRIDE}}
- %0 = "tosa.conv2d"(%arg0, %arg1, %arg2) {dilation = array<i64: 1, 1>, pad = array<i64: 0, 1, 0, 1>, stride = array<i64: 1, 8193>} :
+ %0 = "tosa.conv2d"(%arg0, %arg1, %arg2) {acc_type = f32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 1, 0, 1>, stride = array<i64: 1, 8193>} :
(tensor<1x32x32x8xf32>, tensor<16x2x2x8xf32>, tensor<16xf32>) -> tensor<1x32x32x16xf32>
return %0 : tensor<1x32x32x16xf32>
}
@@ -298,7 +298,7 @@ func.func @test_conv2d_stride_x(%arg0: tensor<1x32x32x8xf32>, %arg1: tensor<16x2
func.func @test_conv3d_dilation_d(%arg0: tensor<1x1x32x32x8xf32>, %arg1: tensor<16x2x2x2x8xf32>, %arg2: tensor<16xf32>) -> tensor<1x1x32x32x16xf32> {
// expected-error@+1 {{'tosa.conv3d' op failed level check: dilation_d * KD <= MAX_KERNEL}}
- %0 = "tosa.conv3d"(%arg0, %arg1, %arg2) {dilation = array<i64: 4097, 1, 1>, pad = array<i64: 0, 1, 0, 1, 0, 1>, stride = array<i64: 1, 1, 1>} :
+ %0 = "tosa.conv3d"(%arg0, %arg1, %arg2) {acc_type = f32, dilation = array<i64: 4097, 1, 1>, pad = array<i64: 0, 1, 0, 1, 0, 1>, stride = array<i64: 1, 1, 1>} :
(tensor<1x1x32x32x8xf32>, tensor<16x2x2x2x8xf32>, tensor<16xf32>) -> tensor<1x1x32x32x16xf32>
return %0 : tensor<1x1x32x32x16xf32>
}
@@ -307,7 +307,7 @@ func.func @test_conv3d_dilation_d(%arg0: tensor<1x1x32x32x8xf32>, %arg1: tensor<
func.func @test_conv3d_dilation_y(%arg0: tensor<1x1x32x32x8xf32>, %arg1: tensor<16x2x2x2x8xf32>, %arg2: tensor<16xf32>) -> tensor<1x1x32x32x16xf32> {
// expected-error@+1 {{'tosa.conv3d' op failed level check: dilation_y * KH <= MAX_KERNEL}}
- %0 = "tosa.conv3d"(%arg0, %arg1, %arg2) {dilation = array<i64: 1, 4097, 1>, pad = array<i64: 0, 1, 0, 1, 0, 1>, stride = array<i64: 1, 1, 1>} :
+ %0 = "tosa.conv3d"(%arg0, %arg1, %arg2) {acc_type = f32, dilation = array<i64: 1, 4097, 1>, pad = array<i64: 0, 1, 0, 1, 0, 1>, stride = array<i64: 1, 1, 1>} :
(tensor<1x1x32x32x8xf32>, tensor<16x2x2x2x8xf32>, tensor<16xf32>) -> tensor<1x1x32x32x16xf32>
return %0 : tensor<1x1x32x32x16xf32>
}
@@ -316,7 +316,7 @@ func.func @test_conv3d_dilation_y(%arg0: tensor<1x1x32x32x8xf32>, %arg1: tensor<
func.func @test_conv3d_dilation_x(%arg0: tensor<1x1x32x32x8xf32>, %arg1: tensor<16x2x2x2x8xf32>, %arg2: tensor<16xf32>) -> tensor<1x1x32x32x16xf32> {
// expected-error@+1 {{'tosa.conv3d' op failed level check: dilation_x * KW <= MAX_KERNEL}}
- %0 = "tosa.conv3d"(%arg0, %arg1, %arg2) {dilation = array<i64: 1, 1, 4097>, pad = array<i64: 0, 1, 0, 1, 0, 1>, stride = array<i64: 1, 1, 1>} :
+ %0 = "tosa.conv3d"(%arg0, %arg1, %arg2) {acc_type = f32, dilation = array<i64: 1, 1, 4097>, pad = array<i64: 0, 1, 0, 1, 0, 1>, stride = array<i64: 1, 1, 1>} :
(tensor<1x1x32x32x8xf32>, tensor<16x2x2x2x8xf32>, tensor<16xf32>) -> tensor<1x1x32x32x16xf32>
return %0 : tensor<1x1x32x32x16xf32>
}
@@ -325,7 +325,7 @@ func.func @test_conv3d_dilation_x(%arg0: tensor<1x1x32x32x8xf32>, %arg1: tensor<
func.func @test_conv3d_pad_d0(%arg0: tensor<1x1x32x32x8xf32>, %arg1: tensor<16x2x2x2x8xf32>, %arg2: tensor<16xf32>) -> tensor<1x1x32x32x16xf32> {
// expected-error@+1 {{'tosa.conv3d' op failed level check: pad <= MAX_KERNEL}}
- %0 = "tosa.conv3d"(%arg0, %arg1, %arg2) {dilation = array<i64: 1, 1, 1>, pad = array<i64: 8193, 1, 0, 1, 0, 1>, stride = array<i64: 1, 1, 1>} :
+ %0 = "tosa.conv3d"(%arg0, %arg1, %arg2) {acc_type = f32, dilation = array<i64: 1, 1, 1>, pad = array<i64: 8193, 1, 0, 1, 0, 1>, stride = array<i64: 1, 1, 1>} :
(tensor<1x1x32x32x8xf32>, tensor<16x2x2x2x8xf32>, tensor<16xf32>) -> tensor<1x1x32x32x16xf32>
return %0 : tensor<1x1x32x32x16xf32>
}
@@ -334,7 +334,7 @@ func.func @test_conv3d_pad_d0(%arg0: tensor<1x1x32x32x8xf32>, %arg1: tensor<16x2
func.func @test_conv3d_pad_d1(%arg0: tensor<1x1x32x32x8xf32>, %arg1: tensor<16x2x2x2x8xf32>, %arg2: tensor<16xf32>) -> tensor<1x1x32x32x16xf32> {
// expected-error@+1 {{'tosa.conv3d' op failed level check: pad <= MAX_KERNEL}}
- %0 = "tosa.conv3d"(%arg0, %arg1, %arg2) {dilation = array<i64: 1, 1, 1>, pad = array<i64: 1, 8193, 0, 1, 0, 1>, stride = array<i64: 1, 1, 1>} :
+ %0 = "tosa.conv3d"(%arg0, %arg1, %arg2) {acc_type = f32, dilation = array<i64: 1, 1, 1>, pad = array<i64: 1, 8193, 0, 1, 0, 1>, stride = array<i64: 1, 1, 1>} :
(tensor<1x1x32x32x8xf32>, tensor<16x2x2x2x8xf32>, tensor<16xf32>) -> tensor<1x1x32x32x16xf32>
return %0 : tensor<1x1x32x32x16xf32>
}
@@ -343,7 +343,7 @@ func.func @test_conv3d_pad_d1(%arg0: tensor<1x1x32x32x8xf32>, %arg1: tensor<16x2
func.func @test_conv3d_pad_top(%arg0: tensor<1x1x32x32x8xf32>, %arg1: tensor<16x2x2x2x8xf32>, %arg2: tensor<16xf32>) -> tensor<1x1x32x32x16xf32> {
// expected-error@+1 {{'tosa.conv3d' op failed level check: pad <= MAX_KERNEL}}
- %0 = "tosa.conv3d"(%arg0, %arg1, %arg2) {dilation = array<i64: 1, 1, 1>, pad = array<i64: 0, 1, 8193, 1, 0, 1>, stride = array<i64: 1, 1, 1>} :
+ %0 = "tosa.conv3d"(%arg0, %arg1, %arg2) {acc_type = f32, dilation = array<i64: 1, 1, 1>, pad = array<i64: 0, 1, 8193, 1, 0, 1>, stride = array<i64: 1, 1, 1>} :
(tensor<1x1x32x32x8xf32>, tensor<16x2x2x2x8xf32>, tensor<16xf32>) -> tensor<1x1x32x32x16xf32>
return %0 : tensor<1x1x32x32x16xf32>
}
@@ -352,7 +352,7 @@ func.func @test_conv3d_pad_top(%arg0: tensor<1x1x32x32x8xf32>, %arg1: tensor<16x
func.func @test_conv3d_pad_bottom(%arg0: tensor<1x1x32x32x8xf32>, %arg1: tensor<16x2x2x2x8xf32>, %arg2: tensor<16xf32>) -> tensor<1x1x32x32x16xf32> {
// expected-error@+1 {{'tosa.conv3d' op failed level check: pad <= MAX_KERNEL}}
- %0 = "tosa.conv3d"(%arg0, %arg1, %arg2) {dilation = array<i64: 1, 1, 1>, pad = array<i64: 0, 1, 0, 8193, 0, 1>, stride = array<i64: 1, 1, 1>} :
+ %0 = "tosa.conv3d"(%arg0, %arg1, %arg2) {acc_type = f32, dilation = array<i64: 1, 1, 1>, pad = array<i64: 0, 1, 0, 8193, 0, 1>, stride = array<i64: 1, 1, 1>} :
(tensor<1x1x32x32x8xf32>, tensor<16x2x2x2x8xf32>, tensor<16xf32>) -> tensor<1x1x32x32x16xf32>
return %0 : tensor<1x1x32x32x16xf32>
}
@@ -361,7 +361,7 @@ func.func @test_conv3d_pad_bottom(%arg0: tensor<1x1x32x32x8xf32>, %arg1: tensor<
func.func @test_conv3d_pad_left(%arg0: tensor<1x1x32x32x8xf32>, %arg1: tensor<16x2x2x2x8xf32>, %arg2: tensor<16xf32>) -> tensor<1x1x32x32x16xf32> {
// expected-error@+1 {{'tosa.conv3d' op failed level check: pad <= MAX_KERNEL}}
- %0 = "tosa.conv3d"(%arg0, %arg1, %arg2) {dilation = array<i64: 1, 1, 1>, pad = array<i64: 0, 1, 0, 1, 8193, 1>, stride = array<i64: 1, 1, 1>} :
+ %0 = "tosa.conv3d"(%arg0, %arg1, %arg2) {acc_type = f32, dilation = array<i64: 1, 1, 1>, pad = array<i64: 0, 1, 0, 1, 8193, 1>, stride = array<i64: 1, 1, 1>} :
(tensor<1x1x32x32x8xf32>, tensor<16x2x2x2x8xf32>, tensor<16xf32>) -> tensor<1x1x32x32x16xf32>
return %0 : tensor<1x1x32x32x16xf32>
}
@@ -370,7 +370,7 @@ func.func @test_conv3d_pad_left(%arg0: tensor<1x1x32x32x8xf32>, %arg1: tensor<16
func.func @test_conv3d_pad_right(%arg0: tensor<1x1x32x32x8xf32>, %arg1: tensor<16x2x2x2x8xf32>, %arg2: tensor<16xf32>) -> tensor<1x1x32x32x16xf32> {
// expected-error@+1 {{'tosa.conv3d' op failed level check: pad <= MAX_KERNEL}}
- %0 = "tosa.conv3d"(%arg0, %arg1, %arg2) {dilation = array<i64: 1, 1, 1>, pad = array<i64: 0, 1, 0, 1, 0, 8193>, stride = array<i64: 1, 1, 1>} :
+ %0 = "tosa.conv3d"(%arg0, %arg1, %arg2) {acc_type = f32, dilation = array<i64: 1, 1, 1>, pad = array<i64: 0, 1, 0, 1, 0, 8193>, stride = array<i64: 1, 1, 1>} :
(tensor<1x1x32x32x8xf32>, tensor<16x2x2x2x8xf32>, tensor<16xf32>) -> tensor<1x1x32x32x16xf32>
return %0 : tensor<1x1x32x32x16xf32>
}
@@ -379,7 +379,7 @@ func.func @test_conv3d_pad_right(%arg0: tensor<1x1x32x32x8xf32>, %arg1: tensor<1
func.func @test_conv3d_stride_d(%arg0: tensor<1x1x32x32x8xf32>, %arg1: tensor<16x2x2x2x8xf32>, %arg2: tensor<16xf32>) -> tensor<1x1x32x32x16xf32> {
// expected-error@+1 {{'tosa.conv3d' op failed level check: stride <= MAX_STRIDE}}
- %0 = "tosa.conv3d"(%arg0, %arg1, %arg2) {dilation = array<i64: 1, 1, 1>, pad = array<i64: 0, 1, 0, 1, 0, 1>, stride = array<i64: 8193, 1, 1>} :
+ %0 = "tosa.conv3d"(%arg0, %arg1, %arg2) {acc_type = f32, dilation = array<i64: 1, 1, 1>, pad = array<i64: 0, 1, 0, 1, 0, 1>, stride = array<i64: 8193, 1, 1>} :
(tensor<1x1x32x32x8xf32>, tensor<16x2x2x2x8xf32>, tensor<16xf32>) -> tensor<1x1x32x32x16xf32>
return %0 : tensor<1x1x32x32x16xf32>
}
@@ -388,7 +388,7 @@ func.func @test_conv3d_stride_d(%arg0: tensor<1x1x32x32x8xf32>, %arg1: tensor<16
func.func @test_conv3d_stride_y(%arg0: tensor<1x1x32x32x8xf32>, %arg1: tensor<16x2x2x2x8xf32>, %arg2: tensor<16xf32>) -> tensor<1x1x32x32x16xf32> {
// expected-error@+1 {{'tosa.conv3d' op failed level check: stride <= MAX_STRIDE}}
- %0 = "tosa.conv3d"(%arg0, %arg1, %arg2) {dilation = array<i64: 1, 1, 1>, pad = array<i64: 0, 1, 0, 1, 0, 1>, stride = array<i64: 1, 8193, 1>} :
+ %0 = "tosa.conv3d"(%arg0, %arg1, %arg2) {acc_type = f32, dilation = array<i64: 1, 1, 1>, pad = array<i64: 0, 1, 0, 1, 0, 1>, stride = array<i64: 1, 8193, 1>} :
(tensor<1x1x32x32x8xf32>, tensor<16x2x2x2x8xf32>, tensor<16xf32>) -> tensor<1x1x32x32x16xf32>
return %0 : tensor<1x1x32x32x16xf32>
}
@@ -397,7 +397,7 @@ func.func @test_conv3d_stride_y(%arg0: tensor<1x1x32x32x8xf32>, %arg1: tensor<16
func.func @test_conv3d_stride_x(%arg0: tensor<1x1x32x32x8xf32>, %arg1: tensor<16x2x2x2x8xf32>, %arg2: tensor<16xf32>) -> tensor<1x1x32x32x16xf32> {
// expected-error@+1 {{'tosa.conv3d' op failed level check: stride <= MAX_STRIDE}}
- %0 = "tosa.conv3d"(%arg0, %arg1, %arg2) {dilation = array<i64: 1, 1, 1>, pad = array<i64: 0, 1, 0, 1, 0, 1>, stride = array<i64: 1, 1, 8193>} :
+ %0 = "tosa.conv3d"(%arg0, %arg1, %arg2) {acc_type = f32, dilation = array<i64: 1, 1, 1>, pad = array<i64: 0, 1, 0, 1, 0, 1>, stride = array<i64: 1, 1, 8193>} :
(tensor<1x1x32x32x8xf32>, tensor<16x2x2x2x8xf32>, tensor<16xf32>) -> tensor<1x1x32x32x16xf32>
return %0 : tensor<1x1x32x32x16xf32>
}
@@ -406,7 +406,7 @@ func.func @test_conv3d_stride_x(%arg0: tensor<1x1x32x32x8xf32>, %arg1: tensor<16
func.func @test_depthwise_conv2d_dilation_y(%arg0: tensor<1x32x32x8xf32>, %arg1: tensor<2x2x8x8xf32>, %arg2: tensor<64xf32>) -> tensor<1x32x32x64xf32> {
// expected-error@+1 {{'tosa.depthwise_conv2d' op failed level check: dilation_y * KH <= MAX_KERNEL}}
- %0 = "tosa.depthwise_conv2d"(%arg0, %arg1, %arg2) {dilation = array<i64: 4097, 1>, pad = array<i64: 0, 1, 0, 1>, stride = array<i64: 1, 1>} :
+ %0 = "tosa.depthwise_conv2d"(%arg0, %arg1, %arg2) {acc_type = f32, dilation = array<i64: 4097, 1>, pad = array<i64: 0, 1, 0, 1>, stride = array<i64: 1, 1>} :
(tensor<1x32x32x8xf32>, tensor<2x2x8x8xf32>, tensor<64xf32>) -> tensor<1x32x32x64xf32>
return %0 : tensor<1x32x32x64xf32>
}
@@ -415,7 +415,7 @@ func.func @test_depthwise_conv2d_dilation_y(%arg0: tensor<1x32x32x8xf32>, %arg1:
func.func @test_depthwise_conv2d_dilation_x(%arg0: tensor<1x32x32x8xf32>, %arg1: tensor<2x2x8x8xf32>, %arg2: tensor<64xf32>) -> tensor<1x32x32x64xf32> {
// expected-error@+1 {{'tosa.depthwise_conv2d' op failed level check: dilation_x * KW <= MAX_KERNEL}}
- %0 = "tosa.depthwise_conv2d"(%arg0, %arg1, %arg2) {dilation = array<i64: 1, 4097>, pad = array<i64: 0, 1, 0, 1>, stride = array<i64: 1, 1>} :
+ %0 = "tosa.depthwise_conv2d"(%arg0, %arg1, %arg2) {acc_type = f32, dilation = array<i64: 1, 4097>, pad = array<i64: 0, 1, 0, 1>, stride = array<i64: 1, 1>} :
(tensor<1x32x32x8xf32>, tensor<2x2x8x8xf32>, tensor<64xf32>) -> tensor<1x32x32x64xf32>
return %0 : tensor<1x32x32x64xf32>
}
@@ -424,7 +424,7 @@ func.func @test_depthwise_conv2d_dilation_x(%arg0: tensor<1x32x32x8xf32>, %arg1:
func.func @test_depthwise_conv2d_pad_top(%arg0: tensor<1x32x32x8xf32>, %arg1: tensor<2x2x8x8xf32>, %arg2: tensor<64xf32>) -> tensor<1x32x32x64xf32> {
// expected-error@+1 {{'tosa.depthwise_conv2d' op failed level check: pad <= MAX_KERNEL}}
- %0 = "tosa.depthwise_conv2d"(%arg0, %arg1, %arg2) {dilation = array<i64: 1, 1>, pad = array<i64: 8193, 1, 0, 1>, stride = array<i64: 1, 1>} :
+ %0 = "tosa.depthwise_conv2d"(%arg0, %arg1, %arg2) {acc_type = f32, dilation = array<i64: 1, 1>, pad = array<i64: 8193, 1, 0, 1>, stride = array<i64: 1, 1>} :
(tensor<1x32x32x8xf32>, tensor<2x2x8x8xf32>, tensor<64xf32>) -> tensor<1x32x32x64xf32>
return %0 : tensor<1x32x32x64xf32>
}
@@ -433,7 +433,7 @@ func.func @test_depthwise_conv2d_pad_top(%arg0: tensor<1x32x32x8xf32>, %arg1: te
func.func @test_depthwise_conv2d_pad_bottom(%arg0: tensor<1x32x32x8xf32>, %arg1: tensor<2x2x8x8xf32>, %arg2: tensor<64xf32>) -> tensor<1x32x32x64xf32> {
// expected-error@+1 {{'tosa.depthwise_conv2d' op failed level check: pad <= MAX_KERNEL}}
- %0 = "tosa.depthwise_conv2d"(%arg0, %arg1, %arg2) {dilation = array<i64: 1, 1>, pad = array<i64: 0, 8193, 0, 1>, stride = array<i64: 1, 1>} :
+ %0 = "tosa.depthwise_conv2d"(%arg0, %arg1, %arg2) {acc_type = f32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 8193, 0, 1>, stride = array<i64: 1, 1>} :
(tensor<1x32x32x8xf32>, tensor<2x2x8x8xf32>, tensor<64xf32>) -> tensor<1x32x32x64xf32>
return %0 : tensor<1x32x32x64xf32>
}
@@ -442,7 +442,7 @@ func.func @test_depthwise_conv2d_pad_bottom(%arg0: tensor<1x32x32x8xf32>, %arg1:
func.func @test_depthwise_conv2d_pad_left(%arg0: tensor<1x32x32x8xf32>, %arg1: tensor<2x2x8x8xf32>, %arg2: tensor<64xf32>) -> tensor<1x32x32x64xf32> {
// expected-error@+1 {{'tosa.depthwise_conv2d' op failed level check: pad <= MAX_KERNEL}}
- %0 = "tosa.depthwise_conv2d"(%arg0, %arg1, %arg2) {dilation = array<i64: 1, 1>, pad = array<i64: 0, 1, 8193, 1>, stride = array<i64: 1, 1>} :
+ %0 = "tosa.depthwise_conv2d"(%arg0, %arg1, %arg2) {acc_type = f32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 1, 8193, 1>, stride = array<i64: 1, 1>} :
(tensor<1x32x32x8xf32>, tensor<2x2x8x8xf32>, tensor<64xf32>) -> tensor<1x32x32x64xf32>
return %0 : tensor<1x32x32x64xf32>
}
@@ -451,7 +451,7 @@ func.func @test_depthwise_conv2d_pad_left(%arg0: tensor<1x32x32x8xf32>, %arg1: t
func.func @test_depthwise_conv2d_pad_right(%arg0: tensor<1x32x32x8xf32>, %arg1: tensor<2x2x8x8xf32>, %arg2: tensor<64xf32>) -> tensor<1x32x32x64xf32> {
// expected-error@+1 {{'tosa.depthwise_conv2d' op failed level check: pad <= MAX_KERNEL}}
- %0 = "tosa.depthwise_conv2d"(%arg0, %arg1, %arg2) {dilation = array<i64: 1, 1>, pad = array<i64: 0, 1, 0, 8193>, stride = array<i64: 1, 1>} :
+ %0 = "tosa.depthwise_conv2d"(%arg0, %arg1, %arg2) {acc_type = f32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 1, 0, 8193>, stride = array<i64: 1, 1>} :
(tensor<1x32x32x8xf32>, tensor<2x2x8x8xf32>, tensor<64xf32>) -> tensor<1x32x32x64xf32>
return %0 : tensor<1x32x32x64xf32>
}
@@ -460,7 +460,7 @@ func.func @test_depthwise_conv2d_pad_right(%arg0: tensor<1x32x32x8xf32>, %arg1:
func.func @test_depthwise_conv2d_stride_y(%arg0: tensor<1x32x32x8xf32>, %arg1: tensor<2x2x8x8xf32>, %arg2: tensor<64xf32>) -> tensor<1x32x32x64xf32> {
// expected-error@+1 {{'tosa.depthwise_conv2d' op failed level check: stride <= MAX_STRIDE}}
- %0 = "tosa.depthwise_conv2d"(%arg0, %arg1, %arg2) {dilation = array<i64: 1, 1>, pad = array<i64: 0, 1, 0, 1>, stride = array<i64: 8193, 1>} :
+ %0 = "tosa.depthwise_conv2d"(%arg0, %arg1, %arg2) {acc_type = f32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 1, 0, 1>, stride = array<i64: 8193, 1>} :
(tensor<1x32x32x8xf32>, tensor<2x2x8x8xf32>, tensor<64xf32>) -> tensor<1x32x32x64xf32>
return %0 : tensor<1x32x32x64xf32>
}
@@ -469,7 +469,7 @@ func.func @test_depthwise_conv2d_stride_y(%arg0: tensor<1x32x32x8xf32>, %arg1: t
func.func @test_depthwise_conv2d_stride_x(%arg0: tensor<1x32x32x8xf32>, %arg1: tensor<2x2x8x8xf32>, %arg2: tensor<64xf32>) -> tensor<1x32x32x64xf32> {
// expected-error@+1 {{'tosa.depthwise_conv2d' op failed level check: stride <= MAX_STRIDE}}
- %0 = "tosa.depthwise_conv2d"(%arg0, %arg1, %arg2) {dilation = array<i64: 1, 1>, pad = array<i64: 0, 1, 0, 1>, stride = array<i64: 1, 8193>} :
+ %0 = "tosa.depthwise_conv2d"(%arg0, %arg1, %arg2) {acc_type = f32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 1, 0, 1>, stride = array<i64: 1, 8193>} :
(tensor<1x32x32x8xf32>, tensor<2x2x8x8xf32>, tensor<64xf32>) -> tensor<1x32x32x64xf32>
return %0 : tensor<1x32x32x64xf32>
}
@@ -603,7 +603,7 @@ func.func @test_rfft2d_input_w(%arg0: tensor<13x8x8193xf32>) -> (tensor<13x8x9xf
func.func @test_transpose_conv2d_weight_h(%arg0: tensor<1x32x32x8xf32>, %arg1: tensor<16x8193x1x8xf32>, %arg2: tensor<16xf32>) -> tensor<1x32x32x16xf32> {
// expected-error@+1 {{'tosa.transpose_conv2d' op failed level check: KH <= MAX_KERNEL}}
- %0 = "tosa.transpose_conv2d"(%arg0, %arg1, %arg2) {out_pad = array<i64: 0, 0, 0, 0>, out_shape = array<i64: 1, 32, 32, 16>, stride = array<i64: 1, 1>} :
+ %0 = "tosa.transpose_conv2d"(%arg0, %arg1, %arg2) {acc_type = f32, out_pad = array<i64: 0, 0, 0, 0>, out_shape = array<i64: 1, 32, 32, 16>, stride = array<i64: 1, 1>} :
(tensor<1x32x32x8xf32>, tensor<16x8193x1x8xf32>, tensor<16xf32>) -> tensor<1x32x32x16xf32>
return %0 : tensor<1x32x32x16xf32>
}
@@ -612,7 +612,7 @@ func.func @test_transpose_conv2d_weight_h(%arg0: tensor<1x32x32x8xf32>, %arg1: t
func.func @test_transpose_conv2d_weight_w(%arg0: tensor<1x32x32x8xf32>, %arg1: tensor<16x1x8193x8xf32>, %arg2: tensor<16xf32>) -> tensor<1x32x32x16xf32> {
// expected-error@+1 {{'tosa.transpose_conv2d' op failed level check: KW <= MAX_KERNEL}}
- %0 = "tosa.transpose_conv2d"(%arg0, %arg1, %arg2) {out_pad = array<i64: 0, 0, 0, 0>, out_shape = array<i64: 1, 32, 32, 16>, stride = array<i64: 1, 1>} :
+ %0 = "tosa.transpose_conv2d"(%arg0, %arg1, %arg2) {acc_type = f32, out_pad = array<i64: 0, 0, 0, 0>, out_shape = array<i64: 1, 32, 32, 16>, stride = array<i64: 1, 1>} :
(tensor<1x32x32x8xf32>, tensor<16x1x8193x8xf32>, tensor<16xf32>) -> tensor<1x32x32x16xf32>
return %0 : tensor<1x32x32x16xf32>
}
@@ -621,7 +621,7 @@ func.func @test_transpose_conv2d_weight_w(%arg0: tensor<1x32x32x8xf32>, %arg1: t
func.func @test_transpose_conv2d_pad_top(%arg0: tensor<1x32x32x8xf32>, %arg1: tensor<16x1x1x8xf32>, %arg2: tensor<16xf32>) -> tensor<1x32x32x16xf32> {
// expected-error@+1 {{'tosa.transpose_conv2d' op failed level check: pad <= MAX_KERNEL}}
- %0 = "tosa.transpose_conv2d"(%arg0, %arg1, %arg2) {out_pad = array<i64: 8193, 0, 0, 0>, out_shape = array<i64: 1, 32, 32, 16>, stride = array<i64: 1, 1>} :
+ %0 = "tosa.transpose_conv2d"(%arg0, %arg1, %arg2) {acc_type = f32, out_pad = array<i64: 8193, 0, 0, 0>, out_shape = array<i64: 1, 32, 32, 16>, stride = array<i64: 1, 1>} :
(tensor<1x32x32x8xf32>, tensor<16x1x1x8xf32>, tensor<16xf32>) -> tensor<1x32x32x16xf32>
return %0 : tensor<1x32x32x16xf32>
}
@@ -630,7 +630,7 @@ func.func @test_transpose_conv2d_pad_top(%arg0: tensor<1x32x32x8xf32>, %arg1: te
func.func @test_transpose_conv2d_pad_bottom(%arg0: tensor<1x32x32x8xf32>, %arg1: tensor<16x1x1x8xf32>, %arg2: tensor<16xf32>) -> tensor<1x32x32x16xf32> {
// expected-error@+1 {{'tosa.transpose_conv2d' op failed level check: pad <= MAX_KERNEL}}
- %0 = "tosa.transpose_conv2d"(%arg0, %arg1, %arg2) {out_pad = array<i64: 0, 8193, 0, 0>, out_shape = array<i64: 1, 32, 32, 16>, stride = array<i64: 1, 1>} :
+ %0 = "tosa.transpose_conv2d"(%arg0, %arg1, %arg2) {acc_type = f32, out_pad = array<i64: 0, 8193, 0, 0>, out_shape = array<i64: 1, 32, 32, 16>, stride = array<i64: 1, 1>} :
(tensor<1x32x32x8xf32>, tensor<16x1x1x8xf32>, tensor<16xf32>) -> tensor<1x32x32x16xf32>
return %0 : tensor<1x32x32x16xf32>
}
@@ -639,7 +639,7 @@ func.func @test_transpose_conv2d_pad_bottom(%arg0: tensor<1x32x32x8xf32>, %arg1:
func.func @test_transpose_conv2d_pad_left(%arg0: tensor<1x32x32x8xf32>, %arg1: tensor<16x1x1x8xf32>, %arg2: tensor<16xf32>) -> tensor<1x32x32x16xf32> {
// expected-error@+1 {{'tosa.transpose_conv2d' op failed level check: pad <= MAX_KERNEL}}
- %0 = "tosa.transpose_conv2d"(%arg0, %arg1, %arg2) {out_pad = array<i64: 0, 0, 8193, 0>, out_shape = array<i64: 1, 32, 32, 16>, stride = array<i64: 1, 1>} :
+ %0 = "tosa.transpose_conv2d"(%arg0, %arg1, %arg2) {acc_type = f32, out_pad = array<i64: 0, 0, 8193, 0>, out_shape = array<i64: 1, 32, 32, 16>, stride = array<i64: 1, 1>} :
(tensor<1x32x32x8xf32>, tensor<16x1x1x8xf32>, tensor<16xf32>) -> tensor<1x32x32x16xf32>
return %0 : tensor<1x32x32x16xf32>
}
@@ -648,7 +648,7 @@ func.func @test_transpose_conv2d_pad_left(%arg0: tensor<1x32x32x8xf32>, %arg1: t
func.func @test_transpose_conv2d_pad_right(%arg0: tensor<1x32x32x8xf32>, %arg1: tensor<16x1x1x8xf32>, %arg2: tensor<16xf32>) -> tensor<1x32x32x16xf32> {
// expected-error@+1 {{'tosa.transpose_conv2d' op failed level check: pad <= MAX_KERNEL}}
- %0 = "tosa.transpose_conv2d"(%arg0, %arg1, %arg2) {out_pad = array<i64: 0, 0, 0, 8193>, out_shape = array<i64: 1, 32, 32, 16>, stride = array<i64: 1, 1>} :
+ %0 = "tosa.transpose_conv2d"(%arg0, %arg1, %arg2) {acc_type = f32, out_pad = array<i64: 0, 0, 0, 8193>, out_shape = array<i64: 1, 32, 32, 16>, stride = array<i64: 1, 1>} :
(tensor<1x32x32x8xf32>, tensor<16x1x1x8xf32>, tensor<16xf32>) -> tensor<1x32x32x16xf32>
return %0 : tensor<1x32x32x16xf32>
}
@@ -657,7 +657,7 @@ func.func @test_transpose_conv2d_pad_right(%arg0: tensor<1x32x32x8xf32>, %arg1:
func.func @test_transpose_conv2d_stride_y(%arg0: tensor<1x32x32x8xf32>, %arg1: tensor<16x1x1x8xf32>, %arg2: tensor<16xf32>) -> tensor<1x32x32x16xf32> {
// expected-error@+1 {{'tosa.transpose_conv2d' op failed level check: stride <= MAX_STRIDE}}
- %0 = "tosa.transpose_conv2d"(%arg0, %arg1, %arg2) {out_pad = array<i64: 0, 0, 0, 0>, out_shape = array<i64: 1, 32, 32, 16>, stride = array<i64: 8193, 1>} :
+ %0 = "tosa.transpose_conv2d"(%arg0, %arg1, %arg2) {acc_type = f32, out_pad = array<i64: 0, 0, 0, 0>, out_shape = array<i64: 1, 32, 32, 16>, stride = array<i64: 8193, 1>} :
(tensor<1x32x32x8xf32>, tensor<16x1x1x8xf32>, tensor<16xf32>) -> tensor<1x32x32x16xf32>
return %0 : tensor<1x32x32x16xf32>
}
@@ -666,7 +666,7 @@ func.func @test_transpose_conv2d_stride_y(%arg0: tensor<1x32x32x8xf32>, %arg1: t
func.func @test_transpose_conv2d_stride_x(%arg0: tensor<1x32x32x8xf32>, %arg1: tensor<16x1x1x8xf32>, %arg2: tensor<16xf32>) -> tensor<1x32x32x16xf32> {
// expected-error@+1 {{'tosa.transpose_conv2d' op failed level check: stride <= MAX_STRIDE}}
- %0 = "tosa.transpose_conv2d"(%arg0, %arg1, %arg2) {out_pad = array<i64: 0, 0, 0, 0>, out_shape = array<i64: 1, 32, 32, 16>, stride = array<i64: 1, 8193>} :
+ %0 = "tosa.transpose_conv2d"(%arg0, %arg1, %arg2) {acc_type = f32, out_pad = array<i64: 0, 0, 0, 0>, out_shape = array<i64: 1, 32, 32, 16>, stride = array<i64: 1, 8193>} :
(tensor<1x32x32x8xf32>, tensor<16x1x1x8xf32>, tensor<16xf32>) -> tensor<1x32x32x16xf32>
return %0 : tensor<1x32x32x16xf32>
}
diff --git a/mlir/test/Dialect/Tosa/ops.mlir b/mlir/test/Dialect/Tosa/ops.mlir
index 88fa94ae90db..f2e1cff72ab2 100644
--- a/mlir/test/Dialect/Tosa/ops.mlir
+++ b/mlir/test/Dialect/Tosa/ops.mlir
@@ -54,7 +54,7 @@ func.func @test_avg_pool2d_q8(%arg0: tensor<1x7x7x9x!quant.uniform<i8:f32, 0.01>
// -----
// CHECK-LABEL: conv2d
func.func @test_conv2d(%arg0: tensor<1x4x4x4xf32>, %arg1: tensor<8x1x1x4xf32>, %arg2: tensor<8xf32>) -> tensor<1x4x4x8xf32> {
- %0 = tosa.conv2d %arg0, %arg1, %arg2 {dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, local_bound = true} : (tensor<1x4x4x4xf32>, tensor<8x1x1x4xf32>, tensor<8xf32>) -> tensor<1x4x4x8xf32>
+ %0 = tosa.conv2d %arg0, %arg1, %arg2 {acc_type = f32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, local_bound = true} : (tensor<1x4x4x4xf32>, tensor<8x1x1x4xf32>, tensor<8xf32>) -> tensor<1x4x4x8xf32>
return %0 : tensor<1x4x4x8xf32>
}
@@ -63,7 +63,7 @@ func.func @test_conv2d(%arg0: tensor<1x4x4x4xf32>, %arg1: tensor<8x1x1x4xf32>, %
func.func @test_conv2d_q8xi4(%arg0: tensor<1x11x11x3xi8>) -> tensor<1x1x1x3xi8> {
%0 = "tosa.const"() {value = dense<0> : tensor<3x11x11x3xi4>} : () -> tensor<3x11x11x3xi4>
%1 = "tosa.const"() {value = dense<[12, 23, 55]> : tensor<3xi32>} : () -> tensor<3xi32>
- %2 = "tosa.conv2d"(%arg0, %0, %1) {dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, quantization_info = #tosa.conv_quant<input_zp = 0, weight_zp = 0>, stride = array<i64: 1, 1>} : (tensor<1x11x11x3xi8>, tensor<3x11x11x3xi4>, tensor<3xi32>) -> tensor<1x1x1x3xi32>
+ %2 = "tosa.conv2d"(%arg0, %0, %1) {acc_type = i32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, quantization_info = #tosa.conv_quant<input_zp = 0, weight_zp = 0>, stride = array<i64: 1, 1>} : (tensor<1x11x11x3xi8>, tensor<3x11x11x3xi4>, tensor<3xi32>) -> tensor<1x1x1x3xi32>
%3 = "tosa.rescale"(%2) {double_round = true, input_zp = 0 : i32, multiplier = array<i32: 2026291432, 1079222024, 1693132724>, output_zp = 27 : i32, per_channel = true, scale32 = true, shift = array<i8: 37, 36, 37>} : (tensor<1x1x1x3xi32>) -> tensor<1x1x1x3xi8>
return %3 : tensor<1x1x1x3xi8>
}
@@ -71,28 +71,28 @@ func.func @test_conv2d_q8xi4(%arg0: tensor<1x11x11x3xi8>) -> tensor<1x1x1x3xi8>
// -----
// CHECK-LABEL: conv3d
func.func @test_conv3d(%arg0: tensor<1x4x8x21x17xf32>, %arg1: tensor<34x1x1x1x17xf32>, %arg2: tensor<34xf32>) -> tensor<1x4x8x21x34xf32> {
- %0 = tosa.conv3d %arg0, %arg1, %arg2 {dilation = array<i64: 1, 1, 1>, pad = array<i64: 0, 0, 0, 0, 0, 0>, stride = array<i64: 1, 1, 1>} : (tensor<1x4x8x21x17xf32>, tensor<34x1x1x1x17xf32>, tensor<34xf32>) -> tensor<1x4x8x21x34xf32>
+ %0 = tosa.conv3d %arg0, %arg1, %arg2 {acc_type = f32, dilation = array<i64: 1, 1, 1>, pad = array<i64: 0, 0, 0, 0, 0, 0>, stride = array<i64: 1, 1, 1>} : (tensor<1x4x8x21x17xf32>, tensor<34x1x1x1x17xf32>, tensor<34xf32>) -> tensor<1x4x8x21x34xf32>
return %0 : tensor<1x4x8x21x34xf32>
}
// -----
// CHECK-LABEL: conv3d_with_local_bound
func.func @test_conv3d_with_local_bound(%arg0: tensor<1x4x8x21x17xf32>, %arg1: tensor<34x1x1x1x17xf32>, %arg2: tensor<34xf32>) -> tensor<1x4x8x21x34xf32> {
- %0 = tosa.conv3d %arg0, %arg1, %arg2 {dilation = array<i64: 1, 1, 1>, pad = array<i64: 0, 0, 0, 0, 0, 0>, stride = array<i64: 1, 1, 1>, local_bound = true} : (tensor<1x4x8x21x17xf32>, tensor<34x1x1x1x17xf32>, tensor<34xf32>) -> tensor<1x4x8x21x34xf32>
+ %0 = tosa.conv3d %arg0, %arg1, %arg2 {acc_type = f32, dilation = array<i64: 1, 1, 1>, pad = array<i64: 0, 0, 0, 0, 0, 0>, stride = array<i64: 1, 1, 1>, local_bound = true} : (tensor<1x4x8x21x17xf32>, tensor<34x1x1x1x17xf32>, tensor<34xf32>) -> tensor<1x4x8x21x34xf32>
return %0 : tensor<1x4x8x21x34xf32>
}
// -----
// CHECK-LABEL: depthwise_conv2d
func.func @test_depthwise_conv2d(%arg0: tensor<1x4x4x4xf32>, %arg1: tensor<1x1x4x2xf32>, %arg2: tensor<8xf32>) -> tensor<1x4x4x8xf32> {
- %0 = tosa.depthwise_conv2d %arg0, %arg1, %arg2 {dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<1x4x4x4xf32>, tensor<1x1x4x2xf32>, tensor<8xf32>) -> tensor<1x4x4x8xf32>
+ %0 = tosa.depthwise_conv2d %arg0, %arg1, %arg2 {acc_type = f32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<1x4x4x4xf32>, tensor<1x1x4x2xf32>, tensor<8xf32>) -> tensor<1x4x4x8xf32>
return %0 : tensor<1x4x4x8xf32>
}
// -----
// CHECK-LABEL: depthwise_conv2d_with_local_bound
func.func @test_depthwise_conv2d_with_local_bound(%arg0: tensor<1x4x4x4xf32>, %arg1: tensor<1x1x4x2xf32>, %arg2: tensor<8xf32>) -> tensor<1x4x4x8xf32> {
- %0 = tosa.depthwise_conv2d %arg0, %arg1, %arg2 {dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, local_bound = true} : (tensor<1x4x4x4xf32>, tensor<1x1x4x2xf32>, tensor<8xf32>) -> tensor<1x4x4x8xf32>
+ %0 = tosa.depthwise_conv2d %arg0, %arg1, %arg2 {acc_type = f32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, local_bound = true} : (tensor<1x4x4x4xf32>, tensor<1x1x4x2xf32>, tensor<8xf32>) -> tensor<1x4x4x8xf32>
return %0 : tensor<1x4x4x8xf32>
}
@@ -162,14 +162,14 @@ func.func @test_rfft2d_with_local_bound(%arg0: tensor<13x8x16xf32>) -> (tensor<1
// -----
// CHECK-LABEL: transpose_conv2d
func.func @test_transpose_conv2d(%arg0: tensor<1x32x32x8xf32>, %arg1: tensor<16x1x1x8xf32>, %arg2: tensor<16xf32>) -> tensor<1x32x32x16xf32> {
- %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2 {out_pad = array<i64: 0, 0, 0, 0>, out_shape = array<i64: 1, 32, 32, 16>, stride = array<i64: 1, 1>} : (tensor<1x32x32x8xf32>, tensor<16x1x1x8xf32>, tensor<16xf32>) -> tensor<1x32x32x16xf32>
+ %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2 {acc_type = f32, out_pad = array<i64: 0, 0, 0, 0>, out_shape = array<i64: 1, 32, 32, 16>, stride = array<i64: 1, 1>} : (tensor<1x32x32x8xf32>, tensor<16x1x1x8xf32>, tensor<16xf32>) -> tensor<1x32x32x16xf32>
return %0 : tensor<1x32x32x16xf32>
}
// -----
// CHECK-LABEL: transpose_conv2d_with_local_bound
func.func @test_transpose_conv2d_with_local_bound(%arg0: tensor<1x32x32x8xf32>, %arg1: tensor<16x1x1x8xf32>, %arg2: tensor<16xf32>) -> tensor<1x32x32x16xf32> {
- %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2 {out_pad = array<i64: 0, 0, 0, 0>, out_shape = array<i64: 1, 32, 32, 16>, stride = array<i64: 1, 1>, local_bound = false} : (tensor<1x32x32x8xf32>, tensor<16x1x1x8xf32>, tensor<16xf32>) -> tensor<1x32x32x16xf32>
+ %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2 {acc_type = f32, out_pad = array<i64: 0, 0, 0, 0>, out_shape = array<i64: 1, 32, 32, 16>, stride = array<i64: 1, 1>, local_bound = false} : (tensor<1x32x32x8xf32>, tensor<16x1x1x8xf32>, tensor<16xf32>) -> tensor<1x32x32x16xf32>
return %0 : tensor<1x32x32x16xf32>
}
@@ -525,16 +525,16 @@ func.func @test_concat(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x21x3xf32>) -
// -----
// CHECK-LABEL: pad
-func.func @test_pad(%arg0: tensor<13x21x3xf32>, %arg1: tensor<3x2xi32>) -> tensor<13x21x3xf32> {
- %0 = tosa.pad %arg0, %arg1 : (tensor<13x21x3xf32>, tensor<3x2xi32>) -> tensor<13x21x3xf32>
+func.func @test_pad(%arg0: tensor<13x21x3xf32>, %arg1: tensor<6xi32>) -> tensor<13x21x3xf32> {
+ %0 = tosa.pad %arg0, %arg1 : (tensor<13x21x3xf32>, tensor<6xi32>) -> tensor<13x21x3xf32>
return %0 : tensor<13x21x3xf32>
}
// -----
// CHECK-LABEL: pad_explicit_value
-func.func @test_pad_explicit_value(%arg0: tensor<13x21x3xf32>, %arg1: tensor<3x2xi32>) -> tensor<13x21x3xf32> {
+func.func @test_pad_explicit_value(%arg0: tensor<13x21x3xf32>, %arg1: tensor<6xi32>) -> tensor<13x21x3xf32> {
%0 = "tosa.const"() {value = dense<3.14> : tensor<f32>} : () -> tensor<f32>
- %1 = tosa.pad %arg0, %arg1, %0 : (tensor<13x21x3xf32>, tensor<3x2xi32>, tensor<f32>) -> tensor<13x21x3xf32>
+ %1 = tosa.pad %arg0, %arg1, %0 : (tensor<13x21x3xf32>, tensor<6xi32>, tensor<f32>) -> tensor<13x21x3xf32>
return %1 : tensor<13x21x3xf32>
}
diff --git a/mlir/test/Dialect/Tosa/quant-test.mlir b/mlir/test/Dialect/Tosa/quant-test.mlir
index 82a87dbc2494..6437f12e3ff8 100644
--- a/mlir/test/Dialect/Tosa/quant-test.mlir
+++ b/mlir/test/Dialect/Tosa/quant-test.mlir
@@ -10,9 +10,9 @@ func.func @test_build_qtype(%arg0 : tensor<16x1x1x8x!quant.uniform<u8<1:255>:f32
// -----
// CHECK-LABEL: test_build_mult_and_shift
-func.func @test_build_mult_and_shift(%arg0: tensor<1x32x32x8x!quant.uniform<i8:f32, 0.015684768557548523>>, %arg1 : tensor<16x1x1x8x!quant.uniform<i8<-127:127>:f32, 0.015680249780416489>>, %arg2 : tensor<16xi32>) -> tensor<1x32x32x16x!quant.uniform<i8:f32, 0.078431375324726104>> {
+func.func @test_build_mult_and_shift(%arg0: tensor<1x32x32x8x!quant.uniform<i8:f32, 0.015684768557548523>>, %arg1 : tensor<16x1x1x8x!quant.uniform<i8<-127:127>:f32, 0.015680249780416489>>, %arg2 : tensor<16xi32>) -> tensor<1x32x32x16x!quant.uniform<i32:f32, 0.078431375324726104>> {
// CHECK: tosa.conv2d
- %0 = "tosa.conv2d"(%arg0, %arg1, %arg2) {pad = array<i64: 1, 1, 2, 2>, dilation = array<i64: 2, 1>, stride = array<i64: 1, 1>, quantization_info = #tosa.conv_quant<input_zp = -1, weight_zp = 0>} : (tensor<1x32x32x8x!quant.uniform<i8:f32, 0.015684768557548523>>, tensor<16x1x1x8x!quant.uniform<i8<-127:127>:f32, 0.015680249780416489>>, tensor<16xi32>) -> tensor<1x32x32x16x!quant.uniform<i8:f32, 0.078431375324726104>>
- return %0 : tensor<1x32x32x16x!quant.uniform<i8:f32, 0.078431375324726104>>
+ %0 = "tosa.conv2d"(%arg0, %arg1, %arg2) {acc_type = i32, pad = array<i64: 1, 1, 2, 2>, dilation = array<i64: 2, 1>, stride = array<i64: 1, 1>, quantization_info = #tosa.conv_quant<input_zp = -1, weight_zp = 0>} : (tensor<1x32x32x8x!quant.uniform<i8:f32, 0.015684768557548523>>, tensor<16x1x1x8x!quant.uniform<i8<-127:127>:f32, 0.015680249780416489>>, tensor<16xi32>) -> tensor<1x32x32x16x!quant.uniform<i32:f32, 0.078431375324726104>>
+ return %0 : tensor<1x32x32x16x!quant.uniform<i32:f32, 0.078431375324726104>>
}
diff --git a/mlir/test/Dialect/Tosa/tosa-decompose-conv2d.mlir b/mlir/test/Dialect/Tosa/tosa-decompose-conv2d.mlir
index d876ccfb3b91..8df4630f9c17 100644
--- a/mlir/test/Dialect/Tosa/tosa-decompose-conv2d.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-decompose-conv2d.mlir
@@ -14,7 +14,7 @@ func.func @conv2d_as_fully_connected(%arg0: tensor<4x10x10x2xf32>, %arg1: tensor
// CHECK: %[[VAR3:.*]] = tosa.reshape %[[VAR2]] {new_shape = array<i64: 4, 10, 10, 3>}
// CHECK-SAME: -> tensor<4x10x10x3xf32>
// CHECK: return %[[VAR3]]
- %0 = tosa.conv2d %arg0, %arg1, %arg2 {pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, dilation = array<i64: 1, 1>} : (tensor<4x10x10x2xf32>, tensor<3x1x1x2xf32>, tensor<3xf32>) -> tensor<4x10x10x3xf32>
+ %0 = tosa.conv2d %arg0, %arg1, %arg2 {acc_type = f32, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, dilation = array<i64: 1, 1>} : (tensor<4x10x10x2xf32>, tensor<3x1x1x2xf32>, tensor<3xf32>) -> tensor<4x10x10x3xf32>
return %0 : tensor<4x10x10x3xf32>
}
@@ -33,7 +33,7 @@ func.func @conv2d_as_fully_connected_quant(%arg0: tensor<4x10x10x2xi8>, %arg1: t
// CHECK: %[[VAR3:.*]] = tosa.reshape %[[VAR2]] {new_shape = array<i64: 4, 10, 10, 3>}
// CHECK-SAME: -> tensor<4x10x10x3xi32>
// CHECK: return %[[VAR3]]
- %0 = tosa.conv2d %arg0, %arg1, %arg2 {pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, dilation = array<i64: 1, 1>, quantization_info = #tosa.conv_quant<input_zp = 42, weight_zp = 24>} : (tensor<4x10x10x2xi8>, tensor<3x1x1x2xi8>, tensor<3xi32>) -> tensor<4x10x10x3xi32>
+ %0 = tosa.conv2d %arg0, %arg1, %arg2 {acc_type = i32, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, dilation = array<i64: 1, 1>, quantization_info = #tosa.conv_quant<input_zp = 42, weight_zp = 24>} : (tensor<4x10x10x2xi8>, tensor<3x1x1x2xi8>, tensor<3xi32>) -> tensor<4x10x10x3xi32>
return %0 : tensor<4x10x10x3xi32>
}
@@ -50,7 +50,7 @@ func.func @conv_with_dynamic_dim(%arg0: tensor<?x14x14x64xi8>, %arg1: tensor<384
// CHECK: %[[VAL_6:.*]] = tosa.reshape %[[VAL_5]] {new_shape = array<i64: -1, 14, 14, 384>} : (tensor<?x384xi32>) -> tensor<?x14x14x384xi32>
// CHECK: return %[[VAL_6]] : tensor<?x14x14x384xi32>
// CHECK: }
- %0 = tosa.conv2d %arg0, %arg1, %arg2 {dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, quantization_info = #tosa.conv_quant<input_zp = -6, weight_zp = 11>, stride = array<i64: 1, 1>} : (tensor<?x14x14x64xi8>, tensor<384x1x1x64xi8>, tensor<384xi32>) -> tensor<?x14x14x384xi32>
+ %0 = tosa.conv2d %arg0, %arg1, %arg2 {acc_type = i32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, quantization_info = #tosa.conv_quant<input_zp = -6, weight_zp = 11>, stride = array<i64: 1, 1>} : (tensor<?x14x14x64xi8>, tensor<384x1x1x64xi8>, tensor<384xi32>) -> tensor<?x14x14x384xi32>
return %0 : tensor<?x14x14x384xi32>
}
@@ -58,13 +58,13 @@ func.func @conv_with_dynamic_dim(%arg0: tensor<?x14x14x64xi8>, %arg1: tensor<384
// CHECK-LABEL: @conv2d_as_fully_connected_padded
func.func @conv2d_as_fully_connected_padded(%arg0: tensor<4x10x10x2xi8>, %arg1: tensor<3x1x1x2xi8>, %arg2: tensor<3xi32>) -> tensor<4x12x12x3xi32> {
- // CHECK-DAG: %[[PAD_SHAPE:.+]] = "tosa.const"() <{value = dense<{{\[\[}}0, 0], [1, 1], [1, 1], [0, 0]]> : tensor<4x2xi64>}
+ // CHECK-DAG: %[[PAD_SHAPE:.+]] = "tosa.const"() <{value = dense<{{\[}}0, 0, 1, 1, 1, 1, 0, 0]> : tensor<8xi64>}
// CHECK-DAG: %[[PAD_VAL:.+]] = "tosa.const"() <{value = dense<42> : tensor<i8>}
- // CHECK-DAG: %[[PAD:.+]] = tosa.pad %arg0, %[[PAD_SHAPE]], %[[PAD_VAL]] : (tensor<4x10x10x2xi8>, tensor<4x2xi64>, tensor<i8>) -> tensor<4x12x12x2xi8>
+ // CHECK-DAG: %[[PAD:.+]] = tosa.pad %arg0, %[[PAD_SHAPE]], %[[PAD_VAL]] : (tensor<4x10x10x2xi8>, tensor<8xi64>, tensor<i8>) -> tensor<4x12x12x2xi8>
// CHECK-DAG: %[[RESHAPE_INPUT:.+]] = tosa.reshape %[[PAD]] {new_shape = array<i64: 576, 2>}
// CHECK-DAG: %[[RESHAPE_FILTER:.+]] = tosa.reshape %arg1 {new_shape = array<i64: 3, 2>}
// CHECK-DAG: %[[FULLY:.+]] = tosa.fully_connected %[[RESHAPE_INPUT]], %[[RESHAPE_FILTER]], %arg2 {quantization_info = #tosa.conv_quant<input_zp = 42, weight_zp = 24>}
// CHECK: %[[RESHAPE:.+]] = tosa.reshape %[[FULLY]] {new_shape = array<i64: 4, 12, 12, 3>}
- %0 = tosa.conv2d %arg0, %arg1, %arg2 {pad = array<i64: 1, 1, 1, 1>, stride = array<i64: 1, 1>, dilation = array<i64: 1, 1>, quantization_info = #tosa.conv_quant<input_zp = 42, weight_zp = 24>} : (tensor<4x10x10x2xi8>, tensor<3x1x1x2xi8>, tensor<3xi32>) -> tensor<4x12x12x3xi32>
+ %0 = tosa.conv2d %arg0, %arg1, %arg2 {acc_type = i32, pad = array<i64: 1, 1, 1, 1>, stride = array<i64: 1, 1>, dilation = array<i64: 1, 1>, quantization_info = #tosa.conv_quant<input_zp = 42, weight_zp = 24>} : (tensor<4x10x10x2xi8>, tensor<3x1x1x2xi8>, tensor<3xi32>) -> tensor<4x12x12x3xi32>
return %0 : tensor<4x12x12x3xi32>
}
diff --git a/mlir/test/Dialect/Tosa/tosa-decompose-depthwise.mlir b/mlir/test/Dialect/Tosa/tosa-decompose-depthwise.mlir
index 2224bf3f57b2..cfff6396ad48 100644
--- a/mlir/test/Dialect/Tosa/tosa-decompose-depthwise.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-decompose-depthwise.mlir
@@ -18,7 +18,7 @@ func.func @depthwise_conv2d_as_mul(%arg0: tensor<4x10x10x2xf32>, %arg1: tensor<1
// CHECK: %[[VAR5:.*]] = tosa.add %[[VAR3]], %[[VAR4]]
// CHECK-SAME: -> tensor<4x10x10x6xf32>
// CHECK: return %[[VAR5]]
- %0 = tosa.depthwise_conv2d %arg0, %arg1, %arg2 {pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, dilation = array<i64: 1, 1>} : (tensor<4x10x10x2xf32>, tensor<1x1x2x3xf32>, tensor<6xf32>) -> tensor<4x10x10x6xf32>
+ %0 = tosa.depthwise_conv2d %arg0, %arg1, %arg2 {acc_type = f32, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, dilation = array<i64: 1, 1>} : (tensor<4x10x10x2xf32>, tensor<1x1x2x3xf32>, tensor<6xf32>) -> tensor<4x10x10x6xf32>
return %0 : tensor<4x10x10x6xf32>
}
@@ -38,7 +38,7 @@ func.func @depthwise_conv2d_as_mul_q(%arg0: tensor<4x10x10x2xi8>, %arg1: tensor<
// CHECK: %[[reO:.+]] = tosa.reshape %[[mul]] {new_shape = array<i64: 4, 10, 10, 6>}
// CHECK: %[[reArg2:.+]] = tosa.reshape %arg2 {new_shape = array<i64: 1, 1, 1, 6>}
// CHECK: %[[add:.+]] = tosa.add %[[reO]], %[[reArg2]]
- %0 = tosa.depthwise_conv2d %arg0, %arg1, %arg2 {pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, dilation = array<i64: 1, 1>, quantization_info = #tosa.conv_quant<input_zp = 7, weight_zp = 11>} : (tensor<4x10x10x2xi8>, tensor<1x1x2x3xi8>, tensor<6xi32>) -> tensor<4x10x10x6xi32>
+ %0 = tosa.depthwise_conv2d %arg0, %arg1, %arg2 {acc_type = i32, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, dilation = array<i64: 1, 1>, quantization_info = #tosa.conv_quant<input_zp = 7, weight_zp = 11>} : (tensor<4x10x10x2xi8>, tensor<1x1x2x3xi8>, tensor<6xi32>) -> tensor<4x10x10x6xi32>
return %0 : tensor<4x10x10x6xi32>
}
@@ -46,15 +46,15 @@ func.func @depthwise_conv2d_as_mul_q(%arg0: tensor<4x10x10x2xi8>, %arg1: tensor<
// CHECK-LABEL: @depthwise_conv2d_as_mul_padded
func.func @depthwise_conv2d_as_mul_padded(%arg0: tensor<4x10x10x2xf32>, %arg1: tensor<1x1x2x3xf32>, %arg2: tensor<6xf32>) -> tensor<4x12x12x6xf32> {
- // CHECK-DAG: %[[pad:.+]] = "tosa.const"() <{value = dense<{{\[\[}}0, 0], [1, 1], [1, 1], [0, 0], [0, 0]]> : tensor<5x2xi64>}
+ // CHECK-DAG: %[[pad:.+]] = "tosa.const"() <{value = dense<{{\[}}0, 0, 1, 1, 1, 1, 0, 0, 0, 0]> : tensor<10xi64>}
// CHECK-DAG: %[[zero:.+]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<f32>}
// CHECK: %[[reIn:.+]] = tosa.reshape %arg0 {new_shape = array<i64: 4, 10, 10, 2, 1>}
- // CHECK: %[[padded:.+]] = tosa.pad %[[reIn]], %[[pad]], %[[zero]] : (tensor<4x10x10x2x1xf32>, tensor<5x2xi64>, tensor<f32>) -> tensor<4x12x12x2x1xf32>
+ // CHECK: %[[padded:.+]] = tosa.pad %[[reIn]], %[[pad]], %[[zero]] : (tensor<4x10x10x2x1xf32>, tensor<10xi64>, tensor<f32>) -> tensor<4x12x12x2x1xf32>
// CHECK: %[[reArg1:.+]] = tosa.reshape %arg1 {new_shape = array<i64: 1, 1, 1, 2, 3>}
// CHECK: %[[mul:.+]] = tosa.mul %3, %[[reArg1]] {shift = 0 : i8}
// CHECK: %[[reOut:.+]] = tosa.reshape %[[mul]] {new_shape = array<i64: 4, 12, 12, 6>}
// CHECK: %[[reArg2:.+]] = tosa.reshape %arg2 {new_shape = array<i64: 1, 1, 1, 6>}
// CHECK: %[[add:.+]] = tosa.add %[[reOut]], %[[reArg2]]
- %0 = tosa.depthwise_conv2d %arg0, %arg1, %arg2 {pad = array<i64: 1, 1, 1, 1>, stride = array<i64: 1, 1>, dilation = array<i64: 1, 1>} : (tensor<4x10x10x2xf32>, tensor<1x1x2x3xf32>, tensor<6xf32>) -> tensor<4x12x12x6xf32>
+ %0 = tosa.depthwise_conv2d %arg0, %arg1, %arg2 {acc_type = f32, pad = array<i64: 1, 1, 1, 1>, stride = array<i64: 1, 1>, dilation = array<i64: 1, 1>} : (tensor<4x10x10x2xf32>, tensor<1x1x2x3xf32>, tensor<6xf32>) -> tensor<4x12x12x6xf32>
return %0 : tensor<4x12x12x6xf32>
}
diff --git a/mlir/test/Dialect/Tosa/tosa-decompose-transpose-conv.mlir b/mlir/test/Dialect/Tosa/tosa-decompose-transpose-conv.mlir
index 1f2bb3fb9a36..c361c7c2899f 100644
--- a/mlir/test/Dialect/Tosa/tosa-decompose-transpose-conv.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-decompose-transpose-conv.mlir
@@ -6,7 +6,7 @@ func.func @transpose_conv2d(%arg0: tensor<2x16x14x3xf32>, %arg1: tensor<5x3x6x3x
// CHECK: %[[REV2:.+]] = tosa.reverse %[[REV1]] {axis = 2 : i32}
// CHECK: tosa.conv2d %arg0, %[[REV2]], %arg2
// CHECK-SAME: dilation = array<i64: 1, 1>, pad = array<i64: 2, 2, 5, 5>, stride = array<i64: 1, 1>
- %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2 {out_pad = array<i64: 0, 0, 0, 0>, out_shape = array<i64: -1, -1, -1, -1>, stride = array<i64: 1, 1>} : (tensor<2x16x14x3xf32>, tensor<5x3x6x3xf32>, tensor<5xf32>) -> tensor<2x18x19x5xf32>
+ %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2 {acc_type = f32, out_pad = array<i64: 0, 0, 0, 0>, out_shape = array<i64: -1, -1, -1, -1>, stride = array<i64: 1, 1>} : (tensor<2x16x14x3xf32>, tensor<5x3x6x3xf32>, tensor<5xf32>) -> tensor<2x18x19x5xf32>
return %0 : tensor<2x18x19x5xf32>
}
@@ -17,8 +17,8 @@ func.func @transpose_conv2d(%arg0: tensor<2x16x14x3xf32>, %arg1: tensor<5x3x6x3x
func.func @transpose_conv2d_quantized(%arg0: tensor<2x16x14x3xi8>, %arg1: tensor<5x3x6x3xi8>, %arg2: tensor<5xi32>) -> (tensor<2x18x19x5xi32>) {
// CHECK: %[[REV1:.+]] = tosa.reverse %arg1 {axis = 1 : i32}
// CHECK: %[[REV2:.+]] = tosa.reverse %[[REV1]] {axis = 2 : i32}
- // CHECK: tosa.conv2d %arg0, %[[REV2]], %arg2 {dilation = array<i64: 1, 1>, pad = array<i64: 2, 2, 5, 5>, quantization_info = #tosa.conv_quant<input_zp = -22, weight_zp = 42>, stride = array<i64: 1, 1>}
- %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2 {out_pad = array<i64: 0, 0, 0, 0>, quantization_info = #tosa.conv_quant<input_zp = -22, weight_zp = 42>, out_shape = array<i64: -1, -1, -1, -1>, stride = array<i64: 1, 1>} : (tensor<2x16x14x3xi8>, tensor<5x3x6x3xi8>, tensor<5xi32>) -> tensor<2x18x19x5xi32>
+ // CHECK: tosa.conv2d %arg0, %[[REV2]], %arg2 {acc_type = i32, dilation = array<i64: 1, 1>, pad = array<i64: 2, 2, 5, 5>, quantization_info = #tosa.conv_quant<input_zp = -22, weight_zp = 42>, stride = array<i64: 1, 1>}
+ %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2 {acc_type = i32, out_pad = array<i64: 0, 0, 0, 0>, quantization_info = #tosa.conv_quant<input_zp = -22, weight_zp = 42>, out_shape = array<i64: -1, -1, -1, -1>, stride = array<i64: 1, 1>} : (tensor<2x16x14x3xi8>, tensor<5x3x6x3xi8>, tensor<5xi32>) -> tensor<2x18x19x5xi32>
return %0 : tensor<2x18x19x5xi32>
}
@@ -32,6 +32,7 @@ func.func @transpose_conv2d_quantized_padded(%arg0: tensor<2x16x14x3xi8>, %arg1:
// CHECK-SAME: dilation = array<i64: 1, 1>, pad = array<i64: 3, 4, 8, 9>,
// CHECK-SAME: quantization_info = #tosa.conv_quant<input_zp = -22, weight_zp = 42>, stride = array<i64: 1, 1>}
%0 = tosa.transpose_conv2d %arg0, %arg1, %arg2 {
+ acc_type = i32,
out_pad = array<i64: 1, 2, 3, 4>,
quantization_info = #tosa.conv_quant<input_zp = -22, weight_zp = 42>,
out_shape = array<i64: -1, -1, -1, -1>,
@@ -44,7 +45,7 @@ func.func @transpose_conv2d_quantized_padded(%arg0: tensor<2x16x14x3xi8>, %arg1:
// CHECK-LABEL: @transpose_conv2d_strided
func.func @transpose_conv2d_strided(%arg0: tensor<2x17x15x3xf32>, %arg1: tensor<5x3x5x3xf32>, %arg2: tensor<5xf32>) -> tensor<2x?x?x5xf32> {
// Manipulate the weight matrix to handle striding.
- // CHECK-DAG: %[[PADV:.+]] = "tosa.const"() <{value = dense<{{\[\[}}0, 0], [0, 1], [0, 1], [0, 0]]> : tensor<4x2xi32>}
+ // CHECK-DAG: %[[PADV:.+]] = "tosa.const"() <{value = dense<{{\[}}0, 0, 0, 1, 0, 1, 0, 0]> : tensor<8xi32>}
// CHECK-DAG: %[[TRANSV:.+]] = "tosa.const"() <{value = dense<[2, 4, 0, 1, 3, 5]> : tensor<6xi32>}
// CHECK-DAG: %[[PADW:.+]] = tosa.pad %arg1, %[[PADV]]
// CHECK-DAG: %[[RESW1:.+]] = tosa.reshape %[[PADW]] {new_shape = array<i64: 5, 2, 2, 2, 3, 3>}
@@ -54,20 +55,20 @@ func.func @transpose_conv2d_strided(%arg0: tensor<2x17x15x3xf32>, %arg1: tensor<
// CHECK-DAG: %[[NEWWEIGHT:.+]] = tosa.reverse %[[REV1]] {axis = 2 : i32}
// Pad out the input matrix to handle the transpose conv.
- // CHECK-DAG: %[[PAD:.+]] = "tosa.const"() <{value = dense<{{\[\[}}0, 0], [1, 1], [1, 1], [0, 0]]> : tensor<4x2xi32>}
+ // CHECK-DAG: %[[PAD:.+]] = "tosa.const"() <{value = dense<{{\[}}0, 0, 1, 1, 1, 1, 0, 0]> : tensor<8xi32>}
// CHECK-DAG: %[[TRANS2:.+]] = "tosa.const"() <{value = dense<[0, 1, 3, 2, 4, 5]> : tensor<6xi32>}
// CHECK-DAG: %[[NEWINPUT:.+]] = tosa.pad %arg0, %[[PAD]]
// Manipulate the final shape.
// CHECK-DAG: %[[BIAS:.+]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<30xf32>}
- // CHECK-DAG: %[[CONV:.+]] = tosa.conv2d %[[NEWINPUT]], %[[NEWWEIGHT]], %[[BIAS]] {dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>}
+ // CHECK-DAG: %[[CONV:.+]] = tosa.conv2d %[[NEWINPUT]], %[[NEWWEIGHT]], %[[BIAS]] {acc_type = f32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>}
// CHECK-DAG: %[[RESHAPE_OUT_1:.+]] = tosa.reshape %[[CONV]] {new_shape = array<i64: 2, 18, 16, 2, 3, 5>}
// CHECK-DAG: %[[TRANS_OUT:.+]] = tosa.transpose %[[RESHAPE_OUT_1]], %[[TRANS2]]
// CHECK-DAG: %[[RESHAPE_OUT_2:.+]] = tosa.reshape %[[TRANS_OUT]] {new_shape = array<i64: 2, 36, 48, 5>}
// CHECK-DAG: %[[SLICE:.+]] = tosa.slice %[[RESHAPE_OUT_2]] {size = array<i64: 2, 35, 47, 5>, start = array<i64: 0, 0, 0, 0>}
// CHECK-DAG: %[[RESHAPE_ARG2:.+]] = tosa.reshape %arg2 {new_shape = array<i64: 1, 1, 1, 5>}
// CHECK: %[[ADD:.+]] = tosa.add %[[SLICE]], %[[RESHAPE_ARG2]]
- %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2 {out_pad = array<i64: 0, 0, 0, 0>, out_shape = array<i64: -1, -1, -1, -1>, stride = array<i64: 2, 3>} : (tensor<2x17x15x3xf32>, tensor<5x3x5x3xf32>, tensor<5xf32>) -> tensor<2x35x47x5xf32>
+ %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2 {acc_type = f32, out_pad = array<i64: 0, 0, 0, 0>, out_shape = array<i64: -1, -1, -1, -1>, stride = array<i64: 2, 3>} : (tensor<2x17x15x3xf32>, tensor<5x3x5x3xf32>, tensor<5xf32>) -> tensor<2x35x47x5xf32>
%1 = tensor.cast %0 : tensor<2x35x47x5xf32> to tensor<2x?x?x5xf32>
return %1 : tensor<2x?x?x5xf32>
}
@@ -77,7 +78,7 @@ func.func @transpose_conv2d_strided(%arg0: tensor<2x17x15x3xf32>, %arg1: tensor<
// CHECK-LABEL: @transpose_conv2d_strided_quantized
func.func @transpose_conv2d_strided_quantized(%arg0: tensor<2x17x15x3xi8>, %arg1: tensor<5x3x5x3xi8>, %arg2: tensor<5xi32>) -> (tensor<2x35x47x5xi32>) {
// Manipulate the weight matrix to handle striding.
- // CHECK-DAG: %[[PADV:.+]] = "tosa.const"() <{value = dense<{{\[\[}}0, 0], [0, 1], [0, 1], [0, 0]]> : tensor<4x2xi32>}
+ // CHECK-DAG: %[[PADV:.+]] = "tosa.const"() <{value = dense<{{\[}}0, 0, 0, 1, 0, 1, 0, 0]> : tensor<8xi32>}
// CHECK-DAG: %[[TRANSV:.+]] = "tosa.const"() <{value = dense<[2, 4, 0, 1, 3, 5]> : tensor<6xi32>}
// CHECK-DAG: %[[PADW:.+]] = tosa.pad %arg1, %[[PADV]] {quantization_info = #tosa.pad_quant<input_zp = 42>}
// CHECK-DAG: %[[RESW1:.+]] = tosa.reshape %[[PADW]] {new_shape = array<i64: 5, 2, 2, 2, 3, 3>}
@@ -87,20 +88,20 @@ func.func @transpose_conv2d_strided_quantized(%arg0: tensor<2x17x15x3xi8>, %arg1
// CHECK-DAG: %[[NEWWEIGHT:.+]] = tosa.reverse %[[REV1]] {axis = 2 : i32}
// Pad out the input matrix to handle the transpose conv.
- // CHECK-DAG: %[[PAD:.+]] = "tosa.const"() <{value = dense<{{\[\[}}0, 0], [1, 1], [1, 1], [0, 0]]> : tensor<4x2xi32>}
+ // CHECK-DAG: %[[PAD:.+]] = "tosa.const"() <{value = dense<{{\[}}0, 0, 1, 1, 1, 1, 0, 0]> : tensor<8xi32>}
// CHECK-DAG: %[[TRANS2:.+]] = "tosa.const"() <{value = dense<[0, 1, 3, 2, 4, 5]> : tensor<6xi32>}
// CHECK-DAG: %[[NEWINPUT:.+]] = tosa.pad %arg0, %[[PAD]] {quantization_info = #tosa.pad_quant<input_zp = -22>}
// Manipulate the final shape.
// CHECK-DAG: %[[BIAS:.+]] = "tosa.const"() <{value = dense<0> : tensor<30xi32>}
- // CHECK-DAG: %[[CONV:.+]] = tosa.conv2d %[[NEWINPUT]], %[[NEWWEIGHT]], %[[BIAS]] {dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, quantization_info = #tosa.conv_quant<input_zp = -22, weight_zp = 42>, stride = array<i64: 1, 1>}
+ // CHECK-DAG: %[[CONV:.+]] = tosa.conv2d %[[NEWINPUT]], %[[NEWWEIGHT]], %[[BIAS]] {acc_type = i32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, quantization_info = #tosa.conv_quant<input_zp = -22, weight_zp = 42>, stride = array<i64: 1, 1>}
// CHECK-DAG: %[[RESHAPE_OUT_1:.+]] = tosa.reshape %[[CONV]] {new_shape = array<i64: 2, 18, 16, 2, 3, 5>}
// CHECK-DAG: %[[TRANS_OUT:.+]] = tosa.transpose %[[RESHAPE_OUT_1]], %[[TRANS2]]
// CHECK-DAG: %[[RESHAPE_OUT_2:.+]] = tosa.reshape %[[TRANS_OUT]] {new_shape = array<i64: 2, 36, 48, 5>}
// CHECK-DAG: %[[SLICE:.+]] = tosa.slice %[[RESHAPE_OUT_2]] {size = array<i64: 2, 35, 47, 5>, start = array<i64: 0, 0, 0, 0>}
// CHECK-DAG: %[[RESHAPE_ARG2:.+]] = tosa.reshape %arg2 {new_shape = array<i64: 1, 1, 1, 5>}
// CHECK: %[[ADD:.+]] = tosa.add %[[SLICE]], %[[RESHAPE_ARG2]]
- %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2 {out_pad = array<i64: 0, 0, 0, 0>, quantization_info = #tosa.conv_quant<input_zp = -22, weight_zp = 42>, out_shape = array<i64: -1, -1, -1, -1>, stride = array<i64: 2, 3>} : (tensor<2x17x15x3xi8>, tensor<5x3x5x3xi8>, tensor<5xi32>) -> tensor<2x35x47x5xi32>
+ %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2 {acc_type = i32, out_pad = array<i64: 0, 0, 0, 0>, quantization_info = #tosa.conv_quant<input_zp = -22, weight_zp = 42>, out_shape = array<i64: -1, -1, -1, -1>, stride = array<i64: 2, 3>} : (tensor<2x17x15x3xi8>, tensor<5x3x5x3xi8>, tensor<5xi32>) -> tensor<2x35x47x5xi32>
return %0 : tensor<2x35x47x5xi32>
}
@@ -108,12 +109,12 @@ func.func @transpose_conv2d_strided_quantized(%arg0: tensor<2x17x15x3xi8>, %arg1
// CHECK-LABEL: @transpose_conv2d_strided_overpad
func.func @transpose_conv2d_strided_overpad(%arg0 : tensor<1x16x1x1xi8>, %arg1 : tensor<1x2x1x1xi8>, %arg2 : tensor<1xi32>) -> (tensor<1x19x2x1xi32>) {
- // CHECK-DAG: %[[WEIGHT_PAD:.+]] = "tosa.const"() <{value = dense<{{\[}}[0, 0], [0, 0], [0, 1], [0, 0]]> : tensor<4x2xi32>
+ // CHECK-DAG: %[[WEIGHT_PAD:.+]] = "tosa.const"() <{value = dense<{{\[}}0, 0, 0, 0, 0, 1, 0, 0]> : tensor<8xi32>
// CHECK-DAG: %[[WEIGHT_PERMS:.+]] = "tosa.const"() <{value = dense<[2, 4, 0, 1, 3, 5]> : tensor<6xi32>}
- // CHECK-DAG: %[[INPUT_PAD:.+]] = "tosa.const"() <{value = dense<{{\[}}[0, 0], [1, 1], [0, 0], [0, 0]]> : tensor<4x2xi32>}
+ // CHECK-DAG: %[[INPUT_PAD:.+]] = "tosa.const"() <{value = dense<{{\[}}0, 0, 1, 1, 0, 0, 0, 0]> : tensor<8xi32>}
// CHECK-DAG: %[[ZERO:.+]] = "tosa.const"() <{value = dense<0> : tensor<2xi32>}
// CHECK-DAG: %[[RESULT_PERMS:.+]] = "tosa.const"() <{value = dense<[0, 1, 3, 2, 4, 5]> : tensor<6xi32>}
- // CHECK-DAG: %[[RESULT_PAD:.+]] = "tosa.const"() <{value = dense<{{\[}}[0, 0], [2, 0], [0, 0], [0, 0]]> : tensor<4x2xi32>}
+ // CHECK-DAG: %[[RESULT_PAD:.+]] = "tosa.const"() <{value = dense<{{\[}}0, 0, 2, 0, 0, 0, 0, 0]> : tensor<8xi32>}
// CHECK: %[[PAD_WEIGHT:.+]] = tosa.pad %arg1, %[[WEIGHT_PAD]] {quantization_info = #tosa.pad_quant<input_zp = 93>}
// CHECK: %[[RESHAPE_WEIGHT_0:.+]] = tosa.reshape %[[PAD_WEIGHT]] {new_shape = array<i64: 1, 2, 1, 1, 2, 1>}
// CHECK: %[[TRANSPOSE_WEIGHT:.+]] = tosa.transpose %[[RESHAPE_WEIGHT_0]], %[[WEIGHT_PERMS]]
@@ -129,6 +130,7 @@ func.func @transpose_conv2d_strided_overpad(%arg0 : tensor<1x16x1x1xi8>, %arg1 :
// CHECK: %[[RESHAPE_ARG2:.+]] = tosa.reshape %arg2 {new_shape = array<i64: 1, 1, 1, 1>}
// CHECK: %[[ADD:.+]] = tosa.add %[[PAD_RESULT]], %[[RESHAPE_ARG2]]
%2 = tosa.transpose_conv2d %arg0, %arg1, %arg2 {
+ acc_type = i32,
out_pad = array<i64: 2, 0, 0, 1>,
out_shape = array<i64: 1, -1, -1, 1>,
stride = array<i64: 1, 2>,
diff --git a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
index d46de740800e..82f3e22a3872 100644
--- a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
@@ -495,9 +495,9 @@ func.func @test_concat_axis_1(%arg0 : tensor<2x1xf32>, %arg1 : tensor<2x2xf32>)
// -----
// CHECK-LABEL: @test_padding_no_const
-func.func @test_padding_no_const(%arg0 : tensor<1x2xf32>, %arg1 : tensor<2x2xi32>) -> () {
- // CHECK: tosa.pad %arg0, %arg1 : (tensor<1x2xf32>, tensor<2x2xi32>) -> tensor<?x?xf32>
- %0 = tosa.pad %arg0, %arg1 : (tensor<1x2xf32>, tensor<2x2xi32>) -> tensor<?x?xf32>
+func.func @test_padding_no_const(%arg0 : tensor<1x2xf32>, %arg1 : tensor<4xi32>) -> () {
+ // CHECK: tosa.pad %arg0, %arg1 : (tensor<1x2xf32>, tensor<4xi32>) -> tensor<?x?xf32>
+ %0 = tosa.pad %arg0, %arg1 : (tensor<1x2xf32>, tensor<4xi32>) -> tensor<?x?xf32>
return
}
@@ -505,9 +505,9 @@ func.func @test_padding_no_const(%arg0 : tensor<1x2xf32>, %arg1 : tensor<2x2xi32
// CHECK-LABEL:@test_padding_dynamic_input
func.func @test_padding_dynamic_input(%arg0 : tensor<1x?xf32>) -> () {
- %0 = arith.constant dense<[[1, 2], [3, 4]]> : tensor<2x2xi32>
- // CHECK: tosa.pad %arg0, %cst : (tensor<1x?xf32>, tensor<2x2xi32>) -> tensor<4x?xf32>
- %1 = tosa.pad %arg0, %0 : (tensor<1x?xf32>, tensor<2x2xi32>) -> tensor<?x?xf32>
+ %0 = arith.constant dense<[1, 2, 3, 4]> : tensor<4xi32>
+ // CHECK: tosa.pad %arg0, %cst : (tensor<1x?xf32>, tensor<4xi32>) -> tensor<4x?xf32>
+ %1 = tosa.pad %arg0, %0 : (tensor<1x?xf32>, tensor<4xi32>) -> tensor<?x?xf32>
return
}
@@ -515,9 +515,9 @@ func.func @test_padding_dynamic_input(%arg0 : tensor<1x?xf32>) -> () {
// CHECK-LABEL: @test_padding_simple
func.func @test_padding_simple(%arg0 : tensor<1x2xf32>) -> () {
- %0 = arith.constant dense<[[1, 2], [3, 4]]> : tensor<2x2xi32>
- // CHECK: tosa.pad %arg0, %cst : (tensor<1x2xf32>, tensor<2x2xi32>) -> tensor<4x9xf32>
- %1 = tosa.pad %arg0, %0 : (tensor<1x2xf32>, tensor<2x2xi32>) -> tensor<?x?xf32>
+ %0 = arith.constant dense<[1, 2, 3, 4]> : tensor<4xi32>
+ // CHECK: tosa.pad %arg0, %cst : (tensor<1x2xf32>, tensor<4xi32>) -> tensor<4x9xf32>
+ %1 = tosa.pad %arg0, %0 : (tensor<1x2xf32>, tensor<4xi32>) -> tensor<?x?xf32>
return
}
@@ -674,7 +674,7 @@ func.func @test_pool_static(%arg0: tensor<3x5x6x7xf32>) {
// CHECK-LABEL: @conv2d_static
func.func @conv2d_static(%input: tensor<2x8x9x3xf32>, %weights: tensor<5x3x6x3xf32>, %bias: tensor<5xf32>) -> () {
// CHECK: -> tensor<2x6x4x5xf32>
- %0 = tosa.conv2d %input, %weights, %bias {pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, dilation = array<i64: 1, 1>} : (tensor<2x8x9x3xf32>, tensor<5x3x6x3xf32>, tensor<5xf32>) -> tensor<?x?x?x?xf32>
+ %0 = tosa.conv2d %input, %weights, %bias {acc_type = f32, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, dilation = array<i64: 1, 1>} : (tensor<2x8x9x3xf32>, tensor<5x3x6x3xf32>, tensor<5xf32>) -> tensor<?x?x?x?xf32>
return
}
@@ -683,7 +683,7 @@ func.func @conv2d_static(%input: tensor<2x8x9x3xf32>, %weights: tensor<5x3x6x3xf
// CHECK-LABEL: @conv2d_dynamic_input
func.func @conv2d_dynamic_input(%input: tensor<?x?x?x?xf32>, %weights: tensor<5x3x6x3xf32>, %bias: tensor<5xf32>) -> () {
// CHECK: -> tensor<?x?x?x5xf32>
- %0 = tosa.conv2d %input, %weights, %bias {pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, dilation = array<i64: 1, 1>} : (tensor<?x?x?x?xf32>, tensor<5x3x6x3xf32>, tensor<5xf32>) -> tensor<?x?x?x?xf32>
+ %0 = tosa.conv2d %input, %weights, %bias {acc_type = f32, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, dilation = array<i64: 1, 1>} : (tensor<?x?x?x?xf32>, tensor<5x3x6x3xf32>, tensor<5xf32>) -> tensor<?x?x?x?xf32>
return
}
@@ -716,7 +716,7 @@ func.func @test_pool_padded(%arg0: tensor<3x5x6x7xf32>) {
// CHECK-LABEL: @conv2d_dynamic_weight
func.func @conv2d_dynamic_weight(%input: tensor<2x8x9x3xf32>, %weights: tensor<?x?x?x?xf32>, %bias: tensor<5xf32>) -> () {
// CHECK: -> tensor<2x?x?x5xf32>
- %0 = tosa.conv2d %input, %weights, %bias {pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, dilation = array<i64: 1, 1>} : (tensor<2x8x9x3xf32>, tensor<?x?x?x?xf32>, tensor<5xf32>) -> tensor<?x?x?x?xf32>
+ %0 = tosa.conv2d %input, %weights, %bias {acc_type = f32, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, dilation = array<i64: 1, 1>} : (tensor<2x8x9x3xf32>, tensor<?x?x?x?xf32>, tensor<5xf32>) -> tensor<?x?x?x?xf32>
return
}
@@ -725,7 +725,7 @@ func.func @conv2d_dynamic_weight(%input: tensor<2x8x9x3xf32>, %weights: tensor<?
// CHECK-LABEL: @conv2d_dynamic_bias
func.func @conv2d_dynamic_bias(%input: tensor<2x8x9x3xf32>, %weights: tensor<5x3x6x3xf32>, %bias: tensor<?xf32>) -> () {
// CHECK: -> tensor<2x6x4x5xf32>
- %0 = tosa.conv2d %input, %weights, %bias {pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, dilation = array<i64: 1, 1>} : (tensor<2x8x9x3xf32>, tensor<5x3x6x3xf32>, tensor<?xf32>) -> tensor<?x?x?x?xf32>
+ %0 = tosa.conv2d %input, %weights, %bias {acc_type = f32, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, dilation = array<i64: 1, 1>} : (tensor<2x8x9x3xf32>, tensor<5x3x6x3xf32>, tensor<?xf32>) -> tensor<?x?x?x?xf32>
return
}
@@ -746,7 +746,7 @@ func.func @test_pool_stride(%arg0: tensor<3x11x12x7xf32>) {
// CHECK-LABEL: @conv2d_padded
func.func @conv2d_padded(%input: tensor<2x8x9x3xf32>, %weights: tensor<5x3x6x3xf32>, %bias: tensor<5xf32>) -> () {
// CHECK: -> tensor<2x9x11x5xf32>
- %0 = tosa.conv2d %input, %weights, %bias {pad = array<i64: 1, 2, 3, 4>, stride = array<i64: 1, 1>, dilation = array<i64: 1, 1>} : (tensor<2x8x9x3xf32>, tensor<5x3x6x3xf32>, tensor<5xf32>) -> tensor<?x?x?x?xf32>
+ %0 = tosa.conv2d %input, %weights, %bias {acc_type = f32, pad = array<i64: 1, 2, 3, 4>, stride = array<i64: 1, 1>, dilation = array<i64: 1, 1>} : (tensor<2x8x9x3xf32>, tensor<5x3x6x3xf32>, tensor<5xf32>) -> tensor<?x?x?x?xf32>
return
}
@@ -755,7 +755,7 @@ func.func @conv2d_padded(%input: tensor<2x8x9x3xf32>, %weights: tensor<5x3x6x3xf
// CHECK-LABEL: @conv2d_dilated
func.func @conv2d_dilated(%input: tensor<2x12x14x3xf32>, %weights: tensor<5x3x6x3xf32>, %bias: tensor<5xf32>) -> () {
// CHECK: -> tensor<2x6x4x5xf32>
- %0 = tosa.conv2d %input, %weights, %bias {pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, dilation = array<i64: 3, 2>} : (tensor<2x12x14x3xf32>, tensor<5x3x6x3xf32>, tensor<5xf32>) -> tensor<?x?x?x?xf32>
+ %0 = tosa.conv2d %input, %weights, %bias {acc_type = f32, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, dilation = array<i64: 3, 2>} : (tensor<2x12x14x3xf32>, tensor<5x3x6x3xf32>, tensor<5xf32>) -> tensor<?x?x?x?xf32>
return
}
@@ -764,7 +764,7 @@ func.func @conv2d_dilated(%input: tensor<2x12x14x3xf32>, %weights: tensor<5x3x6x
// CHECK-LABEL: @conv2d_strided
func.func @conv2d_strided(%input: tensor<1x13x14x1xf32>, %weights: tensor<1x1x1x1xf32>, %bias: tensor<1xf32>) -> () {
// CHECK: -> tensor<1x5x7x1xf32>
- %0 = tosa.conv2d %input, %weights, %bias {pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 3, 2>, dilation = array<i64: 1, 1>} : (tensor<1x13x14x1xf32>, tensor<1x1x1x1xf32>, tensor<1xf32>) -> tensor<?x?x?x?xf32>
+ %0 = tosa.conv2d %input, %weights, %bias {acc_type = f32, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 3, 2>, dilation = array<i64: 1, 1>} : (tensor<1x13x14x1xf32>, tensor<1x1x1x1xf32>, tensor<1xf32>) -> tensor<?x?x?x?xf32>
return
}
@@ -773,7 +773,7 @@ func.func @conv2d_strided(%input: tensor<1x13x14x1xf32>, %weights: tensor<1x1x1x
// CHECK-LABEL: @conv3d_static
func.func @conv3d_static(%input: tensor<2x8x9x10x3xf32>, %weights: tensor<5x3x6x4x3xf32>, %bias: tensor<5xf32>) -> () {
// CHECK: -> tensor<2x6x4x7x5xf32>
- %0 = tosa.conv3d %input, %weights, %bias {dilation = array<i64: 1, 1, 1>, pad = array<i64: 0, 0, 0, 0, 0, 0>, stride = array<i64: 1, 1, 1>} : (tensor<2x8x9x10x3xf32>, tensor<5x3x6x4x3xf32>, tensor<5xf32>) -> tensor<?x?x?x?x?xf32>
+ %0 = tosa.conv3d %input, %weights, %bias {acc_type = f32, dilation = array<i64: 1, 1, 1>, pad = array<i64: 0, 0, 0, 0, 0, 0>, stride = array<i64: 1, 1, 1>} : (tensor<2x8x9x10x3xf32>, tensor<5x3x6x4x3xf32>, tensor<5xf32>) -> tensor<?x?x?x?x?xf32>
return
}
@@ -782,7 +782,7 @@ func.func @conv3d_static(%input: tensor<2x8x9x10x3xf32>, %weights: tensor<5x3x6x
// CHECK-LABEL: @conv3d_dynamic_input
func.func @conv3d_dynamic_input(%arg0: tensor<?x?x?x?x?xf32>, %arg1: tensor<5x3x6x4x3xf32>, %arg2: tensor<5xf32>) {
// CHECK: -> tensor<?x?x?x?x5xf32>
- %0 = tosa.conv3d %arg0, %arg1, %arg2 {dilation = array<i64: 1, 1, 1>, pad = array<i64: 0, 0, 0, 0, 0, 0>, stride = array<i64: 1, 1, 1>} : (tensor<?x?x?x?x?xf32>, tensor<5x3x6x4x3xf32>, tensor<5xf32>) -> tensor<?x?x?x?x?xf32>
+ %0 = tosa.conv3d %arg0, %arg1, %arg2 {acc_type = f32, dilation = array<i64: 1, 1, 1>, pad = array<i64: 0, 0, 0, 0, 0, 0>, stride = array<i64: 1, 1, 1>} : (tensor<?x?x?x?x?xf32>, tensor<5x3x6x4x3xf32>, tensor<5xf32>) -> tensor<?x?x?x?x?xf32>
return
}
@@ -791,7 +791,7 @@ func.func @conv3d_dynamic_input(%arg0: tensor<?x?x?x?x?xf32>, %arg1: tensor<5x3x
// CHECK-LABEL: @conv3d_dynamic_weight
func.func @conv3d_dynamic_weight(%arg0: tensor<2x8x9x10x3xf32>, %arg1: tensor<?x?x?x?x?xf32>, %arg2: tensor<5xf32>) {
// CHECK: -> tensor<2x?x?x?x5xf32>
- %0 = tosa.conv3d %arg0, %arg1, %arg2 {dilation = array<i64: 1, 1, 1>, pad = array<i64: 0, 0, 0, 0, 0, 0>, stride = array<i64: 1, 1, 1>} : (tensor<2x8x9x10x3xf32>, tensor<?x?x?x?x?xf32>, tensor<5xf32>) -> tensor<?x?x?x?x?xf32>
+ %0 = tosa.conv3d %arg0, %arg1, %arg2 {acc_type = f32, dilation = array<i64: 1, 1, 1>, pad = array<i64: 0, 0, 0, 0, 0, 0>, stride = array<i64: 1, 1, 1>} : (tensor<2x8x9x10x3xf32>, tensor<?x?x?x?x?xf32>, tensor<5xf32>) -> tensor<?x?x?x?x?xf32>
return
}
@@ -800,7 +800,7 @@ func.func @conv3d_dynamic_weight(%arg0: tensor<2x8x9x10x3xf32>, %arg1: tensor<?x
// CHECK-LABEL: @conv3d_dynamic_bias
func.func @conv3d_dynamic_bias(%arg0: tensor<2x8x9x10x3xf32>, %arg1: tensor<5x3x6x4x3xf32>, %arg2: tensor<?xf32>) {
// CHECK: -> tensor<2x6x4x7x5xf32>
- %0 = tosa.conv3d %arg0, %arg1, %arg2 {dilation = array<i64: 1, 1, 1>, pad = array<i64: 0, 0, 0, 0, 0, 0>, stride = array<i64: 1, 1, 1>} : (tensor<2x8x9x10x3xf32>, tensor<5x3x6x4x3xf32>, tensor<?xf32>) -> tensor<?x?x?x?x?xf32>
+ %0 = tosa.conv3d %arg0, %arg1, %arg2 {acc_type = f32, dilation = array<i64: 1, 1, 1>, pad = array<i64: 0, 0, 0, 0, 0, 0>, stride = array<i64: 1, 1, 1>} : (tensor<2x8x9x10x3xf32>, tensor<5x3x6x4x3xf32>, tensor<?xf32>) -> tensor<?x?x?x?x?xf32>
return
}
@@ -809,7 +809,7 @@ func.func @conv3d_dynamic_bias(%arg0: tensor<2x8x9x10x3xf32>, %arg1: tensor<5x3x
// CHECK-LABEL: @conv3d_padded
func.func @conv3d_padded(%arg0: tensor<2x8x9x10x3xf32>, %arg1: tensor<5x3x6x4x3xf32>, %arg2: tensor<5xf32>) {
// CHECK: -> tensor<2x9x11x18x5xf32>
- %0 = tosa.conv3d %arg0, %arg1, %arg2 {dilation = array<i64: 1, 1, 1>, pad = array<i64: 1, 2, 3, 4, 5, 6>, stride = array<i64: 1, 1, 1>} : (tensor<2x8x9x10x3xf32>, tensor<5x3x6x4x3xf32>, tensor<5xf32>) -> tensor<?x?x?x?x?xf32>
+ %0 = tosa.conv3d %arg0, %arg1, %arg2 {acc_type = f32, dilation = array<i64: 1, 1, 1>, pad = array<i64: 1, 2, 3, 4, 5, 6>, stride = array<i64: 1, 1, 1>} : (tensor<2x8x9x10x3xf32>, tensor<5x3x6x4x3xf32>, tensor<5xf32>) -> tensor<?x?x?x?x?xf32>
return
}
@@ -818,7 +818,7 @@ func.func @conv3d_padded(%arg0: tensor<2x8x9x10x3xf32>, %arg1: tensor<5x3x6x4x3x
// CHECK-LABEL: @conv3d_dilated
func.func @conv3d_dilated(%arg0: tensor<2x12x14x16x3xf32>, %arg1: tensor<5x3x6x2x3xf32>, %arg2: tensor<5xf32>) {
// CHECK: -> tensor<2x6x4x12x5xf32>
- %0 = tosa.conv3d %arg0, %arg1, %arg2 {dilation = array<i64: 3, 2, 4>, pad = array<i64: 0, 0, 0, 0, 0, 0>, stride = array<i64: 1, 1, 1>} : (tensor<2x12x14x16x3xf32>, tensor<5x3x6x2x3xf32>, tensor<5xf32>) -> tensor<?x?x?x?x?xf32>
+ %0 = tosa.conv3d %arg0, %arg1, %arg2 {acc_type = f32, dilation = array<i64: 3, 2, 4>, pad = array<i64: 0, 0, 0, 0, 0, 0>, stride = array<i64: 1, 1, 1>} : (tensor<2x12x14x16x3xf32>, tensor<5x3x6x2x3xf32>, tensor<5xf32>) -> tensor<?x?x?x?x?xf32>
return
}
@@ -827,7 +827,7 @@ func.func @conv3d_dilated(%arg0: tensor<2x12x14x16x3xf32>, %arg1: tensor<5x3x6x2
// CHECK-LABEL: @conv3d_strided
func.func @conv3d_strided(%arg0: tensor<1x13x14x15x1xf32>, %arg1: tensor<1x1x1x1x1xf32>, %arg2: tensor<1xf32>) {
// CHECK: -> tensor<1x5x7x4x1xf32>
- %0 = tosa.conv3d %arg0, %arg1, %arg2 {dilation = array<i64: 1, 1, 1>, pad = array<i64: 0, 0, 0, 0, 0, 0>, stride = array<i64: 3, 2, 4>} : (tensor<1x13x14x15x1xf32>, tensor<1x1x1x1x1xf32>, tensor<1xf32>) -> tensor<?x?x?x?x?xf32>
+ %0 = tosa.conv3d %arg0, %arg1, %arg2 {acc_type = f32, dilation = array<i64: 1, 1, 1>, pad = array<i64: 0, 0, 0, 0, 0, 0>, stride = array<i64: 3, 2, 4>} : (tensor<1x13x14x15x1xf32>, tensor<1x1x1x1x1xf32>, tensor<1xf32>) -> tensor<?x?x?x?x?xf32>
return
}
@@ -836,7 +836,7 @@ func.func @conv3d_strided(%arg0: tensor<1x13x14x15x1xf32>, %arg1: tensor<1x1x1x1
// CHECK-LABEL: @depthwise_conv2d_static
func.func @depthwise_conv2d_static(%arg0: tensor<2x8x9x3xf32>, %arg1: tensor<3x6x3x5xf32>, %arg2: tensor<15xf32>) {
// CHECK: -> tensor<2x6x4x15xf32>
- %0 = tosa.depthwise_conv2d %arg0, %arg1, %arg2 {dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<2x8x9x3xf32>, tensor<3x6x3x5xf32>, tensor<15xf32>) -> tensor<2x6x4x15xf32>
+ %0 = tosa.depthwise_conv2d %arg0, %arg1, %arg2 {acc_type = f32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<2x8x9x3xf32>, tensor<3x6x3x5xf32>, tensor<15xf32>) -> tensor<2x6x4x15xf32>
return
}
@@ -845,7 +845,7 @@ func.func @depthwise_conv2d_static(%arg0: tensor<2x8x9x3xf32>, %arg1: tensor<3x6
// CHECK-LABEL: @depthwise_conv2d_dynamic_input
func.func @depthwise_conv2d_dynamic_input(%arg0: tensor<?x?x?x?xf32>, %arg1: tensor<3x6x3x5xf32>, %arg2: tensor<15xf32>) {
// CHECK: -> tensor<?x?x?x15xf32>
- %0 = tosa.depthwise_conv2d %arg0, %arg1, %arg2 {dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<?x?x?x?xf32>, tensor<3x6x3x5xf32>, tensor<15xf32>) -> tensor<?x?x?x15xf32>
+ %0 = tosa.depthwise_conv2d %arg0, %arg1, %arg2 {acc_type = f32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<?x?x?x?xf32>, tensor<3x6x3x5xf32>, tensor<15xf32>) -> tensor<?x?x?x15xf32>
return
}
@@ -854,7 +854,7 @@ func.func @depthwise_conv2d_dynamic_input(%arg0: tensor<?x?x?x?xf32>, %arg1: ten
// CHECK-LABEL: @depthwise_conv2d_dynamic_weight
func.func @depthwise_conv2d_dynamic_weight(%arg0: tensor<2x8x9x3xf32>, %arg1: tensor<?x?x?x?xf32>, %arg2: tensor<15xf32>) {
// CHECK: -> tensor<2x?x?x15xf32>
- %0 = tosa.depthwise_conv2d %arg0, %arg1, %arg2 {dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<2x8x9x3xf32>, tensor<?x?x?x?xf32>, tensor<15xf32>) -> tensor<2x?x?x15xf32>
+ %0 = tosa.depthwise_conv2d %arg0, %arg1, %arg2 {acc_type = f32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<2x8x9x3xf32>, tensor<?x?x?x?xf32>, tensor<15xf32>) -> tensor<2x?x?x15xf32>
return
}
@@ -863,7 +863,7 @@ func.func @depthwise_conv2d_dynamic_weight(%arg0: tensor<2x8x9x3xf32>, %arg1: te
// CHECK-LABEL: @depthwise_conv2d_dynamic_bias
func.func @depthwise_conv2d_dynamic_bias(%arg0: tensor<2x8x9x3xf32>, %arg1: tensor<3x6x3x5xf32>, %arg2: tensor<?xf32>) {
// CHECK: -> tensor<2x6x4x15xf32>
- %0 = tosa.depthwise_conv2d %arg0, %arg1, %arg2 {dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<2x8x9x3xf32>, tensor<3x6x3x5xf32>, tensor<?xf32>) -> tensor<2x6x4x15xf32>
+ %0 = tosa.depthwise_conv2d %arg0, %arg1, %arg2 {acc_type = f32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<2x8x9x3xf32>, tensor<3x6x3x5xf32>, tensor<?xf32>) -> tensor<2x6x4x15xf32>
return
}
@@ -872,7 +872,7 @@ func.func @depthwise_conv2d_dynamic_bias(%arg0: tensor<2x8x9x3xf32>, %arg1: tens
// CHECK-LABEL: @depthwise_conv2d_padded
func.func @depthwise_conv2d_padded(%arg0: tensor<2x8x9x3xf32>, %arg1: tensor<3x6x3x5xf32>, %arg2: tensor<15xf32>) {
// CHECK: -> tensor<2x9x11x15xf32>
- %0 = tosa.depthwise_conv2d %arg0, %arg1, %arg2 {dilation = array<i64: 1, 1>, pad = array<i64: 1, 2, 3, 4>, stride = array<i64: 1, 1>} : (tensor<2x8x9x3xf32>, tensor<3x6x3x5xf32>, tensor<15xf32>) -> tensor<2x9x11x15xf32>
+ %0 = tosa.depthwise_conv2d %arg0, %arg1, %arg2 {acc_type = f32, dilation = array<i64: 1, 1>, pad = array<i64: 1, 2, 3, 4>, stride = array<i64: 1, 1>} : (tensor<2x8x9x3xf32>, tensor<3x6x3x5xf32>, tensor<15xf32>) -> tensor<2x9x11x15xf32>
return
}
@@ -881,7 +881,7 @@ func.func @depthwise_conv2d_padded(%arg0: tensor<2x8x9x3xf32>, %arg1: tensor<3x6
// CHECK-LABEL: @depthwise_conv2d_dilated
func.func @depthwise_conv2d_dilated(%arg0: tensor<2x12x14x3xf32>, %arg1: tensor<3x6x3x5xf32>, %arg2: tensor<15xf32>) {
// CHECK: -> tensor<2x6x4x15xf32>
- %0 = tosa.depthwise_conv2d %arg0, %arg1, %arg2 {dilation = array<i64: 3, 2>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<2x12x14x3xf32>, tensor<3x6x3x5xf32>, tensor<15xf32>) -> tensor<2x6x4x15xf32>
+ %0 = tosa.depthwise_conv2d %arg0, %arg1, %arg2 {acc_type = f32, dilation = array<i64: 3, 2>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<2x12x14x3xf32>, tensor<3x6x3x5xf32>, tensor<15xf32>) -> tensor<2x6x4x15xf32>
return
}
@@ -890,7 +890,7 @@ func.func @depthwise_conv2d_dilated(%arg0: tensor<2x12x14x3xf32>, %arg1: tensor<
// CHECK-LABEL: @depthwise_conv2d_strided
func.func @depthwise_conv2d_strided(%arg0: tensor<1x13x14x1xf32>, %arg1: tensor<1x1x1x1xf32>, %arg2: tensor<1xf32>) {
// CHECK: -> tensor<1x5x7x1xf32>
- %0 = tosa.depthwise_conv2d %arg0, %arg1, %arg2 {dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 3, 2>} : (tensor<1x13x14x1xf32>, tensor<1x1x1x1xf32>, tensor<1xf32>) -> tensor<1x5x7x1xf32>
+ %0 = tosa.depthwise_conv2d %arg0, %arg1, %arg2 {acc_type = f32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 3, 2>} : (tensor<1x13x14x1xf32>, tensor<1x1x1x1xf32>, tensor<1xf32>) -> tensor<1x5x7x1xf32>
return
}
@@ -899,7 +899,7 @@ func.func @depthwise_conv2d_strided(%arg0: tensor<1x13x14x1xf32>, %arg1: tensor<
// CHECK-LABEL: @transpose_conv2d_out_shape
func.func @transpose_conv2d_out_shape(%arg0: tensor<2x?x?x3xf32>, %arg1: tensor<5x3x6x3xf32>, %arg2: tensor<5xf32>) {
// CHECK: -> tensor<2x8x9x5xf32>
- %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2 {out_pad = array<i64: 0, 0, 0, 0>, out_shape = array<i64: -1, 8, 9, -1>, stride = array<i64: 1, 1>} : (tensor<2x?x?x3xf32>, tensor<5x3x6x3xf32>, tensor<5xf32>) -> tensor<2x8x9x5xf32>
+ %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2 {acc_type = f32, out_pad = array<i64: 0, 0, 0, 0>, out_shape = array<i64: -1, 8, 9, -1>, stride = array<i64: 1, 1>} : (tensor<2x?x?x3xf32>, tensor<5x3x6x3xf32>, tensor<5xf32>) -> tensor<2x8x9x5xf32>
return
}
@@ -908,7 +908,7 @@ func.func @transpose_conv2d_out_shape(%arg0: tensor<2x?x?x3xf32>, %arg1: tensor<
// CHECK-LABEL: @transpose_conv2d_static
func.func @transpose_conv2d_static(%arg0: tensor<2x16x14x3xf32>, %arg1: tensor<5x3x6x3xf32>, %arg2: tensor<5xf32>) {
// CHECK: -> tensor<2x18x19x5xf32>
- %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2 {out_pad = array<i64: 0, 0, 0, 0>, out_shape = array<i64: -1, -1, -1, -1>, stride = array<i64: 1, 1>} : (tensor<2x16x14x3xf32>, tensor<5x3x6x3xf32>, tensor<5xf32>) -> tensor<2x?x?x5xf32>
+ %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2 {acc_type = f32, out_pad = array<i64: 0, 0, 0, 0>, out_shape = array<i64: -1, -1, -1, -1>, stride = array<i64: 1, 1>} : (tensor<2x16x14x3xf32>, tensor<5x3x6x3xf32>, tensor<5xf32>) -> tensor<2x?x?x5xf32>
return
}
@@ -917,7 +917,7 @@ func.func @transpose_conv2d_static(%arg0: tensor<2x16x14x3xf32>, %arg1: tensor<5
// CHECK-LABEL: @transpose_conv2d_static_strided
func.func @transpose_conv2d_static_strided(%arg0: tensor<2x16x14x3xf32>, %arg1: tensor<5x3x6x3xf32>, %arg2: tensor<5xf32>) {
// CHECK: -> tensor<2x33x45x5xf32>
- %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2 {out_pad = array<i64: 0, 0, 0, 0>, out_shape = array<i64: -1, -1, -1, -1>, stride = array<i64: 2, 3>} : (tensor<2x16x14x3xf32>, tensor<5x3x6x3xf32>, tensor<5xf32>) -> tensor<2x?x?x5xf32>
+ %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2 {acc_type = f32, out_pad = array<i64: 0, 0, 0, 0>, out_shape = array<i64: -1, -1, -1, -1>, stride = array<i64: 2, 3>} : (tensor<2x16x14x3xf32>, tensor<5x3x6x3xf32>, tensor<5xf32>) -> tensor<2x?x?x5xf32>
return
}
@@ -926,7 +926,7 @@ func.func @transpose_conv2d_static_strided(%arg0: tensor<2x16x14x3xf32>, %arg1:
// CHECK-LABEL: @transpose_conv2d_dynamic_input
func.func @transpose_conv2d_dynamic_input(%arg0: tensor<?x?x?x?xf32>, %arg1: tensor<5x3x6x3xf32>, %arg2: tensor<5xf32>) {
// CHECK: -> tensor<?x?x?x5xf32>
- %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2 {out_pad = array<i64: 0, 0, 0, 0>, out_shape = array<i64: -1, -1, -1, -1>, stride = array<i64: 1, 1>} : (tensor<?x?x?x?xf32>, tensor<5x3x6x3xf32>, tensor<5xf32>) -> tensor<?x?x?x5xf32>
+ %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2 {acc_type = f32, out_pad = array<i64: 0, 0, 0, 0>, out_shape = array<i64: -1, -1, -1, -1>, stride = array<i64: 1, 1>} : (tensor<?x?x?x?xf32>, tensor<5x3x6x3xf32>, tensor<5xf32>) -> tensor<?x?x?x5xf32>
return
}
@@ -935,7 +935,7 @@ func.func @transpose_conv2d_dynamic_input(%arg0: tensor<?x?x?x?xf32>, %arg1: ten
// CHECK-LABEL: @transpose_conv2d_dynamic_weights
func.func @transpose_conv2d_dynamic_weights(%arg0: tensor<2x6x4x3xf32>, %arg1: tensor<?x?x?x?xf32>, %arg2: tensor<5xf32>) {
// CHECK: -> tensor<2x?x?x5xf32>
- %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2 {out_pad = array<i64: 0, 0, 0, 0>, out_shape = array<i64: -1, -1, -1, -1>, stride = array<i64: 1, 1>} : (tensor<2x6x4x3xf32>, tensor<?x?x?x?xf32>, tensor<5xf32>) -> tensor<2x?x?x5xf32>
+ %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2 {acc_type = f32, out_pad = array<i64: 0, 0, 0, 0>, out_shape = array<i64: -1, -1, -1, -1>, stride = array<i64: 1, 1>} : (tensor<2x6x4x3xf32>, tensor<?x?x?x?xf32>, tensor<5xf32>) -> tensor<2x?x?x5xf32>
return
}
@@ -944,7 +944,7 @@ func.func @transpose_conv2d_dynamic_weights(%arg0: tensor<2x6x4x3xf32>, %arg1: t
// CHECK-LABEL: @transpose_conv2d_dynamic_bias
func.func @transpose_conv2d_dynamic_bias(%arg0: tensor<2x6x4x3xf32>, %arg1: tensor<5x3x6x3xf32>, %arg2: tensor<?xf32>) {
// CHECK: -> tensor<2x8x9x5xf32>
- %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2 {out_pad = array<i64: 0, 0, 0, 0>, out_shape = array<i64: -1, -1, -1, -1>, stride = array<i64: 1, 1>} : (tensor<2x6x4x3xf32>, tensor<5x3x6x3xf32>, tensor<?xf32>) -> tensor<2x8x9x5xf32>
+ %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2 {acc_type = f32, out_pad = array<i64: 0, 0, 0, 0>, out_shape = array<i64: -1, -1, -1, -1>, stride = array<i64: 1, 1>} : (tensor<2x6x4x3xf32>, tensor<5x3x6x3xf32>, tensor<?xf32>) -> tensor<2x8x9x5xf32>
return
}
@@ -953,14 +953,14 @@ func.func @transpose_conv2d_dynamic_bias(%arg0: tensor<2x6x4x3xf32>, %arg1: tens
// CHECK-LABEL: @transpose_conv2d_padded
func.func @transpose_conv2d_padded(%arg0: tensor<2x9x11x3xf32>, %arg1: tensor<5x3x6x3xf32>, %arg2: tensor<5xf32>) {
// CHECK: -> tensor<2x10x13x5xf32>
- %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2 {out_pad = array<i64: 1, 0, 3, 0>, out_shape = array<i64: -1, -1, -1, -1>, stride = array<i64: 1, 1>} : (tensor<2x9x11x3xf32>, tensor<5x3x6x3xf32>, tensor<5xf32>) -> tensor<2x10x13x5xf32>
+ %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2 {acc_type = f32, out_pad = array<i64: 1, 0, 3, 0>, out_shape = array<i64: -1, -1, -1, -1>, stride = array<i64: 1, 1>} : (tensor<2x9x11x3xf32>, tensor<5x3x6x3xf32>, tensor<5xf32>) -> tensor<2x10x13x5xf32>
return
}
// CHECK-LABEL: @transpose_conv2d_strided
func.func @transpose_conv2d_strided(%arg0: tensor<1x5x7x1xf32>, %arg1: tensor<1x1x1x1xf32>, %arg2: tensor<1xf32>) {
// CHECK: -> tensor<1x13x13x1xf32>
- %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2 {out_pad = array<i64: 0, 0, 0, 0>, out_shape = array<i64: -1, -1, -1, -1>, stride = array<i64: 3, 2>} : (tensor<1x5x7x1xf32>, tensor<1x1x1x1xf32>, tensor<1xf32>) -> tensor<1x13x13x1xf32>
+ %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2 {acc_type = f32, out_pad = array<i64: 0, 0, 0, 0>, out_shape = array<i64: -1, -1, -1, -1>, stride = array<i64: 3, 2>} : (tensor<1x5x7x1xf32>, tensor<1x1x1x1xf32>, tensor<1xf32>) -> tensor<1x13x13x1xf32>
return
}
@@ -1368,7 +1368,7 @@ func.func @test_non_tosa_consumer_still_propagates(%arg0: tensor<1x1x8xf32>, %ar
func.func @test_tosa_use_def_chain(%arg0: tensor<1x32x32x3xf32>, %arg1: tensor<16x3x3x3xf32>, %arg2: tensor<16xf32>) -> tensor<?x16x16x16xf32> {
// CHECK: [[CONV:%.+]] = tosa.conv2d %arg0, %arg1, %arg2
// CHECK: (tensor<1x32x32x3xf32>, tensor<16x3x3x3xf32>, tensor<16xf32>) -> tensor<1x32x32x16xf32>
- %0 = tosa.conv2d %arg0, %arg1, %arg2 {dilation = array<i64: 1, 1>, pad = array<i64: 1, 1, 1, 1>, stride = array<i64: 1, 1>} : (tensor<1x32x32x3xf32>, tensor<16x3x3x3xf32>, tensor<16xf32>) -> tensor<?x32x32x16xf32>
+ %0 = tosa.conv2d %arg0, %arg1, %arg2 {acc_type = f32, dilation = array<i64: 1, 1>, pad = array<i64: 1, 1, 1, 1>, stride = array<i64: 1, 1>} : (tensor<1x32x32x3xf32>, tensor<16x3x3x3xf32>, tensor<16xf32>) -> tensor<?x32x32x16xf32>
// CHECK: tosa.max_pool2d [[CONV]]
// CHECK: (tensor<1x32x32x16xf32>) -> tensor<1x16x16x16xf32>
%1 = tosa.max_pool2d %0 {kernel = array<i64: 2, 2>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 2, 2>} : (tensor<?x32x32x16xf32>) -> tensor<?x16x16x16xf32>
diff --git a/mlir/test/Integration/GPU/CUDA/assert.mlir b/mlir/test/Integration/GPU/CUDA/assert.mlir
new file mode 100644
index 000000000000..3d6527fe59b2
--- /dev/null
+++ b/mlir/test/Integration/GPU/CUDA/assert.mlir
@@ -0,0 +1,38 @@
+// RUN: mlir-opt %s -gpu-lower-to-nvvm-pipeline="cubin-format=%gpu_compilation_format" \
+// RUN: | mlir-cpu-runner \
+// RUN: --shared-libs=%mlir_cuda_runtime \
+// RUN: --shared-libs=%mlir_runner_utils \
+// RUN: --entry-point-result=void 2>&1 \
+// RUN: | FileCheck %s
+
+// CHECK-DAG: thread 0: print after passing assertion
+// CHECK-DAG: thread 1: print after passing assertion
+// CHECK-DAG: callee_file.cc:7: callee_func_name: block: [0,0,0], thread: [0,0,0] Assertion `failing assertion` failed.
+// CHECK-DAG: callee_file.cc:7: callee_func_name: block: [0,0,0], thread: [1,0,0] Assertion `failing assertion` failed.
+// CHECK-NOT: print after failing assertion
+
+module attributes {gpu.container_module} {
+gpu.module @kernels {
+gpu.func @test_assert(%c0: i1, %c1: i1) kernel {
+ %0 = gpu.thread_id x
+ cf.assert %c1, "passing assertion"
+ gpu.printf "thread %lld: print after passing assertion\n", %0 : index
+ // Test callsite(callsite(name)) location.
+ cf.assert %c0, "failing assertion" loc(callsite(callsite("callee_func_name"("callee_file.cc":7:9) at "caller_file.cc":10:8) at "caller2_file.cc":11:12))
+ gpu.printf "thread %lld: print after failing assertion\n", %0 : index
+ gpu.return
+}
+}
+
+func.func @main() {
+ %c2 = arith.constant 2 : index
+ %c1 = arith.constant 1 : index
+ %c0_i1 = arith.constant 0 : i1
+ %c1_i1 = arith.constant 1 : i1
+ gpu.launch_func @kernels::@test_assert
+ blocks in (%c1, %c1, %c1)
+ threads in (%c2, %c1, %c1)
+ args(%c0_i1 : i1, %c1_i1 : i1)
+ return
+}
+}
diff --git a/mlir/test/Integration/GPU/CUDA/printf.mlir b/mlir/test/Integration/GPU/CUDA/printf.mlir
index 99ea1208e9c5..15b0bf02d911 100644
--- a/mlir/test/Integration/GPU/CUDA/printf.mlir
+++ b/mlir/test/Integration/GPU/CUDA/printf.mlir
@@ -14,7 +14,7 @@ module attributes {gpu.container_module} {
%0 = gpu.thread_id x
%csti8 = arith.constant 2 : i8
%cstf32 = arith.constant 3.0 : f32
- gpu.printf "Hello from %lld, %d, %f\n" %0, %csti8, %cstf32 : index, i8, f32
+ gpu.printf "Hello from %lld, %d, %f\n", %0, %csti8, %cstf32 : index, i8, f32
gpu.return
}
}
diff --git a/mlir/test/Integration/GPU/CUDA/sm90/cga_cluster.mlir b/mlir/test/Integration/GPU/CUDA/sm90/cga_cluster.mlir
index c70c940564a2..a22a34b9393a 100644
--- a/mlir/test/Integration/GPU/CUDA/sm90/cga_cluster.mlir
+++ b/mlir/test/Integration/GPU/CUDA/sm90/cga_cluster.mlir
@@ -43,7 +43,7 @@ module attributes {gpu.container_module} {
%cnd2 = arith.cmpi eq, %bidY, %c3 : index
scf.if %cnd1 {
scf.if %cnd2 {
- gpu.printf "clusterIdx: (%d, %d, %d) in Cluster Dimension: (%d, %d, %d) blockIdx: (%d, %d, %d) \n"
+ gpu.printf "clusterIdx: (%d, %d, %d) in Cluster Dimension: (%d, %d, %d) blockIdx: (%d, %d, %d) \n",
%cidX_i32,
%cidY_i32,
%cidZ_i32,
diff --git a/mlir/test/Integration/GPU/CUDA/sm90/tma_load_128x64_swizzle128b.mlir b/mlir/test/Integration/GPU/CUDA/sm90/tma_load_128x64_swizzle128b.mlir
index b50772f8249f..95bde40deb48 100644
--- a/mlir/test/Integration/GPU/CUDA/sm90/tma_load_128x64_swizzle128b.mlir
+++ b/mlir/test/Integration/GPU/CUDA/sm90/tma_load_128x64_swizzle128b.mlir
@@ -85,7 +85,7 @@ module @mymod {
// Step 7. First thread does TMA load
scf.if %10 {
- gpu.printf "[GPU] TMA SIZE %d\0A" %c8192 : index
+ gpu.printf "[GPU] TMA SIZE %d\0A", %c8192 : index
nvgpu.tma.async.load %3[%c0, %c0], %9[%c0] to %7 : !lhsTensorMap, !barrierType -> !shmemlhs
nvgpu.mbarrier.arrive.expect_tx %9[%c0], %c8192 : !barrierType
} else {
@@ -98,16 +98,16 @@ module @mymod {
// Step 9. Print loaded data in 128b swizzled
scf.if %10 {
- gpu.printf "===--- Matrix A ---=== %d \0A" %c-1_i32 : i32
+ gpu.printf "===--- Matrix A ---=== %d \0A", %c-1_i32 : i32
scf.for %arg12 = %c0 to %c128 step %c1 {
scf.for %arg13 = %c0 to %c64 step %c1 {
%15 = memref.load %7[%arg12, %arg13] : !shmemlhs
%16 = arith.extf %15 : f16 to f32
- gpu.printf "%.0f, " %16 : f32
+ gpu.printf "%.0f, ", %16 : f32
}
- gpu.printf "%d\0A" %c-1_i32 : i32
+ gpu.printf "%d\0A", %c-1_i32 : i32
}
- gpu.printf "===----------------=== %d \0A" %c-1_i32 : i32
+ gpu.printf "===----------------=== %d \0A", %c-1_i32 : i32
}
gpu.terminator
}
diff --git a/mlir/test/Integration/GPU/CUDA/sm90/tma_load_64x64_swizzle128b.mlir b/mlir/test/Integration/GPU/CUDA/sm90/tma_load_64x64_swizzle128b.mlir
index 65e5fc0aff6a..e76fa03903b8 100644
--- a/mlir/test/Integration/GPU/CUDA/sm90/tma_load_64x64_swizzle128b.mlir
+++ b/mlir/test/Integration/GPU/CUDA/sm90/tma_load_64x64_swizzle128b.mlir
@@ -109,7 +109,7 @@ module @mymod {
// Step 6. First thread does TMA load
scf.if %10 {
- gpu.printf "[GPU] TMA SIZE %d\0A" %c32768 : index
+ gpu.printf "[GPU] TMA SIZE %d\0A", %c32768 : index
nvgpu.tma.async.load %d_lhsTensorMap[%c0, %c0], %9[%c0] to %lhsShmem : !lhsTensorMap, !barrierType -> !shmemlhs
nvgpu.tma.async.load %d_rhsTensorMap[%c0, %c0], %9[%c0] to %rhsShmem1 : !rhsTensorMap, !barrierType -> memref<64x64xf16, strided<[128, 1]>, 3>
nvgpu.tma.async.load %d_rhsTensorMap[%c64, %c0], %9[%c0] to %rhsShmem2 : !rhsTensorMap, !barrierType -> memref<64x64xf16, strided<[128, 1], offset: 4096>, 3>
@@ -124,16 +124,16 @@ module @mymod {
// Step 8. Print loaded data in 128b swizzled
scf.if %10 {
- gpu.printf "===--- Matrix B ---=== %d \n" %c-1_i32 : i32
+ gpu.printf "===--- Matrix B ---=== %d \n", %c-1_i32 : i32
scf.for %ii = %c0 to %c64 step %c1 {
scf.for %j = %c0 to %c128 step %c1 {
%lhs0 = memref.load %rhsShmem[%ii, %j] : !shmemrhs
%lhs032 = arith.extf %lhs0: f16 to f32
- gpu.printf "%.0f, " %lhs032 : f32
+ gpu.printf "%.0f, ", %lhs032 : f32
}
- gpu.printf "%d\n" %c-1_i32 : i32
+ gpu.printf "%d\n", %c-1_i32 : i32
}
- gpu.printf "===----------------=== %d \n" %c-1_i32 : i32
+ gpu.printf "===----------------=== %d \n", %c-1_i32 : i32
}
gpu.barrier
gpu.terminator
diff --git a/mlir/test/Integration/GPU/CUDA/sm90/tma_load_64x8_8x128_noswizzle.mlir b/mlir/test/Integration/GPU/CUDA/sm90/tma_load_64x8_8x128_noswizzle.mlir
index 391fda82e1e1..acca9811f570 100644
--- a/mlir/test/Integration/GPU/CUDA/sm90/tma_load_64x8_8x128_noswizzle.mlir
+++ b/mlir/test/Integration/GPU/CUDA/sm90/tma_load_64x8_8x128_noswizzle.mlir
@@ -80,8 +80,8 @@ module @mymod {
nvgpu.mbarrier.arrive.expect_tx %9[%c0], %c6144 : <memorySpace = #gpu.address_space<workgroup>>
%11 = memref.load %7[%c0, %c0] : memref<64x8xf32, 3>
%12 = memref.load %8[%c0, %c0] : memref<8x128xf32, 3>
- gpu.printf "[GPU] TMA BEFORE lhs[45][7] %f\0A" %11 : f32
- gpu.printf "[GPU] TMA BEFORE rhs[7][0] %f\0A" %12 : f32
+ gpu.printf "[GPU] TMA BEFORE lhs[45][7] %f\0A", %11 : f32
+ gpu.printf "[GPU] TMA BEFORE rhs[7][0] %f\0A", %12 : f32
nvgpu.tma.async.load %3[%c0, %c0], %9[%c0] to %7 : <tensor = memref<64x8xf32, 3>, swizzle = none, l2promo = none, oob = zero, interleave = none>, <memorySpace = #gpu.address_space<workgroup>> -> memref<64x8xf32, 3>
nvgpu.tma.async.load %4[%c0, %c0], %9[%c0] to %8 : <tensor = memref<8x128xf32, 3>, swizzle = none, l2promo = none, oob = zero, interleave = none>, <memorySpace = #gpu.address_space<workgroup>> -> memref<8x128xf32, 3>
} else {
@@ -92,8 +92,8 @@ module @mymod {
scf.if %10 {
%11 = memref.load %7[%c45, %c7] : memref<64x8xf32, 3>
%12 = memref.load %8[%c7, %c0] : memref<8x128xf32, 3>
- gpu.printf "[GPU] TMA LOADED lhs[45][7] %f\0A" %11 : f32
- gpu.printf "[GPU] TMA LOADED rhs[7][0] %f\0A" %12 : f32
+ gpu.printf "[GPU] TMA LOADED lhs[45][7] %f\0A", %11 : f32
+ gpu.printf "[GPU] TMA LOADED rhs[7][0] %f\0A", %12 : f32
}
gpu.terminator
}
diff --git a/mlir/test/Integration/GPU/CUDA/sm90/transform-dialect/tma_load_64x8_8x128_noswizzle-transform.mlir b/mlir/test/Integration/GPU/CUDA/sm90/transform-dialect/tma_load_64x8_8x128_noswizzle-transform.mlir
index f83f65bb2963..fe6c645357ec 100644
--- a/mlir/test/Integration/GPU/CUDA/sm90/transform-dialect/tma_load_64x8_8x128_noswizzle-transform.mlir
+++ b/mlir/test/Integration/GPU/CUDA/sm90/transform-dialect/tma_load_64x8_8x128_noswizzle-transform.mlir
@@ -96,8 +96,8 @@ func.func @main() {
scf.if %10 {
%11 = memref.load %out[%c45, %c7] : memref<64x8xf32, 3>
%12 = memref.load %out_1[%c7, %c0] : memref<8x128xf32, 3>
- gpu.printf "[GPU] TMA LOADED lhs[45][7] %f\0A" %11 : f32
- gpu.printf "[GPU] TMA LOADED rhs[7][0] %f\0A" %12 : f32
+ gpu.printf "[GPU] TMA LOADED lhs[45][7] %f\0A", %11 : f32
+ gpu.printf "[GPU] TMA LOADED rhs[7][0] %f\0A", %12 : f32
}
gpu.terminator
}
diff --git a/mlir/test/Integration/GPU/ROCM/printf.mlir b/mlir/test/Integration/GPU/ROCM/printf.mlir
index d5e6e3757540..4a0e4d34bfab 100644
--- a/mlir/test/Integration/GPU/ROCM/printf.mlir
+++ b/mlir/test/Integration/GPU/ROCM/printf.mlir
@@ -13,7 +13,7 @@ module attributes {gpu.container_module} {
gpu.module @kernels {
gpu.func @hello() kernel {
%0 = gpu.thread_id x
- gpu.printf "Hello from %d\n" %0 : index
+ gpu.printf "Hello from %d\n", %0 : index
gpu.return
}
}
diff --git a/mlir/test/Target/LLVMIR/Import/import-failure.ll b/mlir/test/Target/LLVMIR/Import/import-failure.ll
index 6bde174642d5..b616cb81e0a8 100644
--- a/mlir/test/Target/LLVMIR/Import/import-failure.ll
+++ b/mlir/test/Target/LLVMIR/Import/import-failure.ll
@@ -13,15 +13,6 @@ bb2:
; // -----
; CHECK: <unknown>
-; CHECK-SAME: error: unhandled value: ptr asm "bswap $0", "=r,r"
-define i32 @unhandled_value(i32 %arg1) {
- %1 = call i32 asm "bswap $0", "=r,r"(i32 %arg1)
- ret i32 %1
-}
-
-; // -----
-
-; CHECK: <unknown>
; CHECK-SAME: unhandled constant: ptr blockaddress(@unhandled_constant, %bb1) since blockaddress(...) is unsupported
; CHECK: <unknown>
; CHECK-SAME: error: unhandled instruction: ret ptr blockaddress(@unhandled_constant, %bb1)
diff --git a/mlir/test/Target/LLVMIR/Import/instructions.ll b/mlir/test/Target/LLVMIR/Import/instructions.ll
index fff48bbc486b..7377e2584110 100644
--- a/mlir/test/Target/LLVMIR/Import/instructions.ll
+++ b/mlir/test/Target/LLVMIR/Import/instructions.ll
@@ -535,6 +535,17 @@ define void @indirect_vararg_call(ptr addrspace(42) %fn) {
; // -----
+; CHECK-LABEL: @inlineasm
+; CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]
+define i32 @inlineasm(i32 %arg1) {
+ ; CHECK: %[[RES:.+]] = llvm.inline_asm has_side_effects "bswap $0", "=r,r" %[[ARG1]] : (i32) -> i32
+ %1 = call i32 asm "bswap $0", "=r,r"(i32 %arg1)
+ ; CHECK: return %[[RES]]
+ ret i32 %1
+}
+
+; // -----
+
; CHECK-LABEL: @gep_static_idx
; CHECK-SAME: %[[PTR:[a-zA-Z0-9]+]]
define void @gep_static_idx(ptr %ptr) {
diff --git a/mlir/test/Target/LLVMIR/Import/metadata-alias-scopes.ll b/mlir/test/Target/LLVMIR/Import/metadata-alias-scopes.ll
index f5128ff76bc5..bf4c85786216 100644
--- a/mlir/test/Target/LLVMIR/Import/metadata-alias-scopes.ll
+++ b/mlir/test/Target/LLVMIR/Import/metadata-alias-scopes.ll
@@ -92,3 +92,38 @@ declare void @foo(ptr %arg1)
!0 = distinct !{!0, !"The domain"}
!1 = !{!1, !0}
!2 = !{!1}
+
+; // -----
+
+; CHECK: #[[DOMAIN:.*]] = #llvm.alias_scope_domain<id = "domain1">
+; CHECK: #[[$SCOPE0:.*]] = #llvm.alias_scope<id = "scopeid1", domain = #[[DOMAIN]], description = "The first scope">
+; CHECK: #[[$SCOPE1:.*]] = #llvm.alias_scope<id = "scopeid2", domain = #[[DOMAIN]]>
+; CHECK: #[[$SCOPE2:.*]] = #llvm.alias_scope<id = "scopeid3", domain = #[[DOMAIN]]>
+
+; CHECK-LABEL: llvm.func @alias_scope
+define void @alias_scope(ptr %arg1) {
+ ; CHECK: llvm.load
+ ; CHECK-SAME: alias_scopes = [#[[$SCOPE0]]]
+ ; CHECK-SAME: noalias_scopes = [#[[$SCOPE1]], #[[$SCOPE2]]]
+ %1 = load i32, ptr %arg1, !alias.scope !4, !noalias !7
+ ; CHECK: llvm.load
+ ; CHECK-SAME: alias_scopes = [#[[$SCOPE1]]]
+ ; CHECK-SAME: noalias_scopes = [#[[$SCOPE0]], #[[$SCOPE2]]]
+ %2 = load i32, ptr %arg1, !alias.scope !5, !noalias !8
+ ; CHECK: llvm.load
+ ; CHECK-SAME: alias_scopes = [#[[$SCOPE2]]]
+ ; CHECK-SAME: noalias_scopes = [#[[$SCOPE0]], #[[$SCOPE1]]]
+ %3 = load i32, ptr %arg1, !alias.scope !6, !noalias !9
+ ret void
+}
+
+!0 = !{!"domain1"}
+!1 = !{!"scopeid1", !0, !"The first scope"}
+!2 = !{!"scopeid2", !0}
+!3 = !{!"scopeid3", !0}
+!4 = !{!1}
+!5 = !{!2}
+!6 = !{!3}
+!7 = !{!2, !3}
+!8 = !{!1, !3}
+!9 = !{!1, !2}
diff --git a/mlir/test/Target/LLVMIR/attribute-alias-scopes.mlir b/mlir/test/Target/LLVMIR/attribute-alias-scopes.mlir
index fa3395533af2..fb71a51512ae 100644
--- a/mlir/test/Target/LLVMIR/attribute-alias-scopes.mlir
+++ b/mlir/test/Target/LLVMIR/attribute-alias-scopes.mlir
@@ -104,3 +104,54 @@ llvm.func @self_reference() {
// CHECK-DAG: ![[SCOPES]] = !{![[SCOPE]]}
// CHECK-DAG: = !DISubroutineType(types: ![[TYPES:[0-9]+]])
// CHECK-DAG: ![[TYPES]] = !{null}
+
+// -----
+
+llvm.func @foo(%arg0: !llvm.ptr)
+
+#alias_scope_domain = #llvm.alias_scope_domain<id = "domain1", description = "The domain">
+#alias_scope1 = #llvm.alias_scope<id = "scope1", domain = #alias_scope_domain, description = "The first scope">
+#alias_scope2 = #llvm.alias_scope<id = "scope2", domain = #alias_scope_domain>
+#alias_scope3 = #llvm.alias_scope<id = "scope3", domain = #alias_scope_domain>
+
+// CHECK-LABEL: @alias_scopes
+llvm.func @alias_scopes(%arg1 : !llvm.ptr) {
+ %0 = llvm.mlir.constant(0 : i32) : i32
+ // CHECK: call void @llvm.experimental.noalias.scope.decl(metadata ![[SCOPES1:[0-9]+]])
+ llvm.intr.experimental.noalias.scope.decl #alias_scope1
+ // CHECK: store {{.*}}, !alias.scope ![[SCOPES1]], !noalias ![[SCOPES23:[0-9]+]]
+ llvm.store %0, %arg1 {alias_scopes = [#alias_scope1], noalias_scopes = [#alias_scope2, #alias_scope3]} : i32, !llvm.ptr
+ // CHECK: load {{.*}}, !alias.scope ![[SCOPES2:[0-9]+]], !noalias ![[SCOPES13:[0-9]+]]
+ %1 = llvm.load %arg1 {alias_scopes = [#alias_scope2], noalias_scopes = [#alias_scope1, #alias_scope3]} : !llvm.ptr -> i32
+ // CHECK: atomicrmw {{.*}}, !alias.scope ![[SCOPES3:[0-9]+]], !noalias ![[SCOPES12:[0-9]+]]
+ %2 = llvm.atomicrmw add %arg1, %0 monotonic {alias_scopes = [#alias_scope3], noalias_scopes = [#alias_scope1, #alias_scope2]} : !llvm.ptr, i32
+ // CHECK: cmpxchg {{.*}}, !alias.scope ![[SCOPES3]]
+ %3 = llvm.cmpxchg %arg1, %1, %2 acq_rel monotonic {alias_scopes = [#alias_scope3]} : !llvm.ptr, i32
+ %5 = llvm.mlir.constant(42 : i8) : i8
+ // CHECK: llvm.memcpy{{.*}}, !alias.scope ![[SCOPES3]]
+ "llvm.intr.memcpy"(%arg1, %arg1, %0) <{isVolatile = false}> {alias_scopes = [#alias_scope3]} : (!llvm.ptr, !llvm.ptr, i32) -> ()
+ // CHECK: llvm.memset{{.*}}, !noalias ![[SCOPES3]]
+ "llvm.intr.memset"(%arg1, %5, %0) <{isVolatile = false}> {noalias_scopes = [#alias_scope3]} : (!llvm.ptr, i8, i32) -> ()
+ // CHECK: call void @foo({{.*}} !alias.scope ![[SCOPES3]]
+ llvm.call @foo(%arg1) {alias_scopes = [#alias_scope3]} : (!llvm.ptr) -> ()
+ // CHECK: call void @foo({{.*}} !noalias ![[SCOPES3]]
+ llvm.call @foo(%arg1) {noalias_scopes = [#alias_scope3]} : (!llvm.ptr) -> ()
+ llvm.return
+}
+
+// Check the intrinsic declarations.
+// CHECK-DAG: declare void @llvm.experimental.noalias.scope.decl(metadata)
+// CHECK-DAG: declare void @llvm.memcpy.p0.p0.i32(ptr noalias nocapture writeonly, ptr noalias nocapture readonly, i32, i1 immarg)
+// CHECK-DAG: declare void @llvm.memset.p0.i32(ptr nocapture writeonly, i8, i32, i1 immarg)
+
+// Check the translated metadata.
+// CHECK-DAG: ![[DOMAIN:[0-9]+]] = !{!"domain1", !"The domain"}
+// CHECK-DAG: ![[SCOPE1:[0-9]+]] = !{!"scope1", ![[DOMAIN]], !"The first scope"}
+// CHECK-DAG: ![[SCOPE2:[0-9]+]] = !{!"scope2", ![[DOMAIN]]}
+// CHECK-DAG: ![[SCOPE3:[0-9]+]] = !{!"scope3", ![[DOMAIN]]}
+// CHECK-DAG: ![[SCOPES1]] = !{![[SCOPE1]]}
+// CHECK-DAG: ![[SCOPES2]] = !{![[SCOPE2]]}
+// CHECK-DAG: ![[SCOPES3]] = !{![[SCOPE3]]}
+// CHECK-DAG: ![[SCOPES12]] = !{![[SCOPE1]], ![[SCOPE2]]}
+// CHECK-DAG: ![[SCOPES13]] = !{![[SCOPE1]], ![[SCOPE3]]}
+// CHECK-DAG: ![[SCOPES23]] = !{![[SCOPE2]], ![[SCOPE3]]}
diff --git a/mlir/test/Target/LLVMIR/nvvmir.mlir b/mlir/test/Target/LLVMIR/nvvmir.mlir
index b69d77496351..2d7710e7cbf2 100644
--- a/mlir/test/Target/LLVMIR/nvvmir.mlir
+++ b/mlir/test/Target/LLVMIR/nvvmir.mlir
@@ -556,9 +556,7 @@ llvm.func @kernel_func() attributes {nvvm.kernel} {
llvm.return
}
-// CHECK: !nvvm.annotations =
-// CHECK-NOT: {ptr @nvvm_special_regs, !"kernel", i32 1}
-// CHECK: {ptr @kernel_func, !"kernel", i32 1}
+// CHECK: ptx_kernel void @kernel_func
// -----
@@ -566,9 +564,8 @@ llvm.func @kernel_func() attributes {nvvm.kernel, nvvm.maxntid = array<i32: 1, 2
llvm.return
}
+// CHECK: define ptx_kernel void @kernel_func
// CHECK: !nvvm.annotations =
-// CHECK-NOT: {ptr @nvvm_special_regs, !"kernel", i32 1}
-// CHECK: {ptr @kernel_func, !"kernel", i32 1}
// CHECK: {ptr @kernel_func, !"maxntidx", i32 1}
// CHECK: {ptr @kernel_func, !"maxntidy", i32 23}
// CHECK: {ptr @kernel_func, !"maxntidz", i32 32}
@@ -578,9 +575,8 @@ llvm.func @kernel_func() attributes {nvvm.kernel, nvvm.reqntid = array<i32: 1, 2
llvm.return
}
+// CHECK: define ptx_kernel void @kernel_func
// CHECK: !nvvm.annotations =
-// CHECK-NOT: {ptr @nvvm_special_regs, !"kernel", i32 1}
-// CHECK: {ptr @kernel_func, !"kernel", i32 1}
// CHECK: {ptr @kernel_func, !"reqntidx", i32 1}
// CHECK: {ptr @kernel_func, !"reqntidy", i32 23}
// CHECK: {ptr @kernel_func, !"reqntidz", i32 32}
@@ -590,31 +586,28 @@ llvm.func @kernel_func() attributes {nvvm.kernel, nvvm.cluster_dim = array<i32:
llvm.return
}
+// CHECK: define ptx_kernel void @kernel_func
// CHECK: !nvvm.annotations =
-// CHECK-NOT: {ptr @nvvm_special_regs, !"kernel", i32 1}
// CHECK: {ptr @kernel_func, !"cluster_dim_x", i32 3}
// CHECK: {ptr @kernel_func, !"cluster_dim_y", i32 5}
// CHECK: {ptr @kernel_func, !"cluster_dim_z", i32 7}
-// CHECK: {ptr @kernel_func, !"kernel", i32 1}
// -----
llvm.func @kernel_func() attributes {nvvm.kernel, nvvm.cluster_max_blocks = 8} {
llvm.return
}
+// CHECK: define ptx_kernel void @kernel_func
// CHECK: !nvvm.annotations =
-// CHECK-NOT: {ptr @nvvm_special_regs, !"kernel", i32 1}
// CHECK: {ptr @kernel_func, !"cluster_max_blocks", i32 8}
-// CHECK: {ptr @kernel_func, !"kernel", i32 1}
// -----
llvm.func @kernel_func() attributes {nvvm.kernel, nvvm.minctasm = 16} {
llvm.return
}
+// CHECK: define ptx_kernel void @kernel_func
// CHECK: !nvvm.annotations =
-// CHECK-NOT: {ptr @nvvm_special_regs, !"kernel", i32 1}
-// CHECK: {ptr @kernel_func, !"kernel", i32 1}
// CHECK: {ptr @kernel_func, !"minctasm", i32 16}
// -----
@@ -622,9 +615,8 @@ llvm.func @kernel_func() attributes {nvvm.kernel, nvvm.maxnreg = 16} {
llvm.return
}
+// CHECK: define ptx_kernel void @kernel_func
// CHECK: !nvvm.annotations =
-// CHECK-NOT: {ptr @nvvm_special_regs, !"kernel", i32 1}
-// CHECK: {ptr @kernel_func, !"kernel", i32 1}
// CHECK: {ptr @kernel_func, !"maxnreg", i32 16}
// -----
@@ -633,9 +625,8 @@ llvm.func @kernel_func() attributes {nvvm.kernel, nvvm.maxntid = array<i32: 1, 2
llvm.return
}
+// CHECK: define ptx_kernel void @kernel_func
// CHECK: !nvvm.annotations =
-// CHECK-NOT: {ptr @nvvm_special_regs, !"kernel", i32 1}
-// CHECK: {ptr @kernel_func, !"kernel", i32 1}
// CHECK: {ptr @kernel_func, !"maxnreg", i32 32}
// CHECK: {ptr @kernel_func, !"maxntidx", i32 1}
// CHECK: {ptr @kernel_func, !"maxntidy", i32 23}
@@ -643,19 +634,19 @@ llvm.func @kernel_func() attributes {nvvm.kernel, nvvm.maxntid = array<i32: 1, 2
// CHECK: {ptr @kernel_func, !"minctasm", i32 16}
// -----
+// CHECK: define ptx_kernel void @kernel_func
// CHECK: !nvvm.annotations =
// CHECK: !1 = !{ptr @kernel_func, !"grid_constant", !2}
// CHECK: !2 = !{i32 1}
-// CHECK: !3 = !{ptr @kernel_func, !"kernel", i32 1}
llvm.func @kernel_func(%arg0: !llvm.ptr {llvm.byval = i32, nvvm.grid_constant}) attributes {nvvm.kernel} {
llvm.return
}
// -----
+// CHECK: define ptx_kernel void @kernel_func
// CHECK: !nvvm.annotations =
// CHECK: !1 = !{ptr @kernel_func, !"grid_constant", !2}
// CHECK: !2 = !{i32 1, i32 3}
-// CHECK: !3 = !{ptr @kernel_func, !"kernel", i32 1}
llvm.func @kernel_func(%arg0: !llvm.ptr {llvm.byval = i32, nvvm.grid_constant}, %arg1: f32, %arg2: !llvm.ptr {llvm.byval = f32, nvvm.grid_constant}) attributes {nvvm.kernel} {
llvm.return
}
diff --git a/mlir/test/Target/LLVMIR/omptarget-threadprivate-device-lowering.mlir b/mlir/test/Target/LLVMIR/omptarget-threadprivate-device-lowering.mlir
new file mode 100644
index 000000000000..279ecb3f8e99
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/omptarget-threadprivate-device-lowering.mlir
@@ -0,0 +1,30 @@
+// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
+
+// Not intended to be a functional example, the aim of this test is to verify
+// omp.threadprivate does not crash on lowering during the OpenMP target device
+// pass when used in conjunction with target code in the same module.
+
+module attributes {omp.is_target_device = true } {
+ llvm.func @func() attributes {omp.declare_target = #omp.declaretarget<device_type = (host), capture_clause = (to)>} {
+ %0 = llvm.mlir.addressof @_QFEpointer2 : !llvm.ptr
+ %1 = omp.threadprivate %0 : !llvm.ptr -> !llvm.ptr
+ %2 = omp.map.info var_ptr(%1 : !llvm.ptr, !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>)>) map_clauses(implicit, to) capture(ByRef) -> !llvm.ptr
+ omp.target map_entries(%2 -> %arg0 : !llvm.ptr) {
+ %3 = llvm.mlir.constant(1 : i32) : i32
+ %4 = llvm.getelementptr %arg0[0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>)>
+ llvm.store %3, %4 : i32, !llvm.ptr
+ omp.terminator
+ }
+ llvm.return
+ }
+ llvm.mlir.global internal @_QFEpointer2() {addr_space = 0 : i32} : !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>)> {
+ %0 = llvm.mlir.undef : !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>)>
+ llvm.return %0 : !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>)>
+ }
+}
+
+// CHECK: define weak_odr protected void @{{.*}}(ptr %{{.*}}, ptr %[[ARG1:.*]]) {
+// CHECK: %[[ALLOCA:.*]] = alloca ptr, align 8
+// CHECK: store ptr %[[ARG1]], ptr %[[ALLOCA]], align 8
+// CHECK: %[[LOAD_ALLOCA:.*]] = load ptr, ptr %[[ALLOCA]], align 8
+// CHECK: store i32 1, ptr %[[LOAD_ALLOCA]], align 4
diff --git a/mlir/test/Target/LLVMIR/openmp-simd-aligned.mlir b/mlir/test/Target/LLVMIR/openmp-simd-aligned.mlir
new file mode 100644
index 000000000000..234604e4b664
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/openmp-simd-aligned.mlir
@@ -0,0 +1,60 @@
+// RUN: mlir-translate -mlir-to-llvmir -split-input-file %s | FileCheck %s
+
+//CHECK-LABEL: define void @_QPsimd_aligned_pointer() {
+//CHECK: %[[A_PTR:.*]] = alloca { ptr, i64, i32, i8, i8, i8, i8 }, i64 1, align 8
+//CHECK: %[[A_VAL:.*]] = load ptr, ptr %[[A_PTR]], align 8
+//CHECK: call void @llvm.assume(i1 true) [ "align"(ptr %[[A_VAL]], i64 256) ]
+llvm.func @_QPsimd_aligned_pointer() {
+ %1 = llvm.mlir.constant(1 : i64) : i64
+ %2 = llvm.alloca %1 x !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8)> {bindc_name = "x"} : (i64) -> !llvm.ptr
+ %3 = llvm.alloca %1 x i32 {bindc_name = "i", pinned} : (i64) -> !llvm.ptr
+ %4 = llvm.mlir.constant(1 : i32) : i32
+ %5 = llvm.mlir.constant(10 : i32) : i32
+ %6 = llvm.mlir.constant(1 : i32) : i32
+ omp.simd aligned(%2 : !llvm.ptr -> 256 : i64) {
+ omp.loop_nest (%arg0) : i32 = (%4) to (%5) inclusive step (%6) {
+ llvm.store %arg0, %3 : i32, !llvm.ptr
+ omp.yield
+ }
+ }
+ llvm.return
+}
+
+//CHECK-LABEL: define void @_QPsimd_aligned_cptr() {
+//CHECK: %[[A_CPTR:.*]] = alloca %_QM__fortran_builtinsT__builtin_c_ptr, i64 1, align 8
+//CHECK: %[[A_VAL:.*]] = load ptr, ptr %[[A_CPTR]], align 8
+//CHECK: call void @llvm.assume(i1 true) [ "align"(ptr %[[A_VAL]], i64 256) ]
+llvm.func @_QPsimd_aligned_cptr() {
+ %0 = llvm.mlir.constant(1 : i64) : i64
+ %1 = llvm.alloca %0 x !llvm.struct<"_QM__fortran_builtinsT__builtin_c_ptr", (i64)> {bindc_name = "a"} : (i64) -> !llvm.ptr
+ %2 = llvm.mlir.constant(1 : i64) : i64
+ %3 = llvm.alloca %2 x i32 {bindc_name = "i", pinned} : (i64) -> !llvm.ptr
+ %4 = llvm.mlir.constant(1 : i32) : i32
+ %5 = llvm.mlir.constant(10 : i32) : i32
+ %6 = llvm.mlir.constant(1 : i32) : i32
+ omp.simd aligned(%1 : !llvm.ptr -> 256 : i64) {
+ omp.loop_nest (%arg0) : i32 = (%4) to (%5) inclusive step (%6) {
+ llvm.store %arg0, %3 : i32, !llvm.ptr
+ omp.yield
+ }
+ }
+ llvm.return
+}
+
+//CHECK-LABEL: define void @_QPsimd_aligned_allocatable() {
+//CHECK: %[[A_ADDR:.*]] = alloca { ptr, i64, i32, i8, i8, i8, i8, [1 x [3 x i64]] }, i64 1, align 8
+//CHECK: %[[A_VAL:.*]] = load ptr, ptr %[[A_ADDR]], align 8
+//CHECK: call void @llvm.assume(i1 true) [ "align"(ptr %[[A_VAL]], i64 256) ]
+llvm.func @_QPsimd_aligned_allocatable() {
+ %0 = llvm.mlir.constant(1 : i64) : i64
+ %1 = llvm.alloca %0 x !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>)> {bindc_name = "a"} : (i64) -> !llvm.ptr
+ %2 = llvm.mlir.constant(1 : i32) : i32
+ %3 = llvm.mlir.constant(10 : i32) : i32
+ %4 = llvm.mlir.constant(1 : i32) : i32
+ omp.simd aligned(%1 : !llvm.ptr -> 256 : i64) {
+ omp.loop_nest (%arg0) : i32 = (%2) to (%3) inclusive step (%4) {
+ omp.yield
+ }
+ }
+ llvm.return
+}
diff --git a/mlir/test/Target/LLVMIR/openmp-todo.mlir b/mlir/test/Target/LLVMIR/openmp-todo.mlir
index 8f3e466cfbbe..83a0990d6316 100644
--- a/mlir/test/Target/LLVMIR/openmp-todo.mlir
+++ b/mlir/test/Target/LLVMIR/openmp-todo.mlir
@@ -127,18 +127,6 @@ llvm.func @sections_private(%x : !llvm.ptr) {
llvm.return
}
-// -----
-
-llvm.func @simd_aligned(%lb : i32, %ub : i32, %step : i32, %x : !llvm.ptr) {
- // expected-error@below {{not yet implemented: Unhandled clause aligned in omp.simd operation}}
- // expected-error@below {{LLVM Translation failed for operation: omp.simd}}
- omp.simd aligned(%x : !llvm.ptr -> 32) {
- omp.loop_nest (%iv) : i32 = (%lb) to (%ub) step (%step) {
- omp.yield
- }
- }
- llvm.return
-}
// -----
diff --git a/mlir/test/Transforms/location-snapshot.mlir b/mlir/test/Transforms/location-snapshot.mlir
index 9f48cb6e3b3f..aeddfedd08ae 100644
--- a/mlir/test/Transforms/location-snapshot.mlir
+++ b/mlir/test/Transforms/location-snapshot.mlir
@@ -1,5 +1,6 @@
// RUN: mlir-opt -allow-unregistered-dialect -snapshot-op-locations='filename=%/t' -mlir-print-local-scope -mlir-print-debuginfo %s | FileCheck %s -DFILE=%/t
// RUN: mlir-opt -allow-unregistered-dialect -snapshot-op-locations='filename=%/t tag='tagged'' -mlir-print-local-scope -mlir-print-debuginfo %s | FileCheck %s --check-prefix=TAG -DFILE=%/t
+// RUN: mlir-opt -allow-unregistered-dialect -snapshot-op-locations='filename=%/t print-debuginfo' -mlir-print-local-scope -mlir-print-debuginfo %s | FileCheck %s --check-prefix=DBG -DFILE=%/t && cat %/t | FileCheck %s --check-prefix=DBGFILE
// CHECK: func @function(
// CHECK-NEXT: loc("[[FILE]]":{{[0-9]+}}:{{[0-9]+}})
@@ -15,3 +16,18 @@ func.func @function() -> i32 {
%1 = "foo"() : () -> i32 loc("original")
return %1 : i32 loc("original")
} loc("original")
+
+// DBG: func @function2(
+// DBG-NEXT: loc("[[FILE]]":{{[0-9]+}}:{{[0-9]+}})
+// DBG-NEXT: loc("[[FILE]]":{{[0-9]+}}:{{[0-9]+}})
+// DBG-NEXT: } loc("[[FILE]]":{{[0-9]+}}:{{[0-9]+}})
+
+// DBGFILE: func @function2(
+// DBGFILE-NEXT: loc("{{.*}}location-snapshot.mlir":{{[0-9]+}}:{{[0-9]+}})
+// DBGFILE-NEXT: loc("{{.*}}location-snapshot.mlir":{{[0-9]+}}:{{[0-9]+}})
+// DBGFILE-NEXT: } loc("{{.*}}location-snapshot.mlir":{{[0-9]+}}:{{[0-9]+}})
+
+func.func @function2() -> i32 {
+ %1 = "foo"() : () -> i32
+ return %1 : i32
+} \ No newline at end of file
diff --git a/mlir/test/Transforms/loop-invariant-code-motion.mlir b/mlir/test/Transforms/loop-invariant-code-motion.mlir
index e4c423ce7052..5133c14414c9 100644
--- a/mlir/test/Transforms/loop-invariant-code-motion.mlir
+++ b/mlir/test/Transforms/loop-invariant-code-motion.mlir
@@ -124,6 +124,64 @@ func.func @invariant_affine_if() {
// -----
+func.func @hoist_invariant_affine_if_success(%lb: index, %ub: index, %step: index) -> i32 {
+ %cst_0 = arith.constant 0 : i32
+ %cst_42 = arith.constant 42 : i32
+ %sum_result = affine.for %i = %lb to %ub iter_args(%acc = %cst_0) -> i32 {
+ %conditional_add = affine.if affine_set<() : ()> () -> (i32) {
+ %add = arith.addi %cst_42, %cst_42 : i32
+ affine.yield %add : i32
+ } else {
+ %poison = ub.poison : i32
+ affine.yield %poison : i32
+ }
+ %sum = arith.addi %acc, %conditional_add : i32
+ affine.yield %sum : i32
+ }
+
+ // CHECK-LABEL: hoist_invariant_affine_if_success
+ // CHECK-NEXT: arith.constant 0 : i32
+ // CHECK-NEXT: %[[CST:.*]] = arith.constant 42 : i32
+ // CHECK-NEXT: %[[IF:.*]] = affine.if
+ // CHECK-NEXT: arith.addi %[[CST]], %[[CST]] : i32
+ // CHECK: affine.for
+ // CHECK-NOT: affine.if
+ // CHECK-NEXT: arith.addi %{{.*}}, %[[IF]]
+
+ return %sum_result : i32
+}
+
+// -----
+
+func.func @hoist_variant_affine_if_failure(%lb: index, %ub: index, %step: index) -> i32 {
+ %cst_0 = arith.constant 0 : i32
+ %cst_42 = arith.constant 42 : i32
+ %ind_7 = arith.constant 7 : index
+ %sum_result = affine.for %i = %lb to %ub iter_args(%acc = %cst_0) -> i32 {
+ %conditional_add = affine.if affine_set<(d0, d1) : (d1 - d0 >= 0)> (%i, %ind_7) -> (i32) {
+ %add = arith.addi %cst_42, %cst_42 : i32
+ affine.yield %add : i32
+ } else {
+ %poison = ub.poison : i32
+ affine.yield %poison : i32
+ }
+ %sum = arith.addi %acc, %conditional_add : i32
+ affine.yield %sum : i32
+ }
+
+ // CHECK-LABEL: hoist_variant_affine_if_failure
+ // CHECK-NEXT: arith.constant 0 : i32
+ // CHECK-NEXT: %[[CST:.*]] = arith.constant 42 : i32
+ // CHECK-NEXT: arith.constant 7 : index
+ // CHECK-NEXT: affine.for
+ // CHECK-NEXT: %[[IF:.*]] = affine.if
+ // CHECK: arith.addi %{{.*}}, %[[IF]]
+
+ return %sum_result : i32
+}
+
+// -----
+
func.func @hoist_affine_for_with_unknown_trip_count(%lb: index, %ub: index) {
affine.for %arg0 = 0 to 10 {
affine.for %arg1 = %lb to %ub {
@@ -383,6 +441,69 @@ func.func @parallel_loop_with_invariant() {
// -----
+func.func @hoist_invariant_scf_if_success(%lb: index, %ub: index, %step: index) -> i32 {
+ %cst_0 = arith.constant 0 : i32
+ %cst_42 = arith.constant 42 : i32
+ %true = arith.constant true
+ %sum_result = scf.for %i = %lb to %ub step %step iter_args(%acc = %cst_0) -> i32 {
+ %conditional_add = scf.if %true -> (i32) {
+ %add = arith.addi %cst_42, %cst_42 : i32
+ scf.yield %add : i32
+ } else {
+ %poison = ub.poison : i32
+ scf.yield %poison : i32
+ }
+ %sum = arith.addi %acc, %conditional_add : i32
+ scf.yield %sum : i32
+ }
+
+ // CHECK-LABEL: hoist_invariant_scf_if_success
+ // CHECK-NEXT: arith.constant 0 : i32
+ // CHECK-NEXT: %[[CST:.*]] = arith.constant 42 : i32
+ // CHECK-NEXT: %[[TRUE:.*]] = arith.constant true
+ // CHECK-NEXT: %[[IF:.*]] = scf.if %[[TRUE]]
+ // CHECK-NEXT: arith.addi %[[CST]], %[[CST]] : i32
+ // CHECK: scf.for
+ // CHECK-NOT: scf.if
+ // CHECK-NEXT: arith.addi %{{.*}}, %[[IF]]
+
+ return %sum_result : i32
+}
+
+// -----
+
+func.func @hoist_variant_scf_if_failure(%lb: index, %ub: index, %step: index) -> i32 {
+ %cst_0 = arith.constant 0 : i32
+ %cst_42 = arith.constant 42 : i32
+ %ind_7 = arith.constant 7 : index
+ %sum_result = scf.for %i = %lb to %ub step %step iter_args(%acc = %cst_0) -> i32 {
+ %cond = arith.cmpi ult, %i, %ind_7 : index
+ %conditional_add = scf.if %cond -> (i32) {
+ %add = arith.addi %cst_42, %cst_42 : i32
+ scf.yield %add : i32
+ } else {
+ %poison = ub.poison : i32
+ scf.yield %poison : i32
+ }
+ %sum = arith.addi %acc, %conditional_add : i32
+ scf.yield %sum : i32
+ }
+
+ // CHECK-LABEL: hoist_variant_scf_if_failure
+ // CHECK-NEXT: arith.constant 0 : i32
+ // CHECK-NEXT: %[[CST_42:.*]] = arith.constant 42 : i32
+ // CHECK-NEXT: %[[CST_7:.*]] = arith.constant 7 : index
+ // CHECK-NEXT: scf.for %[[IV:.*]] = %{{.*}} to %{{.*}}
+ // CHECK-NEXT: %[[CMP:.*]] = arith.cmpi ult, %[[IV]], %[[CST_7]]
+ // CHECK-NEXT: %[[IF:.*]] = scf.if %[[CMP]]
+ // CHECK-NEXT: arith.addi %[[CST_42]], %[[CST_42]] : i32
+ // CHECK: arith.addi %{{.*}}, %[[IF]]
+
+ return %sum_result : i32
+}
+
+// -----
+
func.func private @make_val() -> (index)
// CHECK-LABEL: func @nested_uses_inside
diff --git a/mlir/test/Transforms/test-legalizer.mlir b/mlir/test/Transforms/test-legalizer.mlir
index 2ca5f4963752..ae7d344b7167 100644
--- a/mlir/test/Transforms/test-legalizer.mlir
+++ b/mlir/test/Transforms/test-legalizer.mlir
@@ -64,9 +64,6 @@ func.func @remap_call_1_to_1(%arg0: i64) {
// Contents of the old block are moved to the new block.
// CHECK-NEXT: notifyOperationInserted: test.return, was linked, exact position unknown
-// The new block arguments are used in "test.return".
-// CHECK-NEXT: notifyOperationModified: test.return
-
// The old block is erased.
// CHECK-NEXT: notifyBlockErased
@@ -390,8 +387,8 @@ func.func @caller() {
// CHECK: %[[call:.*]]:2 = call @callee() : () -> (f16, f16)
%0:2 = func.call @callee() : () -> (f32, i24)
- // CHECK: %[[cast1:.*]] = "test.cast"() : () -> i24
- // CHECK: %[[cast0:.*]] = "test.cast"(%[[call]]#0, %[[call]]#1) : (f16, f16) -> f32
+ // CHECK-DAG: %[[cast1:.*]] = "test.cast"() : () -> i24
+ // CHECK-DAG: %[[cast0:.*]] = "test.cast"(%[[call]]#0, %[[call]]#1) : (f16, f16) -> f32
// CHECK: "test.some_user"(%[[cast0]], %[[cast1]]) : (f32, i24) -> ()
// expected-remark @below{{'test.some_user' is not legalizable}}
"test.some_user"(%0#0, %0#1) : (f32, i24) -> ()
@@ -450,7 +447,7 @@ func.func @fold_legalization() -> i32 {
// -----
// CHECK-LABEL: func @convert_detached_signature()
-// CHECK: "test.legal_op_with_region"() ({
+// CHECK: "test.legal_op"() ({
// CHECK: ^bb0(%arg0: f64):
// CHECK: "test.return"() : () -> ()
// CHECK: }) : () -> ()
@@ -483,3 +480,20 @@ func.func @test_1_to_n_block_signature_conversion() {
"test.return"() : () -> ()
}
+// -----
+
+// CHECK: notifyOperationInserted: test.step_1
+// CHECK: notifyOperationReplaced: test.multiple_1_to_n_replacement
+// CHECK: notifyOperationErased: test.multiple_1_to_n_replacement
+// CHECK: notifyOperationInserted: test.legal_op
+// CHECK: notifyOperationReplaced: test.step_1
+// CHECK: notifyOperationErased: test.step_1
+
+// CHECK-LABEL: func @test_multiple_1_to_n_replacement()
+// CHECK: %[[legal_op:.*]]:4 = "test.legal_op"() : () -> (f16, f16, f16, f16)
+// CHECK: %[[cast:.*]] = "test.cast"(%[[legal_op]]#0, %[[legal_op]]#1, %[[legal_op]]#2, %[[legal_op]]#3) : (f16, f16, f16, f16) -> f16
+// CHECK: "test.valid"(%[[cast]]) : (f16) -> ()
+func.func @test_multiple_1_to_n_replacement() {
+ %0 = "test.multiple_1_to_n_replacement"() : () -> (f16)
+ "test.invalid"(%0) : (f16) -> ()
+}
diff --git a/mlir/test/lib/Dialect/Func/TestDecomposeCallGraphTypes.cpp b/mlir/test/lib/Dialect/Func/TestDecomposeCallGraphTypes.cpp
index 09c5b4b2a0ad..d0b62e71ab0c 100644
--- a/mlir/test/lib/Dialect/Func/TestDecomposeCallGraphTypes.cpp
+++ b/mlir/test/lib/Dialect/Func/TestDecomposeCallGraphTypes.cpp
@@ -139,7 +139,7 @@ struct TestDecomposeCallGraphTypes
tupleType.getFlattenedTypes(types);
return success();
});
- typeConverter.addArgumentMaterialization(buildMakeTupleOp);
+ typeConverter.addSourceMaterialization(buildMakeTupleOp);
typeConverter.addTargetMaterialization(buildDecomposeTuple);
populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(
diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index a470497fdbb5..5b7c36c9b97b 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -785,7 +785,7 @@ struct TestDetachedSignatureConversion : public ConversionPattern {
ConversionPatternRewriter &rewriter) const final {
if (op->getNumRegions() != 1)
return failure();
- OperationState state(op->getLoc(), "test.legal_op_with_region", operands,
+ OperationState state(op->getLoc(), "test.legal_op", operands,
op->getResultTypes(), {}, BlockRange());
Region *newRegion = state.addRegion();
rewriter.inlineRegionBefore(op->getRegion(0), *newRegion,
@@ -1234,6 +1234,41 @@ public:
}
};
+/// A pattern that tests two back-to-back 1 -> 2 op replacements.
+class TestMultiple1ToNReplacement : public ConversionPattern {
+public:
+ TestMultiple1ToNReplacement(MLIRContext *ctx, const TypeConverter &converter)
+ : ConversionPattern(converter, "test.multiple_1_to_n_replacement", 1,
+ ctx) {}
+ LogicalResult
+ matchAndRewrite(Operation *op, ArrayRef<ValueRange> operands,
+ ConversionPatternRewriter &rewriter) const final {
+ // Helper function that replaces the given op with a new op of the given
+ // name and doubles each result (1 -> 2 replacement of each result).
+ auto replaceWithDoubleResults = [&](Operation *op, StringRef name) {
+ SmallVector<Type> types;
+ for (Type t : op->getResultTypes()) {
+ types.push_back(t);
+ types.push_back(t);
+ }
+ OperationState state(op->getLoc(), name,
+ /*operands=*/{}, types, op->getAttrs());
+ auto *newOp = rewriter.create(state);
+ SmallVector<ValueRange> repls;
+ for (size_t i = 0, e = op->getNumResults(); i < e; ++i)
+ repls.push_back(newOp->getResults().slice(2 * i, 2));
+ rewriter.replaceOpWithMultiple(op, repls);
+ return newOp;
+ };
+
+ // Replace test.multiple_1_to_n_replacement with test.step_1.
+ Operation *repl1 = replaceWithDoubleResults(op, "test.step_1");
+ // Now replace test.step_1 with test.legal_op.
+ replaceWithDoubleResults(repl1, "test.legal_op");
+ return success();
+ }
+};
+
} // namespace
namespace {
@@ -1241,7 +1276,6 @@ struct TestTypeConverter : public TypeConverter {
using TypeConverter::TypeConverter;
TestTypeConverter() {
addConversion(convertType);
- addArgumentMaterialization(materializeCast);
addSourceMaterialization(materializeCast);
}
@@ -1319,7 +1353,8 @@ struct TestLegalizePatternDriver
TestUndoPropertiesModification, TestEraseOp,
TestRepetitive1ToNConsumer>(&getContext());
patterns.add<TestDropOpSignatureConversion, TestDropAndReplaceInvalidOp,
- TestPassthroughInvalidOp>(&getContext(), converter);
+ TestPassthroughInvalidOp, TestMultiple1ToNReplacement>(
+ &getContext(), converter);
patterns.add<TestDuplicateBlockArgs>(converter, &getContext());
mlir::populateAnyFunctionOpInterfaceTypeConversionPattern(patterns,
converter);
@@ -1330,8 +1365,7 @@ struct TestLegalizePatternDriver
target.addLegalOp<ModuleOp>();
target.addLegalOp<LegalOpA, LegalOpB, LegalOpC, TestCastOp, TestValidOp,
TerminatorOp, OneRegionOp>();
- target.addLegalOp(
- OperationName("test.legal_op_with_region", &getContext()));
+ target.addLegalOp(OperationName("test.legal_op", &getContext()));
target
.addIllegalOp<ILLegalOpF, TestRegionBuilderOp, TestOpWithRegionFold>();
target.addDynamicallyLegalOp<TestReturnOp>([](TestReturnOp op) {
diff --git a/mlir/test/lib/Dialect/Tosa/TosaTestPasses.cpp b/mlir/test/lib/Dialect/Tosa/TosaTestPasses.cpp
index ac904c3e01c9..83db1188861a 100644
--- a/mlir/test/lib/Dialect/Tosa/TosaTestPasses.cpp
+++ b/mlir/test/lib/Dialect/Tosa/TosaTestPasses.cpp
@@ -149,7 +149,7 @@ ConvertTosaConv2DOp::matchAndRewrite(Operation *op,
op->getLoc(), newTosaConv2DOpType, tosaConv2DOp.getInput(),
tosaConv2DOp.getWeight(), tosaConv2DOp.getBias(),
tosaConv2DOp.getPadAttr(), tosaConv2DOp.getStrideAttr(),
- tosaConv2DOp.getDilationAttr());
+ tosaConv2DOp.getDilationAttr(), tosaConv2DOp.getAccTypeAttr());
// Create rescale to quantized type
double inputScale = inputQType.getScale();
diff --git a/mlir/test/lib/Transforms/TestDialectConversion.cpp b/mlir/test/lib/Transforms/TestDialectConversion.cpp
index 2cc1fb5d39d7..a03bf0a1023d 100644
--- a/mlir/test/lib/Transforms/TestDialectConversion.cpp
+++ b/mlir/test/lib/Transforms/TestDialectConversion.cpp
@@ -28,7 +28,6 @@ namespace {
struct PDLLTypeConverter : public TypeConverter {
PDLLTypeConverter() {
addConversion(convertType);
- addArgumentMaterialization(materializeCast);
addSourceMaterialization(materializeCast);
}
diff --git a/mlir/test/python/dialects/transform_structured_ext.py b/mlir/test/python/dialects/transform_structured_ext.py
index fb4c75b53379..8785d6d36007 100644
--- a/mlir/test/python/dialects/transform_structured_ext.py
+++ b/mlir/test/python/dialects/transform_structured_ext.py
@@ -103,6 +103,42 @@ def testFuseIntoContainingOpCompact(target):
@run
@create_sequence
+def testFuseOpCompact(target):
+ structured.FuseOp(
+ target, tile_sizes=[4, 8], tile_interchange=[0, 1], apply_cleanup=True
+ )
+ # CHECK-LABEL: TEST: testFuseOpCompact
+ # CHECK: transform.sequence
+ # CHECK: %{{.+}}, %{{.+}}:2 = transform.structured.fuse %{{.*}}[4, 8]
+ # CHECK-SAME: interchange [0, 1] apply_cleanup = true
+ # CHECK-SAME: (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
+
+
+@run
+@create_sequence
+def testFuseOpNoArg(target):
+ structured.FuseOp(target)
+ # CHECK-LABEL: TEST: testFuseOpNoArg
+ # CHECK: transform.sequence
+ # CHECK: %{{.+}} = transform.structured.fuse %{{.*}} :
+ # CHECK-SAME: (!transform.any_op) -> !transform.any_op
+
+
+@run
+@create_sequence
+def testFuseOpAttributes(target):
+ attr = DenseI64ArrayAttr.get([4, 8])
+ ichange = DenseI64ArrayAttr.get([0, 1])
+ structured.FuseOp(target, tile_sizes=attr, tile_interchange=ichange)
+ # CHECK-LABEL: TEST: testFuseOpAttributes
+ # CHECK: transform.sequence
+ # CHECK: %{{.+}}, %{{.+}}:2 = transform.structured.fuse %{{.*}}[4, 8]
+ # CHECK-SAME: interchange [0, 1]
+ # CHECK-SAME: (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
+
+
+@run
+@create_sequence
def testGeneralize(target):
structured.GeneralizeOp(target)
# CHECK-LABEL: TEST: testGeneralize
diff --git a/mlir/test/python/execution_engine.py b/mlir/test/python/execution_engine.py
index 6d3a8db8c24b..0d12c35d96be 100644
--- a/mlir/test/python/execution_engine.py
+++ b/mlir/test/python/execution_engine.py
@@ -306,7 +306,7 @@ def testUnrankedMemRefWithOffsetCallback():
log(arr)
with Context():
- # The module takes a subview of the argument memref, casts it to an unranked memref and
+ # The module takes a subview of the argument memref, casts it to an unranked memref and
# calls the callback with it.
module = Module.parse(
r"""
diff --git a/mlir/test/python/ir/dialects.py b/mlir/test/python/ir/dialects.py
index d59c6a6bc424..5a2ed684d298 100644
--- a/mlir/test/python/ir/dialects.py
+++ b/mlir/test/python/ir/dialects.py
@@ -121,3 +121,39 @@ def testAppendPrefixSearchPath():
sys.path.append(".")
_cext.globals.append_dialect_search_prefix("custom_dialect")
assert _cext.globals._check_dialect_module_loaded("custom")
+
+
+# CHECK-LABEL: TEST: testDialectLoadOnCreate
+@run
+def testDialectLoadOnCreate():
+ with Context(load_on_create_dialects=[]) as ctx:
+ ctx.emit_error_diagnostics = True
+ ctx.allow_unregistered_dialects = True
+
+ def callback(d):
+ # CHECK: DIAGNOSTIC
+ # CHECK-SAME: op created with unregistered dialect
+ print(f"DIAGNOSTIC={d.message}")
+ return True
+
+ handler = ctx.attach_diagnostic_handler(callback)
+ loc = Location.unknown(ctx)
+ try:
+ op = Operation.create("arith.addi", loc=loc)
+ ctx.allow_unregistered_dialects = False
+ op.verify()
+ except MLIRError as e:
+ pass
+
+ with Context(load_on_create_dialects=["func"]) as ctx:
+ loc = Location.unknown(ctx)
+ fn = Operation.create("func.func", loc=loc)
+
+ # TODO: This may require an update if a site wide policy is set.
+ # CHECK: Load on create: []
+ print(f"Load on create: {get_load_on_create_dialects()}")
+ append_load_on_create_dialect("func")
+ # CHECK: Load on create:
+ # CHECK-SAME: func
+ print(f"Load on create: {get_load_on_create_dialects()}")
+ print(get_load_on_create_dialects())
diff --git a/mlir/test/tblgen-lsp-server/templ-arg-check.test b/mlir/test/tblgen-lsp-server/templ-arg-check.test
new file mode 100644
index 000000000000..cda9b79a1f46
--- /dev/null
+++ b/mlir/test/tblgen-lsp-server/templ-arg-check.test
@@ -0,0 +1,15 @@
+// RUN: tblgen-lsp-server -lit-test < %s | FileCheck -strict-whitespace %s
+{"jsonrpc":"2.0","id":0,"method":"initialize","params":{"processId":123,"rootPath":"tablegen","capabilities":{},"trace":"off"}}
+// -----
+{"jsonrpc":"2.0","method":"textDocument/didOpen","params":{"textDocument":{
+ "uri":"test:///foo.td",
+ "languageId":"tablegen",
+ "version":1,
+ "text":"class Foo<int i>;\ndef : Foo<\"\">;"
+}}}
+// CHECK: "method": "textDocument/publishDiagnostics",
+// CHECK: "message": "Value specified for template argument 'Foo:i' is of type string; expected type int: \"\"",
+// -----
+{"jsonrpc":"2.0","id":3,"method":"shutdown"}
+// -----
+{"jsonrpc":"2.0","method":"exit"}
diff --git a/mlir/utils/pygments/README.md b/mlir/utils/pygments/README.md
new file mode 100644
index 000000000000..838faceb01b0
--- /dev/null
+++ b/mlir/utils/pygments/README.md
@@ -0,0 +1,45 @@
+## Pygments Lexer for MLIR
+
+This file contains a simple Pygments lexer configuration for MLIR, derived from
+the version used in the original CGO paper. Pygments allows for advanced
+configurable syntax highlighting of any code. This lexer is known to be
+incomplete and support mostly core IR with a subset of built-in types.
+Additions and customizations are welcome.
+
+### Standalone Usage
+
+Install Pygments, e.g., by running `pip install Pygments` or a Python package
+manager of your choosing. Use the standalone `pygmentize` command by
+instructing it to load the custom lexer:
+
+```
+pygmentize -l /path/to/mlir_lexer.py:MlirLexer -x myfile.mlir
+```
+
+This will produce highlighted output in the terminal. Other output formats are
+available, see Pygments [documentation](https://pygments.org/docs/) for more
+information.
+
+### LaTeX Usage
+
+First, make sure your distribution includes the `minted` package and list in
+the preamble.
+
+```latex
+\usepackage{minted}
+```
+
+Place the `mlir_lexer.py` in a place where the `latex` binary can find it,
+typically in the working directory next to the main `.tex` file. Note that you
+will have to invoke `latex` with the `-shell-escape` flag. See the `minted`
+package [documentation](https://ctan.org/pkg/minted?lang=en) for more
+information.
+
+Leverage the custom lexer facility of `minted` to use this lexer in your
+document as:
+
+```latex
+\begin{minted}{mlir_lexer.py:MlirLexer -x}
+ ... your code here ...
+\end{minted}
+```
diff --git a/mlir/utils/pygments/mlir_lexer.py b/mlir/utils/pygments/mlir_lexer.py
new file mode 100644
index 000000000000..179a058e9110
--- /dev/null
+++ b/mlir/utils/pygments/mlir_lexer.py
@@ -0,0 +1,38 @@
+# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+from pygments.lexer import RegexLexer
+from pygments.token import *
+
+
+class MlirLexer(RegexLexer):
+ name = "MLIR"
+ aliases = ["mlir"]
+ filenames = ["*.mlir"]
+
+ tokens = {
+ "root": [
+ (r"%[a-zA-Z0-9_]+", Name.Variable),
+ (r"@[a-zA-Z_][a-zA-Z0-9_]+", Name.Function),
+ (r"\^[a-zA-Z0-9_]+", Name.Label),
+ (r"#[a-zA-Z0-9_]+", Name.Constant),
+ (r"![a-zA-Z0-9_]+", Keyword.Type),
+ (r"[a-zA-Z_][a-zA-Z0-9_]*\.", Name.Entity),
+ (r"memref[^.]", Keyword.Type),
+ (r"index", Keyword.Type),
+ (r"i[0-9]+", Keyword.Type),
+ (r"f[0-9]+", Keyword.Type),
+ (r"[0-9]+", Number.Integer),
+ (r"[0-9]*\.[0-9]*", Number.Float),
+ (r'"[^"]*"', String.Double),
+ (r"affine_map", Keyword.Reserved),
+ # TODO: this should be within affine maps only
+ (r"\+-\*\/", Operator),
+ (r"floordiv", Operator.Word),
+ (r"ceildiv", Operator.Word),
+ (r"mod", Operator.Word),
+ (r"()\[\]<>,{}", Punctuation),
+ (r"\/\/.*\n", Comment.Single),
+ ]
+ }