1//===- Utils.h - IR2Vec/MIR2Vec Tool Classes ----------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8///
9/// \file
10/// This file contains the IR2VecTool and MIR2VecTool class definitions for
11/// the llvm-ir2vec embedding generation tool.
12///
13//===----------------------------------------------------------------------===//
14
15#ifndef LLVM_TOOLS_LLVM_IR2VEC_UTILS_UTILS_H
16#define LLVM_TOOLS_LLVM_IR2VEC_UTILS_UTILS_H
17
18#include "llvm/ADT/ArrayRef.h"
19#include "llvm/Analysis/IR2Vec.h"
20#include "llvm/CodeGen/MIR2Vec.h"
21#include "llvm/CodeGen/MIRParser/MIRParser.h"
22#include "llvm/CodeGen/MachineFunction.h"
23#include "llvm/CodeGen/MachineModuleInfo.h"
24#include "llvm/CodeGen/TargetInstrInfo.h"
25#include "llvm/CodeGen/TargetRegisterInfo.h"
26#include "llvm/IR/BasicBlock.h"
27#include "llvm/IR/Function.h"
28#include "llvm/IR/InstIterator.h"
29#include "llvm/IR/Instructions.h"
30#include "llvm/IR/LLVMContext.h"
31#include "llvm/IR/Module.h"
32#include "llvm/IR/PassInstrumentation.h"
33#include "llvm/IR/PassManager.h"
34#include "llvm/IR/Type.h"
35#include "llvm/Support/Debug.h"
36#include "llvm/Support/Error.h"
37#include "llvm/Support/WithColor.h"
38#include "llvm/Support/raw_ostream.h"
39#include "llvm/Target/TargetMachine.h"
40#include <memory>
41#include <string>
42#include <vector>
43
44#define DEBUG_TYPE "ir2vec"
45
46namespace llvm {
47
48/// Tool name for error reporting
49static const char *ToolName = "llvm-ir2vec";
50
51/// Specifies the granularity at which embeddings are generated.
52enum EmbeddingLevel {
53 InstructionLevel, // Generate instruction-level embeddings
54 BasicBlockLevel, // Generate basic block-level embeddings
55 FunctionLevel // Generate function-level embeddings
56};
57
58/// Represents a single knowledge graph triplet (Head, Relation, Tail)
59/// where indices reference entities in an EntityList
60struct Triplet {
61 unsigned Head = 0; ///< Index of the head entity in the entity list
62 unsigned Tail = 0; ///< Index of the tail entity in the entity list
63 unsigned Relation = 0; ///< Relation type (see RelationType enum)
64};
65
66/// Result structure containing all generated triplets and metadata
67struct TripletResult {
68 unsigned MaxRelation =
69 0; ///< Highest relation index used (for ArgRelation + N)
70 std::vector<Triplet> Triplets; ///< Collection of all generated triplets
71};
72
73/// Entity mappings: [entity_name]
74using EntityList = std::vector<std::string>;
75
76namespace ir2vec {
77
78/// Relation types for triplet generation
79enum RelationType {
80 TypeRelation = 0, ///< Instruction to type relationship
81 NextRelation = 1, ///< Sequential instruction relationship
82 ArgRelation = 2 ///< Instruction to operand relationship (ArgRelation + N)
83};
84
85/// Helper class for collecting IR triplets and generating embeddings
86class IR2VecTool {
87private:
88 Module &M;
89 ModuleAnalysisManager MAM;
90 std::unique_ptr<Vocabulary> Vocab;
91
92public:
93 explicit IR2VecTool(Module &M) : M(M) {}
94
95 /// Initialize the IR2Vec vocabulary from the specified file path.
96 Error initializeVocabulary(StringRef VocabPath);
97
98 /// Generate triplets for a single function
99 /// Returns a TripletResult with:
100 /// - Triplets: vector of all (subject, object, relation) tuples
101 /// - MaxRelation: highest Arg relation ID used, or NextRelation if none
102 TripletResult generateTriplets(const Function &F) const;
103
104 /// Get triplets for the entire module
105 TripletResult generateTriplets() const;
106
107 /// Collect triplets for the module and dump output to stream
108 /// Output format: MAX_RELATION=N header followed by relationships
109 void writeTripletsToStream(raw_ostream &OS) const;
110
111 /// Generate entity mappings for the entire vocabulary
112 /// Returns EntityList containing all entity strings
113 static EntityList collectEntityMappings();
114
115 /// Dump entity ID to string mappings
116 static void writeEntitiesToStream(raw_ostream &OS);
117
118 /// Generate embeddings for the entire module
119 void writeEmbeddingsToStream(raw_ostream &OS, EmbeddingLevel Level) const;
120
121 /// Generate embeddings for a single function
122 void writeEmbeddingsToStream(const Function &F, raw_ostream &OS,
123 EmbeddingLevel Level) const;
124};
125
126} // namespace ir2vec
127
128namespace mir2vec {
129
130/// Relation types for MIR2Vec triplet generation
131enum MIRRelationType {
132 MIRNextRelation = 0, ///< Sequential instruction relationship
133 MIRArgRelation = 1 ///< Instruction to operand relationship (ArgRelation + N)
134};
135
136/// Helper class for MIR2Vec embedding generation
137class MIR2VecTool {
138private:
139 MachineModuleInfo &MMI;
140 std::unique_ptr<MIRVocabulary> Vocab;
141
142public:
143 explicit MIR2VecTool(MachineModuleInfo &MMI) : MMI(MMI) {}
144
145 /// Initialize MIR2Vec vocabulary from file (for embeddings generation)
146 bool initializeVocabulary(const Module &M);
147
148 /// Initialize vocabulary with layout information only.
149 /// This creates a minimal vocabulary with correct layout but no actual
150 /// embeddings. Sufficient for generating training data and entity mappings.
151 ///
152 /// Note: Requires target-specific information from the first machine function
153 /// to determine the vocabulary layout (number of opcodes, register classes).
154 ///
155 /// FIXME: Use --target option to get target info directly, avoiding the need
156 /// to parse machine functions for pre-training operations.
157 bool initializeVocabularyForLayout(const Module &M);
158
159 /// Get triplets for a single machine function
160 /// Returns TripletResult containing MaxRelation and vector of Triplets
161 TripletResult generateTriplets(const MachineFunction &MF) const;
162
163 /// Get triplets for the entire module
164 /// Returns TripletResult containing aggregated MaxRelation and all Triplets
165 TripletResult generateTriplets(const Module &M) const;
166
167 /// Collect triplets for the module and write to output stream
168 /// Output format: MAX_RELATION=N header followed by relationships
169 void writeTripletsToStream(const Module &M, raw_ostream &OS) const;
170
171 /// Generate entity mappings for the entire vocabulary
172 EntityList collectEntityMappings() const;
173
174 /// Generate entity mappings and write to output stream
175 void writeEntitiesToStream(raw_ostream &OS) const;
176
177 /// Generate embeddings for all machine functions in the module
178 void writeEmbeddingsToStream(const Module &M, raw_ostream &OS,
179 EmbeddingLevel Level) const;
180
181 /// Generate embeddings for a specific machine function
182 void writeEmbeddingsToStream(MachineFunction &MF, raw_ostream &OS,
183 EmbeddingLevel Level) const;
184
185 /// Get the MIR vocabulary instance
186 const MIRVocabulary *getVocabulary() const { return Vocab.get(); }
187};
188
189/// Helper structure to hold MIR context
190struct MIRContext {
191 LLVMContext Context; // CRITICAL: Must be first for proper destruction order
192 std::unique_ptr<Module> M;
193 std::unique_ptr<MachineModuleInfo> MMI;
194 std::unique_ptr<TargetMachine> TM;
195};
196
197} // namespace mir2vec
198
199} // namespace llvm
200
201#endif // LLVM_TOOLS_LLVM_IR2VEC_UTILS_UTILS_H
202