1//===- IRUtils.h - IR2Vec Tool Class ------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8///
9/// \file
10/// This file contains the IR2VecTool class definition for generating
11/// embeddings and triplets from LLVM IR. It has no dependency on Machine IR.
12///
13//===----------------------------------------------------------------------===//
14
15#ifndef LLVM_TOOLS_LLVM_IR2VEC_UTILS_IRUTILS_H
16#define LLVM_TOOLS_LLVM_IR2VEC_UTILS_IRUTILS_H
17
18#include "Common.h"
19#include "llvm/ADT/DenseMap.h"
20#include "llvm/Analysis/IR2Vec.h"
21#include "llvm/IR/Function.h"
22#include "llvm/IR/Module.h"
23#include "llvm/IR/PassInstrumentation.h"
24#include "llvm/IR/PassManager.h"
25#include "llvm/Support/Error.h"
26#include "llvm/Support/raw_ostream.h"
27#include <memory>
28
29#define DEBUG_TYPE "ir2vec"
30
31namespace llvm {
32
33/// Per-function embedding map: Function* -> Embedding
34using FuncEmbMap = DenseMap<const Function *, ir2vec::Embedding>;
35
36namespace ir2vec {
37
38/// Relation types for IR triplet generation
39enum RelationType {
40 TypeRelation = 0, ///< Instruction to type relationship
41 NextRelation = 1, ///< Sequential instruction relationship
42 ArgRelation = 2 ///< Instruction to operand relationship (ArgRelation + N)
43};
44
45/// Load an IR2Vec vocabulary from a JSON file on disk.
46Expected<std::shared_ptr<Vocabulary>> loadVocabulary(StringRef VocabPath);
47
48/// Helper class for collecting IR triplets and generating embeddings
49class IR2VecTool {
50private:
51 Module &M;
52 ModuleAnalysisManager MAM;
53
54 /// \note The API around vocab object is not thread-safe.
55 /// Specifically, calling setVocabulary() on an instance while
56 /// another thread reading the Vocab object with the same instance
57 /// can cause a data race on this internal shared_ptr<Vocabulary> member.
58 std::shared_ptr<Vocabulary> Vocab;
59
60public:
61 explicit IR2VecTool(Module &M) : M(M) {}
62
63 /// Creates the embedding object for downstream embedding streaming
64 Expected<std::unique_ptr<Embedder>>
65 createIR2VecEmbedder(const Function &F, IR2VecKind Kind) const;
66
67 /// Sets the vocabulary for this tool instance.
68 /// This allows sharing the same vocabulary instance across multiple
69 /// IR2VecTool instances, which is useful for generating embeddings for
70 /// multiple functions without needing to reload the vocabulary each time.
71 Error setVocabulary(std::shared_ptr<Vocabulary> V);
72
73 /// Generate triplets for a single function
74 /// Returns a TripletResult with:
75 /// - Triplets: vector of all (subject, object, relation) tuples
76 /// - MaxRelation: highest Arg relation ID used, or NextRelation if none
77 TripletResult generateTriplets(const Function &F) const;
78
79 /// Get triplets for the entire module
80 TripletResult generateTriplets() const;
81
82 /// Collect triplets for the module and dump output to stream
83 /// Output format: MAX_RELATION=N header followed by relationships
84 void writeTripletsToStream(raw_ostream &OS) const;
85
86 /// Generate entity mappings for the entire vocabulary
87 /// Returns EntityList containing all entity strings
88 static EntityList collectEntityMappings();
89
90 /// Dump entity ID to string mappings
91 static void writeEntitiesToStream(raw_ostream &OS);
92
93 /// Get embedding for a single function
94 Expected<Embedding> getFunctionEmbedding(const Function &F,
95 IR2VecKind Kind) const;
96
97 /// Get embeddings for all functions in the module
98 Expected<FuncEmbMap> getFunctionEmbeddingsMap(IR2VecKind Kind) const;
99
100 /// Get embeddings for all basic blocks in a function
101 Expected<BBEmbeddingsMap> getBBEmbeddingsMap(const Function &F,
102 IR2VecKind Kind) const;
103
104 /// Get embeddings for all instructions in a function
105 Expected<InstEmbeddingsMap> getInstEmbeddingsMap(const Function &F,
106 IR2VecKind Kind) const;
107
108 /// Generate embeddings for the entire module
109 void writeEmbeddingsToStream(raw_ostream &OS, EmbeddingLevel Level) const;
110
111 /// Generate embeddings for a single function
112 void writeEmbeddingsToStream(const Function &F, raw_ostream &OS,
113 EmbeddingLevel Level) const;
114};
115
116} // namespace ir2vec
117} // namespace llvm
118
119#endif // LLVM_TOOLS_LLVM_IR2VEC_UTILS_IRUTILS_H
120