summaryrefslogtreecommitdiff
path: root/llvm/lib/Target/DirectX/DXILOpLowering.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'llvm/lib/Target/DirectX/DXILOpLowering.cpp')
-rw-r--r--llvm/lib/Target/DirectX/DXILOpLowering.cpp128
1 files changed, 69 insertions, 59 deletions
diff --git a/llvm/lib/Target/DirectX/DXILOpLowering.cpp b/llvm/lib/Target/DirectX/DXILOpLowering.cpp
index 5f84cdcfda6d..fb708a61dd31 100644
--- a/llvm/lib/Target/DirectX/DXILOpLowering.cpp
+++ b/llvm/lib/Target/DirectX/DXILOpLowering.cpp
@@ -1,15 +1,12 @@
-//===- DXILOpLower.cpp - Lowering LLVM intrinsic to DIXLOp function -------===//
+//===- DXILOpLowering.cpp - Lowering to DXIL operations -------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
-///
-/// \file This file contains passes and utilities to lower llvm intrinsic call
-/// to DXILOp function call.
-//===----------------------------------------------------------------------===//
+#include "DXILOpLowering.h"
#include "DXILConstants.h"
#include "DXILIntrinsicExpansion.h"
#include "DXILOpBuilder.h"
@@ -73,77 +70,90 @@ static SmallVector<Value *> argVectorFlatten(CallInst *Orig,
return NewOperands;
}
-static void lowerIntrinsic(dxil::OpCode DXILOp, Function &F, Module &M) {
- IRBuilder<> B(M.getContext());
- DXILOpBuilder OpBuilder(M, B);
- for (User *U : make_early_inc_range(F.users())) {
- CallInst *CI = dyn_cast<CallInst>(U);
- if (!CI)
- continue;
-
- SmallVector<Value *> Args;
- B.SetInsertPoint(CI);
- if (isVectorArgExpansion(F)) {
- SmallVector<Value *> NewArgs = argVectorFlatten(CI, B);
- Args.append(NewArgs.begin(), NewArgs.end());
- } else
- Args.append(CI->arg_begin(), CI->arg_end());
-
- Expected<CallInst *> OpCallOrErr = OpBuilder.tryCreateOp(DXILOp, Args,
- F.getReturnType());
- if (Error E = OpCallOrErr.takeError()) {
- std::string Message(toString(std::move(E)));
- DiagnosticInfoUnsupported Diag(*CI->getFunction(), Message,
- CI->getDebugLoc());
- M.getContext().diagnose(Diag);
- continue;
+namespace {
+class OpLowerer {
+ Module &M;
+ DXILOpBuilder OpBuilder;
+
+public:
+ OpLowerer(Module &M) : M(M), OpBuilder(M) {}
+
+ void replaceFunction(Function &F,
+ llvm::function_ref<Error(CallInst *CI)> ReplaceCall) {
+ for (User *U : make_early_inc_range(F.users())) {
+ CallInst *CI = dyn_cast<CallInst>(U);
+ if (!CI)
+ continue;
+
+ if (Error E = ReplaceCall(CI)) {
+ std::string Message(toString(std::move(E)));
+ DiagnosticInfoUnsupported Diag(*CI->getFunction(), Message,
+ CI->getDebugLoc());
+ M.getContext().diagnose(Diag);
+ continue;
+ }
}
- CallInst *OpCall = *OpCallOrErr;
+ if (F.user_empty())
+ F.eraseFromParent();
+ }
- CI->replaceAllUsesWith(OpCall);
- CI->eraseFromParent();
+ void replaceFunctionWithOp(Function &F, dxil::OpCode DXILOp) {
+ bool IsVectorArgExpansion = isVectorArgExpansion(F);
+ replaceFunction(F, [&](CallInst *CI) -> Error {
+ SmallVector<Value *> Args;
+ OpBuilder.getIRB().SetInsertPoint(CI);
+ if (IsVectorArgExpansion) {
+ SmallVector<Value *> NewArgs = argVectorFlatten(CI, OpBuilder.getIRB());
+ Args.append(NewArgs.begin(), NewArgs.end());
+ } else
+ Args.append(CI->arg_begin(), CI->arg_end());
+
+ Expected<CallInst *> OpCall =
+ OpBuilder.tryCreateOp(DXILOp, Args, F.getReturnType());
+ if (Error E = OpCall.takeError())
+ return E;
+
+ CI->replaceAllUsesWith(*OpCall);
+ CI->eraseFromParent();
+ return Error::success();
+ });
}
- if (F.user_empty())
- F.eraseFromParent();
-}
-static bool lowerIntrinsics(Module &M) {
- bool Updated = false;
+ bool lowerIntrinsics() {
+ bool Updated = false;
- for (Function &F : make_early_inc_range(M.functions())) {
- if (!F.isDeclaration())
- continue;
- Intrinsic::ID ID = F.getIntrinsicID();
- switch (ID) {
- default:
- continue;
+ for (Function &F : make_early_inc_range(M.functions())) {
+ if (!F.isDeclaration())
+ continue;
+ Intrinsic::ID ID = F.getIntrinsicID();
+ switch (ID) {
+ default:
+ continue;
#define DXIL_OP_INTRINSIC(OpCode, Intrin) \
case Intrin: \
- lowerIntrinsic(OpCode, F, M); \
+ replaceFunctionWithOp(F, OpCode); \
break;
#include "DXILOperation.inc"
+ }
+ Updated = true;
}
- Updated = true;
- }
- return Updated;
-}
-
-namespace {
-/// A pass that transforms external global definitions into declarations.
-class DXILOpLowering : public PassInfoMixin<DXILOpLowering> {
-public:
- PreservedAnalyses run(Module &M, ModuleAnalysisManager &) {
- if (lowerIntrinsics(M))
- return PreservedAnalyses::none();
- return PreservedAnalyses::all();
+ return Updated;
}
};
} // namespace
+PreservedAnalyses DXILOpLowering::run(Module &M, ModuleAnalysisManager &) {
+ if (OpLowerer(M).lowerIntrinsics())
+ return PreservedAnalyses::none();
+ return PreservedAnalyses::all();
+}
+
namespace {
class DXILOpLoweringLegacy : public ModulePass {
public:
- bool runOnModule(Module &M) override { return lowerIntrinsics(M); }
+ bool runOnModule(Module &M) override {
+ return OpLowerer(M).lowerIntrinsics();
+ }
StringRef getPassName() const override { return "DXIL Op Lowering"; }
DXILOpLoweringLegacy() : ModulePass(ID) {}