1//===- Utils.cpp - IR2Vec/MIR2Vec Embedding Generation Tool -----------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8///
9/// \file
10/// This file implements the IR2VecTool and MIR2VecTool classes for
11/// IR2Vec/MIR2Vec embedding generation.
12///
13//===----------------------------------------------------------------------===//
14
15#include "Utils.h"
16#include "llvm/ADT/ArrayRef.h"
17#include "llvm/Analysis/IR2Vec.h"
18#include "llvm/IR/BasicBlock.h"
19#include "llvm/IR/Function.h"
20#include "llvm/IR/InstIterator.h"
21#include "llvm/IR/Instructions.h"
22#include "llvm/IR/LLVMContext.h"
23#include "llvm/IR/Module.h"
24#include "llvm/IR/PassInstrumentation.h"
25#include "llvm/IR/PassManager.h"
26#include "llvm/IR/Type.h"
27#include "llvm/Support/Debug.h"
28#include "llvm/Support/Errc.h"
29#include "llvm/Support/Error.h"
30#include "llvm/Support/raw_ostream.h"
31
32#include "llvm/CodeGen/MIR2Vec.h"
33#include "llvm/CodeGen/MIRParser/MIRParser.h"
34#include "llvm/CodeGen/MachineFunction.h"
35#include "llvm/CodeGen/MachineModuleInfo.h"
36#include "llvm/CodeGen/TargetInstrInfo.h"
37#include "llvm/CodeGen/TargetRegisterInfo.h"
38#include "llvm/Target/TargetMachine.h"
39
40#define DEBUG_TYPE "ir2vec"
41
42namespace llvm {
43
44namespace ir2vec {
45
46Error IR2VecTool::initializeVocabulary(StringRef VocabPath) {
47 auto VocabOrErr = Vocabulary::fromFile(VocabFilePath: VocabPath);
48 if (!VocabOrErr)
49 return VocabOrErr.takeError();
50
51 Vocab = std::make_unique<Vocabulary>(args: std::move(*VocabOrErr));
52
53 if (!Vocab->isValid())
54 return createStringError(EC: errc::invalid_argument,
55 S: "Failed to initialize IR2Vec vocabulary");
56 return Error::success();
57}
58
59TripletResult IR2VecTool::generateTriplets(const Function &F) const {
60 if (F.isDeclaration())
61 return {};
62
63 TripletResult Result;
64 Result.MaxRelation = 0;
65
66 unsigned MaxRelation = NextRelation;
67 unsigned PrevOpcode = 0;
68 bool HasPrevOpcode = false;
69
70 for (const BasicBlock &BB : F) {
71 for (const auto &I : BB.instructionsWithoutDebug()) {
72 unsigned Opcode = Vocabulary::getIndex(Opcode: I.getOpcode());
73 unsigned TypeID = Vocabulary::getIndex(TypeID: I.getType()->getTypeID());
74
75 // Add "Next" relationship with previous instruction
76 if (HasPrevOpcode) {
77 Result.Triplets.push_back(x: {.Head: PrevOpcode, .Tail: Opcode, .Relation: NextRelation});
78 LLVM_DEBUG(dbgs() << Vocabulary::getVocabKeyForOpcode(PrevOpcode + 1)
79 << '\t'
80 << Vocabulary::getVocabKeyForOpcode(Opcode + 1)
81 << '\t' << "Next\n");
82 }
83
84 // Add "Type" relationship
85 Result.Triplets.push_back(x: {.Head: Opcode, .Tail: TypeID, .Relation: TypeRelation});
86 LLVM_DEBUG(
87 dbgs() << Vocabulary::getVocabKeyForOpcode(Opcode + 1) << '\t'
88 << Vocabulary::getVocabKeyForTypeID(I.getType()->getTypeID())
89 << '\t' << "Type\n");
90
91 // Add "Arg" relationships
92 unsigned ArgIndex = 0;
93 for (const Use &U : I.operands()) {
94 unsigned OperandID = Vocabulary::getIndex(Op: *U.get());
95 unsigned RelationID = ArgRelation + ArgIndex;
96 Result.Triplets.push_back(x: {.Head: Opcode, .Tail: OperandID, .Relation: RelationID});
97
98 LLVM_DEBUG({
99 StringRef OperandStr = Vocabulary::getVocabKeyForOperandKind(
100 Vocabulary::getOperandKind(U.get()));
101 dbgs() << Vocabulary::getVocabKeyForOpcode(Opcode + 1) << '\t'
102 << OperandStr << '\t' << "Arg" << ArgIndex << '\n';
103 });
104
105 ++ArgIndex;
106 }
107 // Only update MaxRelation if there were operands
108 if (ArgIndex > 0)
109 MaxRelation = std::max(a: MaxRelation, b: ArgRelation + ArgIndex - 1);
110 PrevOpcode = Opcode;
111 HasPrevOpcode = true;
112 }
113 }
114
115 Result.MaxRelation = MaxRelation;
116 return Result;
117}
118
119TripletResult IR2VecTool::generateTriplets() const {
120 TripletResult Result;
121 Result.MaxRelation = NextRelation;
122
123 for (const Function &F : M.getFunctionDefs()) {
124 TripletResult FuncResult = generateTriplets(F);
125 Result.MaxRelation = std::max(a: Result.MaxRelation, b: FuncResult.MaxRelation);
126 Result.Triplets.insert(position: Result.Triplets.end(), first: FuncResult.Triplets.begin(),
127 last: FuncResult.Triplets.end());
128 }
129
130 return Result;
131}
132
133void IR2VecTool::writeTripletsToStream(raw_ostream &OS) const {
134 auto Result = generateTriplets();
135 OS << "MAX_RELATION=" << Result.MaxRelation << '\n';
136 for (const auto &T : Result.Triplets)
137 OS << T.Head << '\t' << T.Tail << '\t' << T.Relation << '\n';
138}
139
140EntityList IR2VecTool::collectEntityMappings() {
141 auto EntityLen = Vocabulary::getCanonicalSize();
142 EntityList Result;
143 for (unsigned EntityID = 0; EntityID < EntityLen; ++EntityID)
144 Result.push_back(x: Vocabulary::getStringKey(Pos: EntityID).str());
145 return Result;
146}
147
148void IR2VecTool::writeEntitiesToStream(raw_ostream &OS) {
149 auto Entities = collectEntityMappings();
150 OS << Entities.size() << "\n";
151 for (unsigned EntityID = 0; EntityID < Entities.size(); ++EntityID)
152 OS << Entities[EntityID] << '\t' << EntityID << '\n';
153}
154
155Expected<std::unique_ptr<Embedder>>
156IR2VecTool::createIR2VecEmbedder(const Function &F, IR2VecKind Kind) const {
157 if (!Vocab || !Vocab->isValid())
158 return createStringError(
159 EC: errc::invalid_argument,
160 S: "Vocabulary is not valid. IR2VecTool not initialized.");
161
162 if (F.isDeclaration())
163 return createStringError(EC: errc::invalid_argument,
164 S: "Function is a declaration.");
165
166 auto Emb = Embedder::create(Mode: Kind, F, Vocab: *Vocab);
167 if (!Emb)
168 return createStringError(EC: errc::invalid_argument,
169 Fmt: "Failed to create embedder for function '%s'.",
170 Vals: F.getName().str().c_str());
171
172 return std::move(Emb);
173}
174
175Expected<Embedding> IR2VecTool::getFunctionEmbedding(const Function &F,
176 IR2VecKind Kind) const {
177 auto Emb = createIR2VecEmbedder(F, Kind);
178 if (!Emb)
179 return Emb.takeError();
180
181 return (*Emb)->getFunctionVector();
182}
183
184Expected<FuncEmbMap>
185IR2VecTool::getFunctionEmbeddingsMap(IR2VecKind Kind) const {
186 FuncEmbMap Result;
187
188 for (const Function &F : M.getFunctionDefs()) {
189 auto Emb = getFunctionEmbedding(F, Kind);
190 if (!Emb)
191 return Emb.takeError();
192 Result.try_emplace(Key: &F, Args: std::move(*Emb));
193 }
194
195 return Result;
196}
197
198Expected<BBEmbeddingsMap>
199IR2VecTool::getBBEmbeddingsMap(const Function &F, IR2VecKind Kind) const {
200 auto Emb = createIR2VecEmbedder(F, Kind);
201 if (!Emb)
202 return Emb.takeError();
203
204 BBEmbeddingsMap Result;
205
206 for (const BasicBlock &BB : F)
207 Result.try_emplace(Key: &BB, Args: (*Emb)->getBBVector(BB));
208
209 return Result;
210}
211
212Expected<InstEmbeddingsMap>
213IR2VecTool::getInstEmbeddingsMap(const Function &F, IR2VecKind Kind) const {
214 auto Emb = createIR2VecEmbedder(F, Kind);
215 if (!Emb)
216 return Emb.takeError();
217
218 InstEmbeddingsMap Result;
219
220 for (const Instruction &I : instructions(F))
221 Result.try_emplace(Key: &I, Args: (*Emb)->getInstVector(I));
222
223 return Result;
224}
225
226void IR2VecTool::writeEmbeddingsToStream(raw_ostream &OS,
227 EmbeddingLevel Level) const {
228 for (const Function &F : M.getFunctionDefs())
229 writeEmbeddingsToStream(F, OS, Level);
230}
231
232void IR2VecTool::writeEmbeddingsToStream(const Function &F, raw_ostream &OS,
233 EmbeddingLevel Level) const {
234 auto IR2VecEmbedderObj = createIR2VecEmbedder(F, Kind: IR2VecEmbeddingKind);
235 if (!IR2VecEmbedderObj) {
236 WithColor::error(OS&: errs(), Prefix: ToolName)
237 << toString(E: IR2VecEmbedderObj.takeError()) << "\n";
238 return;
239 }
240 auto Emb = std::move(*IR2VecEmbedderObj);
241
242 OS << "Function: " << F.getName() << "\n";
243
244 // Generate embeddings based on the specified level
245 switch (Level) {
246 case FunctionLevel:
247 Emb->getFunctionVector().print(OS);
248 break;
249 case BasicBlockLevel:
250 for (const BasicBlock &BB : F) {
251 OS << BB.getName() << ":";
252 Emb->getBBVector(BB).print(OS);
253 }
254 break;
255 case InstructionLevel:
256 for (const Instruction &I : instructions(F)) {
257 OS << I;
258 Emb->getInstVector(I).print(OS);
259 }
260 break;
261 }
262}
263
264} // namespace ir2vec
265
266namespace mir2vec {
267
268bool MIR2VecTool::initializeVocabulary(const Module &M) {
269 MIR2VecVocabProvider Provider(MMI);
270 auto VocabOrErr = Provider.getVocabulary(M);
271 if (!VocabOrErr) {
272 WithColor::error(OS&: errs(), Prefix: ToolName)
273 << "Failed to load MIR2Vec vocabulary - "
274 << toString(E: VocabOrErr.takeError()) << "\n";
275 return false;
276 }
277 Vocab = std::make_unique<MIRVocabulary>(args: std::move(*VocabOrErr));
278 return true;
279}
280
281bool MIR2VecTool::initializeVocabularyForLayout(const Module &M) {
282 for (const Function &F : M.getFunctionDefs()) {
283 MachineFunction *MF = MMI.getMachineFunction(F);
284 if (!MF)
285 continue;
286
287 const TargetInstrInfo &TII = *MF->getSubtarget().getInstrInfo();
288 const TargetRegisterInfo &TRI = *MF->getSubtarget().getRegisterInfo();
289 const MachineRegisterInfo &MRI = MF->getRegInfo();
290
291 auto VocabOrErr = MIRVocabulary::createDummyVocabForTest(TII, TRI, MRI, Dim: 1);
292 if (!VocabOrErr) {
293 WithColor::error(OS&: errs(), Prefix: ToolName)
294 << "Failed to create dummy vocabulary - "
295 << toString(E: VocabOrErr.takeError()) << "\n";
296 return false;
297 }
298 Vocab = std::make_unique<MIRVocabulary>(args: std::move(*VocabOrErr));
299 return true;
300 }
301
302 WithColor::error(OS&: errs(), Prefix: ToolName)
303 << "No machine functions found to initialize vocabulary\n";
304 return false;
305}
306
307TripletResult MIR2VecTool::generateTriplets(const MachineFunction &MF) const {
308 TripletResult Result;
309 Result.MaxRelation = MIRNextRelation;
310
311 if (!Vocab) {
312 WithColor::error(OS&: errs(), Prefix: ToolName)
313 << "MIR Vocabulary must be initialized for triplet generation.\n";
314 return Result;
315 }
316
317 unsigned PrevOpcode = 0;
318 bool HasPrevOpcode = false;
319 for (const MachineBasicBlock &MBB : MF) {
320 for (const MachineInstr &MI : MBB) {
321 // Skip debug instructions
322 if (MI.isDebugInstr())
323 continue;
324
325 // Get opcode entity ID
326 unsigned OpcodeID = Vocab->getEntityIDForOpcode(Opcode: MI.getOpcode());
327
328 // Add "Next" relationship with previous instruction
329 if (HasPrevOpcode) {
330 Result.Triplets.push_back(x: {.Head: PrevOpcode, .Tail: OpcodeID, .Relation: MIRNextRelation});
331 LLVM_DEBUG(dbgs() << Vocab->getStringKey(PrevOpcode) << '\t'
332 << Vocab->getStringKey(OpcodeID) << '\t' << "Next\n");
333 }
334
335 // Add "Arg" relationships for operands
336 unsigned ArgIndex = 0;
337 for (const MachineOperand &MO : MI.operands()) {
338 auto OperandID = Vocab->getEntityIDForMachineOperand(MO);
339 unsigned RelationID = MIRArgRelation + ArgIndex;
340 Result.Triplets.push_back(x: {.Head: OpcodeID, .Tail: OperandID, .Relation: RelationID});
341 LLVM_DEBUG({
342 std::string OperandStr = Vocab->getStringKey(OperandID);
343 dbgs() << Vocab->getStringKey(OpcodeID) << '\t' << OperandStr << '\t'
344 << "Arg" << ArgIndex << '\n';
345 });
346
347 ++ArgIndex;
348 }
349
350 // Update MaxRelation if there were operands
351 if (ArgIndex > 0)
352 Result.MaxRelation =
353 std::max(a: Result.MaxRelation, b: MIRArgRelation + ArgIndex - 1);
354
355 PrevOpcode = OpcodeID;
356 HasPrevOpcode = true;
357 }
358 }
359
360 return Result;
361}
362
363TripletResult MIR2VecTool::generateTriplets(const Module &M) const {
364 TripletResult Result;
365 Result.MaxRelation = MIRNextRelation;
366
367 for (const Function &F : M.getFunctionDefs()) {
368 MachineFunction *MF = MMI.getMachineFunction(F);
369 if (!MF) {
370 WithColor::warning(OS&: errs(), Prefix: ToolName)
371 << "No MachineFunction for " << F.getName() << "\n";
372 continue;
373 }
374
375 TripletResult FuncResult = generateTriplets(MF: *MF);
376 Result.MaxRelation = std::max(a: Result.MaxRelation, b: FuncResult.MaxRelation);
377 Result.Triplets.insert(position: Result.Triplets.end(), first: FuncResult.Triplets.begin(),
378 last: FuncResult.Triplets.end());
379 }
380
381 return Result;
382}
383
384void MIR2VecTool::writeTripletsToStream(const Module &M,
385 raw_ostream &OS) const {
386 auto Result = generateTriplets(M);
387 OS << "MAX_RELATION=" << Result.MaxRelation << '\n';
388 for (const auto &T : Result.Triplets)
389 OS << T.Head << '\t' << T.Tail << '\t' << T.Relation << '\n';
390}
391
392EntityList MIR2VecTool::collectEntityMappings() const {
393 if (!Vocab) {
394 WithColor::error(OS&: errs(), Prefix: ToolName)
395 << "Vocabulary must be initialized for entity mappings.\n";
396 return {};
397 }
398
399 const unsigned EntityCount = Vocab->getCanonicalSize();
400 EntityList Result;
401 for (unsigned EntityID = 0; EntityID < EntityCount; ++EntityID)
402 Result.push_back(x: Vocab->getStringKey(Pos: EntityID));
403
404 return Result;
405}
406
407void MIR2VecTool::writeEntitiesToStream(raw_ostream &OS) const {
408 auto Entities = collectEntityMappings();
409 if (Entities.empty())
410 return;
411
412 OS << Entities.size() << "\n";
413 for (unsigned EntityID = 0; EntityID < Entities.size(); ++EntityID)
414 OS << Entities[EntityID] << '\t' << EntityID << '\n';
415}
416
417void MIR2VecTool::writeEmbeddingsToStream(const Module &M, raw_ostream &OS,
418 EmbeddingLevel Level) const {
419 if (!Vocab) {
420 WithColor::error(OS&: errs(), Prefix: ToolName) << "Vocabulary not initialized.\n";
421 return;
422 }
423
424 for (const Function &F : M.getFunctionDefs()) {
425 MachineFunction *MF = MMI.getMachineFunction(F);
426 if (!MF) {
427 WithColor::warning(OS&: errs(), Prefix: ToolName)
428 << "No MachineFunction for " << F.getName() << "\n";
429 continue;
430 }
431
432 writeEmbeddingsToStream(MF&: *MF, OS, Level);
433 }
434}
435
436void MIR2VecTool::writeEmbeddingsToStream(MachineFunction &MF, raw_ostream &OS,
437 EmbeddingLevel Level) const {
438 if (!Vocab) {
439 WithColor::error(OS&: errs(), Prefix: ToolName) << "Vocabulary not initialized.\n";
440 return;
441 }
442
443 auto Emb = MIREmbedder::create(Mode: MIR2VecKind::Symbolic, MF, Vocab: *Vocab);
444 if (!Emb) {
445 WithColor::error(OS&: errs(), Prefix: ToolName)
446 << "Failed to create embedder for " << MF.getName() << "\n";
447 return;
448 }
449
450 OS << "MIR2Vec embeddings for machine function " << MF.getName() << ":\n";
451
452 // Generate embeddings based on the specified level
453 switch (Level) {
454 case FunctionLevel:
455 OS << "Function vector: ";
456 Emb->getMFunctionVector().print(OS);
457 break;
458 case BasicBlockLevel:
459 OS << "Basic block vectors:\n";
460 for (const MachineBasicBlock &MBB : MF) {
461 OS << "MBB " << MBB.getName() << ": ";
462 Emb->getMBBVector(MBB).print(OS);
463 }
464 break;
465 case InstructionLevel:
466 OS << "Instruction vectors:\n";
467 for (const MachineBasicBlock &MBB : MF) {
468 for (const MachineInstr &MI : MBB) {
469 OS << MI << " -> ";
470 Emb->getMInstVector(MI).print(OS);
471 }
472 }
473 break;
474 }
475}
476
477} // namespace mir2vec
478
479} // namespace llvm
480