diff options
Diffstat (limited to 'llvm/lib/Analysis/IR2Vec.cpp')
| -rw-r--r-- | llvm/lib/Analysis/IR2Vec.cpp | 241 |
1 files changed, 128 insertions, 113 deletions
diff --git a/llvm/lib/Analysis/IR2Vec.cpp b/llvm/lib/Analysis/IR2Vec.cpp index 95f30fd3f427..99afc0601d52 100644 --- a/llvm/lib/Analysis/IR2Vec.cpp +++ b/llvm/lib/Analysis/IR2Vec.cpp @@ -32,11 +32,11 @@ using namespace ir2vec; #define DEBUG_TYPE "ir2vec" STATISTIC(VocabMissCounter, - "Number of lookups to entites not present in the vocabulary"); + "Number of lookups to entities not present in the vocabulary"); namespace llvm { namespace ir2vec { -static cl::OptionCategory IR2VecCategory("IR2Vec Options"); +cl::OptionCategory IR2VecCategory("IR2Vec Options"); // FIXME: Use a default vocab when not specified static cl::opt<std::string> @@ -52,6 +52,15 @@ cl::opt<float> TypeWeight("ir2vec-type-weight", cl::Optional, cl::init(0.5), cl::opt<float> ArgWeight("ir2vec-arg-weight", cl::Optional, cl::init(0.2), cl::desc("Weight for argument embeddings"), cl::cat(IR2VecCategory)); +cl::opt<IR2VecKind> IR2VecEmbeddingKind( + "ir2vec-kind", cl::Optional, + cl::values(clEnumValN(IR2VecKind::Symbolic, "symbolic", + "Generate symbolic embeddings"), + clEnumValN(IR2VecKind::FlowAware, "flow-aware", + "Generate flow-aware embeddings")), + cl::init(IR2VecKind::Symbolic), cl::desc("IR2Vec embedding kind"), + cl::cat(IR2VecCategory)); + } // namespace ir2vec } // namespace llvm @@ -123,8 +132,12 @@ bool Embedding::approximatelyEquals(const Embedding &RHS, double Tolerance) const { assert(this->size() == RHS.size() && "Vectors must have the same dimension"); for (size_t Itr = 0; Itr < this->size(); ++Itr) - if (std::abs((*this)[Itr] - RHS[Itr]) > Tolerance) + if (std::abs((*this)[Itr] - RHS[Itr]) > Tolerance) { + LLVM_DEBUG(errs() << "Embedding mismatch at index " << Itr << ": " + << (*this)[Itr] << " vs " << RHS[Itr] + << "; Tolerance: " << Tolerance << "\n"); return false; + } return true; } @@ -141,14 +154,16 @@ void Embedding::print(raw_ostream &OS) const { Embedder::Embedder(const Function &F, const Vocabulary &Vocab) : F(F), Vocab(Vocab), Dimension(Vocab.getDimension()), - OpcWeight(::OpcWeight), TypeWeight(::TypeWeight), ArgWeight(::ArgWeight) { -} + OpcWeight(::OpcWeight), TypeWeight(::TypeWeight), ArgWeight(::ArgWeight), + FuncVector(Embedding(Dimension)) {} std::unique_ptr<Embedder> Embedder::create(IR2VecKind Mode, const Function &F, const Vocabulary &Vocab) { switch (Mode) { case IR2VecKind::Symbolic: return std::make_unique<SymbolicEmbedder>(F, Vocab); + case IR2VecKind::FlowAware: + return std::make_unique<FlowAwareEmbedder>(F, Vocab); } return nullptr; } @@ -180,6 +195,17 @@ const Embedding &Embedder::getFunctionVector() const { return FuncVector; } +void Embedder::computeEmbeddings() const { + if (F.isDeclaration()) + return; + + // Consider only the basic blocks that are reachable from entry + for (const BasicBlock *BB : depth_first(&F)) { + computeEmbeddings(*BB); + FuncVector += BBVecMap[BB]; + } +} + void SymbolicEmbedder::computeEmbeddings(const BasicBlock &BB) const { Embedding BBVector(Dimension, 0); @@ -187,7 +213,7 @@ void SymbolicEmbedder::computeEmbeddings(const BasicBlock &BB) const { for (const auto &I : BB.instructionsWithoutDebug()) { Embedding ArgEmb(Dimension, 0); for (const auto &Op : I.operands()) - ArgEmb += Vocab[Op]; + ArgEmb += Vocab[*Op]; auto InstVector = Vocab[I.getOpcode()] + Vocab[I.getType()->getTypeID()] + ArgEmb; InstVecMap[&I] = InstVector; @@ -196,51 +222,75 @@ void SymbolicEmbedder::computeEmbeddings(const BasicBlock &BB) const { BBVecMap[&BB] = BBVector; } -void SymbolicEmbedder::computeEmbeddings() const { - if (F.isDeclaration()) - return; +void FlowAwareEmbedder::computeEmbeddings(const BasicBlock &BB) const { + Embedding BBVector(Dimension, 0); - // Consider only the basic blocks that are reachable from entry - for (const BasicBlock *BB : depth_first(&F)) { - computeEmbeddings(*BB); - FuncVector += BBVecMap[BB]; + // We consider only the non-debug and non-pseudo instructions + for (const auto &I : BB.instructionsWithoutDebug()) { + // TODO: Handle call instructions differently. + // For now, we treat them like other instructions + Embedding ArgEmb(Dimension, 0); + for (const auto &Op : I.operands()) { + // If the operand is defined elsewhere, we use its embedding + if (const auto *DefInst = dyn_cast<Instruction>(Op)) { + auto DefIt = InstVecMap.find(DefInst); + assert(DefIt != InstVecMap.end() && + "Instruction should have been processed before its operands"); + ArgEmb += DefIt->second; + continue; + } + // If the operand is not defined by an instruction, we use the vocabulary + else { + LLVM_DEBUG(errs() << "Using embedding from vocabulary for operand: " + << *Op << "=" << Vocab[*Op][0] << "\n"); + ArgEmb += Vocab[*Op]; + } + } + // Create the instruction vector by combining opcode, type, and arguments + // embeddings + auto InstVector = + Vocab[I.getOpcode()] + Vocab[I.getType()->getTypeID()] + ArgEmb; + InstVecMap[&I] = InstVector; + BBVector += InstVector; } + BBVecMap[&BB] = BBVector; } // ==----------------------------------------------------------------------===// // Vocabulary //===----------------------------------------------------------------------===// -Vocabulary::Vocabulary(VocabVector &&Vocab) - : Vocab(std::move(Vocab)), Valid(true) {} +unsigned Vocabulary::getDimension() const { + assert(isValid() && "IR2Vec Vocabulary is invalid"); + return Vocab[0].size(); +} -bool Vocabulary::isValid() const { - return Vocab.size() == Vocabulary::expectedSize() && Valid; +unsigned Vocabulary::getSlotIndex(unsigned Opcode) { + assert(Opcode >= 1 && Opcode <= MaxOpcodes && "Invalid opcode"); + return Opcode - 1; // Convert to zero-based index } -size_t Vocabulary::size() const { - assert(Valid && "IR2Vec Vocabulary is invalid"); - return Vocab.size(); +unsigned Vocabulary::getSlotIndex(Type::TypeID TypeID) { + assert(static_cast<unsigned>(TypeID) < MaxTypeIDs && "Invalid type ID"); + return MaxOpcodes + static_cast<unsigned>(getCanonicalTypeID(TypeID)); } -unsigned Vocabulary::getDimension() const { - assert(Valid && "IR2Vec Vocabulary is invalid"); - return Vocab[0].size(); +unsigned Vocabulary::getSlotIndex(const Value &Op) { + unsigned Index = static_cast<unsigned>(getOperandKind(&Op)); + assert(Index < MaxOperandKinds && "Invalid OperandKind"); + return MaxOpcodes + MaxCanonicalTypeIDs + Index; } const Embedding &Vocabulary::operator[](unsigned Opcode) const { - assert(Opcode >= 1 && Opcode <= MaxOpcodes && "Invalid opcode"); - return Vocab[Opcode - 1]; + return Vocab[getSlotIndex(Opcode)]; } -const Embedding &Vocabulary::operator[](Type::TypeID TypeId) const { - assert(static_cast<unsigned>(TypeId) < MaxTypeIDs && "Invalid type ID"); - return Vocab[MaxOpcodes + static_cast<unsigned>(TypeId)]; +const Embedding &Vocabulary::operator[](Type::TypeID TypeID) const { + return Vocab[getSlotIndex(TypeID)]; } -const ir2vec::Embedding &Vocabulary::operator[](const Value *Arg) const { - OperandKind ArgKind = getOperandKind(Arg); - return Vocab[MaxOpcodes + MaxTypeIDs + static_cast<unsigned>(ArgKind)]; +const ir2vec::Embedding &Vocabulary::operator[](const Value &Arg) const { + return Vocab[getSlotIndex(Arg)]; } StringRef Vocabulary::getVocabKeyForOpcode(unsigned Opcode) { @@ -254,43 +304,21 @@ StringRef Vocabulary::getVocabKeyForOpcode(unsigned Opcode) { return "UnknownOpcode"; } +StringRef Vocabulary::getVocabKeyForCanonicalTypeID(CanonicalTypeID CType) { + unsigned Index = static_cast<unsigned>(CType); + assert(Index < MaxCanonicalTypeIDs && "Invalid CanonicalTypeID"); + return CanonicalTypeNames[Index]; +} + +Vocabulary::CanonicalTypeID +Vocabulary::getCanonicalTypeID(Type::TypeID TypeID) { + unsigned Index = static_cast<unsigned>(TypeID); + assert(Index < MaxTypeIDs && "Invalid TypeID"); + return TypeIDMapping[Index]; +} + StringRef Vocabulary::getVocabKeyForTypeID(Type::TypeID TypeID) { - switch (TypeID) { - case Type::VoidTyID: - return "VoidTy"; - case Type::HalfTyID: - case Type::BFloatTyID: - case Type::FloatTyID: - case Type::DoubleTyID: - case Type::X86_FP80TyID: - case Type::FP128TyID: - case Type::PPC_FP128TyID: - return "FloatTy"; - case Type::IntegerTyID: - return "IntegerTy"; - case Type::FunctionTyID: - return "FunctionTy"; - case Type::StructTyID: - return "StructTy"; - case Type::ArrayTyID: - return "ArrayTy"; - case Type::PointerTyID: - case Type::TypedPointerTyID: - return "PointerTy"; - case Type::FixedVectorTyID: - case Type::ScalableVectorTyID: - return "VectorTy"; - case Type::LabelTyID: - return "LabelTy"; - case Type::TokenTyID: - return "TokenTy"; - case Type::MetadataTyID: - return "MetadataTy"; - case Type::X86_AMXTyID: - case Type::TargetExtTyID: - return "UnknownTy"; - } - return "UnknownTy"; + return getVocabKeyForCanonicalTypeID(getCanonicalTypeID(TypeID)); } StringRef Vocabulary::getVocabKeyForOperandKind(Vocabulary::OperandKind Kind) { @@ -299,20 +327,6 @@ StringRef Vocabulary::getVocabKeyForOperandKind(Vocabulary::OperandKind Kind) { return OperandKindNames[Index]; } -Vocabulary::VocabVector Vocabulary::createDummyVocabForTest(unsigned Dim) { - VocabVector DummyVocab; - float DummyVal = 0.1f; - // Create a dummy vocabulary with entries for all opcodes, types, and - // operand - for ([[maybe_unused]] unsigned _ : - seq(0u, Vocabulary::MaxOpcodes + Vocabulary::MaxTypeIDs + - Vocabulary::MaxOperandKinds)) { - DummyVocab.push_back(Embedding(Dim, DummyVal)); - DummyVal += 0.1f; - } - return DummyVocab; -} - // Helper function to classify an operand into OperandKind Vocabulary::OperandKind Vocabulary::getOperandKind(const Value *Op) { if (isa<Function>(Op)) @@ -324,34 +338,18 @@ Vocabulary::OperandKind Vocabulary::getOperandKind(const Value *Op) { return OperandKind::VariableID; } -unsigned Vocabulary::getNumericID(unsigned Opcode) { - assert(Opcode >= 1 && Opcode <= MaxOpcodes && "Invalid opcode"); - return Opcode - 1; // Convert to zero-based index -} - -unsigned Vocabulary::getNumericID(Type::TypeID TypeID) { - assert(static_cast<unsigned>(TypeID) < MaxTypeIDs && "Invalid type ID"); - return MaxOpcodes + static_cast<unsigned>(TypeID); -} - -unsigned Vocabulary::getNumericID(const Value *Op) { - unsigned Index = static_cast<unsigned>(getOperandKind(Op)); - assert(Index < MaxOperandKinds && "Invalid OperandKind"); - return MaxOpcodes + MaxTypeIDs + Index; -} - StringRef Vocabulary::getStringKey(unsigned Pos) { - assert(Pos < Vocabulary::expectedSize() && - "Position out of bounds in vocabulary"); + assert(Pos < NumCanonicalEntries && "Position out of bounds in vocabulary"); // Opcode if (Pos < MaxOpcodes) return getVocabKeyForOpcode(Pos + 1); // Type - if (Pos < MaxOpcodes + MaxTypeIDs) - return getVocabKeyForTypeID(static_cast<Type::TypeID>(Pos - MaxOpcodes)); + if (Pos < MaxOpcodes + MaxCanonicalTypeIDs) + return getVocabKeyForCanonicalTypeID( + static_cast<CanonicalTypeID>(Pos - MaxOpcodes)); // Operand return getVocabKeyForOperandKind( - static_cast<OperandKind>(Pos - MaxOpcodes - MaxTypeIDs)); + static_cast<OperandKind>(Pos - MaxOpcodes - MaxCanonicalTypeIDs)); } // For now, assume vocabulary is stable unless explicitly invalidated. @@ -361,6 +359,21 @@ bool Vocabulary::invalidate(Module &M, const PreservedAnalyses &PA, return !(PAC.preservedWhenStateless()); } +Vocabulary::VocabVector Vocabulary::createDummyVocabForTest(unsigned Dim) { + VocabVector DummyVocab; + DummyVocab.reserve(NumCanonicalEntries); + float DummyVal = 0.1f; + // Create a dummy vocabulary with entries for all opcodes, types, and + // operands + for ([[maybe_unused]] unsigned _ : + seq(0u, Vocabulary::MaxOpcodes + Vocabulary::MaxCanonicalTypeIDs + + Vocabulary::MaxOperandKinds)) { + DummyVocab.push_back(Embedding(Dim, DummyVal)); + DummyVal += 0.1f; + } + return DummyVocab; +} + // ==----------------------------------------------------------------------===// // IR2VecVocabAnalysis //===----------------------------------------------------------------------===// @@ -452,7 +465,8 @@ void IR2VecVocabAnalysis::generateNumMappedVocab() { // Handle Opcodes std::vector<Embedding> NumericOpcodeEmbeddings(Vocabulary::MaxOpcodes, - Embedding(Dim, 0)); + Embedding(Dim)); + NumericOpcodeEmbeddings.reserve(Vocabulary::MaxOpcodes); for (unsigned Opcode : seq(0u, Vocabulary::MaxOpcodes)) { StringRef VocabKey = Vocabulary::getVocabKeyForOpcode(Opcode + 1); auto It = OpcVocab.find(VocabKey.str()); @@ -464,14 +478,15 @@ void IR2VecVocabAnalysis::generateNumMappedVocab() { Vocab.insert(Vocab.end(), NumericOpcodeEmbeddings.begin(), NumericOpcodeEmbeddings.end()); - // Handle Types - std::vector<Embedding> NumericTypeEmbeddings(Vocabulary::MaxTypeIDs, - Embedding(Dim, 0)); - for (unsigned TypeID : seq(0u, Vocabulary::MaxTypeIDs)) { - StringRef VocabKey = - Vocabulary::getVocabKeyForTypeID(static_cast<Type::TypeID>(TypeID)); + // Handle Types - only canonical types are present in vocabulary + std::vector<Embedding> NumericTypeEmbeddings(Vocabulary::MaxCanonicalTypeIDs, + Embedding(Dim)); + NumericTypeEmbeddings.reserve(Vocabulary::MaxCanonicalTypeIDs); + for (unsigned CTypeID : seq(0u, Vocabulary::MaxCanonicalTypeIDs)) { + StringRef VocabKey = Vocabulary::getVocabKeyForCanonicalTypeID( + static_cast<Vocabulary::CanonicalTypeID>(CTypeID)); if (auto It = TypeVocab.find(VocabKey.str()); It != TypeVocab.end()) { - NumericTypeEmbeddings[TypeID] = It->second; + NumericTypeEmbeddings[CTypeID] = It->second; continue; } handleMissingEntity(VocabKey.str()); @@ -481,7 +496,8 @@ void IR2VecVocabAnalysis::generateNumMappedVocab() { // Handle Arguments/Operands std::vector<Embedding> NumericArgEmbeddings(Vocabulary::MaxOperandKinds, - Embedding(Dim, 0)); + Embedding(Dim)); + NumericArgEmbeddings.reserve(Vocabulary::MaxOperandKinds); for (unsigned OpKind : seq(0u, Vocabulary::MaxOperandKinds)) { Vocabulary::OperandKind Kind = static_cast<Vocabulary::OperandKind>(OpKind); StringRef VocabKey = Vocabulary::getVocabKeyForOperandKind(Kind); @@ -552,8 +568,7 @@ PreservedAnalyses IR2VecPrinterPass::run(Module &M, assert(Vocabulary.isValid() && "IR2Vec Vocabulary is invalid"); for (Function &F : M) { - std::unique_ptr<Embedder> Emb = - Embedder::create(IR2VecKind::Symbolic, F, Vocabulary); + auto Emb = Embedder::create(IR2VecEmbeddingKind, F, Vocabulary); if (!Emb) { OS << "Error creating IR2Vec embeddings \n"; continue; |
