summaryrefslogtreecommitdiff
path: root/llvm/lib/Target/DirectX/DXILOpLowering.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'llvm/lib/Target/DirectX/DXILOpLowering.cpp')
-rw-r--r--llvm/lib/Target/DirectX/DXILOpLowering.cpp55
1 files changed, 53 insertions, 2 deletions
diff --git a/llvm/lib/Target/DirectX/DXILOpLowering.cpp b/llvm/lib/Target/DirectX/DXILOpLowering.cpp
index 3e334b0ec298..f09e322f88e1 100644
--- a/llvm/lib/Target/DirectX/DXILOpLowering.cpp
+++ b/llvm/lib/Target/DirectX/DXILOpLowering.cpp
@@ -30,6 +30,48 @@
using namespace llvm;
using namespace llvm::dxil;
+static bool isVectorArgExpansion(Function &F) {
+ switch (F.getIntrinsicID()) {
+ case Intrinsic::dx_dot2:
+ case Intrinsic::dx_dot3:
+ case Intrinsic::dx_dot4:
+ return true;
+ }
+ return false;
+}
+
+static SmallVector<Value *> populateOperands(Value *Arg, IRBuilder<> &Builder) {
+ SmallVector<Value *, 4> ExtractedElements;
+ auto *VecArg = dyn_cast<FixedVectorType>(Arg->getType());
+ for (unsigned I = 0; I < VecArg->getNumElements(); ++I) {
+ Value *Index = ConstantInt::get(Type::getInt32Ty(Arg->getContext()), I);
+ Value *ExtractedElement = Builder.CreateExtractElement(Arg, Index);
+ ExtractedElements.push_back(ExtractedElement);
+ }
+ return ExtractedElements;
+}
+
+static SmallVector<Value *> argVectorFlatten(CallInst *Orig,
+ IRBuilder<> &Builder) {
+ // Note: arg[NumOperands-1] is a pointer and is not needed by our flattening.
+ unsigned NumOperands = Orig->getNumOperands() - 1;
+ assert(NumOperands > 0);
+ Value *Arg0 = Orig->getOperand(0);
+ [[maybe_unused]] auto *VecArg0 = dyn_cast<FixedVectorType>(Arg0->getType());
+ assert(VecArg0);
+ SmallVector<Value *> NewOperands = populateOperands(Arg0, Builder);
+ for (unsigned I = 1; I < NumOperands; ++I) {
+ Value *Arg = Orig->getOperand(I);
+ [[maybe_unused]] auto *VecArg = dyn_cast<FixedVectorType>(Arg->getType());
+ assert(VecArg);
+ assert(VecArg0->getElementType() == VecArg->getElementType());
+ assert(VecArg0->getNumElements() == VecArg->getNumElements());
+ auto NextOperandList = populateOperands(Arg, Builder);
+ NewOperands.append(NextOperandList.begin(), NextOperandList.end());
+ }
+ return NewOperands;
+}
+
static void lowerIntrinsic(dxil::OpCode DXILOp, Function &F, Module &M) {
IRBuilder<> B(M.getContext());
DXILOpBuilder DXILB(M, B);
@@ -39,9 +81,18 @@ static void lowerIntrinsic(dxil::OpCode DXILOp, Function &F, Module &M) {
if (!CI)
continue;
+ SmallVector<Value *> Args;
+ Value *DXILOpArg = B.getInt32(static_cast<unsigned>(DXILOp));
+ Args.emplace_back(DXILOpArg);
B.SetInsertPoint(CI);
- CallInst *DXILCI = DXILB.createDXILOpCall(DXILOp, F.getReturnType(),
- OverloadTy, CI->args());
+ if (isVectorArgExpansion(F)) {
+ SmallVector<Value *> NewArgs = argVectorFlatten(CI, B);
+ Args.append(NewArgs.begin(), NewArgs.end());
+ } else
+ Args.append(CI->arg_begin(), CI->arg_end());
+
+ CallInst *DXILCI =
+ DXILB.createDXILOpCall(DXILOp, F.getReturnType(), OverloadTy, Args);
CI->replaceAllUsesWith(DXILCI);
CI->eraseFromParent();