summaryrefslogtreecommitdiff
path: root/llvm/lib/Transforms/IPO/FunctionSpecialization.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'llvm/lib/Transforms/IPO/FunctionSpecialization.cpp')
-rw-r--r--llvm/lib/Transforms/IPO/FunctionSpecialization.cpp50
1 files changed, 45 insertions, 5 deletions
diff --git a/llvm/lib/Transforms/IPO/FunctionSpecialization.cpp b/llvm/lib/Transforms/IPO/FunctionSpecialization.cpp
index 9196a0147c43..30459caee160 100644
--- a/llvm/lib/Transforms/IPO/FunctionSpecialization.cpp
+++ b/llvm/lib/Transforms/IPO/FunctionSpecialization.cpp
@@ -89,6 +89,8 @@ static cl::opt<bool> SpecializeLiteralConstant(
"Enable specialization of functions that take a literal constant as an "
"argument"));
+extern cl::opt<bool> ProfcheckDisableMetadataFixes;
+
bool InstCostVisitor::canEliminateSuccessor(BasicBlock *BB,
BasicBlock *Succ) const {
unsigned I = 0;
@@ -784,9 +786,31 @@ bool FunctionSpecializer::run() {
// Update the known call sites to call the clone.
for (CallBase *Call : S.CallSites) {
+ Function *Clone = S.Clone;
LLVM_DEBUG(dbgs() << "FnSpecialization: Redirecting " << *Call
- << " to call " << S.Clone->getName() << "\n");
+ << " to call " << Clone->getName() << "\n");
Call->setCalledFunction(S.Clone);
+ auto &BFI = GetBFI(*Call->getFunction());
+ std::optional<uint64_t> Count =
+ BFI.getBlockProfileCount(Call->getParent());
+ if (Count && !ProfcheckDisableMetadataFixes) {
+ std::optional<llvm::Function::ProfileCount> MaybeCloneCount =
+ Clone->getEntryCount();
+ assert(MaybeCloneCount && "Clone entry count was not set!");
+ uint64_t CallCount = *Count + MaybeCloneCount->getCount();
+ Clone->setEntryCount(CallCount);
+ if (std::optional<llvm::Function::ProfileCount> MaybeOriginalCount =
+ S.F->getEntryCount()) {
+ uint64_t OriginalCount = MaybeOriginalCount->getCount();
+ if (OriginalCount >= CallCount) {
+ S.F->setEntryCount(OriginalCount - CallCount);
+ } else {
+ // This should generally not happen as that would mean there are
+ // more computed calls to the function than what was recorded.
+ LLVM_DEBUG(S.F->setEntryCount(0));
+ }
+ }
+ }
}
Clones.push_back(S.Clone);
@@ -838,14 +862,24 @@ bool FunctionSpecializer::run() {
}
void FunctionSpecializer::removeDeadFunctions() {
- for (Function *F : FullySpecialized) {
+ for (Function *F : DeadFunctions) {
LLVM_DEBUG(dbgs() << "FnSpecialization: Removing dead function "
<< F->getName() << "\n");
if (FAM)
FAM->clear(*F, F->getName());
+
+ // Remove all the callsites that were proven unreachable once, and replace
+ // them with poison.
+ for (User *U : make_early_inc_range(F->users())) {
+ assert((isa<CallInst>(U) || isa<InvokeInst>(U)) &&
+ "User of dead function must be call or invoke");
+ Instruction *CS = cast<Instruction>(U);
+ CS->replaceAllUsesWith(PoisonValue::get(CS->getType()));
+ CS->eraseFromParent();
+ }
F->eraseFromParent();
}
- FullySpecialized.clear();
+ DeadFunctions.clear();
}
/// Clone the function \p F and remove the ssa_copy intrinsics added by
@@ -1033,6 +1067,9 @@ Function *FunctionSpecializer::createSpecialization(Function *F,
// clone must.
Clone->setLinkage(GlobalValue::InternalLinkage);
+ if (F->getEntryCount() && !ProfcheckDisableMetadataFixes)
+ Clone->setEntryCount(0);
+
// Initialize the lattice state of the arguments of the function clone,
// marking the argument on which we specialized the function constant
// with the given value.
@@ -1206,8 +1243,11 @@ void FunctionSpecializer::updateCallSites(Function *F, const Spec *Begin,
// If the function has been completely specialized, the original function
// is no longer needed. Mark it unreachable.
- if (NCallsLeft == 0 && Solver.isArgumentTrackedFunction(F)) {
+ // NOTE: If the address of a function is taken, we cannot treat it as dead
+ // function.
+ if (NCallsLeft == 0 && Solver.isArgumentTrackedFunction(F) &&
+ !F->hasAddressTaken()) {
Solver.markFunctionUnreachable(F);
- FullySpecialized.insert(F);
+ DeadFunctions.insert(F);
}
}