diff options
Diffstat (limited to 'llvm/lib/IR/ProfDataUtils.cpp')
| -rw-r--r-- | llvm/lib/IR/ProfDataUtils.cpp | 51 |
1 files changed, 39 insertions, 12 deletions
diff --git a/llvm/lib/IR/ProfDataUtils.cpp b/llvm/lib/IR/ProfDataUtils.cpp index 51e78dc5e6c0..992ce34e0003 100644 --- a/llvm/lib/IR/ProfDataUtils.cpp +++ b/llvm/lib/IR/ProfDataUtils.cpp @@ -19,6 +19,7 @@ #include "llvm/IR/LLVMContext.h" #include "llvm/IR/MDBuilder.h" #include "llvm/IR/Metadata.h" +#include "llvm/IR/ProfDataUtils.h" #include "llvm/Support/BranchProbability.h" #include "llvm/Support/CommandLine.h" @@ -40,9 +41,6 @@ namespace { // We maintain some constants here to ensure that we access the branch weights // correctly, and can change the behavior in the future if the layout changes -// The index at which the weights vector starts -constexpr unsigned WeightsIdx = 1; - // the minimum number of operands for MD_prof nodes with branch weights constexpr unsigned MinBWOps = 3; @@ -75,6 +73,7 @@ static void extractFromBranchWeightMD(const MDNode *ProfileData, assert(isBranchWeightMD(ProfileData) && "wrong metadata"); unsigned NOps = ProfileData->getNumOperands(); + unsigned WeightsIdx = getBranchWeightOffset(ProfileData); assert(WeightsIdx < NOps && "Weights Index must be less than NOps."); Weights.resize(NOps - WeightsIdx); @@ -82,8 +81,8 @@ static void extractFromBranchWeightMD(const MDNode *ProfileData, ConstantInt *Weight = mdconst::dyn_extract<ConstantInt>(ProfileData->getOperand(Idx)); assert(Weight && "Malformed branch_weight in MD_prof node"); - assert(Weight->getValue().getActiveBits() <= 32 && - "Too many bits for uint32_t"); + assert(Weight->getValue().getActiveBits() <= (sizeof(T) * 8) && + "Too many bits for MD_prof branch_weight"); Weights[Idx - WeightsIdx] = Weight->getZExtValue(); } } @@ -123,6 +122,30 @@ bool hasValidBranchWeightMD(const Instruction &I) { return getValidBranchWeightMDNode(I); } +bool hasBranchWeightOrigin(const Instruction &I) { + auto *ProfileData = I.getMetadata(LLVMContext::MD_prof); + return hasBranchWeightOrigin(ProfileData); +} + +bool hasBranchWeightOrigin(const MDNode *ProfileData) { + if (!isBranchWeightMD(ProfileData)) + return false; + auto *ProfDataName = dyn_cast<MDString>(ProfileData->getOperand(1)); + // NOTE: if we ever have more types of branch weight provenance, + // we need to check the string value is "expected". For now, we + // supply a more generic API, and avoid the spurious comparisons. + assert(ProfDataName == nullptr || ProfDataName->getString() == "expected"); + return ProfDataName != nullptr; +} + +unsigned getBranchWeightOffset(const MDNode *ProfileData) { + return hasBranchWeightOrigin(ProfileData) ? 2 : 1; +} + +unsigned getNumBranchWeights(const MDNode &ProfileData) { + return ProfileData.getNumOperands() - getBranchWeightOffset(&ProfileData); +} + MDNode *getBranchWeightMDNode(const Instruction &I) { auto *ProfileData = I.getMetadata(LLVMContext::MD_prof); if (!isBranchWeightMD(ProfileData)) @@ -132,7 +155,7 @@ MDNode *getBranchWeightMDNode(const Instruction &I) { MDNode *getValidBranchWeightMDNode(const Instruction &I) { auto *ProfileData = getBranchWeightMDNode(I); - if (ProfileData && ProfileData->getNumOperands() == 1 + I.getNumSuccessors()) + if (ProfileData && getNumBranchWeights(*ProfileData) == I.getNumSuccessors()) return ProfileData; return nullptr; } @@ -191,7 +214,8 @@ bool extractProfTotalWeight(const MDNode *ProfileData, uint64_t &TotalVal) { return false; if (ProfDataName->getString() == "branch_weights") { - for (unsigned Idx = 1; Idx < ProfileData->getNumOperands(); Idx++) { + unsigned Offset = getBranchWeightOffset(ProfileData); + for (unsigned Idx = Offset; Idx < ProfileData->getNumOperands(); ++Idx) { auto *V = mdconst::dyn_extract<ConstantInt>(ProfileData->getOperand(Idx)); assert(V && "Malformed branch_weight in MD_prof node"); TotalVal += V->getValue().getZExtValue(); @@ -212,9 +236,10 @@ bool extractProfTotalWeight(const Instruction &I, uint64_t &TotalVal) { return extractProfTotalWeight(I.getMetadata(LLVMContext::MD_prof), TotalVal); } -void setBranchWeights(Instruction &I, ArrayRef<uint32_t> Weights) { +void setBranchWeights(Instruction &I, ArrayRef<uint32_t> Weights, + bool IsExpected) { MDBuilder MDB(I.getContext()); - MDNode *BranchWeights = MDB.createBranchWeights(Weights); + MDNode *BranchWeights = MDB.createBranchWeights(Weights, IsExpected); I.setMetadata(LLVMContext::MD_prof, BranchWeights); } @@ -241,9 +266,11 @@ void scaleProfData(Instruction &I, uint64_t S, uint64_t T) { if (ProfDataName->getString() == "branch_weights" && ProfileData->getNumOperands() > 0) { // Using APInt::div may be expensive, but most cases should fit 64 bits. - APInt Val(128, mdconst::dyn_extract<ConstantInt>(ProfileData->getOperand(1)) - ->getValue() - .getZExtValue()); + APInt Val(128, + mdconst::dyn_extract<ConstantInt>( + ProfileData->getOperand(getBranchWeightOffset(ProfileData))) + ->getValue() + .getZExtValue()); Val *= APS; Vals.push_back(MDB.createConstant(ConstantInt::get( Type::getInt32Ty(C), Val.udiv(APT).getLimitedValue(UINT32_MAX)))); |
