summaryrefslogtreecommitdiff
path: root/llvm/lib/IR/ProfDataUtils.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'llvm/lib/IR/ProfDataUtils.cpp')
-rw-r--r--llvm/lib/IR/ProfDataUtils.cpp51
1 files changed, 39 insertions, 12 deletions
diff --git a/llvm/lib/IR/ProfDataUtils.cpp b/llvm/lib/IR/ProfDataUtils.cpp
index 51e78dc5e6c0..992ce34e0003 100644
--- a/llvm/lib/IR/ProfDataUtils.cpp
+++ b/llvm/lib/IR/ProfDataUtils.cpp
@@ -19,6 +19,7 @@
#include "llvm/IR/LLVMContext.h"
#include "llvm/IR/MDBuilder.h"
#include "llvm/IR/Metadata.h"
+#include "llvm/IR/ProfDataUtils.h"
#include "llvm/Support/BranchProbability.h"
#include "llvm/Support/CommandLine.h"
@@ -40,9 +41,6 @@ namespace {
// We maintain some constants here to ensure that we access the branch weights
// correctly, and can change the behavior in the future if the layout changes
-// The index at which the weights vector starts
-constexpr unsigned WeightsIdx = 1;
-
// the minimum number of operands for MD_prof nodes with branch weights
constexpr unsigned MinBWOps = 3;
@@ -75,6 +73,7 @@ static void extractFromBranchWeightMD(const MDNode *ProfileData,
assert(isBranchWeightMD(ProfileData) && "wrong metadata");
unsigned NOps = ProfileData->getNumOperands();
+ unsigned WeightsIdx = getBranchWeightOffset(ProfileData);
assert(WeightsIdx < NOps && "Weights Index must be less than NOps.");
Weights.resize(NOps - WeightsIdx);
@@ -82,8 +81,8 @@ static void extractFromBranchWeightMD(const MDNode *ProfileData,
ConstantInt *Weight =
mdconst::dyn_extract<ConstantInt>(ProfileData->getOperand(Idx));
assert(Weight && "Malformed branch_weight in MD_prof node");
- assert(Weight->getValue().getActiveBits() <= 32 &&
- "Too many bits for uint32_t");
+ assert(Weight->getValue().getActiveBits() <= (sizeof(T) * 8) &&
+ "Too many bits for MD_prof branch_weight");
Weights[Idx - WeightsIdx] = Weight->getZExtValue();
}
}
@@ -123,6 +122,30 @@ bool hasValidBranchWeightMD(const Instruction &I) {
return getValidBranchWeightMDNode(I);
}
+bool hasBranchWeightOrigin(const Instruction &I) {
+ auto *ProfileData = I.getMetadata(LLVMContext::MD_prof);
+ return hasBranchWeightOrigin(ProfileData);
+}
+
+bool hasBranchWeightOrigin(const MDNode *ProfileData) {
+ if (!isBranchWeightMD(ProfileData))
+ return false;
+ auto *ProfDataName = dyn_cast<MDString>(ProfileData->getOperand(1));
+ // NOTE: if we ever have more types of branch weight provenance,
+ // we need to check the string value is "expected". For now, we
+ // supply a more generic API, and avoid the spurious comparisons.
+ assert(ProfDataName == nullptr || ProfDataName->getString() == "expected");
+ return ProfDataName != nullptr;
+}
+
+unsigned getBranchWeightOffset(const MDNode *ProfileData) {
+ return hasBranchWeightOrigin(ProfileData) ? 2 : 1;
+}
+
+unsigned getNumBranchWeights(const MDNode &ProfileData) {
+ return ProfileData.getNumOperands() - getBranchWeightOffset(&ProfileData);
+}
+
MDNode *getBranchWeightMDNode(const Instruction &I) {
auto *ProfileData = I.getMetadata(LLVMContext::MD_prof);
if (!isBranchWeightMD(ProfileData))
@@ -132,7 +155,7 @@ MDNode *getBranchWeightMDNode(const Instruction &I) {
MDNode *getValidBranchWeightMDNode(const Instruction &I) {
auto *ProfileData = getBranchWeightMDNode(I);
- if (ProfileData && ProfileData->getNumOperands() == 1 + I.getNumSuccessors())
+ if (ProfileData && getNumBranchWeights(*ProfileData) == I.getNumSuccessors())
return ProfileData;
return nullptr;
}
@@ -191,7 +214,8 @@ bool extractProfTotalWeight(const MDNode *ProfileData, uint64_t &TotalVal) {
return false;
if (ProfDataName->getString() == "branch_weights") {
- for (unsigned Idx = 1; Idx < ProfileData->getNumOperands(); Idx++) {
+ unsigned Offset = getBranchWeightOffset(ProfileData);
+ for (unsigned Idx = Offset; Idx < ProfileData->getNumOperands(); ++Idx) {
auto *V = mdconst::dyn_extract<ConstantInt>(ProfileData->getOperand(Idx));
assert(V && "Malformed branch_weight in MD_prof node");
TotalVal += V->getValue().getZExtValue();
@@ -212,9 +236,10 @@ bool extractProfTotalWeight(const Instruction &I, uint64_t &TotalVal) {
return extractProfTotalWeight(I.getMetadata(LLVMContext::MD_prof), TotalVal);
}
-void setBranchWeights(Instruction &I, ArrayRef<uint32_t> Weights) {
+void setBranchWeights(Instruction &I, ArrayRef<uint32_t> Weights,
+ bool IsExpected) {
MDBuilder MDB(I.getContext());
- MDNode *BranchWeights = MDB.createBranchWeights(Weights);
+ MDNode *BranchWeights = MDB.createBranchWeights(Weights, IsExpected);
I.setMetadata(LLVMContext::MD_prof, BranchWeights);
}
@@ -241,9 +266,11 @@ void scaleProfData(Instruction &I, uint64_t S, uint64_t T) {
if (ProfDataName->getString() == "branch_weights" &&
ProfileData->getNumOperands() > 0) {
// Using APInt::div may be expensive, but most cases should fit 64 bits.
- APInt Val(128, mdconst::dyn_extract<ConstantInt>(ProfileData->getOperand(1))
- ->getValue()
- .getZExtValue());
+ APInt Val(128,
+ mdconst::dyn_extract<ConstantInt>(
+ ProfileData->getOperand(getBranchWeightOffset(ProfileData)))
+ ->getValue()
+ .getZExtValue());
Val *= APS;
Vals.push_back(MDB.createConstant(ConstantInt::get(
Type::getInt32Ty(C), Val.udiv(APT).getLimitedValue(UINT32_MAX))));