summaryrefslogtreecommitdiff
path: root/llvm/lib/Analysis/IR2Vec.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'llvm/lib/Analysis/IR2Vec.cpp')
-rw-r--r--llvm/lib/Analysis/IR2Vec.cpp241
1 files changed, 128 insertions, 113 deletions
diff --git a/llvm/lib/Analysis/IR2Vec.cpp b/llvm/lib/Analysis/IR2Vec.cpp
index 95f30fd3f427..99afc0601d52 100644
--- a/llvm/lib/Analysis/IR2Vec.cpp
+++ b/llvm/lib/Analysis/IR2Vec.cpp
@@ -32,11 +32,11 @@ using namespace ir2vec;
#define DEBUG_TYPE "ir2vec"
STATISTIC(VocabMissCounter,
- "Number of lookups to entites not present in the vocabulary");
+ "Number of lookups to entities not present in the vocabulary");
namespace llvm {
namespace ir2vec {
-static cl::OptionCategory IR2VecCategory("IR2Vec Options");
+cl::OptionCategory IR2VecCategory("IR2Vec Options");
// FIXME: Use a default vocab when not specified
static cl::opt<std::string>
@@ -52,6 +52,15 @@ cl::opt<float> TypeWeight("ir2vec-type-weight", cl::Optional, cl::init(0.5),
cl::opt<float> ArgWeight("ir2vec-arg-weight", cl::Optional, cl::init(0.2),
cl::desc("Weight for argument embeddings"),
cl::cat(IR2VecCategory));
+cl::opt<IR2VecKind> IR2VecEmbeddingKind(
+ "ir2vec-kind", cl::Optional,
+ cl::values(clEnumValN(IR2VecKind::Symbolic, "symbolic",
+ "Generate symbolic embeddings"),
+ clEnumValN(IR2VecKind::FlowAware, "flow-aware",
+ "Generate flow-aware embeddings")),
+ cl::init(IR2VecKind::Symbolic), cl::desc("IR2Vec embedding kind"),
+ cl::cat(IR2VecCategory));
+
} // namespace ir2vec
} // namespace llvm
@@ -123,8 +132,12 @@ bool Embedding::approximatelyEquals(const Embedding &RHS,
double Tolerance) const {
assert(this->size() == RHS.size() && "Vectors must have the same dimension");
for (size_t Itr = 0; Itr < this->size(); ++Itr)
- if (std::abs((*this)[Itr] - RHS[Itr]) > Tolerance)
+ if (std::abs((*this)[Itr] - RHS[Itr]) > Tolerance) {
+ LLVM_DEBUG(errs() << "Embedding mismatch at index " << Itr << ": "
+ << (*this)[Itr] << " vs " << RHS[Itr]
+ << "; Tolerance: " << Tolerance << "\n");
return false;
+ }
return true;
}
@@ -141,14 +154,16 @@ void Embedding::print(raw_ostream &OS) const {
Embedder::Embedder(const Function &F, const Vocabulary &Vocab)
: F(F), Vocab(Vocab), Dimension(Vocab.getDimension()),
- OpcWeight(::OpcWeight), TypeWeight(::TypeWeight), ArgWeight(::ArgWeight) {
-}
+ OpcWeight(::OpcWeight), TypeWeight(::TypeWeight), ArgWeight(::ArgWeight),
+ FuncVector(Embedding(Dimension)) {}
std::unique_ptr<Embedder> Embedder::create(IR2VecKind Mode, const Function &F,
const Vocabulary &Vocab) {
switch (Mode) {
case IR2VecKind::Symbolic:
return std::make_unique<SymbolicEmbedder>(F, Vocab);
+ case IR2VecKind::FlowAware:
+ return std::make_unique<FlowAwareEmbedder>(F, Vocab);
}
return nullptr;
}
@@ -180,6 +195,17 @@ const Embedding &Embedder::getFunctionVector() const {
return FuncVector;
}
+void Embedder::computeEmbeddings() const {
+ if (F.isDeclaration())
+ return;
+
+ // Consider only the basic blocks that are reachable from entry
+ for (const BasicBlock *BB : depth_first(&F)) {
+ computeEmbeddings(*BB);
+ FuncVector += BBVecMap[BB];
+ }
+}
+
void SymbolicEmbedder::computeEmbeddings(const BasicBlock &BB) const {
Embedding BBVector(Dimension, 0);
@@ -187,7 +213,7 @@ void SymbolicEmbedder::computeEmbeddings(const BasicBlock &BB) const {
for (const auto &I : BB.instructionsWithoutDebug()) {
Embedding ArgEmb(Dimension, 0);
for (const auto &Op : I.operands())
- ArgEmb += Vocab[Op];
+ ArgEmb += Vocab[*Op];
auto InstVector =
Vocab[I.getOpcode()] + Vocab[I.getType()->getTypeID()] + ArgEmb;
InstVecMap[&I] = InstVector;
@@ -196,51 +222,75 @@ void SymbolicEmbedder::computeEmbeddings(const BasicBlock &BB) const {
BBVecMap[&BB] = BBVector;
}
-void SymbolicEmbedder::computeEmbeddings() const {
- if (F.isDeclaration())
- return;
+void FlowAwareEmbedder::computeEmbeddings(const BasicBlock &BB) const {
+ Embedding BBVector(Dimension, 0);
- // Consider only the basic blocks that are reachable from entry
- for (const BasicBlock *BB : depth_first(&F)) {
- computeEmbeddings(*BB);
- FuncVector += BBVecMap[BB];
+ // We consider only the non-debug and non-pseudo instructions
+ for (const auto &I : BB.instructionsWithoutDebug()) {
+ // TODO: Handle call instructions differently.
+ // For now, we treat them like other instructions
+ Embedding ArgEmb(Dimension, 0);
+ for (const auto &Op : I.operands()) {
+ // If the operand is defined elsewhere, we use its embedding
+ if (const auto *DefInst = dyn_cast<Instruction>(Op)) {
+ auto DefIt = InstVecMap.find(DefInst);
+ assert(DefIt != InstVecMap.end() &&
+ "Instruction should have been processed before its operands");
+ ArgEmb += DefIt->second;
+ continue;
+ }
+ // If the operand is not defined by an instruction, we use the vocabulary
+ else {
+ LLVM_DEBUG(errs() << "Using embedding from vocabulary for operand: "
+ << *Op << "=" << Vocab[*Op][0] << "\n");
+ ArgEmb += Vocab[*Op];
+ }
+ }
+ // Create the instruction vector by combining opcode, type, and arguments
+ // embeddings
+ auto InstVector =
+ Vocab[I.getOpcode()] + Vocab[I.getType()->getTypeID()] + ArgEmb;
+ InstVecMap[&I] = InstVector;
+ BBVector += InstVector;
}
+ BBVecMap[&BB] = BBVector;
}
// ==----------------------------------------------------------------------===//
// Vocabulary
//===----------------------------------------------------------------------===//
-Vocabulary::Vocabulary(VocabVector &&Vocab)
- : Vocab(std::move(Vocab)), Valid(true) {}
+unsigned Vocabulary::getDimension() const {
+ assert(isValid() && "IR2Vec Vocabulary is invalid");
+ return Vocab[0].size();
+}
-bool Vocabulary::isValid() const {
- return Vocab.size() == Vocabulary::expectedSize() && Valid;
+unsigned Vocabulary::getSlotIndex(unsigned Opcode) {
+ assert(Opcode >= 1 && Opcode <= MaxOpcodes && "Invalid opcode");
+ return Opcode - 1; // Convert to zero-based index
}
-size_t Vocabulary::size() const {
- assert(Valid && "IR2Vec Vocabulary is invalid");
- return Vocab.size();
+unsigned Vocabulary::getSlotIndex(Type::TypeID TypeID) {
+ assert(static_cast<unsigned>(TypeID) < MaxTypeIDs && "Invalid type ID");
+ return MaxOpcodes + static_cast<unsigned>(getCanonicalTypeID(TypeID));
}
-unsigned Vocabulary::getDimension() const {
- assert(Valid && "IR2Vec Vocabulary is invalid");
- return Vocab[0].size();
+unsigned Vocabulary::getSlotIndex(const Value &Op) {
+ unsigned Index = static_cast<unsigned>(getOperandKind(&Op));
+ assert(Index < MaxOperandKinds && "Invalid OperandKind");
+ return MaxOpcodes + MaxCanonicalTypeIDs + Index;
}
const Embedding &Vocabulary::operator[](unsigned Opcode) const {
- assert(Opcode >= 1 && Opcode <= MaxOpcodes && "Invalid opcode");
- return Vocab[Opcode - 1];
+ return Vocab[getSlotIndex(Opcode)];
}
-const Embedding &Vocabulary::operator[](Type::TypeID TypeId) const {
- assert(static_cast<unsigned>(TypeId) < MaxTypeIDs && "Invalid type ID");
- return Vocab[MaxOpcodes + static_cast<unsigned>(TypeId)];
+const Embedding &Vocabulary::operator[](Type::TypeID TypeID) const {
+ return Vocab[getSlotIndex(TypeID)];
}
-const ir2vec::Embedding &Vocabulary::operator[](const Value *Arg) const {
- OperandKind ArgKind = getOperandKind(Arg);
- return Vocab[MaxOpcodes + MaxTypeIDs + static_cast<unsigned>(ArgKind)];
+const ir2vec::Embedding &Vocabulary::operator[](const Value &Arg) const {
+ return Vocab[getSlotIndex(Arg)];
}
StringRef Vocabulary::getVocabKeyForOpcode(unsigned Opcode) {
@@ -254,43 +304,21 @@ StringRef Vocabulary::getVocabKeyForOpcode(unsigned Opcode) {
return "UnknownOpcode";
}
+StringRef Vocabulary::getVocabKeyForCanonicalTypeID(CanonicalTypeID CType) {
+ unsigned Index = static_cast<unsigned>(CType);
+ assert(Index < MaxCanonicalTypeIDs && "Invalid CanonicalTypeID");
+ return CanonicalTypeNames[Index];
+}
+
+Vocabulary::CanonicalTypeID
+Vocabulary::getCanonicalTypeID(Type::TypeID TypeID) {
+ unsigned Index = static_cast<unsigned>(TypeID);
+ assert(Index < MaxTypeIDs && "Invalid TypeID");
+ return TypeIDMapping[Index];
+}
+
StringRef Vocabulary::getVocabKeyForTypeID(Type::TypeID TypeID) {
- switch (TypeID) {
- case Type::VoidTyID:
- return "VoidTy";
- case Type::HalfTyID:
- case Type::BFloatTyID:
- case Type::FloatTyID:
- case Type::DoubleTyID:
- case Type::X86_FP80TyID:
- case Type::FP128TyID:
- case Type::PPC_FP128TyID:
- return "FloatTy";
- case Type::IntegerTyID:
- return "IntegerTy";
- case Type::FunctionTyID:
- return "FunctionTy";
- case Type::StructTyID:
- return "StructTy";
- case Type::ArrayTyID:
- return "ArrayTy";
- case Type::PointerTyID:
- case Type::TypedPointerTyID:
- return "PointerTy";
- case Type::FixedVectorTyID:
- case Type::ScalableVectorTyID:
- return "VectorTy";
- case Type::LabelTyID:
- return "LabelTy";
- case Type::TokenTyID:
- return "TokenTy";
- case Type::MetadataTyID:
- return "MetadataTy";
- case Type::X86_AMXTyID:
- case Type::TargetExtTyID:
- return "UnknownTy";
- }
- return "UnknownTy";
+ return getVocabKeyForCanonicalTypeID(getCanonicalTypeID(TypeID));
}
StringRef Vocabulary::getVocabKeyForOperandKind(Vocabulary::OperandKind Kind) {
@@ -299,20 +327,6 @@ StringRef Vocabulary::getVocabKeyForOperandKind(Vocabulary::OperandKind Kind) {
return OperandKindNames[Index];
}
-Vocabulary::VocabVector Vocabulary::createDummyVocabForTest(unsigned Dim) {
- VocabVector DummyVocab;
- float DummyVal = 0.1f;
- // Create a dummy vocabulary with entries for all opcodes, types, and
- // operand
- for ([[maybe_unused]] unsigned _ :
- seq(0u, Vocabulary::MaxOpcodes + Vocabulary::MaxTypeIDs +
- Vocabulary::MaxOperandKinds)) {
- DummyVocab.push_back(Embedding(Dim, DummyVal));
- DummyVal += 0.1f;
- }
- return DummyVocab;
-}
-
// Helper function to classify an operand into OperandKind
Vocabulary::OperandKind Vocabulary::getOperandKind(const Value *Op) {
if (isa<Function>(Op))
@@ -324,34 +338,18 @@ Vocabulary::OperandKind Vocabulary::getOperandKind(const Value *Op) {
return OperandKind::VariableID;
}
-unsigned Vocabulary::getNumericID(unsigned Opcode) {
- assert(Opcode >= 1 && Opcode <= MaxOpcodes && "Invalid opcode");
- return Opcode - 1; // Convert to zero-based index
-}
-
-unsigned Vocabulary::getNumericID(Type::TypeID TypeID) {
- assert(static_cast<unsigned>(TypeID) < MaxTypeIDs && "Invalid type ID");
- return MaxOpcodes + static_cast<unsigned>(TypeID);
-}
-
-unsigned Vocabulary::getNumericID(const Value *Op) {
- unsigned Index = static_cast<unsigned>(getOperandKind(Op));
- assert(Index < MaxOperandKinds && "Invalid OperandKind");
- return MaxOpcodes + MaxTypeIDs + Index;
-}
-
StringRef Vocabulary::getStringKey(unsigned Pos) {
- assert(Pos < Vocabulary::expectedSize() &&
- "Position out of bounds in vocabulary");
+ assert(Pos < NumCanonicalEntries && "Position out of bounds in vocabulary");
// Opcode
if (Pos < MaxOpcodes)
return getVocabKeyForOpcode(Pos + 1);
// Type
- if (Pos < MaxOpcodes + MaxTypeIDs)
- return getVocabKeyForTypeID(static_cast<Type::TypeID>(Pos - MaxOpcodes));
+ if (Pos < MaxOpcodes + MaxCanonicalTypeIDs)
+ return getVocabKeyForCanonicalTypeID(
+ static_cast<CanonicalTypeID>(Pos - MaxOpcodes));
// Operand
return getVocabKeyForOperandKind(
- static_cast<OperandKind>(Pos - MaxOpcodes - MaxTypeIDs));
+ static_cast<OperandKind>(Pos - MaxOpcodes - MaxCanonicalTypeIDs));
}
// For now, assume vocabulary is stable unless explicitly invalidated.
@@ -361,6 +359,21 @@ bool Vocabulary::invalidate(Module &M, const PreservedAnalyses &PA,
return !(PAC.preservedWhenStateless());
}
+Vocabulary::VocabVector Vocabulary::createDummyVocabForTest(unsigned Dim) {
+ VocabVector DummyVocab;
+ DummyVocab.reserve(NumCanonicalEntries);
+ float DummyVal = 0.1f;
+ // Create a dummy vocabulary with entries for all opcodes, types, and
+ // operands
+ for ([[maybe_unused]] unsigned _ :
+ seq(0u, Vocabulary::MaxOpcodes + Vocabulary::MaxCanonicalTypeIDs +
+ Vocabulary::MaxOperandKinds)) {
+ DummyVocab.push_back(Embedding(Dim, DummyVal));
+ DummyVal += 0.1f;
+ }
+ return DummyVocab;
+}
+
// ==----------------------------------------------------------------------===//
// IR2VecVocabAnalysis
//===----------------------------------------------------------------------===//
@@ -452,7 +465,8 @@ void IR2VecVocabAnalysis::generateNumMappedVocab() {
// Handle Opcodes
std::vector<Embedding> NumericOpcodeEmbeddings(Vocabulary::MaxOpcodes,
- Embedding(Dim, 0));
+ Embedding(Dim));
+ NumericOpcodeEmbeddings.reserve(Vocabulary::MaxOpcodes);
for (unsigned Opcode : seq(0u, Vocabulary::MaxOpcodes)) {
StringRef VocabKey = Vocabulary::getVocabKeyForOpcode(Opcode + 1);
auto It = OpcVocab.find(VocabKey.str());
@@ -464,14 +478,15 @@ void IR2VecVocabAnalysis::generateNumMappedVocab() {
Vocab.insert(Vocab.end(), NumericOpcodeEmbeddings.begin(),
NumericOpcodeEmbeddings.end());
- // Handle Types
- std::vector<Embedding> NumericTypeEmbeddings(Vocabulary::MaxTypeIDs,
- Embedding(Dim, 0));
- for (unsigned TypeID : seq(0u, Vocabulary::MaxTypeIDs)) {
- StringRef VocabKey =
- Vocabulary::getVocabKeyForTypeID(static_cast<Type::TypeID>(TypeID));
+ // Handle Types - only canonical types are present in vocabulary
+ std::vector<Embedding> NumericTypeEmbeddings(Vocabulary::MaxCanonicalTypeIDs,
+ Embedding(Dim));
+ NumericTypeEmbeddings.reserve(Vocabulary::MaxCanonicalTypeIDs);
+ for (unsigned CTypeID : seq(0u, Vocabulary::MaxCanonicalTypeIDs)) {
+ StringRef VocabKey = Vocabulary::getVocabKeyForCanonicalTypeID(
+ static_cast<Vocabulary::CanonicalTypeID>(CTypeID));
if (auto It = TypeVocab.find(VocabKey.str()); It != TypeVocab.end()) {
- NumericTypeEmbeddings[TypeID] = It->second;
+ NumericTypeEmbeddings[CTypeID] = It->second;
continue;
}
handleMissingEntity(VocabKey.str());
@@ -481,7 +496,8 @@ void IR2VecVocabAnalysis::generateNumMappedVocab() {
// Handle Arguments/Operands
std::vector<Embedding> NumericArgEmbeddings(Vocabulary::MaxOperandKinds,
- Embedding(Dim, 0));
+ Embedding(Dim));
+ NumericArgEmbeddings.reserve(Vocabulary::MaxOperandKinds);
for (unsigned OpKind : seq(0u, Vocabulary::MaxOperandKinds)) {
Vocabulary::OperandKind Kind = static_cast<Vocabulary::OperandKind>(OpKind);
StringRef VocabKey = Vocabulary::getVocabKeyForOperandKind(Kind);
@@ -552,8 +568,7 @@ PreservedAnalyses IR2VecPrinterPass::run(Module &M,
assert(Vocabulary.isValid() && "IR2Vec Vocabulary is invalid");
for (Function &F : M) {
- std::unique_ptr<Embedder> Emb =
- Embedder::create(IR2VecKind::Symbolic, F, Vocabulary);
+ auto Emb = Embedder::create(IR2VecEmbeddingKind, F, Vocabulary);
if (!Emb) {
OS << "Error creating IR2Vec embeddings \n";
continue;