diff options
Diffstat (limited to 'llvm/lib/Transforms/IPO/FunctionSpecialization.cpp')
| -rw-r--r-- | llvm/lib/Transforms/IPO/FunctionSpecialization.cpp | 50 |
1 files changed, 45 insertions, 5 deletions
diff --git a/llvm/lib/Transforms/IPO/FunctionSpecialization.cpp b/llvm/lib/Transforms/IPO/FunctionSpecialization.cpp index 9196a0147c43..30459caee160 100644 --- a/llvm/lib/Transforms/IPO/FunctionSpecialization.cpp +++ b/llvm/lib/Transforms/IPO/FunctionSpecialization.cpp @@ -89,6 +89,8 @@ static cl::opt<bool> SpecializeLiteralConstant( "Enable specialization of functions that take a literal constant as an " "argument")); +extern cl::opt<bool> ProfcheckDisableMetadataFixes; + bool InstCostVisitor::canEliminateSuccessor(BasicBlock *BB, BasicBlock *Succ) const { unsigned I = 0; @@ -784,9 +786,31 @@ bool FunctionSpecializer::run() { // Update the known call sites to call the clone. for (CallBase *Call : S.CallSites) { + Function *Clone = S.Clone; LLVM_DEBUG(dbgs() << "FnSpecialization: Redirecting " << *Call - << " to call " << S.Clone->getName() << "\n"); + << " to call " << Clone->getName() << "\n"); Call->setCalledFunction(S.Clone); + auto &BFI = GetBFI(*Call->getFunction()); + std::optional<uint64_t> Count = + BFI.getBlockProfileCount(Call->getParent()); + if (Count && !ProfcheckDisableMetadataFixes) { + std::optional<llvm::Function::ProfileCount> MaybeCloneCount = + Clone->getEntryCount(); + assert(MaybeCloneCount && "Clone entry count was not set!"); + uint64_t CallCount = *Count + MaybeCloneCount->getCount(); + Clone->setEntryCount(CallCount); + if (std::optional<llvm::Function::ProfileCount> MaybeOriginalCount = + S.F->getEntryCount()) { + uint64_t OriginalCount = MaybeOriginalCount->getCount(); + if (OriginalCount >= CallCount) { + S.F->setEntryCount(OriginalCount - CallCount); + } else { + // This should generally not happen as that would mean there are + // more computed calls to the function than what was recorded. + LLVM_DEBUG(S.F->setEntryCount(0)); + } + } + } } Clones.push_back(S.Clone); @@ -838,14 +862,24 @@ bool FunctionSpecializer::run() { } void FunctionSpecializer::removeDeadFunctions() { - for (Function *F : FullySpecialized) { + for (Function *F : DeadFunctions) { LLVM_DEBUG(dbgs() << "FnSpecialization: Removing dead function " << F->getName() << "\n"); if (FAM) FAM->clear(*F, F->getName()); + + // Remove all the callsites that were proven unreachable once, and replace + // them with poison. + for (User *U : make_early_inc_range(F->users())) { + assert((isa<CallInst>(U) || isa<InvokeInst>(U)) && + "User of dead function must be call or invoke"); + Instruction *CS = cast<Instruction>(U); + CS->replaceAllUsesWith(PoisonValue::get(CS->getType())); + CS->eraseFromParent(); + } F->eraseFromParent(); } - FullySpecialized.clear(); + DeadFunctions.clear(); } /// Clone the function \p F and remove the ssa_copy intrinsics added by @@ -1033,6 +1067,9 @@ Function *FunctionSpecializer::createSpecialization(Function *F, // clone must. Clone->setLinkage(GlobalValue::InternalLinkage); + if (F->getEntryCount() && !ProfcheckDisableMetadataFixes) + Clone->setEntryCount(0); + // Initialize the lattice state of the arguments of the function clone, // marking the argument on which we specialized the function constant // with the given value. @@ -1206,8 +1243,11 @@ void FunctionSpecializer::updateCallSites(Function *F, const Spec *Begin, // If the function has been completely specialized, the original function // is no longer needed. Mark it unreachable. - if (NCallsLeft == 0 && Solver.isArgumentTrackedFunction(F)) { + // NOTE: If the address of a function is taken, we cannot treat it as dead + // function. + if (NCallsLeft == 0 && Solver.isArgumentTrackedFunction(F) && + !F->hasAddressTaken()) { Solver.markFunctionUnreachable(F); - FullySpecialized.insert(F); + DeadFunctions.insert(F); } } |
