summaryrefslogtreecommitdiff
path: root/llvm/lib/Target/SPIRV
diff options
context:
space:
mode:
Diffstat (limited to 'llvm/lib/Target/SPIRV')
-rw-r--r--llvm/lib/Target/SPIRV/Analysis/SPIRVConvergenceRegionAnalysis.cpp2
-rw-r--r--llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVBaseInfo.cpp7
-rw-r--r--llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVBaseInfo.h5
-rw-r--r--llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVTargetStreamer.h2
-rw-r--r--llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp279
-rw-r--r--llvm/lib/Target/SPIRV/SPIRVBuiltins.td114
-rw-r--r--llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp2
-rw-r--r--llvm/lib/Target/SPIRV/SPIRVDuplicatesTracker.h226
-rw-r--r--llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp64
-rw-r--r--llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp94
-rw-r--r--llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h22
-rw-r--r--llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp1
-rw-r--r--llvm/lib/Target/SPIRV/SPIRVISelLowering.h5
-rw-r--r--llvm/lib/Target/SPIRV/SPIRVInstrInfo.td16
-rw-r--r--llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp95
-rw-r--r--llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp6
-rw-r--r--llvm/lib/Target/SPIRV/SPIRVMergeRegionExitTargets.cpp27
-rw-r--r--llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp9
-rw-r--r--llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp68
-rw-r--r--llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td2
-rw-r--r--llvm/lib/Target/SPIRV/SPIRVUtils.cpp6
-rw-r--r--llvm/lib/Target/SPIRV/SPIRVUtils.h24
22 files changed, 744 insertions, 332 deletions
diff --git a/llvm/lib/Target/SPIRV/Analysis/SPIRVConvergenceRegionAnalysis.cpp b/llvm/lib/Target/SPIRV/Analysis/SPIRVConvergenceRegionAnalysis.cpp
index 7f5f7d0b1e4d..25e285e35f93 100644
--- a/llvm/lib/Target/SPIRV/Analysis/SPIRVConvergenceRegionAnalysis.cpp
+++ b/llvm/lib/Target/SPIRV/Analysis/SPIRVConvergenceRegionAnalysis.cpp
@@ -138,7 +138,7 @@ ConvergenceRegion::ConvergenceRegion(
SmallPtrSet<BasicBlock *, 8> &&Blocks, SmallPtrSet<BasicBlock *, 2> &&Exits)
: DT(DT), LI(LI), ConvergenceToken(ConvergenceToken), Entry(Entry),
Exits(std::move(Exits)), Blocks(std::move(Blocks)) {
- for (auto *BB : this->Exits)
+ for ([[maybe_unused]] auto *BB : this->Exits)
assert(this->Blocks.count(BB) != 0);
assert(this->Blocks.count(this->Entry) != 0);
}
diff --git a/llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVBaseInfo.cpp b/llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVBaseInfo.cpp
index d96d2bf31b62..0f9a2a69e073 100644
--- a/llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVBaseInfo.cpp
+++ b/llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVBaseInfo.cpp
@@ -198,6 +198,8 @@ std::string getExtInstSetName(SPIRV::InstructionSet::InstructionSet Set) {
return "OpenCL.std";
case SPIRV::InstructionSet::GLSL_std_450:
return "GLSL.std.450";
+ case SPIRV::InstructionSet::NonSemantic_Shader_DebugInfo_100:
+ return "NonSemantic.Shader.DebugInfo.100";
case SPIRV::InstructionSet::SPV_AMD_shader_trinary_minmax:
return "SPV_AMD_shader_trinary_minmax";
}
@@ -206,8 +208,9 @@ std::string getExtInstSetName(SPIRV::InstructionSet::InstructionSet Set) {
SPIRV::InstructionSet::InstructionSet
getExtInstSetFromString(std::string SetName) {
- for (auto Set : {SPIRV::InstructionSet::GLSL_std_450,
- SPIRV::InstructionSet::OpenCL_std}) {
+ for (auto Set :
+ {SPIRV::InstructionSet::GLSL_std_450, SPIRV::InstructionSet::OpenCL_std,
+ SPIRV::InstructionSet::NonSemantic_Shader_DebugInfo_100}) {
if (SetName == getExtInstSetName(Set))
return Set;
}
diff --git a/llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVBaseInfo.h b/llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVBaseInfo.h
index 990eb1d230bc..44625793e941 100644
--- a/llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVBaseInfo.h
+++ b/llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVBaseInfo.h
@@ -197,6 +197,11 @@ namespace GLSLExtInst {
#include "SPIRVGenTables.inc"
} // namespace GLSLExtInst
+namespace NonSemanticExtInst {
+#define GET_NonSemanticExtInst_DECL
+#include "SPIRVGenTables.inc"
+} // namespace NonSemanticExtInst
+
namespace Opcode {
#define GET_Opcode_DECL
#include "SPIRVGenTables.inc"
diff --git a/llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVTargetStreamer.h b/llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVTargetStreamer.h
index 842958695e10..a6dd7138edf3 100644
--- a/llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVTargetStreamer.h
+++ b/llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVTargetStreamer.h
@@ -21,7 +21,7 @@ public:
~SPIRVTargetStreamer() override;
void changeSection(const MCSection *CurSection, MCSection *Section,
- const MCExpr *SubSection, raw_ostream &OS) override {}
+ uint32_t SubSection, raw_ostream &OS) override {}
};
} // namespace llvm
diff --git a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
index 49838e685a6d..0b93a4d85eed 100644
--- a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
@@ -300,6 +300,72 @@ lookupBuiltin(StringRef DemangledCall,
return nullptr;
}
+static MachineInstr *getBlockStructInstr(Register ParamReg,
+ MachineRegisterInfo *MRI) {
+ // We expect the following sequence of instructions:
+ // %0:_(pN) = G_INTRINSIC_W_SIDE_EFFECTS intrinsic(@llvm.spv.alloca)
+ // or = G_GLOBAL_VALUE @block_literal_global
+ // %1:_(pN) = G_INTRINSIC_W_SIDE_EFFECTS intrinsic(@llvm.spv.bitcast), %0
+ // %2:_(p4) = G_ADDRSPACE_CAST %1:_(pN)
+ MachineInstr *MI = MRI->getUniqueVRegDef(ParamReg);
+ assert(MI->getOpcode() == TargetOpcode::G_ADDRSPACE_CAST &&
+ MI->getOperand(1).isReg());
+ Register BitcastReg = MI->getOperand(1).getReg();
+ MachineInstr *BitcastMI = MRI->getUniqueVRegDef(BitcastReg);
+ assert(isSpvIntrinsic(*BitcastMI, Intrinsic::spv_bitcast) &&
+ BitcastMI->getOperand(2).isReg());
+ Register ValueReg = BitcastMI->getOperand(2).getReg();
+ MachineInstr *ValueMI = MRI->getUniqueVRegDef(ValueReg);
+ return ValueMI;
+}
+
+// Return an integer constant corresponding to the given register and
+// defined in spv_track_constant.
+// TODO: maybe unify with prelegalizer pass.
+static unsigned getConstFromIntrinsic(Register Reg, MachineRegisterInfo *MRI) {
+ MachineInstr *DefMI = MRI->getUniqueVRegDef(Reg);
+ assert(isSpvIntrinsic(*DefMI, Intrinsic::spv_track_constant) &&
+ DefMI->getOperand(2).isReg());
+ MachineInstr *DefMI2 = MRI->getUniqueVRegDef(DefMI->getOperand(2).getReg());
+ assert(DefMI2->getOpcode() == TargetOpcode::G_CONSTANT &&
+ DefMI2->getOperand(1).isCImm());
+ return DefMI2->getOperand(1).getCImm()->getValue().getZExtValue();
+}
+
+// Return type of the instruction result from spv_assign_type intrinsic.
+// TODO: maybe unify with prelegalizer pass.
+static const Type *getMachineInstrType(MachineInstr *MI) {
+ MachineInstr *NextMI = MI->getNextNode();
+ if (!NextMI)
+ return nullptr;
+ if (isSpvIntrinsic(*NextMI, Intrinsic::spv_assign_name))
+ if ((NextMI = NextMI->getNextNode()) == nullptr)
+ return nullptr;
+ Register ValueReg = MI->getOperand(0).getReg();
+ if ((!isSpvIntrinsic(*NextMI, Intrinsic::spv_assign_type) &&
+ !isSpvIntrinsic(*NextMI, Intrinsic::spv_assign_ptr_type)) ||
+ NextMI->getOperand(1).getReg() != ValueReg)
+ return nullptr;
+ Type *Ty = getMDOperandAsType(NextMI->getOperand(2).getMetadata(), 0);
+ assert(Ty && "Type is expected");
+ return Ty;
+}
+
+static const Type *getBlockStructType(Register ParamReg,
+ MachineRegisterInfo *MRI) {
+ // In principle, this information should be passed to us from Clang via
+ // an elementtype attribute. However, said attribute requires that
+ // the function call be an intrinsic, which is not. Instead, we rely on being
+ // able to trace this to the declaration of a variable: OpenCL C specification
+ // section 6.12.5 should guarantee that we can do this.
+ MachineInstr *MI = getBlockStructInstr(ParamReg, MRI);
+ if (MI->getOpcode() == TargetOpcode::G_GLOBAL_VALUE)
+ return MI->getOperand(1).getGlobal()->getType();
+ assert(isSpvIntrinsic(*MI, Intrinsic::spv_alloca) &&
+ "Blocks in OpenCL C must be traceable to allocation site");
+ return getMachineInstrType(MI);
+}
+
//===----------------------------------------------------------------------===//
// Helper functions for building misc instructions
//===----------------------------------------------------------------------===//
@@ -492,16 +558,21 @@ static Register buildMemSemanticsReg(Register SemanticsRegister,
static bool buildOpFromWrapper(MachineIRBuilder &MIRBuilder, unsigned Opcode,
const SPIRV::IncomingCall *Call,
- Register TypeReg = Register(0)) {
+ Register TypeReg,
+ ArrayRef<uint32_t> ImmArgs = {}) {
MachineRegisterInfo *MRI = MIRBuilder.getMRI();
auto MIB = MIRBuilder.buildInstr(Opcode);
if (TypeReg.isValid())
MIB.addDef(Call->ReturnRegister).addUse(TypeReg);
- for (Register ArgReg : Call->Arguments) {
+ unsigned Sz = Call->Arguments.size() - ImmArgs.size();
+ for (unsigned i = 0; i < Sz; ++i) {
+ Register ArgReg = Call->Arguments[i];
if (!MRI->getRegClassOrNull(ArgReg))
MRI->setRegClass(ArgReg, &SPIRV::IDRegClass);
MIB.addUse(ArgReg);
}
+ for (uint32_t ImmArg : ImmArgs)
+ MIB.addImm(ImmArg);
return true;
}
@@ -509,7 +580,7 @@ static bool buildOpFromWrapper(MachineIRBuilder &MIRBuilder, unsigned Opcode,
static bool buildAtomicInitInst(const SPIRV::IncomingCall *Call,
MachineIRBuilder &MIRBuilder) {
if (Call->isSpirvOp())
- return buildOpFromWrapper(MIRBuilder, SPIRV::OpStore, Call);
+ return buildOpFromWrapper(MIRBuilder, SPIRV::OpStore, Call, Register(0));
assert(Call->Arguments.size() == 2 &&
"Need 2 arguments for atomic init translation");
@@ -567,7 +638,7 @@ static bool buildAtomicStoreInst(const SPIRV::IncomingCall *Call,
MachineIRBuilder &MIRBuilder,
SPIRVGlobalRegistry *GR) {
if (Call->isSpirvOp())
- return buildOpFromWrapper(MIRBuilder, SPIRV::OpAtomicStore, Call);
+ return buildOpFromWrapper(MIRBuilder, SPIRV::OpAtomicStore, Call, Register(0));
Register ScopeRegister =
buildConstantIntReg(SPIRV::Scope::Device, MIRBuilder, GR);
@@ -694,7 +765,7 @@ static bool buildAtomicCompareExchangeInst(
return true;
}
-/// Helper function for building an atomic load instruction.
+/// Helper function for building atomic instructions.
static bool buildAtomicRMWInst(const SPIRV::IncomingCall *Call, unsigned Opcode,
MachineIRBuilder &MIRBuilder,
SPIRVGlobalRegistry *GR) {
@@ -719,13 +790,36 @@ static bool buildAtomicRMWInst(const SPIRV::IncomingCall *Call, unsigned Opcode,
MemSemanticsReg = buildMemSemanticsReg(MemSemanticsReg, PtrRegister,
Semantics, MIRBuilder, GR);
MRI->setRegClass(Call->Arguments[1], &SPIRV::IDRegClass);
+ Register ValueReg = Call->Arguments[1];
+ Register ValueTypeReg = GR->getSPIRVTypeID(Call->ReturnType);
+ // support cl_ext_float_atomics
+ if (Call->ReturnType->getOpcode() == SPIRV::OpTypeFloat) {
+ if (Opcode == SPIRV::OpAtomicIAdd) {
+ Opcode = SPIRV::OpAtomicFAddEXT;
+ } else if (Opcode == SPIRV::OpAtomicISub) {
+ // Translate OpAtomicISub applied to a floating type argument to
+ // OpAtomicFAddEXT with the negative value operand
+ Opcode = SPIRV::OpAtomicFAddEXT;
+ Register NegValueReg =
+ MRI->createGenericVirtualRegister(MRI->getType(ValueReg));
+ MRI->setRegClass(NegValueReg, &SPIRV::IDRegClass);
+ GR->assignSPIRVTypeToVReg(Call->ReturnType, NegValueReg,
+ MIRBuilder.getMF());
+ MIRBuilder.buildInstr(TargetOpcode::G_FNEG)
+ .addDef(NegValueReg)
+ .addUse(ValueReg);
+ insertAssignInstr(NegValueReg, nullptr, Call->ReturnType, GR, MIRBuilder,
+ MIRBuilder.getMF().getRegInfo());
+ ValueReg = NegValueReg;
+ }
+ }
MIRBuilder.buildInstr(Opcode)
.addDef(Call->ReturnRegister)
- .addUse(GR->getSPIRVTypeID(Call->ReturnType))
+ .addUse(ValueTypeReg)
.addUse(PtrRegister)
.addUse(ScopeRegister)
.addUse(MemSemanticsReg)
- .addUse(Call->Arguments[1]);
+ .addUse(ValueReg);
return true;
}
@@ -804,7 +898,7 @@ static bool buildBarrierInst(const SPIRV::IncomingCall *Call, unsigned Opcode,
MachineIRBuilder &MIRBuilder,
SPIRVGlobalRegistry *GR) {
if (Call->isSpirvOp())
- return buildOpFromWrapper(MIRBuilder, Opcode, Call);
+ return buildOpFromWrapper(MIRBuilder, Opcode, Call, Register(0));
MachineRegisterInfo *MRI = MIRBuilder.getMRI();
unsigned MemFlags = getIConstVal(Call->Arguments[0], MRI);
@@ -949,7 +1043,35 @@ static bool generateGroupInst(const SPIRV::IncomingCall *Call,
const SPIRV::DemangledBuiltin *Builtin = Call->Builtin;
const SPIRV::GroupBuiltin *GroupBuiltin =
SPIRV::lookupGroupBuiltin(Builtin->Name);
+
MachineRegisterInfo *MRI = MIRBuilder.getMRI();
+ if (Call->isSpirvOp()) {
+ if (GroupBuiltin->NoGroupOperation)
+ return buildOpFromWrapper(MIRBuilder, GroupBuiltin->Opcode, Call,
+ GR->getSPIRVTypeID(Call->ReturnType));
+
+ // Group Operation is a literal
+ Register GroupOpReg = Call->Arguments[1];
+ const MachineInstr *MI = getDefInstrMaybeConstant(GroupOpReg, MRI);
+ if (!MI || MI->getOpcode() != TargetOpcode::G_CONSTANT)
+ report_fatal_error(
+ "Group Operation parameter must be an integer constant");
+ uint64_t GrpOp = MI->getOperand(1).getCImm()->getValue().getZExtValue();
+ Register ScopeReg = Call->Arguments[0];
+ if (!MRI->getRegClassOrNull(ScopeReg))
+ MRI->setRegClass(ScopeReg, &SPIRV::IDRegClass);
+ Register ValueReg = Call->Arguments[2];
+ if (!MRI->getRegClassOrNull(ValueReg))
+ MRI->setRegClass(ValueReg, &SPIRV::IDRegClass);
+ MIRBuilder.buildInstr(GroupBuiltin->Opcode)
+ .addDef(Call->ReturnRegister)
+ .addUse(GR->getSPIRVTypeID(Call->ReturnType))
+ .addUse(ScopeReg)
+ .addImm(GrpOp)
+ .addUse(ValueReg);
+ return true;
+ }
+
Register Arg0;
if (GroupBuiltin->HasBoolArg) {
Register ConstRegister = Call->Arguments[0];
@@ -1371,6 +1493,14 @@ static bool generateBarrierInst(const SPIRV::IncomingCall *Call,
return buildBarrierInst(Call, Opcode, MIRBuilder, GR);
}
+static bool generateCastToPtrInst(const SPIRV::IncomingCall *Call,
+ MachineIRBuilder &MIRBuilder) {
+ MIRBuilder.buildInstr(TargetOpcode::G_ADDRSPACE_CAST)
+ .addDef(Call->ReturnRegister)
+ .addUse(Call->Arguments[0]);
+ return true;
+}
+
static bool generateDotOrFMulInst(const SPIRV::IncomingCall *Call,
MachineIRBuilder &MIRBuilder,
SPIRVGlobalRegistry *GR) {
@@ -1722,6 +1852,45 @@ static bool generateSelectInst(const SPIRV::IncomingCall *Call,
return true;
}
+static bool generateConstructInst(const SPIRV::IncomingCall *Call,
+ MachineIRBuilder &MIRBuilder,
+ SPIRVGlobalRegistry *GR) {
+ return buildOpFromWrapper(MIRBuilder, SPIRV::OpCompositeConstruct, Call,
+ GR->getSPIRVTypeID(Call->ReturnType));
+}
+
+static bool generateCoopMatrInst(const SPIRV::IncomingCall *Call,
+ MachineIRBuilder &MIRBuilder,
+ SPIRVGlobalRegistry *GR) {
+ const SPIRV::DemangledBuiltin *Builtin = Call->Builtin;
+ unsigned Opcode =
+ SPIRV::lookupNativeBuiltin(Builtin->Name, Builtin->Set)->Opcode;
+ bool IsSet = Opcode != SPIRV::OpCooperativeMatrixStoreKHR;
+ unsigned ArgSz = Call->Arguments.size();
+ unsigned LiteralIdx = 0;
+ if (Opcode == SPIRV::OpCooperativeMatrixLoadKHR && ArgSz > 3)
+ LiteralIdx = 3;
+ else if (Opcode == SPIRV::OpCooperativeMatrixStoreKHR && ArgSz > 4)
+ LiteralIdx = 4;
+ SmallVector<uint32_t, 1> ImmArgs;
+ MachineRegisterInfo *MRI = MIRBuilder.getMRI();
+ if (LiteralIdx > 0)
+ ImmArgs.push_back(getConstFromIntrinsic(Call->Arguments[LiteralIdx], MRI));
+ Register TypeReg = GR->getSPIRVTypeID(Call->ReturnType);
+ if (Opcode == SPIRV::OpCooperativeMatrixLengthKHR) {
+ SPIRVType *CoopMatrType = GR->getSPIRVTypeForVReg(Call->Arguments[0]);
+ if (!CoopMatrType)
+ report_fatal_error("Can't find a register's type definition");
+ MIRBuilder.buildInstr(Opcode)
+ .addDef(Call->ReturnRegister)
+ .addUse(TypeReg)
+ .addUse(CoopMatrType->getOperand(0).getReg());
+ return true;
+ }
+ return buildOpFromWrapper(MIRBuilder, Opcode, Call,
+ IsSet ? TypeReg : Register(0), ImmArgs);
+}
+
static bool generateSpecConstantInst(const SPIRV::IncomingCall *Call,
MachineIRBuilder &MIRBuilder,
SPIRVGlobalRegistry *GR) {
@@ -1826,7 +1995,10 @@ static bool buildNDRange(const SPIRV::IncomingCall *Call,
.addDef(GlobalWorkSize)
.addUse(GR->getSPIRVTypeID(SpvFieldTy))
.addUse(GWSPtr);
- Const = GR->getOrCreateConsIntArray(0, MIRBuilder, SpvFieldTy);
+ const SPIRVSubtarget &ST =
+ cast<SPIRVSubtarget>(MIRBuilder.getMF().getSubtarget());
+ Const = GR->getOrCreateConstIntArray(0, Size, *MIRBuilder.getInsertPt(),
+ SpvFieldTy, *ST.getInstrInfo());
} else {
Const = GR->buildConstantInt(0, MIRBuilder, SpvTy);
}
@@ -1847,68 +2019,6 @@ static bool buildNDRange(const SPIRV::IncomingCall *Call,
.addUse(TmpReg);
}
-static MachineInstr *getBlockStructInstr(Register ParamReg,
- MachineRegisterInfo *MRI) {
- // We expect the following sequence of instructions:
- // %0:_(pN) = G_INTRINSIC_W_SIDE_EFFECTS intrinsic(@llvm.spv.alloca)
- // or = G_GLOBAL_VALUE @block_literal_global
- // %1:_(pN) = G_INTRINSIC_W_SIDE_EFFECTS intrinsic(@llvm.spv.bitcast), %0
- // %2:_(p4) = G_ADDRSPACE_CAST %1:_(pN)
- MachineInstr *MI = MRI->getUniqueVRegDef(ParamReg);
- assert(MI->getOpcode() == TargetOpcode::G_ADDRSPACE_CAST &&
- MI->getOperand(1).isReg());
- Register BitcastReg = MI->getOperand(1).getReg();
- MachineInstr *BitcastMI = MRI->getUniqueVRegDef(BitcastReg);
- assert(isSpvIntrinsic(*BitcastMI, Intrinsic::spv_bitcast) &&
- BitcastMI->getOperand(2).isReg());
- Register ValueReg = BitcastMI->getOperand(2).getReg();
- MachineInstr *ValueMI = MRI->getUniqueVRegDef(ValueReg);
- return ValueMI;
-}
-
-// Return an integer constant corresponding to the given register and
-// defined in spv_track_constant.
-// TODO: maybe unify with prelegalizer pass.
-static unsigned getConstFromIntrinsic(Register Reg, MachineRegisterInfo *MRI) {
- MachineInstr *DefMI = MRI->getUniqueVRegDef(Reg);
- assert(isSpvIntrinsic(*DefMI, Intrinsic::spv_track_constant) &&
- DefMI->getOperand(2).isReg());
- MachineInstr *DefMI2 = MRI->getUniqueVRegDef(DefMI->getOperand(2).getReg());
- assert(DefMI2->getOpcode() == TargetOpcode::G_CONSTANT &&
- DefMI2->getOperand(1).isCImm());
- return DefMI2->getOperand(1).getCImm()->getValue().getZExtValue();
-}
-
-// Return type of the instruction result from spv_assign_type intrinsic.
-// TODO: maybe unify with prelegalizer pass.
-static const Type *getMachineInstrType(MachineInstr *MI) {
- MachineInstr *NextMI = MI->getNextNode();
- if (isSpvIntrinsic(*NextMI, Intrinsic::spv_assign_name))
- NextMI = NextMI->getNextNode();
- Register ValueReg = MI->getOperand(0).getReg();
- if (!isSpvIntrinsic(*NextMI, Intrinsic::spv_assign_type) ||
- NextMI->getOperand(1).getReg() != ValueReg)
- return nullptr;
- Type *Ty = getMDOperandAsType(NextMI->getOperand(2).getMetadata(), 0);
- assert(Ty && "Type is expected");
- return Ty;
-}
-
-static const Type *getBlockStructType(Register ParamReg,
- MachineRegisterInfo *MRI) {
- // In principle, this information should be passed to us from Clang via
- // an elementtype attribute. However, said attribute requires that
- // the function call be an intrinsic, which is not. Instead, we rely on being
- // able to trace this to the declaration of a variable: OpenCL C specification
- // section 6.12.5 should guarantee that we can do this.
- MachineInstr *MI = getBlockStructInstr(ParamReg, MRI);
- if (MI->getOpcode() == TargetOpcode::G_GLOBAL_VALUE)
- return MI->getOperand(1).getGlobal()->getType();
- assert(isSpvIntrinsic(*MI, Intrinsic::spv_alloca) &&
- "Blocks in OpenCL C must be traceable to allocation site");
- return getMachineInstrType(MI);
-}
-
// TODO: maybe move to the global register.
static SPIRVType *
getOrCreateSPIRVDeviceEventPointer(MachineIRBuilder &MIRBuilder,
@@ -2322,6 +2432,8 @@ std::optional<bool> lowerBuiltin(const StringRef DemangledCall,
return generateAtomicFloatingInst(Call.get(), MIRBuilder, GR);
case SPIRV::Barrier:
return generateBarrierInst(Call.get(), MIRBuilder, GR);
+ case SPIRV::CastToPtr:
+ return generateCastToPtrInst(Call.get(), MIRBuilder);
case SPIRV::Dot:
return generateDotOrFMulInst(Call.get(), MIRBuilder, GR);
case SPIRV::Wave:
@@ -2340,6 +2452,8 @@ std::optional<bool> lowerBuiltin(const StringRef DemangledCall,
return generateSampleImageInst(DemangledCall, Call.get(), MIRBuilder, GR);
case SPIRV::Select:
return generateSelectInst(Call.get(), MIRBuilder);
+ case SPIRV::Construct:
+ return generateConstructInst(Call.get(), MIRBuilder, GR);
case SPIRV::SpecConstant:
return generateSpecConstantInst(Call.get(), MIRBuilder, GR);
case SPIRV::Enqueue:
@@ -2358,6 +2472,8 @@ std::optional<bool> lowerBuiltin(const StringRef DemangledCall,
return generateGroupUniformInst(Call.get(), MIRBuilder, GR);
case SPIRV::KernelClock:
return generateKernelClockInst(Call.get(), MIRBuilder, GR);
+ case SPIRV::CoopMatr:
+ return generateCoopMatrInst(Call.get(), MIRBuilder, GR);
}
return false;
}
@@ -2376,7 +2492,7 @@ Type *parseBuiltinCallArgumentBaseType(const StringRef DemangledCall,
if (hasBuiltinTypePrefix(TypeStr)) {
// OpenCL builtin types in demangled call strings have the following format:
// e.g. ocl_image2d_ro
- bool IsOCLBuiltinType = TypeStr.consume_front("ocl_");
+ [[maybe_unused]] bool IsOCLBuiltinType = TypeStr.consume_front("ocl_");
assert(IsOCLBuiltinType && "Invalid OpenCL builtin prefix");
// Check if this is pointer to a builtin type and not just pointer
@@ -2482,6 +2598,22 @@ static SPIRVType *getPipeType(const TargetExtType *ExtensionType,
ExtensionType->getIntParameter(0)));
}
+static SPIRVType *getCoopMatrType(const TargetExtType *ExtensionType,
+ MachineIRBuilder &MIRBuilder,
+ SPIRVGlobalRegistry *GR) {
+ assert(ExtensionType->getNumIntParameters() == 4 &&
+ "Invalid number of parameters for SPIR-V coop matrices builtin!");
+ assert(ExtensionType->getNumTypeParameters() == 1 &&
+ "SPIR-V coop matrices builtin type must have a type parameter!");
+ const SPIRVType *ElemType =
+ GR->getOrCreateSPIRVType(ExtensionType->getTypeParameter(0), MIRBuilder);
+ // Create or get an existing type from GlobalRegistry.
+ return GR->getOrCreateOpTypeCoopMatr(
+ MIRBuilder, ExtensionType, ElemType, ExtensionType->getIntParameter(0),
+ ExtensionType->getIntParameter(1), ExtensionType->getIntParameter(2),
+ ExtensionType->getIntParameter(3));
+}
+
static SPIRVType *
getImageType(const TargetExtType *ExtensionType,
const SPIRV::AccessQualifier::AccessQualifier Qualifier,
@@ -2612,6 +2744,9 @@ SPIRVType *lowerBuiltinType(const Type *OpaqueType,
case SPIRV::OpTypeSampledImage:
TargetType = getSampledImageType(BuiltinType, MIRBuilder, GR);
break;
+ case SPIRV::OpTypeCooperativeMatrixKHR:
+ TargetType = getCoopMatrType(BuiltinType, MIRBuilder, GR);
+ break;
default:
TargetType =
getNonParameterizedType(BuiltinType, TypeRecord, MIRBuilder, GR);
diff --git a/llvm/lib/Target/SPIRV/SPIRVBuiltins.td b/llvm/lib/Target/SPIRV/SPIRVBuiltins.td
index edc9e1a33d9f..fb88332ab890 100644
--- a/llvm/lib/Target/SPIRV/SPIRVBuiltins.td
+++ b/llvm/lib/Target/SPIRV/SPIRVBuiltins.td
@@ -26,6 +26,7 @@ class InstructionSet<bits<32> value> {
def OpenCL_std : InstructionSet<0>;
def GLSL_std_450 : InstructionSet<1>;
def SPV_AMD_shader_trinary_minmax : InstructionSet<2>;
+def NonSemantic_Shader_DebugInfo_100 : InstructionSet<3>;
// Define various builtin groups
def BuiltinGroup : GenericEnum {
@@ -59,6 +60,9 @@ def IntelSubgroups : BuiltinGroup;
def AtomicFloating : BuiltinGroup;
def GroupUniform : BuiltinGroup;
def KernelClock : BuiltinGroup;
+def CastToPtr : BuiltinGroup;
+def Construct : BuiltinGroup;
+def CoopMatr : BuiltinGroup;
//===----------------------------------------------------------------------===//
// Class defining a demangled builtin record. The information in the record
@@ -113,6 +117,9 @@ def : DemangledBuiltin<"__spirv_ImageSampleExplicitLod", OpenCL_std, SampleImage
// Select builtin record:
def : DemangledBuiltin<"__spirv_Select", OpenCL_std, Select, 3, 3>;
+// Composite Construct builtin record:
+def : DemangledBuiltin<"__spirv_CompositeConstruct", OpenCL_std, Construct, 1, 0>;
+
//===----------------------------------------------------------------------===//
// Class defining an extended builtin record used for lowering into an
// OpExtInst instruction.
@@ -170,6 +177,17 @@ class GLSLExtInst<string name, bits<32> value> {
bits<32> Value = value;
}
+def NonSemanticExtInst : GenericEnum {
+ let FilterClass = "NonSemanticExtInst";
+ let NameField = "Name";
+ let ValueField = "Value";
+}
+
+class NonSemanticExtInst<string name, bits<32> value> {
+ string Name = name;
+ bits<32> Value = value;
+}
+
// Multiclass used to define at the same time both a demangled builtin record
// and a corresponding extended builtin record.
multiclass DemangledExtendedBuiltin<string name, InstructionSet set, int number> {
@@ -183,6 +201,10 @@ multiclass DemangledExtendedBuiltin<string name, InstructionSet set, int number>
if !eq(set, GLSL_std_450) then {
def : GLSLExtInst<name, number>;
}
+
+ if !eq(set, NonSemantic_Shader_DebugInfo_100) then {
+ def : NonSemanticExtInst<name, number>;
+ }
}
// Extended builtin records:
@@ -430,6 +452,50 @@ defm : DemangledExtendedBuiltin<"NMin", GLSL_std_450, 79>;
defm : DemangledExtendedBuiltin<"NMax", GLSL_std_450, 80>;
defm : DemangledExtendedBuiltin<"NClamp", GLSL_std_450, 81>;
+defm : DemangledExtendedBuiltin<"DebugInfoNone", NonSemantic_Shader_DebugInfo_100, 0>;
+defm : DemangledExtendedBuiltin<"DebugCompilationUnit", NonSemantic_Shader_DebugInfo_100, 1>;
+defm : DemangledExtendedBuiltin<"DebugTypeBasic", NonSemantic_Shader_DebugInfo_100, 2>;
+defm : DemangledExtendedBuiltin<"DebugTypePointer", NonSemantic_Shader_DebugInfo_100, 3>;
+defm : DemangledExtendedBuiltin<"DebugTypeQualifier", NonSemantic_Shader_DebugInfo_100, 4>;
+defm : DemangledExtendedBuiltin<"DebugTypeArray", NonSemantic_Shader_DebugInfo_100, 5>;
+defm : DemangledExtendedBuiltin<"DebugTypeVector", NonSemantic_Shader_DebugInfo_100, 6>;
+defm : DemangledExtendedBuiltin<"DebugTypedef", NonSemantic_Shader_DebugInfo_100, 7>;
+defm : DemangledExtendedBuiltin<"DebugTypeFunction", NonSemantic_Shader_DebugInfo_100, 8>;
+defm : DemangledExtendedBuiltin<"DebugTypeEnum", NonSemantic_Shader_DebugInfo_100, 9>;
+defm : DemangledExtendedBuiltin<"DebugTypeComposite", NonSemantic_Shader_DebugInfo_100, 10>;
+defm : DemangledExtendedBuiltin<"DebugTypeMember", NonSemantic_Shader_DebugInfo_100, 11>;
+defm : DemangledExtendedBuiltin<"DebugTypeInheritance", NonSemantic_Shader_DebugInfo_100, 12>;
+defm : DemangledExtendedBuiltin<"DebugTypePtrToMember", NonSemantic_Shader_DebugInfo_100, 13>;
+defm : DemangledExtendedBuiltin<"DebugTypeTemplate", NonSemantic_Shader_DebugInfo_100, 14>;
+defm : DemangledExtendedBuiltin<"DebugTypeTemplateParameter", NonSemantic_Shader_DebugInfo_100, 15>;
+defm : DemangledExtendedBuiltin<"DebugTypeTemplateTemplateParameter", NonSemantic_Shader_DebugInfo_100, 16>;
+defm : DemangledExtendedBuiltin<"DebugTypeTemplateParameterPack", NonSemantic_Shader_DebugInfo_100, 17>;
+defm : DemangledExtendedBuiltin<"DebugGlobalVariable", NonSemantic_Shader_DebugInfo_100, 18>;
+defm : DemangledExtendedBuiltin<"DebugFunctionDeclaration", NonSemantic_Shader_DebugInfo_100, 19>;
+defm : DemangledExtendedBuiltin<"DebugFunction", NonSemantic_Shader_DebugInfo_100, 20>;
+defm : DemangledExtendedBuiltin<"DebugLexicalBlock", NonSemantic_Shader_DebugInfo_100, 21>;
+defm : DemangledExtendedBuiltin<"DebugLexicalBlockDiscriminator", NonSemantic_Shader_DebugInfo_100, 22>;
+defm : DemangledExtendedBuiltin<"DebugScope", NonSemantic_Shader_DebugInfo_100, 23>;
+defm : DemangledExtendedBuiltin<"DebugNoScope", NonSemantic_Shader_DebugInfo_100, 24>;
+defm : DemangledExtendedBuiltin<"DebugInlinedAt", NonSemantic_Shader_DebugInfo_100, 25>;
+defm : DemangledExtendedBuiltin<"DebugLocalVariable", NonSemantic_Shader_DebugInfo_100, 26>;
+defm : DemangledExtendedBuiltin<"DebugInlinedVariable", NonSemantic_Shader_DebugInfo_100, 27>;
+defm : DemangledExtendedBuiltin<"DebugDeclare", NonSemantic_Shader_DebugInfo_100, 28>;
+defm : DemangledExtendedBuiltin<"DebugValue", NonSemantic_Shader_DebugInfo_100, 29>;
+defm : DemangledExtendedBuiltin<"DebugOperation", NonSemantic_Shader_DebugInfo_100, 30>;
+defm : DemangledExtendedBuiltin<"DebugExpression", NonSemantic_Shader_DebugInfo_100, 31>;
+defm : DemangledExtendedBuiltin<"DebugMacroDef", NonSemantic_Shader_DebugInfo_100, 32>;
+defm : DemangledExtendedBuiltin<"DebugMacroUndef", NonSemantic_Shader_DebugInfo_100, 33>;
+defm : DemangledExtendedBuiltin<"DebugImportedEntity", NonSemantic_Shader_DebugInfo_100, 34>;
+defm : DemangledExtendedBuiltin<"DebugSource", NonSemantic_Shader_DebugInfo_100, 35>;
+defm : DemangledExtendedBuiltin<"DebugFunctionDefinition", NonSemantic_Shader_DebugInfo_100, 101>;
+defm : DemangledExtendedBuiltin<"DebugSourceContinued", NonSemantic_Shader_DebugInfo_100, 102>;
+defm : DemangledExtendedBuiltin<"DebugLine", NonSemantic_Shader_DebugInfo_100, 103>;
+defm : DemangledExtendedBuiltin<"DebugNoLine", NonSemantic_Shader_DebugInfo_100, 104>;
+defm : DemangledExtendedBuiltin<"DebugBuildIdentifier", NonSemantic_Shader_DebugInfo_100, 105>;
+defm : DemangledExtendedBuiltin<"DebugStoragePath", NonSemantic_Shader_DebugInfo_100, 106>;
+defm : DemangledExtendedBuiltin<"DebugEntryPoint", NonSemantic_Shader_DebugInfo_100, 107>;
+defm : DemangledExtendedBuiltin<"DebugTypeMatrix", NonSemantic_Shader_DebugInfo_100, 108>;
//===----------------------------------------------------------------------===//
// Class defining an native builtin record used for direct translation into a
// SPIR-V instruction.
@@ -532,6 +598,7 @@ defm : DemangledNativeBuiltin<"__spirv_AtomicAnd", OpenCL_std, Atomic, 4, 4, OpA
defm : DemangledNativeBuiltin<"atomic_exchange", OpenCL_std, Atomic, 2, 4, OpAtomicExchange>;
defm : DemangledNativeBuiltin<"atomic_exchange_explicit", OpenCL_std, Atomic, 2, 4, OpAtomicExchange>;
defm : DemangledNativeBuiltin<"AtomicEx__spirv_change", OpenCL_std, Atomic, 2, 4, OpAtomicExchange>;
+defm : DemangledNativeBuiltin<"__spirv_AtomicExchange", OpenCL_std, Atomic, 4, 4, OpAtomicExchange>;
defm : DemangledNativeBuiltin<"atomic_work_item_fence", OpenCL_std, Atomic, 1, 3, OpMemoryBarrier>;
defm : DemangledNativeBuiltin<"__spirv_MemoryBarrier", OpenCL_std, Atomic, 2, 2, OpMemoryBarrier>;
defm : DemangledNativeBuiltin<"atomic_fetch_add", OpenCL_std, Atomic, 2, 4, OpAtomicIAdd>;
@@ -539,11 +606,11 @@ defm : DemangledNativeBuiltin<"atomic_fetch_sub", OpenCL_std, Atomic, 2, 4, OpAt
defm : DemangledNativeBuiltin<"atomic_fetch_or", OpenCL_std, Atomic, 2, 4, OpAtomicOr>;
defm : DemangledNativeBuiltin<"atomic_fetch_xor", OpenCL_std, Atomic, 2, 4, OpAtomicXor>;
defm : DemangledNativeBuiltin<"atomic_fetch_and", OpenCL_std, Atomic, 2, 4, OpAtomicAnd>;
-defm : DemangledNativeBuiltin<"atomic_fetch_add_explicit", OpenCL_std, Atomic, 4, 6, OpAtomicIAdd>;
-defm : DemangledNativeBuiltin<"atomic_fetch_sub_explicit", OpenCL_std, Atomic, 4, 6, OpAtomicISub>;
-defm : DemangledNativeBuiltin<"atomic_fetch_or_explicit", OpenCL_std, Atomic, 4, 6, OpAtomicOr>;
-defm : DemangledNativeBuiltin<"atomic_fetch_xor_explicit", OpenCL_std, Atomic, 4, 6, OpAtomicXor>;
-defm : DemangledNativeBuiltin<"atomic_fetch_and_explicit", OpenCL_std, Atomic, 4, 6, OpAtomicAnd>;
+defm : DemangledNativeBuiltin<"atomic_fetch_add_explicit", OpenCL_std, Atomic, 3, 4, OpAtomicIAdd>;
+defm : DemangledNativeBuiltin<"atomic_fetch_sub_explicit", OpenCL_std, Atomic, 3, 4, OpAtomicISub>;
+defm : DemangledNativeBuiltin<"atomic_fetch_or_explicit", OpenCL_std, Atomic, 3, 4, OpAtomicOr>;
+defm : DemangledNativeBuiltin<"atomic_fetch_xor_explicit", OpenCL_std, Atomic, 3, 4, OpAtomicXor>;
+defm : DemangledNativeBuiltin<"atomic_fetch_and_explicit", OpenCL_std, Atomic, 3, 4, OpAtomicAnd>;
defm : DemangledNativeBuiltin<"atomic_flag_test_and_set", OpenCL_std, Atomic, 1, 1, OpAtomicFlagTestAndSet>;
defm : DemangledNativeBuiltin<"__spirv_AtomicFlagTestAndSet", OpenCL_std, Atomic, 3, 3, OpAtomicFlagTestAndSet>;
defm : DemangledNativeBuiltin<"atomic_flag_test_and_set_explicit", OpenCL_std, Atomic, 2, 3, OpAtomicFlagTestAndSet>;
@@ -595,6 +662,23 @@ defm : DemangledNativeBuiltin<"__spirv_GroupWaitEvents", OpenCL_std, AsyncCopy,
defm : DemangledNativeBuiltin<"__spirv_Load", OpenCL_std, LoadStore, 1, 3, OpLoad>;
defm : DemangledNativeBuiltin<"__spirv_Store", OpenCL_std, LoadStore, 2, 4, OpStore>;
+// Address Space Qualifier Functions/Pointers Conversion Instructions:
+defm : DemangledNativeBuiltin<"to_global", OpenCL_std, CastToPtr, 1, 1, OpGenericCastToPtr>;
+defm : DemangledNativeBuiltin<"to_local", OpenCL_std, CastToPtr, 1, 1, OpGenericCastToPtr>;
+defm : DemangledNativeBuiltin<"to_private", OpenCL_std, CastToPtr, 1, 1, OpGenericCastToPtr>;
+defm : DemangledNativeBuiltin<"__spirv_GenericCastToPtr_ToGlobal", OpenCL_std, CastToPtr, 2, 2, OpGenericCastToPtr>;
+defm : DemangledNativeBuiltin<"__spirv_GenericCastToPtr_ToLocal", OpenCL_std, CastToPtr, 2, 2, OpGenericCastToPtr>;
+defm : DemangledNativeBuiltin<"__spirv_GenericCastToPtr_ToPrivate", OpenCL_std, CastToPtr, 2, 2, OpGenericCastToPtr>;
+defm : DemangledNativeBuiltin<"__spirv_GenericCastToPtrExplicit_ToGlobal", OpenCL_std, CastToPtr, 2, 2, OpGenericCastToPtr>;
+defm : DemangledNativeBuiltin<"__spirv_GenericCastToPtrExplicit_ToLocal", OpenCL_std, CastToPtr, 2, 2, OpGenericCastToPtr>;
+defm : DemangledNativeBuiltin<"__spirv_GenericCastToPtrExplicit_ToPrivate", OpenCL_std, CastToPtr, 2, 2, OpGenericCastToPtr>;
+
+// Cooperative Matrix builtin records:
+defm : DemangledNativeBuiltin<"__spirv_CooperativeMatrixLoadKHR", OpenCL_std, CoopMatr, 2, 0, OpCooperativeMatrixLoadKHR>;
+defm : DemangledNativeBuiltin<"__spirv_CooperativeMatrixStoreKHR", OpenCL_std, CoopMatr, 3, 0, OpCooperativeMatrixStoreKHR>;
+defm : DemangledNativeBuiltin<"__spirv_CooperativeMatrixMulAddKHR", OpenCL_std, CoopMatr, 3, 0, OpCooperativeMatrixMulAddKHR>;
+defm : DemangledNativeBuiltin<"__spirv_CooperativeMatrixLengthKHR", OpenCL_std, CoopMatr, 1, 1, OpCooperativeMatrixLengthKHR>;
+
//===----------------------------------------------------------------------===//
// Class defining a work/sub group builtin that should be translated into a
// SPIR-V instruction using the defined properties.
@@ -682,9 +766,17 @@ multiclass DemangledGroupBuiltin<string name, int level /* OnlyWork/OnlySub/...
}
}
+multiclass DemangledGroupBuiltinWrapper<string name, bits<8> minNumArgs, bits<8> maxNumArgs, Op operation> {
+ def : DemangledBuiltin<name, OpenCL_std, Group, minNumArgs, maxNumArgs>;
+ def : GroupBuiltin<name, operation>;
+}
+
defm : DemangledGroupBuiltin<"group_all", WorkOrSub, OpGroupAll>;
+defm : DemangledGroupBuiltinWrapper<"__spirv_GroupAll", 2, 2, OpGroupAll>;
defm : DemangledGroupBuiltin<"group_any", WorkOrSub, OpGroupAny>;
+defm : DemangledGroupBuiltinWrapper<"__spirv_GroupAny", 2, 2, OpGroupAny>;
defm : DemangledGroupBuiltin<"group_broadcast", WorkOrSub, OpGroupBroadcast>;
+defm : DemangledGroupBuiltinWrapper<"__spirv_GroupBroadcast", 3, 3, OpGroupBroadcast>;
defm : DemangledGroupBuiltin<"group_non_uniform_broadcast", OnlySub, OpGroupNonUniformBroadcast>;
defm : DemangledGroupBuiltin<"group_broadcast_first", OnlySub, OpGroupNonUniformBroadcastFirst>;
@@ -719,41 +811,49 @@ defm : DemangledGroupBuiltin<"group_scan_inclusive_adds", WorkOrSub, OpGroupIAdd
defm : DemangledGroupBuiltin<"group_reduce_addu", WorkOrSub, OpGroupIAdd>;
defm : DemangledGroupBuiltin<"group_scan_exclusive_addu", WorkOrSub, OpGroupIAdd>;
defm : DemangledGroupBuiltin<"group_scan_inclusive_addu", WorkOrSub, OpGroupIAdd>;
+defm : DemangledGroupBuiltinWrapper<"__spirv_GroupIAdd", 3, 3, OpGroupIAdd>;
defm : DemangledGroupBuiltin<"group_fadd", WorkOrSub, OpGroupFAdd>;
defm : DemangledGroupBuiltin<"group_reduce_addf", WorkOrSub, OpGroupFAdd>;
defm : DemangledGroupBuiltin<"group_scan_exclusive_addf", WorkOrSub, OpGroupFAdd>;
defm : DemangledGroupBuiltin<"group_scan_inclusive_addf", WorkOrSub, OpGroupFAdd>;
+defm : DemangledGroupBuiltinWrapper<"__spirv_GroupFAdd", 3, 3, OpGroupFAdd>;
defm : DemangledGroupBuiltin<"group_fmin", WorkOrSub, OpGroupFMin>;
defm : DemangledGroupBuiltin<"group_reduce_minf", WorkOrSub, OpGroupFMin>;
defm : DemangledGroupBuiltin<"group_scan_exclusive_minf", WorkOrSub, OpGroupFMin>;
defm : DemangledGroupBuiltin<"group_scan_inclusive_minf", WorkOrSub, OpGroupFMin>;
+defm : DemangledGroupBuiltinWrapper<"__spirv_GroupFMin", 3, 3, OpGroupFMin>;
defm : DemangledGroupBuiltin<"group_umin", WorkOrSub, OpGroupUMin>;
defm : DemangledGroupBuiltin<"group_reduce_minu", WorkOrSub, OpGroupUMin>;
defm : DemangledGroupBuiltin<"group_scan_exclusive_minu", WorkOrSub, OpGroupUMin>;
defm : DemangledGroupBuiltin<"group_scan_inclusive_minu", WorkOrSub, OpGroupUMin>;
+defm : DemangledGroupBuiltinWrapper<"__spirv_GroupUMin", 3, 3, OpGroupUMin>;
defm : DemangledGroupBuiltin<"group_smin", WorkOrSub, OpGroupSMin>;
defm : DemangledGroupBuiltin<"group_reduce_mins", WorkOrSub, OpGroupSMin>;
defm : DemangledGroupBuiltin<"group_scan_exclusive_mins", WorkOrSub, OpGroupSMin>;
defm : DemangledGroupBuiltin<"group_scan_inclusive_mins", WorkOrSub, OpGroupSMin>;
+defm : DemangledGroupBuiltinWrapper<"__spirv_GroupSMin", 3, 3, OpGroupSMin>;
defm : DemangledGroupBuiltin<"group_fmax", WorkOrSub, OpGroupFMax>;
defm : DemangledGroupBuiltin<"group_reduce_maxf", WorkOrSub, OpGroupFMax>;
defm : DemangledGroupBuiltin<"group_scan_exclusive_maxf", WorkOrSub, OpGroupFMax>;
defm : DemangledGroupBuiltin<"group_scan_inclusive_maxf", WorkOrSub, OpGroupFMax>;
+defm : DemangledGroupBuiltinWrapper<"__spirv_GroupFMax", 3, 3, OpGroupFMax>;
defm : DemangledGroupBuiltin<"group_umax", WorkOrSub, OpGroupUMax>;
defm : DemangledGroupBuiltin<"group_reduce_maxu", WorkOrSub, OpGroupUMax>;
defm : DemangledGroupBuiltin<"group_scan_exclusive_maxu", WorkOrSub, OpGroupUMax>;
defm : DemangledGroupBuiltin<"group_scan_inclusive_maxu", WorkOrSub, OpGroupUMax>;
+defm : DemangledGroupBuiltinWrapper<"__spirv_GroupUMax", 3, 3, OpGroupUMax>;
defm : DemangledGroupBuiltin<"group_smax", WorkOrSub, OpGroupSMax>;
defm : DemangledGroupBuiltin<"group_reduce_maxs", WorkOrSub, OpGroupSMax>;
defm : DemangledGroupBuiltin<"group_scan_exclusive_maxs", WorkOrSub, OpGroupSMax>;
defm : DemangledGroupBuiltin<"group_scan_inclusive_maxs", WorkOrSub, OpGroupSMax>;
+defm : DemangledGroupBuiltinWrapper<"__spirv_GroupSMax", 3, 3, OpGroupSMax>;
// cl_khr_subgroup_non_uniform_arithmetic
defm : DemangledGroupBuiltin<"group_non_uniform_iadd", WorkOrSub, OpGroupNonUniformIAdd>;
@@ -997,8 +1097,6 @@ multiclass DemangledAtomicFloatingBuiltin<string name, bits<8> minNumArgs, bits<
defm : DemangledAtomicFloatingBuiltin<"AddEXT", 4, 4, OpAtomicFAddEXT>;
defm : DemangledAtomicFloatingBuiltin<"MinEXT", 4, 4, OpAtomicFMinEXT>;
defm : DemangledAtomicFloatingBuiltin<"MaxEXT", 4, 4, OpAtomicFMaxEXT>;
-// TODO: add support for cl_ext_float_atomics to enable performing atomic operations
-// on floating-point numbers in memory (float arguments for atomic_fetch_add, ...)
//===----------------------------------------------------------------------===//
// Class defining a sub group builtin that should be translated into a
@@ -1407,7 +1505,7 @@ def : BuiltinType<"spirv.DeviceEvent", OpTypeDeviceEvent>;
def : BuiltinType<"spirv.Image", OpTypeImage>;
def : BuiltinType<"spirv.SampledImage", OpTypeSampledImage>;
def : BuiltinType<"spirv.Pipe", OpTypePipe>;
-
+def : BuiltinType<"spirv.CooperativeMatrixKHR", OpTypeCooperativeMatrixKHR>;
//===----------------------------------------------------------------------===//
// Class matching an OpenCL builtin type name to an equivalent SPIR-V
diff --git a/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp b/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp
index 75aa1823b11f..c7c244cfa897 100644
--- a/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp
@@ -66,6 +66,8 @@ static const std::map<std::string, SPIRV::Extension::Extension>
SPIRV::Extension::Extension::SPV_INTEL_function_pointers},
{"SPV_KHR_shader_clock",
SPIRV::Extension::Extension::SPV_KHR_shader_clock},
+ {"SPV_KHR_cooperative_matrix",
+ SPIRV::Extension::Extension::SPV_KHR_cooperative_matrix},
};
bool SPIRVExtensionsParser::parse(cl::Option &O, llvm::StringRef ArgName,
diff --git a/llvm/lib/Target/SPIRV/SPIRVDuplicatesTracker.h b/llvm/lib/Target/SPIRV/SPIRVDuplicatesTracker.h
index 2ec3fb35ca04..a37e65a47eda 100644
--- a/llvm/lib/Target/SPIRV/SPIRVDuplicatesTracker.h
+++ b/llvm/lib/Target/SPIRV/SPIRVDuplicatesTracker.h
@@ -16,6 +16,7 @@
#include "MCTargetDesc/SPIRVBaseInfo.h"
#include "MCTargetDesc/SPIRVMCTargetDesc.h"
+#include "SPIRVUtils.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/MapVector.h"
#include "llvm/CodeGen/GlobalISel/MachineIRBuilder.h"
@@ -51,152 +52,87 @@ public:
void addDep(DTSortableEntry *E) { Deps.push_back(E); }
};
-struct SpecialTypeDescriptor {
- enum SpecialTypeKind {
- STK_Empty = 0,
- STK_Image,
- STK_SampledImage,
- STK_Sampler,
- STK_Pipe,
- STK_DeviceEvent,
- STK_Pointer,
- STK_Last = -1
- };
- SpecialTypeKind Kind;
-
- unsigned Hash;
-
- SpecialTypeDescriptor() = delete;
- SpecialTypeDescriptor(SpecialTypeKind K) : Kind(K) { Hash = Kind; }
-
- unsigned getHash() const { return Hash; }
-
- virtual ~SpecialTypeDescriptor() {}
-};
-
-struct ImageTypeDescriptor : public SpecialTypeDescriptor {
- union ImageAttrs {
- struct BitFlags {
- unsigned Dim : 3;
- unsigned Depth : 2;
- unsigned Arrayed : 1;
- unsigned MS : 1;
- unsigned Sampled : 2;
- unsigned ImageFormat : 6;
- unsigned AQ : 2;
- } Flags;
- unsigned Val;
- };
-
- ImageTypeDescriptor(const Type *SampledTy, unsigned Dim, unsigned Depth,
- unsigned Arrayed, unsigned MS, unsigned Sampled,
- unsigned ImageFormat, unsigned AQ = 0)
- : SpecialTypeDescriptor(SpecialTypeKind::STK_Image) {
- ImageAttrs Attrs;
- Attrs.Val = 0;
- Attrs.Flags.Dim = Dim;
- Attrs.Flags.Depth = Depth;
- Attrs.Flags.Arrayed = Arrayed;
- Attrs.Flags.MS = MS;
- Attrs.Flags.Sampled = Sampled;
- Attrs.Flags.ImageFormat = ImageFormat;
- Attrs.Flags.AQ = AQ;
- Hash = (DenseMapInfo<Type *>().getHashValue(SampledTy) & 0xffff) ^
- ((Attrs.Val << 8) | Kind);
- }
-
- static bool classof(const SpecialTypeDescriptor *TD) {
- return TD->Kind == SpecialTypeKind::STK_Image;
- }
-};
-
-struct SampledImageTypeDescriptor : public SpecialTypeDescriptor {
- SampledImageTypeDescriptor(const Type *SampledTy, const MachineInstr *ImageTy)
- : SpecialTypeDescriptor(SpecialTypeKind::STK_SampledImage) {
- assert(ImageTy->getOpcode() == SPIRV::OpTypeImage);
- ImageTypeDescriptor TD(
- SampledTy, ImageTy->getOperand(2).getImm(),
- ImageTy->getOperand(3).getImm(), ImageTy->getOperand(4).getImm(),
- ImageTy->getOperand(5).getImm(), ImageTy->getOperand(6).getImm(),
- ImageTy->getOperand(7).getImm(), ImageTy->getOperand(8).getImm());
- Hash = TD.getHash() ^ Kind;
- }
-
- static bool classof(const SpecialTypeDescriptor *TD) {
- return TD->Kind == SpecialTypeKind::STK_SampledImage;
- }
-};
-
-struct SamplerTypeDescriptor : public SpecialTypeDescriptor {
- SamplerTypeDescriptor()
- : SpecialTypeDescriptor(SpecialTypeKind::STK_Sampler) {
- Hash = Kind;
- }
-
- static bool classof(const SpecialTypeDescriptor *TD) {
- return TD->Kind == SpecialTypeKind::STK_Sampler;
- }
+enum SpecialTypeKind {
+ STK_Empty = 0,
+ STK_Image,
+ STK_SampledImage,
+ STK_Sampler,
+ STK_Pipe,
+ STK_DeviceEvent,
+ STK_Pointer,
+ STK_Last = -1
};
-struct PipeTypeDescriptor : public SpecialTypeDescriptor {
-
- PipeTypeDescriptor(uint8_t AQ)
- : SpecialTypeDescriptor(SpecialTypeKind::STK_Pipe) {
- Hash = (AQ << 8) | Kind;
- }
-
- static bool classof(const SpecialTypeDescriptor *TD) {
- return TD->Kind == SpecialTypeKind::STK_Pipe;
+using SpecialTypeDescriptor = std::tuple<const Type *, unsigned, unsigned>;
+
+union ImageAttrs {
+ struct BitFlags {
+ unsigned Dim : 3;
+ unsigned Depth : 2;
+ unsigned Arrayed : 1;
+ unsigned MS : 1;
+ unsigned Sampled : 2;
+ unsigned ImageFormat : 6;
+ unsigned AQ : 2;
+ } Flags;
+ unsigned Val;
+
+ ImageAttrs(unsigned Dim, unsigned Depth, unsigned Arrayed, unsigned MS,
+ unsigned Sampled, unsigned ImageFormat, unsigned AQ = 0) {
+ Val = 0;
+ Flags.Dim = Dim;
+ Flags.Depth = Depth;
+ Flags.Arrayed = Arrayed;
+ Flags.MS = MS;
+ Flags.Sampled = Sampled;
+ Flags.ImageFormat = ImageFormat;
+ Flags.AQ = AQ;
}
};
-struct DeviceEventTypeDescriptor : public SpecialTypeDescriptor {
-
- DeviceEventTypeDescriptor()
- : SpecialTypeDescriptor(SpecialTypeKind::STK_DeviceEvent) {
- Hash = Kind;
- }
-
- static bool classof(const SpecialTypeDescriptor *TD) {
- return TD->Kind == SpecialTypeKind::STK_DeviceEvent;
- }
-};
-
-struct PointerTypeDescriptor : public SpecialTypeDescriptor {
- const Type *ElementType;
- unsigned AddressSpace;
-
- PointerTypeDescriptor() = delete;
- PointerTypeDescriptor(const Type *ElementType, unsigned AddressSpace)
- : SpecialTypeDescriptor(SpecialTypeKind::STK_Pointer),
- ElementType(ElementType), AddressSpace(AddressSpace) {
- Hash = (DenseMapInfo<Type *>().getHashValue(ElementType) & 0xffff) ^
- ((AddressSpace << 8) | Kind);
- }
-
- static bool classof(const SpecialTypeDescriptor *TD) {
- return TD->Kind == SpecialTypeKind::STK_Pointer;
- }
-};
+inline SpecialTypeDescriptor
+make_descr_image(const Type *SampledTy, unsigned Dim, unsigned Depth,
+ unsigned Arrayed, unsigned MS, unsigned Sampled,
+ unsigned ImageFormat, unsigned AQ = 0) {
+ return std::make_tuple(
+ SampledTy,
+ ImageAttrs(Dim, Depth, Arrayed, MS, Sampled, ImageFormat, AQ).Val,
+ SpecialTypeKind::STK_Image);
+}
+
+inline SpecialTypeDescriptor
+make_descr_sampled_image(const Type *SampledTy, const MachineInstr *ImageTy) {
+ assert(ImageTy->getOpcode() == SPIRV::OpTypeImage);
+ return std::make_tuple(
+ SampledTy,
+ ImageAttrs(
+ ImageTy->getOperand(2).getImm(), ImageTy->getOperand(3).getImm(),
+ ImageTy->getOperand(4).getImm(), ImageTy->getOperand(5).getImm(),
+ ImageTy->getOperand(6).getImm(), ImageTy->getOperand(7).getImm(),
+ ImageTy->getOperand(8).getImm())
+ .Val,
+ SpecialTypeKind::STK_SampledImage);
+}
+
+inline SpecialTypeDescriptor make_descr_sampler() {
+ return std::make_tuple(nullptr, 0U, SpecialTypeKind::STK_Sampler);
+}
+
+inline SpecialTypeDescriptor make_descr_pipe(uint8_t AQ) {
+ return std::make_tuple(nullptr, AQ, SpecialTypeKind::STK_Pipe);
+}
+
+inline SpecialTypeDescriptor make_descr_event() {
+ return std::make_tuple(nullptr, 0U, SpecialTypeKind::STK_DeviceEvent);
+}
+
+inline SpecialTypeDescriptor make_descr_pointee(const Type *ElementType,
+ unsigned AddressSpace) {
+ return std::make_tuple(ElementType, AddressSpace,
+ SpecialTypeKind::STK_Pointer);
+}
} // namespace SPIRV
-template <> struct DenseMapInfo<SPIRV::SpecialTypeDescriptor> {
- static inline SPIRV::SpecialTypeDescriptor getEmptyKey() {
- return SPIRV::SpecialTypeDescriptor(
- SPIRV::SpecialTypeDescriptor::STK_Empty);
- }
- static inline SPIRV::SpecialTypeDescriptor getTombstoneKey() {
- return SPIRV::SpecialTypeDescriptor(SPIRV::SpecialTypeDescriptor::STK_Last);
- }
- static unsigned getHashValue(SPIRV::SpecialTypeDescriptor Val) {
- return Val.getHash();
- }
- static bool isEqual(SPIRV::SpecialTypeDescriptor LHS,
- SPIRV::SpecialTypeDescriptor RHS) {
- return getHashValue(LHS) == getHashValue(RHS);
- }
-};
-
template <typename KeyTy> class SPIRVDuplicatesTrackerBase {
public:
// NOTE: using MapVector instead of DenseMap helps getting everything ordered
@@ -282,12 +218,12 @@ public:
MachineModuleInfo *MMI);
void add(const Type *Ty, const MachineFunction *MF, Register R) {
- TT.add(Ty, MF, R);
+ TT.add(unifyPtrType(Ty), MF, R);
}
- void add(const Type *PointerElementType, unsigned AddressSpace,
+ void add(const Type *PointeeTy, unsigned AddressSpace,
const MachineFunction *MF, Register R) {
- ST.add(SPIRV::PointerTypeDescriptor(PointerElementType, AddressSpace), MF,
+ ST.add(SPIRV::make_descr_pointee(unifyPtrType(PointeeTy), AddressSpace), MF,
R);
}
@@ -317,13 +253,13 @@ public:
}
Register find(const Type *Ty, const MachineFunction *MF) {
- return TT.find(const_cast<Type *>(Ty), MF);
+ return TT.find(unifyPtrType(Ty), MF);
}
- Register find(const Type *PointerElementType, unsigned AddressSpace,
+ Register find(const Type *PointeeTy, unsigned AddressSpace,
const MachineFunction *MF) {
return ST.find(
- SPIRV::PointerTypeDescriptor(PointerElementType, AddressSpace), MF);
+ SPIRV::make_descr_pointee(unifyPtrType(PointeeTy), AddressSpace), MF);
}
Register find(const Constant *C, const MachineFunction *MF) {
diff --git a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
index 7b8e3230bf55..dd5884096b85 100644
--- a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
@@ -69,7 +69,7 @@ class SPIRVEmitIntrinsics
DenseSet<Instruction *> AggrStores;
// deduce element type of untyped pointers
- Type *deduceElementType(Value *I);
+ Type *deduceElementType(Value *I, bool UnknownElemTypeI8);
Type *deduceElementTypeHelper(Value *I);
Type *deduceElementTypeHelper(Value *I, std::unordered_set<Value *> &Visited);
Type *deduceElementTypeByValueDeep(Type *ValueTy, Value *Operand,
@@ -105,7 +105,8 @@ class SPIRVEmitIntrinsics
void replaceMemInstrUses(Instruction *Old, Instruction *New, IRBuilder<> &B);
void processInstrAfterVisit(Instruction *I, IRBuilder<> &B);
- void insertAssignPtrTypeIntrs(Instruction *I, IRBuilder<> &B);
+ bool insertAssignPtrTypeIntrs(Instruction *I, IRBuilder<> &B,
+ bool UnknownElemTypeI8);
void insertAssignTypeIntrs(Instruction *I, IRBuilder<> &B);
void insertAssignPtrTypeTargetExt(TargetExtType *AssignedType, Value *V,
IRBuilder<> &B);
@@ -367,6 +368,26 @@ Type *SPIRVEmitIntrinsics::deduceElementTypeHelper(
if (Ty)
break;
}
+ } else if (auto *CI = dyn_cast<CallInst>(I)) {
+ static StringMap<unsigned> ResTypeByArg = {
+ {"to_global", 0},
+ {"to_local", 0},
+ {"to_private", 0},
+ {"__spirv_GenericCastToPtr_ToGlobal", 0},
+ {"__spirv_GenericCastToPtr_ToLocal", 0},
+ {"__spirv_GenericCastToPtr_ToPrivate", 0},
+ {"__spirv_GenericCastToPtrExplicit_ToGlobal", 0},
+ {"__spirv_GenericCastToPtrExplicit_ToLocal", 0},
+ {"__spirv_GenericCastToPtrExplicit_ToPrivate", 0}};
+ // TODO: maybe improve performance by caching demangled names
+ if (Function *CalledF = CI->getCalledFunction()) {
+ std::string DemangledName =
+ getOclOrSpirvBuiltinDemangledName(CalledF->getName());
+ auto AsArgIt = ResTypeByArg.find(DemangledName);
+ if (AsArgIt != ResTypeByArg.end())
+ Ty = deduceElementTypeHelper(CI->getArgOperand(AsArgIt->second),
+ Visited);
+ }
}
// remember the found relationship
@@ -460,10 +481,10 @@ Type *SPIRVEmitIntrinsics::deduceNestedTypeHelper(
return OrigTy;
}
-Type *SPIRVEmitIntrinsics::deduceElementType(Value *I) {
+Type *SPIRVEmitIntrinsics::deduceElementType(Value *I, bool UnknownElemTypeI8) {
if (Type *Ty = deduceElementTypeHelper(I))
return Ty;
- return IntegerType::getInt8Ty(I->getContext());
+ return UnknownElemTypeI8 ? IntegerType::getInt8Ty(I->getContext()) : nullptr;
}
// If the Instruction has Pointer operands with unresolved types, this function
@@ -652,10 +673,8 @@ void SPIRVEmitIntrinsics::preprocessCompositeConstants(IRBuilder<> &B) {
AggrConst = cast<Constant>(COp);
ResTy = B.getInt32Ty();
} else if (auto *COp = dyn_cast<ConstantAggregateZero>(Op)) {
- if (!Op->getType()->isVectorTy()) {
- AggrConst = cast<Constant>(COp);
- ResTy = B.getInt32Ty();
- }
+ AggrConst = cast<Constant>(COp);
+ ResTy = Op->getType()->isVectorTy() ? COp->getType() : B.getInt32Ty();
}
if (AggrConst) {
SmallVector<Value *> Args;
@@ -1152,16 +1171,23 @@ void SPIRVEmitIntrinsics::processGlobalValue(GlobalVariable &GV,
B.CreateIntrinsic(Intrinsic::spv_unref_global, GV.getType(), &GV);
}
-void SPIRVEmitIntrinsics::insertAssignPtrTypeIntrs(Instruction *I,
- IRBuilder<> &B) {
+// Return true, if we can't decide what is the pointee type now and will get
+// back to the question later. Return false is spv_assign_ptr_type is not needed
+// or can be inserted immediately.
+bool SPIRVEmitIntrinsics::insertAssignPtrTypeIntrs(Instruction *I,
+ IRBuilder<> &B,
+ bool UnknownElemTypeI8) {
reportFatalOnTokenType(I);
if (!isPointerTy(I->getType()) || !requireAssignType(I) ||
isa<BitCastInst>(I))
- return;
+ return false;
setInsertPointAfterDef(B, I);
- Type *ElemTy = deduceElementType(I);
- buildAssignPtr(B, ElemTy, I);
+ if (Type *ElemTy = deduceElementType(I, UnknownElemTypeI8)) {
+ buildAssignPtr(B, ElemTy, I);
+ return false;
+ }
+ return true;
}
void SPIRVEmitIntrinsics::insertAssignTypeIntrs(Instruction *I,
@@ -1199,7 +1225,7 @@ void SPIRVEmitIntrinsics::insertAssignTypeIntrs(Instruction *I,
buildAssignPtr(B, PType->getElementType(), Op);
} else if (isPointerTy(OpTy)) {
Type *ElemTy = GR->findDeducedElementType(Op);
- buildAssignPtr(B, ElemTy ? ElemTy : deduceElementType(Op), Op);
+ buildAssignPtr(B, ElemTy ? ElemTy : deduceElementType(Op, true), Op);
} else {
CallInst *AssignCI = buildIntrWithMD(Intrinsic::spv_assign_type,
{OpTy}, Op, Op, {}, B);
@@ -1235,8 +1261,7 @@ void SPIRVEmitIntrinsics::processInstrAfterVisit(Instruction *I,
}
bool IsPhi = isa<PHINode>(I), BPrepared = false;
for (const auto &Op : I->operands()) {
- if ((isa<ConstantAggregateZero>(Op) && Op->getType()->isVectorTy()) ||
- isa<PHINode>(I) || isa<SwitchInst>(I))
+ if (isa<PHINode>(I) || isa<SwitchInst>(I))
TrackConstants = false;
if ((isa<ConstantData>(Op) || isa<ConstantExpr>(Op)) && TrackConstants) {
unsigned OpNo = Op.getOperandNo();
@@ -1395,10 +1420,15 @@ bool SPIRVEmitIntrinsics::runOnFunction(Function &Func) {
if (isConvergenceIntrinsic(I))
continue;
- insertAssignPtrTypeIntrs(I, B);
+ bool Postpone = insertAssignPtrTypeIntrs(I, B, false);
+ // if Postpone is true, we can't decide on pointee type yet
insertAssignTypeIntrs(I, B);
insertPtrCastOrAssignTypeInstr(I, B);
insertSpirvDecorations(I, B);
+ // if instruction requires a pointee type set, let's check if we know it
+ // already, and force it to be i8 if not
+ if (Postpone && !GR->findAssignPtrTypeInstr(I))
+ insertAssignPtrTypeIntrs(I, B, true);
}
for (auto &I : instructions(Func))
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
index d434e0b5efbc..5558c7a5a4a5 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
@@ -394,7 +394,7 @@ Register SPIRVGlobalRegistry::getOrCreateCompositeOrNull(
Constant *Val, MachineInstr &I, SPIRVType *SpvType,
const SPIRVInstrInfo &TII, Constant *CA, unsigned BitWidth,
unsigned ElemCnt, bool ZeroAsNull) {
- // Find a constant vector in DT or build a new one.
+ // Find a constant vector or array in DT or build a new one.
Register Res = DT.find(CA, CurMF);
// If no values are attached, the composite is null constant.
bool IsNull = Val->isNullValue() && ZeroAsNull;
@@ -474,20 +474,28 @@ Register SPIRVGlobalRegistry::getOrCreateConstVector(APFloat Val,
ZeroAsNull);
}
-Register
-SPIRVGlobalRegistry::getOrCreateConsIntArray(uint64_t Val, MachineInstr &I,
- SPIRVType *SpvType,
- const SPIRVInstrInfo &TII) {
+Register SPIRVGlobalRegistry::getOrCreateConstIntArray(
+ uint64_t Val, size_t Num, MachineInstr &I, SPIRVType *SpvType,
+ const SPIRVInstrInfo &TII) {
const Type *LLVMTy = getTypeForSPIRVType(SpvType);
assert(LLVMTy->isArrayTy());
const ArrayType *LLVMArrTy = cast<ArrayType>(LLVMTy);
Type *LLVMBaseTy = LLVMArrTy->getElementType();
- auto *ConstInt = ConstantInt::get(LLVMBaseTy, Val);
- auto *ConstArr =
- ConstantArray::get(const_cast<ArrayType *>(LLVMArrTy), {ConstInt});
+ Constant *CI = ConstantInt::get(LLVMBaseTy, Val);
SPIRVType *SpvBaseTy = getSPIRVTypeForVReg(SpvType->getOperand(1).getReg());
unsigned BW = getScalarOrVectorBitWidth(SpvBaseTy);
- return getOrCreateCompositeOrNull(ConstInt, I, SpvType, TII, ConstArr, BW,
+ // The following is reasonably unique key that is better that [Val]. The naive
+ // alternative would be something along the lines of:
+ // SmallVector<Constant *> NumCI(Num, CI);
+ // Constant *UniqueKey =
+ // ConstantArray::get(const_cast<ArrayType*>(LLVMArrTy), NumCI);
+ // that would be a truly unique but dangerous key, because it could lead to
+ // the creation of constants of arbitrary length (that is, the parameter of
+ // memset) which were missing in the original module.
+ Constant *UniqueKey = ConstantStruct::getAnon(
+ {PoisonValue::get(const_cast<ArrayType *>(LLVMArrTy)),
+ ConstantInt::get(LLVMBaseTy, Val), ConstantInt::get(LLVMBaseTy, Num)});
+ return getOrCreateCompositeOrNull(CI, I, SpvType, TII, UniqueKey, BW,
LLVMArrTy->getNumElements());
}
@@ -546,24 +554,6 @@ SPIRVGlobalRegistry::getOrCreateConsIntVector(uint64_t Val,
}
Register
-SPIRVGlobalRegistry::getOrCreateConsIntArray(uint64_t Val,
- MachineIRBuilder &MIRBuilder,
- SPIRVType *SpvType, bool EmitIR) {
- const Type *LLVMTy = getTypeForSPIRVType(SpvType);
- assert(LLVMTy->isArrayTy());
- const ArrayType *LLVMArrTy = cast<ArrayType>(LLVMTy);
- Type *LLVMBaseTy = LLVMArrTy->getElementType();
- const auto ConstInt = ConstantInt::get(LLVMBaseTy, Val);
- auto ConstArr =
- ConstantArray::get(const_cast<ArrayType *>(LLVMArrTy), {ConstInt});
- SPIRVType *SpvBaseTy = getSPIRVTypeForVReg(SpvType->getOperand(1).getReg());
- unsigned BW = getScalarOrVectorBitWidth(SpvBaseTy);
- return getOrCreateIntCompositeOrNull(Val, MIRBuilder, SpvType, EmitIR,
- ConstArr, BW,
- LLVMArrTy->getNumElements());
-}
-
-Register
SPIRVGlobalRegistry::getOrCreateConstNullPtr(MachineIRBuilder &MIRBuilder,
SPIRVType *SpvType) {
const Type *LLVMTy = getTypeForSPIRVType(SpvType);
@@ -936,7 +926,7 @@ SPIRVType *SPIRVGlobalRegistry::restOfCreateSPIRVType(
SPIRVType *SpirvType = createSPIRVType(Ty, MIRBuilder, AccessQual, EmitIR);
TypesInProcessing.erase(Ty);
VRegToTypeMap[&MIRBuilder.getMF()][getSPIRVTypeID(SpirvType)] = SpirvType;
- SPIRVToLLVMType[SpirvType] = Ty;
+ SPIRVToLLVMType[SpirvType] = unifyPtrType(Ty);
Register Reg = DT.find(Ty, &MIRBuilder.getMF());
// Do not add OpTypeForwardPointer to DT, a corresponding normal pointer type
// will be added later. For special types it is already added to DT.
@@ -1080,12 +1070,14 @@ bool SPIRVGlobalRegistry::isScalarOrVectorSigned(const SPIRVType *Type) const {
return IntType && IntType->getOperand(2).getImm() != 0;
}
+SPIRVType *SPIRVGlobalRegistry::getPointeeType(SPIRVType *PtrType) {
+ return PtrType && PtrType->getOpcode() == SPIRV::OpTypePointer
+ ? getSPIRVTypeForVReg(PtrType->getOperand(2).getReg())
+ : nullptr;
+}
+
unsigned SPIRVGlobalRegistry::getPointeeTypeOp(Register PtrReg) {
- SPIRVType *PtrType = getSPIRVTypeForVReg(PtrReg);
- SPIRVType *ElemType =
- PtrType && PtrType->getOpcode() == SPIRV::OpTypePointer
- ? getSPIRVTypeForVReg(PtrType->getOperand(2).getReg())
- : nullptr;
+ SPIRVType *ElemType = getPointeeType(getSPIRVTypeForVReg(PtrReg));
return ElemType ? ElemType->getOpcode() : 0;
}
@@ -1122,9 +1114,9 @@ SPIRVType *SPIRVGlobalRegistry::getOrCreateOpTypeImage(
uint32_t Depth, uint32_t Arrayed, uint32_t Multisampled, uint32_t Sampled,
SPIRV::ImageFormat::ImageFormat ImageFormat,
SPIRV::AccessQualifier::AccessQualifier AccessQual) {
- SPIRV::ImageTypeDescriptor TD(SPIRVToLLVMType.lookup(SampledType), Dim, Depth,
- Arrayed, Multisampled, Sampled, ImageFormat,
- AccessQual);
+ auto TD = SPIRV::make_descr_image(SPIRVToLLVMType.lookup(SampledType), Dim,
+ Depth, Arrayed, Multisampled, Sampled,
+ ImageFormat, AccessQual);
if (auto *Res = checkSpecialInstr(TD, MIRBuilder))
return Res;
Register ResVReg = createTypeVReg(MIRBuilder);
@@ -1143,7 +1135,7 @@ SPIRVType *SPIRVGlobalRegistry::getOrCreateOpTypeImage(
SPIRVType *
SPIRVGlobalRegistry::getOrCreateOpTypeSampler(MachineIRBuilder &MIRBuilder) {
- SPIRV::SamplerTypeDescriptor TD;
+ auto TD = SPIRV::make_descr_sampler();
if (auto *Res = checkSpecialInstr(TD, MIRBuilder))
return Res;
Register ResVReg = createTypeVReg(MIRBuilder);
@@ -1154,7 +1146,7 @@ SPIRVGlobalRegistry::getOrCreateOpTypeSampler(MachineIRBuilder &MIRBuilder) {
SPIRVType *SPIRVGlobalRegistry::getOrCreateOpTypePipe(
MachineIRBuilder &MIRBuilder,
SPIRV::AccessQualifier::AccessQualifier AccessQual) {
- SPIRV::PipeTypeDescriptor TD(AccessQual);
+ auto TD = SPIRV::make_descr_pipe(AccessQual);
if (auto *Res = checkSpecialInstr(TD, MIRBuilder))
return Res;
Register ResVReg = createTypeVReg(MIRBuilder);
@@ -1166,7 +1158,7 @@ SPIRVType *SPIRVGlobalRegistry::getOrCreateOpTypePipe(
SPIRVType *SPIRVGlobalRegistry::getOrCreateOpTypeDeviceEvent(
MachineIRBuilder &MIRBuilder) {
- SPIRV::DeviceEventTypeDescriptor TD;
+ auto TD = SPIRV::make_descr_event();
if (auto *Res = checkSpecialInstr(TD, MIRBuilder))
return Res;
Register ResVReg = createTypeVReg(MIRBuilder);
@@ -1176,7 +1168,7 @@ SPIRVType *SPIRVGlobalRegistry::getOrCreateOpTypeDeviceEvent(
SPIRVType *SPIRVGlobalRegistry::getOrCreateOpTypeSampledImage(
SPIRVType *ImageType, MachineIRBuilder &MIRBuilder) {
- SPIRV::SampledImageTypeDescriptor TD(
+ auto TD = SPIRV::make_descr_sampled_image(
SPIRVToLLVMType.lookup(MIRBuilder.getMF().getRegInfo().getVRegDef(
ImageType->getOperand(1).getReg())),
ImageType);
@@ -1189,6 +1181,26 @@ SPIRVType *SPIRVGlobalRegistry::getOrCreateOpTypeSampledImage(
.addUse(getSPIRVTypeID(ImageType));
}
+SPIRVType *SPIRVGlobalRegistry::getOrCreateOpTypeCoopMatr(
+ MachineIRBuilder &MIRBuilder, const TargetExtType *ExtensionType,
+ const SPIRVType *ElemType, uint32_t Scope, uint32_t Rows, uint32_t Columns,
+ uint32_t Use) {
+ Register ResVReg = DT.find(ExtensionType, &MIRBuilder.getMF());
+ if (ResVReg.isValid())
+ return MIRBuilder.getMF().getRegInfo().getUniqueVRegDef(ResVReg);
+ ResVReg = createTypeVReg(MIRBuilder);
+ SPIRVType *SpirvTy =
+ MIRBuilder.buildInstr(SPIRV::OpTypeCooperativeMatrixKHR)
+ .addDef(ResVReg)
+ .addUse(getSPIRVTypeID(ElemType))
+ .addUse(buildConstantInt(Scope, MIRBuilder, nullptr, true))
+ .addUse(buildConstantInt(Rows, MIRBuilder, nullptr, true))
+ .addUse(buildConstantInt(Columns, MIRBuilder, nullptr, true))
+ .addUse(buildConstantInt(Use, MIRBuilder, nullptr, true));
+ DT.add(ExtensionType, &MIRBuilder.getMF(), ResVReg);
+ return SpirvTy;
+}
+
SPIRVType *SPIRVGlobalRegistry::getOrCreateOpTypeByOpcode(
const Type *Ty, MachineIRBuilder &MIRBuilder, unsigned Opcode) {
Register ResVReg = DT.find(Ty, &MIRBuilder.getMF());
@@ -1268,7 +1280,7 @@ SPIRVType *SPIRVGlobalRegistry::finishCreatingSPIRVType(const Type *LLVMTy,
SPIRVType *SpirvType) {
assert(CurMF == SpirvType->getMF());
VRegToTypeMap[CurMF][getSPIRVTypeID(SpirvType)] = SpirvType;
- SPIRVToLLVMType[SpirvType] = LLVMTy;
+ SPIRVToLLVMType[SpirvType] = unifyPtrType(LLVMTy);
return SpirvType;
}
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
index db01f68f48de..a45e1ccd0717 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
@@ -292,6 +292,8 @@ public:
return Res->second;
}
+ // Return a pointee's type, or nullptr otherwise.
+ SPIRVType *getPointeeType(SPIRVType *PtrType);
// Return a pointee's type op code, or 0 otherwise.
unsigned getPointeeTypeOp(Register PtrReg);
@@ -327,6 +329,12 @@ public:
return Ret;
}
+ // Return true if the type is an aggregate type.
+ bool isAggregateType(SPIRVType *Type) const {
+ return Type && (Type->getOpcode() == SPIRV::OpTypeStruct &&
+ Type->getOpcode() == SPIRV::OpTypeArray);
+ }
+
// Whether the given VReg has an OpTypeXXX instruction mapped to it with the
// given opcode (e.g. OpTypeFloat).
bool isScalarOfType(Register VReg, unsigned TypeOpcode) const;
@@ -449,13 +457,11 @@ public:
Register getOrCreateConstVector(APFloat Val, MachineInstr &I,
SPIRVType *SpvType, const SPIRVInstrInfo &TII,
bool ZeroAsNull = true);
- Register getOrCreateConsIntArray(uint64_t Val, MachineInstr &I,
- SPIRVType *SpvType,
- const SPIRVInstrInfo &TII);
+ Register getOrCreateConstIntArray(uint64_t Val, size_t Num, MachineInstr &I,
+ SPIRVType *SpvType,
+ const SPIRVInstrInfo &TII);
Register getOrCreateConsIntVector(uint64_t Val, MachineIRBuilder &MIRBuilder,
SPIRVType *SpvType, bool EmitIR = true);
- Register getOrCreateConsIntArray(uint64_t Val, MachineIRBuilder &MIRBuilder,
- SPIRVType *SpvType, bool EmitIR = true);
Register getOrCreateConstNullPtr(MachineIRBuilder &MIRBuilder,
SPIRVType *SpvType);
Register buildConstantSampler(Register Res, unsigned AddrMode, unsigned Param,
@@ -514,7 +520,11 @@ public:
SPIRVType *getOrCreateOpTypeSampledImage(SPIRVType *ImageType,
MachineIRBuilder &MIRBuilder);
-
+ SPIRVType *getOrCreateOpTypeCoopMatr(MachineIRBuilder &MIRBuilder,
+ const TargetExtType *ExtensionType,
+ const SPIRVType *ElemType,
+ uint32_t Scope, uint32_t Rows,
+ uint32_t Columns, uint32_t Use);
SPIRVType *
getOrCreateOpTypePipe(MachineIRBuilder &MIRBuilder,
SPIRV::AccessQualifier::AccessQualifier AccQual);
diff --git a/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp b/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp
index 5ccbaf12ddee..4383d1c5c0e2 100644
--- a/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp
@@ -339,6 +339,7 @@ void SPIRVTargetLowering::finalizeLowering(MachineFunction &MF) const {
GR.getSPIRVTypeForVReg(MI.getOperand(1).getReg()));
break;
case SPIRV::OpPtrCastToGeneric:
+ case SPIRV::OpGenericCastToPtr:
validateAccessChain(STI, MRI, GR, MI);
break;
case SPIRV::OpInBoundsPtrAccessChain:
diff --git a/llvm/lib/Target/SPIRV/SPIRVISelLowering.h b/llvm/lib/Target/SPIRV/SPIRVISelLowering.h
index 6fc200abf462..77356b7512a7 100644
--- a/llvm/lib/Target/SPIRV/SPIRVISelLowering.h
+++ b/llvm/lib/Target/SPIRV/SPIRVISelLowering.h
@@ -68,6 +68,11 @@ public:
// extra instructions required to preserve validity of SPIR-V code imposed by
// the standard.
void finalizeLowering(MachineFunction &MF) const override;
+
+ MVT getPreferredSwitchConditionType(LLVMContext &Context,
+ EVT ConditionVT) const override {
+ return ConditionVT.getSimpleVT();
+ }
};
} // namespace llvm
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td b/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td
index dedfd5e6e32d..63549b06e967 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td
+++ b/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td
@@ -211,6 +211,9 @@ def OpTypeAccelerationStructureNV: Op<5341, (outs TYPE:$res), (ins),
def OpTypeCooperativeMatrixNV: Op<5358, (outs TYPE:$res),
(ins TYPE:$compType, ID:$scope, ID:$rows, ID:$cols),
"$res = OpTypeCooperativeMatrixNV $compType $scope $rows $cols">;
+def OpTypeCooperativeMatrixKHR: Op<4456, (outs TYPE:$res),
+ (ins TYPE:$compType, ID:$scope, ID:$rows, ID:$cols, ID:$use),
+ "$res = OpTypeCooperativeMatrixKHR $compType $scope $rows $cols $use">;
// 3.42.7 Constant-Creation Instructions
@@ -864,3 +867,16 @@ def OpAsmINTEL: Op<5610, (outs ID:$res), (ins TYPE:$type, TYPE:$asm_type, ID:$ta
"$res = OpAsmINTEL $type $asm_type $target $asm">;
def OpAsmCallINTEL: Op<5611, (outs ID:$res), (ins TYPE:$type, ID:$asm, variable_ops),
"$res = OpAsmCallINTEL $type $asm">;
+
+// SPV_KHR_cooperative_matrix
+def OpCooperativeMatrixLoadKHR: Op<4457, (outs ID:$res),
+ (ins TYPE:$resType, ID:$pointer, ID:$memory_layout, variable_ops),
+ "$res = OpCooperativeMatrixLoadKHR $resType $pointer $memory_layout">;
+def OpCooperativeMatrixStoreKHR: Op<4458, (outs),
+ (ins ID:$pointer, ID:$objectToStore, ID:$memory_layout, variable_ops),
+ "OpCooperativeMatrixStoreKHR $pointer $objectToStore $memory_layout">;
+def OpCooperativeMatrixMulAddKHR: Op<4459, (outs ID:$res),
+ (ins TYPE:$type, ID:$A, ID:$B, ID:$C, variable_ops),
+ "$res = OpCooperativeMatrixMulAddKHR $type $A $B $C">;
+def OpCooperativeMatrixLengthKHR: Op<4460, (outs ID:$res), (ins TYPE:$type, ID:$coop_matr_type),
+ "$res = OpCooperativeMatrixLengthKHR $type $coop_matr_type">;
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
index db83172f7fa9..9be736ce88ce 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
@@ -173,6 +173,9 @@ private:
bool selectFmix(Register ResVReg, const SPIRVType *ResType,
MachineInstr &I) const;
+ bool selectRsqrt(Register ResVReg, const SPIRVType *ResType,
+ MachineInstr &I) const;
+
void renderImm32(MachineInstrBuilder &MIB, const MachineInstr &I,
int OpIdx) const;
void renderFImm32(MachineInstrBuilder &MIB, const MachineInstr &I,
@@ -469,6 +472,18 @@ bool SPIRVInstructionSelector::spvSelect(Register ResVReg,
return selectExtInst(ResVReg, ResType, I, CL::sin, GL::Sin);
case TargetOpcode::G_FTAN:
return selectExtInst(ResVReg, ResType, I, CL::tan, GL::Tan);
+ case TargetOpcode::G_FACOS:
+ return selectExtInst(ResVReg, ResType, I, CL::acos, GL::Acos);
+ case TargetOpcode::G_FASIN:
+ return selectExtInst(ResVReg, ResType, I, CL::asin, GL::Asin);
+ case TargetOpcode::G_FATAN:
+ return selectExtInst(ResVReg, ResType, I, CL::atan, GL::Atan);
+ case TargetOpcode::G_FCOSH:
+ return selectExtInst(ResVReg, ResType, I, CL::cosh, GL::Cosh);
+ case TargetOpcode::G_FSINH:
+ return selectExtInst(ResVReg, ResType, I, CL::sinh, GL::Sinh);
+ case TargetOpcode::G_FTANH:
+ return selectExtInst(ResVReg, ResType, I, CL::tanh, GL::Tanh);
case TargetOpcode::G_FSQRT:
return selectExtInst(ResVReg, ResType, I, CL::sqrt, GL::Sqrt);
@@ -831,7 +846,7 @@ bool SPIRVInstructionSelector::selectMemOperation(Register ResVReg,
unsigned Num = getIConstVal(I.getOperand(2).getReg(), MRI);
SPIRVType *ValTy = GR.getOrCreateSPIRVIntegerType(8, I, TII);
SPIRVType *ArrTy = GR.getOrCreateSPIRVArrayType(ValTy, Num, I, TII);
- Register Const = GR.getOrCreateConsIntArray(Val, I, ArrTy, TII);
+ Register Const = GR.getOrCreateConstIntArray(Val, Num, I, ArrTy, TII);
SPIRVType *VarTy = GR.getOrCreateSPIRVPointerType(
ArrTy, I, TII, SPIRV::StorageClass::UniformConstant);
// TODO: check if we have such GV, add init, use buildGlobalVariable.
@@ -1102,7 +1117,7 @@ bool SPIRVInstructionSelector::selectAddrSpaceCast(Register ResVReg,
if (isGenericCastablePtr(SrcSC) && isGenericCastablePtr(DstSC)) {
Register Tmp = MRI->createVirtualRegister(&SPIRV::IDRegClass);
SPIRVType *GenericPtrTy = GR.getOrCreateSPIRVPointerType(
- SrcPtrTy, I, TII, SPIRV::StorageClass::Generic);
+ GR.getPointeeType(SrcPtrTy), I, TII, SPIRV::StorageClass::Generic);
MachineBasicBlock &BB = *I.getParent();
const DebugLoc &DL = I.getDebugLoc();
bool Success = BuildMI(BB, I, DL, TII.get(SPIRV::OpPtrCastToGeneric))
@@ -1315,6 +1330,23 @@ bool SPIRVInstructionSelector::selectFmix(Register ResVReg,
.constrainAllUses(TII, TRI, RBI);
}
+bool SPIRVInstructionSelector::selectRsqrt(Register ResVReg,
+ const SPIRVType *ResType,
+ MachineInstr &I) const {
+
+ assert(I.getNumOperands() == 3);
+ assert(I.getOperand(2).isReg());
+ MachineBasicBlock &BB = *I.getParent();
+
+ return BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpExtInst))
+ .addDef(ResVReg)
+ .addUse(GR.getSPIRVTypeID(ResType))
+ .addImm(static_cast<uint32_t>(SPIRV::InstructionSet::GLSL_std_450))
+ .addImm(GL::InverseSqrt)
+ .addUse(I.getOperand(2).getReg())
+ .constrainAllUses(TII, TRI, RBI);
+}
+
bool SPIRVInstructionSelector::selectBitreverse(Register ResVReg,
const SPIRVType *ResType,
MachineInstr &I) const {
@@ -1413,20 +1445,50 @@ static unsigned getArrayComponentCount(MachineRegisterInfo *MRI,
}
// Return true if the type represents a constant register
-static bool isConstReg(MachineRegisterInfo *MRI, SPIRVType *OpDef) {
+static bool isConstReg(MachineRegisterInfo *MRI, SPIRVType *OpDef,
+ SmallPtrSet<SPIRVType *, 4> &Visited) {
if (OpDef->getOpcode() == SPIRV::ASSIGN_TYPE &&
OpDef->getOperand(1).isReg()) {
if (SPIRVType *RefDef = MRI->getVRegDef(OpDef->getOperand(1).getReg()))
OpDef = RefDef;
}
- return OpDef->getOpcode() == TargetOpcode::G_CONSTANT ||
- OpDef->getOpcode() == TargetOpcode::G_FCONSTANT;
+
+ if (Visited.contains(OpDef))
+ return true;
+ Visited.insert(OpDef);
+
+ unsigned Opcode = OpDef->getOpcode();
+ switch (Opcode) {
+ case TargetOpcode::G_CONSTANT:
+ case TargetOpcode::G_FCONSTANT:
+ return true;
+ case TargetOpcode::G_INTRINSIC:
+ case TargetOpcode::G_INTRINSIC_W_SIDE_EFFECTS:
+ case TargetOpcode::G_INTRINSIC_CONVERGENT_W_SIDE_EFFECTS:
+ return cast<GIntrinsic>(*OpDef).getIntrinsicID() ==
+ Intrinsic::spv_const_composite;
+ case TargetOpcode::G_BUILD_VECTOR:
+ case TargetOpcode::G_SPLAT_VECTOR: {
+ for (unsigned i = OpDef->getNumExplicitDefs(); i < OpDef->getNumOperands();
+ i++) {
+ SPIRVType *OpNestedDef =
+ OpDef->getOperand(i).isReg()
+ ? MRI->getVRegDef(OpDef->getOperand(i).getReg())
+ : nullptr;
+ if (OpNestedDef && !isConstReg(MRI, OpNestedDef, Visited))
+ return false;
+ }
+ return true;
+ }
+ }
+ return false;
}
// Return true if the virtual register represents a constant
static bool isConstReg(MachineRegisterInfo *MRI, Register OpReg) {
+ SmallPtrSet<SPIRVType *, 4> Visited;
if (SPIRVType *OpDef = MRI->getVRegDef(OpReg))
- return isConstReg(MRI, OpDef);
+ return isConstReg(MRI, OpDef, Visited);
return false;
}
@@ -1740,15 +1802,18 @@ bool SPIRVInstructionSelector::selectOpUndef(Register ResVReg,
static bool isImm(const MachineOperand &MO, MachineRegisterInfo *MRI) {
assert(MO.isReg());
const SPIRVType *TypeInst = MRI->getVRegDef(MO.getReg());
- if (TypeInst->getOpcode() != SPIRV::ASSIGN_TYPE)
- return false;
- assert(TypeInst->getOperand(1).isReg());
- MachineInstr *ImmInst = MRI->getVRegDef(TypeInst->getOperand(1).getReg());
- return ImmInst->getOpcode() == TargetOpcode::G_CONSTANT;
+ if (TypeInst->getOpcode() == SPIRV::ASSIGN_TYPE) {
+ assert(TypeInst->getOperand(1).isReg());
+ MachineInstr *ImmInst = MRI->getVRegDef(TypeInst->getOperand(1).getReg());
+ return ImmInst->getOpcode() == TargetOpcode::G_CONSTANT;
+ }
+ return TypeInst->getOpcode() == SPIRV::OpConstantI;
}
static int64_t foldImm(const MachineOperand &MO, MachineRegisterInfo *MRI) {
const SPIRVType *TypeInst = MRI->getVRegDef(MO.getReg());
+ if (TypeInst->getOpcode() == SPIRV::OpConstantI)
+ return TypeInst->getOperand(2).getImm();
MachineInstr *ImmInst = MRI->getVRegDef(TypeInst->getOperand(1).getReg());
assert(ImmInst->getOpcode() == TargetOpcode::G_CONSTANT);
return ImmInst->getOperand(1).getCImm()->getZExtValue();
@@ -1850,8 +1915,10 @@ bool SPIRVInstructionSelector::wrapIntoSpecConstantOp(
Register OpReg = I.getOperand(i).getReg();
SPIRVType *OpDefine = MRI->getVRegDef(OpReg);
SPIRVType *OpType = GR.getSPIRVTypeForVReg(OpReg);
- if (!OpDefine || !OpType || isConstReg(MRI, OpDefine) ||
- OpDefine->getOpcode() == TargetOpcode::G_ADDRSPACE_CAST) {
+ SmallPtrSet<SPIRVType *, 4> Visited;
+ if (!OpDefine || !OpType || isConstReg(MRI, OpDefine, Visited) ||
+ OpDefine->getOpcode() == TargetOpcode::G_ADDRSPACE_CAST ||
+ GR.isAggregateType(OpType)) {
// The case of G_ADDRSPACE_CAST inside spv_const_composite() is processed
// by selectAddrSpaceCast()
CompositeArgs.push_back(OpReg);
@@ -1992,6 +2059,8 @@ bool SPIRVInstructionSelector::selectIntrinsic(Register ResVReg,
return selectAny(ResVReg, ResType, I);
case Intrinsic::spv_lerp:
return selectFmix(ResVReg, ResType, I);
+ case Intrinsic::spv_rsqrt:
+ return selectRsqrt(ResVReg, ResType, I);
case Intrinsic::spv_lifetime_start:
case Intrinsic::spv_lifetime_end: {
unsigned Op = IID == Intrinsic::spv_lifetime_start ? SPIRV::OpLifetimeStart
diff --git a/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp b/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp
index 57fbf3b3f8f1..6c7c3af19965 100644
--- a/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp
@@ -278,6 +278,12 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) {
G_FCOS,
G_FSIN,
G_FTAN,
+ G_FACOS,
+ G_FASIN,
+ G_FATAN,
+ G_FCOSH,
+ G_FSINH,
+ G_FTANH,
G_FSQRT,
G_FFLOOR,
G_FRINT,
diff --git a/llvm/lib/Target/SPIRV/SPIRVMergeRegionExitTargets.cpp b/llvm/lib/Target/SPIRV/SPIRVMergeRegionExitTargets.cpp
index 2744c25d1bc7..0747dd1bbaf4 100644
--- a/llvm/lib/Target/SPIRV/SPIRVMergeRegionExitTargets.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVMergeRegionExitTargets.cpp
@@ -17,6 +17,8 @@
#include "SPIRVSubtarget.h"
#include "SPIRVTargetMachine.h"
#include "SPIRVUtils.h"
+#include "llvm/ADT/DenseMap.h"
+#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/Analysis/LoopInfo.h"
#include "llvm/CodeGen/IntrinsicLowering.h"
#include "llvm/IR/CFG.h"
@@ -71,7 +73,7 @@ public:
/// terminator will take.
llvm::Value *createExitVariable(
BasicBlock *BB,
- const std::unordered_map<BasicBlock *, ConstantInt *> &TargetToValue) {
+ const DenseMap<BasicBlock *, ConstantInt *> &TargetToValue) {
auto *T = BB->getTerminator();
if (isa<ReturnInst>(T))
return nullptr;
@@ -98,12 +100,12 @@ public:
}
// TODO: add support for switch cases.
- assert(false && "Unhandled terminator type.");
+ llvm_unreachable("Unhandled terminator type.");
}
/// Replaces |BB|'s branch targets present in |ToReplace| with |NewTarget|.
void replaceBranchTargets(BasicBlock *BB,
- const std::unordered_set<BasicBlock *> ToReplace,
+ const SmallPtrSet<BasicBlock *, 4> &ToReplace,
BasicBlock *NewTarget) {
auto *T = BB->getTerminator();
if (isa<ReturnInst>(T))
@@ -133,7 +135,7 @@ public:
bool runOnConvergenceRegionNoRecurse(LoopInfo &LI,
const SPIRV::ConvergenceRegion *CR) {
// Gather all the exit targets for this region.
- std::unordered_set<BasicBlock *> ExitTargets;
+ SmallPtrSet<BasicBlock *, 4> ExitTargets;
for (BasicBlock *Exit : CR->Exits) {
for (BasicBlock *Target : gatherSuccessors(Exit)) {
if (CR->Blocks.count(Target) == 0)
@@ -164,9 +166,10 @@ public:
// Creating one constant per distinct exit target. This will be route to the
// correct target.
- std::unordered_map<BasicBlock *, ConstantInt *> TargetToValue;
+ DenseMap<BasicBlock *, ConstantInt *> TargetToValue;
for (BasicBlock *Target : SortedExitTargets)
- TargetToValue.emplace(Target, Builder.getInt32(TargetToValue.size()));
+ TargetToValue.insert(
+ std::make_pair(Target, Builder.getInt32(TargetToValue.size())));
// Creating one variable per exit node, set to the constant matching the
// targeted external block.
@@ -184,12 +187,12 @@ public:
}
// Creating the switch to jump to the correct exit target.
- std::vector<std::pair<BasicBlock *, ConstantInt *>> CasesList(
- TargetToValue.begin(), TargetToValue.end());
- llvm::SwitchInst *Sw =
- Builder.CreateSwitch(node, CasesList[0].first, CasesList.size() - 1);
- for (size_t i = 1; i < CasesList.size(); i++)
- Sw->addCase(CasesList[i].second, CasesList[i].first);
+ llvm::SwitchInst *Sw = Builder.CreateSwitch(node, SortedExitTargets[0],
+ SortedExitTargets.size() - 1);
+ for (size_t i = 1; i < SortedExitTargets.size(); i++) {
+ BasicBlock *BB = SortedExitTargets[i];
+ Sw->addCase(TargetToValue[BB], BB);
+ }
// Fix exit branches to redirect to the new exit.
for (auto Exit : CR->Exits)
diff --git a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
index 30a6c474f467..ac0aa682ea4b 100644
--- a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
@@ -1168,6 +1168,15 @@ void addInstrRequirements(const MachineInstr &MI,
Reqs.addCapability(SPIRV::Capability::AsmINTEL);
}
break;
+ case SPIRV::OpTypeCooperativeMatrixKHR:
+ if (!ST.canUseExtension(SPIRV::Extension::SPV_KHR_cooperative_matrix))
+ report_fatal_error(
+ "OpTypeCooperativeMatrixKHR type requires the "
+ "following SPIR-V extension: SPV_KHR_cooperative_matrix",
+ false);
+ Reqs.addExtension(SPIRV::Extension::SPV_KHR_cooperative_matrix);
+ Reqs.addCapability(SPIRV::Capability::CooperativeMatrixKHR);
+ break;
default:
break;
}
diff --git a/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp b/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp
index adc5b36af6f1..0ea2f176565e 100644
--- a/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp
@@ -41,7 +41,8 @@ public:
static void
addConstantsToTrack(MachineFunction &MF, SPIRVGlobalRegistry *GR,
const SPIRVSubtarget &STI,
- DenseMap<MachineInstr *, Type *> &TargetExtConstTypes) {
+ DenseMap<MachineInstr *, Type *> &TargetExtConstTypes,
+ SmallSet<Register, 4> &TrackedConstRegs) {
MachineRegisterInfo &MRI = MF.getRegInfo();
DenseMap<MachineInstr *, Register> RegsAlreadyAddedToDT;
SmallVector<MachineInstr *, 10> ToErase, ToEraseComposites;
@@ -80,6 +81,7 @@ addConstantsToTrack(MachineFunction &MF, SPIRVGlobalRegistry *GR,
}
}
GR->add(Const, &MF, SrcReg);
+ TrackedConstRegs.insert(SrcReg);
if (Const->getType()->isTargetExtTy()) {
// remember association so that we can restore it when assign types
MachineInstr *SrcMI = MRI.getVRegDef(SrcReg);
@@ -121,7 +123,9 @@ addConstantsToTrack(MachineFunction &MF, SPIRVGlobalRegistry *GR,
MI->eraseFromParent();
}
-static void foldConstantsIntoIntrinsics(MachineFunction &MF) {
+static void
+foldConstantsIntoIntrinsics(MachineFunction &MF,
+ const SmallSet<Register, 4> &TrackedConstRegs) {
SmallVector<MachineInstr *, 10> ToErase;
MachineRegisterInfo &MRI = MF.getRegInfo();
const unsigned AssignNameOperandShift = 2;
@@ -137,7 +141,8 @@ static void foldConstantsIntoIntrinsics(MachineFunction &MF) {
MI.removeOperand(NumOp);
MI.addOperand(MachineOperand::CreateImm(
ConstMI->getOperand(1).getCImm()->getZExtValue()));
- if (MRI.use_empty(ConstMI->getOperand(0).getReg()))
+ Register DefReg = ConstMI->getOperand(0).getReg();
+ if (MRI.use_empty(DefReg) && !TrackedConstRegs.contains(DefReg))
ToErase.push_back(ConstMI);
}
}
@@ -271,6 +276,21 @@ static SPIRVType *propagateSPIRVType(MachineInstr *MI, SPIRVGlobalRegistry *GR,
return SpirvTy;
}
+// To support current approach and limitations wrt. bit width here we widen a
+// scalar register with a bit width greater than 1 to valid sizes and cap it to
+// 64 width.
+static void widenScalarLLTNextPow2(Register Reg, MachineRegisterInfo &MRI) {
+ LLT RegType = MRI.getType(Reg);
+ if (!RegType.isScalar())
+ return;
+ unsigned Sz = RegType.getScalarSizeInBits();
+ if (Sz == 1)
+ return;
+ unsigned NewSz = std::min(std::max(1u << Log2_32_Ceil(Sz), 8u), 64u);
+ if (NewSz != Sz)
+ MRI.setType(Reg, LLT::scalar(NewSz));
+}
+
static std::pair<Register, unsigned>
createNewIdReg(SPIRVType *SpvType, Register SrcReg, MachineRegisterInfo &MRI,
const SPIRVGlobalRegistry &GR) {
@@ -395,6 +415,7 @@ generateAssignInstrs(MachineFunction &MF, SPIRVGlobalRegistry *GR,
MachineRegisterInfo &MRI = MF.getRegInfo();
SmallVector<MachineInstr *, 10> ToErase;
+ DenseMap<MachineInstr *, Register> RegsAlreadyAddedToDT;
for (MachineBasicBlock *MBB : post_order(&MF)) {
if (MBB->empty())
@@ -406,6 +427,11 @@ generateAssignInstrs(MachineFunction &MF, SPIRVGlobalRegistry *GR,
MachineInstr &MI = *MII;
unsigned MIOp = MI.getOpcode();
+ // validate bit width of scalar registers
+ for (const auto &MOP : MI.operands())
+ if (MOP.isReg())
+ widenScalarLLTNextPow2(MOP.getReg(), MRI);
+
if (isSpvIntrinsic(MI, Intrinsic::spv_assign_ptr_type)) {
Register Reg = MI.getOperand(1).getReg();
MIB.setInsertPt(*MI.getParent(), MI.getIterator());
@@ -441,6 +467,7 @@ generateAssignInstrs(MachineFunction &MF, SPIRVGlobalRegistry *GR,
// %rctmp = G_CONSTANT ty Val
// %rc = ASSIGN_TYPE %rctmp, %cty
Register Reg = MI.getOperand(0).getReg();
+ bool NeedAssignType = true;
if (MRI.hasOneUse(Reg)) {
MachineInstr &UseMI = *MRI.use_instr_begin(Reg);
if (isSpvIntrinsic(UseMI, Intrinsic::spv_assign_type) ||
@@ -453,7 +480,20 @@ generateAssignInstrs(MachineFunction &MF, SPIRVGlobalRegistry *GR,
Ty = TargetExtIt == TargetExtConstTypes.end()
? MI.getOperand(1).getCImm()->getType()
: TargetExtIt->second;
- GR->add(MI.getOperand(1).getCImm(), &MF, Reg);
+ const ConstantInt *OpCI = MI.getOperand(1).getCImm();
+ Register PrimaryReg = GR->find(OpCI, &MF);
+ if (!PrimaryReg.isValid()) {
+ GR->add(OpCI, &MF, Reg);
+ } else if (PrimaryReg != Reg &&
+ MRI.getType(Reg) == MRI.getType(PrimaryReg)) {
+ auto *RCReg = MRI.getRegClassOrNull(Reg);
+ auto *RCPrimary = MRI.getRegClassOrNull(PrimaryReg);
+ if (!RCReg || RCPrimary == RCReg) {
+ RegsAlreadyAddedToDT[&MI] = PrimaryReg;
+ ToErase.push_back(&MI);
+ NeedAssignType = false;
+ }
+ }
} else if (MIOp == TargetOpcode::G_FCONSTANT) {
Ty = MI.getOperand(1).getFPImm()->getType();
} else {
@@ -472,14 +512,10 @@ generateAssignInstrs(MachineFunction &MF, SPIRVGlobalRegistry *GR,
MI.getNumExplicitOperands() - MI.getNumExplicitDefs();
Ty = VectorType::get(ElemTy, NumElts, false);
}
- insertAssignInstr(Reg, Ty, nullptr, GR, MIB, MRI);
+ if (NeedAssignType)
+ insertAssignInstr(Reg, Ty, nullptr, GR, MIB, MRI);
} else if (MIOp == TargetOpcode::G_GLOBAL_VALUE) {
propagateSPIRVType(&MI, GR, MRI, MIB);
- } else if (MIOp == TargetOpcode::G_BITREVERSE) {
- Register Reg = MI.getOperand(0).getReg();
- LLT RegType = MRI.getType(Reg);
- if (RegType.getSizeInBits() < 32)
- MRI.setType(Reg, LLT::scalar(32));
}
if (MII == Begin)
@@ -488,8 +524,12 @@ generateAssignInstrs(MachineFunction &MF, SPIRVGlobalRegistry *GR,
--MII;
}
}
- for (MachineInstr *MI : ToErase)
+ for (MachineInstr *MI : ToErase) {
+ auto It = RegsAlreadyAddedToDT.find(MI);
+ if (RegsAlreadyAddedToDT.contains(MI))
+ MRI.replaceRegWith(MI->getOperand(0).getReg(), It->second);
MI->eraseFromParent();
+ }
// Address the case when IRTranslator introduces instructions with new
// registers without SPIRVType associated.
@@ -821,8 +861,10 @@ bool SPIRVPreLegalizer::runOnMachineFunction(MachineFunction &MF) {
MachineIRBuilder MIB(MF);
// a registry of target extension constants
DenseMap<MachineInstr *, Type *> TargetExtConstTypes;
- addConstantsToTrack(MF, GR, ST, TargetExtConstTypes);
- foldConstantsIntoIntrinsics(MF);
+ // to keep record of tracked constants
+ SmallSet<Register, 4> TrackedConstRegs;
+ addConstantsToTrack(MF, GR, ST, TargetExtConstTypes, TrackedConstRegs);
+ foldConstantsIntoIntrinsics(MF, TrackedConstRegs);
insertBitcasts(MF, GR, MIB);
generateAssignInstrs(MF, GR, MIB, TargetExtConstTypes);
processSwitches(MF, GR, MIB);
diff --git a/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td b/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td
index 318c5cebb7a4..96601dd8796c 100644
--- a/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td
+++ b/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td
@@ -302,6 +302,7 @@ defm SPV_INTEL_inline_assembly : ExtensionOperand<107>;
defm SPV_INTEL_cache_controls : ExtensionOperand<108>;
defm SPV_INTEL_global_variable_host_access : ExtensionOperand<109>;
defm SPV_INTEL_global_variable_fpga_decorations : ExtensionOperand<110>;
+defm SPV_KHR_cooperative_matrix : ExtensionOperand<111>;
//===----------------------------------------------------------------------===//
// Multiclass used to define Capabilities enum values and at the same time
@@ -478,6 +479,7 @@ defm GlobalVariableHostAccessINTEL : CapabilityOperand<6187, 0, 0, [SPV_INTEL_gl
defm HostAccessINTEL : CapabilityOperand<6188, 0, 0, [SPV_INTEL_global_variable_host_access], []>;
defm GlobalVariableFPGADecorationsINTEL : CapabilityOperand<6189, 0, 0, [SPV_INTEL_global_variable_fpga_decorations], []>;
defm CacheControlsINTEL : CapabilityOperand<6441, 0, 0, [SPV_INTEL_cache_controls], []>;
+defm CooperativeMatrixKHR : CapabilityOperand<6022, 0, 0, [SPV_KHR_cooperative_matrix], []>;
//===----------------------------------------------------------------------===//
// Multiclass used to define SourceLanguage enum values and at the same time
diff --git a/llvm/lib/Target/SPIRV/SPIRVUtils.cpp b/llvm/lib/Target/SPIRV/SPIRVUtils.cpp
index c1b90b0e9d88..927683ad7e32 100644
--- a/llvm/lib/Target/SPIRV/SPIRVUtils.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVUtils.cpp
@@ -253,7 +253,11 @@ SPIRV::MemorySemantics::MemorySemantics getMemSemantics(AtomicOrdering Ord) {
MachineInstr *getDefInstrMaybeConstant(Register &ConstReg,
const MachineRegisterInfo *MRI) {
- MachineInstr *ConstInstr = MRI->getVRegDef(ConstReg);
+ MachineInstr *MI = MRI->getVRegDef(ConstReg);
+ MachineInstr *ConstInstr =
+ MI->getOpcode() == SPIRV::G_TRUNC || MI->getOpcode() == SPIRV::G_ZEXT
+ ? MRI->getVRegDef(MI->getOperand(1).getReg())
+ : MI;
if (auto *GI = dyn_cast<GIntrinsic>(ConstInstr)) {
if (GI->is(Intrinsic::spv_track_constant)) {
ConstReg = ConstInstr->getOperand(2).getReg();
diff --git a/llvm/lib/Target/SPIRV/SPIRVUtils.h b/llvm/lib/Target/SPIRV/SPIRVUtils.h
index c131eecb1c13..12725d6bac14 100644
--- a/llvm/lib/Target/SPIRV/SPIRVUtils.h
+++ b/llvm/lib/Target/SPIRV/SPIRVUtils.h
@@ -160,5 +160,29 @@ inline Type *toTypedPointer(Type *Ty) {
: Ty;
}
+inline Type *toTypedFunPointer(FunctionType *FTy) {
+ Type *OrigRetTy = FTy->getReturnType();
+ Type *RetTy = toTypedPointer(OrigRetTy);
+ bool IsUntypedPtr = false;
+ for (Type *PTy : FTy->params()) {
+ if (isUntypedPointerTy(PTy)) {
+ IsUntypedPtr = true;
+ break;
+ }
+ }
+ if (!IsUntypedPtr && RetTy == OrigRetTy)
+ return FTy;
+ SmallVector<Type *> ParamTys;
+ for (Type *PTy : FTy->params())
+ ParamTys.push_back(toTypedPointer(PTy));
+ return FunctionType::get(RetTy, ParamTys, FTy->isVarArg());
+}
+
+inline const Type *unifyPtrType(const Type *Ty) {
+ if (auto FTy = dyn_cast<FunctionType>(Ty))
+ return toTypedFunPointer(const_cast<FunctionType *>(FTy));
+ return toTypedPointer(const_cast<Type *>(Ty));
+}
+
} // namespace llvm
#endif // LLVM_LIB_TARGET_SPIRV_SPIRVUTILS_H