1//===- IR2Vec.cpp - Implementation of IR2Vec -----------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM
4// Exceptions. See the LICENSE file for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8///
9/// \file
10/// This file implements the IR2Vec algorithm.
11///
12//===----------------------------------------------------------------------===//
13
14#include "llvm/Analysis/IR2Vec.h"
15
16#include "llvm/ADT/DepthFirstIterator.h"
17#include "llvm/ADT/Statistic.h"
18#include "llvm/IR/CFG.h"
19#include "llvm/IR/Module.h"
20#include "llvm/IR/PassManager.h"
21#include "llvm/Support/Debug.h"
22#include "llvm/Support/Errc.h"
23#include "llvm/Support/Error.h"
24#include "llvm/Support/ErrorHandling.h"
25#include "llvm/Support/Format.h"
26#include "llvm/Support/MemoryBuffer.h"
27
28using namespace llvm;
29using namespace ir2vec;
30
31#define DEBUG_TYPE "ir2vec"
32
33STATISTIC(VocabMissCounter,
34 "Number of lookups to entites not present in the vocabulary");
35
36namespace llvm {
37namespace ir2vec {
38static cl::OptionCategory IR2VecCategory("IR2Vec Options");
39
40// FIXME: Use a default vocab when not specified
41static cl::opt<std::string>
42 VocabFile("ir2vec-vocab-path", cl::Optional,
43 cl::desc("Path to the vocabulary file for IR2Vec"), cl::init(Val: ""),
44 cl::cat(IR2VecCategory));
45cl::opt<float> OpcWeight("ir2vec-opc-weight", cl::Optional, cl::init(Val: 1.0),
46 cl::desc("Weight for opcode embeddings"),
47 cl::cat(IR2VecCategory));
48cl::opt<float> TypeWeight("ir2vec-type-weight", cl::Optional, cl::init(Val: 0.5),
49 cl::desc("Weight for type embeddings"),
50 cl::cat(IR2VecCategory));
51cl::opt<float> ArgWeight("ir2vec-arg-weight", cl::Optional, cl::init(Val: 0.2),
52 cl::desc("Weight for argument embeddings"),
53 cl::cat(IR2VecCategory));
54} // namespace ir2vec
55} // namespace llvm
56
57AnalysisKey IR2VecVocabAnalysis::Key;
58
59namespace llvm::json {
60inline bool fromJSON(const llvm::json::Value &E, Embedding &Out,
61 llvm::json::Path P) {
62 std::vector<double> TempOut;
63 if (!llvm::json::fromJSON(E, Out&: TempOut, P))
64 return false;
65 Out = Embedding(std::move(TempOut));
66 return true;
67}
68} // namespace llvm::json
69
70// ==----------------------------------------------------------------------===//
71// Embedding
72//===----------------------------------------------------------------------===//
73Embedding &Embedding::operator+=(const Embedding &RHS) {
74 assert(this->size() == RHS.size() && "Vectors must have the same dimension");
75 std::transform(first1: this->begin(), last1: this->end(), first2: RHS.begin(), result: this->begin(),
76 binary_op: std::plus<double>());
77 return *this;
78}
79
80Embedding Embedding::operator+(const Embedding &RHS) const {
81 Embedding Result(*this);
82 Result += RHS;
83 return Result;
84}
85
86Embedding &Embedding::operator-=(const Embedding &RHS) {
87 assert(this->size() == RHS.size() && "Vectors must have the same dimension");
88 std::transform(first1: this->begin(), last1: this->end(), first2: RHS.begin(), result: this->begin(),
89 binary_op: std::minus<double>());
90 return *this;
91}
92
93Embedding Embedding::operator-(const Embedding &RHS) const {
94 Embedding Result(*this);
95 Result -= RHS;
96 return Result;
97}
98
99Embedding &Embedding::operator*=(double Factor) {
100 std::transform(first: this->begin(), last: this->end(), result: this->begin(),
101 unary_op: [Factor](double Elem) { return Elem * Factor; });
102 return *this;
103}
104
105Embedding Embedding::operator*(double Factor) const {
106 Embedding Result(*this);
107 Result *= Factor;
108 return Result;
109}
110
111Embedding &Embedding::scaleAndAdd(const Embedding &Src, float Factor) {
112 assert(this->size() == Src.size() && "Vectors must have the same dimension");
113 for (size_t Itr = 0; Itr < this->size(); ++Itr)
114 (*this)[Itr] += Src[Itr] * Factor;
115 return *this;
116}
117
118bool Embedding::approximatelyEquals(const Embedding &RHS,
119 double Tolerance) const {
120 assert(this->size() == RHS.size() && "Vectors must have the same dimension");
121 for (size_t Itr = 0; Itr < this->size(); ++Itr)
122 if (std::abs(x: (*this)[Itr] - RHS[Itr]) > Tolerance)
123 return false;
124 return true;
125}
126
127void Embedding::print(raw_ostream &OS) const {
128 OS << " [";
129 for (const auto &Elem : Data)
130 OS << " " << format(Fmt: "%.2f", Vals: Elem) << " ";
131 OS << "]\n";
132}
133
134// ==----------------------------------------------------------------------===//
135// Embedder and its subclasses
136//===----------------------------------------------------------------------===//
137
138Embedder::Embedder(const Function &F, const Vocab &Vocabulary)
139 : F(F), Vocabulary(Vocabulary),
140 Dimension(Vocabulary.begin()->second.size()), OpcWeight(::OpcWeight),
141 TypeWeight(::TypeWeight), ArgWeight(::ArgWeight) {}
142
143std::unique_ptr<Embedder> Embedder::create(IR2VecKind Mode, const Function &F,
144 const Vocab &Vocabulary) {
145 switch (Mode) {
146 case IR2VecKind::Symbolic:
147 return std::make_unique<SymbolicEmbedder>(args: F, args: Vocabulary);
148 }
149 return nullptr;
150}
151
152// FIXME: Currently lookups are string based. Use numeric Keys
153// for efficiency
154Embedding Embedder::lookupVocab(const std::string &Key) const {
155 Embedding Vec(Dimension, 0);
156 // FIXME: Use zero vectors in vocab and assert failure for
157 // unknown entities rather than silently returning zeroes here.
158 auto It = Vocabulary.find(x: Key);
159 if (It != Vocabulary.end())
160 return It->second;
161 LLVM_DEBUG(errs() << "cannot find key in map : " << Key << "\n");
162 ++VocabMissCounter;
163 return Vec;
164}
165
166const InstEmbeddingsMap &Embedder::getInstVecMap() const {
167 if (InstVecMap.empty())
168 computeEmbeddings();
169 return InstVecMap;
170}
171
172const BBEmbeddingsMap &Embedder::getBBVecMap() const {
173 if (BBVecMap.empty())
174 computeEmbeddings();
175 return BBVecMap;
176}
177
178const Embedding &Embedder::getBBVector(const BasicBlock &BB) const {
179 auto It = BBVecMap.find(Val: &BB);
180 if (It != BBVecMap.end())
181 return It->second;
182 computeEmbeddings(BB);
183 return BBVecMap[&BB];
184}
185
186const Embedding &Embedder::getFunctionVector() const {
187 // Currently, we always (re)compute the embeddings for the function.
188 // This is cheaper than caching the vector.
189 computeEmbeddings();
190 return FuncVector;
191}
192
193#define RETURN_LOOKUP_IF(CONDITION, KEY_STR) \
194 if (CONDITION) \
195 return lookupVocab(KEY_STR);
196
197Embedding SymbolicEmbedder::getTypeEmbedding(const Type *Ty) const {
198 RETURN_LOOKUP_IF(Ty->isVoidTy(), "voidTy");
199 RETURN_LOOKUP_IF(Ty->isFloatingPointTy(), "floatTy");
200 RETURN_LOOKUP_IF(Ty->isIntegerTy(), "integerTy");
201 RETURN_LOOKUP_IF(Ty->isFunctionTy(), "functionTy");
202 RETURN_LOOKUP_IF(Ty->isStructTy(), "structTy");
203 RETURN_LOOKUP_IF(Ty->isArrayTy(), "arrayTy");
204 RETURN_LOOKUP_IF(Ty->isPointerTy(), "pointerTy");
205 RETURN_LOOKUP_IF(Ty->isVectorTy(), "vectorTy");
206 RETURN_LOOKUP_IF(Ty->isEmptyTy(), "emptyTy");
207 RETURN_LOOKUP_IF(Ty->isLabelTy(), "labelTy");
208 RETURN_LOOKUP_IF(Ty->isTokenTy(), "tokenTy");
209 RETURN_LOOKUP_IF(Ty->isMetadataTy(), "metadataTy");
210 return lookupVocab(Key: "unknownTy");
211}
212
213Embedding SymbolicEmbedder::getOperandEmbedding(const Value *Op) const {
214 RETURN_LOOKUP_IF(isa<Function>(Op), "function");
215 RETURN_LOOKUP_IF(isa<PointerType>(Op->getType()), "pointer");
216 RETURN_LOOKUP_IF(isa<Constant>(Op), "constant");
217 return lookupVocab(Key: "variable");
218}
219
220#undef RETURN_LOOKUP_IF
221
222void SymbolicEmbedder::computeEmbeddings(const BasicBlock &BB) const {
223 Embedding BBVector(Dimension, 0);
224
225 // We consider only the non-debug and non-pseudo instructions
226 for (const auto &I : BB.instructionsWithoutDebug()) {
227 Embedding InstVector(Dimension, 0);
228
229 // FIXME: Currently lookups are string based. Use numeric Keys
230 // for efficiency.
231 InstVector += lookupVocab(Key: I.getOpcodeName());
232 InstVector += getTypeEmbedding(Ty: I.getType());
233 for (const auto &Op : I.operands()) {
234 InstVector += getOperandEmbedding(Op: Op.get());
235 }
236 InstVecMap[&I] = InstVector;
237 BBVector += InstVector;
238 }
239 BBVecMap[&BB] = BBVector;
240}
241
242void SymbolicEmbedder::computeEmbeddings() const {
243 if (F.isDeclaration())
244 return;
245
246 // Consider only the basic blocks that are reachable from entry
247 for (const BasicBlock *BB : depth_first(G: &F)) {
248 computeEmbeddings(BB: *BB);
249 FuncVector += BBVecMap[BB];
250 }
251}
252
253// ==----------------------------------------------------------------------===//
254// IR2VecVocabResult and IR2VecVocabAnalysis
255//===----------------------------------------------------------------------===//
256
257IR2VecVocabResult::IR2VecVocabResult(ir2vec::Vocab &&Vocabulary)
258 : Vocabulary(std::move(Vocabulary)), Valid(true) {}
259
260const ir2vec::Vocab &IR2VecVocabResult::getVocabulary() const {
261 assert(Valid && "IR2Vec Vocabulary is invalid");
262 return Vocabulary;
263}
264
265unsigned IR2VecVocabResult::getDimension() const {
266 assert(Valid && "IR2Vec Vocabulary is invalid");
267 return Vocabulary.begin()->second.size();
268}
269
270// For now, assume vocabulary is stable unless explicitly invalidated.
271bool IR2VecVocabResult::invalidate(
272 Module &M, const PreservedAnalyses &PA,
273 ModuleAnalysisManager::Invalidator &Inv) const {
274 auto PAC = PA.getChecker<IR2VecVocabAnalysis>();
275 return !(PAC.preservedWhenStateless());
276}
277
278Error IR2VecVocabAnalysis::parseVocabSection(
279 StringRef Key, const json::Value &ParsedVocabValue,
280 ir2vec::Vocab &TargetVocab, unsigned &Dim) {
281 json::Path::Root Path("");
282 const json::Object *RootObj = ParsedVocabValue.getAsObject();
283 if (!RootObj)
284 return createStringError(EC: errc::invalid_argument,
285 S: "JSON root is not an object");
286
287 const json::Value *SectionValue = RootObj->get(K: Key);
288 if (!SectionValue)
289 return createStringError(EC: errc::invalid_argument,
290 S: "Missing '" + std::string(Key) +
291 "' section in vocabulary file");
292 if (!json::fromJSON(E: *SectionValue, Out&: TargetVocab, P: Path))
293 return createStringError(EC: errc::illegal_byte_sequence,
294 S: "Unable to parse '" + std::string(Key) +
295 "' section from vocabulary");
296
297 Dim = TargetVocab.begin()->second.size();
298 if (Dim == 0)
299 return createStringError(EC: errc::illegal_byte_sequence,
300 S: "Dimension of '" + std::string(Key) +
301 "' section of the vocabulary is zero");
302
303 if (!std::all_of(first: TargetVocab.begin(), last: TargetVocab.end(),
304 pred: [Dim](const std::pair<StringRef, Embedding> &Entry) {
305 return Entry.second.size() == Dim;
306 }))
307 return createStringError(
308 EC: errc::illegal_byte_sequence,
309 S: "All vectors in the '" + std::string(Key) +
310 "' section of the vocabulary are not of the same dimension");
311
312 return Error::success();
313}
314
315// FIXME: Make this optional. We can avoid file reads
316// by auto-generating a default vocabulary during the build time.
317Error IR2VecVocabAnalysis::readVocabulary() {
318 auto BufOrError = MemoryBuffer::getFileOrSTDIN(Filename: VocabFile, /*IsText=*/true);
319 if (!BufOrError)
320 return createFileError(F: VocabFile, EC: BufOrError.getError());
321
322 auto Content = BufOrError.get()->getBuffer();
323
324 Expected<json::Value> ParsedVocabValue = json::parse(JSON: Content);
325 if (!ParsedVocabValue)
326 return ParsedVocabValue.takeError();
327
328 ir2vec::Vocab OpcodeVocab, TypeVocab, ArgVocab;
329 unsigned OpcodeDim = 0, TypeDim = 0, ArgDim = 0;
330 if (auto Err = parseVocabSection(Key: "Opcodes", ParsedVocabValue: *ParsedVocabValue, TargetVocab&: OpcodeVocab,
331 Dim&: OpcodeDim))
332 return Err;
333
334 if (auto Err =
335 parseVocabSection(Key: "Types", ParsedVocabValue: *ParsedVocabValue, TargetVocab&: TypeVocab, Dim&: TypeDim))
336 return Err;
337
338 if (auto Err =
339 parseVocabSection(Key: "Arguments", ParsedVocabValue: *ParsedVocabValue, TargetVocab&: ArgVocab, Dim&: ArgDim))
340 return Err;
341
342 if (!(OpcodeDim == TypeDim && TypeDim == ArgDim))
343 return createStringError(EC: errc::illegal_byte_sequence,
344 S: "Vocabulary sections have different dimensions");
345
346 auto scaleVocabSection = [](ir2vec::Vocab &Vocab, double Weight) {
347 for (auto &Entry : Vocab)
348 Entry.second *= Weight;
349 };
350 scaleVocabSection(OpcodeVocab, OpcWeight);
351 scaleVocabSection(TypeVocab, TypeWeight);
352 scaleVocabSection(ArgVocab, ArgWeight);
353
354 Vocabulary.insert(first: OpcodeVocab.begin(), last: OpcodeVocab.end());
355 Vocabulary.insert(first: TypeVocab.begin(), last: TypeVocab.end());
356 Vocabulary.insert(first: ArgVocab.begin(), last: ArgVocab.end());
357
358 return Error::success();
359}
360
361IR2VecVocabAnalysis::IR2VecVocabAnalysis(const Vocab &Vocabulary)
362 : Vocabulary(Vocabulary) {}
363
364IR2VecVocabAnalysis::IR2VecVocabAnalysis(Vocab &&Vocabulary)
365 : Vocabulary(std::move(Vocabulary)) {}
366
367void IR2VecVocabAnalysis::emitError(Error Err, LLVMContext &Ctx) {
368 handleAllErrors(E: std::move(Err), Handlers: [&](const ErrorInfoBase &EI) {
369 Ctx.emitError(ErrorStr: "Error reading vocabulary: " + EI.message());
370 });
371}
372
373IR2VecVocabAnalysis::Result
374IR2VecVocabAnalysis::run(Module &M, ModuleAnalysisManager &AM) {
375 auto Ctx = &M.getContext();
376 // If vocabulary is already populated by the constructor, use it.
377 if (!Vocabulary.empty())
378 return IR2VecVocabResult(std::move(Vocabulary));
379
380 // Otherwise, try to read from the vocabulary file.
381 if (VocabFile.empty()) {
382 // FIXME: Use default vocabulary
383 Ctx->emitError(ErrorStr: "IR2Vec vocabulary file path not specified");
384 return IR2VecVocabResult(); // Return invalid result
385 }
386 if (auto Err = readVocabulary()) {
387 emitError(Err: std::move(Err), Ctx&: *Ctx);
388 return IR2VecVocabResult();
389 }
390 return IR2VecVocabResult(std::move(Vocabulary));
391}
392
393// ==----------------------------------------------------------------------===//
394// Printer Passes
395//===----------------------------------------------------------------------===//
396
397PreservedAnalyses IR2VecPrinterPass::run(Module &M,
398 ModuleAnalysisManager &MAM) {
399 auto IR2VecVocabResult = MAM.getResult<IR2VecVocabAnalysis>(IR&: M);
400 assert(IR2VecVocabResult.isValid() && "IR2Vec Vocabulary is invalid");
401
402 auto Vocab = IR2VecVocabResult.getVocabulary();
403 for (Function &F : M) {
404 std::unique_ptr<Embedder> Emb =
405 Embedder::create(Mode: IR2VecKind::Symbolic, F, Vocabulary: Vocab);
406 if (!Emb) {
407 OS << "Error creating IR2Vec embeddings \n";
408 continue;
409 }
410
411 OS << "IR2Vec embeddings for function " << F.getName() << ":\n";
412 OS << "Function vector: ";
413 Emb->getFunctionVector().print(OS);
414
415 OS << "Basic block vectors:\n";
416 const auto &BBMap = Emb->getBBVecMap();
417 for (const BasicBlock &BB : F) {
418 auto It = BBMap.find(Val: &BB);
419 if (It != BBMap.end()) {
420 OS << "Basic block: " << BB.getName() << ":\n";
421 It->second.print(OS);
422 }
423 }
424
425 OS << "Instruction vectors:\n";
426 const auto &InstMap = Emb->getInstVecMap();
427 for (const BasicBlock &BB : F) {
428 for (const Instruction &I : BB) {
429 auto It = InstMap.find(Val: &I);
430 if (It != InstMap.end()) {
431 OS << "Instruction: ";
432 I.print(O&: OS);
433 It->second.print(OS);
434 }
435 }
436 }
437 }
438 return PreservedAnalyses::all();
439}
440
441PreservedAnalyses IR2VecVocabPrinterPass::run(Module &M,
442 ModuleAnalysisManager &MAM) {
443 auto IR2VecVocabResult = MAM.getResult<IR2VecVocabAnalysis>(IR&: M);
444 assert(IR2VecVocabResult.isValid() && "IR2Vec Vocabulary is invalid");
445
446 auto Vocab = IR2VecVocabResult.getVocabulary();
447 for (const auto &Entry : Vocab) {
448 OS << "Key: " << Entry.first << ": ";
449 Entry.second.print(OS);
450 }
451
452 return PreservedAnalyses::all();
453}
454