1//===- Utils.h - IR2Vec/MIR2Vec Tool Classes ----------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8///
9/// \file
10/// This file contains the IR2VecTool and MIR2VecTool class definitions for
11/// the llvm-ir2vec embedding generation tool.
12///
13//===----------------------------------------------------------------------===//
14
15#ifndef LLVM_TOOLS_LLVM_IR2VEC_UTILS_UTILS_H
16#define LLVM_TOOLS_LLVM_IR2VEC_UTILS_UTILS_H
17
18#include "llvm/ADT/ArrayRef.h"
19#include "llvm/ADT/DenseMap.h"
20#include "llvm/Analysis/IR2Vec.h"
21#include "llvm/CodeGen/MIR2Vec.h"
22#include "llvm/CodeGen/MIRParser/MIRParser.h"
23#include "llvm/CodeGen/MachineFunction.h"
24#include "llvm/CodeGen/MachineModuleInfo.h"
25#include "llvm/CodeGen/TargetInstrInfo.h"
26#include "llvm/CodeGen/TargetRegisterInfo.h"
27#include "llvm/IR/BasicBlock.h"
28#include "llvm/IR/Function.h"
29#include "llvm/IR/InstIterator.h"
30#include "llvm/IR/Instructions.h"
31#include "llvm/IR/LLVMContext.h"
32#include "llvm/IR/Module.h"
33#include "llvm/IR/PassInstrumentation.h"
34#include "llvm/IR/PassManager.h"
35#include "llvm/IR/Type.h"
36#include "llvm/Support/Debug.h"
37#include "llvm/Support/Error.h"
38#include "llvm/Support/WithColor.h"
39#include "llvm/Support/raw_ostream.h"
40#include "llvm/Target/TargetMachine.h"
41#include <memory>
42#include <string>
43#include <vector>
44
45#define DEBUG_TYPE "ir2vec"
46
47namespace llvm {
48
49/// Tool name for error reporting
50static const char *ToolName = "llvm-ir2vec";
51
52/// Specifies the granularity at which embeddings are generated.
53enum EmbeddingLevel {
54 InstructionLevel, // Generate instruction-level embeddings
55 BasicBlockLevel, // Generate basic block-level embeddings
56 FunctionLevel // Generate function-level embeddings
57};
58
59/// Represents a single knowledge graph triplet (Head, Relation, Tail)
60/// where indices reference entities in an EntityList
61struct Triplet {
62 unsigned Head = 0; ///< Index of the head entity in the entity list
63 unsigned Tail = 0; ///< Index of the tail entity in the entity list
64 unsigned Relation = 0; ///< Relation type (see RelationType enum)
65};
66
67/// Result structure containing all generated triplets and metadata
68struct TripletResult {
69 unsigned MaxRelation =
70 0; ///< Highest relation index used (for ArgRelation + N)
71 std::vector<Triplet> Triplets; ///< Collection of all generated triplets
72};
73
74/// Entity mappings: [entity_name]
75using EntityList = std::vector<std::string>;
76using FuncEmbMap = DenseMap<const Function *, ir2vec::Embedding>;
77
78namespace ir2vec {
79
80/// Relation types for triplet generation
81enum RelationType {
82 TypeRelation = 0, ///< Instruction to type relationship
83 NextRelation = 1, ///< Sequential instruction relationship
84 ArgRelation = 2 ///< Instruction to operand relationship (ArgRelation + N)
85};
86
87/// Helper class for collecting IR triplets and generating embeddings
88class IR2VecTool {
89private:
90 Module &M;
91 ModuleAnalysisManager MAM;
92 std::unique_ptr<Vocabulary> Vocab;
93
94public:
95 explicit IR2VecTool(Module &M) : M(M) {}
96
97 /// Initialize the IR2Vec vocabulary from the specified file path.
98 Error initializeVocabulary(StringRef VocabPath);
99
100 /// Generate triplets for a single function
101 /// Returns a TripletResult with:
102 /// - Triplets: vector of all (subject, object, relation) tuples
103 /// - MaxRelation: highest Arg relation ID used, or NextRelation if none
104 TripletResult generateTriplets(const Function &F) const;
105
106 /// Get triplets for the entire module
107 TripletResult generateTriplets() const;
108
109 /// Collect triplets for the module and dump output to stream
110 /// Output format: MAX_RELATION=N header followed by relationships
111 void writeTripletsToStream(raw_ostream &OS) const;
112
113 /// Generate entity mappings for the entire vocabulary
114 /// Returns EntityList containing all entity strings
115 static EntityList collectEntityMappings();
116
117 /// Dump entity ID to string mappings
118 static void writeEntitiesToStream(raw_ostream &OS);
119
120 // Get embedding for a single function
121 Expected<Embedding> getFunctionEmbedding(const Function &F,
122 IR2VecKind Kind) const;
123
124 /// Get embeddings for all functions in the module
125 Expected<FuncEmbMap> getFunctionEmbeddingsMap(IR2VecKind Kind) const;
126
127 /// Get embeddings for all basic blocks in a function
128 Expected<BBEmbeddingsMap> getBBEmbeddingsMap(const Function &F,
129 IR2VecKind Kind) const;
130
131 /// Generate embeddings for the entire module
132 void writeEmbeddingsToStream(raw_ostream &OS, EmbeddingLevel Level) const;
133
134 /// Generate embeddings for a single function
135 void writeEmbeddingsToStream(const Function &F, raw_ostream &OS,
136 EmbeddingLevel Level) const;
137};
138
139} // namespace ir2vec
140
141namespace mir2vec {
142
143/// Relation types for MIR2Vec triplet generation
144enum MIRRelationType {
145 MIRNextRelation = 0, ///< Sequential instruction relationship
146 MIRArgRelation = 1 ///< Instruction to operand relationship (ArgRelation + N)
147};
148
149/// Helper class for MIR2Vec embedding generation
150class MIR2VecTool {
151private:
152 MachineModuleInfo &MMI;
153 std::unique_ptr<MIRVocabulary> Vocab;
154
155public:
156 explicit MIR2VecTool(MachineModuleInfo &MMI) : MMI(MMI) {}
157
158 /// Initialize MIR2Vec vocabulary from file (for embeddings generation)
159 bool initializeVocabulary(const Module &M);
160
161 /// Initialize vocabulary with layout information only.
162 /// This creates a minimal vocabulary with correct layout but no actual
163 /// embeddings. Sufficient for generating training data and entity mappings.
164 ///
165 /// Note: Requires target-specific information from the first machine function
166 /// to determine the vocabulary layout (number of opcodes, register classes).
167 ///
168 /// FIXME: Use --target option to get target info directly, avoiding the need
169 /// to parse machine functions for pre-training operations.
170 bool initializeVocabularyForLayout(const Module &M);
171
172 /// Get triplets for a single machine function
173 /// Returns TripletResult containing MaxRelation and vector of Triplets
174 TripletResult generateTriplets(const MachineFunction &MF) const;
175
176 /// Get triplets for the entire module
177 /// Returns TripletResult containing aggregated MaxRelation and all Triplets
178 TripletResult generateTriplets(const Module &M) const;
179
180 /// Collect triplets for the module and write to output stream
181 /// Output format: MAX_RELATION=N header followed by relationships
182 void writeTripletsToStream(const Module &M, raw_ostream &OS) const;
183
184 /// Generate entity mappings for the entire vocabulary
185 EntityList collectEntityMappings() const;
186
187 /// Generate entity mappings and write to output stream
188 void writeEntitiesToStream(raw_ostream &OS) const;
189
190 /// Generate embeddings for all machine functions in the module
191 void writeEmbeddingsToStream(const Module &M, raw_ostream &OS,
192 EmbeddingLevel Level) const;
193
194 /// Generate embeddings for a specific machine function
195 void writeEmbeddingsToStream(MachineFunction &MF, raw_ostream &OS,
196 EmbeddingLevel Level) const;
197
198 /// Get the MIR vocabulary instance
199 const MIRVocabulary *getVocabulary() const { return Vocab.get(); }
200};
201
202/// Helper structure to hold MIR context
203struct MIRContext {
204 LLVMContext Context; // CRITICAL: Must be first for proper destruction order
205 std::unique_ptr<Module> M;
206 std::unique_ptr<MachineModuleInfo> MMI;
207 std::unique_ptr<TargetMachine> TM;
208};
209
210} // namespace mir2vec
211
212} // namespace llvm
213
214#endif // LLVM_TOOLS_LLVM_IR2VEC_UTILS_UTILS_H
215