1//===- MIRUtils.cpp - MIR2Vec Embedding Generation --------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8///
9/// \file
10/// This file implements the MIR2VecTool class for MIR2Vec embedding generation
11/// from LLVM Machine IR. It has no dependency on the IR2Vec embedding API.
12///
13//===----------------------------------------------------------------------===//
14
15#include "MIRUtils.h"
16#include "llvm/CodeGen/MIR2Vec.h"
17#include "llvm/CodeGen/MIRParser/MIRParser.h"
18#include "llvm/CodeGen/MachineFunction.h"
19#include "llvm/CodeGen/MachineModuleInfo.h"
20#include "llvm/CodeGen/TargetInstrInfo.h"
21#include "llvm/CodeGen/TargetRegisterInfo.h"
22#include "llvm/Support/Debug.h"
23#include "llvm/Support/WithColor.h"
24#include "llvm/Support/raw_ostream.h"
25#include "llvm/Target/TargetMachine.h"
26
27#define DEBUG_TYPE "ir2vec"
28
29namespace llvm {
30namespace mir2vec {
31
32bool MIR2VecTool::initializeVocabulary(const Module &M) {
33 MIR2VecVocabProvider Provider(MMI);
34 auto VocabOrErr = Provider.getVocabulary(M);
35 if (!VocabOrErr) {
36 WithColor::error(OS&: errs(), Prefix: ToolName)
37 << "Failed to load MIR2Vec vocabulary - "
38 << toString(E: VocabOrErr.takeError()) << "\n";
39 return false;
40 }
41 Vocab = std::make_unique<MIRVocabulary>(args: std::move(*VocabOrErr));
42 return true;
43}
44
45bool MIR2VecTool::initializeVocabularyForLayout(const Module &M) {
46 for (const Function &F : M.getFunctionDefs()) {
47 MachineFunction *MF = MMI.getMachineFunction(F);
48 if (!MF)
49 continue;
50
51 const TargetInstrInfo &TII = *MF->getSubtarget().getInstrInfo();
52 const TargetRegisterInfo &TRI = *MF->getSubtarget().getRegisterInfo();
53 const MachineRegisterInfo &MRI = MF->getRegInfo();
54
55 auto VocabOrErr = MIRVocabulary::createDummyVocabForTest(TII, TRI, MRI, Dim: 1);
56 if (!VocabOrErr) {
57 WithColor::error(OS&: errs(), Prefix: ToolName)
58 << "Failed to create dummy vocabulary - "
59 << toString(E: VocabOrErr.takeError()) << "\n";
60 return false;
61 }
62 Vocab = std::make_unique<MIRVocabulary>(args: std::move(*VocabOrErr));
63 return true;
64 }
65
66 WithColor::error(OS&: errs(), Prefix: ToolName)
67 << "No machine functions found to initialize vocabulary\n";
68 return false;
69}
70
71TripletResult MIR2VecTool::generateTriplets(const MachineFunction &MF) const {
72 TripletResult Result;
73 Result.MaxRelation = MIRNextRelation;
74
75 if (!Vocab) {
76 WithColor::error(OS&: errs(), Prefix: ToolName)
77 << "MIR Vocabulary must be initialized for triplet generation.\n";
78 return Result;
79 }
80
81 unsigned PrevOpcode = 0;
82 bool HasPrevOpcode = false;
83 for (const MachineBasicBlock &MBB : MF) {
84 for (const MachineInstr &MI : MBB) {
85 if (MI.isDebugInstr())
86 continue;
87
88 unsigned OpcodeID = Vocab->getEntityIDForOpcode(Opcode: MI.getOpcode());
89
90 if (HasPrevOpcode) {
91 Result.Triplets.push_back(x: {.Head: PrevOpcode, .Tail: OpcodeID, .Relation: MIRNextRelation});
92 LLVM_DEBUG(dbgs() << Vocab->getStringKey(PrevOpcode) << '\t'
93 << Vocab->getStringKey(OpcodeID) << '\t' << "Next\n");
94 }
95
96 unsigned ArgIndex = 0;
97 for (const MachineOperand &MO : MI.operands()) {
98 auto OperandID = Vocab->getEntityIDForMachineOperand(MO);
99 unsigned RelationID = MIRArgRelation + ArgIndex;
100 Result.Triplets.push_back(x: {.Head: OpcodeID, .Tail: OperandID, .Relation: RelationID});
101 LLVM_DEBUG({
102 std::string OperandStr = Vocab->getStringKey(OperandID);
103 dbgs() << Vocab->getStringKey(OpcodeID) << '\t' << OperandStr << '\t'
104 << "Arg" << ArgIndex << '\n';
105 });
106
107 ++ArgIndex;
108 }
109
110 if (ArgIndex > 0)
111 Result.MaxRelation =
112 std::max(a: Result.MaxRelation, b: MIRArgRelation + ArgIndex - 1);
113
114 PrevOpcode = OpcodeID;
115 HasPrevOpcode = true;
116 }
117 }
118
119 return Result;
120}
121
122TripletResult MIR2VecTool::generateTriplets(const Module &M) const {
123 TripletResult Result;
124 Result.MaxRelation = MIRNextRelation;
125
126 for (const Function &F : M.getFunctionDefs()) {
127 MachineFunction *MF = MMI.getMachineFunction(F);
128 if (!MF) {
129 WithColor::warning(OS&: errs(), Prefix: ToolName)
130 << "No MachineFunction for " << F.getName() << "\n";
131 continue;
132 }
133
134 TripletResult FuncResult = generateTriplets(MF: *MF);
135 Result.MaxRelation = std::max(a: Result.MaxRelation, b: FuncResult.MaxRelation);
136 Result.Triplets.insert(position: Result.Triplets.end(), first: FuncResult.Triplets.begin(),
137 last: FuncResult.Triplets.end());
138 }
139
140 return Result;
141}
142
143void MIR2VecTool::writeTripletsToStream(const Module &M,
144 raw_ostream &OS) const {
145 auto Result = generateTriplets(M);
146 OS << "MAX_RELATION=" << Result.MaxRelation << '\n';
147 for (const auto &T : Result.Triplets)
148 OS << T.Head << '\t' << T.Tail << '\t' << T.Relation << '\n';
149}
150
151EntityList MIR2VecTool::collectEntityMappings() const {
152 if (!Vocab) {
153 WithColor::error(OS&: errs(), Prefix: ToolName)
154 << "Vocabulary must be initialized for entity mappings.\n";
155 return {};
156 }
157
158 const unsigned EntityCount = Vocab->getCanonicalSize();
159 EntityList Result;
160 for (unsigned EntityID = 0; EntityID < EntityCount; ++EntityID)
161 Result.push_back(x: Vocab->getStringKey(Pos: EntityID));
162
163 return Result;
164}
165
166void MIR2VecTool::writeEntitiesToStream(raw_ostream &OS) const {
167 auto Entities = collectEntityMappings();
168 if (Entities.empty())
169 return;
170
171 OS << Entities.size() << "\n";
172 for (unsigned EntityID = 0; EntityID < Entities.size(); ++EntityID)
173 OS << Entities[EntityID] << '\t' << EntityID << '\n';
174}
175
176void MIR2VecTool::writeEmbeddingsToStream(const Module &M, raw_ostream &OS,
177 EmbeddingLevel Level) const {
178 if (!Vocab) {
179 WithColor::error(OS&: errs(), Prefix: ToolName) << "Vocabulary not initialized.\n";
180 return;
181 }
182
183 for (const Function &F : M.getFunctionDefs()) {
184 MachineFunction *MF = MMI.getMachineFunction(F);
185 if (!MF) {
186 WithColor::warning(OS&: errs(), Prefix: ToolName)
187 << "No MachineFunction for " << F.getName() << "\n";
188 continue;
189 }
190
191 writeEmbeddingsToStream(MF&: *MF, OS, Level);
192 }
193}
194
195void MIR2VecTool::writeEmbeddingsToStream(MachineFunction &MF, raw_ostream &OS,
196 EmbeddingLevel Level) const {
197 if (!Vocab) {
198 WithColor::error(OS&: errs(), Prefix: ToolName) << "Vocabulary not initialized.\n";
199 return;
200 }
201
202 auto Emb = MIREmbedder::create(Mode: MIR2VecKind::Symbolic, MF, Vocab: *Vocab);
203 if (!Emb) {
204 WithColor::error(OS&: errs(), Prefix: ToolName)
205 << "Failed to create embedder for " << MF.getName() << "\n";
206 return;
207 }
208
209 OS << "MIR2Vec embeddings for machine function " << MF.getName() << ":\n";
210
211 switch (Level) {
212 case FunctionLevel:
213 OS << "Function vector: ";
214 Emb->getMFunctionVector().print(OS);
215 break;
216 case BasicBlockLevel:
217 OS << "Basic block vectors:\n";
218 for (const MachineBasicBlock &MBB : MF) {
219 OS << "MBB " << MBB.getName() << ": ";
220 Emb->getMBBVector(MBB).print(OS);
221 }
222 break;
223 case InstructionLevel:
224 OS << "Instruction vectors:\n";
225 for (const MachineBasicBlock &MBB : MF) {
226 for (const MachineInstr &MI : MBB) {
227 OS << MI << " -> ";
228 Emb->getMInstVector(MI).print(OS);
229 }
230 }
231 break;
232 }
233}
234
235} // namespace mir2vec
236} // namespace llvm
237