1//===- Utils.cpp - IR2Vec/MIR2Vec Embedding Generation Tool -----------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8///
9/// \file
10/// This file implements the IR2VecTool and MIR2VecTool classes for
11/// IR2Vec/MIR2Vec embedding generation.
12///
13//===----------------------------------------------------------------------===//
14
15#include "Utils.h"
16#include "llvm/ADT/ArrayRef.h"
17#include "llvm/Analysis/IR2Vec.h"
18#include "llvm/IR/BasicBlock.h"
19#include "llvm/IR/Function.h"
20#include "llvm/IR/InstIterator.h"
21#include "llvm/IR/Instructions.h"
22#include "llvm/IR/LLVMContext.h"
23#include "llvm/IR/Module.h"
24#include "llvm/IR/PassInstrumentation.h"
25#include "llvm/IR/PassManager.h"
26#include "llvm/IR/Type.h"
27#include "llvm/Support/Debug.h"
28#include "llvm/Support/Errc.h"
29#include "llvm/Support/Error.h"
30#include "llvm/Support/raw_ostream.h"
31
32#include "llvm/CodeGen/MIR2Vec.h"
33#include "llvm/CodeGen/MIRParser/MIRParser.h"
34#include "llvm/CodeGen/MachineFunction.h"
35#include "llvm/CodeGen/MachineModuleInfo.h"
36#include "llvm/CodeGen/TargetInstrInfo.h"
37#include "llvm/CodeGen/TargetRegisterInfo.h"
38#include "llvm/Target/TargetMachine.h"
39
40#define DEBUG_TYPE "ir2vec"
41
42namespace llvm {
43
44namespace ir2vec {
45
46Error IR2VecTool::initializeVocabulary(StringRef VocabPath) {
47 auto VocabOrErr = Vocabulary::fromFile(VocabFilePath: VocabPath);
48 if (!VocabOrErr)
49 return VocabOrErr.takeError();
50
51 Vocab = std::make_unique<Vocabulary>(args: std::move(*VocabOrErr));
52
53 if (!Vocab->isValid())
54 return createStringError(EC: errc::invalid_argument,
55 S: "Failed to initialize IR2Vec vocabulary");
56 return Error::success();
57}
58
59TripletResult IR2VecTool::generateTriplets(const Function &F) const {
60 if (F.isDeclaration())
61 return {};
62
63 TripletResult Result;
64 Result.MaxRelation = 0;
65
66 unsigned MaxRelation = NextRelation;
67 unsigned PrevOpcode = 0;
68 bool HasPrevOpcode = false;
69
70 for (const BasicBlock &BB : F) {
71 for (const auto &I : BB) {
72 if (I.isDebugOrPseudoInst())
73 continue;
74 unsigned Opcode = Vocabulary::getIndex(Opcode: I.getOpcode());
75 unsigned TypeID = Vocabulary::getIndex(TypeID: I.getType()->getTypeID());
76
77 // Add "Next" relationship with previous instruction
78 if (HasPrevOpcode) {
79 Result.Triplets.push_back(x: {.Head: PrevOpcode, .Tail: Opcode, .Relation: NextRelation});
80 LLVM_DEBUG(dbgs() << Vocabulary::getVocabKeyForOpcode(PrevOpcode + 1)
81 << '\t'
82 << Vocabulary::getVocabKeyForOpcode(Opcode + 1)
83 << '\t' << "Next\n");
84 }
85
86 // Add "Type" relationship
87 Result.Triplets.push_back(x: {.Head: Opcode, .Tail: TypeID, .Relation: TypeRelation});
88 LLVM_DEBUG(
89 dbgs() << Vocabulary::getVocabKeyForOpcode(Opcode + 1) << '\t'
90 << Vocabulary::getVocabKeyForTypeID(I.getType()->getTypeID())
91 << '\t' << "Type\n");
92
93 // Add "Arg" relationships
94 unsigned ArgIndex = 0;
95 for (const Use &U : I.operands()) {
96 unsigned OperandID = Vocabulary::getIndex(Op: *U.get());
97 unsigned RelationID = ArgRelation + ArgIndex;
98 Result.Triplets.push_back(x: {.Head: Opcode, .Tail: OperandID, .Relation: RelationID});
99
100 LLVM_DEBUG({
101 StringRef OperandStr = Vocabulary::getVocabKeyForOperandKind(
102 Vocabulary::getOperandKind(U.get()));
103 dbgs() << Vocabulary::getVocabKeyForOpcode(Opcode + 1) << '\t'
104 << OperandStr << '\t' << "Arg" << ArgIndex << '\n';
105 });
106
107 ++ArgIndex;
108 }
109 // Only update MaxRelation if there were operands
110 if (ArgIndex > 0)
111 MaxRelation = std::max(a: MaxRelation, b: ArgRelation + ArgIndex - 1);
112 PrevOpcode = Opcode;
113 HasPrevOpcode = true;
114 }
115 }
116
117 Result.MaxRelation = MaxRelation;
118 return Result;
119}
120
121TripletResult IR2VecTool::generateTriplets() const {
122 TripletResult Result;
123 Result.MaxRelation = NextRelation;
124
125 for (const Function &F : M.getFunctionDefs()) {
126 TripletResult FuncResult = generateTriplets(F);
127 Result.MaxRelation = std::max(a: Result.MaxRelation, b: FuncResult.MaxRelation);
128 Result.Triplets.insert(position: Result.Triplets.end(), first: FuncResult.Triplets.begin(),
129 last: FuncResult.Triplets.end());
130 }
131
132 return Result;
133}
134
135void IR2VecTool::writeTripletsToStream(raw_ostream &OS) const {
136 auto Result = generateTriplets();
137 OS << "MAX_RELATION=" << Result.MaxRelation << '\n';
138 for (const auto &T : Result.Triplets)
139 OS << T.Head << '\t' << T.Tail << '\t' << T.Relation << '\n';
140}
141
142EntityList IR2VecTool::collectEntityMappings() {
143 auto EntityLen = Vocabulary::getCanonicalSize();
144 EntityList Result;
145 for (unsigned EntityID = 0; EntityID < EntityLen; ++EntityID)
146 Result.push_back(x: Vocabulary::getStringKey(Pos: EntityID).str());
147 return Result;
148}
149
150void IR2VecTool::writeEntitiesToStream(raw_ostream &OS) {
151 auto Entities = collectEntityMappings();
152 OS << Entities.size() << "\n";
153 for (unsigned EntityID = 0; EntityID < Entities.size(); ++EntityID)
154 OS << Entities[EntityID] << '\t' << EntityID << '\n';
155}
156
157Expected<std::unique_ptr<Embedder>>
158IR2VecTool::createIR2VecEmbedder(const Function &F, IR2VecKind Kind) const {
159 if (!Vocab || !Vocab->isValid())
160 return createStringError(
161 EC: errc::invalid_argument,
162 S: "Vocabulary is not valid. IR2VecTool not initialized.");
163
164 if (F.isDeclaration())
165 return createStringError(EC: errc::invalid_argument,
166 S: "Function is a declaration.");
167
168 auto Emb = Embedder::create(Mode: Kind, F, Vocab: *Vocab);
169 if (!Emb)
170 return createStringError(EC: errc::invalid_argument,
171 Fmt: "Failed to create embedder for function '%s'.",
172 Vals: F.getName().str().c_str());
173
174 return std::move(Emb);
175}
176
177Expected<Embedding> IR2VecTool::getFunctionEmbedding(const Function &F,
178 IR2VecKind Kind) const {
179 auto Emb = createIR2VecEmbedder(F, Kind);
180 if (!Emb)
181 return Emb.takeError();
182
183 return (*Emb)->getFunctionVector();
184}
185
186Expected<FuncEmbMap>
187IR2VecTool::getFunctionEmbeddingsMap(IR2VecKind Kind) const {
188 FuncEmbMap Result;
189
190 for (const Function &F : M.getFunctionDefs()) {
191 auto Emb = getFunctionEmbedding(F, Kind);
192 if (!Emb)
193 return Emb.takeError();
194 Result.try_emplace(Key: &F, Args: std::move(*Emb));
195 }
196
197 return Result;
198}
199
200Expected<BBEmbeddingsMap>
201IR2VecTool::getBBEmbeddingsMap(const Function &F, IR2VecKind Kind) const {
202 auto Emb = createIR2VecEmbedder(F, Kind);
203 if (!Emb)
204 return Emb.takeError();
205
206 BBEmbeddingsMap Result;
207
208 for (const BasicBlock &BB : F)
209 Result.try_emplace(Key: &BB, Args: (*Emb)->getBBVector(BB));
210
211 return Result;
212}
213
214Expected<InstEmbeddingsMap>
215IR2VecTool::getInstEmbeddingsMap(const Function &F, IR2VecKind Kind) const {
216 auto Emb = createIR2VecEmbedder(F, Kind);
217 if (!Emb)
218 return Emb.takeError();
219
220 InstEmbeddingsMap Result;
221
222 for (const Instruction &I : instructions(F))
223 Result.try_emplace(Key: &I, Args: (*Emb)->getInstVector(I));
224
225 return Result;
226}
227
228void IR2VecTool::writeEmbeddingsToStream(raw_ostream &OS,
229 EmbeddingLevel Level) const {
230 for (const Function &F : M.getFunctionDefs())
231 writeEmbeddingsToStream(F, OS, Level);
232}
233
234void IR2VecTool::writeEmbeddingsToStream(const Function &F, raw_ostream &OS,
235 EmbeddingLevel Level) const {
236 auto IR2VecEmbedderObj = createIR2VecEmbedder(F, Kind: IR2VecEmbeddingKind);
237 if (!IR2VecEmbedderObj) {
238 WithColor::error(OS&: errs(), Prefix: ToolName)
239 << toString(E: IR2VecEmbedderObj.takeError()) << "\n";
240 return;
241 }
242 auto Emb = std::move(*IR2VecEmbedderObj);
243
244 OS << "Function: " << F.getName() << "\n";
245
246 // Generate embeddings based on the specified level
247 switch (Level) {
248 case FunctionLevel:
249 Emb->getFunctionVector().print(OS);
250 break;
251 case BasicBlockLevel:
252 for (const BasicBlock &BB : F) {
253 OS << BB.getName() << ":";
254 Emb->getBBVector(BB).print(OS);
255 }
256 break;
257 case InstructionLevel:
258 for (const Instruction &I : instructions(F)) {
259 OS << I;
260 Emb->getInstVector(I).print(OS);
261 }
262 break;
263 }
264}
265
266} // namespace ir2vec
267
268namespace mir2vec {
269
270bool MIR2VecTool::initializeVocabulary(const Module &M) {
271 MIR2VecVocabProvider Provider(MMI);
272 auto VocabOrErr = Provider.getVocabulary(M);
273 if (!VocabOrErr) {
274 WithColor::error(OS&: errs(), Prefix: ToolName)
275 << "Failed to load MIR2Vec vocabulary - "
276 << toString(E: VocabOrErr.takeError()) << "\n";
277 return false;
278 }
279 Vocab = std::make_unique<MIRVocabulary>(args: std::move(*VocabOrErr));
280 return true;
281}
282
283bool MIR2VecTool::initializeVocabularyForLayout(const Module &M) {
284 for (const Function &F : M.getFunctionDefs()) {
285 MachineFunction *MF = MMI.getMachineFunction(F);
286 if (!MF)
287 continue;
288
289 const TargetInstrInfo &TII = *MF->getSubtarget().getInstrInfo();
290 const TargetRegisterInfo &TRI = *MF->getSubtarget().getRegisterInfo();
291 const MachineRegisterInfo &MRI = MF->getRegInfo();
292
293 auto VocabOrErr = MIRVocabulary::createDummyVocabForTest(TII, TRI, MRI, Dim: 1);
294 if (!VocabOrErr) {
295 WithColor::error(OS&: errs(), Prefix: ToolName)
296 << "Failed to create dummy vocabulary - "
297 << toString(E: VocabOrErr.takeError()) << "\n";
298 return false;
299 }
300 Vocab = std::make_unique<MIRVocabulary>(args: std::move(*VocabOrErr));
301 return true;
302 }
303
304 WithColor::error(OS&: errs(), Prefix: ToolName)
305 << "No machine functions found to initialize vocabulary\n";
306 return false;
307}
308
309TripletResult MIR2VecTool::generateTriplets(const MachineFunction &MF) const {
310 TripletResult Result;
311 Result.MaxRelation = MIRNextRelation;
312
313 if (!Vocab) {
314 WithColor::error(OS&: errs(), Prefix: ToolName)
315 << "MIR Vocabulary must be initialized for triplet generation.\n";
316 return Result;
317 }
318
319 unsigned PrevOpcode = 0;
320 bool HasPrevOpcode = false;
321 for (const MachineBasicBlock &MBB : MF) {
322 for (const MachineInstr &MI : MBB) {
323 // Skip debug instructions
324 if (MI.isDebugInstr())
325 continue;
326
327 // Get opcode entity ID
328 unsigned OpcodeID = Vocab->getEntityIDForOpcode(Opcode: MI.getOpcode());
329
330 // Add "Next" relationship with previous instruction
331 if (HasPrevOpcode) {
332 Result.Triplets.push_back(x: {.Head: PrevOpcode, .Tail: OpcodeID, .Relation: MIRNextRelation});
333 LLVM_DEBUG(dbgs() << Vocab->getStringKey(PrevOpcode) << '\t'
334 << Vocab->getStringKey(OpcodeID) << '\t' << "Next\n");
335 }
336
337 // Add "Arg" relationships for operands
338 unsigned ArgIndex = 0;
339 for (const MachineOperand &MO : MI.operands()) {
340 auto OperandID = Vocab->getEntityIDForMachineOperand(MO);
341 unsigned RelationID = MIRArgRelation + ArgIndex;
342 Result.Triplets.push_back(x: {.Head: OpcodeID, .Tail: OperandID, .Relation: RelationID});
343 LLVM_DEBUG({
344 std::string OperandStr = Vocab->getStringKey(OperandID);
345 dbgs() << Vocab->getStringKey(OpcodeID) << '\t' << OperandStr << '\t'
346 << "Arg" << ArgIndex << '\n';
347 });
348
349 ++ArgIndex;
350 }
351
352 // Update MaxRelation if there were operands
353 if (ArgIndex > 0)
354 Result.MaxRelation =
355 std::max(a: Result.MaxRelation, b: MIRArgRelation + ArgIndex - 1);
356
357 PrevOpcode = OpcodeID;
358 HasPrevOpcode = true;
359 }
360 }
361
362 return Result;
363}
364
365TripletResult MIR2VecTool::generateTriplets(const Module &M) const {
366 TripletResult Result;
367 Result.MaxRelation = MIRNextRelation;
368
369 for (const Function &F : M.getFunctionDefs()) {
370 MachineFunction *MF = MMI.getMachineFunction(F);
371 if (!MF) {
372 WithColor::warning(OS&: errs(), Prefix: ToolName)
373 << "No MachineFunction for " << F.getName() << "\n";
374 continue;
375 }
376
377 TripletResult FuncResult = generateTriplets(MF: *MF);
378 Result.MaxRelation = std::max(a: Result.MaxRelation, b: FuncResult.MaxRelation);
379 Result.Triplets.insert(position: Result.Triplets.end(), first: FuncResult.Triplets.begin(),
380 last: FuncResult.Triplets.end());
381 }
382
383 return Result;
384}
385
386void MIR2VecTool::writeTripletsToStream(const Module &M,
387 raw_ostream &OS) const {
388 auto Result = generateTriplets(M);
389 OS << "MAX_RELATION=" << Result.MaxRelation << '\n';
390 for (const auto &T : Result.Triplets)
391 OS << T.Head << '\t' << T.Tail << '\t' << T.Relation << '\n';
392}
393
394EntityList MIR2VecTool::collectEntityMappings() const {
395 if (!Vocab) {
396 WithColor::error(OS&: errs(), Prefix: ToolName)
397 << "Vocabulary must be initialized for entity mappings.\n";
398 return {};
399 }
400
401 const unsigned EntityCount = Vocab->getCanonicalSize();
402 EntityList Result;
403 for (unsigned EntityID = 0; EntityID < EntityCount; ++EntityID)
404 Result.push_back(x: Vocab->getStringKey(Pos: EntityID));
405
406 return Result;
407}
408
409void MIR2VecTool::writeEntitiesToStream(raw_ostream &OS) const {
410 auto Entities = collectEntityMappings();
411 if (Entities.empty())
412 return;
413
414 OS << Entities.size() << "\n";
415 for (unsigned EntityID = 0; EntityID < Entities.size(); ++EntityID)
416 OS << Entities[EntityID] << '\t' << EntityID << '\n';
417}
418
419void MIR2VecTool::writeEmbeddingsToStream(const Module &M, raw_ostream &OS,
420 EmbeddingLevel Level) const {
421 if (!Vocab) {
422 WithColor::error(OS&: errs(), Prefix: ToolName) << "Vocabulary not initialized.\n";
423 return;
424 }
425
426 for (const Function &F : M.getFunctionDefs()) {
427 MachineFunction *MF = MMI.getMachineFunction(F);
428 if (!MF) {
429 WithColor::warning(OS&: errs(), Prefix: ToolName)
430 << "No MachineFunction for " << F.getName() << "\n";
431 continue;
432 }
433
434 writeEmbeddingsToStream(MF&: *MF, OS, Level);
435 }
436}
437
438void MIR2VecTool::writeEmbeddingsToStream(MachineFunction &MF, raw_ostream &OS,
439 EmbeddingLevel Level) const {
440 if (!Vocab) {
441 WithColor::error(OS&: errs(), Prefix: ToolName) << "Vocabulary not initialized.\n";
442 return;
443 }
444
445 auto Emb = MIREmbedder::create(Mode: MIR2VecKind::Symbolic, MF, Vocab: *Vocab);
446 if (!Emb) {
447 WithColor::error(OS&: errs(), Prefix: ToolName)
448 << "Failed to create embedder for " << MF.getName() << "\n";
449 return;
450 }
451
452 OS << "MIR2Vec embeddings for machine function " << MF.getName() << ":\n";
453
454 // Generate embeddings based on the specified level
455 switch (Level) {
456 case FunctionLevel:
457 OS << "Function vector: ";
458 Emb->getMFunctionVector().print(OS);
459 break;
460 case BasicBlockLevel:
461 OS << "Basic block vectors:\n";
462 for (const MachineBasicBlock &MBB : MF) {
463 OS << "MBB " << MBB.getName() << ": ";
464 Emb->getMBBVector(MBB).print(OS);
465 }
466 break;
467 case InstructionLevel:
468 OS << "Instruction vectors:\n";
469 for (const MachineBasicBlock &MBB : MF) {
470 for (const MachineInstr &MI : MBB) {
471 OS << MI << " -> ";
472 Emb->getMInstVector(MI).print(OS);
473 }
474 }
475 break;
476 }
477}
478
479} // namespace mir2vec
480
481} // namespace llvm
482