1//===- MIRUtils.h - MIR2Vec Tool Class ----------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8///
9/// \file
10/// This file contains the MIR2VecTool class definition for generating
11/// embeddings and triplets from LLVM Machine IR. It has no dependency on
12/// the LLVM IR embedding API (IR2VecTool).
13///
14//===----------------------------------------------------------------------===//
15
16#ifndef LLVM_TOOLS_LLVM_IR2VEC_UTILS_MIRUTILS_H
17#define LLVM_TOOLS_LLVM_IR2VEC_UTILS_MIRUTILS_H
18
19#include "Common.h"
20#include "llvm/CodeGen/MIR2Vec.h"
21#include "llvm/CodeGen/MIRParser/MIRParser.h"
22#include "llvm/CodeGen/MachineFunction.h"
23#include "llvm/CodeGen/MachineModuleInfo.h"
24#include "llvm/CodeGen/TargetInstrInfo.h"
25#include "llvm/CodeGen/TargetRegisterInfo.h"
26#include "llvm/IR/LLVMContext.h"
27#include "llvm/IR/Module.h"
28#include "llvm/Support/Error.h"
29#include "llvm/Support/raw_ostream.h"
30#include "llvm/Target/TargetMachine.h"
31#include <memory>
32
33namespace llvm {
34
35namespace mir2vec {
36
37/// Relation types for MIR2Vec triplet generation
38enum MIRRelationType {
39 MIRNextRelation = 0, ///< Sequential instruction relationship
40 MIRArgRelation = 1 ///< Instruction to operand relationship (ArgRelation + N)
41};
42
43/// Helper class for MIR2Vec embedding generation
44class MIR2VecTool {
45private:
46 MachineModuleInfo &MMI;
47 std::unique_ptr<MIRVocabulary> Vocab;
48
49public:
50 explicit MIR2VecTool(MachineModuleInfo &MMI) : MMI(MMI) {}
51
52 /// Initialize MIR2Vec vocabulary from file (for embeddings generation)
53 bool initializeVocabulary(const Module &M);
54
55 /// Initialize vocabulary with layout information only.
56 /// This creates a minimal vocabulary with correct layout but no actual
57 /// embeddings. Sufficient for generating training data and entity mappings.
58 ///
59 /// Note: Requires target-specific information from the first machine function
60 /// to determine the vocabulary layout (number of opcodes, register classes).
61 ///
62 /// FIXME: Use --target option to get target info directly, avoiding the need
63 /// to parse machine functions for pre-training operations.
64 bool initializeVocabularyForLayout(const Module &M);
65
66 /// Get triplets for a single machine function
67 /// Returns TripletResult containing MaxRelation and vector of Triplets
68 TripletResult generateTriplets(const MachineFunction &MF) const;
69
70 /// Get triplets for the entire module
71 /// Returns TripletResult containing aggregated MaxRelation and all Triplets
72 TripletResult generateTriplets(const Module &M) const;
73
74 /// Collect triplets for the module and write to output stream
75 /// Output format: MAX_RELATION=N header followed by relationships
76 void writeTripletsToStream(const Module &M, raw_ostream &OS) const;
77
78 /// Generate entity mappings for the entire vocabulary
79 EntityList collectEntityMappings() const;
80
81 /// Generate entity mappings and write to output stream
82 void writeEntitiesToStream(raw_ostream &OS) const;
83
84 /// Generate embeddings for all machine functions in the module
85 void writeEmbeddingsToStream(const Module &M, raw_ostream &OS,
86 EmbeddingLevel Level) const;
87
88 /// Generate embeddings for a specific machine function
89 void writeEmbeddingsToStream(MachineFunction &MF, raw_ostream &OS,
90 EmbeddingLevel Level) const;
91
92 /// Get the MIR vocabulary instance
93 const MIRVocabulary *getVocabulary() const { return Vocab.get(); }
94};
95
96/// Helper structure to hold MIR context.
97/// CRITICAL: Member declaration order matters for correct destruction.
98struct MIRContext {
99 LLVMContext Context; // Must be first: other members hold references into it
100 std::unique_ptr<Module> M;
101 std::unique_ptr<MachineModuleInfo> MMI;
102 std::unique_ptr<TargetMachine> TM;
103};
104
105} // namespace mir2vec
106} // namespace llvm
107
108#endif // LLVM_TOOLS_LLVM_IR2VEC_UTILS_MIRUTILS_H
109