diff options
Diffstat (limited to 'llvm/lib/Target/DirectX/DXILOpLowering.cpp')
| -rw-r--r-- | llvm/lib/Target/DirectX/DXILOpLowering.cpp | 55 |
1 files changed, 53 insertions, 2 deletions
diff --git a/llvm/lib/Target/DirectX/DXILOpLowering.cpp b/llvm/lib/Target/DirectX/DXILOpLowering.cpp index 3e334b0ec298..f09e322f88e1 100644 --- a/llvm/lib/Target/DirectX/DXILOpLowering.cpp +++ b/llvm/lib/Target/DirectX/DXILOpLowering.cpp @@ -30,6 +30,48 @@ using namespace llvm; using namespace llvm::dxil; +static bool isVectorArgExpansion(Function &F) { + switch (F.getIntrinsicID()) { + case Intrinsic::dx_dot2: + case Intrinsic::dx_dot3: + case Intrinsic::dx_dot4: + return true; + } + return false; +} + +static SmallVector<Value *> populateOperands(Value *Arg, IRBuilder<> &Builder) { + SmallVector<Value *, 4> ExtractedElements; + auto *VecArg = dyn_cast<FixedVectorType>(Arg->getType()); + for (unsigned I = 0; I < VecArg->getNumElements(); ++I) { + Value *Index = ConstantInt::get(Type::getInt32Ty(Arg->getContext()), I); + Value *ExtractedElement = Builder.CreateExtractElement(Arg, Index); + ExtractedElements.push_back(ExtractedElement); + } + return ExtractedElements; +} + +static SmallVector<Value *> argVectorFlatten(CallInst *Orig, + IRBuilder<> &Builder) { + // Note: arg[NumOperands-1] is a pointer and is not needed by our flattening. + unsigned NumOperands = Orig->getNumOperands() - 1; + assert(NumOperands > 0); + Value *Arg0 = Orig->getOperand(0); + [[maybe_unused]] auto *VecArg0 = dyn_cast<FixedVectorType>(Arg0->getType()); + assert(VecArg0); + SmallVector<Value *> NewOperands = populateOperands(Arg0, Builder); + for (unsigned I = 1; I < NumOperands; ++I) { + Value *Arg = Orig->getOperand(I); + [[maybe_unused]] auto *VecArg = dyn_cast<FixedVectorType>(Arg->getType()); + assert(VecArg); + assert(VecArg0->getElementType() == VecArg->getElementType()); + assert(VecArg0->getNumElements() == VecArg->getNumElements()); + auto NextOperandList = populateOperands(Arg, Builder); + NewOperands.append(NextOperandList.begin(), NextOperandList.end()); + } + return NewOperands; +} + static void lowerIntrinsic(dxil::OpCode DXILOp, Function &F, Module &M) { IRBuilder<> B(M.getContext()); DXILOpBuilder DXILB(M, B); @@ -39,9 +81,18 @@ static void lowerIntrinsic(dxil::OpCode DXILOp, Function &F, Module &M) { if (!CI) continue; + SmallVector<Value *> Args; + Value *DXILOpArg = B.getInt32(static_cast<unsigned>(DXILOp)); + Args.emplace_back(DXILOpArg); B.SetInsertPoint(CI); - CallInst *DXILCI = DXILB.createDXILOpCall(DXILOp, F.getReturnType(), - OverloadTy, CI->args()); + if (isVectorArgExpansion(F)) { + SmallVector<Value *> NewArgs = argVectorFlatten(CI, B); + Args.append(NewArgs.begin(), NewArgs.end()); + } else + Args.append(CI->arg_begin(), CI->arg_end()); + + CallInst *DXILCI = + DXILB.createDXILOpCall(DXILOp, F.getReturnType(), OverloadTy, Args); CI->replaceAllUsesWith(DXILCI); CI->eraseFromParent(); |
