1//===- Utils.h - IR2Vec/MIR2Vec Tool Classes ----------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8///
9/// \file
10/// This file contains the IR2VecTool and MIR2VecTool class definitions for
11/// the llvm-ir2vec embedding generation tool.
12///
13//===----------------------------------------------------------------------===//
14
15#ifndef LLVM_TOOLS_LLVM_IR2VEC_UTILS_UTILS_H
16#define LLVM_TOOLS_LLVM_IR2VEC_UTILS_UTILS_H
17
18#include "llvm/ADT/ArrayRef.h"
19#include "llvm/ADT/DenseMap.h"
20#include "llvm/Analysis/IR2Vec.h"
21#include "llvm/CodeGen/MIR2Vec.h"
22#include "llvm/CodeGen/MIRParser/MIRParser.h"
23#include "llvm/CodeGen/MachineFunction.h"
24#include "llvm/CodeGen/MachineModuleInfo.h"
25#include "llvm/CodeGen/TargetInstrInfo.h"
26#include "llvm/CodeGen/TargetRegisterInfo.h"
27#include "llvm/IR/BasicBlock.h"
28#include "llvm/IR/Function.h"
29#include "llvm/IR/InstIterator.h"
30#include "llvm/IR/Instructions.h"
31#include "llvm/IR/LLVMContext.h"
32#include "llvm/IR/Module.h"
33#include "llvm/IR/PassInstrumentation.h"
34#include "llvm/IR/PassManager.h"
35#include "llvm/IR/Type.h"
36#include "llvm/Support/Debug.h"
37#include "llvm/Support/Error.h"
38#include "llvm/Support/WithColor.h"
39#include "llvm/Support/raw_ostream.h"
40#include "llvm/Target/TargetMachine.h"
41#include <memory>
42#include <string>
43#include <vector>
44
45#define DEBUG_TYPE "ir2vec"
46
47namespace llvm {
48
49/// Tool name for error reporting
50static const char *ToolName = "llvm-ir2vec";
51
52/// Specifies the granularity at which embeddings are generated.
53enum EmbeddingLevel {
54 InstructionLevel, // Generate instruction-level embeddings
55 BasicBlockLevel, // Generate basic block-level embeddings
56 FunctionLevel // Generate function-level embeddings
57};
58
59/// Represents a single knowledge graph triplet (Head, Relation, Tail)
60/// where indices reference entities in an EntityList
61struct Triplet {
62 unsigned Head = 0; ///< Index of the head entity in the entity list
63 unsigned Tail = 0; ///< Index of the tail entity in the entity list
64 unsigned Relation = 0; ///< Relation type (see RelationType enum)
65};
66
67/// Result structure containing all generated triplets and metadata
68struct TripletResult {
69 unsigned MaxRelation =
70 0; ///< Highest relation index used (for ArgRelation + N)
71 std::vector<Triplet> Triplets; ///< Collection of all generated triplets
72};
73
74/// Entity mappings: [entity_name]
75using EntityList = std::vector<std::string>;
76using FuncEmbMap = DenseMap<const Function *, ir2vec::Embedding>;
77
78namespace ir2vec {
79
80/// Relation types for triplet generation
81enum RelationType {
82 TypeRelation = 0, ///< Instruction to type relationship
83 NextRelation = 1, ///< Sequential instruction relationship
84 ArgRelation = 2 ///< Instruction to operand relationship (ArgRelation + N)
85};
86
87/// Helper class for collecting IR triplets and generating embeddings
88class IR2VecTool {
89private:
90 Module &M;
91 ModuleAnalysisManager MAM;
92 std::unique_ptr<Vocabulary> Vocab;
93
94public:
95 explicit IR2VecTool(Module &M) : M(M) {}
96
97 /// Creates the embedding object for downstream embedding streaming
98 Expected<std::unique_ptr<Embedder>>
99 createIR2VecEmbedder(const Function &F, IR2VecKind Kind) const;
100
101 /// Initialize the IR2Vec vocabulary from the specified file path.
102 Error initializeVocabulary(StringRef VocabPath);
103
104 /// Generate triplets for a single function
105 /// Returns a TripletResult with:
106 /// - Triplets: vector of all (subject, object, relation) tuples
107 /// - MaxRelation: highest Arg relation ID used, or NextRelation if none
108 TripletResult generateTriplets(const Function &F) const;
109
110 /// Get triplets for the entire module
111 TripletResult generateTriplets() const;
112
113 /// Collect triplets for the module and dump output to stream
114 /// Output format: MAX_RELATION=N header followed by relationships
115 void writeTripletsToStream(raw_ostream &OS) const;
116
117 /// Generate entity mappings for the entire vocabulary
118 /// Returns EntityList containing all entity strings
119 static EntityList collectEntityMappings();
120
121 /// Dump entity ID to string mappings
122 static void writeEntitiesToStream(raw_ostream &OS);
123
124 // Get embedding for a single function
125 Expected<Embedding> getFunctionEmbedding(const Function &F,
126 IR2VecKind Kind) const;
127
128 /// Get embeddings for all functions in the module
129 Expected<FuncEmbMap> getFunctionEmbeddingsMap(IR2VecKind Kind) const;
130
131 /// Get embeddings for all basic blocks in a function
132 Expected<BBEmbeddingsMap> getBBEmbeddingsMap(const Function &F,
133 IR2VecKind Kind) const;
134 /// Get embeddings for all instructions in a function
135 Expected<InstEmbeddingsMap> getInstEmbeddingsMap(const Function &F,
136 IR2VecKind Kind) const;
137
138 /// Generate embeddings for the entire module
139 void writeEmbeddingsToStream(raw_ostream &OS, EmbeddingLevel Level) const;
140
141 /// Generate embeddings for a single function
142 void writeEmbeddingsToStream(const Function &F, raw_ostream &OS,
143 EmbeddingLevel Level) const;
144};
145
146} // namespace ir2vec
147
148namespace mir2vec {
149
150/// Relation types for MIR2Vec triplet generation
151enum MIRRelationType {
152 MIRNextRelation = 0, ///< Sequential instruction relationship
153 MIRArgRelation = 1 ///< Instruction to operand relationship (ArgRelation + N)
154};
155
156/// Helper class for MIR2Vec embedding generation
157class MIR2VecTool {
158private:
159 MachineModuleInfo &MMI;
160 std::unique_ptr<MIRVocabulary> Vocab;
161
162public:
163 explicit MIR2VecTool(MachineModuleInfo &MMI) : MMI(MMI) {}
164
165 /// Initialize MIR2Vec vocabulary from file (for embeddings generation)
166 bool initializeVocabulary(const Module &M);
167
168 /// Initialize vocabulary with layout information only.
169 /// This creates a minimal vocabulary with correct layout but no actual
170 /// embeddings. Sufficient for generating training data and entity mappings.
171 ///
172 /// Note: Requires target-specific information from the first machine function
173 /// to determine the vocabulary layout (number of opcodes, register classes).
174 ///
175 /// FIXME: Use --target option to get target info directly, avoiding the need
176 /// to parse machine functions for pre-training operations.
177 bool initializeVocabularyForLayout(const Module &M);
178
179 /// Get triplets for a single machine function
180 /// Returns TripletResult containing MaxRelation and vector of Triplets
181 TripletResult generateTriplets(const MachineFunction &MF) const;
182
183 /// Get triplets for the entire module
184 /// Returns TripletResult containing aggregated MaxRelation and all Triplets
185 TripletResult generateTriplets(const Module &M) const;
186
187 /// Collect triplets for the module and write to output stream
188 /// Output format: MAX_RELATION=N header followed by relationships
189 void writeTripletsToStream(const Module &M, raw_ostream &OS) const;
190
191 /// Generate entity mappings for the entire vocabulary
192 EntityList collectEntityMappings() const;
193
194 /// Generate entity mappings and write to output stream
195 void writeEntitiesToStream(raw_ostream &OS) const;
196
197 /// Generate embeddings for all machine functions in the module
198 void writeEmbeddingsToStream(const Module &M, raw_ostream &OS,
199 EmbeddingLevel Level) const;
200
201 /// Generate embeddings for a specific machine function
202 void writeEmbeddingsToStream(MachineFunction &MF, raw_ostream &OS,
203 EmbeddingLevel Level) const;
204
205 /// Get the MIR vocabulary instance
206 const MIRVocabulary *getVocabulary() const { return Vocab.get(); }
207};
208
209/// Helper structure to hold MIR context
210struct MIRContext {
211 LLVMContext Context; // CRITICAL: Must be first for proper destruction order
212 std::unique_ptr<Module> M;
213 std::unique_ptr<MachineModuleInfo> MMI;
214 std::unique_ptr<TargetMachine> TM;
215};
216
217} // namespace mir2vec
218
219} // namespace llvm
220
221#endif // LLVM_TOOLS_LLVM_IR2VEC_UTILS_UTILS_H
222