diff options
Diffstat (limited to 'llvm/lib/Target/SPIRV')
22 files changed, 744 insertions, 332 deletions
diff --git a/llvm/lib/Target/SPIRV/Analysis/SPIRVConvergenceRegionAnalysis.cpp b/llvm/lib/Target/SPIRV/Analysis/SPIRVConvergenceRegionAnalysis.cpp index 7f5f7d0b1e4d..25e285e35f93 100644 --- a/llvm/lib/Target/SPIRV/Analysis/SPIRVConvergenceRegionAnalysis.cpp +++ b/llvm/lib/Target/SPIRV/Analysis/SPIRVConvergenceRegionAnalysis.cpp @@ -138,7 +138,7 @@ ConvergenceRegion::ConvergenceRegion( SmallPtrSet<BasicBlock *, 8> &&Blocks, SmallPtrSet<BasicBlock *, 2> &&Exits) : DT(DT), LI(LI), ConvergenceToken(ConvergenceToken), Entry(Entry), Exits(std::move(Exits)), Blocks(std::move(Blocks)) { - for (auto *BB : this->Exits) + for ([[maybe_unused]] auto *BB : this->Exits) assert(this->Blocks.count(BB) != 0); assert(this->Blocks.count(this->Entry) != 0); } diff --git a/llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVBaseInfo.cpp b/llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVBaseInfo.cpp index d96d2bf31b62..0f9a2a69e073 100644 --- a/llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVBaseInfo.cpp +++ b/llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVBaseInfo.cpp @@ -198,6 +198,8 @@ std::string getExtInstSetName(SPIRV::InstructionSet::InstructionSet Set) { return "OpenCL.std"; case SPIRV::InstructionSet::GLSL_std_450: return "GLSL.std.450"; + case SPIRV::InstructionSet::NonSemantic_Shader_DebugInfo_100: + return "NonSemantic.Shader.DebugInfo.100"; case SPIRV::InstructionSet::SPV_AMD_shader_trinary_minmax: return "SPV_AMD_shader_trinary_minmax"; } @@ -206,8 +208,9 @@ std::string getExtInstSetName(SPIRV::InstructionSet::InstructionSet Set) { SPIRV::InstructionSet::InstructionSet getExtInstSetFromString(std::string SetName) { - for (auto Set : {SPIRV::InstructionSet::GLSL_std_450, - SPIRV::InstructionSet::OpenCL_std}) { + for (auto Set : + {SPIRV::InstructionSet::GLSL_std_450, SPIRV::InstructionSet::OpenCL_std, + SPIRV::InstructionSet::NonSemantic_Shader_DebugInfo_100}) { if (SetName == getExtInstSetName(Set)) return Set; } diff --git a/llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVBaseInfo.h b/llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVBaseInfo.h index 990eb1d230bc..44625793e941 100644 --- a/llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVBaseInfo.h +++ b/llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVBaseInfo.h @@ -197,6 +197,11 @@ namespace GLSLExtInst { #include "SPIRVGenTables.inc" } // namespace GLSLExtInst +namespace NonSemanticExtInst { +#define GET_NonSemanticExtInst_DECL +#include "SPIRVGenTables.inc" +} // namespace NonSemanticExtInst + namespace Opcode { #define GET_Opcode_DECL #include "SPIRVGenTables.inc" diff --git a/llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVTargetStreamer.h b/llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVTargetStreamer.h index 842958695e10..a6dd7138edf3 100644 --- a/llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVTargetStreamer.h +++ b/llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVTargetStreamer.h @@ -21,7 +21,7 @@ public: ~SPIRVTargetStreamer() override; void changeSection(const MCSection *CurSection, MCSection *Section, - const MCExpr *SubSection, raw_ostream &OS) override {} + uint32_t SubSection, raw_ostream &OS) override {} }; } // namespace llvm diff --git a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp index 49838e685a6d..0b93a4d85eed 100644 --- a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp @@ -300,6 +300,72 @@ lookupBuiltin(StringRef DemangledCall, return nullptr; } +static MachineInstr *getBlockStructInstr(Register ParamReg, + MachineRegisterInfo *MRI) { + // We expect the following sequence of instructions: + // %0:_(pN) = G_INTRINSIC_W_SIDE_EFFECTS intrinsic(@llvm.spv.alloca) + // or = G_GLOBAL_VALUE @block_literal_global + // %1:_(pN) = G_INTRINSIC_W_SIDE_EFFECTS intrinsic(@llvm.spv.bitcast), %0 + // %2:_(p4) = G_ADDRSPACE_CAST %1:_(pN) + MachineInstr *MI = MRI->getUniqueVRegDef(ParamReg); + assert(MI->getOpcode() == TargetOpcode::G_ADDRSPACE_CAST && + MI->getOperand(1).isReg()); + Register BitcastReg = MI->getOperand(1).getReg(); + MachineInstr *BitcastMI = MRI->getUniqueVRegDef(BitcastReg); + assert(isSpvIntrinsic(*BitcastMI, Intrinsic::spv_bitcast) && + BitcastMI->getOperand(2).isReg()); + Register ValueReg = BitcastMI->getOperand(2).getReg(); + MachineInstr *ValueMI = MRI->getUniqueVRegDef(ValueReg); + return ValueMI; +} + +// Return an integer constant corresponding to the given register and +// defined in spv_track_constant. +// TODO: maybe unify with prelegalizer pass. +static unsigned getConstFromIntrinsic(Register Reg, MachineRegisterInfo *MRI) { + MachineInstr *DefMI = MRI->getUniqueVRegDef(Reg); + assert(isSpvIntrinsic(*DefMI, Intrinsic::spv_track_constant) && + DefMI->getOperand(2).isReg()); + MachineInstr *DefMI2 = MRI->getUniqueVRegDef(DefMI->getOperand(2).getReg()); + assert(DefMI2->getOpcode() == TargetOpcode::G_CONSTANT && + DefMI2->getOperand(1).isCImm()); + return DefMI2->getOperand(1).getCImm()->getValue().getZExtValue(); +} + +// Return type of the instruction result from spv_assign_type intrinsic. +// TODO: maybe unify with prelegalizer pass. +static const Type *getMachineInstrType(MachineInstr *MI) { + MachineInstr *NextMI = MI->getNextNode(); + if (!NextMI) + return nullptr; + if (isSpvIntrinsic(*NextMI, Intrinsic::spv_assign_name)) + if ((NextMI = NextMI->getNextNode()) == nullptr) + return nullptr; + Register ValueReg = MI->getOperand(0).getReg(); + if ((!isSpvIntrinsic(*NextMI, Intrinsic::spv_assign_type) && + !isSpvIntrinsic(*NextMI, Intrinsic::spv_assign_ptr_type)) || + NextMI->getOperand(1).getReg() != ValueReg) + return nullptr; + Type *Ty = getMDOperandAsType(NextMI->getOperand(2).getMetadata(), 0); + assert(Ty && "Type is expected"); + return Ty; +} + +static const Type *getBlockStructType(Register ParamReg, + MachineRegisterInfo *MRI) { + // In principle, this information should be passed to us from Clang via + // an elementtype attribute. However, said attribute requires that + // the function call be an intrinsic, which is not. Instead, we rely on being + // able to trace this to the declaration of a variable: OpenCL C specification + // section 6.12.5 should guarantee that we can do this. + MachineInstr *MI = getBlockStructInstr(ParamReg, MRI); + if (MI->getOpcode() == TargetOpcode::G_GLOBAL_VALUE) + return MI->getOperand(1).getGlobal()->getType(); + assert(isSpvIntrinsic(*MI, Intrinsic::spv_alloca) && + "Blocks in OpenCL C must be traceable to allocation site"); + return getMachineInstrType(MI); +} + //===----------------------------------------------------------------------===// // Helper functions for building misc instructions //===----------------------------------------------------------------------===// @@ -492,16 +558,21 @@ static Register buildMemSemanticsReg(Register SemanticsRegister, static bool buildOpFromWrapper(MachineIRBuilder &MIRBuilder, unsigned Opcode, const SPIRV::IncomingCall *Call, - Register TypeReg = Register(0)) { + Register TypeReg, + ArrayRef<uint32_t> ImmArgs = {}) { MachineRegisterInfo *MRI = MIRBuilder.getMRI(); auto MIB = MIRBuilder.buildInstr(Opcode); if (TypeReg.isValid()) MIB.addDef(Call->ReturnRegister).addUse(TypeReg); - for (Register ArgReg : Call->Arguments) { + unsigned Sz = Call->Arguments.size() - ImmArgs.size(); + for (unsigned i = 0; i < Sz; ++i) { + Register ArgReg = Call->Arguments[i]; if (!MRI->getRegClassOrNull(ArgReg)) MRI->setRegClass(ArgReg, &SPIRV::IDRegClass); MIB.addUse(ArgReg); } + for (uint32_t ImmArg : ImmArgs) + MIB.addImm(ImmArg); return true; } @@ -509,7 +580,7 @@ static bool buildOpFromWrapper(MachineIRBuilder &MIRBuilder, unsigned Opcode, static bool buildAtomicInitInst(const SPIRV::IncomingCall *Call, MachineIRBuilder &MIRBuilder) { if (Call->isSpirvOp()) - return buildOpFromWrapper(MIRBuilder, SPIRV::OpStore, Call); + return buildOpFromWrapper(MIRBuilder, SPIRV::OpStore, Call, Register(0)); assert(Call->Arguments.size() == 2 && "Need 2 arguments for atomic init translation"); @@ -567,7 +638,7 @@ static bool buildAtomicStoreInst(const SPIRV::IncomingCall *Call, MachineIRBuilder &MIRBuilder, SPIRVGlobalRegistry *GR) { if (Call->isSpirvOp()) - return buildOpFromWrapper(MIRBuilder, SPIRV::OpAtomicStore, Call); + return buildOpFromWrapper(MIRBuilder, SPIRV::OpAtomicStore, Call, Register(0)); Register ScopeRegister = buildConstantIntReg(SPIRV::Scope::Device, MIRBuilder, GR); @@ -694,7 +765,7 @@ static bool buildAtomicCompareExchangeInst( return true; } -/// Helper function for building an atomic load instruction. +/// Helper function for building atomic instructions. static bool buildAtomicRMWInst(const SPIRV::IncomingCall *Call, unsigned Opcode, MachineIRBuilder &MIRBuilder, SPIRVGlobalRegistry *GR) { @@ -719,13 +790,36 @@ static bool buildAtomicRMWInst(const SPIRV::IncomingCall *Call, unsigned Opcode, MemSemanticsReg = buildMemSemanticsReg(MemSemanticsReg, PtrRegister, Semantics, MIRBuilder, GR); MRI->setRegClass(Call->Arguments[1], &SPIRV::IDRegClass); + Register ValueReg = Call->Arguments[1]; + Register ValueTypeReg = GR->getSPIRVTypeID(Call->ReturnType); + // support cl_ext_float_atomics + if (Call->ReturnType->getOpcode() == SPIRV::OpTypeFloat) { + if (Opcode == SPIRV::OpAtomicIAdd) { + Opcode = SPIRV::OpAtomicFAddEXT; + } else if (Opcode == SPIRV::OpAtomicISub) { + // Translate OpAtomicISub applied to a floating type argument to + // OpAtomicFAddEXT with the negative value operand + Opcode = SPIRV::OpAtomicFAddEXT; + Register NegValueReg = + MRI->createGenericVirtualRegister(MRI->getType(ValueReg)); + MRI->setRegClass(NegValueReg, &SPIRV::IDRegClass); + GR->assignSPIRVTypeToVReg(Call->ReturnType, NegValueReg, + MIRBuilder.getMF()); + MIRBuilder.buildInstr(TargetOpcode::G_FNEG) + .addDef(NegValueReg) + .addUse(ValueReg); + insertAssignInstr(NegValueReg, nullptr, Call->ReturnType, GR, MIRBuilder, + MIRBuilder.getMF().getRegInfo()); + ValueReg = NegValueReg; + } + } MIRBuilder.buildInstr(Opcode) .addDef(Call->ReturnRegister) - .addUse(GR->getSPIRVTypeID(Call->ReturnType)) + .addUse(ValueTypeReg) .addUse(PtrRegister) .addUse(ScopeRegister) .addUse(MemSemanticsReg) - .addUse(Call->Arguments[1]); + .addUse(ValueReg); return true; } @@ -804,7 +898,7 @@ static bool buildBarrierInst(const SPIRV::IncomingCall *Call, unsigned Opcode, MachineIRBuilder &MIRBuilder, SPIRVGlobalRegistry *GR) { if (Call->isSpirvOp()) - return buildOpFromWrapper(MIRBuilder, Opcode, Call); + return buildOpFromWrapper(MIRBuilder, Opcode, Call, Register(0)); MachineRegisterInfo *MRI = MIRBuilder.getMRI(); unsigned MemFlags = getIConstVal(Call->Arguments[0], MRI); @@ -949,7 +1043,35 @@ static bool generateGroupInst(const SPIRV::IncomingCall *Call, const SPIRV::DemangledBuiltin *Builtin = Call->Builtin; const SPIRV::GroupBuiltin *GroupBuiltin = SPIRV::lookupGroupBuiltin(Builtin->Name); + MachineRegisterInfo *MRI = MIRBuilder.getMRI(); + if (Call->isSpirvOp()) { + if (GroupBuiltin->NoGroupOperation) + return buildOpFromWrapper(MIRBuilder, GroupBuiltin->Opcode, Call, + GR->getSPIRVTypeID(Call->ReturnType)); + + // Group Operation is a literal + Register GroupOpReg = Call->Arguments[1]; + const MachineInstr *MI = getDefInstrMaybeConstant(GroupOpReg, MRI); + if (!MI || MI->getOpcode() != TargetOpcode::G_CONSTANT) + report_fatal_error( + "Group Operation parameter must be an integer constant"); + uint64_t GrpOp = MI->getOperand(1).getCImm()->getValue().getZExtValue(); + Register ScopeReg = Call->Arguments[0]; + if (!MRI->getRegClassOrNull(ScopeReg)) + MRI->setRegClass(ScopeReg, &SPIRV::IDRegClass); + Register ValueReg = Call->Arguments[2]; + if (!MRI->getRegClassOrNull(ValueReg)) + MRI->setRegClass(ValueReg, &SPIRV::IDRegClass); + MIRBuilder.buildInstr(GroupBuiltin->Opcode) + .addDef(Call->ReturnRegister) + .addUse(GR->getSPIRVTypeID(Call->ReturnType)) + .addUse(ScopeReg) + .addImm(GrpOp) + .addUse(ValueReg); + return true; + } + Register Arg0; if (GroupBuiltin->HasBoolArg) { Register ConstRegister = Call->Arguments[0]; @@ -1371,6 +1493,14 @@ static bool generateBarrierInst(const SPIRV::IncomingCall *Call, return buildBarrierInst(Call, Opcode, MIRBuilder, GR); } +static bool generateCastToPtrInst(const SPIRV::IncomingCall *Call, + MachineIRBuilder &MIRBuilder) { + MIRBuilder.buildInstr(TargetOpcode::G_ADDRSPACE_CAST) + .addDef(Call->ReturnRegister) + .addUse(Call->Arguments[0]); + return true; +} + static bool generateDotOrFMulInst(const SPIRV::IncomingCall *Call, MachineIRBuilder &MIRBuilder, SPIRVGlobalRegistry *GR) { @@ -1722,6 +1852,45 @@ static bool generateSelectInst(const SPIRV::IncomingCall *Call, return true; } +static bool generateConstructInst(const SPIRV::IncomingCall *Call, + MachineIRBuilder &MIRBuilder, + SPIRVGlobalRegistry *GR) { + return buildOpFromWrapper(MIRBuilder, SPIRV::OpCompositeConstruct, Call, + GR->getSPIRVTypeID(Call->ReturnType)); +} + +static bool generateCoopMatrInst(const SPIRV::IncomingCall *Call, + MachineIRBuilder &MIRBuilder, + SPIRVGlobalRegistry *GR) { + const SPIRV::DemangledBuiltin *Builtin = Call->Builtin; + unsigned Opcode = + SPIRV::lookupNativeBuiltin(Builtin->Name, Builtin->Set)->Opcode; + bool IsSet = Opcode != SPIRV::OpCooperativeMatrixStoreKHR; + unsigned ArgSz = Call->Arguments.size(); + unsigned LiteralIdx = 0; + if (Opcode == SPIRV::OpCooperativeMatrixLoadKHR && ArgSz > 3) + LiteralIdx = 3; + else if (Opcode == SPIRV::OpCooperativeMatrixStoreKHR && ArgSz > 4) + LiteralIdx = 4; + SmallVector<uint32_t, 1> ImmArgs; + MachineRegisterInfo *MRI = MIRBuilder.getMRI(); + if (LiteralIdx > 0) + ImmArgs.push_back(getConstFromIntrinsic(Call->Arguments[LiteralIdx], MRI)); + Register TypeReg = GR->getSPIRVTypeID(Call->ReturnType); + if (Opcode == SPIRV::OpCooperativeMatrixLengthKHR) { + SPIRVType *CoopMatrType = GR->getSPIRVTypeForVReg(Call->Arguments[0]); + if (!CoopMatrType) + report_fatal_error("Can't find a register's type definition"); + MIRBuilder.buildInstr(Opcode) + .addDef(Call->ReturnRegister) + .addUse(TypeReg) + .addUse(CoopMatrType->getOperand(0).getReg()); + return true; + } + return buildOpFromWrapper(MIRBuilder, Opcode, Call, + IsSet ? TypeReg : Register(0), ImmArgs); +} + static bool generateSpecConstantInst(const SPIRV::IncomingCall *Call, MachineIRBuilder &MIRBuilder, SPIRVGlobalRegistry *GR) { @@ -1826,7 +1995,10 @@ static bool buildNDRange(const SPIRV::IncomingCall *Call, .addDef(GlobalWorkSize) .addUse(GR->getSPIRVTypeID(SpvFieldTy)) .addUse(GWSPtr); - Const = GR->getOrCreateConsIntArray(0, MIRBuilder, SpvFieldTy); + const SPIRVSubtarget &ST = + cast<SPIRVSubtarget>(MIRBuilder.getMF().getSubtarget()); + Const = GR->getOrCreateConstIntArray(0, Size, *MIRBuilder.getInsertPt(), + SpvFieldTy, *ST.getInstrInfo()); } else { Const = GR->buildConstantInt(0, MIRBuilder, SpvTy); } @@ -1847,68 +2019,6 @@ static bool buildNDRange(const SPIRV::IncomingCall *Call, .addUse(TmpReg); } -static MachineInstr *getBlockStructInstr(Register ParamReg, - MachineRegisterInfo *MRI) { - // We expect the following sequence of instructions: - // %0:_(pN) = G_INTRINSIC_W_SIDE_EFFECTS intrinsic(@llvm.spv.alloca) - // or = G_GLOBAL_VALUE @block_literal_global - // %1:_(pN) = G_INTRINSIC_W_SIDE_EFFECTS intrinsic(@llvm.spv.bitcast), %0 - // %2:_(p4) = G_ADDRSPACE_CAST %1:_(pN) - MachineInstr *MI = MRI->getUniqueVRegDef(ParamReg); - assert(MI->getOpcode() == TargetOpcode::G_ADDRSPACE_CAST && - MI->getOperand(1).isReg()); - Register BitcastReg = MI->getOperand(1).getReg(); - MachineInstr *BitcastMI = MRI->getUniqueVRegDef(BitcastReg); - assert(isSpvIntrinsic(*BitcastMI, Intrinsic::spv_bitcast) && - BitcastMI->getOperand(2).isReg()); - Register ValueReg = BitcastMI->getOperand(2).getReg(); - MachineInstr *ValueMI = MRI->getUniqueVRegDef(ValueReg); - return ValueMI; -} - -// Return an integer constant corresponding to the given register and -// defined in spv_track_constant. -// TODO: maybe unify with prelegalizer pass. -static unsigned getConstFromIntrinsic(Register Reg, MachineRegisterInfo *MRI) { - MachineInstr *DefMI = MRI->getUniqueVRegDef(Reg); - assert(isSpvIntrinsic(*DefMI, Intrinsic::spv_track_constant) && - DefMI->getOperand(2).isReg()); - MachineInstr *DefMI2 = MRI->getUniqueVRegDef(DefMI->getOperand(2).getReg()); - assert(DefMI2->getOpcode() == TargetOpcode::G_CONSTANT && - DefMI2->getOperand(1).isCImm()); - return DefMI2->getOperand(1).getCImm()->getValue().getZExtValue(); -} - -// Return type of the instruction result from spv_assign_type intrinsic. -// TODO: maybe unify with prelegalizer pass. -static const Type *getMachineInstrType(MachineInstr *MI) { - MachineInstr *NextMI = MI->getNextNode(); - if (isSpvIntrinsic(*NextMI, Intrinsic::spv_assign_name)) - NextMI = NextMI->getNextNode(); - Register ValueReg = MI->getOperand(0).getReg(); - if (!isSpvIntrinsic(*NextMI, Intrinsic::spv_assign_type) || - NextMI->getOperand(1).getReg() != ValueReg) - return nullptr; - Type *Ty = getMDOperandAsType(NextMI->getOperand(2).getMetadata(), 0); - assert(Ty && "Type is expected"); - return Ty; -} - -static const Type *getBlockStructType(Register ParamReg, - MachineRegisterInfo *MRI) { - // In principle, this information should be passed to us from Clang via - // an elementtype attribute. However, said attribute requires that - // the function call be an intrinsic, which is not. Instead, we rely on being - // able to trace this to the declaration of a variable: OpenCL C specification - // section 6.12.5 should guarantee that we can do this. - MachineInstr *MI = getBlockStructInstr(ParamReg, MRI); - if (MI->getOpcode() == TargetOpcode::G_GLOBAL_VALUE) - return MI->getOperand(1).getGlobal()->getType(); - assert(isSpvIntrinsic(*MI, Intrinsic::spv_alloca) && - "Blocks in OpenCL C must be traceable to allocation site"); - return getMachineInstrType(MI); -} - // TODO: maybe move to the global register. static SPIRVType * getOrCreateSPIRVDeviceEventPointer(MachineIRBuilder &MIRBuilder, @@ -2322,6 +2432,8 @@ std::optional<bool> lowerBuiltin(const StringRef DemangledCall, return generateAtomicFloatingInst(Call.get(), MIRBuilder, GR); case SPIRV::Barrier: return generateBarrierInst(Call.get(), MIRBuilder, GR); + case SPIRV::CastToPtr: + return generateCastToPtrInst(Call.get(), MIRBuilder); case SPIRV::Dot: return generateDotOrFMulInst(Call.get(), MIRBuilder, GR); case SPIRV::Wave: @@ -2340,6 +2452,8 @@ std::optional<bool> lowerBuiltin(const StringRef DemangledCall, return generateSampleImageInst(DemangledCall, Call.get(), MIRBuilder, GR); case SPIRV::Select: return generateSelectInst(Call.get(), MIRBuilder); + case SPIRV::Construct: + return generateConstructInst(Call.get(), MIRBuilder, GR); case SPIRV::SpecConstant: return generateSpecConstantInst(Call.get(), MIRBuilder, GR); case SPIRV::Enqueue: @@ -2358,6 +2472,8 @@ std::optional<bool> lowerBuiltin(const StringRef DemangledCall, return generateGroupUniformInst(Call.get(), MIRBuilder, GR); case SPIRV::KernelClock: return generateKernelClockInst(Call.get(), MIRBuilder, GR); + case SPIRV::CoopMatr: + return generateCoopMatrInst(Call.get(), MIRBuilder, GR); } return false; } @@ -2376,7 +2492,7 @@ Type *parseBuiltinCallArgumentBaseType(const StringRef DemangledCall, if (hasBuiltinTypePrefix(TypeStr)) { // OpenCL builtin types in demangled call strings have the following format: // e.g. ocl_image2d_ro - bool IsOCLBuiltinType = TypeStr.consume_front("ocl_"); + [[maybe_unused]] bool IsOCLBuiltinType = TypeStr.consume_front("ocl_"); assert(IsOCLBuiltinType && "Invalid OpenCL builtin prefix"); // Check if this is pointer to a builtin type and not just pointer @@ -2482,6 +2598,22 @@ static SPIRVType *getPipeType(const TargetExtType *ExtensionType, ExtensionType->getIntParameter(0))); } +static SPIRVType *getCoopMatrType(const TargetExtType *ExtensionType, + MachineIRBuilder &MIRBuilder, + SPIRVGlobalRegistry *GR) { + assert(ExtensionType->getNumIntParameters() == 4 && + "Invalid number of parameters for SPIR-V coop matrices builtin!"); + assert(ExtensionType->getNumTypeParameters() == 1 && + "SPIR-V coop matrices builtin type must have a type parameter!"); + const SPIRVType *ElemType = + GR->getOrCreateSPIRVType(ExtensionType->getTypeParameter(0), MIRBuilder); + // Create or get an existing type from GlobalRegistry. + return GR->getOrCreateOpTypeCoopMatr( + MIRBuilder, ExtensionType, ElemType, ExtensionType->getIntParameter(0), + ExtensionType->getIntParameter(1), ExtensionType->getIntParameter(2), + ExtensionType->getIntParameter(3)); +} + static SPIRVType * getImageType(const TargetExtType *ExtensionType, const SPIRV::AccessQualifier::AccessQualifier Qualifier, @@ -2612,6 +2744,9 @@ SPIRVType *lowerBuiltinType(const Type *OpaqueType, case SPIRV::OpTypeSampledImage: TargetType = getSampledImageType(BuiltinType, MIRBuilder, GR); break; + case SPIRV::OpTypeCooperativeMatrixKHR: + TargetType = getCoopMatrType(BuiltinType, MIRBuilder, GR); + break; default: TargetType = getNonParameterizedType(BuiltinType, TypeRecord, MIRBuilder, GR); diff --git a/llvm/lib/Target/SPIRV/SPIRVBuiltins.td b/llvm/lib/Target/SPIRV/SPIRVBuiltins.td index edc9e1a33d9f..fb88332ab890 100644 --- a/llvm/lib/Target/SPIRV/SPIRVBuiltins.td +++ b/llvm/lib/Target/SPIRV/SPIRVBuiltins.td @@ -26,6 +26,7 @@ class InstructionSet<bits<32> value> { def OpenCL_std : InstructionSet<0>; def GLSL_std_450 : InstructionSet<1>; def SPV_AMD_shader_trinary_minmax : InstructionSet<2>; +def NonSemantic_Shader_DebugInfo_100 : InstructionSet<3>; // Define various builtin groups def BuiltinGroup : GenericEnum { @@ -59,6 +60,9 @@ def IntelSubgroups : BuiltinGroup; def AtomicFloating : BuiltinGroup; def GroupUniform : BuiltinGroup; def KernelClock : BuiltinGroup; +def CastToPtr : BuiltinGroup; +def Construct : BuiltinGroup; +def CoopMatr : BuiltinGroup; //===----------------------------------------------------------------------===// // Class defining a demangled builtin record. The information in the record @@ -113,6 +117,9 @@ def : DemangledBuiltin<"__spirv_ImageSampleExplicitLod", OpenCL_std, SampleImage // Select builtin record: def : DemangledBuiltin<"__spirv_Select", OpenCL_std, Select, 3, 3>; +// Composite Construct builtin record: +def : DemangledBuiltin<"__spirv_CompositeConstruct", OpenCL_std, Construct, 1, 0>; + //===----------------------------------------------------------------------===// // Class defining an extended builtin record used for lowering into an // OpExtInst instruction. @@ -170,6 +177,17 @@ class GLSLExtInst<string name, bits<32> value> { bits<32> Value = value; } +def NonSemanticExtInst : GenericEnum { + let FilterClass = "NonSemanticExtInst"; + let NameField = "Name"; + let ValueField = "Value"; +} + +class NonSemanticExtInst<string name, bits<32> value> { + string Name = name; + bits<32> Value = value; +} + // Multiclass used to define at the same time both a demangled builtin record // and a corresponding extended builtin record. multiclass DemangledExtendedBuiltin<string name, InstructionSet set, int number> { @@ -183,6 +201,10 @@ multiclass DemangledExtendedBuiltin<string name, InstructionSet set, int number> if !eq(set, GLSL_std_450) then { def : GLSLExtInst<name, number>; } + + if !eq(set, NonSemantic_Shader_DebugInfo_100) then { + def : NonSemanticExtInst<name, number>; + } } // Extended builtin records: @@ -430,6 +452,50 @@ defm : DemangledExtendedBuiltin<"NMin", GLSL_std_450, 79>; defm : DemangledExtendedBuiltin<"NMax", GLSL_std_450, 80>; defm : DemangledExtendedBuiltin<"NClamp", GLSL_std_450, 81>; +defm : DemangledExtendedBuiltin<"DebugInfoNone", NonSemantic_Shader_DebugInfo_100, 0>; +defm : DemangledExtendedBuiltin<"DebugCompilationUnit", NonSemantic_Shader_DebugInfo_100, 1>; +defm : DemangledExtendedBuiltin<"DebugTypeBasic", NonSemantic_Shader_DebugInfo_100, 2>; +defm : DemangledExtendedBuiltin<"DebugTypePointer", NonSemantic_Shader_DebugInfo_100, 3>; +defm : DemangledExtendedBuiltin<"DebugTypeQualifier", NonSemantic_Shader_DebugInfo_100, 4>; +defm : DemangledExtendedBuiltin<"DebugTypeArray", NonSemantic_Shader_DebugInfo_100, 5>; +defm : DemangledExtendedBuiltin<"DebugTypeVector", NonSemantic_Shader_DebugInfo_100, 6>; +defm : DemangledExtendedBuiltin<"DebugTypedef", NonSemantic_Shader_DebugInfo_100, 7>; +defm : DemangledExtendedBuiltin<"DebugTypeFunction", NonSemantic_Shader_DebugInfo_100, 8>; +defm : DemangledExtendedBuiltin<"DebugTypeEnum", NonSemantic_Shader_DebugInfo_100, 9>; +defm : DemangledExtendedBuiltin<"DebugTypeComposite", NonSemantic_Shader_DebugInfo_100, 10>; +defm : DemangledExtendedBuiltin<"DebugTypeMember", NonSemantic_Shader_DebugInfo_100, 11>; +defm : DemangledExtendedBuiltin<"DebugTypeInheritance", NonSemantic_Shader_DebugInfo_100, 12>; +defm : DemangledExtendedBuiltin<"DebugTypePtrToMember", NonSemantic_Shader_DebugInfo_100, 13>; +defm : DemangledExtendedBuiltin<"DebugTypeTemplate", NonSemantic_Shader_DebugInfo_100, 14>; +defm : DemangledExtendedBuiltin<"DebugTypeTemplateParameter", NonSemantic_Shader_DebugInfo_100, 15>; +defm : DemangledExtendedBuiltin<"DebugTypeTemplateTemplateParameter", NonSemantic_Shader_DebugInfo_100, 16>; +defm : DemangledExtendedBuiltin<"DebugTypeTemplateParameterPack", NonSemantic_Shader_DebugInfo_100, 17>; +defm : DemangledExtendedBuiltin<"DebugGlobalVariable", NonSemantic_Shader_DebugInfo_100, 18>; +defm : DemangledExtendedBuiltin<"DebugFunctionDeclaration", NonSemantic_Shader_DebugInfo_100, 19>; +defm : DemangledExtendedBuiltin<"DebugFunction", NonSemantic_Shader_DebugInfo_100, 20>; +defm : DemangledExtendedBuiltin<"DebugLexicalBlock", NonSemantic_Shader_DebugInfo_100, 21>; +defm : DemangledExtendedBuiltin<"DebugLexicalBlockDiscriminator", NonSemantic_Shader_DebugInfo_100, 22>; +defm : DemangledExtendedBuiltin<"DebugScope", NonSemantic_Shader_DebugInfo_100, 23>; +defm : DemangledExtendedBuiltin<"DebugNoScope", NonSemantic_Shader_DebugInfo_100, 24>; +defm : DemangledExtendedBuiltin<"DebugInlinedAt", NonSemantic_Shader_DebugInfo_100, 25>; +defm : DemangledExtendedBuiltin<"DebugLocalVariable", NonSemantic_Shader_DebugInfo_100, 26>; +defm : DemangledExtendedBuiltin<"DebugInlinedVariable", NonSemantic_Shader_DebugInfo_100, 27>; +defm : DemangledExtendedBuiltin<"DebugDeclare", NonSemantic_Shader_DebugInfo_100, 28>; +defm : DemangledExtendedBuiltin<"DebugValue", NonSemantic_Shader_DebugInfo_100, 29>; +defm : DemangledExtendedBuiltin<"DebugOperation", NonSemantic_Shader_DebugInfo_100, 30>; +defm : DemangledExtendedBuiltin<"DebugExpression", NonSemantic_Shader_DebugInfo_100, 31>; +defm : DemangledExtendedBuiltin<"DebugMacroDef", NonSemantic_Shader_DebugInfo_100, 32>; +defm : DemangledExtendedBuiltin<"DebugMacroUndef", NonSemantic_Shader_DebugInfo_100, 33>; +defm : DemangledExtendedBuiltin<"DebugImportedEntity", NonSemantic_Shader_DebugInfo_100, 34>; +defm : DemangledExtendedBuiltin<"DebugSource", NonSemantic_Shader_DebugInfo_100, 35>; +defm : DemangledExtendedBuiltin<"DebugFunctionDefinition", NonSemantic_Shader_DebugInfo_100, 101>; +defm : DemangledExtendedBuiltin<"DebugSourceContinued", NonSemantic_Shader_DebugInfo_100, 102>; +defm : DemangledExtendedBuiltin<"DebugLine", NonSemantic_Shader_DebugInfo_100, 103>; +defm : DemangledExtendedBuiltin<"DebugNoLine", NonSemantic_Shader_DebugInfo_100, 104>; +defm : DemangledExtendedBuiltin<"DebugBuildIdentifier", NonSemantic_Shader_DebugInfo_100, 105>; +defm : DemangledExtendedBuiltin<"DebugStoragePath", NonSemantic_Shader_DebugInfo_100, 106>; +defm : DemangledExtendedBuiltin<"DebugEntryPoint", NonSemantic_Shader_DebugInfo_100, 107>; +defm : DemangledExtendedBuiltin<"DebugTypeMatrix", NonSemantic_Shader_DebugInfo_100, 108>; //===----------------------------------------------------------------------===// // Class defining an native builtin record used for direct translation into a // SPIR-V instruction. @@ -532,6 +598,7 @@ defm : DemangledNativeBuiltin<"__spirv_AtomicAnd", OpenCL_std, Atomic, 4, 4, OpA defm : DemangledNativeBuiltin<"atomic_exchange", OpenCL_std, Atomic, 2, 4, OpAtomicExchange>; defm : DemangledNativeBuiltin<"atomic_exchange_explicit", OpenCL_std, Atomic, 2, 4, OpAtomicExchange>; defm : DemangledNativeBuiltin<"AtomicEx__spirv_change", OpenCL_std, Atomic, 2, 4, OpAtomicExchange>; +defm : DemangledNativeBuiltin<"__spirv_AtomicExchange", OpenCL_std, Atomic, 4, 4, OpAtomicExchange>; defm : DemangledNativeBuiltin<"atomic_work_item_fence", OpenCL_std, Atomic, 1, 3, OpMemoryBarrier>; defm : DemangledNativeBuiltin<"__spirv_MemoryBarrier", OpenCL_std, Atomic, 2, 2, OpMemoryBarrier>; defm : DemangledNativeBuiltin<"atomic_fetch_add", OpenCL_std, Atomic, 2, 4, OpAtomicIAdd>; @@ -539,11 +606,11 @@ defm : DemangledNativeBuiltin<"atomic_fetch_sub", OpenCL_std, Atomic, 2, 4, OpAt defm : DemangledNativeBuiltin<"atomic_fetch_or", OpenCL_std, Atomic, 2, 4, OpAtomicOr>; defm : DemangledNativeBuiltin<"atomic_fetch_xor", OpenCL_std, Atomic, 2, 4, OpAtomicXor>; defm : DemangledNativeBuiltin<"atomic_fetch_and", OpenCL_std, Atomic, 2, 4, OpAtomicAnd>; -defm : DemangledNativeBuiltin<"atomic_fetch_add_explicit", OpenCL_std, Atomic, 4, 6, OpAtomicIAdd>; -defm : DemangledNativeBuiltin<"atomic_fetch_sub_explicit", OpenCL_std, Atomic, 4, 6, OpAtomicISub>; -defm : DemangledNativeBuiltin<"atomic_fetch_or_explicit", OpenCL_std, Atomic, 4, 6, OpAtomicOr>; -defm : DemangledNativeBuiltin<"atomic_fetch_xor_explicit", OpenCL_std, Atomic, 4, 6, OpAtomicXor>; -defm : DemangledNativeBuiltin<"atomic_fetch_and_explicit", OpenCL_std, Atomic, 4, 6, OpAtomicAnd>; +defm : DemangledNativeBuiltin<"atomic_fetch_add_explicit", OpenCL_std, Atomic, 3, 4, OpAtomicIAdd>; +defm : DemangledNativeBuiltin<"atomic_fetch_sub_explicit", OpenCL_std, Atomic, 3, 4, OpAtomicISub>; +defm : DemangledNativeBuiltin<"atomic_fetch_or_explicit", OpenCL_std, Atomic, 3, 4, OpAtomicOr>; +defm : DemangledNativeBuiltin<"atomic_fetch_xor_explicit", OpenCL_std, Atomic, 3, 4, OpAtomicXor>; +defm : DemangledNativeBuiltin<"atomic_fetch_and_explicit", OpenCL_std, Atomic, 3, 4, OpAtomicAnd>; defm : DemangledNativeBuiltin<"atomic_flag_test_and_set", OpenCL_std, Atomic, 1, 1, OpAtomicFlagTestAndSet>; defm : DemangledNativeBuiltin<"__spirv_AtomicFlagTestAndSet", OpenCL_std, Atomic, 3, 3, OpAtomicFlagTestAndSet>; defm : DemangledNativeBuiltin<"atomic_flag_test_and_set_explicit", OpenCL_std, Atomic, 2, 3, OpAtomicFlagTestAndSet>; @@ -595,6 +662,23 @@ defm : DemangledNativeBuiltin<"__spirv_GroupWaitEvents", OpenCL_std, AsyncCopy, defm : DemangledNativeBuiltin<"__spirv_Load", OpenCL_std, LoadStore, 1, 3, OpLoad>; defm : DemangledNativeBuiltin<"__spirv_Store", OpenCL_std, LoadStore, 2, 4, OpStore>; +// Address Space Qualifier Functions/Pointers Conversion Instructions: +defm : DemangledNativeBuiltin<"to_global", OpenCL_std, CastToPtr, 1, 1, OpGenericCastToPtr>; +defm : DemangledNativeBuiltin<"to_local", OpenCL_std, CastToPtr, 1, 1, OpGenericCastToPtr>; +defm : DemangledNativeBuiltin<"to_private", OpenCL_std, CastToPtr, 1, 1, OpGenericCastToPtr>; +defm : DemangledNativeBuiltin<"__spirv_GenericCastToPtr_ToGlobal", OpenCL_std, CastToPtr, 2, 2, OpGenericCastToPtr>; +defm : DemangledNativeBuiltin<"__spirv_GenericCastToPtr_ToLocal", OpenCL_std, CastToPtr, 2, 2, OpGenericCastToPtr>; +defm : DemangledNativeBuiltin<"__spirv_GenericCastToPtr_ToPrivate", OpenCL_std, CastToPtr, 2, 2, OpGenericCastToPtr>; +defm : DemangledNativeBuiltin<"__spirv_GenericCastToPtrExplicit_ToGlobal", OpenCL_std, CastToPtr, 2, 2, OpGenericCastToPtr>; +defm : DemangledNativeBuiltin<"__spirv_GenericCastToPtrExplicit_ToLocal", OpenCL_std, CastToPtr, 2, 2, OpGenericCastToPtr>; +defm : DemangledNativeBuiltin<"__spirv_GenericCastToPtrExplicit_ToPrivate", OpenCL_std, CastToPtr, 2, 2, OpGenericCastToPtr>; + +// Cooperative Matrix builtin records: +defm : DemangledNativeBuiltin<"__spirv_CooperativeMatrixLoadKHR", OpenCL_std, CoopMatr, 2, 0, OpCooperativeMatrixLoadKHR>; +defm : DemangledNativeBuiltin<"__spirv_CooperativeMatrixStoreKHR", OpenCL_std, CoopMatr, 3, 0, OpCooperativeMatrixStoreKHR>; +defm : DemangledNativeBuiltin<"__spirv_CooperativeMatrixMulAddKHR", OpenCL_std, CoopMatr, 3, 0, OpCooperativeMatrixMulAddKHR>; +defm : DemangledNativeBuiltin<"__spirv_CooperativeMatrixLengthKHR", OpenCL_std, CoopMatr, 1, 1, OpCooperativeMatrixLengthKHR>; + //===----------------------------------------------------------------------===// // Class defining a work/sub group builtin that should be translated into a // SPIR-V instruction using the defined properties. @@ -682,9 +766,17 @@ multiclass DemangledGroupBuiltin<string name, int level /* OnlyWork/OnlySub/... } } +multiclass DemangledGroupBuiltinWrapper<string name, bits<8> minNumArgs, bits<8> maxNumArgs, Op operation> { + def : DemangledBuiltin<name, OpenCL_std, Group, minNumArgs, maxNumArgs>; + def : GroupBuiltin<name, operation>; +} + defm : DemangledGroupBuiltin<"group_all", WorkOrSub, OpGroupAll>; +defm : DemangledGroupBuiltinWrapper<"__spirv_GroupAll", 2, 2, OpGroupAll>; defm : DemangledGroupBuiltin<"group_any", WorkOrSub, OpGroupAny>; +defm : DemangledGroupBuiltinWrapper<"__spirv_GroupAny", 2, 2, OpGroupAny>; defm : DemangledGroupBuiltin<"group_broadcast", WorkOrSub, OpGroupBroadcast>; +defm : DemangledGroupBuiltinWrapper<"__spirv_GroupBroadcast", 3, 3, OpGroupBroadcast>; defm : DemangledGroupBuiltin<"group_non_uniform_broadcast", OnlySub, OpGroupNonUniformBroadcast>; defm : DemangledGroupBuiltin<"group_broadcast_first", OnlySub, OpGroupNonUniformBroadcastFirst>; @@ -719,41 +811,49 @@ defm : DemangledGroupBuiltin<"group_scan_inclusive_adds", WorkOrSub, OpGroupIAdd defm : DemangledGroupBuiltin<"group_reduce_addu", WorkOrSub, OpGroupIAdd>; defm : DemangledGroupBuiltin<"group_scan_exclusive_addu", WorkOrSub, OpGroupIAdd>; defm : DemangledGroupBuiltin<"group_scan_inclusive_addu", WorkOrSub, OpGroupIAdd>; +defm : DemangledGroupBuiltinWrapper<"__spirv_GroupIAdd", 3, 3, OpGroupIAdd>; defm : DemangledGroupBuiltin<"group_fadd", WorkOrSub, OpGroupFAdd>; defm : DemangledGroupBuiltin<"group_reduce_addf", WorkOrSub, OpGroupFAdd>; defm : DemangledGroupBuiltin<"group_scan_exclusive_addf", WorkOrSub, OpGroupFAdd>; defm : DemangledGroupBuiltin<"group_scan_inclusive_addf", WorkOrSub, OpGroupFAdd>; +defm : DemangledGroupBuiltinWrapper<"__spirv_GroupFAdd", 3, 3, OpGroupFAdd>; defm : DemangledGroupBuiltin<"group_fmin", WorkOrSub, OpGroupFMin>; defm : DemangledGroupBuiltin<"group_reduce_minf", WorkOrSub, OpGroupFMin>; defm : DemangledGroupBuiltin<"group_scan_exclusive_minf", WorkOrSub, OpGroupFMin>; defm : DemangledGroupBuiltin<"group_scan_inclusive_minf", WorkOrSub, OpGroupFMin>; +defm : DemangledGroupBuiltinWrapper<"__spirv_GroupFMin", 3, 3, OpGroupFMin>; defm : DemangledGroupBuiltin<"group_umin", WorkOrSub, OpGroupUMin>; defm : DemangledGroupBuiltin<"group_reduce_minu", WorkOrSub, OpGroupUMin>; defm : DemangledGroupBuiltin<"group_scan_exclusive_minu", WorkOrSub, OpGroupUMin>; defm : DemangledGroupBuiltin<"group_scan_inclusive_minu", WorkOrSub, OpGroupUMin>; +defm : DemangledGroupBuiltinWrapper<"__spirv_GroupUMin", 3, 3, OpGroupUMin>; defm : DemangledGroupBuiltin<"group_smin", WorkOrSub, OpGroupSMin>; defm : DemangledGroupBuiltin<"group_reduce_mins", WorkOrSub, OpGroupSMin>; defm : DemangledGroupBuiltin<"group_scan_exclusive_mins", WorkOrSub, OpGroupSMin>; defm : DemangledGroupBuiltin<"group_scan_inclusive_mins", WorkOrSub, OpGroupSMin>; +defm : DemangledGroupBuiltinWrapper<"__spirv_GroupSMin", 3, 3, OpGroupSMin>; defm : DemangledGroupBuiltin<"group_fmax", WorkOrSub, OpGroupFMax>; defm : DemangledGroupBuiltin<"group_reduce_maxf", WorkOrSub, OpGroupFMax>; defm : DemangledGroupBuiltin<"group_scan_exclusive_maxf", WorkOrSub, OpGroupFMax>; defm : DemangledGroupBuiltin<"group_scan_inclusive_maxf", WorkOrSub, OpGroupFMax>; +defm : DemangledGroupBuiltinWrapper<"__spirv_GroupFMax", 3, 3, OpGroupFMax>; defm : DemangledGroupBuiltin<"group_umax", WorkOrSub, OpGroupUMax>; defm : DemangledGroupBuiltin<"group_reduce_maxu", WorkOrSub, OpGroupUMax>; defm : DemangledGroupBuiltin<"group_scan_exclusive_maxu", WorkOrSub, OpGroupUMax>; defm : DemangledGroupBuiltin<"group_scan_inclusive_maxu", WorkOrSub, OpGroupUMax>; +defm : DemangledGroupBuiltinWrapper<"__spirv_GroupUMax", 3, 3, OpGroupUMax>; defm : DemangledGroupBuiltin<"group_smax", WorkOrSub, OpGroupSMax>; defm : DemangledGroupBuiltin<"group_reduce_maxs", WorkOrSub, OpGroupSMax>; defm : DemangledGroupBuiltin<"group_scan_exclusive_maxs", WorkOrSub, OpGroupSMax>; defm : DemangledGroupBuiltin<"group_scan_inclusive_maxs", WorkOrSub, OpGroupSMax>; +defm : DemangledGroupBuiltinWrapper<"__spirv_GroupSMax", 3, 3, OpGroupSMax>; // cl_khr_subgroup_non_uniform_arithmetic defm : DemangledGroupBuiltin<"group_non_uniform_iadd", WorkOrSub, OpGroupNonUniformIAdd>; @@ -997,8 +1097,6 @@ multiclass DemangledAtomicFloatingBuiltin<string name, bits<8> minNumArgs, bits< defm : DemangledAtomicFloatingBuiltin<"AddEXT", 4, 4, OpAtomicFAddEXT>; defm : DemangledAtomicFloatingBuiltin<"MinEXT", 4, 4, OpAtomicFMinEXT>; defm : DemangledAtomicFloatingBuiltin<"MaxEXT", 4, 4, OpAtomicFMaxEXT>; -// TODO: add support for cl_ext_float_atomics to enable performing atomic operations -// on floating-point numbers in memory (float arguments for atomic_fetch_add, ...) //===----------------------------------------------------------------------===// // Class defining a sub group builtin that should be translated into a @@ -1407,7 +1505,7 @@ def : BuiltinType<"spirv.DeviceEvent", OpTypeDeviceEvent>; def : BuiltinType<"spirv.Image", OpTypeImage>; def : BuiltinType<"spirv.SampledImage", OpTypeSampledImage>; def : BuiltinType<"spirv.Pipe", OpTypePipe>; - +def : BuiltinType<"spirv.CooperativeMatrixKHR", OpTypeCooperativeMatrixKHR>; //===----------------------------------------------------------------------===// // Class matching an OpenCL builtin type name to an equivalent SPIR-V diff --git a/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp b/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp index 75aa1823b11f..c7c244cfa897 100644 --- a/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp @@ -66,6 +66,8 @@ static const std::map<std::string, SPIRV::Extension::Extension> SPIRV::Extension::Extension::SPV_INTEL_function_pointers}, {"SPV_KHR_shader_clock", SPIRV::Extension::Extension::SPV_KHR_shader_clock}, + {"SPV_KHR_cooperative_matrix", + SPIRV::Extension::Extension::SPV_KHR_cooperative_matrix}, }; bool SPIRVExtensionsParser::parse(cl::Option &O, llvm::StringRef ArgName, diff --git a/llvm/lib/Target/SPIRV/SPIRVDuplicatesTracker.h b/llvm/lib/Target/SPIRV/SPIRVDuplicatesTracker.h index 2ec3fb35ca04..a37e65a47eda 100644 --- a/llvm/lib/Target/SPIRV/SPIRVDuplicatesTracker.h +++ b/llvm/lib/Target/SPIRV/SPIRVDuplicatesTracker.h @@ -16,6 +16,7 @@ #include "MCTargetDesc/SPIRVBaseInfo.h" #include "MCTargetDesc/SPIRVMCTargetDesc.h" +#include "SPIRVUtils.h" #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/MapVector.h" #include "llvm/CodeGen/GlobalISel/MachineIRBuilder.h" @@ -51,152 +52,87 @@ public: void addDep(DTSortableEntry *E) { Deps.push_back(E); } }; -struct SpecialTypeDescriptor { - enum SpecialTypeKind { - STK_Empty = 0, - STK_Image, - STK_SampledImage, - STK_Sampler, - STK_Pipe, - STK_DeviceEvent, - STK_Pointer, - STK_Last = -1 - }; - SpecialTypeKind Kind; - - unsigned Hash; - - SpecialTypeDescriptor() = delete; - SpecialTypeDescriptor(SpecialTypeKind K) : Kind(K) { Hash = Kind; } - - unsigned getHash() const { return Hash; } - - virtual ~SpecialTypeDescriptor() {} -}; - -struct ImageTypeDescriptor : public SpecialTypeDescriptor { - union ImageAttrs { - struct BitFlags { - unsigned Dim : 3; - unsigned Depth : 2; - unsigned Arrayed : 1; - unsigned MS : 1; - unsigned Sampled : 2; - unsigned ImageFormat : 6; - unsigned AQ : 2; - } Flags; - unsigned Val; - }; - - ImageTypeDescriptor(const Type *SampledTy, unsigned Dim, unsigned Depth, - unsigned Arrayed, unsigned MS, unsigned Sampled, - unsigned ImageFormat, unsigned AQ = 0) - : SpecialTypeDescriptor(SpecialTypeKind::STK_Image) { - ImageAttrs Attrs; - Attrs.Val = 0; - Attrs.Flags.Dim = Dim; - Attrs.Flags.Depth = Depth; - Attrs.Flags.Arrayed = Arrayed; - Attrs.Flags.MS = MS; - Attrs.Flags.Sampled = Sampled; - Attrs.Flags.ImageFormat = ImageFormat; - Attrs.Flags.AQ = AQ; - Hash = (DenseMapInfo<Type *>().getHashValue(SampledTy) & 0xffff) ^ - ((Attrs.Val << 8) | Kind); - } - - static bool classof(const SpecialTypeDescriptor *TD) { - return TD->Kind == SpecialTypeKind::STK_Image; - } -}; - -struct SampledImageTypeDescriptor : public SpecialTypeDescriptor { - SampledImageTypeDescriptor(const Type *SampledTy, const MachineInstr *ImageTy) - : SpecialTypeDescriptor(SpecialTypeKind::STK_SampledImage) { - assert(ImageTy->getOpcode() == SPIRV::OpTypeImage); - ImageTypeDescriptor TD( - SampledTy, ImageTy->getOperand(2).getImm(), - ImageTy->getOperand(3).getImm(), ImageTy->getOperand(4).getImm(), - ImageTy->getOperand(5).getImm(), ImageTy->getOperand(6).getImm(), - ImageTy->getOperand(7).getImm(), ImageTy->getOperand(8).getImm()); - Hash = TD.getHash() ^ Kind; - } - - static bool classof(const SpecialTypeDescriptor *TD) { - return TD->Kind == SpecialTypeKind::STK_SampledImage; - } -}; - -struct SamplerTypeDescriptor : public SpecialTypeDescriptor { - SamplerTypeDescriptor() - : SpecialTypeDescriptor(SpecialTypeKind::STK_Sampler) { - Hash = Kind; - } - - static bool classof(const SpecialTypeDescriptor *TD) { - return TD->Kind == SpecialTypeKind::STK_Sampler; - } +enum SpecialTypeKind { + STK_Empty = 0, + STK_Image, + STK_SampledImage, + STK_Sampler, + STK_Pipe, + STK_DeviceEvent, + STK_Pointer, + STK_Last = -1 }; -struct PipeTypeDescriptor : public SpecialTypeDescriptor { - - PipeTypeDescriptor(uint8_t AQ) - : SpecialTypeDescriptor(SpecialTypeKind::STK_Pipe) { - Hash = (AQ << 8) | Kind; - } - - static bool classof(const SpecialTypeDescriptor *TD) { - return TD->Kind == SpecialTypeKind::STK_Pipe; +using SpecialTypeDescriptor = std::tuple<const Type *, unsigned, unsigned>; + +union ImageAttrs { + struct BitFlags { + unsigned Dim : 3; + unsigned Depth : 2; + unsigned Arrayed : 1; + unsigned MS : 1; + unsigned Sampled : 2; + unsigned ImageFormat : 6; + unsigned AQ : 2; + } Flags; + unsigned Val; + + ImageAttrs(unsigned Dim, unsigned Depth, unsigned Arrayed, unsigned MS, + unsigned Sampled, unsigned ImageFormat, unsigned AQ = 0) { + Val = 0; + Flags.Dim = Dim; + Flags.Depth = Depth; + Flags.Arrayed = Arrayed; + Flags.MS = MS; + Flags.Sampled = Sampled; + Flags.ImageFormat = ImageFormat; + Flags.AQ = AQ; } }; -struct DeviceEventTypeDescriptor : public SpecialTypeDescriptor { - - DeviceEventTypeDescriptor() - : SpecialTypeDescriptor(SpecialTypeKind::STK_DeviceEvent) { - Hash = Kind; - } - - static bool classof(const SpecialTypeDescriptor *TD) { - return TD->Kind == SpecialTypeKind::STK_DeviceEvent; - } -}; - -struct PointerTypeDescriptor : public SpecialTypeDescriptor { - const Type *ElementType; - unsigned AddressSpace; - - PointerTypeDescriptor() = delete; - PointerTypeDescriptor(const Type *ElementType, unsigned AddressSpace) - : SpecialTypeDescriptor(SpecialTypeKind::STK_Pointer), - ElementType(ElementType), AddressSpace(AddressSpace) { - Hash = (DenseMapInfo<Type *>().getHashValue(ElementType) & 0xffff) ^ - ((AddressSpace << 8) | Kind); - } - - static bool classof(const SpecialTypeDescriptor *TD) { - return TD->Kind == SpecialTypeKind::STK_Pointer; - } -}; +inline SpecialTypeDescriptor +make_descr_image(const Type *SampledTy, unsigned Dim, unsigned Depth, + unsigned Arrayed, unsigned MS, unsigned Sampled, + unsigned ImageFormat, unsigned AQ = 0) { + return std::make_tuple( + SampledTy, + ImageAttrs(Dim, Depth, Arrayed, MS, Sampled, ImageFormat, AQ).Val, + SpecialTypeKind::STK_Image); +} + +inline SpecialTypeDescriptor +make_descr_sampled_image(const Type *SampledTy, const MachineInstr *ImageTy) { + assert(ImageTy->getOpcode() == SPIRV::OpTypeImage); + return std::make_tuple( + SampledTy, + ImageAttrs( + ImageTy->getOperand(2).getImm(), ImageTy->getOperand(3).getImm(), + ImageTy->getOperand(4).getImm(), ImageTy->getOperand(5).getImm(), + ImageTy->getOperand(6).getImm(), ImageTy->getOperand(7).getImm(), + ImageTy->getOperand(8).getImm()) + .Val, + SpecialTypeKind::STK_SampledImage); +} + +inline SpecialTypeDescriptor make_descr_sampler() { + return std::make_tuple(nullptr, 0U, SpecialTypeKind::STK_Sampler); +} + +inline SpecialTypeDescriptor make_descr_pipe(uint8_t AQ) { + return std::make_tuple(nullptr, AQ, SpecialTypeKind::STK_Pipe); +} + +inline SpecialTypeDescriptor make_descr_event() { + return std::make_tuple(nullptr, 0U, SpecialTypeKind::STK_DeviceEvent); +} + +inline SpecialTypeDescriptor make_descr_pointee(const Type *ElementType, + unsigned AddressSpace) { + return std::make_tuple(ElementType, AddressSpace, + SpecialTypeKind::STK_Pointer); +} } // namespace SPIRV -template <> struct DenseMapInfo<SPIRV::SpecialTypeDescriptor> { - static inline SPIRV::SpecialTypeDescriptor getEmptyKey() { - return SPIRV::SpecialTypeDescriptor( - SPIRV::SpecialTypeDescriptor::STK_Empty); - } - static inline SPIRV::SpecialTypeDescriptor getTombstoneKey() { - return SPIRV::SpecialTypeDescriptor(SPIRV::SpecialTypeDescriptor::STK_Last); - } - static unsigned getHashValue(SPIRV::SpecialTypeDescriptor Val) { - return Val.getHash(); - } - static bool isEqual(SPIRV::SpecialTypeDescriptor LHS, - SPIRV::SpecialTypeDescriptor RHS) { - return getHashValue(LHS) == getHashValue(RHS); - } -}; - template <typename KeyTy> class SPIRVDuplicatesTrackerBase { public: // NOTE: using MapVector instead of DenseMap helps getting everything ordered @@ -282,12 +218,12 @@ public: MachineModuleInfo *MMI); void add(const Type *Ty, const MachineFunction *MF, Register R) { - TT.add(Ty, MF, R); + TT.add(unifyPtrType(Ty), MF, R); } - void add(const Type *PointerElementType, unsigned AddressSpace, + void add(const Type *PointeeTy, unsigned AddressSpace, const MachineFunction *MF, Register R) { - ST.add(SPIRV::PointerTypeDescriptor(PointerElementType, AddressSpace), MF, + ST.add(SPIRV::make_descr_pointee(unifyPtrType(PointeeTy), AddressSpace), MF, R); } @@ -317,13 +253,13 @@ public: } Register find(const Type *Ty, const MachineFunction *MF) { - return TT.find(const_cast<Type *>(Ty), MF); + return TT.find(unifyPtrType(Ty), MF); } - Register find(const Type *PointerElementType, unsigned AddressSpace, + Register find(const Type *PointeeTy, unsigned AddressSpace, const MachineFunction *MF) { return ST.find( - SPIRV::PointerTypeDescriptor(PointerElementType, AddressSpace), MF); + SPIRV::make_descr_pointee(unifyPtrType(PointeeTy), AddressSpace), MF); } Register find(const Constant *C, const MachineFunction *MF) { diff --git a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp index 7b8e3230bf55..dd5884096b85 100644 --- a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp @@ -69,7 +69,7 @@ class SPIRVEmitIntrinsics DenseSet<Instruction *> AggrStores; // deduce element type of untyped pointers - Type *deduceElementType(Value *I); + Type *deduceElementType(Value *I, bool UnknownElemTypeI8); Type *deduceElementTypeHelper(Value *I); Type *deduceElementTypeHelper(Value *I, std::unordered_set<Value *> &Visited); Type *deduceElementTypeByValueDeep(Type *ValueTy, Value *Operand, @@ -105,7 +105,8 @@ class SPIRVEmitIntrinsics void replaceMemInstrUses(Instruction *Old, Instruction *New, IRBuilder<> &B); void processInstrAfterVisit(Instruction *I, IRBuilder<> &B); - void insertAssignPtrTypeIntrs(Instruction *I, IRBuilder<> &B); + bool insertAssignPtrTypeIntrs(Instruction *I, IRBuilder<> &B, + bool UnknownElemTypeI8); void insertAssignTypeIntrs(Instruction *I, IRBuilder<> &B); void insertAssignPtrTypeTargetExt(TargetExtType *AssignedType, Value *V, IRBuilder<> &B); @@ -367,6 +368,26 @@ Type *SPIRVEmitIntrinsics::deduceElementTypeHelper( if (Ty) break; } + } else if (auto *CI = dyn_cast<CallInst>(I)) { + static StringMap<unsigned> ResTypeByArg = { + {"to_global", 0}, + {"to_local", 0}, + {"to_private", 0}, + {"__spirv_GenericCastToPtr_ToGlobal", 0}, + {"__spirv_GenericCastToPtr_ToLocal", 0}, + {"__spirv_GenericCastToPtr_ToPrivate", 0}, + {"__spirv_GenericCastToPtrExplicit_ToGlobal", 0}, + {"__spirv_GenericCastToPtrExplicit_ToLocal", 0}, + {"__spirv_GenericCastToPtrExplicit_ToPrivate", 0}}; + // TODO: maybe improve performance by caching demangled names + if (Function *CalledF = CI->getCalledFunction()) { + std::string DemangledName = + getOclOrSpirvBuiltinDemangledName(CalledF->getName()); + auto AsArgIt = ResTypeByArg.find(DemangledName); + if (AsArgIt != ResTypeByArg.end()) + Ty = deduceElementTypeHelper(CI->getArgOperand(AsArgIt->second), + Visited); + } } // remember the found relationship @@ -460,10 +481,10 @@ Type *SPIRVEmitIntrinsics::deduceNestedTypeHelper( return OrigTy; } -Type *SPIRVEmitIntrinsics::deduceElementType(Value *I) { +Type *SPIRVEmitIntrinsics::deduceElementType(Value *I, bool UnknownElemTypeI8) { if (Type *Ty = deduceElementTypeHelper(I)) return Ty; - return IntegerType::getInt8Ty(I->getContext()); + return UnknownElemTypeI8 ? IntegerType::getInt8Ty(I->getContext()) : nullptr; } // If the Instruction has Pointer operands with unresolved types, this function @@ -652,10 +673,8 @@ void SPIRVEmitIntrinsics::preprocessCompositeConstants(IRBuilder<> &B) { AggrConst = cast<Constant>(COp); ResTy = B.getInt32Ty(); } else if (auto *COp = dyn_cast<ConstantAggregateZero>(Op)) { - if (!Op->getType()->isVectorTy()) { - AggrConst = cast<Constant>(COp); - ResTy = B.getInt32Ty(); - } + AggrConst = cast<Constant>(COp); + ResTy = Op->getType()->isVectorTy() ? COp->getType() : B.getInt32Ty(); } if (AggrConst) { SmallVector<Value *> Args; @@ -1152,16 +1171,23 @@ void SPIRVEmitIntrinsics::processGlobalValue(GlobalVariable &GV, B.CreateIntrinsic(Intrinsic::spv_unref_global, GV.getType(), &GV); } -void SPIRVEmitIntrinsics::insertAssignPtrTypeIntrs(Instruction *I, - IRBuilder<> &B) { +// Return true, if we can't decide what is the pointee type now and will get +// back to the question later. Return false is spv_assign_ptr_type is not needed +// or can be inserted immediately. +bool SPIRVEmitIntrinsics::insertAssignPtrTypeIntrs(Instruction *I, + IRBuilder<> &B, + bool UnknownElemTypeI8) { reportFatalOnTokenType(I); if (!isPointerTy(I->getType()) || !requireAssignType(I) || isa<BitCastInst>(I)) - return; + return false; setInsertPointAfterDef(B, I); - Type *ElemTy = deduceElementType(I); - buildAssignPtr(B, ElemTy, I); + if (Type *ElemTy = deduceElementType(I, UnknownElemTypeI8)) { + buildAssignPtr(B, ElemTy, I); + return false; + } + return true; } void SPIRVEmitIntrinsics::insertAssignTypeIntrs(Instruction *I, @@ -1199,7 +1225,7 @@ void SPIRVEmitIntrinsics::insertAssignTypeIntrs(Instruction *I, buildAssignPtr(B, PType->getElementType(), Op); } else if (isPointerTy(OpTy)) { Type *ElemTy = GR->findDeducedElementType(Op); - buildAssignPtr(B, ElemTy ? ElemTy : deduceElementType(Op), Op); + buildAssignPtr(B, ElemTy ? ElemTy : deduceElementType(Op, true), Op); } else { CallInst *AssignCI = buildIntrWithMD(Intrinsic::spv_assign_type, {OpTy}, Op, Op, {}, B); @@ -1235,8 +1261,7 @@ void SPIRVEmitIntrinsics::processInstrAfterVisit(Instruction *I, } bool IsPhi = isa<PHINode>(I), BPrepared = false; for (const auto &Op : I->operands()) { - if ((isa<ConstantAggregateZero>(Op) && Op->getType()->isVectorTy()) || - isa<PHINode>(I) || isa<SwitchInst>(I)) + if (isa<PHINode>(I) || isa<SwitchInst>(I)) TrackConstants = false; if ((isa<ConstantData>(Op) || isa<ConstantExpr>(Op)) && TrackConstants) { unsigned OpNo = Op.getOperandNo(); @@ -1395,10 +1420,15 @@ bool SPIRVEmitIntrinsics::runOnFunction(Function &Func) { if (isConvergenceIntrinsic(I)) continue; - insertAssignPtrTypeIntrs(I, B); + bool Postpone = insertAssignPtrTypeIntrs(I, B, false); + // if Postpone is true, we can't decide on pointee type yet insertAssignTypeIntrs(I, B); insertPtrCastOrAssignTypeInstr(I, B); insertSpirvDecorations(I, B); + // if instruction requires a pointee type set, let's check if we know it + // already, and force it to be i8 if not + if (Postpone && !GR->findAssignPtrTypeInstr(I)) + insertAssignPtrTypeIntrs(I, B, true); } for (auto &I : instructions(Func)) diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp index d434e0b5efbc..5558c7a5a4a5 100644 --- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp @@ -394,7 +394,7 @@ Register SPIRVGlobalRegistry::getOrCreateCompositeOrNull( Constant *Val, MachineInstr &I, SPIRVType *SpvType, const SPIRVInstrInfo &TII, Constant *CA, unsigned BitWidth, unsigned ElemCnt, bool ZeroAsNull) { - // Find a constant vector in DT or build a new one. + // Find a constant vector or array in DT or build a new one. Register Res = DT.find(CA, CurMF); // If no values are attached, the composite is null constant. bool IsNull = Val->isNullValue() && ZeroAsNull; @@ -474,20 +474,28 @@ Register SPIRVGlobalRegistry::getOrCreateConstVector(APFloat Val, ZeroAsNull); } -Register -SPIRVGlobalRegistry::getOrCreateConsIntArray(uint64_t Val, MachineInstr &I, - SPIRVType *SpvType, - const SPIRVInstrInfo &TII) { +Register SPIRVGlobalRegistry::getOrCreateConstIntArray( + uint64_t Val, size_t Num, MachineInstr &I, SPIRVType *SpvType, + const SPIRVInstrInfo &TII) { const Type *LLVMTy = getTypeForSPIRVType(SpvType); assert(LLVMTy->isArrayTy()); const ArrayType *LLVMArrTy = cast<ArrayType>(LLVMTy); Type *LLVMBaseTy = LLVMArrTy->getElementType(); - auto *ConstInt = ConstantInt::get(LLVMBaseTy, Val); - auto *ConstArr = - ConstantArray::get(const_cast<ArrayType *>(LLVMArrTy), {ConstInt}); + Constant *CI = ConstantInt::get(LLVMBaseTy, Val); SPIRVType *SpvBaseTy = getSPIRVTypeForVReg(SpvType->getOperand(1).getReg()); unsigned BW = getScalarOrVectorBitWidth(SpvBaseTy); - return getOrCreateCompositeOrNull(ConstInt, I, SpvType, TII, ConstArr, BW, + // The following is reasonably unique key that is better that [Val]. The naive + // alternative would be something along the lines of: + // SmallVector<Constant *> NumCI(Num, CI); + // Constant *UniqueKey = + // ConstantArray::get(const_cast<ArrayType*>(LLVMArrTy), NumCI); + // that would be a truly unique but dangerous key, because it could lead to + // the creation of constants of arbitrary length (that is, the parameter of + // memset) which were missing in the original module. + Constant *UniqueKey = ConstantStruct::getAnon( + {PoisonValue::get(const_cast<ArrayType *>(LLVMArrTy)), + ConstantInt::get(LLVMBaseTy, Val), ConstantInt::get(LLVMBaseTy, Num)}); + return getOrCreateCompositeOrNull(CI, I, SpvType, TII, UniqueKey, BW, LLVMArrTy->getNumElements()); } @@ -546,24 +554,6 @@ SPIRVGlobalRegistry::getOrCreateConsIntVector(uint64_t Val, } Register -SPIRVGlobalRegistry::getOrCreateConsIntArray(uint64_t Val, - MachineIRBuilder &MIRBuilder, - SPIRVType *SpvType, bool EmitIR) { - const Type *LLVMTy = getTypeForSPIRVType(SpvType); - assert(LLVMTy->isArrayTy()); - const ArrayType *LLVMArrTy = cast<ArrayType>(LLVMTy); - Type *LLVMBaseTy = LLVMArrTy->getElementType(); - const auto ConstInt = ConstantInt::get(LLVMBaseTy, Val); - auto ConstArr = - ConstantArray::get(const_cast<ArrayType *>(LLVMArrTy), {ConstInt}); - SPIRVType *SpvBaseTy = getSPIRVTypeForVReg(SpvType->getOperand(1).getReg()); - unsigned BW = getScalarOrVectorBitWidth(SpvBaseTy); - return getOrCreateIntCompositeOrNull(Val, MIRBuilder, SpvType, EmitIR, - ConstArr, BW, - LLVMArrTy->getNumElements()); -} - -Register SPIRVGlobalRegistry::getOrCreateConstNullPtr(MachineIRBuilder &MIRBuilder, SPIRVType *SpvType) { const Type *LLVMTy = getTypeForSPIRVType(SpvType); @@ -936,7 +926,7 @@ SPIRVType *SPIRVGlobalRegistry::restOfCreateSPIRVType( SPIRVType *SpirvType = createSPIRVType(Ty, MIRBuilder, AccessQual, EmitIR); TypesInProcessing.erase(Ty); VRegToTypeMap[&MIRBuilder.getMF()][getSPIRVTypeID(SpirvType)] = SpirvType; - SPIRVToLLVMType[SpirvType] = Ty; + SPIRVToLLVMType[SpirvType] = unifyPtrType(Ty); Register Reg = DT.find(Ty, &MIRBuilder.getMF()); // Do not add OpTypeForwardPointer to DT, a corresponding normal pointer type // will be added later. For special types it is already added to DT. @@ -1080,12 +1070,14 @@ bool SPIRVGlobalRegistry::isScalarOrVectorSigned(const SPIRVType *Type) const { return IntType && IntType->getOperand(2).getImm() != 0; } +SPIRVType *SPIRVGlobalRegistry::getPointeeType(SPIRVType *PtrType) { + return PtrType && PtrType->getOpcode() == SPIRV::OpTypePointer + ? getSPIRVTypeForVReg(PtrType->getOperand(2).getReg()) + : nullptr; +} + unsigned SPIRVGlobalRegistry::getPointeeTypeOp(Register PtrReg) { - SPIRVType *PtrType = getSPIRVTypeForVReg(PtrReg); - SPIRVType *ElemType = - PtrType && PtrType->getOpcode() == SPIRV::OpTypePointer - ? getSPIRVTypeForVReg(PtrType->getOperand(2).getReg()) - : nullptr; + SPIRVType *ElemType = getPointeeType(getSPIRVTypeForVReg(PtrReg)); return ElemType ? ElemType->getOpcode() : 0; } @@ -1122,9 +1114,9 @@ SPIRVType *SPIRVGlobalRegistry::getOrCreateOpTypeImage( uint32_t Depth, uint32_t Arrayed, uint32_t Multisampled, uint32_t Sampled, SPIRV::ImageFormat::ImageFormat ImageFormat, SPIRV::AccessQualifier::AccessQualifier AccessQual) { - SPIRV::ImageTypeDescriptor TD(SPIRVToLLVMType.lookup(SampledType), Dim, Depth, - Arrayed, Multisampled, Sampled, ImageFormat, - AccessQual); + auto TD = SPIRV::make_descr_image(SPIRVToLLVMType.lookup(SampledType), Dim, + Depth, Arrayed, Multisampled, Sampled, + ImageFormat, AccessQual); if (auto *Res = checkSpecialInstr(TD, MIRBuilder)) return Res; Register ResVReg = createTypeVReg(MIRBuilder); @@ -1143,7 +1135,7 @@ SPIRVType *SPIRVGlobalRegistry::getOrCreateOpTypeImage( SPIRVType * SPIRVGlobalRegistry::getOrCreateOpTypeSampler(MachineIRBuilder &MIRBuilder) { - SPIRV::SamplerTypeDescriptor TD; + auto TD = SPIRV::make_descr_sampler(); if (auto *Res = checkSpecialInstr(TD, MIRBuilder)) return Res; Register ResVReg = createTypeVReg(MIRBuilder); @@ -1154,7 +1146,7 @@ SPIRVGlobalRegistry::getOrCreateOpTypeSampler(MachineIRBuilder &MIRBuilder) { SPIRVType *SPIRVGlobalRegistry::getOrCreateOpTypePipe( MachineIRBuilder &MIRBuilder, SPIRV::AccessQualifier::AccessQualifier AccessQual) { - SPIRV::PipeTypeDescriptor TD(AccessQual); + auto TD = SPIRV::make_descr_pipe(AccessQual); if (auto *Res = checkSpecialInstr(TD, MIRBuilder)) return Res; Register ResVReg = createTypeVReg(MIRBuilder); @@ -1166,7 +1158,7 @@ SPIRVType *SPIRVGlobalRegistry::getOrCreateOpTypePipe( SPIRVType *SPIRVGlobalRegistry::getOrCreateOpTypeDeviceEvent( MachineIRBuilder &MIRBuilder) { - SPIRV::DeviceEventTypeDescriptor TD; + auto TD = SPIRV::make_descr_event(); if (auto *Res = checkSpecialInstr(TD, MIRBuilder)) return Res; Register ResVReg = createTypeVReg(MIRBuilder); @@ -1176,7 +1168,7 @@ SPIRVType *SPIRVGlobalRegistry::getOrCreateOpTypeDeviceEvent( SPIRVType *SPIRVGlobalRegistry::getOrCreateOpTypeSampledImage( SPIRVType *ImageType, MachineIRBuilder &MIRBuilder) { - SPIRV::SampledImageTypeDescriptor TD( + auto TD = SPIRV::make_descr_sampled_image( SPIRVToLLVMType.lookup(MIRBuilder.getMF().getRegInfo().getVRegDef( ImageType->getOperand(1).getReg())), ImageType); @@ -1189,6 +1181,26 @@ SPIRVType *SPIRVGlobalRegistry::getOrCreateOpTypeSampledImage( .addUse(getSPIRVTypeID(ImageType)); } +SPIRVType *SPIRVGlobalRegistry::getOrCreateOpTypeCoopMatr( + MachineIRBuilder &MIRBuilder, const TargetExtType *ExtensionType, + const SPIRVType *ElemType, uint32_t Scope, uint32_t Rows, uint32_t Columns, + uint32_t Use) { + Register ResVReg = DT.find(ExtensionType, &MIRBuilder.getMF()); + if (ResVReg.isValid()) + return MIRBuilder.getMF().getRegInfo().getUniqueVRegDef(ResVReg); + ResVReg = createTypeVReg(MIRBuilder); + SPIRVType *SpirvTy = + MIRBuilder.buildInstr(SPIRV::OpTypeCooperativeMatrixKHR) + .addDef(ResVReg) + .addUse(getSPIRVTypeID(ElemType)) + .addUse(buildConstantInt(Scope, MIRBuilder, nullptr, true)) + .addUse(buildConstantInt(Rows, MIRBuilder, nullptr, true)) + .addUse(buildConstantInt(Columns, MIRBuilder, nullptr, true)) + .addUse(buildConstantInt(Use, MIRBuilder, nullptr, true)); + DT.add(ExtensionType, &MIRBuilder.getMF(), ResVReg); + return SpirvTy; +} + SPIRVType *SPIRVGlobalRegistry::getOrCreateOpTypeByOpcode( const Type *Ty, MachineIRBuilder &MIRBuilder, unsigned Opcode) { Register ResVReg = DT.find(Ty, &MIRBuilder.getMF()); @@ -1268,7 +1280,7 @@ SPIRVType *SPIRVGlobalRegistry::finishCreatingSPIRVType(const Type *LLVMTy, SPIRVType *SpirvType) { assert(CurMF == SpirvType->getMF()); VRegToTypeMap[CurMF][getSPIRVTypeID(SpirvType)] = SpirvType; - SPIRVToLLVMType[SpirvType] = LLVMTy; + SPIRVToLLVMType[SpirvType] = unifyPtrType(LLVMTy); return SpirvType; } diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h index db01f68f48de..a45e1ccd0717 100644 --- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h +++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h @@ -292,6 +292,8 @@ public: return Res->second; } + // Return a pointee's type, or nullptr otherwise. + SPIRVType *getPointeeType(SPIRVType *PtrType); // Return a pointee's type op code, or 0 otherwise. unsigned getPointeeTypeOp(Register PtrReg); @@ -327,6 +329,12 @@ public: return Ret; } + // Return true if the type is an aggregate type. + bool isAggregateType(SPIRVType *Type) const { + return Type && (Type->getOpcode() == SPIRV::OpTypeStruct && + Type->getOpcode() == SPIRV::OpTypeArray); + } + // Whether the given VReg has an OpTypeXXX instruction mapped to it with the // given opcode (e.g. OpTypeFloat). bool isScalarOfType(Register VReg, unsigned TypeOpcode) const; @@ -449,13 +457,11 @@ public: Register getOrCreateConstVector(APFloat Val, MachineInstr &I, SPIRVType *SpvType, const SPIRVInstrInfo &TII, bool ZeroAsNull = true); - Register getOrCreateConsIntArray(uint64_t Val, MachineInstr &I, - SPIRVType *SpvType, - const SPIRVInstrInfo &TII); + Register getOrCreateConstIntArray(uint64_t Val, size_t Num, MachineInstr &I, + SPIRVType *SpvType, + const SPIRVInstrInfo &TII); Register getOrCreateConsIntVector(uint64_t Val, MachineIRBuilder &MIRBuilder, SPIRVType *SpvType, bool EmitIR = true); - Register getOrCreateConsIntArray(uint64_t Val, MachineIRBuilder &MIRBuilder, - SPIRVType *SpvType, bool EmitIR = true); Register getOrCreateConstNullPtr(MachineIRBuilder &MIRBuilder, SPIRVType *SpvType); Register buildConstantSampler(Register Res, unsigned AddrMode, unsigned Param, @@ -514,7 +520,11 @@ public: SPIRVType *getOrCreateOpTypeSampledImage(SPIRVType *ImageType, MachineIRBuilder &MIRBuilder); - + SPIRVType *getOrCreateOpTypeCoopMatr(MachineIRBuilder &MIRBuilder, + const TargetExtType *ExtensionType, + const SPIRVType *ElemType, + uint32_t Scope, uint32_t Rows, + uint32_t Columns, uint32_t Use); SPIRVType * getOrCreateOpTypePipe(MachineIRBuilder &MIRBuilder, SPIRV::AccessQualifier::AccessQualifier AccQual); diff --git a/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp b/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp index 5ccbaf12ddee..4383d1c5c0e2 100644 --- a/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp @@ -339,6 +339,7 @@ void SPIRVTargetLowering::finalizeLowering(MachineFunction &MF) const { GR.getSPIRVTypeForVReg(MI.getOperand(1).getReg())); break; case SPIRV::OpPtrCastToGeneric: + case SPIRV::OpGenericCastToPtr: validateAccessChain(STI, MRI, GR, MI); break; case SPIRV::OpInBoundsPtrAccessChain: diff --git a/llvm/lib/Target/SPIRV/SPIRVISelLowering.h b/llvm/lib/Target/SPIRV/SPIRVISelLowering.h index 6fc200abf462..77356b7512a7 100644 --- a/llvm/lib/Target/SPIRV/SPIRVISelLowering.h +++ b/llvm/lib/Target/SPIRV/SPIRVISelLowering.h @@ -68,6 +68,11 @@ public: // extra instructions required to preserve validity of SPIR-V code imposed by // the standard. void finalizeLowering(MachineFunction &MF) const override; + + MVT getPreferredSwitchConditionType(LLVMContext &Context, + EVT ConditionVT) const override { + return ConditionVT.getSimpleVT(); + } }; } // namespace llvm diff --git a/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td b/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td index dedfd5e6e32d..63549b06e967 100644 --- a/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td +++ b/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td @@ -211,6 +211,9 @@ def OpTypeAccelerationStructureNV: Op<5341, (outs TYPE:$res), (ins), def OpTypeCooperativeMatrixNV: Op<5358, (outs TYPE:$res), (ins TYPE:$compType, ID:$scope, ID:$rows, ID:$cols), "$res = OpTypeCooperativeMatrixNV $compType $scope $rows $cols">; +def OpTypeCooperativeMatrixKHR: Op<4456, (outs TYPE:$res), + (ins TYPE:$compType, ID:$scope, ID:$rows, ID:$cols, ID:$use), + "$res = OpTypeCooperativeMatrixKHR $compType $scope $rows $cols $use">; // 3.42.7 Constant-Creation Instructions @@ -864,3 +867,16 @@ def OpAsmINTEL: Op<5610, (outs ID:$res), (ins TYPE:$type, TYPE:$asm_type, ID:$ta "$res = OpAsmINTEL $type $asm_type $target $asm">; def OpAsmCallINTEL: Op<5611, (outs ID:$res), (ins TYPE:$type, ID:$asm, variable_ops), "$res = OpAsmCallINTEL $type $asm">; + +// SPV_KHR_cooperative_matrix +def OpCooperativeMatrixLoadKHR: Op<4457, (outs ID:$res), + (ins TYPE:$resType, ID:$pointer, ID:$memory_layout, variable_ops), + "$res = OpCooperativeMatrixLoadKHR $resType $pointer $memory_layout">; +def OpCooperativeMatrixStoreKHR: Op<4458, (outs), + (ins ID:$pointer, ID:$objectToStore, ID:$memory_layout, variable_ops), + "OpCooperativeMatrixStoreKHR $pointer $objectToStore $memory_layout">; +def OpCooperativeMatrixMulAddKHR: Op<4459, (outs ID:$res), + (ins TYPE:$type, ID:$A, ID:$B, ID:$C, variable_ops), + "$res = OpCooperativeMatrixMulAddKHR $type $A $B $C">; +def OpCooperativeMatrixLengthKHR: Op<4460, (outs ID:$res), (ins TYPE:$type, ID:$coop_matr_type), + "$res = OpCooperativeMatrixLengthKHR $type $coop_matr_type">; diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp index db83172f7fa9..9be736ce88ce 100644 --- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp @@ -173,6 +173,9 @@ private: bool selectFmix(Register ResVReg, const SPIRVType *ResType, MachineInstr &I) const; + bool selectRsqrt(Register ResVReg, const SPIRVType *ResType, + MachineInstr &I) const; + void renderImm32(MachineInstrBuilder &MIB, const MachineInstr &I, int OpIdx) const; void renderFImm32(MachineInstrBuilder &MIB, const MachineInstr &I, @@ -469,6 +472,18 @@ bool SPIRVInstructionSelector::spvSelect(Register ResVReg, return selectExtInst(ResVReg, ResType, I, CL::sin, GL::Sin); case TargetOpcode::G_FTAN: return selectExtInst(ResVReg, ResType, I, CL::tan, GL::Tan); + case TargetOpcode::G_FACOS: + return selectExtInst(ResVReg, ResType, I, CL::acos, GL::Acos); + case TargetOpcode::G_FASIN: + return selectExtInst(ResVReg, ResType, I, CL::asin, GL::Asin); + case TargetOpcode::G_FATAN: + return selectExtInst(ResVReg, ResType, I, CL::atan, GL::Atan); + case TargetOpcode::G_FCOSH: + return selectExtInst(ResVReg, ResType, I, CL::cosh, GL::Cosh); + case TargetOpcode::G_FSINH: + return selectExtInst(ResVReg, ResType, I, CL::sinh, GL::Sinh); + case TargetOpcode::G_FTANH: + return selectExtInst(ResVReg, ResType, I, CL::tanh, GL::Tanh); case TargetOpcode::G_FSQRT: return selectExtInst(ResVReg, ResType, I, CL::sqrt, GL::Sqrt); @@ -831,7 +846,7 @@ bool SPIRVInstructionSelector::selectMemOperation(Register ResVReg, unsigned Num = getIConstVal(I.getOperand(2).getReg(), MRI); SPIRVType *ValTy = GR.getOrCreateSPIRVIntegerType(8, I, TII); SPIRVType *ArrTy = GR.getOrCreateSPIRVArrayType(ValTy, Num, I, TII); - Register Const = GR.getOrCreateConsIntArray(Val, I, ArrTy, TII); + Register Const = GR.getOrCreateConstIntArray(Val, Num, I, ArrTy, TII); SPIRVType *VarTy = GR.getOrCreateSPIRVPointerType( ArrTy, I, TII, SPIRV::StorageClass::UniformConstant); // TODO: check if we have such GV, add init, use buildGlobalVariable. @@ -1102,7 +1117,7 @@ bool SPIRVInstructionSelector::selectAddrSpaceCast(Register ResVReg, if (isGenericCastablePtr(SrcSC) && isGenericCastablePtr(DstSC)) { Register Tmp = MRI->createVirtualRegister(&SPIRV::IDRegClass); SPIRVType *GenericPtrTy = GR.getOrCreateSPIRVPointerType( - SrcPtrTy, I, TII, SPIRV::StorageClass::Generic); + GR.getPointeeType(SrcPtrTy), I, TII, SPIRV::StorageClass::Generic); MachineBasicBlock &BB = *I.getParent(); const DebugLoc &DL = I.getDebugLoc(); bool Success = BuildMI(BB, I, DL, TII.get(SPIRV::OpPtrCastToGeneric)) @@ -1315,6 +1330,23 @@ bool SPIRVInstructionSelector::selectFmix(Register ResVReg, .constrainAllUses(TII, TRI, RBI); } +bool SPIRVInstructionSelector::selectRsqrt(Register ResVReg, + const SPIRVType *ResType, + MachineInstr &I) const { + + assert(I.getNumOperands() == 3); + assert(I.getOperand(2).isReg()); + MachineBasicBlock &BB = *I.getParent(); + + return BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpExtInst)) + .addDef(ResVReg) + .addUse(GR.getSPIRVTypeID(ResType)) + .addImm(static_cast<uint32_t>(SPIRV::InstructionSet::GLSL_std_450)) + .addImm(GL::InverseSqrt) + .addUse(I.getOperand(2).getReg()) + .constrainAllUses(TII, TRI, RBI); +} + bool SPIRVInstructionSelector::selectBitreverse(Register ResVReg, const SPIRVType *ResType, MachineInstr &I) const { @@ -1413,20 +1445,50 @@ static unsigned getArrayComponentCount(MachineRegisterInfo *MRI, } // Return true if the type represents a constant register -static bool isConstReg(MachineRegisterInfo *MRI, SPIRVType *OpDef) { +static bool isConstReg(MachineRegisterInfo *MRI, SPIRVType *OpDef, + SmallPtrSet<SPIRVType *, 4> &Visited) { if (OpDef->getOpcode() == SPIRV::ASSIGN_TYPE && OpDef->getOperand(1).isReg()) { if (SPIRVType *RefDef = MRI->getVRegDef(OpDef->getOperand(1).getReg())) OpDef = RefDef; } - return OpDef->getOpcode() == TargetOpcode::G_CONSTANT || - OpDef->getOpcode() == TargetOpcode::G_FCONSTANT; + + if (Visited.contains(OpDef)) + return true; + Visited.insert(OpDef); + + unsigned Opcode = OpDef->getOpcode(); + switch (Opcode) { + case TargetOpcode::G_CONSTANT: + case TargetOpcode::G_FCONSTANT: + return true; + case TargetOpcode::G_INTRINSIC: + case TargetOpcode::G_INTRINSIC_W_SIDE_EFFECTS: + case TargetOpcode::G_INTRINSIC_CONVERGENT_W_SIDE_EFFECTS: + return cast<GIntrinsic>(*OpDef).getIntrinsicID() == + Intrinsic::spv_const_composite; + case TargetOpcode::G_BUILD_VECTOR: + case TargetOpcode::G_SPLAT_VECTOR: { + for (unsigned i = OpDef->getNumExplicitDefs(); i < OpDef->getNumOperands(); + i++) { + SPIRVType *OpNestedDef = + OpDef->getOperand(i).isReg() + ? MRI->getVRegDef(OpDef->getOperand(i).getReg()) + : nullptr; + if (OpNestedDef && !isConstReg(MRI, OpNestedDef, Visited)) + return false; + } + return true; + } + } + return false; } // Return true if the virtual register represents a constant static bool isConstReg(MachineRegisterInfo *MRI, Register OpReg) { + SmallPtrSet<SPIRVType *, 4> Visited; if (SPIRVType *OpDef = MRI->getVRegDef(OpReg)) - return isConstReg(MRI, OpDef); + return isConstReg(MRI, OpDef, Visited); return false; } @@ -1740,15 +1802,18 @@ bool SPIRVInstructionSelector::selectOpUndef(Register ResVReg, static bool isImm(const MachineOperand &MO, MachineRegisterInfo *MRI) { assert(MO.isReg()); const SPIRVType *TypeInst = MRI->getVRegDef(MO.getReg()); - if (TypeInst->getOpcode() != SPIRV::ASSIGN_TYPE) - return false; - assert(TypeInst->getOperand(1).isReg()); - MachineInstr *ImmInst = MRI->getVRegDef(TypeInst->getOperand(1).getReg()); - return ImmInst->getOpcode() == TargetOpcode::G_CONSTANT; + if (TypeInst->getOpcode() == SPIRV::ASSIGN_TYPE) { + assert(TypeInst->getOperand(1).isReg()); + MachineInstr *ImmInst = MRI->getVRegDef(TypeInst->getOperand(1).getReg()); + return ImmInst->getOpcode() == TargetOpcode::G_CONSTANT; + } + return TypeInst->getOpcode() == SPIRV::OpConstantI; } static int64_t foldImm(const MachineOperand &MO, MachineRegisterInfo *MRI) { const SPIRVType *TypeInst = MRI->getVRegDef(MO.getReg()); + if (TypeInst->getOpcode() == SPIRV::OpConstantI) + return TypeInst->getOperand(2).getImm(); MachineInstr *ImmInst = MRI->getVRegDef(TypeInst->getOperand(1).getReg()); assert(ImmInst->getOpcode() == TargetOpcode::G_CONSTANT); return ImmInst->getOperand(1).getCImm()->getZExtValue(); @@ -1850,8 +1915,10 @@ bool SPIRVInstructionSelector::wrapIntoSpecConstantOp( Register OpReg = I.getOperand(i).getReg(); SPIRVType *OpDefine = MRI->getVRegDef(OpReg); SPIRVType *OpType = GR.getSPIRVTypeForVReg(OpReg); - if (!OpDefine || !OpType || isConstReg(MRI, OpDefine) || - OpDefine->getOpcode() == TargetOpcode::G_ADDRSPACE_CAST) { + SmallPtrSet<SPIRVType *, 4> Visited; + if (!OpDefine || !OpType || isConstReg(MRI, OpDefine, Visited) || + OpDefine->getOpcode() == TargetOpcode::G_ADDRSPACE_CAST || + GR.isAggregateType(OpType)) { // The case of G_ADDRSPACE_CAST inside spv_const_composite() is processed // by selectAddrSpaceCast() CompositeArgs.push_back(OpReg); @@ -1992,6 +2059,8 @@ bool SPIRVInstructionSelector::selectIntrinsic(Register ResVReg, return selectAny(ResVReg, ResType, I); case Intrinsic::spv_lerp: return selectFmix(ResVReg, ResType, I); + case Intrinsic::spv_rsqrt: + return selectRsqrt(ResVReg, ResType, I); case Intrinsic::spv_lifetime_start: case Intrinsic::spv_lifetime_end: { unsigned Op = IID == Intrinsic::spv_lifetime_start ? SPIRV::OpLifetimeStart diff --git a/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp b/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp index 57fbf3b3f8f1..6c7c3af19965 100644 --- a/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp @@ -278,6 +278,12 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) { G_FCOS, G_FSIN, G_FTAN, + G_FACOS, + G_FASIN, + G_FATAN, + G_FCOSH, + G_FSINH, + G_FTANH, G_FSQRT, G_FFLOOR, G_FRINT, diff --git a/llvm/lib/Target/SPIRV/SPIRVMergeRegionExitTargets.cpp b/llvm/lib/Target/SPIRV/SPIRVMergeRegionExitTargets.cpp index 2744c25d1bc7..0747dd1bbaf4 100644 --- a/llvm/lib/Target/SPIRV/SPIRVMergeRegionExitTargets.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVMergeRegionExitTargets.cpp @@ -17,6 +17,8 @@ #include "SPIRVSubtarget.h" #include "SPIRVTargetMachine.h" #include "SPIRVUtils.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/SmallPtrSet.h" #include "llvm/Analysis/LoopInfo.h" #include "llvm/CodeGen/IntrinsicLowering.h" #include "llvm/IR/CFG.h" @@ -71,7 +73,7 @@ public: /// terminator will take. llvm::Value *createExitVariable( BasicBlock *BB, - const std::unordered_map<BasicBlock *, ConstantInt *> &TargetToValue) { + const DenseMap<BasicBlock *, ConstantInt *> &TargetToValue) { auto *T = BB->getTerminator(); if (isa<ReturnInst>(T)) return nullptr; @@ -98,12 +100,12 @@ public: } // TODO: add support for switch cases. - assert(false && "Unhandled terminator type."); + llvm_unreachable("Unhandled terminator type."); } /// Replaces |BB|'s branch targets present in |ToReplace| with |NewTarget|. void replaceBranchTargets(BasicBlock *BB, - const std::unordered_set<BasicBlock *> ToReplace, + const SmallPtrSet<BasicBlock *, 4> &ToReplace, BasicBlock *NewTarget) { auto *T = BB->getTerminator(); if (isa<ReturnInst>(T)) @@ -133,7 +135,7 @@ public: bool runOnConvergenceRegionNoRecurse(LoopInfo &LI, const SPIRV::ConvergenceRegion *CR) { // Gather all the exit targets for this region. - std::unordered_set<BasicBlock *> ExitTargets; + SmallPtrSet<BasicBlock *, 4> ExitTargets; for (BasicBlock *Exit : CR->Exits) { for (BasicBlock *Target : gatherSuccessors(Exit)) { if (CR->Blocks.count(Target) == 0) @@ -164,9 +166,10 @@ public: // Creating one constant per distinct exit target. This will be route to the // correct target. - std::unordered_map<BasicBlock *, ConstantInt *> TargetToValue; + DenseMap<BasicBlock *, ConstantInt *> TargetToValue; for (BasicBlock *Target : SortedExitTargets) - TargetToValue.emplace(Target, Builder.getInt32(TargetToValue.size())); + TargetToValue.insert( + std::make_pair(Target, Builder.getInt32(TargetToValue.size()))); // Creating one variable per exit node, set to the constant matching the // targeted external block. @@ -184,12 +187,12 @@ public: } // Creating the switch to jump to the correct exit target. - std::vector<std::pair<BasicBlock *, ConstantInt *>> CasesList( - TargetToValue.begin(), TargetToValue.end()); - llvm::SwitchInst *Sw = - Builder.CreateSwitch(node, CasesList[0].first, CasesList.size() - 1); - for (size_t i = 1; i < CasesList.size(); i++) - Sw->addCase(CasesList[i].second, CasesList[i].first); + llvm::SwitchInst *Sw = Builder.CreateSwitch(node, SortedExitTargets[0], + SortedExitTargets.size() - 1); + for (size_t i = 1; i < SortedExitTargets.size(); i++) { + BasicBlock *BB = SortedExitTargets[i]; + Sw->addCase(TargetToValue[BB], BB); + } // Fix exit branches to redirect to the new exit. for (auto Exit : CR->Exits) diff --git a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp index 30a6c474f467..ac0aa682ea4b 100644 --- a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp @@ -1168,6 +1168,15 @@ void addInstrRequirements(const MachineInstr &MI, Reqs.addCapability(SPIRV::Capability::AsmINTEL); } break; + case SPIRV::OpTypeCooperativeMatrixKHR: + if (!ST.canUseExtension(SPIRV::Extension::SPV_KHR_cooperative_matrix)) + report_fatal_error( + "OpTypeCooperativeMatrixKHR type requires the " + "following SPIR-V extension: SPV_KHR_cooperative_matrix", + false); + Reqs.addExtension(SPIRV::Extension::SPV_KHR_cooperative_matrix); + Reqs.addCapability(SPIRV::Capability::CooperativeMatrixKHR); + break; default: break; } diff --git a/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp b/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp index adc5b36af6f1..0ea2f176565e 100644 --- a/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp @@ -41,7 +41,8 @@ public: static void addConstantsToTrack(MachineFunction &MF, SPIRVGlobalRegistry *GR, const SPIRVSubtarget &STI, - DenseMap<MachineInstr *, Type *> &TargetExtConstTypes) { + DenseMap<MachineInstr *, Type *> &TargetExtConstTypes, + SmallSet<Register, 4> &TrackedConstRegs) { MachineRegisterInfo &MRI = MF.getRegInfo(); DenseMap<MachineInstr *, Register> RegsAlreadyAddedToDT; SmallVector<MachineInstr *, 10> ToErase, ToEraseComposites; @@ -80,6 +81,7 @@ addConstantsToTrack(MachineFunction &MF, SPIRVGlobalRegistry *GR, } } GR->add(Const, &MF, SrcReg); + TrackedConstRegs.insert(SrcReg); if (Const->getType()->isTargetExtTy()) { // remember association so that we can restore it when assign types MachineInstr *SrcMI = MRI.getVRegDef(SrcReg); @@ -121,7 +123,9 @@ addConstantsToTrack(MachineFunction &MF, SPIRVGlobalRegistry *GR, MI->eraseFromParent(); } -static void foldConstantsIntoIntrinsics(MachineFunction &MF) { +static void +foldConstantsIntoIntrinsics(MachineFunction &MF, + const SmallSet<Register, 4> &TrackedConstRegs) { SmallVector<MachineInstr *, 10> ToErase; MachineRegisterInfo &MRI = MF.getRegInfo(); const unsigned AssignNameOperandShift = 2; @@ -137,7 +141,8 @@ static void foldConstantsIntoIntrinsics(MachineFunction &MF) { MI.removeOperand(NumOp); MI.addOperand(MachineOperand::CreateImm( ConstMI->getOperand(1).getCImm()->getZExtValue())); - if (MRI.use_empty(ConstMI->getOperand(0).getReg())) + Register DefReg = ConstMI->getOperand(0).getReg(); + if (MRI.use_empty(DefReg) && !TrackedConstRegs.contains(DefReg)) ToErase.push_back(ConstMI); } } @@ -271,6 +276,21 @@ static SPIRVType *propagateSPIRVType(MachineInstr *MI, SPIRVGlobalRegistry *GR, return SpirvTy; } +// To support current approach and limitations wrt. bit width here we widen a +// scalar register with a bit width greater than 1 to valid sizes and cap it to +// 64 width. +static void widenScalarLLTNextPow2(Register Reg, MachineRegisterInfo &MRI) { + LLT RegType = MRI.getType(Reg); + if (!RegType.isScalar()) + return; + unsigned Sz = RegType.getScalarSizeInBits(); + if (Sz == 1) + return; + unsigned NewSz = std::min(std::max(1u << Log2_32_Ceil(Sz), 8u), 64u); + if (NewSz != Sz) + MRI.setType(Reg, LLT::scalar(NewSz)); +} + static std::pair<Register, unsigned> createNewIdReg(SPIRVType *SpvType, Register SrcReg, MachineRegisterInfo &MRI, const SPIRVGlobalRegistry &GR) { @@ -395,6 +415,7 @@ generateAssignInstrs(MachineFunction &MF, SPIRVGlobalRegistry *GR, MachineRegisterInfo &MRI = MF.getRegInfo(); SmallVector<MachineInstr *, 10> ToErase; + DenseMap<MachineInstr *, Register> RegsAlreadyAddedToDT; for (MachineBasicBlock *MBB : post_order(&MF)) { if (MBB->empty()) @@ -406,6 +427,11 @@ generateAssignInstrs(MachineFunction &MF, SPIRVGlobalRegistry *GR, MachineInstr &MI = *MII; unsigned MIOp = MI.getOpcode(); + // validate bit width of scalar registers + for (const auto &MOP : MI.operands()) + if (MOP.isReg()) + widenScalarLLTNextPow2(MOP.getReg(), MRI); + if (isSpvIntrinsic(MI, Intrinsic::spv_assign_ptr_type)) { Register Reg = MI.getOperand(1).getReg(); MIB.setInsertPt(*MI.getParent(), MI.getIterator()); @@ -441,6 +467,7 @@ generateAssignInstrs(MachineFunction &MF, SPIRVGlobalRegistry *GR, // %rctmp = G_CONSTANT ty Val // %rc = ASSIGN_TYPE %rctmp, %cty Register Reg = MI.getOperand(0).getReg(); + bool NeedAssignType = true; if (MRI.hasOneUse(Reg)) { MachineInstr &UseMI = *MRI.use_instr_begin(Reg); if (isSpvIntrinsic(UseMI, Intrinsic::spv_assign_type) || @@ -453,7 +480,20 @@ generateAssignInstrs(MachineFunction &MF, SPIRVGlobalRegistry *GR, Ty = TargetExtIt == TargetExtConstTypes.end() ? MI.getOperand(1).getCImm()->getType() : TargetExtIt->second; - GR->add(MI.getOperand(1).getCImm(), &MF, Reg); + const ConstantInt *OpCI = MI.getOperand(1).getCImm(); + Register PrimaryReg = GR->find(OpCI, &MF); + if (!PrimaryReg.isValid()) { + GR->add(OpCI, &MF, Reg); + } else if (PrimaryReg != Reg && + MRI.getType(Reg) == MRI.getType(PrimaryReg)) { + auto *RCReg = MRI.getRegClassOrNull(Reg); + auto *RCPrimary = MRI.getRegClassOrNull(PrimaryReg); + if (!RCReg || RCPrimary == RCReg) { + RegsAlreadyAddedToDT[&MI] = PrimaryReg; + ToErase.push_back(&MI); + NeedAssignType = false; + } + } } else if (MIOp == TargetOpcode::G_FCONSTANT) { Ty = MI.getOperand(1).getFPImm()->getType(); } else { @@ -472,14 +512,10 @@ generateAssignInstrs(MachineFunction &MF, SPIRVGlobalRegistry *GR, MI.getNumExplicitOperands() - MI.getNumExplicitDefs(); Ty = VectorType::get(ElemTy, NumElts, false); } - insertAssignInstr(Reg, Ty, nullptr, GR, MIB, MRI); + if (NeedAssignType) + insertAssignInstr(Reg, Ty, nullptr, GR, MIB, MRI); } else if (MIOp == TargetOpcode::G_GLOBAL_VALUE) { propagateSPIRVType(&MI, GR, MRI, MIB); - } else if (MIOp == TargetOpcode::G_BITREVERSE) { - Register Reg = MI.getOperand(0).getReg(); - LLT RegType = MRI.getType(Reg); - if (RegType.getSizeInBits() < 32) - MRI.setType(Reg, LLT::scalar(32)); } if (MII == Begin) @@ -488,8 +524,12 @@ generateAssignInstrs(MachineFunction &MF, SPIRVGlobalRegistry *GR, --MII; } } - for (MachineInstr *MI : ToErase) + for (MachineInstr *MI : ToErase) { + auto It = RegsAlreadyAddedToDT.find(MI); + if (RegsAlreadyAddedToDT.contains(MI)) + MRI.replaceRegWith(MI->getOperand(0).getReg(), It->second); MI->eraseFromParent(); + } // Address the case when IRTranslator introduces instructions with new // registers without SPIRVType associated. @@ -821,8 +861,10 @@ bool SPIRVPreLegalizer::runOnMachineFunction(MachineFunction &MF) { MachineIRBuilder MIB(MF); // a registry of target extension constants DenseMap<MachineInstr *, Type *> TargetExtConstTypes; - addConstantsToTrack(MF, GR, ST, TargetExtConstTypes); - foldConstantsIntoIntrinsics(MF); + // to keep record of tracked constants + SmallSet<Register, 4> TrackedConstRegs; + addConstantsToTrack(MF, GR, ST, TargetExtConstTypes, TrackedConstRegs); + foldConstantsIntoIntrinsics(MF, TrackedConstRegs); insertBitcasts(MF, GR, MIB); generateAssignInstrs(MF, GR, MIB, TargetExtConstTypes); processSwitches(MF, GR, MIB); diff --git a/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td b/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td index 318c5cebb7a4..96601dd8796c 100644 --- a/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td +++ b/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td @@ -302,6 +302,7 @@ defm SPV_INTEL_inline_assembly : ExtensionOperand<107>; defm SPV_INTEL_cache_controls : ExtensionOperand<108>; defm SPV_INTEL_global_variable_host_access : ExtensionOperand<109>; defm SPV_INTEL_global_variable_fpga_decorations : ExtensionOperand<110>; +defm SPV_KHR_cooperative_matrix : ExtensionOperand<111>; //===----------------------------------------------------------------------===// // Multiclass used to define Capabilities enum values and at the same time @@ -478,6 +479,7 @@ defm GlobalVariableHostAccessINTEL : CapabilityOperand<6187, 0, 0, [SPV_INTEL_gl defm HostAccessINTEL : CapabilityOperand<6188, 0, 0, [SPV_INTEL_global_variable_host_access], []>; defm GlobalVariableFPGADecorationsINTEL : CapabilityOperand<6189, 0, 0, [SPV_INTEL_global_variable_fpga_decorations], []>; defm CacheControlsINTEL : CapabilityOperand<6441, 0, 0, [SPV_INTEL_cache_controls], []>; +defm CooperativeMatrixKHR : CapabilityOperand<6022, 0, 0, [SPV_KHR_cooperative_matrix], []>; //===----------------------------------------------------------------------===// // Multiclass used to define SourceLanguage enum values and at the same time diff --git a/llvm/lib/Target/SPIRV/SPIRVUtils.cpp b/llvm/lib/Target/SPIRV/SPIRVUtils.cpp index c1b90b0e9d88..927683ad7e32 100644 --- a/llvm/lib/Target/SPIRV/SPIRVUtils.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVUtils.cpp @@ -253,7 +253,11 @@ SPIRV::MemorySemantics::MemorySemantics getMemSemantics(AtomicOrdering Ord) { MachineInstr *getDefInstrMaybeConstant(Register &ConstReg, const MachineRegisterInfo *MRI) { - MachineInstr *ConstInstr = MRI->getVRegDef(ConstReg); + MachineInstr *MI = MRI->getVRegDef(ConstReg); + MachineInstr *ConstInstr = + MI->getOpcode() == SPIRV::G_TRUNC || MI->getOpcode() == SPIRV::G_ZEXT + ? MRI->getVRegDef(MI->getOperand(1).getReg()) + : MI; if (auto *GI = dyn_cast<GIntrinsic>(ConstInstr)) { if (GI->is(Intrinsic::spv_track_constant)) { ConstReg = ConstInstr->getOperand(2).getReg(); diff --git a/llvm/lib/Target/SPIRV/SPIRVUtils.h b/llvm/lib/Target/SPIRV/SPIRVUtils.h index c131eecb1c13..12725d6bac14 100644 --- a/llvm/lib/Target/SPIRV/SPIRVUtils.h +++ b/llvm/lib/Target/SPIRV/SPIRVUtils.h @@ -160,5 +160,29 @@ inline Type *toTypedPointer(Type *Ty) { : Ty; } +inline Type *toTypedFunPointer(FunctionType *FTy) { + Type *OrigRetTy = FTy->getReturnType(); + Type *RetTy = toTypedPointer(OrigRetTy); + bool IsUntypedPtr = false; + for (Type *PTy : FTy->params()) { + if (isUntypedPointerTy(PTy)) { + IsUntypedPtr = true; + break; + } + } + if (!IsUntypedPtr && RetTy == OrigRetTy) + return FTy; + SmallVector<Type *> ParamTys; + for (Type *PTy : FTy->params()) + ParamTys.push_back(toTypedPointer(PTy)); + return FunctionType::get(RetTy, ParamTys, FTy->isVarArg()); +} + +inline const Type *unifyPtrType(const Type *Ty) { + if (auto FTy = dyn_cast<FunctionType>(Ty)) + return toTypedFunPointer(const_cast<FunctionType *>(FTy)); + return toTypedPointer(const_cast<Type *>(Ty)); +} + } // namespace llvm #endif // LLVM_LIB_TARGET_SPIRV_SPIRVUTILS_H |
