diff options
Diffstat (limited to 'llvm/lib/Target/SPIRV/SPIRVAsmPrinter.cpp')
| -rw-r--r-- | llvm/lib/Target/SPIRV/SPIRVAsmPrinter.cpp | 20 |
1 files changed, 14 insertions, 6 deletions
diff --git a/llvm/lib/Target/SPIRV/SPIRVAsmPrinter.cpp b/llvm/lib/Target/SPIRV/SPIRVAsmPrinter.cpp index ad0158086044..3206c264f99d 100644 --- a/llvm/lib/Target/SPIRV/SPIRVAsmPrinter.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVAsmPrinter.cpp @@ -69,7 +69,8 @@ public: void outputOpFunctionEnd(); void outputExtFuncDecls(); void outputExecutionModeFromMDNode(Register Reg, MDNode *Node, - SPIRV::ExecutionMode::ExecutionMode EM); + SPIRV::ExecutionMode::ExecutionMode EM, + unsigned ExpectMDOps, int64_t DefVal); void outputExecutionModeFromNumthreadsAttribute( const Register &Reg, const Attribute &Attr, SPIRV::ExecutionMode::ExecutionMode EM); @@ -422,12 +423,19 @@ static void addOpsFromMDNode(MDNode *MDN, MCInst &Inst, } void SPIRVAsmPrinter::outputExecutionModeFromMDNode( - Register Reg, MDNode *Node, SPIRV::ExecutionMode::ExecutionMode EM) { + Register Reg, MDNode *Node, SPIRV::ExecutionMode::ExecutionMode EM, + unsigned ExpectMDOps, int64_t DefVal) { MCInst Inst; Inst.setOpcode(SPIRV::OpExecutionMode); Inst.addOperand(MCOperand::createReg(Reg)); Inst.addOperand(MCOperand::createImm(static_cast<unsigned>(EM))); addOpsFromMDNode(Node, Inst, MAI); + // reqd_work_group_size and work_group_size_hint require 3 operands, + // if metadata contains less operands, just add a default value + unsigned NodeSz = Node->getNumOperands(); + if (ExpectMDOps > 0 && NodeSz < ExpectMDOps) + for (unsigned i = NodeSz; i < ExpectMDOps; ++i) + Inst.addOperand(MCOperand::createImm(DefVal)); outputMCInst(Inst); } @@ -473,17 +481,17 @@ void SPIRVAsmPrinter::outputExecutionMode(const Module &M) { Register FReg = MAI->getFuncReg(&F); assert(FReg.isValid()); if (MDNode *Node = F.getMetadata("reqd_work_group_size")) - outputExecutionModeFromMDNode(FReg, Node, - SPIRV::ExecutionMode::LocalSize); + outputExecutionModeFromMDNode(FReg, Node, SPIRV::ExecutionMode::LocalSize, + 3, 1); if (Attribute Attr = F.getFnAttribute("hlsl.numthreads"); Attr.isValid()) outputExecutionModeFromNumthreadsAttribute( FReg, Attr, SPIRV::ExecutionMode::LocalSize); if (MDNode *Node = F.getMetadata("work_group_size_hint")) outputExecutionModeFromMDNode(FReg, Node, - SPIRV::ExecutionMode::LocalSizeHint); + SPIRV::ExecutionMode::LocalSizeHint, 3, 1); if (MDNode *Node = F.getMetadata("intel_reqd_sub_group_size")) outputExecutionModeFromMDNode(FReg, Node, - SPIRV::ExecutionMode::SubgroupSize); + SPIRV::ExecutionMode::SubgroupSize, 0, 0); if (MDNode *Node = F.getMetadata("vec_type_hint")) { MCInst Inst; Inst.setOpcode(SPIRV::OpExecutionMode); |
