| 1 | //===- IR2Vec.cpp - Implementation of IR2Vec -----------------------------===// | 
|---|
| 2 | // | 
|---|
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM | 
|---|
| 4 | // Exceptions. See the LICENSE file for license information. | 
|---|
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | 
|---|
| 6 | // | 
|---|
| 7 | //===----------------------------------------------------------------------===// | 
|---|
| 8 | /// | 
|---|
| 9 | /// \file | 
|---|
| 10 | /// This file implements the IR2Vec algorithm. | 
|---|
| 11 | /// | 
|---|
| 12 | //===----------------------------------------------------------------------===// | 
|---|
| 13 |  | 
|---|
| 14 | #include "llvm/Analysis/IR2Vec.h" | 
|---|
| 15 |  | 
|---|
| 16 | #include "llvm/ADT/DepthFirstIterator.h" | 
|---|
| 17 | #include "llvm/ADT/Statistic.h" | 
|---|
| 18 | #include "llvm/IR/CFG.h" | 
|---|
| 19 | #include "llvm/IR/Module.h" | 
|---|
| 20 | #include "llvm/IR/PassManager.h" | 
|---|
| 21 | #include "llvm/Support/Debug.h" | 
|---|
| 22 | #include "llvm/Support/Errc.h" | 
|---|
| 23 | #include "llvm/Support/Error.h" | 
|---|
| 24 | #include "llvm/Support/ErrorHandling.h" | 
|---|
| 25 | #include "llvm/Support/Format.h" | 
|---|
| 26 | #include "llvm/Support/MemoryBuffer.h" | 
|---|
| 27 |  | 
|---|
| 28 | using namespace llvm; | 
|---|
| 29 | using namespace ir2vec; | 
|---|
| 30 |  | 
|---|
| 31 | #define DEBUG_TYPE "ir2vec" | 
|---|
| 32 |  | 
|---|
| 33 | STATISTIC(VocabMissCounter, | 
|---|
| 34 | "Number of lookups to entites not present in the vocabulary"); | 
|---|
| 35 |  | 
|---|
| 36 | namespace llvm { | 
|---|
| 37 | namespace ir2vec { | 
|---|
| 38 | static cl::OptionCategory IR2VecCategory( "IR2Vec Options"); | 
|---|
| 39 |  | 
|---|
| 40 | // FIXME: Use a default vocab when not specified | 
|---|
| 41 | static cl::opt<std::string> | 
|---|
| 42 | VocabFile( "ir2vec-vocab-path", cl::Optional, | 
|---|
| 43 | cl::desc( "Path to the vocabulary file for IR2Vec"), cl::init(Val: ""), | 
|---|
| 44 | cl::cat(IR2VecCategory)); | 
|---|
| 45 | cl::opt<float> OpcWeight( "ir2vec-opc-weight", cl::Optional, cl::init(Val: 1.0), | 
|---|
| 46 | cl::desc( "Weight for opcode embeddings"), | 
|---|
| 47 | cl::cat(IR2VecCategory)); | 
|---|
| 48 | cl::opt<float> TypeWeight( "ir2vec-type-weight", cl::Optional, cl::init(Val: 0.5), | 
|---|
| 49 | cl::desc( "Weight for type embeddings"), | 
|---|
| 50 | cl::cat(IR2VecCategory)); | 
|---|
| 51 | cl::opt<float> ArgWeight( "ir2vec-arg-weight", cl::Optional, cl::init(Val: 0.2), | 
|---|
| 52 | cl::desc( "Weight for argument embeddings"), | 
|---|
| 53 | cl::cat(IR2VecCategory)); | 
|---|
| 54 | } // namespace ir2vec | 
|---|
| 55 | } // namespace llvm | 
|---|
| 56 |  | 
|---|
| 57 | AnalysisKey IR2VecVocabAnalysis::Key; | 
|---|
| 58 |  | 
|---|
| 59 | namespace llvm::json { | 
|---|
| 60 | inline bool fromJSON(const llvm::json::Value &E, Embedding &Out, | 
|---|
| 61 | llvm::json::Path P) { | 
|---|
| 62 | std::vector<double> TempOut; | 
|---|
| 63 | if (!llvm::json::fromJSON(E, Out&: TempOut, P)) | 
|---|
| 64 | return false; | 
|---|
| 65 | Out = Embedding(std::move(TempOut)); | 
|---|
| 66 | return true; | 
|---|
| 67 | } | 
|---|
| 68 | } // namespace llvm::json | 
|---|
| 69 |  | 
|---|
| 70 | // ==----------------------------------------------------------------------===// | 
|---|
| 71 | // Embedding | 
|---|
| 72 | //===----------------------------------------------------------------------===// | 
|---|
| 73 | Embedding &Embedding::operator+=(const Embedding &RHS) { | 
|---|
| 74 | assert(this->size() == RHS.size() && "Vectors must have the same dimension"); | 
|---|
| 75 | std::transform(first1: this->begin(), last1: this->end(), first2: RHS.begin(), result: this->begin(), | 
|---|
| 76 | binary_op: std::plus<double>()); | 
|---|
| 77 | return *this; | 
|---|
| 78 | } | 
|---|
| 79 |  | 
|---|
| 80 | Embedding Embedding::operator+(const Embedding &RHS) const { | 
|---|
| 81 | Embedding Result(*this); | 
|---|
| 82 | Result += RHS; | 
|---|
| 83 | return Result; | 
|---|
| 84 | } | 
|---|
| 85 |  | 
|---|
| 86 | Embedding &Embedding::operator-=(const Embedding &RHS) { | 
|---|
| 87 | assert(this->size() == RHS.size() && "Vectors must have the same dimension"); | 
|---|
| 88 | std::transform(first1: this->begin(), last1: this->end(), first2: RHS.begin(), result: this->begin(), | 
|---|
| 89 | binary_op: std::minus<double>()); | 
|---|
| 90 | return *this; | 
|---|
| 91 | } | 
|---|
| 92 |  | 
|---|
| 93 | Embedding Embedding::operator-(const Embedding &RHS) const { | 
|---|
| 94 | Embedding Result(*this); | 
|---|
| 95 | Result -= RHS; | 
|---|
| 96 | return Result; | 
|---|
| 97 | } | 
|---|
| 98 |  | 
|---|
| 99 | Embedding &Embedding::operator*=(double Factor) { | 
|---|
| 100 | std::transform(first: this->begin(), last: this->end(), result: this->begin(), | 
|---|
| 101 | unary_op: [Factor](double Elem) { return Elem * Factor; }); | 
|---|
| 102 | return *this; | 
|---|
| 103 | } | 
|---|
| 104 |  | 
|---|
| 105 | Embedding Embedding::operator*(double Factor) const { | 
|---|
| 106 | Embedding Result(*this); | 
|---|
| 107 | Result *= Factor; | 
|---|
| 108 | return Result; | 
|---|
| 109 | } | 
|---|
| 110 |  | 
|---|
| 111 | Embedding &Embedding::scaleAndAdd(const Embedding &Src, float Factor) { | 
|---|
| 112 | assert(this->size() == Src.size() && "Vectors must have the same dimension"); | 
|---|
| 113 | for (size_t Itr = 0; Itr < this->size(); ++Itr) | 
|---|
| 114 | (*this)[Itr] += Src[Itr] * Factor; | 
|---|
| 115 | return *this; | 
|---|
| 116 | } | 
|---|
| 117 |  | 
|---|
| 118 | bool Embedding::approximatelyEquals(const Embedding &RHS, | 
|---|
| 119 | double Tolerance) const { | 
|---|
| 120 | assert(this->size() == RHS.size() && "Vectors must have the same dimension"); | 
|---|
| 121 | for (size_t Itr = 0; Itr < this->size(); ++Itr) | 
|---|
| 122 | if (std::abs(x: (*this)[Itr] - RHS[Itr]) > Tolerance) | 
|---|
| 123 | return false; | 
|---|
| 124 | return true; | 
|---|
| 125 | } | 
|---|
| 126 |  | 
|---|
| 127 | void Embedding::print(raw_ostream &OS) const { | 
|---|
| 128 | OS << " ["; | 
|---|
| 129 | for (const auto &Elem : Data) | 
|---|
| 130 | OS << " "<< format(Fmt: "%.2f", Vals: Elem) << " "; | 
|---|
| 131 | OS << "]\n"; | 
|---|
| 132 | } | 
|---|
| 133 |  | 
|---|
| 134 | // ==----------------------------------------------------------------------===// | 
|---|
| 135 | // Embedder and its subclasses | 
|---|
| 136 | //===----------------------------------------------------------------------===// | 
|---|
| 137 |  | 
|---|
| 138 | Embedder::Embedder(const Function &F, const Vocab &Vocabulary) | 
|---|
| 139 | : F(F), Vocabulary(Vocabulary), | 
|---|
| 140 | Dimension(Vocabulary.begin()->second.size()), OpcWeight(::OpcWeight), | 
|---|
| 141 | TypeWeight(::TypeWeight), ArgWeight(::ArgWeight) {} | 
|---|
| 142 |  | 
|---|
| 143 | std::unique_ptr<Embedder> Embedder::create(IR2VecKind Mode, const Function &F, | 
|---|
| 144 | const Vocab &Vocabulary) { | 
|---|
| 145 | switch (Mode) { | 
|---|
| 146 | case IR2VecKind::Symbolic: | 
|---|
| 147 | return std::make_unique<SymbolicEmbedder>(args: F, args: Vocabulary); | 
|---|
| 148 | } | 
|---|
| 149 | return nullptr; | 
|---|
| 150 | } | 
|---|
| 151 |  | 
|---|
| 152 | // FIXME: Currently lookups are string based. Use numeric Keys | 
|---|
| 153 | // for efficiency | 
|---|
| 154 | Embedding Embedder::lookupVocab(const std::string &Key) const { | 
|---|
| 155 | Embedding Vec(Dimension, 0); | 
|---|
| 156 | // FIXME: Use zero vectors in vocab and assert failure for | 
|---|
| 157 | // unknown entities rather than silently returning zeroes here. | 
|---|
| 158 | auto It = Vocabulary.find(x: Key); | 
|---|
| 159 | if (It != Vocabulary.end()) | 
|---|
| 160 | return It->second; | 
|---|
| 161 | LLVM_DEBUG(errs() << "cannot find key in map : "<< Key << "\n"); | 
|---|
| 162 | ++VocabMissCounter; | 
|---|
| 163 | return Vec; | 
|---|
| 164 | } | 
|---|
| 165 |  | 
|---|
| 166 | const InstEmbeddingsMap &Embedder::getInstVecMap() const { | 
|---|
| 167 | if (InstVecMap.empty()) | 
|---|
| 168 | computeEmbeddings(); | 
|---|
| 169 | return InstVecMap; | 
|---|
| 170 | } | 
|---|
| 171 |  | 
|---|
| 172 | const BBEmbeddingsMap &Embedder::getBBVecMap() const { | 
|---|
| 173 | if (BBVecMap.empty()) | 
|---|
| 174 | computeEmbeddings(); | 
|---|
| 175 | return BBVecMap; | 
|---|
| 176 | } | 
|---|
| 177 |  | 
|---|
| 178 | const Embedding &Embedder::getBBVector(const BasicBlock &BB) const { | 
|---|
| 179 | auto It = BBVecMap.find(Val: &BB); | 
|---|
| 180 | if (It != BBVecMap.end()) | 
|---|
| 181 | return It->second; | 
|---|
| 182 | computeEmbeddings(BB); | 
|---|
| 183 | return BBVecMap[&BB]; | 
|---|
| 184 | } | 
|---|
| 185 |  | 
|---|
| 186 | const Embedding &Embedder::getFunctionVector() const { | 
|---|
| 187 | // Currently, we always (re)compute the embeddings for the function. | 
|---|
| 188 | // This is cheaper than caching the vector. | 
|---|
| 189 | computeEmbeddings(); | 
|---|
| 190 | return FuncVector; | 
|---|
| 191 | } | 
|---|
| 192 |  | 
|---|
| 193 | #define RETURN_LOOKUP_IF(CONDITION, KEY_STR)                                   \ | 
|---|
| 194 | if (CONDITION)                                                               \ | 
|---|
| 195 | return lookupVocab(KEY_STR); | 
|---|
| 196 |  | 
|---|
| 197 | Embedding SymbolicEmbedder::getTypeEmbedding(const Type *Ty) const { | 
|---|
| 198 | RETURN_LOOKUP_IF(Ty->isVoidTy(), "voidTy"); | 
|---|
| 199 | RETURN_LOOKUP_IF(Ty->isFloatingPointTy(), "floatTy"); | 
|---|
| 200 | RETURN_LOOKUP_IF(Ty->isIntegerTy(), "integerTy"); | 
|---|
| 201 | RETURN_LOOKUP_IF(Ty->isFunctionTy(), "functionTy"); | 
|---|
| 202 | RETURN_LOOKUP_IF(Ty->isStructTy(), "structTy"); | 
|---|
| 203 | RETURN_LOOKUP_IF(Ty->isArrayTy(), "arrayTy"); | 
|---|
| 204 | RETURN_LOOKUP_IF(Ty->isPointerTy(), "pointerTy"); | 
|---|
| 205 | RETURN_LOOKUP_IF(Ty->isVectorTy(), "vectorTy"); | 
|---|
| 206 | RETURN_LOOKUP_IF(Ty->isEmptyTy(), "emptyTy"); | 
|---|
| 207 | RETURN_LOOKUP_IF(Ty->isLabelTy(), "labelTy"); | 
|---|
| 208 | RETURN_LOOKUP_IF(Ty->isTokenTy(), "tokenTy"); | 
|---|
| 209 | RETURN_LOOKUP_IF(Ty->isMetadataTy(), "metadataTy"); | 
|---|
| 210 | return lookupVocab(Key: "unknownTy"); | 
|---|
| 211 | } | 
|---|
| 212 |  | 
|---|
| 213 | Embedding SymbolicEmbedder::getOperandEmbedding(const Value *Op) const { | 
|---|
| 214 | RETURN_LOOKUP_IF(isa<Function>(Op), "function"); | 
|---|
| 215 | RETURN_LOOKUP_IF(isa<PointerType>(Op->getType()), "pointer"); | 
|---|
| 216 | RETURN_LOOKUP_IF(isa<Constant>(Op), "constant"); | 
|---|
| 217 | return lookupVocab(Key: "variable"); | 
|---|
| 218 | } | 
|---|
| 219 |  | 
|---|
| 220 | #undef RETURN_LOOKUP_IF | 
|---|
| 221 |  | 
|---|
| 222 | void SymbolicEmbedder::computeEmbeddings(const BasicBlock &BB) const { | 
|---|
| 223 | Embedding BBVector(Dimension, 0); | 
|---|
| 224 |  | 
|---|
| 225 | // We consider only the non-debug and non-pseudo instructions | 
|---|
| 226 | for (const auto &I : BB.instructionsWithoutDebug()) { | 
|---|
| 227 | Embedding InstVector(Dimension, 0); | 
|---|
| 228 |  | 
|---|
| 229 | // FIXME: Currently lookups are string based. Use numeric Keys | 
|---|
| 230 | // for efficiency. | 
|---|
| 231 | InstVector += lookupVocab(Key: I.getOpcodeName()); | 
|---|
| 232 | InstVector += getTypeEmbedding(Ty: I.getType()); | 
|---|
| 233 | for (const auto &Op : I.operands()) { | 
|---|
| 234 | InstVector += getOperandEmbedding(Op: Op.get()); | 
|---|
| 235 | } | 
|---|
| 236 | InstVecMap[&I] = InstVector; | 
|---|
| 237 | BBVector += InstVector; | 
|---|
| 238 | } | 
|---|
| 239 | BBVecMap[&BB] = BBVector; | 
|---|
| 240 | } | 
|---|
| 241 |  | 
|---|
| 242 | void SymbolicEmbedder::computeEmbeddings() const { | 
|---|
| 243 | if (F.isDeclaration()) | 
|---|
| 244 | return; | 
|---|
| 245 |  | 
|---|
| 246 | // Consider only the basic blocks that are reachable from entry | 
|---|
| 247 | for (const BasicBlock *BB : depth_first(G: &F)) { | 
|---|
| 248 | computeEmbeddings(BB: *BB); | 
|---|
| 249 | FuncVector += BBVecMap[BB]; | 
|---|
| 250 | } | 
|---|
| 251 | } | 
|---|
| 252 |  | 
|---|
| 253 | // ==----------------------------------------------------------------------===// | 
|---|
| 254 | // IR2VecVocabResult and IR2VecVocabAnalysis | 
|---|
| 255 | //===----------------------------------------------------------------------===// | 
|---|
| 256 |  | 
|---|
| 257 | IR2VecVocabResult::IR2VecVocabResult(ir2vec::Vocab &&Vocabulary) | 
|---|
| 258 | : Vocabulary(std::move(Vocabulary)), Valid(true) {} | 
|---|
| 259 |  | 
|---|
| 260 | const ir2vec::Vocab &IR2VecVocabResult::getVocabulary() const { | 
|---|
| 261 | assert(Valid && "IR2Vec Vocabulary is invalid"); | 
|---|
| 262 | return Vocabulary; | 
|---|
| 263 | } | 
|---|
| 264 |  | 
|---|
| 265 | unsigned IR2VecVocabResult::getDimension() const { | 
|---|
| 266 | assert(Valid && "IR2Vec Vocabulary is invalid"); | 
|---|
| 267 | return Vocabulary.begin()->second.size(); | 
|---|
| 268 | } | 
|---|
| 269 |  | 
|---|
| 270 | // For now, assume vocabulary is stable unless explicitly invalidated. | 
|---|
| 271 | bool IR2VecVocabResult::invalidate( | 
|---|
| 272 | Module &M, const PreservedAnalyses &PA, | 
|---|
| 273 | ModuleAnalysisManager::Invalidator &Inv) const { | 
|---|
| 274 | auto PAC = PA.getChecker<IR2VecVocabAnalysis>(); | 
|---|
| 275 | return !(PAC.preservedWhenStateless()); | 
|---|
| 276 | } | 
|---|
| 277 |  | 
|---|
| 278 | Error IR2VecVocabAnalysis::parseVocabSection( | 
|---|
| 279 | StringRef Key, const json::Value &ParsedVocabValue, | 
|---|
| 280 | ir2vec::Vocab &TargetVocab, unsigned &Dim) { | 
|---|
| 281 | json::Path::Root Path( ""); | 
|---|
| 282 | const json::Object *RootObj = ParsedVocabValue.getAsObject(); | 
|---|
| 283 | if (!RootObj) | 
|---|
| 284 | return createStringError(EC: errc::invalid_argument, | 
|---|
| 285 | S: "JSON root is not an object"); | 
|---|
| 286 |  | 
|---|
| 287 | const json::Value *SectionValue = RootObj->get(K: Key); | 
|---|
| 288 | if (!SectionValue) | 
|---|
| 289 | return createStringError(EC: errc::invalid_argument, | 
|---|
| 290 | S: "Missing '"+ std::string(Key) + | 
|---|
| 291 | "' section in vocabulary file"); | 
|---|
| 292 | if (!json::fromJSON(E: *SectionValue, Out&: TargetVocab, P: Path)) | 
|---|
| 293 | return createStringError(EC: errc::illegal_byte_sequence, | 
|---|
| 294 | S: "Unable to parse '"+ std::string(Key) + | 
|---|
| 295 | "' section from vocabulary"); | 
|---|
| 296 |  | 
|---|
| 297 | Dim = TargetVocab.begin()->second.size(); | 
|---|
| 298 | if (Dim == 0) | 
|---|
| 299 | return createStringError(EC: errc::illegal_byte_sequence, | 
|---|
| 300 | S: "Dimension of '"+ std::string(Key) + | 
|---|
| 301 | "' section of the vocabulary is zero"); | 
|---|
| 302 |  | 
|---|
| 303 | if (!std::all_of(first: TargetVocab.begin(), last: TargetVocab.end(), | 
|---|
| 304 | pred: [Dim](const std::pair<StringRef, Embedding> &Entry) { | 
|---|
| 305 | return Entry.second.size() == Dim; | 
|---|
| 306 | })) | 
|---|
| 307 | return createStringError( | 
|---|
| 308 | EC: errc::illegal_byte_sequence, | 
|---|
| 309 | S: "All vectors in the '"+ std::string(Key) + | 
|---|
| 310 | "' section of the vocabulary are not of the same dimension"); | 
|---|
| 311 |  | 
|---|
| 312 | return Error::success(); | 
|---|
| 313 | } | 
|---|
| 314 |  | 
|---|
| 315 | // FIXME: Make this optional. We can avoid file reads | 
|---|
| 316 | // by auto-generating a default vocabulary during the build time. | 
|---|
| 317 | Error IR2VecVocabAnalysis::readVocabulary() { | 
|---|
| 318 | auto BufOrError = MemoryBuffer::getFileOrSTDIN(Filename: VocabFile, /*IsText=*/true); | 
|---|
| 319 | if (!BufOrError) | 
|---|
| 320 | return createFileError(F: VocabFile, EC: BufOrError.getError()); | 
|---|
| 321 |  | 
|---|
| 322 | auto Content = BufOrError.get()->getBuffer(); | 
|---|
| 323 |  | 
|---|
| 324 | Expected<json::Value> ParsedVocabValue = json::parse(JSON: Content); | 
|---|
| 325 | if (!ParsedVocabValue) | 
|---|
| 326 | return ParsedVocabValue.takeError(); | 
|---|
| 327 |  | 
|---|
| 328 | ir2vec::Vocab OpcodeVocab, TypeVocab, ArgVocab; | 
|---|
| 329 | unsigned OpcodeDim = 0, TypeDim = 0, ArgDim = 0; | 
|---|
| 330 | if (auto Err = parseVocabSection(Key: "Opcodes", ParsedVocabValue: *ParsedVocabValue, TargetVocab&: OpcodeVocab, | 
|---|
| 331 | Dim&: OpcodeDim)) | 
|---|
| 332 | return Err; | 
|---|
| 333 |  | 
|---|
| 334 | if (auto Err = | 
|---|
| 335 | parseVocabSection(Key: "Types", ParsedVocabValue: *ParsedVocabValue, TargetVocab&: TypeVocab, Dim&: TypeDim)) | 
|---|
| 336 | return Err; | 
|---|
| 337 |  | 
|---|
| 338 | if (auto Err = | 
|---|
| 339 | parseVocabSection(Key: "Arguments", ParsedVocabValue: *ParsedVocabValue, TargetVocab&: ArgVocab, Dim&: ArgDim)) | 
|---|
| 340 | return Err; | 
|---|
| 341 |  | 
|---|
| 342 | if (!(OpcodeDim == TypeDim && TypeDim == ArgDim)) | 
|---|
| 343 | return createStringError(EC: errc::illegal_byte_sequence, | 
|---|
| 344 | S: "Vocabulary sections have different dimensions"); | 
|---|
| 345 |  | 
|---|
| 346 | auto scaleVocabSection = [](ir2vec::Vocab &Vocab, double Weight) { | 
|---|
| 347 | for (auto &Entry : Vocab) | 
|---|
| 348 | Entry.second *= Weight; | 
|---|
| 349 | }; | 
|---|
| 350 | scaleVocabSection(OpcodeVocab, OpcWeight); | 
|---|
| 351 | scaleVocabSection(TypeVocab, TypeWeight); | 
|---|
| 352 | scaleVocabSection(ArgVocab, ArgWeight); | 
|---|
| 353 |  | 
|---|
| 354 | Vocabulary.insert(first: OpcodeVocab.begin(), last: OpcodeVocab.end()); | 
|---|
| 355 | Vocabulary.insert(first: TypeVocab.begin(), last: TypeVocab.end()); | 
|---|
| 356 | Vocabulary.insert(first: ArgVocab.begin(), last: ArgVocab.end()); | 
|---|
| 357 |  | 
|---|
| 358 | return Error::success(); | 
|---|
| 359 | } | 
|---|
| 360 |  | 
|---|
| 361 | IR2VecVocabAnalysis::IR2VecVocabAnalysis(const Vocab &Vocabulary) | 
|---|
| 362 | : Vocabulary(Vocabulary) {} | 
|---|
| 363 |  | 
|---|
| 364 | IR2VecVocabAnalysis::IR2VecVocabAnalysis(Vocab &&Vocabulary) | 
|---|
| 365 | : Vocabulary(std::move(Vocabulary)) {} | 
|---|
| 366 |  | 
|---|
| 367 | void IR2VecVocabAnalysis::emitError(Error Err, LLVMContext &Ctx) { | 
|---|
| 368 | handleAllErrors(E: std::move(Err), Handlers: [&](const ErrorInfoBase &EI) { | 
|---|
| 369 | Ctx.emitError(ErrorStr: "Error reading vocabulary: "+ EI.message()); | 
|---|
| 370 | }); | 
|---|
| 371 | } | 
|---|
| 372 |  | 
|---|
| 373 | IR2VecVocabAnalysis::Result | 
|---|
| 374 | IR2VecVocabAnalysis::run(Module &M, ModuleAnalysisManager &AM) { | 
|---|
| 375 | auto Ctx = &M.getContext(); | 
|---|
| 376 | // If vocabulary is already populated by the constructor, use it. | 
|---|
| 377 | if (!Vocabulary.empty()) | 
|---|
| 378 | return IR2VecVocabResult(std::move(Vocabulary)); | 
|---|
| 379 |  | 
|---|
| 380 | // Otherwise, try to read from the vocabulary file. | 
|---|
| 381 | if (VocabFile.empty()) { | 
|---|
| 382 | // FIXME: Use default vocabulary | 
|---|
| 383 | Ctx->emitError(ErrorStr: "IR2Vec vocabulary file path not specified"); | 
|---|
| 384 | return IR2VecVocabResult(); // Return invalid result | 
|---|
| 385 | } | 
|---|
| 386 | if (auto Err = readVocabulary()) { | 
|---|
| 387 | emitError(Err: std::move(Err), Ctx&: *Ctx); | 
|---|
| 388 | return IR2VecVocabResult(); | 
|---|
| 389 | } | 
|---|
| 390 | return IR2VecVocabResult(std::move(Vocabulary)); | 
|---|
| 391 | } | 
|---|
| 392 |  | 
|---|
| 393 | // ==----------------------------------------------------------------------===// | 
|---|
| 394 | // Printer Passes | 
|---|
| 395 | //===----------------------------------------------------------------------===// | 
|---|
| 396 |  | 
|---|
| 397 | PreservedAnalyses IR2VecPrinterPass::run(Module &M, | 
|---|
| 398 | ModuleAnalysisManager &MAM) { | 
|---|
| 399 | auto IR2VecVocabResult = MAM.getResult<IR2VecVocabAnalysis>(IR&: M); | 
|---|
| 400 | assert(IR2VecVocabResult.isValid() && "IR2Vec Vocabulary is invalid"); | 
|---|
| 401 |  | 
|---|
| 402 | auto Vocab = IR2VecVocabResult.getVocabulary(); | 
|---|
| 403 | for (Function &F : M) { | 
|---|
| 404 | std::unique_ptr<Embedder> Emb = | 
|---|
| 405 | Embedder::create(Mode: IR2VecKind::Symbolic, F, Vocabulary: Vocab); | 
|---|
| 406 | if (!Emb) { | 
|---|
| 407 | OS << "Error creating IR2Vec embeddings \n"; | 
|---|
| 408 | continue; | 
|---|
| 409 | } | 
|---|
| 410 |  | 
|---|
| 411 | OS << "IR2Vec embeddings for function "<< F.getName() << ":\n"; | 
|---|
| 412 | OS << "Function vector: "; | 
|---|
| 413 | Emb->getFunctionVector().print(OS); | 
|---|
| 414 |  | 
|---|
| 415 | OS << "Basic block vectors:\n"; | 
|---|
| 416 | const auto &BBMap = Emb->getBBVecMap(); | 
|---|
| 417 | for (const BasicBlock &BB : F) { | 
|---|
| 418 | auto It = BBMap.find(Val: &BB); | 
|---|
| 419 | if (It != BBMap.end()) { | 
|---|
| 420 | OS << "Basic block: "<< BB.getName() << ":\n"; | 
|---|
| 421 | It->second.print(OS); | 
|---|
| 422 | } | 
|---|
| 423 | } | 
|---|
| 424 |  | 
|---|
| 425 | OS << "Instruction vectors:\n"; | 
|---|
| 426 | const auto &InstMap = Emb->getInstVecMap(); | 
|---|
| 427 | for (const BasicBlock &BB : F) { | 
|---|
| 428 | for (const Instruction &I : BB) { | 
|---|
| 429 | auto It = InstMap.find(Val: &I); | 
|---|
| 430 | if (It != InstMap.end()) { | 
|---|
| 431 | OS << "Instruction: "; | 
|---|
| 432 | I.print(O&: OS); | 
|---|
| 433 | It->second.print(OS); | 
|---|
| 434 | } | 
|---|
| 435 | } | 
|---|
| 436 | } | 
|---|
| 437 | } | 
|---|
| 438 | return PreservedAnalyses::all(); | 
|---|
| 439 | } | 
|---|
| 440 |  | 
|---|
| 441 | PreservedAnalyses IR2VecVocabPrinterPass::run(Module &M, | 
|---|
| 442 | ModuleAnalysisManager &MAM) { | 
|---|
| 443 | auto IR2VecVocabResult = MAM.getResult<IR2VecVocabAnalysis>(IR&: M); | 
|---|
| 444 | assert(IR2VecVocabResult.isValid() && "IR2Vec Vocabulary is invalid"); | 
|---|
| 445 |  | 
|---|
| 446 | auto Vocab = IR2VecVocabResult.getVocabulary(); | 
|---|
| 447 | for (const auto &Entry : Vocab) { | 
|---|
| 448 | OS << "Key: "<< Entry.first << ": "; | 
|---|
| 449 | Entry.second.print(OS); | 
|---|
| 450 | } | 
|---|
| 451 |  | 
|---|
| 452 | return PreservedAnalyses::all(); | 
|---|
| 453 | } | 
|---|
| 454 |  | 
|---|