1//===- Utils.cpp - IR2Vec/MIR2Vec Embedding Generation Tool -----------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8///
9/// \file
10/// This file implements the IR2VecTool and MIR2VecTool classes for
11/// IR2Vec/MIR2Vec embedding generation.
12///
13//===----------------------------------------------------------------------===//
14
15#include "Utils.h"
16#include "llvm/ADT/ArrayRef.h"
17#include "llvm/Analysis/IR2Vec.h"
18#include "llvm/IR/BasicBlock.h"
19#include "llvm/IR/Function.h"
20#include "llvm/IR/InstIterator.h"
21#include "llvm/IR/Instructions.h"
22#include "llvm/IR/LLVMContext.h"
23#include "llvm/IR/Module.h"
24#include "llvm/IR/PassInstrumentation.h"
25#include "llvm/IR/PassManager.h"
26#include "llvm/IR/Type.h"
27#include "llvm/Support/Debug.h"
28#include "llvm/Support/Errc.h"
29#include "llvm/Support/Error.h"
30#include "llvm/Support/raw_ostream.h"
31
32#include "llvm/CodeGen/MIR2Vec.h"
33#include "llvm/CodeGen/MIRParser/MIRParser.h"
34#include "llvm/CodeGen/MachineFunction.h"
35#include "llvm/CodeGen/MachineModuleInfo.h"
36#include "llvm/CodeGen/TargetInstrInfo.h"
37#include "llvm/CodeGen/TargetRegisterInfo.h"
38#include "llvm/Target/TargetMachine.h"
39
40#define DEBUG_TYPE "ir2vec"
41
42namespace llvm {
43
44namespace ir2vec {
45
46Error IR2VecTool::initializeVocabulary(StringRef VocabPath) {
47 auto VocabOrErr = Vocabulary::fromFile(VocabFilePath: VocabPath);
48 if (!VocabOrErr)
49 return VocabOrErr.takeError();
50
51 Vocab = std::make_unique<Vocabulary>(args: std::move(*VocabOrErr));
52
53 if (!Vocab->isValid())
54 return createStringError(EC: errc::invalid_argument,
55 S: "Failed to initialize IR2Vec vocabulary");
56 return Error::success();
57}
58
59TripletResult IR2VecTool::generateTriplets(const Function &F) const {
60 if (F.isDeclaration())
61 return {};
62
63 TripletResult Result;
64 Result.MaxRelation = 0;
65
66 unsigned MaxRelation = NextRelation;
67 unsigned PrevOpcode = 0;
68 bool HasPrevOpcode = false;
69
70 for (const BasicBlock &BB : F) {
71 for (const auto &I : BB.instructionsWithoutDebug()) {
72 unsigned Opcode = Vocabulary::getIndex(Opcode: I.getOpcode());
73 unsigned TypeID = Vocabulary::getIndex(TypeID: I.getType()->getTypeID());
74
75 // Add "Next" relationship with previous instruction
76 if (HasPrevOpcode) {
77 Result.Triplets.push_back(x: {.Head: PrevOpcode, .Tail: Opcode, .Relation: NextRelation});
78 LLVM_DEBUG(dbgs() << Vocabulary::getVocabKeyForOpcode(PrevOpcode + 1)
79 << '\t'
80 << Vocabulary::getVocabKeyForOpcode(Opcode + 1)
81 << '\t' << "Next\n");
82 }
83
84 // Add "Type" relationship
85 Result.Triplets.push_back(x: {.Head: Opcode, .Tail: TypeID, .Relation: TypeRelation});
86 LLVM_DEBUG(
87 dbgs() << Vocabulary::getVocabKeyForOpcode(Opcode + 1) << '\t'
88 << Vocabulary::getVocabKeyForTypeID(I.getType()->getTypeID())
89 << '\t' << "Type\n");
90
91 // Add "Arg" relationships
92 unsigned ArgIndex = 0;
93 for (const Use &U : I.operands()) {
94 unsigned OperandID = Vocabulary::getIndex(Op: *U.get());
95 unsigned RelationID = ArgRelation + ArgIndex;
96 Result.Triplets.push_back(x: {.Head: Opcode, .Tail: OperandID, .Relation: RelationID});
97
98 LLVM_DEBUG({
99 StringRef OperandStr = Vocabulary::getVocabKeyForOperandKind(
100 Vocabulary::getOperandKind(U.get()));
101 dbgs() << Vocabulary::getVocabKeyForOpcode(Opcode + 1) << '\t'
102 << OperandStr << '\t' << "Arg" << ArgIndex << '\n';
103 });
104
105 ++ArgIndex;
106 }
107 // Only update MaxRelation if there were operands
108 if (ArgIndex > 0)
109 MaxRelation = std::max(a: MaxRelation, b: ArgRelation + ArgIndex - 1);
110 PrevOpcode = Opcode;
111 HasPrevOpcode = true;
112 }
113 }
114
115 Result.MaxRelation = MaxRelation;
116 return Result;
117}
118
119TripletResult IR2VecTool::generateTriplets() const {
120 TripletResult Result;
121 Result.MaxRelation = NextRelation;
122
123 for (const Function &F : M.getFunctionDefs()) {
124 TripletResult FuncResult = generateTriplets(F);
125 Result.MaxRelation = std::max(a: Result.MaxRelation, b: FuncResult.MaxRelation);
126 Result.Triplets.insert(position: Result.Triplets.end(), first: FuncResult.Triplets.begin(),
127 last: FuncResult.Triplets.end());
128 }
129
130 return Result;
131}
132
133void IR2VecTool::writeTripletsToStream(raw_ostream &OS) const {
134 auto Result = generateTriplets();
135 OS << "MAX_RELATION=" << Result.MaxRelation << '\n';
136 for (const auto &T : Result.Triplets)
137 OS << T.Head << '\t' << T.Tail << '\t' << T.Relation << '\n';
138}
139
140EntityList IR2VecTool::collectEntityMappings() {
141 auto EntityLen = Vocabulary::getCanonicalSize();
142 EntityList Result;
143 for (unsigned EntityID = 0; EntityID < EntityLen; ++EntityID)
144 Result.push_back(x: Vocabulary::getStringKey(Pos: EntityID).str());
145 return Result;
146}
147
148void IR2VecTool::writeEntitiesToStream(raw_ostream &OS) {
149 auto Entities = collectEntityMappings();
150 OS << Entities.size() << "\n";
151 for (unsigned EntityID = 0; EntityID < Entities.size(); ++EntityID)
152 OS << Entities[EntityID] << '\t' << EntityID << '\n';
153}
154
155Expected<Embedding> IR2VecTool::getFunctionEmbedding(const Function &F,
156 IR2VecKind Kind) const {
157 if (!Vocab || !Vocab->isValid())
158 return createStringError(
159 EC: errc::invalid_argument,
160 S: "Vocabulary is not valid. IR2VecTool not initialized.");
161
162 if (F.isDeclaration())
163 return createStringError(EC: errc::invalid_argument,
164 S: "Function is a declaration.");
165
166 auto Emb = Embedder::create(Mode: Kind, F, Vocab: *Vocab);
167 if (!Emb)
168 return createStringError(EC: errc::invalid_argument,
169 Fmt: "Failed to create embedder for function '%s'.",
170 Vals: F.getName().str().c_str());
171
172 return Emb->getFunctionVector();
173}
174
175Expected<FuncEmbMap>
176IR2VecTool::getFunctionEmbeddingsMap(IR2VecKind Kind) const {
177 if (!Vocab || !Vocab->isValid())
178 return createStringError(
179 EC: errc::invalid_argument,
180 S: "Vocabulary is not valid. IR2VecTool not initialized.");
181
182 FuncEmbMap Result;
183
184 for (const Function &F : M.getFunctionDefs()) {
185 auto Emb = getFunctionEmbedding(F, Kind);
186 if (!Emb)
187 return Emb.takeError();
188 Result.try_emplace(Key: &F, Args: std::move(*Emb));
189 }
190
191 return Result;
192}
193
194Expected<BBEmbeddingsMap>
195IR2VecTool::getBBEmbeddingsMap(const Function &F, IR2VecKind Kind) const {
196 if (!Vocab || !Vocab->isValid())
197 return createStringError(
198 EC: errc::invalid_argument,
199 S: "Vocabulary is not valid. IR2VecTool not initialized.");
200
201 BBEmbeddingsMap Result;
202
203 if (F.isDeclaration())
204 return createStringError(EC: errc::invalid_argument,
205 S: "Function is a declaration.");
206
207 auto Emb = Embedder::create(Mode: Kind, F, Vocab: *Vocab);
208 if (!Emb)
209 return createStringError(EC: errc::invalid_argument,
210 Fmt: "Failed to create embedder for function '%s'.",
211 Vals: F.getName().str().c_str());
212
213 for (const BasicBlock &BB : F)
214 Result.try_emplace(Key: &BB, Args: Emb->getBBVector(BB));
215
216 return Result;
217}
218
219void IR2VecTool::writeEmbeddingsToStream(raw_ostream &OS,
220 EmbeddingLevel Level) const {
221 if (!Vocab || !Vocab->isValid()) {
222 WithColor::error(OS&: errs(), Prefix: ToolName)
223 << "Vocabulary is not valid. IR2VecTool not initialized.\n";
224 return;
225 }
226
227 for (const Function &F : M.getFunctionDefs())
228 writeEmbeddingsToStream(F, OS, Level);
229}
230
231void IR2VecTool::writeEmbeddingsToStream(const Function &F, raw_ostream &OS,
232 EmbeddingLevel Level) const {
233 if (!Vocab || !Vocab->isValid()) {
234 WithColor::error(OS&: errs(), Prefix: ToolName)
235 << "Vocabulary is not valid. IR2VecTool not initialized.\n";
236 return;
237 }
238
239 if (F.isDeclaration()) {
240 OS << "Function " << F.getName() << " is a declaration, skipping.\n";
241 return;
242 }
243
244 // Create embedder for this function
245 auto Emb = Embedder::create(Mode: IR2VecEmbeddingKind, F, Vocab: *Vocab);
246 if (!Emb) {
247 WithColor::error(OS&: errs(), Prefix: ToolName)
248 << "Failed to create embedder for function " << F.getName() << "\n";
249 return;
250 }
251
252 OS << "Function: " << F.getName() << "\n";
253
254 // Generate embeddings based on the specified level
255 switch (Level) {
256 case FunctionLevel:
257 Emb->getFunctionVector().print(OS);
258 break;
259 case BasicBlockLevel:
260 for (const BasicBlock &BB : F) {
261 OS << BB.getName() << ":";
262 Emb->getBBVector(BB).print(OS);
263 }
264 break;
265 case InstructionLevel:
266 for (const Instruction &I : instructions(F)) {
267 OS << I;
268 Emb->getInstVector(I).print(OS);
269 }
270 break;
271 }
272}
273
274} // namespace ir2vec
275
276namespace mir2vec {
277
278bool MIR2VecTool::initializeVocabulary(const Module &M) {
279 MIR2VecVocabProvider Provider(MMI);
280 auto VocabOrErr = Provider.getVocabulary(M);
281 if (!VocabOrErr) {
282 WithColor::error(OS&: errs(), Prefix: ToolName)
283 << "Failed to load MIR2Vec vocabulary - "
284 << toString(E: VocabOrErr.takeError()) << "\n";
285 return false;
286 }
287 Vocab = std::make_unique<MIRVocabulary>(args: std::move(*VocabOrErr));
288 return true;
289}
290
291bool MIR2VecTool::initializeVocabularyForLayout(const Module &M) {
292 for (const Function &F : M.getFunctionDefs()) {
293 MachineFunction *MF = MMI.getMachineFunction(F);
294 if (!MF)
295 continue;
296
297 const TargetInstrInfo &TII = *MF->getSubtarget().getInstrInfo();
298 const TargetRegisterInfo &TRI = *MF->getSubtarget().getRegisterInfo();
299 const MachineRegisterInfo &MRI = MF->getRegInfo();
300
301 auto VocabOrErr = MIRVocabulary::createDummyVocabForTest(TII, TRI, MRI, Dim: 1);
302 if (!VocabOrErr) {
303 WithColor::error(OS&: errs(), Prefix: ToolName)
304 << "Failed to create dummy vocabulary - "
305 << toString(E: VocabOrErr.takeError()) << "\n";
306 return false;
307 }
308 Vocab = std::make_unique<MIRVocabulary>(args: std::move(*VocabOrErr));
309 return true;
310 }
311
312 WithColor::error(OS&: errs(), Prefix: ToolName)
313 << "No machine functions found to initialize vocabulary\n";
314 return false;
315}
316
317TripletResult MIR2VecTool::generateTriplets(const MachineFunction &MF) const {
318 TripletResult Result;
319 Result.MaxRelation = MIRNextRelation;
320
321 if (!Vocab) {
322 WithColor::error(OS&: errs(), Prefix: ToolName)
323 << "MIR Vocabulary must be initialized for triplet generation.\n";
324 return Result;
325 }
326
327 unsigned PrevOpcode = 0;
328 bool HasPrevOpcode = false;
329 for (const MachineBasicBlock &MBB : MF) {
330 for (const MachineInstr &MI : MBB) {
331 // Skip debug instructions
332 if (MI.isDebugInstr())
333 continue;
334
335 // Get opcode entity ID
336 unsigned OpcodeID = Vocab->getEntityIDForOpcode(Opcode: MI.getOpcode());
337
338 // Add "Next" relationship with previous instruction
339 if (HasPrevOpcode) {
340 Result.Triplets.push_back(x: {.Head: PrevOpcode, .Tail: OpcodeID, .Relation: MIRNextRelation});
341 LLVM_DEBUG(dbgs() << Vocab->getStringKey(PrevOpcode) << '\t'
342 << Vocab->getStringKey(OpcodeID) << '\t' << "Next\n");
343 }
344
345 // Add "Arg" relationships for operands
346 unsigned ArgIndex = 0;
347 for (const MachineOperand &MO : MI.operands()) {
348 auto OperandID = Vocab->getEntityIDForMachineOperand(MO);
349 unsigned RelationID = MIRArgRelation + ArgIndex;
350 Result.Triplets.push_back(x: {.Head: OpcodeID, .Tail: OperandID, .Relation: RelationID});
351 LLVM_DEBUG({
352 std::string OperandStr = Vocab->getStringKey(OperandID);
353 dbgs() << Vocab->getStringKey(OpcodeID) << '\t' << OperandStr << '\t'
354 << "Arg" << ArgIndex << '\n';
355 });
356
357 ++ArgIndex;
358 }
359
360 // Update MaxRelation if there were operands
361 if (ArgIndex > 0)
362 Result.MaxRelation =
363 std::max(a: Result.MaxRelation, b: MIRArgRelation + ArgIndex - 1);
364
365 PrevOpcode = OpcodeID;
366 HasPrevOpcode = true;
367 }
368 }
369
370 return Result;
371}
372
373TripletResult MIR2VecTool::generateTriplets(const Module &M) const {
374 TripletResult Result;
375 Result.MaxRelation = MIRNextRelation;
376
377 for (const Function &F : M.getFunctionDefs()) {
378 MachineFunction *MF = MMI.getMachineFunction(F);
379 if (!MF) {
380 WithColor::warning(OS&: errs(), Prefix: ToolName)
381 << "No MachineFunction for " << F.getName() << "\n";
382 continue;
383 }
384
385 TripletResult FuncResult = generateTriplets(MF: *MF);
386 Result.MaxRelation = std::max(a: Result.MaxRelation, b: FuncResult.MaxRelation);
387 Result.Triplets.insert(position: Result.Triplets.end(), first: FuncResult.Triplets.begin(),
388 last: FuncResult.Triplets.end());
389 }
390
391 return Result;
392}
393
394void MIR2VecTool::writeTripletsToStream(const Module &M,
395 raw_ostream &OS) const {
396 auto Result = generateTriplets(M);
397 OS << "MAX_RELATION=" << Result.MaxRelation << '\n';
398 for (const auto &T : Result.Triplets)
399 OS << T.Head << '\t' << T.Tail << '\t' << T.Relation << '\n';
400}
401
402EntityList MIR2VecTool::collectEntityMappings() const {
403 if (!Vocab) {
404 WithColor::error(OS&: errs(), Prefix: ToolName)
405 << "Vocabulary must be initialized for entity mappings.\n";
406 return {};
407 }
408
409 const unsigned EntityCount = Vocab->getCanonicalSize();
410 EntityList Result;
411 for (unsigned EntityID = 0; EntityID < EntityCount; ++EntityID)
412 Result.push_back(x: Vocab->getStringKey(Pos: EntityID));
413
414 return Result;
415}
416
417void MIR2VecTool::writeEntitiesToStream(raw_ostream &OS) const {
418 auto Entities = collectEntityMappings();
419 if (Entities.empty())
420 return;
421
422 OS << Entities.size() << "\n";
423 for (unsigned EntityID = 0; EntityID < Entities.size(); ++EntityID)
424 OS << Entities[EntityID] << '\t' << EntityID << '\n';
425}
426
427void MIR2VecTool::writeEmbeddingsToStream(const Module &M, raw_ostream &OS,
428 EmbeddingLevel Level) const {
429 if (!Vocab) {
430 WithColor::error(OS&: errs(), Prefix: ToolName) << "Vocabulary not initialized.\n";
431 return;
432 }
433
434 for (const Function &F : M.getFunctionDefs()) {
435 MachineFunction *MF = MMI.getMachineFunction(F);
436 if (!MF) {
437 WithColor::warning(OS&: errs(), Prefix: ToolName)
438 << "No MachineFunction for " << F.getName() << "\n";
439 continue;
440 }
441
442 writeEmbeddingsToStream(MF&: *MF, OS, Level);
443 }
444}
445
446void MIR2VecTool::writeEmbeddingsToStream(MachineFunction &MF, raw_ostream &OS,
447 EmbeddingLevel Level) const {
448 if (!Vocab) {
449 WithColor::error(OS&: errs(), Prefix: ToolName) << "Vocabulary not initialized.\n";
450 return;
451 }
452
453 auto Emb = MIREmbedder::create(Mode: MIR2VecKind::Symbolic, MF, Vocab: *Vocab);
454 if (!Emb) {
455 WithColor::error(OS&: errs(), Prefix: ToolName)
456 << "Failed to create embedder for " << MF.getName() << "\n";
457 return;
458 }
459
460 OS << "MIR2Vec embeddings for machine function " << MF.getName() << ":\n";
461
462 // Generate embeddings based on the specified level
463 switch (Level) {
464 case FunctionLevel:
465 OS << "Function vector: ";
466 Emb->getMFunctionVector().print(OS);
467 break;
468 case BasicBlockLevel:
469 OS << "Basic block vectors:\n";
470 for (const MachineBasicBlock &MBB : MF) {
471 OS << "MBB " << MBB.getName() << ": ";
472 Emb->getMBBVector(MBB).print(OS);
473 }
474 break;
475 case InstructionLevel:
476 OS << "Instruction vectors:\n";
477 for (const MachineBasicBlock &MBB : MF) {
478 for (const MachineInstr &MI : MBB) {
479 OS << MI << " -> ";
480 Emb->getMInstVector(MI).print(OS);
481 }
482 }
483 break;
484 }
485}
486
487} // namespace mir2vec
488
489} // namespace llvm
490