1//===- Utils.cpp - IR2Vec/MIR2Vec Embedding Generation Tool -----------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8///
9/// \file
10/// This file implements the IR2VecTool and MIR2VecTool classes for
11/// IR2Vec/MIR2Vec embedding generation.
12///
13//===----------------------------------------------------------------------===//
14
15#include "Utils.h"
16#include "llvm/ADT/ArrayRef.h"
17#include "llvm/Analysis/IR2Vec.h"
18#include "llvm/IR/BasicBlock.h"
19#include "llvm/IR/Function.h"
20#include "llvm/IR/InstIterator.h"
21#include "llvm/IR/Instructions.h"
22#include "llvm/IR/LLVMContext.h"
23#include "llvm/IR/Module.h"
24#include "llvm/IR/PassInstrumentation.h"
25#include "llvm/IR/PassManager.h"
26#include "llvm/IR/Type.h"
27#include "llvm/Support/Debug.h"
28#include "llvm/Support/Errc.h"
29#include "llvm/Support/raw_ostream.h"
30
31#include "llvm/CodeGen/MIR2Vec.h"
32#include "llvm/CodeGen/MIRParser/MIRParser.h"
33#include "llvm/CodeGen/MachineFunction.h"
34#include "llvm/CodeGen/MachineModuleInfo.h"
35#include "llvm/CodeGen/TargetInstrInfo.h"
36#include "llvm/CodeGen/TargetRegisterInfo.h"
37#include "llvm/Target/TargetMachine.h"
38
39#define DEBUG_TYPE "ir2vec"
40
41namespace llvm {
42
43namespace ir2vec {
44
45Error IR2VecTool::initializeVocabulary(StringRef VocabPath) {
46 auto VocabOrErr = Vocabulary::fromFile(VocabFilePath: VocabPath);
47 if (!VocabOrErr)
48 return VocabOrErr.takeError();
49
50 Vocab = std::make_unique<Vocabulary>(args: std::move(*VocabOrErr));
51
52 if (!Vocab->isValid())
53 return createStringError(EC: errc::invalid_argument,
54 S: "Failed to initialize IR2Vec vocabulary");
55 return Error::success();
56}
57
58TripletResult IR2VecTool::generateTriplets(const Function &F) const {
59 if (F.isDeclaration())
60 return {};
61
62 TripletResult Result;
63 Result.MaxRelation = 0;
64
65 unsigned MaxRelation = NextRelation;
66 unsigned PrevOpcode = 0;
67 bool HasPrevOpcode = false;
68
69 for (const BasicBlock &BB : F) {
70 for (const auto &I : BB.instructionsWithoutDebug()) {
71 unsigned Opcode = Vocabulary::getIndex(Opcode: I.getOpcode());
72 unsigned TypeID = Vocabulary::getIndex(TypeID: I.getType()->getTypeID());
73
74 // Add "Next" relationship with previous instruction
75 if (HasPrevOpcode) {
76 Result.Triplets.push_back(x: {.Head: PrevOpcode, .Tail: Opcode, .Relation: NextRelation});
77 LLVM_DEBUG(dbgs() << Vocabulary::getVocabKeyForOpcode(PrevOpcode + 1)
78 << '\t'
79 << Vocabulary::getVocabKeyForOpcode(Opcode + 1)
80 << '\t' << "Next\n");
81 }
82
83 // Add "Type" relationship
84 Result.Triplets.push_back(x: {.Head: Opcode, .Tail: TypeID, .Relation: TypeRelation});
85 LLVM_DEBUG(
86 dbgs() << Vocabulary::getVocabKeyForOpcode(Opcode + 1) << '\t'
87 << Vocabulary::getVocabKeyForTypeID(I.getType()->getTypeID())
88 << '\t' << "Type\n");
89
90 // Add "Arg" relationships
91 unsigned ArgIndex = 0;
92 for (const Use &U : I.operands()) {
93 unsigned OperandID = Vocabulary::getIndex(Op: *U.get());
94 unsigned RelationID = ArgRelation + ArgIndex;
95 Result.Triplets.push_back(x: {.Head: Opcode, .Tail: OperandID, .Relation: RelationID});
96
97 LLVM_DEBUG({
98 StringRef OperandStr = Vocabulary::getVocabKeyForOperandKind(
99 Vocabulary::getOperandKind(U.get()));
100 dbgs() << Vocabulary::getVocabKeyForOpcode(Opcode + 1) << '\t'
101 << OperandStr << '\t' << "Arg" << ArgIndex << '\n';
102 });
103
104 ++ArgIndex;
105 }
106 // Only update MaxRelation if there were operands
107 if (ArgIndex > 0)
108 MaxRelation = std::max(a: MaxRelation, b: ArgRelation + ArgIndex - 1);
109 PrevOpcode = Opcode;
110 HasPrevOpcode = true;
111 }
112 }
113
114 Result.MaxRelation = MaxRelation;
115 return Result;
116}
117
118TripletResult IR2VecTool::generateTriplets() const {
119 TripletResult Result;
120 Result.MaxRelation = NextRelation;
121
122 for (const Function &F : M.getFunctionDefs()) {
123 TripletResult FuncResult = generateTriplets(F);
124 Result.MaxRelation = std::max(a: Result.MaxRelation, b: FuncResult.MaxRelation);
125 Result.Triplets.insert(position: Result.Triplets.end(), first: FuncResult.Triplets.begin(),
126 last: FuncResult.Triplets.end());
127 }
128
129 return Result;
130}
131
132void IR2VecTool::writeTripletsToStream(raw_ostream &OS) const {
133 auto Result = generateTriplets();
134 OS << "MAX_RELATION=" << Result.MaxRelation << '\n';
135 for (const auto &T : Result.Triplets)
136 OS << T.Head << '\t' << T.Tail << '\t' << T.Relation << '\n';
137}
138
139EntityList IR2VecTool::collectEntityMappings() {
140 auto EntityLen = Vocabulary::getCanonicalSize();
141 EntityList Result;
142 for (unsigned EntityID = 0; EntityID < EntityLen; ++EntityID)
143 Result.push_back(x: Vocabulary::getStringKey(Pos: EntityID).str());
144 return Result;
145}
146
147void IR2VecTool::writeEntitiesToStream(raw_ostream &OS) {
148 auto Entities = collectEntityMappings();
149 OS << Entities.size() << "\n";
150 for (unsigned EntityID = 0; EntityID < Entities.size(); ++EntityID)
151 OS << Entities[EntityID] << '\t' << EntityID << '\n';
152}
153
154void IR2VecTool::writeEmbeddingsToStream(raw_ostream &OS,
155 EmbeddingLevel Level) const {
156 if (!Vocab->isValid()) {
157 WithColor::error(OS&: errs(), Prefix: ToolName)
158 << "Vocabulary is not valid. IR2VecTool not initialized.\n";
159 return;
160 }
161
162 for (const Function &F : M.getFunctionDefs())
163 writeEmbeddingsToStream(F, OS, Level);
164}
165
166void IR2VecTool::writeEmbeddingsToStream(const Function &F, raw_ostream &OS,
167 EmbeddingLevel Level) const {
168 if (!Vocab || !Vocab->isValid()) {
169 WithColor::error(OS&: errs(), Prefix: ToolName)
170 << "Vocabulary is not valid. IR2VecTool not initialized.\n";
171 return;
172 }
173 if (F.isDeclaration()) {
174 OS << "Function " << F.getName() << " is a declaration, skipping.\n";
175 return;
176 }
177
178 // Create embedder for this function
179 auto Emb = Embedder::create(Mode: IR2VecEmbeddingKind, F, Vocab: *Vocab);
180 if (!Emb) {
181 WithColor::error(OS&: errs(), Prefix: ToolName)
182 << "Failed to create embedder for function " << F.getName() << "\n";
183 return;
184 }
185
186 OS << "Function: " << F.getName() << "\n";
187
188 // Generate embeddings based on the specified level
189 switch (Level) {
190 case FunctionLevel:
191 Emb->getFunctionVector().print(OS);
192 break;
193 case BasicBlockLevel:
194 for (const BasicBlock &BB : F) {
195 OS << BB.getName() << ":";
196 Emb->getBBVector(BB).print(OS);
197 }
198 break;
199 case InstructionLevel:
200 for (const Instruction &I : instructions(F)) {
201 OS << I;
202 Emb->getInstVector(I).print(OS);
203 }
204 break;
205 }
206}
207
208} // namespace ir2vec
209
210namespace mir2vec {
211
212bool MIR2VecTool::initializeVocabulary(const Module &M) {
213 MIR2VecVocabProvider Provider(MMI);
214 auto VocabOrErr = Provider.getVocabulary(M);
215 if (!VocabOrErr) {
216 WithColor::error(OS&: errs(), Prefix: ToolName)
217 << "Failed to load MIR2Vec vocabulary - "
218 << toString(E: VocabOrErr.takeError()) << "\n";
219 return false;
220 }
221 Vocab = std::make_unique<MIRVocabulary>(args: std::move(*VocabOrErr));
222 return true;
223}
224
225bool MIR2VecTool::initializeVocabularyForLayout(const Module &M) {
226 for (const Function &F : M.getFunctionDefs()) {
227 MachineFunction *MF = MMI.getMachineFunction(F);
228 if (!MF)
229 continue;
230
231 const TargetInstrInfo &TII = *MF->getSubtarget().getInstrInfo();
232 const TargetRegisterInfo &TRI = *MF->getSubtarget().getRegisterInfo();
233 const MachineRegisterInfo &MRI = MF->getRegInfo();
234
235 auto VocabOrErr = MIRVocabulary::createDummyVocabForTest(TII, TRI, MRI, Dim: 1);
236 if (!VocabOrErr) {
237 WithColor::error(OS&: errs(), Prefix: ToolName)
238 << "Failed to create dummy vocabulary - "
239 << toString(E: VocabOrErr.takeError()) << "\n";
240 return false;
241 }
242 Vocab = std::make_unique<MIRVocabulary>(args: std::move(*VocabOrErr));
243 return true;
244 }
245
246 WithColor::error(OS&: errs(), Prefix: ToolName)
247 << "No machine functions found to initialize vocabulary\n";
248 return false;
249}
250
251TripletResult MIR2VecTool::generateTriplets(const MachineFunction &MF) const {
252 TripletResult Result;
253 Result.MaxRelation = MIRNextRelation;
254
255 if (!Vocab) {
256 WithColor::error(OS&: errs(), Prefix: ToolName)
257 << "MIR Vocabulary must be initialized for triplet generation.\n";
258 return Result;
259 }
260
261 unsigned PrevOpcode = 0;
262 bool HasPrevOpcode = false;
263 for (const MachineBasicBlock &MBB : MF) {
264 for (const MachineInstr &MI : MBB) {
265 // Skip debug instructions
266 if (MI.isDebugInstr())
267 continue;
268
269 // Get opcode entity ID
270 unsigned OpcodeID = Vocab->getEntityIDForOpcode(Opcode: MI.getOpcode());
271
272 // Add "Next" relationship with previous instruction
273 if (HasPrevOpcode) {
274 Result.Triplets.push_back(x: {.Head: PrevOpcode, .Tail: OpcodeID, .Relation: MIRNextRelation});
275 LLVM_DEBUG(dbgs() << Vocab->getStringKey(PrevOpcode) << '\t'
276 << Vocab->getStringKey(OpcodeID) << '\t' << "Next\n");
277 }
278
279 // Add "Arg" relationships for operands
280 unsigned ArgIndex = 0;
281 for (const MachineOperand &MO : MI.operands()) {
282 auto OperandID = Vocab->getEntityIDForMachineOperand(MO);
283 unsigned RelationID = MIRArgRelation + ArgIndex;
284 Result.Triplets.push_back(x: {.Head: OpcodeID, .Tail: OperandID, .Relation: RelationID});
285 LLVM_DEBUG({
286 std::string OperandStr = Vocab->getStringKey(OperandID);
287 dbgs() << Vocab->getStringKey(OpcodeID) << '\t' << OperandStr << '\t'
288 << "Arg" << ArgIndex << '\n';
289 });
290
291 ++ArgIndex;
292 }
293
294 // Update MaxRelation if there were operands
295 if (ArgIndex > 0)
296 Result.MaxRelation =
297 std::max(a: Result.MaxRelation, b: MIRArgRelation + ArgIndex - 1);
298
299 PrevOpcode = OpcodeID;
300 HasPrevOpcode = true;
301 }
302 }
303
304 return Result;
305}
306
307TripletResult MIR2VecTool::generateTriplets(const Module &M) const {
308 TripletResult Result;
309 Result.MaxRelation = MIRNextRelation;
310
311 for (const Function &F : M.getFunctionDefs()) {
312 MachineFunction *MF = MMI.getMachineFunction(F);
313 if (!MF) {
314 WithColor::warning(OS&: errs(), Prefix: ToolName)
315 << "No MachineFunction for " << F.getName() << "\n";
316 continue;
317 }
318
319 TripletResult FuncResult = generateTriplets(MF: *MF);
320 Result.MaxRelation = std::max(a: Result.MaxRelation, b: FuncResult.MaxRelation);
321 Result.Triplets.insert(position: Result.Triplets.end(), first: FuncResult.Triplets.begin(),
322 last: FuncResult.Triplets.end());
323 }
324
325 return Result;
326}
327
328void MIR2VecTool::writeTripletsToStream(const Module &M,
329 raw_ostream &OS) const {
330 auto Result = generateTriplets(M);
331 OS << "MAX_RELATION=" << Result.MaxRelation << '\n';
332 for (const auto &T : Result.Triplets)
333 OS << T.Head << '\t' << T.Tail << '\t' << T.Relation << '\n';
334}
335
336EntityList MIR2VecTool::collectEntityMappings() const {
337 if (!Vocab) {
338 WithColor::error(OS&: errs(), Prefix: ToolName)
339 << "Vocabulary must be initialized for entity mappings.\n";
340 return {};
341 }
342
343 const unsigned EntityCount = Vocab->getCanonicalSize();
344 EntityList Result;
345 for (unsigned EntityID = 0; EntityID < EntityCount; ++EntityID)
346 Result.push_back(x: Vocab->getStringKey(Pos: EntityID));
347
348 return Result;
349}
350
351void MIR2VecTool::writeEntitiesToStream(raw_ostream &OS) const {
352 auto Entities = collectEntityMappings();
353 if (Entities.empty())
354 return;
355
356 OS << Entities.size() << "\n";
357 for (unsigned EntityID = 0; EntityID < Entities.size(); ++EntityID)
358 OS << Entities[EntityID] << '\t' << EntityID << '\n';
359}
360
361void MIR2VecTool::writeEmbeddingsToStream(const Module &M, raw_ostream &OS,
362 EmbeddingLevel Level) const {
363 if (!Vocab) {
364 WithColor::error(OS&: errs(), Prefix: ToolName) << "Vocabulary not initialized.\n";
365 return;
366 }
367
368 for (const Function &F : M.getFunctionDefs()) {
369 MachineFunction *MF = MMI.getMachineFunction(F);
370 if (!MF) {
371 WithColor::warning(OS&: errs(), Prefix: ToolName)
372 << "No MachineFunction for " << F.getName() << "\n";
373 continue;
374 }
375
376 writeEmbeddingsToStream(MF&: *MF, OS, Level);
377 }
378}
379
380void MIR2VecTool::writeEmbeddingsToStream(MachineFunction &MF, raw_ostream &OS,
381 EmbeddingLevel Level) const {
382 if (!Vocab) {
383 WithColor::error(OS&: errs(), Prefix: ToolName) << "Vocabulary not initialized.\n";
384 return;
385 }
386
387 auto Emb = MIREmbedder::create(Mode: MIR2VecKind::Symbolic, MF, Vocab: *Vocab);
388 if (!Emb) {
389 WithColor::error(OS&: errs(), Prefix: ToolName)
390 << "Failed to create embedder for " << MF.getName() << "\n";
391 return;
392 }
393
394 OS << "MIR2Vec embeddings for machine function " << MF.getName() << ":\n";
395
396 // Generate embeddings based on the specified level
397 switch (Level) {
398 case FunctionLevel:
399 OS << "Function vector: ";
400 Emb->getMFunctionVector().print(OS);
401 break;
402 case BasicBlockLevel:
403 OS << "Basic block vectors:\n";
404 for (const MachineBasicBlock &MBB : MF) {
405 OS << "MBB " << MBB.getName() << ": ";
406 Emb->getMBBVector(MBB).print(OS);
407 }
408 break;
409 case InstructionLevel:
410 OS << "Instruction vectors:\n";
411 for (const MachineBasicBlock &MBB : MF) {
412 for (const MachineInstr &MI : MBB) {
413 OS << MI << " -> ";
414 Emb->getMInstVector(MI).print(OS);
415 }
416 }
417 break;
418 }
419}
420
421} // namespace mir2vec
422
423} // namespace llvm
424