1//===- IRUtils.cpp - IR2Vec Embedding Generation ----------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8///
9/// \file
10/// This file implements the IR2VecTool class for IR2Vec embedding generation
11/// from LLVM IR. It has no dependency on Machine IR.
12///
13//===----------------------------------------------------------------------===//
14
15#include "IRUtils.h"
16#include "llvm/Analysis/IR2Vec.h"
17#include "llvm/IR/BasicBlock.h"
18#include "llvm/IR/Function.h"
19#include "llvm/IR/InstIterator.h"
20#include "llvm/IR/Instructions.h"
21#include "llvm/IR/LLVMContext.h"
22#include "llvm/IR/Module.h"
23#include "llvm/IR/PassInstrumentation.h"
24#include "llvm/IR/PassManager.h"
25#include "llvm/IR/Type.h"
26#include "llvm/Support/Debug.h"
27#include "llvm/Support/Errc.h"
28#include "llvm/Support/Error.h"
29#include "llvm/Support/WithColor.h"
30#include "llvm/Support/raw_ostream.h"
31
32#define DEBUG_TYPE "ir2vec"
33
34namespace llvm {
35namespace ir2vec {
36
37Expected<std::shared_ptr<Vocabulary>> loadVocabulary(StringRef VocabPath) {
38 auto VocabOrErr = Vocabulary::fromFile(VocabFilePath: VocabPath);
39 if (!VocabOrErr)
40 return VocabOrErr.takeError();
41
42 auto V = std::make_shared<Vocabulary>(args: std::move(*VocabOrErr));
43
44 if (!V->isValid())
45 return createStringError(EC: errc::invalid_argument,
46 S: "Failed to initialize IR2Vec vocabulary");
47 return V;
48}
49
50Error IR2VecTool::setVocabulary(std::shared_ptr<Vocabulary> V) {
51 if (!V)
52 return createStringError(EC: errc::invalid_argument,
53 S: "Null pointer provided for vocabulary. Will not "
54 "set IR2VecTool vocabulary.");
55 if (!V->isValid())
56 return createStringError(
57 EC: errc::invalid_argument,
58 S: "Vocabulary is not valid. Will not set IR2VecTool vocabulary.");
59 Vocab = std::move(V);
60 return Error::success();
61}
62
63TripletResult IR2VecTool::generateTriplets(const Function &F) const {
64 if (F.isDeclaration())
65 return {};
66
67 TripletResult Result;
68 Result.MaxRelation = 0;
69
70 unsigned MaxRelation = NextRelation;
71 unsigned PrevOpcode = 0;
72 bool HasPrevOpcode = false;
73
74 for (const BasicBlock &BB : F) {
75 for (const auto &I : BB) {
76 if (I.isDebugOrPseudoInst())
77 continue;
78 unsigned Opcode = Vocabulary::getIndex(Opcode: I.getOpcode());
79 unsigned TypeID = Vocabulary::getIndex(TypeID: I.getType()->getTypeID());
80
81 // Add "Next" relationship with previous instruction
82 if (HasPrevOpcode) {
83 Result.Triplets.push_back(x: {.Head: PrevOpcode, .Tail: Opcode, .Relation: NextRelation});
84 LLVM_DEBUG(dbgs() << Vocabulary::getVocabKeyForOpcode(PrevOpcode + 1)
85 << '\t'
86 << Vocabulary::getVocabKeyForOpcode(Opcode + 1)
87 << '\t' << "Next\n");
88 }
89
90 // Add "Type" relationship
91 Result.Triplets.push_back(x: {.Head: Opcode, .Tail: TypeID, .Relation: TypeRelation});
92 LLVM_DEBUG(
93 dbgs() << Vocabulary::getVocabKeyForOpcode(Opcode + 1) << '\t'
94 << Vocabulary::getVocabKeyForTypeID(I.getType()->getTypeID())
95 << '\t' << "Type\n");
96
97 // Add "Arg" relationships
98 unsigned ArgIndex = 0;
99 for (const Use &U : I.operands()) {
100 unsigned OperandID = Vocabulary::getIndex(Op: *U.get());
101 unsigned RelationID = ArgRelation + ArgIndex;
102 Result.Triplets.push_back(x: {.Head: Opcode, .Tail: OperandID, .Relation: RelationID});
103
104 LLVM_DEBUG({
105 StringRef OperandStr = Vocabulary::getVocabKeyForOperandKind(
106 Vocabulary::getOperandKind(U.get()));
107 dbgs() << Vocabulary::getVocabKeyForOpcode(Opcode + 1) << '\t'
108 << OperandStr << '\t' << "Arg" << ArgIndex << '\n';
109 });
110
111 ++ArgIndex;
112 }
113 // Only update MaxRelation if there were operands
114 if (ArgIndex > 0)
115 MaxRelation = std::max(a: MaxRelation, b: ArgRelation + ArgIndex - 1);
116 PrevOpcode = Opcode;
117 HasPrevOpcode = true;
118 }
119 }
120
121 Result.MaxRelation = MaxRelation;
122 return Result;
123}
124
125TripletResult IR2VecTool::generateTriplets() const {
126 TripletResult Result;
127 Result.MaxRelation = NextRelation;
128
129 for (const Function &F : M.getFunctionDefs()) {
130 TripletResult FuncResult = generateTriplets(F);
131 Result.MaxRelation = std::max(a: Result.MaxRelation, b: FuncResult.MaxRelation);
132 Result.Triplets.insert(position: Result.Triplets.end(), first: FuncResult.Triplets.begin(),
133 last: FuncResult.Triplets.end());
134 }
135
136 return Result;
137}
138
139void IR2VecTool::writeTripletsToStream(raw_ostream &OS) const {
140 auto Result = generateTriplets();
141 OS << "MAX_RELATION=" << Result.MaxRelation << '\n';
142 for (const auto &T : Result.Triplets)
143 OS << T.Head << '\t' << T.Tail << '\t' << T.Relation << '\n';
144}
145
146EntityList IR2VecTool::collectEntityMappings() {
147 auto EntityLen = Vocabulary::getCanonicalSize();
148 EntityList Result;
149 for (unsigned EntityID = 0; EntityID < EntityLen; ++EntityID)
150 Result.push_back(x: Vocabulary::getStringKey(Pos: EntityID).str());
151 return Result;
152}
153
154void IR2VecTool::writeEntitiesToStream(raw_ostream &OS) {
155 auto Entities = collectEntityMappings();
156 OS << Entities.size() << "\n";
157 for (unsigned EntityID = 0; EntityID < Entities.size(); ++EntityID)
158 OS << Entities[EntityID] << '\t' << EntityID << '\n';
159}
160
161Expected<std::unique_ptr<Embedder>>
162IR2VecTool::createIR2VecEmbedder(const Function &F, IR2VecKind Kind) const {
163 if (!Vocab || !Vocab->isValid())
164 return createStringError(
165 EC: errc::invalid_argument,
166 S: "Vocabulary is not valid. IR2VecTool not initialized.");
167
168 if (F.isDeclaration())
169 return createStringError(EC: errc::invalid_argument,
170 S: "Function is a declaration.");
171
172 auto Emb = Embedder::create(Mode: Kind, F, Vocab: *Vocab);
173 if (!Emb)
174 return createStringError(EC: errc::invalid_argument,
175 Fmt: "Failed to create embedder for function '%s'.",
176 Vals: F.getName().str().c_str());
177
178 return std::move(Emb);
179}
180
181Expected<Embedding> IR2VecTool::getFunctionEmbedding(const Function &F,
182 IR2VecKind Kind) const {
183 auto Emb = createIR2VecEmbedder(F, Kind);
184 if (!Emb)
185 return Emb.takeError();
186
187 return (*Emb)->getFunctionVector();
188}
189
190Expected<FuncEmbMap>
191IR2VecTool::getFunctionEmbeddingsMap(IR2VecKind Kind) const {
192 FuncEmbMap Result;
193
194 for (const Function &F : M.getFunctionDefs()) {
195 auto Emb = getFunctionEmbedding(F, Kind);
196 if (!Emb)
197 return Emb.takeError();
198 Result.try_emplace(Key: &F, Args: std::move(*Emb));
199 }
200
201 return Result;
202}
203
204Expected<BBEmbeddingsMap>
205IR2VecTool::getBBEmbeddingsMap(const Function &F, IR2VecKind Kind) const {
206 auto Emb = createIR2VecEmbedder(F, Kind);
207 if (!Emb)
208 return Emb.takeError();
209
210 BBEmbeddingsMap Result;
211
212 for (const BasicBlock &BB : F)
213 Result.try_emplace(Key: &BB, Args: (*Emb)->getBBVector(BB));
214
215 return Result;
216}
217
218Expected<InstEmbeddingsMap>
219IR2VecTool::getInstEmbeddingsMap(const Function &F, IR2VecKind Kind) const {
220 auto Emb = createIR2VecEmbedder(F, Kind);
221 if (!Emb)
222 return Emb.takeError();
223
224 InstEmbeddingsMap Result;
225
226 for (const Instruction &I : instructions(F))
227 Result.try_emplace(Key: &I, Args: (*Emb)->getInstVector(I));
228
229 return Result;
230}
231
232void IR2VecTool::writeEmbeddingsToStream(raw_ostream &OS,
233 EmbeddingLevel Level) const {
234 for (const Function &F : M.getFunctionDefs())
235 writeEmbeddingsToStream(F, OS, Level);
236}
237
238void IR2VecTool::writeEmbeddingsToStream(const Function &F, raw_ostream &OS,
239 EmbeddingLevel Level) const {
240 auto IR2VecEmbedderObj = createIR2VecEmbedder(F, Kind: IR2VecEmbeddingKind);
241 if (!IR2VecEmbedderObj) {
242 WithColor::error(OS&: errs(), Prefix: ToolName)
243 << toString(E: IR2VecEmbedderObj.takeError()) << "\n";
244 return;
245 }
246 auto Emb = std::move(*IR2VecEmbedderObj);
247
248 OS << "Function: " << F.getName() << "\n";
249
250 // Generate embeddings based on the specified level
251 switch (Level) {
252 case FunctionLevel:
253 Emb->getFunctionVector().print(OS);
254 break;
255 case BasicBlockLevel:
256 for (const BasicBlock &BB : F) {
257 OS << BB.getName() << ":";
258 Emb->getBBVector(BB).print(OS);
259 }
260 break;
261 case InstructionLevel:
262 for (const Instruction &I : instructions(F)) {
263 OS << I;
264 Emb->getInstVector(I).print(OS);
265 }
266 break;
267 }
268}
269
270} // namespace ir2vec
271} // namespace llvm
272