summaryrefslogtreecommitdiff
path: root/llvm/lib/Target/SPIRV/SPIRVAsmPrinter.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'llvm/lib/Target/SPIRV/SPIRVAsmPrinter.cpp')
-rw-r--r--llvm/lib/Target/SPIRV/SPIRVAsmPrinter.cpp20
1 files changed, 14 insertions, 6 deletions
diff --git a/llvm/lib/Target/SPIRV/SPIRVAsmPrinter.cpp b/llvm/lib/Target/SPIRV/SPIRVAsmPrinter.cpp
index ad0158086044..3206c264f99d 100644
--- a/llvm/lib/Target/SPIRV/SPIRVAsmPrinter.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVAsmPrinter.cpp
@@ -69,7 +69,8 @@ public:
void outputOpFunctionEnd();
void outputExtFuncDecls();
void outputExecutionModeFromMDNode(Register Reg, MDNode *Node,
- SPIRV::ExecutionMode::ExecutionMode EM);
+ SPIRV::ExecutionMode::ExecutionMode EM,
+ unsigned ExpectMDOps, int64_t DefVal);
void outputExecutionModeFromNumthreadsAttribute(
const Register &Reg, const Attribute &Attr,
SPIRV::ExecutionMode::ExecutionMode EM);
@@ -422,12 +423,19 @@ static void addOpsFromMDNode(MDNode *MDN, MCInst &Inst,
}
void SPIRVAsmPrinter::outputExecutionModeFromMDNode(
- Register Reg, MDNode *Node, SPIRV::ExecutionMode::ExecutionMode EM) {
+ Register Reg, MDNode *Node, SPIRV::ExecutionMode::ExecutionMode EM,
+ unsigned ExpectMDOps, int64_t DefVal) {
MCInst Inst;
Inst.setOpcode(SPIRV::OpExecutionMode);
Inst.addOperand(MCOperand::createReg(Reg));
Inst.addOperand(MCOperand::createImm(static_cast<unsigned>(EM)));
addOpsFromMDNode(Node, Inst, MAI);
+ // reqd_work_group_size and work_group_size_hint require 3 operands,
+ // if metadata contains less operands, just add a default value
+ unsigned NodeSz = Node->getNumOperands();
+ if (ExpectMDOps > 0 && NodeSz < ExpectMDOps)
+ for (unsigned i = NodeSz; i < ExpectMDOps; ++i)
+ Inst.addOperand(MCOperand::createImm(DefVal));
outputMCInst(Inst);
}
@@ -473,17 +481,17 @@ void SPIRVAsmPrinter::outputExecutionMode(const Module &M) {
Register FReg = MAI->getFuncReg(&F);
assert(FReg.isValid());
if (MDNode *Node = F.getMetadata("reqd_work_group_size"))
- outputExecutionModeFromMDNode(FReg, Node,
- SPIRV::ExecutionMode::LocalSize);
+ outputExecutionModeFromMDNode(FReg, Node, SPIRV::ExecutionMode::LocalSize,
+ 3, 1);
if (Attribute Attr = F.getFnAttribute("hlsl.numthreads"); Attr.isValid())
outputExecutionModeFromNumthreadsAttribute(
FReg, Attr, SPIRV::ExecutionMode::LocalSize);
if (MDNode *Node = F.getMetadata("work_group_size_hint"))
outputExecutionModeFromMDNode(FReg, Node,
- SPIRV::ExecutionMode::LocalSizeHint);
+ SPIRV::ExecutionMode::LocalSizeHint, 3, 1);
if (MDNode *Node = F.getMetadata("intel_reqd_sub_group_size"))
outputExecutionModeFromMDNode(FReg, Node,
- SPIRV::ExecutionMode::SubgroupSize);
+ SPIRV::ExecutionMode::SubgroupSize, 0, 0);
if (MDNode *Node = F.getMetadata("vec_type_hint")) {
MCInst Inst;
Inst.setOpcode(SPIRV::OpExecutionMode);