diff options
| author | Matthias Springer <mspringer@nvidia.com> | 2025-03-21 14:49:28 +0100 |
|---|---|---|
| committer | Matthias Springer <mspringer@nvidia.com> | 2025-03-21 14:49:28 +0100 |
| commit | 85c0b6be5c046b342987ff3523836bd87806e971 (patch) | |
| tree | 780e382130e6231325a1df443d4a137ef683d14b | |
| parent | 53a395fda32cb0edd899202b6614595185b01ef1 (diff) | |
[mlir][IR] Add `ShapedTypeInterface`users/matthias-springer/scalar_type_interface
54 files changed, 317 insertions, 164 deletions
diff --git a/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h b/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h index 1f40eb6fc693..50b419bce78e 100644 --- a/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h +++ b/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h @@ -68,7 +68,7 @@ scf::ForOp createLoopOverTileSlices( bool isMultipleOfSMETileVectorType(VectorType vType); /// Creates a vector type for the SME tile of `elementType`. -VectorType getSMETileTypeForElement(Type elementType); +VectorType getSMETileTypeForElement(ScalarTypeInterface elementType); /// Erase trivially dead tile ops from a function. void eraseTriviallyDeadTileOps(IRRewriter &rewriter, diff --git a/mlir/include/mlir/IR/BuiltinDialectBytecode.td b/mlir/include/mlir/IR/BuiltinDialectBytecode.td index 0208e8cdbf29..7e79a17119c5 100644 --- a/mlir/include/mlir/IR/BuiltinDialectBytecode.td +++ b/mlir/include/mlir/IR/BuiltinDialectBytecode.td @@ -295,6 +295,8 @@ def VectorType : DialectType<(type Type:$elementType )> { let printerPredicate = "!$_val.isScalable()"; + // Note: Element type must implement ScalarTypeInterface. + let cBuilder = "get<$_resultType>(context, shape, llvm::cast<ScalarTypeInterface>(elementType))"; } def VectorTypeWithScalableDims : DialectType<(type @@ -304,7 +306,7 @@ def VectorTypeWithScalableDims : DialectType<(type )> { let printerPredicate = "$_val.isScalable()"; // Note: order of serialization does not match order of builder. - let cBuilder = "get<$_resultType>(context, shape, elementType, scalableDims)"; + let cBuilder = "get<$_resultType>(context, shape, llvm::cast<ScalarTypeInterface>(elementType), scalableDims)"; } } diff --git a/mlir/include/mlir/IR/BuiltinTypeInterfaces.td b/mlir/include/mlir/IR/BuiltinTypeInterfaces.td index 8aa2c5557015..71bd4df762d2 100644 --- a/mlir/include/mlir/IR/BuiltinTypeInterfaces.td +++ b/mlir/include/mlir/IR/BuiltinTypeInterfaces.td @@ -16,7 +16,40 @@ include "mlir/IR/OpBase.td" -def FloatTypeInterface : TypeInterface<"FloatType"> { +//===----------------------------------------------------------------------===// +// ScalarTypeInterface +//===----------------------------------------------------------------------===// + +def ScalarTypeInterface : TypeInterface<"ScalarTypeInterface"> { + let cppNamespace = "::mlir"; + let description = [{ + Indication that this type is a scalar type. + + The bitwidth of a scalar type is a fixed constant but may be unknown in the + absence of data layout information. + + Scalar types are POD (plain-old-data) entities that have an in-memory + representation: scalar values can be loaded/store from/to memory, so + abstract types like function types or async tokens cannot be scalar types. + + Scalar types should be limited to types that can lower to something that + egress dialects would consider a valid vector element type. + }]; + + let methods = [ + InterfaceMethod<[{ + Return the bitwidth of this type, if it has an inherent bitwidth. I.e., a + bitwidth that is known in the absence of data layout information. + }], + "std::optional<uint64_t>", "getInherentBitwidth", (ins)> + ]; +} + +//===----------------------------------------------------------------------===// +// FloatTypeInterface +//===----------------------------------------------------------------------===// + +def FloatTypeInterface : TypeInterface<"FloatType", [ScalarTypeInterface]> { let cppNamespace = "::mlir"; let description = [{ This type interface should be implemented by all floating-point types. It diff --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h index df1e02732617..a1950cda6318 100644 --- a/mlir/include/mlir/IR/BuiltinTypes.h +++ b/mlir/include/mlir/IR/BuiltinTypes.h @@ -275,7 +275,7 @@ public: scalableDims(other.getScalableDims()) {} /// Build from scratch. - Builder(ArrayRef<int64_t> shape, Type elementType, + Builder(ArrayRef<int64_t> shape, ScalarTypeInterface elementType, ArrayRef<bool> scalableDims = {}) : elementType(elementType), shape(shape), scalableDims(scalableDims) {} @@ -286,7 +286,7 @@ public: return *this; } - Builder &setElementType(Type newElementType) { + Builder &setElementType(ScalarTypeInterface newElementType) { elementType = newElementType; return *this; } @@ -312,7 +312,7 @@ public: } private: - Type elementType; + ScalarTypeInterface elementType; CopyOnWriteArrayRef<int64_t> shape; CopyOnWriteArrayRef<bool> scalableDims; }; diff --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td index af474b3e3ec4..2f03d5191385 100644 --- a/mlir/include/mlir/IR/BuiltinTypes.td +++ b/mlir/include/mlir/IR/BuiltinTypes.td @@ -85,6 +85,14 @@ class Builtin_FloatType<string name, string mnemonic, DeclareTypeInterfaceMethods< FloatTypeInterface, ["getFloatSemantics"] # declaredInterfaceMethods>]> { + + let extraClassDeclaration = [{ + /// Return the bitwidth of this type. This is an interface method of + /// ScalarTypeInterface. + std::optional<uint64_t> getInherentBitwidth() { + return static_cast<uint64_t>(::llvm::cast<FloatType>(*this).getWidth()); + } + }]; } // Float types that are cached in MLIRContext. @@ -93,6 +101,12 @@ class Builtin_CachedFloatType<string name, string mnemonic, : Builtin_FloatType<name, mnemonic, declaredInterfaceMethods> { let extraClassDeclaration = [{ static }] # name # [{Type get(MLIRContext *context); + + /// Return the bitwidth of this type. This is an interface method of + /// ScalarTypeInterface. + std::optional<uint64_t> getInherentBitwidth() { + return static_cast<uint64_t>(::llvm::cast<FloatType>(*this).getWidth()); + } }]; } @@ -447,7 +461,7 @@ def Builtin_Function : Builtin_Type<"Function", "function"> { // IndexType //===----------------------------------------------------------------------===// -def Builtin_Index : Builtin_Type<"Index", "index"> { +def Builtin_Index : Builtin_Type<"Index", "index", [ScalarTypeInterface]> { let summary = "Integer-like type with unknown platform-dependent bit width"; let description = [{ Syntax: @@ -467,6 +481,12 @@ def Builtin_Index : Builtin_Type<"Index", "index"> { let extraClassDeclaration = [{ static IndexType get(MLIRContext *context); + /// Return the bitwidth of this type. This is an interface method of + /// ScalarTypeInterface. + std::optional<uint64_t> getInherentBitwidth() const { + return std::nullopt; + } + /// Storage bit width used for IndexType by internal compiler data /// structures. static constexpr unsigned kInternalStorageBitWidth = 64; @@ -477,7 +497,8 @@ def Builtin_Index : Builtin_Type<"Index", "index"> { // IntegerType //===----------------------------------------------------------------------===// -def Builtin_Integer : Builtin_Type<"Integer", "integer"> { +def Builtin_Integer + : Builtin_Type<"Integer", "integer", [ScalarTypeInterface]> { let summary = "Integer type with arbitrary precision up to a fixed limit"; let description = [{ Syntax: @@ -531,6 +552,12 @@ def Builtin_Integer : Builtin_Type<"Integer", "integer"> { /// Return null if the scaled element type cannot be represented. IntegerType scaleElementBitwidth(unsigned scale); + /// Return the bitwidth of this type. This is an interface method of + /// ScalarTypeInterface. + std::optional<uint64_t> getInherentBitwidth() const { + return static_cast<uint64_t>(getWidth()); + } + /// Integer representation maximal bitwidth. /// Note: This is aligned with the maximum width of llvm::IntegerType. static constexpr unsigned kMaxWidth = (1 << 24) - 1; @@ -1249,10 +1276,6 @@ def Builtin_UnrankedTensor : Builtin_Type<"UnrankedTensor", "unranked_tensor", [ // VectorType //===----------------------------------------------------------------------===// -def Builtin_VectorTypeElementType : AnyTypeOf<[AnyInteger, Index, AnyFloat]> { - let cppFunctionName = "isValidVectorTypeElementType"; -} - def Builtin_Vector : Builtin_Type<"Vector", "vector", [ShapedTypeInterface, ValueSemantics], "Type"> { let summary = "Multi-dimensional SIMD vector type"; @@ -1303,12 +1326,12 @@ def Builtin_Vector : Builtin_Type<"Vector", "vector", }]; let parameters = (ins ArrayRefParameter<"int64_t">:$shape, - Builtin_VectorTypeElementType:$elementType, + AnyScalarType:$elementType, ArrayRefParameter<"bool">:$scalableDims ); let builders = [ TypeBuilderWithInferredContext<(ins - "ArrayRef<int64_t>":$shape, "Type":$elementType, + "ArrayRef<int64_t>":$shape, "ScalarTypeInterface":$elementType, CArg<"ArrayRef<bool>", "{}">:$scalableDims ), [{ // While `scalableDims` is optional, its default value should be diff --git a/mlir/include/mlir/IR/CommonTypeConstraints.td b/mlir/include/mlir/IR/CommonTypeConstraints.td index 601517717978..709c7dc213ff 100644 --- a/mlir/include/mlir/IR/CommonTypeConstraints.td +++ b/mlir/include/mlir/IR/CommonTypeConstraints.td @@ -203,6 +203,10 @@ class ConfinedType<Type type, list<Pred> predicates, string summary = "", list<Pred> predicateList = predicates; } +def AnyScalarType : Type< + CPred<"::llvm::isa<::mlir::ScalarTypeInterface>($_self)">, + "scalable type", "::mlir::ScalarTypeInterface">; + // Integer types. // Any integer type irrespective of its width and signedness semantics. diff --git a/mlir/lib/AsmParser/TypeParser.cpp b/mlir/lib/AsmParser/TypeParser.cpp index 21bb0ec3d0d5..1ccd16e1b3ab 100644 --- a/mlir/lib/AsmParser/TypeParser.cpp +++ b/mlir/lib/AsmParser/TypeParser.cpp @@ -491,7 +491,14 @@ VectorType Parser::parseVectorType() { if (!elementType || parseToken(Token::greater, "expected '>' in vector type")) return nullptr; - return getChecked<VectorType>(loc, dimensions, elementType, scalableDims); + auto scalarElementType = dyn_cast<ScalarTypeInterface>(elementType); + if (!scalarElementType) { + emitWrongTokenError("vector type requires scalar element type"); + return nullptr; + } + + return getChecked<VectorType>(loc, dimensions, scalarElementType, + scalableDims); } /// Parse a dimension list in a vector type. This populates the dimension list. diff --git a/mlir/lib/CAPI/IR/BuiltinTypes.cpp b/mlir/lib/CAPI/IR/BuiltinTypes.cpp index a080adf0f810..80e8c239689b 100644 --- a/mlir/lib/CAPI/IR/BuiltinTypes.cpp +++ b/mlir/lib/CAPI/IR/BuiltinTypes.cpp @@ -364,21 +364,22 @@ bool mlirTypeIsAVector(MlirType type) { MlirType mlirVectorTypeGet(intptr_t rank, const int64_t *shape, MlirType elementType) { return wrap(VectorType::get(llvm::ArrayRef(shape, static_cast<size_t>(rank)), - unwrap(elementType))); + cast<ScalarTypeInterface>(unwrap(elementType)))); } MlirType mlirVectorTypeGetChecked(MlirLocation loc, intptr_t rank, const int64_t *shape, MlirType elementType) { return wrap(VectorType::getChecked( unwrap(loc), llvm::ArrayRef(shape, static_cast<size_t>(rank)), - unwrap(elementType))); + cast<ScalarTypeInterface>(unwrap(elementType)))); } MlirType mlirVectorTypeGetScalable(intptr_t rank, const int64_t *shape, const bool *scalable, MlirType elementType) { - return wrap(VectorType::get( - llvm::ArrayRef(shape, static_cast<size_t>(rank)), unwrap(elementType), - llvm::ArrayRef(scalable, static_cast<size_t>(rank)))); + return wrap( + VectorType::get(llvm::ArrayRef(shape, static_cast<size_t>(rank)), + cast<ScalarTypeInterface>(unwrap(elementType)), + llvm::ArrayRef(scalable, static_cast<size_t>(rank)))); } MlirType mlirVectorTypeGetScalableChecked(MlirLocation loc, intptr_t rank, @@ -387,7 +388,7 @@ MlirType mlirVectorTypeGetScalableChecked(MlirLocation loc, intptr_t rank, MlirType elementType) { return wrap(VectorType::getChecked( unwrap(loc), llvm::ArrayRef(shape, static_cast<size_t>(rank)), - unwrap(elementType), + cast<ScalarTypeInterface>(unwrap(elementType)), llvm::ArrayRef(scalable, static_cast<size_t>(rank)))); } diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp index 949424db7c4d..bedebabc4908 100644 --- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp +++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp @@ -283,7 +283,7 @@ struct RawBufferOpLowering : public ConvertOpToLLVMPattern<GpuOp> { Type llvmWantedDataType = this->typeConverter->convertType(wantedDataType); - Type i32 = rewriter.getI32Type(); + auto i32 = rewriter.getI32Type(); // Get the type size in bytes. DataLayout dataLayout = DataLayout::closest(gpuOp); @@ -560,7 +560,7 @@ static void wmmaPushInputOperand(ConversionPatternRewriter &rewriter, int64_t numBits = vectorType.getNumElements() * elemType.getIntOrFloatBitWidth(); - Type i32 = rewriter.getI32Type(); + auto i32 = rewriter.getI32Type(); Type intrinsicInType = numBits <= 32 ? (Type)rewriter.getIntegerType(numBits) : (Type)VectorType::get(numBits / 32, i32); @@ -1099,8 +1099,9 @@ struct AMDGPUDPPLowering : public ConvertOpToLLVMPattern<DPPOp> { operand = rewriter.create<LLVM::BitcastOp>(loc, llvmSrcIntType, operand); } - auto llvmVecType = typeConverter->convertType(mlir::VectorType::get( - 32 / operandType.getIntOrFloatBitWidth(), llvmSrcIntType)); + auto llvmVecType = typeConverter->convertType( + mlir::VectorType::get(32 / operandType.getIntOrFloatBitWidth(), + cast<ScalarTypeInterface>(llvmSrcIntType))); Value undefVec = rewriter.create<LLVM::UndefOp>(loc, llvmVecType); operand = rewriter.create<LLVM::InsertElementOp>( loc, undefVec, operand, createI32Constant(rewriter, loc, 0)); diff --git a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp index 27be54728c1a..d17a610e2ac2 100644 --- a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp +++ b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp @@ -250,7 +250,8 @@ TruncFToFloat8RewritePattern::matchAndRewrite(arith::TruncFOp op, if (saturateFP8) in = clampInput(rewriter, loc, outElemType, in); auto inVectorTy = dyn_cast<VectorType>(in.getType()); - VectorType truncResType = VectorType::get(4, outElemType); + VectorType truncResType = + VectorType::get(4, cast<ScalarTypeInterface>(outElemType)); if (!inVectorTy) { Value asFloat = castToF32(in, loc, rewriter); Value asF8s = rewriter.create<amdgpu::PackedTrunc2xFp8Op>( @@ -331,7 +332,8 @@ LogicalResult TruncfToFloat16RewritePattern::matchAndRewrite( Location loc = op.getLoc(); Value in = op.getIn(); Type outElemType = getElementTypeOrSelf(op.getOut().getType()); - VectorType truncResType = VectorType::get(2, outElemType); + VectorType truncResType = + VectorType::get(2, cast<ScalarTypeInterface>(outElemType)); auto inVectorTy = dyn_cast<VectorType>(in.getType()); // Handle the case where input type is not a vector type diff --git a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp index 9c4dfa27b144..13ff632c18b4 100644 --- a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp +++ b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp @@ -322,7 +322,8 @@ struct ConstantCompositeOpPattern final dstAttrType = RankedTensorType::get(dstAttrType.getShape(), dstElemType); else - dstAttrType = VectorType::get(dstAttrType.getShape(), dstElemType); + dstAttrType = VectorType::get(dstAttrType.getShape(), + cast<ScalarTypeInterface>(dstElemType)); dstElementsAttr = DenseElementsAttr::get(dstAttrType, elements); } @@ -908,7 +909,8 @@ public: // cases. Extend them to 32-bit and do comparision then. Type type = rewriter.getI32Type(); if (auto vectorType = dyn_cast<VectorType>(dstType)) - type = VectorType::get(vectorType.getShape(), type); + type = VectorType::get(vectorType.getShape(), + cast<ScalarTypeInterface>(type)); Value extLhs = rewriter.create<arith::ExtUIOp>(op.getLoc(), type, adaptor.getLhs()); Value extRhs = diff --git a/mlir/lib/Conversion/ArmNeon2dToIntr/ArmNeon2dToIntr.cpp b/mlir/lib/Conversion/ArmNeon2dToIntr/ArmNeon2dToIntr.cpp index 9c6de938a710..d984ab5d932b 100644 --- a/mlir/lib/Conversion/ArmNeon2dToIntr/ArmNeon2dToIntr.cpp +++ b/mlir/lib/Conversion/ArmNeon2dToIntr/ArmNeon2dToIntr.cpp @@ -33,7 +33,8 @@ public: /// arm.neon.intr.sdot LogicalResult matchAndRewrite(Sdot2dOp op, PatternRewriter &rewriter) const override { - Type elemType = cast<VectorType>(op.getB().getType()).getElementType(); + ScalarTypeInterface elemType = + cast<VectorType>(op.getB().getType()).getElementType(); int length = cast<VectorType>(op.getB().getType()).getShape()[0] * Sdot2dOp::kReductionSize; VectorType flattenedVectorType = VectorType::get({length}, elemType); diff --git a/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp b/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp index 4bd94bcebf29..6a04bd39f2d8 100644 --- a/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp +++ b/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp @@ -303,7 +303,8 @@ static Value createMinMaxF(OpBuilder &builder, Location loc, Value lhs, auto floatType = cast<FloatType>(getElementTypeOrSelf(lhs.getType())); Type i1Type = builder.getI1Type(); if (auto vecType = dyn_cast<VectorType>(lhs.getType())) - i1Type = VectorType::get(vecType.getShape(), i1Type); + i1Type = + VectorType::get(vecType.getShape(), cast<ScalarTypeInterface>(i1Type)); Value cmp = builder.create<LLVM::FCmpOp>( loc, i1Type, isMin ? LLVM::FCmpPredicate::olt : LLVM::FCmpPredicate::ogt, lhs, rhs); diff --git a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp index ea251e4564ea..d170a1f01dad 100644 --- a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp +++ b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp @@ -638,8 +638,9 @@ FailureOr<Type> LLVMTypeConverter::convertVectorType(VectorType type) const { if (!elementType) return {}; if (type.getShape().empty()) - return VectorType::get({1}, elementType); - Type vectorType = VectorType::get(type.getShape().back(), elementType, + return VectorType::get({1}, cast<ScalarTypeInterface>(elementType)); + Type vectorType = VectorType::get(type.getShape().back(), + cast<ScalarTypeInterface>(elementType), type.getScalableDims().back()); assert(LLVM::isCompatibleVectorType(vectorType) && "expected vector type compatible with the LLVM dialect"); diff --git a/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp b/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp index 1b83794b5f45..6676477b9e34 100644 --- a/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp +++ b/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp @@ -154,7 +154,7 @@ struct CopySignPattern final : public OpConversionPattern<math::CopySignOp> { if (auto vectorType = dyn_cast<VectorType>(type)) { assert(vectorType.getRank() == 1); int count = vectorType.getNumElements(); - intType = VectorType::get(count, intType); + intType = VectorType::get(count, cast<ScalarTypeInterface>(intType)); SmallVector<Value> signSplat(count, signMask); signMask = @@ -380,7 +380,8 @@ struct PowFOpPattern final : public OpConversionPattern<math::PowFOp> { auto operandType = adaptor.getRhs().getType(); if (auto vectorType = dyn_cast<VectorType>(operandType)) { auto shape = vectorType.getShape(); - intType = VectorType::get(shape, scalarIntType); + intType = + VectorType::get(shape, cast<ScalarTypeInterface>(scalarIntType)); } // Per GL Pow extended instruction spec: diff --git a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp index 6e0adfc1e0ff..d1e10ef2e80f 100644 --- a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp +++ b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp @@ -152,7 +152,8 @@ static Value optionallyTruncateOrExtend(Location loc, Value value, static Value broadcast(Location loc, Value toBroadcast, unsigned numElements, const TypeConverter &typeConverter, ConversionPatternRewriter &rewriter) { - auto vectorType = VectorType::get(numElements, toBroadcast.getType()); + auto vectorType = VectorType::get( + numElements, cast<ScalarTypeInterface>(toBroadcast.getType())); auto llvmVectorType = typeConverter.convertType(vectorType); auto llvmI32Type = typeConverter.convertType(rewriter.getIntegerType(32)); Value broadcasted = rewriter.create<LLVM::PoisonOp>(loc, llvmVectorType); diff --git a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp index eaefe9e38579..5ed167bde089 100644 --- a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp +++ b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp @@ -631,7 +631,7 @@ getMmaSyncVectorOperandType(const nvgpu::FragmentElementInfo ®Info) { Type elType = regInfo.registerLLVMType; if (auto vecType = dyn_cast<VectorType>(elType)) elType = vecType.getElementType(); - return VectorType::get(shape, elType); + return VectorType::get(shape, cast<ScalarTypeInterface>(elType)); } /// Convert a 2D splat ConstantOp to a SubgroupMmaConstantMatrix op. @@ -802,7 +802,8 @@ createNonLdMatrixLoads(RewriterBase &rewriter, vector::TransferReadOp op, // must load each element individually. if (!isTransposeLoad) { if (!isa<VectorType>(loadedElType)) { - loadedElType = VectorType::get({1}, loadedElType); + loadedElType = + VectorType::get({1}, cast<ScalarTypeInterface>(loadedElType)); } for (int i = 0; i < vectorType.getShape()[0]; i++) { diff --git a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp index 95db83118559..8cb35e1cab93 100644 --- a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp +++ b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp @@ -1480,7 +1480,10 @@ struct UnrollTransferWriteConversion // argument into `transfer_write` to become a scalar. We solve // this by broadcasting the scalar to a 0D vector. xferVec = b.create<vector::BroadcastOp>( - loc, VectorType::get({}, extracted.getType()), extracted); + loc, + VectorType::get( + {}, cast<ScalarTypeInterface>(extracted.getType())), + extracted); } else { xferVec = extracted; } diff --git a/mlir/lib/Dialect/AMDGPU/Transforms/EmulateAtomics.cpp b/mlir/lib/Dialect/AMDGPU/Transforms/EmulateAtomics.cpp index 7dd4be66d2bd..87c94f23b515 100644 --- a/mlir/lib/Dialect/AMDGPU/Transforms/EmulateAtomics.cpp +++ b/mlir/lib/Dialect/AMDGPU/Transforms/EmulateAtomics.cpp @@ -98,7 +98,7 @@ static Value flattenVecToBits(ConversionPatternRewriter &rewriter, Location loc, int64_t bitwidth = vectorType.getElementTypeBitWidth() * vectorType.getNumElements(); - Type allBitsType = rewriter.getIntegerType(bitwidth); + auto allBitsType = rewriter.getIntegerType(bitwidth); auto allBitsVecType = VectorType::get({1}, allBitsType); Value bitcast = rewriter.create<vector::BitCastOp>(loc, allBitsVecType, val); Value scalar = rewriter.create<vector::ExtractOp>(loc, bitcast, 0); diff --git a/mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp b/mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp index eaaafaf68767..38df408ad3b0 100644 --- a/mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp +++ b/mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp @@ -936,7 +936,8 @@ isVectorizableLoopPtrFactory(const DenseSet<Operation *> ¶llelLoops, static VectorType getVectorType(Type scalarTy, const VectorizationStrategy *strategy) { assert(!isa<VectorType>(scalarTy) && "Expected scalar type"); - return VectorType::get(strategy->vectorSizes, scalarTy); + return VectorType::get(strategy->vectorSizes, + cast<ScalarTypeInterface>(scalarTy)); } /// Tries to transform a scalar constant into a vector constant. Returns the @@ -1195,7 +1196,8 @@ static Operation *vectorizeAffineLoad(AffineLoadOp loadOp, VectorizationState &state) { MemRefType memRefType = loadOp.getMemRefType(); Type elementType = memRefType.getElementType(); - auto vectorType = VectorType::get(state.strategy->vectorSizes, elementType); + auto vectorType = VectorType::get(state.strategy->vectorSizes, + cast<ScalarTypeInterface>(elementType)); // Replace map operands with operands from the vector loop nest. SmallVector<Value, 8> mapOperands; @@ -1426,7 +1428,8 @@ static Operation *widenOp(Operation *op, VectorizationState &state) { SmallVector<Type, 8> vectorTypes; for (Value result : op->getResults()) vectorTypes.push_back( - VectorType::get(state.strategy->vectorSizes, result.getType())); + VectorType::get(state.strategy->vectorSizes, + cast<ScalarTypeInterface>(result.getType()))); SmallVector<Value, 8> vectorOperands; for (Value operand : op->getOperands()) { @@ -1832,7 +1835,6 @@ verifyLoopNesting(const std::vector<SmallVector<AffineForOp, 2>> &loops) { return success(); } - /// External utility to vectorize affine loops in 'loops' using the n-D /// vectorization factors in 'vectorSizes'. By default, each vectorization /// factor is applied inner-to-outer to the loops of each loop nest. diff --git a/mlir/lib/Dialect/Arith/Transforms/EmulateWideInt.cpp b/mlir/lib/Dialect/Arith/Transforms/EmulateWideInt.cpp index 61f8d82a615d..3d00efa72ec5 100644 --- a/mlir/lib/Dialect/Arith/Transforms/EmulateWideInt.cpp +++ b/mlir/lib/Dialect/Arith/Transforms/EmulateWideInt.cpp @@ -581,7 +581,8 @@ struct ConvertIndexCastIndexToInt final : OpConversionPattern<CastOp> { Type narrowTy = rewriter.getIntegerType(typeConverter->getMaxTargetIntBitWidth()); if (auto vecTy = dyn_cast<VectorType>(resultType)) - narrowTy = VectorType::get(vecTy.getShape(), narrowTy); + narrowTy = VectorType::get(vecTy.getShape(), + cast<ScalarTypeInterface>(narrowTy)); // Sign or zero-extend the result. Let the matching conversion pattern // legalize the extension op. diff --git a/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp b/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp index 2a1271dfd6bd..013fb0019755 100644 --- a/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp +++ b/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp @@ -162,10 +162,10 @@ public: Value tiledAcc = extractOperand(op.getAcc(), accPermutationMap, accOffsets); - auto inputElementType = - cast<ShapedType>(tiledLhs.getType()).getElementType(); - auto accElementType = - cast<ShapedType>(tiledAcc.getType()).getElementType(); + auto inputElementType = cast<ScalarTypeInterface>( + cast<ShapedType>(tiledLhs.getType()).getElementType()); + auto accElementType = cast<ScalarTypeInterface>( + cast<ShapedType>(tiledAcc.getType()).getElementType()); auto inputExpandedType = VectorType::get({2, 8}, inputElementType); auto outputExpandedType = VectorType::get({2, 2}, accElementType); diff --git a/mlir/lib/Dialect/ArmSME/IR/Utils.cpp b/mlir/lib/Dialect/ArmSME/IR/Utils.cpp index 1f7305a5f814..3975b400950e 100644 --- a/mlir/lib/Dialect/ArmSME/IR/Utils.cpp +++ b/mlir/lib/Dialect/ArmSME/IR/Utils.cpp @@ -111,7 +111,7 @@ bool isMultipleOfSMETileVectorType(VectorType vType) { vectorRows % minNumElts == 0 && vectorCols % minNumElts == 0; } -VectorType getSMETileTypeForElement(Type elementType) { +VectorType getSMETileTypeForElement(ScalarTypeInterface elementType) { unsigned minNumElts = getSMETileSliceMinNumElts(elementType); return VectorType::get({minNumElts, minNumElts}, elementType, {true, true}); } diff --git a/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp b/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp index edd7f607f24f..9f7082ca9360 100644 --- a/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp +++ b/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp @@ -89,7 +89,7 @@ LogicalResult AssertOp::canonicalize(AssertOp op, PatternRewriter &rewriter) { return failure(); } -// This side effect models "program termination". +// This side effect models "program termination". void AssertOp::getEffects( SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>> &effects) { @@ -480,8 +480,9 @@ void SwitchOp::build(OpBuilder &builder, OperationState &result, Value value, ArrayRef<ValueRange> caseOperands) { DenseIntElementsAttr caseValuesAttr; if (!caseValues.empty()) { - ShapedType caseValueType = VectorType::get( - static_cast<int64_t>(caseValues.size()), value.getType()); + ShapedType caseValueType = + VectorType::get(static_cast<int64_t>(caseValues.size()), + cast<ScalarTypeInterface>(value.getType())); caseValuesAttr = DenseIntElementsAttr::get(caseValueType, caseValues); } build(builder, result, value, defaultDestination, defaultOperands, @@ -494,8 +495,9 @@ void SwitchOp::build(OpBuilder &builder, OperationState &result, Value value, ArrayRef<ValueRange> caseOperands) { DenseIntElementsAttr caseValuesAttr; if (!caseValues.empty()) { - ShapedType caseValueType = VectorType::get( - static_cast<int64_t>(caseValues.size()), value.getType()); + ShapedType caseValueType = + VectorType::get(static_cast<int64_t>(caseValues.size()), + cast<ScalarTypeInterface>(value.getType())); caseValuesAttr = DenseIntElementsAttr::get(caseValueType, caseValues); } build(builder, result, value, defaultDestination, defaultOperands, @@ -550,7 +552,8 @@ static ParseResult parseSwitchOpCases( if (!values.empty()) { ShapedType caseValueType = - VectorType::get(static_cast<int64_t>(values.size()), flagType); + VectorType::get(static_cast<int64_t>(values.size()), + cast<ScalarTypeInterface>(flagType)); caseValues = DenseIntElementsAttr::get(caseValueType, values); } return success(); diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp index 5370de501a85..833eb96baadc 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -548,8 +548,9 @@ void SwitchOp::build(OpBuilder &builder, OperationState &result, Value value, ArrayRef<int32_t> branchWeights) { DenseIntElementsAttr caseValuesAttr; if (!caseValues.empty()) { - ShapedType caseValueType = VectorType::get( - static_cast<int64_t>(caseValues.size()), value.getType()); + ShapedType caseValueType = + VectorType::get(static_cast<int64_t>(caseValues.size()), + cast<ScalarTypeInterface>(value.getType())); caseValuesAttr = DenseIntElementsAttr::get(caseValueType, caseValues); } @@ -564,8 +565,9 @@ void SwitchOp::build(OpBuilder &builder, OperationState &result, Value value, ArrayRef<int32_t> branchWeights) { DenseIntElementsAttr caseValuesAttr; if (!caseValues.empty()) { - ShapedType caseValueType = VectorType::get( - static_cast<int64_t>(caseValues.size()), value.getType()); + ShapedType caseValueType = + VectorType::get(static_cast<int64_t>(caseValues.size()), + cast<ScalarTypeInterface>(value.getType())); caseValuesAttr = DenseIntElementsAttr::get(caseValueType, caseValues); } @@ -611,8 +613,8 @@ static ParseResult parseSwitchOpCases( if (failed(parser.parseCommaSeparatedList(parseCase))) return failure(); - ShapedType caseValueType = - VectorType::get(static_cast<int64_t>(values.size()), flagType); + ShapedType caseValueType = VectorType::get( + static_cast<int64_t>(values.size()), cast<ScalarTypeInterface>(flagType)); caseValues = DenseIntElementsAttr::get(caseValueType, values); return parser.parseRSquare(); } diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp index 8f39ede721c9..5e790de461ce 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp @@ -946,7 +946,8 @@ Type mlir::LLVM::getVectorType(Type elementType, unsigned numElements, // LLVM vectors are always 1-D, hence only 1 bool is required to mark it as // scalable/non-scalable. - return VectorType::get(numElements, elementType, {isScalable}); + return VectorType::get(numElements, cast<ScalarTypeInterface>(elementType), + {isScalable}); } Type mlir::LLVM::getVectorType(Type elementType, @@ -966,7 +967,7 @@ Type mlir::LLVM::getFixedVectorType(Type elementType, unsigned numElements) { "to be either builtin or LLVM dialect type"); if (useLLVM) return LLVMFixedVectorType::get(elementType, numElements); - return VectorType::get(numElements, elementType); + return VectorType::get(numElements, cast<ScalarTypeInterface>(elementType)); } Type mlir::LLVM::getScalableVectorType(Type elementType, unsigned numElements) { @@ -981,7 +982,8 @@ Type mlir::LLVM::getScalableVectorType(Type elementType, unsigned numElements) { // LLVM vectors are always 1-D, hence only 1 bool is required to mark it as // scalable/non-scalable. - return VectorType::get(numElements, elementType, /*scalableDims=*/true); + return VectorType::get(numElements, cast<ScalarTypeInterface>(elementType), + /*scalableDims=*/true); } llvm::TypeSize mlir::LLVM::getPrimitiveTypeSizeInBits(Type type) { diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp index 2dcd897330d1..e4909c4ee0f6 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -249,7 +249,8 @@ struct VectorizationState { scalableDims.append(scalableVecDims.begin(), scalableVecDims.end()); } - return VectorType::get(vectorShape, elementType, scalableDims); + return VectorType::get(vectorShape, cast<ScalarTypeInterface>(elementType), + scalableDims); } /// Masks an operation with the canonical vector mask if the operation needs @@ -1338,9 +1339,10 @@ vectorizeOneOp(RewriterBase &rewriter, VectorizationState &state, assert(vecOperand && "Vector operand couldn't be found"); if (firstMaxRankedType) { - auto vecType = VectorType::get(firstMaxRankedType.getShape(), - getElementTypeOrSelf(vecOperand.getType()), - firstMaxRankedType.getScalableDims()); + auto vecType = VectorType::get( + firstMaxRankedType.getShape(), + cast<ScalarTypeInterface>(getElementTypeOrSelf(vecOperand.getType())), + firstMaxRankedType.getScalableDims()); vecOperands.push_back(broadcastIfNeeded(rewriter, vecOperand, vecType)); } else { vecOperands.push_back(vecOperand); @@ -1351,7 +1353,8 @@ vectorizeOneOp(RewriterBase &rewriter, VectorizationState &state, for (Type resultType : op->getResultTypes()) { resultTypes.push_back( firstMaxRankedType - ? VectorType::get(firstMaxRankedType.getShape(), resultType, + ? VectorType::get(firstMaxRankedType.getShape(), + cast<ScalarTypeInterface>(resultType), firstMaxRankedType.getScalableDims()) : resultType); } @@ -1632,8 +1635,9 @@ vectorizeAsTensorPackOp(RewriterBase &rewriter, linalg::PackOp packOp, // Create ShapeCastOp. SmallVector<int64_t> destShape(inputVectorSizes); destShape.append(innerTiles.begin(), innerTiles.end()); - auto tiledPackType = VectorType::get(getTiledPackShape(packOp, destShape), - packOp.getDestType().getElementType()); + auto tiledPackType = VectorType::get( + getTiledPackShape(packOp, destShape), + cast<ScalarTypeInterface>(packOp.getDestType().getElementType())); auto shapeCastOp = rewriter.create<vector::ShapeCastOp>(loc, tiledPackType, maskedRead); @@ -1768,8 +1772,9 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp, // Collapse the vector to the size required by result. RankedTensorType collapsedType = tensor::CollapseShapeOp::inferCollapsedType( stripMineTensorType, packMetadata.reassociations); - mlir::VectorType vecCollapsedType = - VectorType::get(collapsedType.getShape(), collapsedType.getElementType()); + mlir::VectorType vecCollapsedType = VectorType::get( + collapsedType.getShape(), + cast<ScalarTypeInterface>(collapsedType.getElementType())); vector::ShapeCastOp shapeCastOp = rewriter.create<vector::ShapeCastOp>( loc, vecCollapsedType, transposeOp->getResult(0)); @@ -2473,8 +2478,10 @@ LogicalResult mlir::linalg::vectorizeCopy(RewriterBase &rewriter, !VectorType::isValidElementType(dstElementType)) return failure(); - auto readType = VectorType::get(srcType.getShape(), srcElementType); - auto writeType = VectorType::get(dstType.getShape(), dstElementType); + auto readType = VectorType::get(srcType.getShape(), + cast<ScalarTypeInterface>(srcElementType)); + auto writeType = VectorType::get(dstType.getShape(), + cast<ScalarTypeInterface>(dstElementType)); Location loc = copyOp->getLoc(); Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0); @@ -2839,7 +2846,8 @@ vectorizeAsInsertSliceOp(RewriterBase &rewriter, tensor::InsertSliceOp sliceOp, return failure(); } } - auto vecType = VectorType::get(vecShape, sourceType.getElementType()); + auto vecType = VectorType::get( + vecShape, cast<ScalarTypeInterface>(sourceType.getElementType())); // 3. Generate TransferReadOp + TransferWriteOp ReifiedRankedShapedTypeDims reifiedSrcSizes; @@ -2943,8 +2951,9 @@ struct PadOpVectorizationWithInsertSlicePattern if (insertOp.getDest() == padOp.getResult()) return failure(); - auto vecType = VectorType::get(padOp.getType().getShape(), - padOp.getType().getElementType()); + auto vecType = VectorType::get( + padOp.getType().getShape(), + cast<ScalarTypeInterface>(padOp.getType().getElementType())); unsigned vecRank = vecType.getRank(); unsigned tensorRank = insertOp.getType().getRank(); @@ -3366,9 +3375,12 @@ struct Conv1DGenerator Type lhsEltType = lhsShapedType.getElementType(); Type rhsEltType = rhsShapedType.getElementType(); Type resEltType = resShapedType.getElementType(); - auto lhsType = VectorType::get(lhsShape, lhsEltType); - auto rhsType = VectorType::get(rhsShape, rhsEltType); - auto resType = VectorType::get(resShape, resEltType); + auto lhsType = + VectorType::get(lhsShape, cast<ScalarTypeInterface>(lhsEltType)); + auto rhsType = + VectorType::get(rhsShape, cast<ScalarTypeInterface>(rhsEltType)); + auto resType = + VectorType::get(resShape, cast<ScalarTypeInterface>(resEltType)); // Zero padding with the corresponding dimensions for lhs, rhs and res. SmallVector<Value> lhsPadding(lhsShape.size(), zero); SmallVector<Value> rhsPadding(rhsShape.size(), zero); @@ -3595,13 +3607,14 @@ struct Conv1DGenerator // (i.e. 16 convolved with 3 (@stride 1 dilation 1) -> 14) ((wSize - 1) * strideW + 1) + ((kwSize - 1) * dilationW + 1) - 1, cSize}, - lhsEltType, /*scalableDims=*/{false, false, scalableChDim}); + cast<ScalarTypeInterface>(lhsEltType), + /*scalableDims=*/{false, false, scalableChDim}); VectorType rhsType = - VectorType::get({kwSize, cSize}, rhsEltType, + VectorType::get({kwSize, cSize}, cast<ScalarTypeInterface>(rhsEltType), /*scalableDims=*/{false, scalableChDim}); - VectorType resType = - VectorType::get({nSize, wSize, cSize}, resEltType, - /*scalableDims=*/{false, false, scalableChDim}); + VectorType resType = VectorType::get( + {nSize, wSize, cSize}, cast<ScalarTypeInterface>(resEltType), + /*scalableDims=*/{false, false, scalableChDim}); // Masks the input xfer Op along the channel dim, iff the corresponding // scalable flag is set. @@ -3685,10 +3698,10 @@ struct Conv1DGenerator // Note - the scalable flags are ignored as flattening combined with // scalable vectorization is not supported. SmallVector<int64_t> inOutFlattenSliceSizes = {nSize, wSizeStep * cSize}; - auto lhsTypeAfterFlattening = - VectorType::get(inOutFlattenSliceSizes, lhsEltType); - auto resTypeAfterFlattening = - VectorType::get(inOutFlattenSliceSizes, resEltType); + auto lhsTypeAfterFlattening = VectorType::get( + inOutFlattenSliceSizes, cast<ScalarTypeInterface>(lhsEltType)); + auto resTypeAfterFlattening = VectorType::get( + inOutFlattenSliceSizes, cast<ScalarTypeInterface>(resEltType)); // Compute contraction: O{n, w, c} += I{n, sw * w + dw * kw, c} * F{c} for (int64_t kw = 0; kw < kwSize; ++kw) { @@ -3708,7 +3721,10 @@ struct Conv1DGenerator if (flatten) { // Un-flatten the output vector (restore the channel dimension) resVals[w] = rewriter.create<vector::ShapeCastOp>( - loc, VectorType::get(inOutSliceSizes, resEltType), resVals[w]); + loc, + VectorType::get(inOutSliceSizes, + cast<ScalarTypeInterface>(resEltType)), + resVals[w]); } } } diff --git a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp index a26e380232a9..bdfedabe2364 100644 --- a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp +++ b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp @@ -64,7 +64,8 @@ static std::optional<VectorShape> vectorShape(Value value) { // Broadcasts scalar type into vector type (iff shape is non-scalar). static Type broadcast(Type type, std::optional<VectorShape> shape) { assert(!isa<VectorType>(type) && "must be scalar type"); - return shape ? VectorType::get(shape->sizes, type, shape->scalableFlags) + return shape ? VectorType::get(shape->sizes, cast<ScalarTypeInterface>(type), + shape->scalableFlags) : type; } @@ -156,7 +157,8 @@ handleMultidimensionalVectors(ImplicitLocOpBuilder &builder, // Stitch results together into one large vector. Type resultEltType = cast<VectorType>(results[0].getType()).getElementType(); - Type resultExpandedType = VectorType::get(expandedShape, resultEltType); + Type resultExpandedType = + VectorType::get(expandedShape, cast<ScalarTypeInterface>(resultEltType)); Value result = builder.create<arith::ConstantOp>( resultExpandedType, builder.getZeroAttr(resultExpandedType)); @@ -166,7 +168,8 @@ handleMultidimensionalVectors(ImplicitLocOpBuilder &builder, // Reshape back to the original vector shape. return builder.create<vector::ShapeCastOp>( - VectorType::get(inputShape, resultEltType), result); + VectorType::get(inputShape, cast<ScalarTypeInterface>(resultEltType)), + result); } //----------------------------------------------------------------------------// diff --git a/mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp b/mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp index 556922a64b09..b2b914cd6642 100644 --- a/mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp +++ b/mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp @@ -684,7 +684,8 @@ Value MmaSyncBuilder::buildMmaSyncMemRefLoadOperand( auto loads = buildMemRefLoads(b, loc, laneId, memref, std::move(indexFn)); Type elementType = getElementTypeOrSelf(memref.getType()); - auto vt = VectorType::get(vectorShape, elementType); + auto vt = + VectorType::get(vectorShape, cast<ScalarTypeInterface>(elementType)); Value res = b.create<vector::SplatOp>(loc, vt, loads[0]); foreachIndividualVectorElement( res, diff --git a/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp b/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp index 7c0d36964865..7281e0da7f7f 100644 --- a/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp +++ b/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp @@ -6,9 +6,9 @@ // //===----------------------------------------------------------------------===// +#include "mlir/Dialect/Quant/IR/QuantTypes.h" #include "TypeDetail.h" #include "mlir/Dialect/Quant/IR/Quant.h" -#include "mlir/Dialect/Quant/IR/QuantTypes.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/MLIRContext.h" @@ -34,7 +34,7 @@ double getMaxScale(Type expressedType) { return APFloat::getLargest(floatType.getFloatSemantics()).convertToDouble(); } -} // namespace +} // namespace unsigned QuantizedType::getFlags() const { return static_cast<ImplType *>(impl)->flags; @@ -146,7 +146,7 @@ Type QuantizedType::castFromStorageType(Type candidateType) { if (llvm::isa<VectorType>(candidateType)) { // i.e. tensor<4xi8> -> tensor<4x!quant<"uniform[i8:f32]{1.0}">> return VectorType::get(llvm::cast<VectorType>(candidateType).getShape(), - getStorageType()); + llvm::cast<ScalarTypeInterface>(getStorageType())); } return nullptr; @@ -172,7 +172,8 @@ Type QuantizedType::castToStorageType(Type quantizedType) { return UnrankedTensorType::get(storageType); } if (llvm::isa<VectorType>(quantizedType)) { - return VectorType::get(sType.getShape(), storageType); + return VectorType::get(sType.getShape(), + llvm::cast<ScalarTypeInterface>(storageType)); } } @@ -200,7 +201,8 @@ Type QuantizedType::castFromExpressedType(Type candidateType) { } if (llvm::isa<VectorType>(candidateType)) { // i.e. tensor<4xf32> -> tensor<4x!quant<"uniform[i8:f32]{1.0}">> - return VectorType::get(candidateShapedType.getShape(), *this); + return VectorType::get(candidateShapedType.getShape(), + llvm::cast<ScalarTypeInterface>(*this)); } } @@ -227,7 +229,8 @@ Type QuantizedType::castToExpressedType(Type quantizedType) { return UnrankedTensorType::get(expressedType); } if (llvm::isa<VectorType>(quantizedType)) { - return VectorType::get(sType.getShape(), expressedType); + return VectorType::get(sType.getShape(), + llvm::cast<ScalarTypeInterface>(expressedType)); } } diff --git a/mlir/lib/Dialect/Quant/Utils/UniformSupport.cpp b/mlir/lib/Dialect/Quant/Utils/UniformSupport.cpp index 62c7a7128d63..7cd7bc8da850 100644 --- a/mlir/lib/Dialect/Quant/Utils/UniformSupport.cpp +++ b/mlir/lib/Dialect/Quant/Utils/UniformSupport.cpp @@ -39,7 +39,8 @@ Type ExpressedToQuantizedConverter::convert(QuantizedType elementalType) const { if (dyn_cast<UnrankedTensorType>(inputType)) return UnrankedTensorType::get(elementalType); if (auto vectorType = dyn_cast<VectorType>(inputType)) - return VectorType::get(vectorType.getShape(), elementalType); + return VectorType::get(vectorType.getShape(), + cast<ScalarTypeInterface>(elementalType)); // If the expressed types match, just use the new elemental type. if (elementalType.getExpressedType() == expressedType) diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOpDefinition.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOpDefinition.cpp index d8dfe164458e..0b0a309c02c3 100644 --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOpDefinition.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOpDefinition.cpp @@ -42,7 +42,8 @@ static Type getUnaryOpResultType(Type operandType) { Builder builder(operandType.getContext()); Type resultType = builder.getIntegerType(1); if (auto vecType = llvm::dyn_cast<VectorType>(operandType)) - return VectorType::get(vecType.getNumElements(), resultType); + return VectorType::get(vecType.getNumElements(), + cast<ScalarTypeInterface>(resultType)); return resultType; } diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp index a60410d01ac5..77305de066c1 100644 --- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp +++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp @@ -366,7 +366,8 @@ convertVectorType(const spirv::TargetEnv &targetEnv, return nullptr; } - return VectorType::get(type.getShape(), elementType); + return VectorType::get(type.getShape(), + cast<ScalarTypeInterface>(elementType)); } if (type.getRank() <= 1 && type.getNumElements() == 1) @@ -392,7 +393,8 @@ convertVectorType(const spirv::TargetEnv &targetEnv, auto elementType = convertScalarType(targetEnv, options, scalarType, storageClass); if (elementType) - return VectorType::get(type.getShape(), elementType); + return VectorType::get(type.getShape(), + cast<ScalarTypeInterface>(elementType)); return nullptr; } @@ -417,7 +419,7 @@ convertComplexType(const spirv::TargetEnv &targetEnv, return nullptr; } - return VectorType::get(2, elementType); + return VectorType::get(2, cast<ScalarTypeInterface>(elementType)); } /// Converts a tensor `type` to a suitable type under the given `targetEnv`. @@ -770,8 +772,9 @@ getOrInsertBuiltinVariable(Block &body, Location loc, spirv::BuiltIn builtin, case spirv::BuiltIn::WorkgroupId: case spirv::BuiltIn::LocalInvocationId: case spirv::BuiltIn::GlobalInvocationId: { - auto ptrType = spirv::PointerType::get(VectorType::get({3}, integerType), - spirv::StorageClass::Input); + auto ptrType = spirv::PointerType::get( + VectorType::get({3}, cast<ScalarTypeInterface>(integerType)), + spirv::StorageClass::Input); std::string name = getBuiltinVarName(builtin, prefix, suffix); newVarOp = builder.create<spirv::GlobalVariableOp>(loc, ptrType, name, builtin); diff --git a/mlir/lib/Dialect/SPIRV/Transforms/UnifyAliasedResourcePass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/UnifyAliasedResourcePass.cpp index 07cf26926a1d..e0337ae7e916 100644 --- a/mlir/lib/Dialect/SPIRV/Transforms/UnifyAliasedResourcePass.cpp +++ b/mlir/lib/Dialect/SPIRV/Transforms/UnifyAliasedResourcePass.cpp @@ -496,7 +496,8 @@ struct ConvertLoad : public ConvertAliasResource<spirv::LoadOp> { Type vectorType = srcElemType; if (!isa<VectorType>(srcElemType)) - vectorType = VectorType::get({ratio}, dstElemType); + vectorType = + VectorType::get({ratio}, cast<ScalarTypeInterface>(dstElemType)); // If both the source and destination are vector types, we need to make // sure the scalar type is the same for composite construction later. @@ -511,7 +512,8 @@ struct ConvertLoad : public ConvertAliasResource<spirv::LoadOp> { // SPIR-V. Type castType = srcElemVecType.getElementType(); if (count > 1) - castType = VectorType::get({count}, castType); + castType = + VectorType::get({count}, cast<ScalarTypeInterface>(castType)); for (Value &c : components) c = rewriter.create<spirv::BitcastOp>(loc, castType, c); diff --git a/mlir/lib/Dialect/SPIRV/Utils/LayoutUtils.cpp b/mlir/lib/Dialect/SPIRV/Utils/LayoutUtils.cpp index b19495bc3744..9a416eb15ef8 100644 --- a/mlir/lib/Dialect/SPIRV/Utils/LayoutUtils.cpp +++ b/mlir/lib/Dialect/SPIRV/Utils/LayoutUtils.cpp @@ -118,7 +118,7 @@ Type VulkanLayoutUtils::decorateType(VectorType vectorType, // times its scalar alignment." size = elementSize * numElements; alignment = numElements == 2 ? elementAlignment * 2 : elementAlignment * 4; - return VectorType::get(numElements, memberType); + return VectorType::get(numElements, cast<ScalarTypeInterface>(memberType)); } Type VulkanLayoutUtils::decorateType(spirv::ArrayType arrayType, diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp index b2eca539194a..54e43089dc8e 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp @@ -57,7 +57,8 @@ static bool isInvariantArg(BlockArgument arg, Block *block) { /// Constructs vector type for element type. static VectorType vectorType(VL vl, Type etp) { - return VectorType::get(vl.vectorLength, etp, vl.enableVLAVectorization); + return VectorType::get(vl.vectorLength, cast<ScalarTypeInterface>(etp), + vl.enableVLAVectorization); } /// Constructs vector type from a memref value. diff --git a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp index 0258f797143c..acd508a9b35d 100644 --- a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp +++ b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp @@ -1236,7 +1236,8 @@ Type Merger::inferType(ExprId e, Value src) const { // Inspect source type. For vector types, apply the same // vectorization to the destination type. if (auto vtp = dyn_cast<VectorType>(src.getType())) - return VectorType::get(vtp.getNumElements(), dtp, vtp.getScalableDims()); + return VectorType::get(vtp.getNumElements(), cast<ScalarTypeInterface>(dtp), + vtp.getScalableDims()); return dtp; } diff --git a/mlir/lib/Dialect/Traits.cpp b/mlir/lib/Dialect/Traits.cpp index a7aa25eae264..cc26463a84d5 100644 --- a/mlir/lib/Dialect/Traits.cpp +++ b/mlir/lib/Dialect/Traits.cpp @@ -179,7 +179,7 @@ Type OpTrait::util::getBroadcastedType(Type type1, Type type2, // Compose the final broadcasted type if (resultCompositeKind == VectorType::getTypeID()) - return VectorType::get(resultShape, elementType); + return VectorType::get(resultShape, cast<ScalarTypeInterface>(elementType)); if (resultCompositeKind == RankedTensorType::getTypeID()) return RankedTensorType::get(resultShape, elementType); return elementType; diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index 8e0e723cf4ed..73fe27bf12e1 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -2419,7 +2419,8 @@ Value BroadcastOp::createOrFoldBroadcastOp( Location loc = value.getLoc(); Type elementType = getElementTypeOrSelf(value.getType()); VectorType srcVectorType = llvm::dyn_cast<VectorType>(value.getType()); - VectorType dstVectorType = VectorType::get(dstShape, elementType); + VectorType dstVectorType = + VectorType::get(dstShape, cast<ScalarTypeInterface>(elementType)); // Step 2. If scalar -> dstShape broadcast, just do it. if (!srcVectorType) { @@ -2481,7 +2482,8 @@ Value BroadcastOp::createOrFoldBroadcastOp( .empty() && "unexpected \"dim-1\" broadcast"); - VectorType broadcastType = VectorType::get(broadcastShape, elementType); + VectorType broadcastType = + VectorType::get(broadcastShape, cast<ScalarTypeInterface>(elementType)); assert(vector::isBroadcastableTo(value.getType(), broadcastType) == vector::BroadcastableToResult::Success && "must be broadcastable"); @@ -5914,9 +5916,9 @@ void TypeCastOp::build(OpBuilder &builder, OperationState &result, Value source) { result.addOperands(source); MemRefType memRefType = llvm::cast<MemRefType>(source.getType()); - VectorType vectorType = - VectorType::get(extractShape(memRefType), - getElementTypeOrSelf(getElementTypeOrSelf(memRefType))); + VectorType vectorType = VectorType::get( + extractShape(memRefType), cast<ScalarTypeInterface>(getElementTypeOrSelf( + getElementTypeOrSelf(memRefType)))); result.addTypes(MemRefType::get({}, vectorType, MemRefLayoutAttrInterface(), memRefType.getMemorySpace())); } diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp index fec3c6c52e5e..225df20e37fa 100644 --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp @@ -112,9 +112,9 @@ public: // %a = [%u, %v] // .. // %x = [%a,%b,%c,%d] - VectorType resType = - VectorType::get(dstType.getShape().drop_front(), eltType, - dstType.getScalableDims().drop_front()); + VectorType resType = VectorType::get( + dstType.getShape().drop_front(), cast<ScalarTypeInterface>(eltType), + dstType.getScalableDims().drop_front()); Value result = rewriter.create<ub::PoisonOp>(loc, dstType); if (m == 0) { // Stetch at start. diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp index c6627b5ec0d7..c659bfc67a21 100644 --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp @@ -1367,7 +1367,8 @@ FailureOr<Value> ContractionOpToMatmulOpLowering::matchAndRewriteMaskableOp( mul = rew.create<vector::ShapeCastOp>( loc, VectorType::get({lhsRows, rhsColumns}, - getElementTypeOrSelf(op.getAcc().getType())), + cast<ScalarTypeInterface>( + getElementTypeOrSelf(op.getAcc().getType()))), mul); // ACC must be C(m, n) or C(n, m). diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp index 3b38505becd1..e22a3c0f4dfc 100644 --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp @@ -107,7 +107,8 @@ struct FlattenGather : OpRewritePattern<vector::GatherOp> { /// ```mlir /// %subview = memref.subview %M (...) /// : memref<100x3xf32> to memref<100xf32, strided<[3]>> -/// %gather = vector.gather %subview[%idxs] (...) : memref<100xf32, strided<[3]>> +/// %gather = vector.gather %subview[%idxs] (...) : memref<100xf32, +/// strided<[3]>> /// ``` /// ==> /// ```mlir @@ -200,7 +201,8 @@ struct Gather1DToConditionalLoads : OpRewritePattern<vector::GatherOp> { Location loc = op.getLoc(); Type elemTy = resultTy.getElementType(); // Vector type with a single element. Used to generate `vector.loads`. - VectorType elemVecTy = VectorType::get({1}, elemTy); + VectorType elemVecTy = + VectorType::get({1}, cast<ScalarTypeInterface>(elemTy)); Value condMask = op.getMask(); Value base = op.getBase(); diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp index e214257de2cd..7953e91f65f4 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp @@ -1192,7 +1192,8 @@ struct WarpOpExtractScalar : public WarpDistributionPattern { return failure(); int64_t elementsPerLane = extractSrcType.getShape()[0] / warpOp.getWarpSize(); - distributedVecType = VectorType::get({elementsPerLane}, elType); + distributedVecType = + VectorType::get({elementsPerLane}, cast<ScalarTypeInterface>(elType)); } else { distributedVecType = extractSrcType; } @@ -1711,8 +1712,8 @@ struct WarpOpReduction : public WarpDistributionPattern { // Return vector that will be reduced from the WarpExecuteOnLane0Op. unsigned operandIndex = yieldOperand->getOperandNumber(); SmallVector<Value> yieldValues = {reductionOp.getVector()}; - SmallVector<Type> retTypes = { - VectorType::get({numElements}, reductionOp.getType())}; + SmallVector<Type> retTypes = {VectorType::get( + {numElements}, cast<ScalarTypeInterface>(reductionOp.getType()))}; if (reductionOp.getAcc()) { yieldValues.push_back(reductionOp.getAcc()); retTypes.push_back(reductionOp.getAcc().getType()); diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp index cf6efaa04ae4..19424e7854e5 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp @@ -297,12 +297,14 @@ static VectorValue emulatedVectorLoad(OpBuilder &rewriter, Location loc, auto emulatedPerContainerElem = containerElemTy.getIntOrFloatBitWidth() / emulatedElemTy.getIntOrFloatBitWidth(); auto newLoad = rewriter.create<vector::LoadOp>( - loc, VectorType::get(numContainerElemsToLoad, containerElemTy), base, - getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices)); + loc, + VectorType::get(numContainerElemsToLoad, + cast<ScalarTypeInterface>(containerElemTy)), + base, getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices)); return rewriter.create<vector::BitCastOp>( loc, VectorType::get(numContainerElemsToLoad * emulatedPerContainerElem, - emulatedElemTy), + cast<ScalarTypeInterface>(emulatedElemTy)), newLoad); } @@ -358,7 +360,8 @@ static void atomicRMW(OpBuilder &builder, Location loc, // Load the original value from memory, and cast it to the original element // type. - auto oneElemVecType = VectorType::get({1}, origValue.getType()); + auto oneElemVecType = + VectorType::get({1}, cast<ScalarTypeInterface>(origValue.getType())); Value origVecValue = builder.create<vector::FromElementsOp>( loc, oneElemVecType, ValueRange{origValue}); @@ -378,8 +381,9 @@ static void nonAtomicRMW(OpBuilder &builder, Location loc, VectorValue valueToStore, Value mask) { assert(valueToStore.getType().getRank() == 1 && "expected 1-D vector"); - auto oneElemVecType = - VectorType::get({1}, linearizedMemref.getType().getElementType()); + auto oneElemVecType = VectorType::get( + {1}, + cast<ScalarTypeInterface>(linearizedMemref.getType().getElementType())); Value origVecValue = builder.create<vector::LoadOp>( loc, oneElemVecType, linearizedMemref, ValueRange{linearizedIndex}); origVecValue = builder.create<vector::BitCastOp>(loc, valueToStore.getType(), @@ -559,7 +563,9 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> { // Basic case: storing full bytes. auto numElements = origElements / emulatedPerContainerElem; auto bitCast = rewriter.create<vector::BitCastOp>( - loc, VectorType::get(numElements, containerElemTy), + loc, + VectorType::get(numElements, + cast<ScalarTypeInterface>(containerElemTy)), op.getValueToStore()); rewriter.replaceOpWithNewOp<vector::StoreOp>( op, bitCast.getResult(), memrefBase, @@ -665,7 +671,7 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> { auto memrefElemType = getElementTypeOrSelf(memrefBase.getType()); auto storeType = VectorType::get( {originType.getNumElements() / emulatedPerContainerElem}, - memrefElemType); + cast<ScalarTypeInterface>(memrefElemType)); auto bitCast = rewriter.create<vector::BitCastOp>(loc, storeType, fullWidthStorePart); rewriter.create<vector::StoreOp>(loc, bitCast.getResult(), memrefBase, @@ -794,7 +800,8 @@ struct ConvertVectorMaskedStore final auto numElements = (origElements + emulatedPerContainerElem - 1) / emulatedPerContainerElem; - auto newType = VectorType::get(numElements, containerElemTy); + auto newType = VectorType::get(numElements, + cast<ScalarTypeInterface>(containerElemTy)); auto passThru = rewriter.create<arith::ConstantOp>( loc, newType, rewriter.getZeroAttr(newType)); @@ -803,7 +810,8 @@ struct ConvertVectorMaskedStore final newMask.value()->getResult(0), passThru); auto newBitCastType = - VectorType::get(numElements * emulatedPerContainerElem, emulatedElemTy); + VectorType::get(numElements * emulatedPerContainerElem, + cast<ScalarTypeInterface>(emulatedElemTy)); Value valueToStore = rewriter.create<vector::BitCastOp>(loc, newBitCastType, newLoad); valueToStore = rewriter.create<arith::SelectOp>( @@ -1032,9 +1040,11 @@ struct ConvertVectorMaskedLoad final auto numElements = llvm::divideCeil(maxIntraDataOffset + origElements, emulatedPerContainerElem); - auto loadType = VectorType::get(numElements, containerElemTy); + auto loadType = VectorType::get(numElements, + cast<ScalarTypeInterface>(containerElemTy)); auto newBitcastType = - VectorType::get(numElements * emulatedPerContainerElem, emulatedElemTy); + VectorType::get(numElements * emulatedPerContainerElem, + cast<ScalarTypeInterface>(emulatedElemTy)); auto emptyVector = rewriter.create<arith::ConstantOp>( loc, newBitcastType, rewriter.getZeroAttr(newBitcastType)); @@ -1188,13 +1198,17 @@ struct ConvertVectorTransferRead final emulatedPerContainerElem); auto newRead = rewriter.create<vector::TransferReadOp>( - loc, VectorType::get(numElements, containerElemTy), adaptor.getSource(), + loc, + VectorType::get(numElements, + cast<ScalarTypeInterface>(containerElemTy)), + adaptor.getSource(), getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices), newPadding); auto bitCast = rewriter.create<vector::BitCastOp>( loc, - VectorType::get(numElements * emulatedPerContainerElem, emulatedElemTy), + VectorType::get(numElements * emulatedPerContainerElem, + cast<ScalarTypeInterface>(emulatedElemTy)), newRead); Value result = bitCast->getResult(0); diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp index dc46ed17a374..1339e3f49eab 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp @@ -607,7 +607,8 @@ struct BubbleDownVectorBitCastForExtract Location loc = extractOp.getLoc(); Value packedValue = rewriter.create<vector::ExtractOp>( loc, castOp.getSource(), index / expandRatio); - Type packedVecType = VectorType::get(/*shape=*/{1}, packedValue.getType()); + Type packedVecType = VectorType::get( + /*shape=*/{1}, cast<ScalarTypeInterface>(packedValue.getType())); Value zero = rewriter.create<arith::ConstantOp>( loc, packedVecType, rewriter.getZeroAttr(packedVecType)); packedValue = rewriter.create<vector::InsertOp>(loc, packedValue, zero, @@ -1059,7 +1060,7 @@ static Value buildVectorComparison(PatternRewriter &rewriter, Operation *op, // If we can assume all indices fit in 32-bit, we perform the vector // comparison in 32-bit to get a higher degree of SIMD parallelism. // Otherwise we perform the vector comparison using 64-bit indices. - Type idxType = + ScalarTypeInterface idxType = force32BitVectorIndices ? rewriter.getI32Type() : rewriter.getI64Type(); DenseIntElementsAttr indicesAttr; if (dim == 0 && force32BitVectorIndices) { diff --git a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp index 7b56cd0cf0e9..fa0ac4e47bac 100644 --- a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp +++ b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp @@ -337,7 +337,8 @@ Value vector::createReadOrMaskedRead(OpBuilder &builder, Location loc, auto sourceShape = sourceShapedType.getShape(); assert(sourceShape.size() == readShape.size() && "expected same ranks."); auto maskType = VectorType::get(readShape, builder.getI1Type()); - auto vectorType = VectorType::get(readShape, padValue.getType()); + auto vectorType = + VectorType::get(readShape, cast<ScalarTypeInterface>(padValue.getType())); assert(padValue.getType() == sourceShapedType.getElementType() && "expected same pad element type to match source element type"); int64_t readRank = readShape.size(); diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp index 78c242571935..31ccb14de0cf 100644 --- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp +++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp @@ -368,8 +368,9 @@ FailureOr<VectorType> TensorDescType::getDistributedVectorType() { "tensor descriptor shape is not distributable"); if (chunkSize > 1) return VectorType::get({chunkSize / wiDataSize, wiDataSize}, - getElementType()); - return VectorType::get({wiDataSize}, getElementType()); + llvm::cast<ScalarTypeInterface>(getElementType())); + return VectorType::get({wiDataSize}, + llvm::cast<ScalarTypeInterface>(getElementType())); } // Case 2: block loads/stores @@ -393,7 +394,7 @@ FailureOr<VectorType> TensorDescType::getDistributedVectorType() { tensorSize *= getArrayLength(); return VectorType::get({tensorSize / (sgSize * wiDataSize), wiDataSize}, - getElementType()); + llvm::cast<ScalarTypeInterface>(getElementType())); } } // namespace xegpu diff --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp index 3924d082f062..2c8d75aaf259 100644 --- a/mlir/lib/IR/BuiltinTypes.cpp +++ b/mlir/lib/IR/BuiltinTypes.cpp @@ -211,11 +211,12 @@ LogicalResult OpaqueType::verify(function_ref<InFlightDiagnostic()> emitError, //===----------------------------------------------------------------------===// bool VectorType::isValidElementType(Type t) { - return isValidVectorTypeElementType(t); + return llvm::isa<ScalarTypeInterface>(t); } LogicalResult VectorType::verify(function_ref<InFlightDiagnostic()> emitError, - ArrayRef<int64_t> shape, Type elementType, + ArrayRef<int64_t> shape, + ScalarTypeInterface elementType, ArrayRef<bool> scalableDims) { if (!isValidElementType(elementType)) return emitError() @@ -248,7 +249,8 @@ VectorType VectorType::scaleElementBitwidth(unsigned scale) { VectorType VectorType::cloneWith(std::optional<ArrayRef<int64_t>> shape, Type elementType) const { - return VectorType::get(shape.value_or(getShape()), elementType, + return VectorType::get(shape.value_or(getShape()), + llvm::cast<ScalarTypeInterface>(elementType), getScalableDims()); } diff --git a/mlir/lib/Target/LLVMIR/ModuleImport.cpp b/mlir/lib/Target/LLVMIR/ModuleImport.cpp index a07189ae1323..54c540b28fdb 100644 --- a/mlir/lib/Target/LLVMIR/ModuleImport.cpp +++ b/mlir/lib/Target/LLVMIR/ModuleImport.cpp @@ -139,8 +139,8 @@ static LogicalResult convertInstructionImpl(OpBuilder &odsBuilder, if (iface.isConvertibleInstruction(inst->getOpcode())) return iface.convertInstruction(odsBuilder, inst, llvmOperands, moduleImport); - // TODO: Implement the `convertInstruction` hooks in the - // `LLVMDialectLLVMIRImportInterface` and move the following include there. + // TODO: Implement the `convertInstruction` hooks in the + // `LLVMDialectLLVMIRImportInterface` and move the following include there. #include "mlir/Dialect/LLVMIR/LLVMOpFromLLVMIRConversions.inc" return failure(); } @@ -813,7 +813,7 @@ static Type getVectorTypeForAttr(Type type, ArrayRef<int64_t> arrayShape = {}) { SmallVector<int64_t> shape(arrayShape); shape.push_back(numElements.getKnownMinValue()); - return VectorType::get(shape, elementType); + return VectorType::get(shape, cast<ScalarTypeInterface>(elementType)); } Type ModuleImport::getBuiltinTypeForAttr(Type type) { diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp index 273817d53d30..6b2726970e94 100644 --- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp +++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp @@ -882,7 +882,8 @@ LogicalResult spirv::Deserializer::processType(spirv::Opcode opcode, return emitError(unknownLoc, "OpTypeVector references undefined <id> ") << operands[1]; } - typeMap[operands[0]] = VectorType::get({operands[2]}, elementTy); + typeMap[operands[0]] = + VectorType::get({operands[2]}, cast<ScalarTypeInterface>(elementTy)); } break; case spirv::Opcode::OpTypePointer: { return processOpTypePointer(operands); diff --git a/mlir/test/IR/invalid-builtin-types.mlir b/mlir/test/IR/invalid-builtin-types.mlir index 51612446d2e6..5be76d7fd387 100644 --- a/mlir/test/IR/invalid-builtin-types.mlir +++ b/mlir/test/IR/invalid-builtin-types.mlir @@ -115,7 +115,7 @@ func.func @illegaltype(i21312312323120) // expected-error {{invalid integer widt // ----- // Test no nested vector. -// expected-error@+1 {{failed to verify 'elementType': integer or index or floating-point}} +// expected-error@+1 {{failed to verify 'elementType': vector type requires scalar element type}} func.func @vectors(vector<1 x vector<1xi32>>, vector<2x4xf32>) // ----- diff --git a/mlir/test/lib/Conversion/MathToVCIX/TestMathToVCIXConversion.cpp b/mlir/test/lib/Conversion/MathToVCIX/TestMathToVCIXConversion.cpp index 1e45ab57ebcc..67eb832c471d 100644 --- a/mlir/test/lib/Conversion/MathToVCIX/TestMathToVCIXConversion.cpp +++ b/mlir/test/lib/Conversion/MathToVCIX/TestMathToVCIXConversion.cpp @@ -48,7 +48,8 @@ static std::pair<unsigned, VectorType> legalizeVectorType(const Type &type) { const unsigned lmul = eltCount * sew / 64; unsigned n = lmul > 8 ? llvm::Log2_32(lmul) - 2 : 1; - return {n, VectorType::get({eltCount >> (n - 1)}, eltTy, {true})}; + return {n, VectorType::get({eltCount >> (n - 1)}, + cast<ScalarTypeInterface>(eltTy), {true})}; } /// Replace math.cos(v) operation with vcix.v.iv(v). diff --git a/mlir/unittests/Dialect/ArmSME/TileTypeConversionTest.cpp b/mlir/unittests/Dialect/ArmSME/TileTypeConversionTest.cpp index 305f87948981..98fb71d9355e 100644 --- a/mlir/unittests/Dialect/ArmSME/TileTypeConversionTest.cpp +++ b/mlir/unittests/Dialect/ArmSME/TileTypeConversionTest.cpp @@ -31,7 +31,7 @@ TEST_F(ArmSMETest, TestTileTypeConversion) { populateArmSMEToLLVMConversionPatterns(llvmConverterWithArmSMEConversion, patterns); - Type i32 = IntegerType::get(&context, 32); + auto i32 = IntegerType::get(&context, 32); auto smeTileType = VectorType::get({4, 4}, i32, {true, true}); // An unmodified LLVMTypeConverer should fail to convert an ArmSME tile type. diff --git a/mlir/unittests/IR/ShapedTypeTest.cpp b/mlir/unittests/IR/ShapedTypeTest.cpp index bc4066ed210e..abb33f5bedea 100644 --- a/mlir/unittests/IR/ShapedTypeTest.cpp +++ b/mlir/unittests/IR/ShapedTypeTest.cpp @@ -110,10 +110,10 @@ TEST(ShapedTypeTest, CloneTensor) { TEST(ShapedTypeTest, CloneVector) { MLIRContext context; - Type i32 = IntegerType::get(&context, 32); - Type f32 = Float32Type::get(&context); + auto i32 = IntegerType::get(&context, 32); + auto f32 = Float32Type::get(&context); - Type vectorOriginalType = i32; + auto vectorOriginalType = i32; llvm::SmallVector<int64_t> vectorOriginalShape({10, 20}); ShapedType vectorType = VectorType::get(vectorOriginalShape, vectorOriginalType); @@ -123,7 +123,7 @@ TEST(ShapedTypeTest, CloneVector) { ASSERT_EQ(vectorType.clone(vectorNewShape), VectorType::get(vectorNewShape, vectorOriginalType)); // Update type. - Type vectorNewType = f32; + auto vectorNewType = f32; ASSERT_NE(vectorOriginalType, vectorNewType); ASSERT_EQ(vectorType.clone(vectorNewType), VectorType::get(vectorOriginalShape, vectorNewType)); @@ -134,7 +134,7 @@ TEST(ShapedTypeTest, CloneVector) { TEST(ShapedTypeTest, VectorTypeBuilder) { MLIRContext context; - Type f32 = Float32Type::get(&context); + auto f32 = Float32Type::get(&context); SmallVector<int64_t> shape{2, 4, 8, 9, 1}; SmallVector<bool> scalableDims{true, false, true, false, false}; |
