1//===- MIR2Vec.cpp - Implementation of MIR2Vec ---------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM
4// Exceptions. See the LICENSE file for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8///
9/// \file
10/// This file implements the MIR2Vec algorithm for Machine IR embeddings.
11///
12//===----------------------------------------------------------------------===//
13
14#include "llvm/CodeGen/MIR2Vec.h"
15#include "llvm/ADT/DepthFirstIterator.h"
16#include "llvm/ADT/Statistic.h"
17#include "llvm/CodeGen/TargetInstrInfo.h"
18#include "llvm/IR/Module.h"
19#include "llvm/InitializePasses.h"
20#include "llvm/Pass.h"
21#include "llvm/Support/Errc.h"
22#include "llvm/Support/MemoryBuffer.h"
23#include "llvm/Support/Regex.h"
24
25using namespace llvm;
26using namespace mir2vec;
27
28#define DEBUG_TYPE "mir2vec"
29
30STATISTIC(MIRVocabMissCounter,
31 "Number of lookups to MIR entities not present in the vocabulary");
32
33namespace llvm {
34namespace mir2vec {
35cl::OptionCategory MIR2VecCategory("MIR2Vec Options");
36
37// FIXME: Use a default vocab when not specified
38static cl::opt<std::string>
39 VocabFile("mir2vec-vocab-path", cl::Optional,
40 cl::desc("Path to the vocabulary file for MIR2Vec"), cl::init(Val: ""),
41 cl::cat(MIR2VecCategory));
42cl::opt<float> OpcWeight("mir2vec-opc-weight", cl::Optional, cl::init(Val: 1.0),
43 cl::desc("Weight for machine opcode embeddings"),
44 cl::cat(MIR2VecCategory));
45cl::opt<float> CommonOperandWeight(
46 "mir2vec-common-operand-weight", cl::Optional, cl::init(Val: 1.0),
47 cl::desc("Weight for common operand embeddings"), cl::cat(MIR2VecCategory));
48cl::opt<float>
49 RegOperandWeight("mir2vec-reg-operand-weight", cl::Optional, cl::init(Val: 1.0),
50 cl::desc("Weight for register operand embeddings"),
51 cl::cat(MIR2VecCategory));
52cl::opt<MIR2VecKind> MIR2VecEmbeddingKind(
53 "mir2vec-kind", cl::Optional,
54 cl::values(clEnumValN(MIR2VecKind::Symbolic, "symbolic",
55 "Generate symbolic embeddings for MIR")),
56 cl::init(Val: MIR2VecKind::Symbolic), cl::desc("MIR2Vec embedding kind"),
57 cl::cat(MIR2VecCategory));
58
59static cl::opt<bool> PrintAllVocabEntries(
60 "mir2vec-print-all-vocab-entries", cl::Optional, cl::init(Val: false),
61 cl::desc("Print all vocabulary entries including zero embeddings"),
62 cl::cat(MIR2VecCategory));
63
64} // namespace mir2vec
65} // namespace llvm
66
67//===----------------------------------------------------------------------===//
68// Vocabulary
69//===----------------------------------------------------------------------===//
70
71MIRVocabulary::MIRVocabulary(VocabMap &&OpcodeMap, VocabMap &&CommonOperandMap,
72 VocabMap &&PhysicalRegisterMap,
73 VocabMap &&VirtualRegisterMap,
74 const TargetInstrInfo &TII,
75 const TargetRegisterInfo &TRI,
76 const MachineRegisterInfo &MRI)
77 : TII(TII), TRI(TRI), MRI(MRI) {
78 buildCanonicalOpcodeMapping();
79 unsigned CanonicalOpcodeCount = UniqueBaseOpcodeNames.size();
80 assert(CanonicalOpcodeCount > 0 &&
81 "No canonical opcodes found for target - invalid vocabulary");
82
83 buildRegisterOperandMapping();
84
85 // Define layout of vocabulary sections
86 Layout.OpcodeBase = 0;
87 Layout.CommonOperandBase = CanonicalOpcodeCount;
88 // We expect same classes for physical and virtual registers
89 Layout.PhyRegBase = Layout.CommonOperandBase + std::size(CommonOperandNames);
90 Layout.VirtRegBase = Layout.PhyRegBase + RegisterOperandNames.size();
91
92 generateStorage(OpcodeMap, CommonOperandMap, PhyRegMap: PhysicalRegisterMap,
93 VirtRegMap: VirtualRegisterMap);
94 Layout.TotalEntries = Storage.size();
95}
96
97Expected<MIRVocabulary>
98MIRVocabulary::create(VocabMap &&OpcodeMap, VocabMap &&CommonOperandMap,
99 VocabMap &&PhyRegMap, VocabMap &&VirtRegMap,
100 const TargetInstrInfo &TII, const TargetRegisterInfo &TRI,
101 const MachineRegisterInfo &MRI) {
102 if (OpcodeMap.empty() || CommonOperandMap.empty() || PhyRegMap.empty() ||
103 VirtRegMap.empty())
104 return createStringError(EC: errc::invalid_argument,
105 S: "Empty vocabulary entries provided");
106
107 MIRVocabulary Vocab(std::move(OpcodeMap), std::move(CommonOperandMap),
108 std::move(PhyRegMap), std::move(VirtRegMap), TII, TRI,
109 MRI);
110
111 // Validate Storage after construction
112 if (!Vocab.Storage.isValid())
113 return createStringError(EC: errc::invalid_argument,
114 S: "Failed to create valid vocabulary storage");
115 Vocab.ZeroEmbedding = Embedding(Vocab.Storage.getDimension(), 0.0);
116 return std::move(Vocab);
117}
118
119std::string MIRVocabulary::extractBaseOpcodeName(StringRef InstrName) {
120 // Extract base instruction name using regex to capture letters and
121 // underscores Examples: "ADD32rr" -> "ADD", "ARITH_FENCE" -> "ARITH_FENCE"
122 //
123 // TODO: Consider more sophisticated extraction:
124 // - Handle complex prefixes like "AVX1_SETALLONES" correctly (Currently, it
125 // would naively map to "AVX")
126 // - Extract width suffixes (8,16,32,64) as separate features
127 // - Capture addressing mode suffixes (r,i,m,ri,etc.) for better analysis
128 // (Currently, instances like "MOV32mi" map to "MOV", but "ADDPDrr" would map
129 // to "ADDPDrr")
130
131 assert(!InstrName.empty() && "Instruction name should not be empty");
132
133 // Use regex to extract initial sequence of letters and underscores
134 static const Regex BaseOpcodeRegex("([a-zA-Z_]+)");
135 SmallVector<StringRef, 2> Matches;
136
137 if (BaseOpcodeRegex.match(String: InstrName, Matches: &Matches) && Matches.size() > 1) {
138 StringRef Match = Matches[1];
139 // Trim trailing underscores
140 while (!Match.empty() && Match.back() == '_')
141 Match = Match.drop_back();
142 return Match.str();
143 }
144
145 // Fallback to original name if no pattern matches
146 return InstrName.str();
147}
148
149unsigned MIRVocabulary::getCanonicalIndexForBaseName(StringRef BaseName) const {
150 assert(!UniqueBaseOpcodeNames.empty() && "Canonical mapping not built");
151 auto It = std::find(first: UniqueBaseOpcodeNames.begin(),
152 last: UniqueBaseOpcodeNames.end(), val: BaseName.str());
153 assert(It != UniqueBaseOpcodeNames.end() &&
154 "Base name not found in unique opcodes");
155 return std::distance(first: UniqueBaseOpcodeNames.begin(), last: It);
156}
157
158unsigned MIRVocabulary::getCanonicalOpcodeIndex(unsigned Opcode) const {
159 auto BaseOpcode = extractBaseOpcodeName(InstrName: TII.getName(Opcode));
160 return getCanonicalIndexForBaseName(BaseName: BaseOpcode);
161}
162
163unsigned
164MIRVocabulary::getCanonicalIndexForOperandName(StringRef OperandName) const {
165 auto It = std::find(first: std::begin(arr: CommonOperandNames),
166 last: std::end(arr: CommonOperandNames), val: OperandName);
167 assert(It != std::end(CommonOperandNames) &&
168 "Operand name not found in common operands");
169 return Layout.CommonOperandBase +
170 std::distance(first: std::begin(arr: CommonOperandNames), last: It);
171}
172
173unsigned
174MIRVocabulary::getCanonicalIndexForRegisterClass(StringRef RegName,
175 bool IsPhysical) const {
176 auto It = std::find(first: RegisterOperandNames.begin(), last: RegisterOperandNames.end(),
177 val: RegName);
178 assert(It != RegisterOperandNames.end() &&
179 "Register name not found in register operands");
180 unsigned LocalIndex = std::distance(first: RegisterOperandNames.begin(), last: It);
181 return (IsPhysical ? Layout.PhyRegBase : Layout.VirtRegBase) + LocalIndex;
182}
183
184std::string MIRVocabulary::getStringKey(unsigned Pos) const {
185 assert(Pos < Layout.TotalEntries && "Position out of bounds in vocabulary");
186
187 // Handle opcodes section
188 if (Pos < Layout.CommonOperandBase) {
189 // Convert canonical index back to base opcode name
190 auto It = UniqueBaseOpcodeNames.begin();
191 std::advance(i&: It, n: Pos);
192 assert(It != UniqueBaseOpcodeNames.end() &&
193 "Canonical index out of bounds in opcode section");
194 return *It;
195 }
196
197 auto getLocalIndex = [](unsigned Pos, size_t BaseOffset, size_t Bound,
198 const char *Msg) {
199 unsigned LocalIndex = Pos - BaseOffset;
200 assert(LocalIndex < Bound && Msg);
201 return LocalIndex;
202 };
203
204 // Handle common operands section
205 if (Pos < Layout.PhyRegBase) {
206 unsigned LocalIndex = getLocalIndex(
207 Pos, Layout.CommonOperandBase, std::size(CommonOperandNames),
208 "Local index out of bounds in common operands");
209 return CommonOperandNames[LocalIndex].str();
210 }
211
212 // Handle physical registers section
213 if (Pos < Layout.VirtRegBase) {
214 unsigned LocalIndex =
215 getLocalIndex(Pos, Layout.PhyRegBase, RegisterOperandNames.size(),
216 "Local index out of bounds in physical registers");
217 return "PhyReg_" + RegisterOperandNames[LocalIndex];
218 }
219
220 // Handle virtual registers section
221 unsigned LocalIndex =
222 getLocalIndex(Pos, Layout.VirtRegBase, RegisterOperandNames.size(),
223 "Local index out of bounds in virtual registers");
224 return "VirtReg_" + RegisterOperandNames[LocalIndex];
225}
226
227void MIRVocabulary::generateStorage(const VocabMap &OpcodeMap,
228 const VocabMap &CommonOperandsMap,
229 const VocabMap &PhyRegMap,
230 const VocabMap &VirtRegMap) {
231
232 // Helper for handling missing entities in the vocabulary.
233 // Currently, we use a zero vector. In the future, we will throw an error to
234 // ensure that *all* known entities are present in the vocabulary.
235 auto handleMissingEntity = [](StringRef Key) {
236 LLVM_DEBUG(errs() << "MIR2Vec: Missing vocabulary entry for " << Key
237 << "; using zero vector. This will result in an error "
238 "in the future.\n");
239 ++MIRVocabMissCounter;
240 };
241
242 // Initialize opcode embeddings section
243 unsigned EmbeddingDim = OpcodeMap.begin()->second.size();
244 std::vector<Embedding> OpcodeEmbeddings(Layout.CommonOperandBase,
245 Embedding(EmbeddingDim));
246
247 // Populate opcode embeddings using canonical mapping
248 for (auto COpcodeName : UniqueBaseOpcodeNames) {
249 if (auto It = OpcodeMap.find(x: COpcodeName); It != OpcodeMap.end()) {
250 auto COpcodeIndex = getCanonicalIndexForBaseName(BaseName: COpcodeName);
251 assert(COpcodeIndex < Layout.CommonOperandBase &&
252 "Canonical index out of bounds");
253 OpcodeEmbeddings[COpcodeIndex] = It->second;
254 } else {
255 handleMissingEntity(COpcodeName);
256 }
257 }
258
259 // Initialize common operand embeddings section
260 std::vector<Embedding> CommonOperandEmbeddings(std::size(CommonOperandNames),
261 Embedding(EmbeddingDim));
262 unsigned OperandIndex = 0;
263 for (const auto &CommonOperandName : CommonOperandNames) {
264 if (auto It = CommonOperandsMap.find(x: CommonOperandName.str());
265 It != CommonOperandsMap.end()) {
266 CommonOperandEmbeddings[OperandIndex] = It->second;
267 } else {
268 handleMissingEntity(CommonOperandName);
269 }
270 ++OperandIndex;
271 }
272
273 // Helper lambda for creating register operand embeddings
274 auto createRegisterEmbeddings = [&](const VocabMap &RegMap) {
275 std::vector<Embedding> RegEmbeddings(TRI.getNumRegClasses(),
276 Embedding(EmbeddingDim));
277 unsigned RegOperandIndex = 0;
278 for (const auto &RegOperandName : RegisterOperandNames) {
279 if (auto It = RegMap.find(x: RegOperandName); It != RegMap.end())
280 RegEmbeddings[RegOperandIndex] = It->second;
281 else
282 handleMissingEntity(RegOperandName);
283 ++RegOperandIndex;
284 }
285 return RegEmbeddings;
286 };
287
288 // Initialize register operand embeddings sections
289 std::vector<Embedding> PhyRegEmbeddings = createRegisterEmbeddings(PhyRegMap);
290 std::vector<Embedding> VirtRegEmbeddings =
291 createRegisterEmbeddings(VirtRegMap);
292
293 // Scale the vocabulary sections based on the provided weights
294 auto scaleVocabSection = [](std::vector<Embedding> &Embeddings,
295 double Weight) {
296 for (auto &Embedding : Embeddings)
297 Embedding *= Weight;
298 };
299 scaleVocabSection(OpcodeEmbeddings, OpcWeight);
300 scaleVocabSection(CommonOperandEmbeddings, CommonOperandWeight);
301 scaleVocabSection(PhyRegEmbeddings, RegOperandWeight);
302 scaleVocabSection(VirtRegEmbeddings, RegOperandWeight);
303
304 std::vector<std::vector<Embedding>> Sections(
305 static_cast<unsigned>(Section::MaxSections));
306 Sections[static_cast<unsigned>(Section::Opcodes)] =
307 std::move(OpcodeEmbeddings);
308 Sections[static_cast<unsigned>(Section::CommonOperands)] =
309 std::move(CommonOperandEmbeddings);
310 Sections[static_cast<unsigned>(Section::PhyRegisters)] =
311 std::move(PhyRegEmbeddings);
312 Sections[static_cast<unsigned>(Section::VirtRegisters)] =
313 std::move(VirtRegEmbeddings);
314
315 Storage = ir2vec::VocabStorage(std::move(Sections));
316}
317
318void MIRVocabulary::buildCanonicalOpcodeMapping() {
319 // Check if already built
320 if (!UniqueBaseOpcodeNames.empty())
321 return;
322
323 // Build mapping from opcodes to canonical base opcode indices
324 for (unsigned Opcode = 0; Opcode < TII.getNumOpcodes(); ++Opcode) {
325 std::string BaseOpcode = extractBaseOpcodeName(InstrName: TII.getName(Opcode));
326 UniqueBaseOpcodeNames.insert(x: BaseOpcode);
327 }
328
329 LLVM_DEBUG(dbgs() << "MIR2Vec: Built canonical mapping for target with "
330 << UniqueBaseOpcodeNames.size()
331 << " unique base opcodes\n");
332}
333
334void MIRVocabulary::buildRegisterOperandMapping() {
335 // Check if already built
336 if (!RegisterOperandNames.empty())
337 return;
338
339 for (unsigned RC = 0; RC < TRI.getNumRegClasses(); ++RC) {
340 const TargetRegisterClass *RegClass = TRI.getRegClass(i: RC);
341 if (!RegClass)
342 continue;
343
344 // Get the register class name
345 StringRef ClassName = TRI.getRegClassName(Class: RegClass);
346 RegisterOperandNames.push_back(Elt: ClassName.str());
347 }
348}
349
350unsigned MIRVocabulary::getCommonOperandIndex(
351 MachineOperand::MachineOperandType OperandType) const {
352 assert(OperandType != MachineOperand::MO_Register &&
353 "Expected non-register operand type");
354 assert(OperandType > MachineOperand::MO_Register &&
355 OperandType < MachineOperand::MO_Last && "Operand type out of bounds");
356 return static_cast<unsigned>(OperandType) - 1;
357}
358
359unsigned MIRVocabulary::getRegisterOperandIndex(Register Reg) const {
360 assert(!RegisterOperandNames.empty() && "Register operand mapping not built");
361 assert(Reg.isValid() && "Invalid register; not expected here");
362 assert((Reg.isPhysical() || Reg.isVirtual()) &&
363 "Expected a physical or virtual register");
364
365 const TargetRegisterClass *RegClass = nullptr;
366
367 // For physical registers, use TRI to get minimal register class as a
368 // physical register can belong to multiple classes. For virtual
369 // registers, use MRI to uniquely identify the assigned register class.
370 if (Reg.isPhysical())
371 RegClass = TRI.getMinimalPhysRegClass(Reg);
372 else
373 RegClass = MRI.getRegClass(Reg);
374
375 if (RegClass)
376 return RegClass->getID();
377 // Fallback for registers without a class (shouldn't happen)
378 llvm_unreachable("Register operand without a valid register class");
379 return 0;
380}
381
382Expected<MIRVocabulary> MIRVocabulary::createDummyVocabForTest(
383 const TargetInstrInfo &TII, const TargetRegisterInfo &TRI,
384 const MachineRegisterInfo &MRI, unsigned Dim) {
385 assert(Dim > 0 && "Dimension must be greater than zero");
386
387 float DummyVal = 0.1f;
388
389 VocabMap DummyOpcMap, DummyOperandMap, DummyPhyRegMap, DummyVirtRegMap;
390
391 // Process opcodes directly without creating temporary vocabulary
392 for (unsigned Opcode = 0; Opcode < TII.getNumOpcodes(); ++Opcode) {
393 std::string BaseOpcode = extractBaseOpcodeName(InstrName: TII.getName(Opcode));
394 if (DummyOpcMap.count(x: BaseOpcode) == 0) { // Only add if not already present
395 DummyOpcMap[BaseOpcode] = Embedding(Dim, DummyVal);
396 DummyVal += 0.1f;
397 }
398 }
399
400 // Add common operands
401 for (const auto &CommonOperandName : CommonOperandNames) {
402 DummyOperandMap[CommonOperandName.str()] = Embedding(Dim, DummyVal);
403 DummyVal += 0.1f;
404 }
405
406 // Process register classes directly
407 for (unsigned RC = 0; RC < TRI.getNumRegClasses(); ++RC) {
408 const TargetRegisterClass *RegClass = TRI.getRegClass(i: RC);
409 if (!RegClass)
410 continue;
411
412 std::string ClassName = TRI.getRegClassName(Class: RegClass);
413 DummyPhyRegMap[ClassName] = Embedding(Dim, DummyVal);
414 DummyVirtRegMap[ClassName] = Embedding(Dim, DummyVal);
415 DummyVal += 0.1f;
416 }
417
418 // Create vocabulary directly without temporary instance
419 return MIRVocabulary::create(
420 OpcodeMap: std::move(DummyOpcMap), CommonOperandMap: std::move(DummyOperandMap),
421 PhyRegMap: std::move(DummyPhyRegMap), VirtRegMap: std::move(DummyVirtRegMap), TII, TRI, MRI);
422}
423
424//===----------------------------------------------------------------------===//
425// MIR2VecVocabProvider and MIR2VecVocabLegacyAnalysis
426//===----------------------------------------------------------------------===//
427
428Expected<mir2vec::MIRVocabulary>
429MIR2VecVocabProvider::getVocabulary(const Module &M) {
430 VocabMap OpcVocab, CommonOperandVocab, PhyRegVocabMap, VirtRegVocabMap;
431
432 if (Error Err = readVocabulary(OpcVocab, CommonOperandVocab, PhyRegVocabMap,
433 VirtRegVocabMap))
434 return std::move(Err);
435
436 for (const auto &F : M) {
437 if (F.isDeclaration())
438 continue;
439
440 if (auto *MF = MMI.getMachineFunction(F)) {
441 auto &Subtarget = MF->getSubtarget();
442 if (const auto *TII = Subtarget.getInstrInfo())
443 if (const auto *TRI = Subtarget.getRegisterInfo())
444 return mir2vec::MIRVocabulary::create(
445 OpcodeMap: std::move(OpcVocab), CommonOperandMap: std::move(CommonOperandVocab),
446 PhyRegMap: std::move(PhyRegVocabMap), VirtRegMap: std::move(VirtRegVocabMap), TII: *TII, TRI: *TRI,
447 MRI: MF->getRegInfo());
448 }
449 }
450 return createStringError(EC: errc::invalid_argument,
451 S: "No machine functions found in module");
452}
453
454Error MIR2VecVocabProvider::readVocabulary(VocabMap &OpcodeVocab,
455 VocabMap &CommonOperandVocab,
456 VocabMap &PhyRegVocabMap,
457 VocabMap &VirtRegVocabMap) {
458 if (VocabFile.empty())
459 return createStringError(
460 EC: errc::invalid_argument,
461 S: "MIR2Vec vocabulary file path not specified; set it "
462 "using --mir2vec-vocab-path");
463
464 auto BufOrError = MemoryBuffer::getFileOrSTDIN(Filename: VocabFile, /*IsText=*/true);
465 if (!BufOrError)
466 return createFileError(F: VocabFile, EC: BufOrError.getError());
467
468 auto Content = BufOrError.get()->getBuffer();
469
470 Expected<json::Value> ParsedVocabValue = json::parse(JSON: Content);
471 if (!ParsedVocabValue)
472 return ParsedVocabValue.takeError();
473
474 unsigned OpcodeDim = 0, CommonOperandDim = 0, PhyRegOperandDim = 0,
475 VirtRegOperandDim = 0;
476 if (auto Err = ir2vec::VocabStorage::parseVocabSection(
477 Key: "Opcodes", ParsedVocabValue: *ParsedVocabValue, TargetVocab&: OpcodeVocab, Dim&: OpcodeDim))
478 return Err;
479
480 if (auto Err = ir2vec::VocabStorage::parseVocabSection(
481 Key: "CommonOperands", ParsedVocabValue: *ParsedVocabValue, TargetVocab&: CommonOperandVocab,
482 Dim&: CommonOperandDim))
483 return Err;
484
485 if (auto Err = ir2vec::VocabStorage::parseVocabSection(
486 Key: "PhysicalRegisters", ParsedVocabValue: *ParsedVocabValue, TargetVocab&: PhyRegVocabMap,
487 Dim&: PhyRegOperandDim))
488 return Err;
489
490 if (auto Err = ir2vec::VocabStorage::parseVocabSection(
491 Key: "VirtualRegisters", ParsedVocabValue: *ParsedVocabValue, TargetVocab&: VirtRegVocabMap,
492 Dim&: VirtRegOperandDim))
493 return Err;
494
495 // All sections must have the same embedding dimension
496 if (!(OpcodeDim == CommonOperandDim && CommonOperandDim == PhyRegOperandDim &&
497 PhyRegOperandDim == VirtRegOperandDim)) {
498 return createStringError(
499 EC: errc::illegal_byte_sequence,
500 S: "MIR2Vec vocabulary sections have different dimensions");
501 }
502
503 return Error::success();
504}
505
506char MIR2VecVocabLegacyAnalysis::ID = 0;
507INITIALIZE_PASS_BEGIN(MIR2VecVocabLegacyAnalysis, "mir2vec-vocab-analysis",
508 "MIR2Vec Vocabulary Analysis", false, true)
509INITIALIZE_PASS_DEPENDENCY(MachineModuleInfoWrapperPass)
510INITIALIZE_PASS_END(MIR2VecVocabLegacyAnalysis, "mir2vec-vocab-analysis",
511 "MIR2Vec Vocabulary Analysis", false, true)
512
513StringRef MIR2VecVocabLegacyAnalysis::getPassName() const {
514 return "MIR2Vec Vocabulary Analysis";
515}
516
517//===----------------------------------------------------------------------===//
518// MIREmbedder and its subclasses
519//===----------------------------------------------------------------------===//
520
521std::unique_ptr<MIREmbedder> MIREmbedder::create(MIR2VecKind Mode,
522 const MachineFunction &MF,
523 const MIRVocabulary &Vocab) {
524 switch (Mode) {
525 case MIR2VecKind::Symbolic:
526 return std::make_unique<SymbolicMIREmbedder>(args: MF, args: Vocab);
527 }
528 return nullptr;
529}
530
531Embedding MIREmbedder::computeEmbeddings(const MachineBasicBlock &MBB) const {
532 Embedding MBBVector(Dimension, 0);
533
534 // Get instruction info for opcode name resolution
535 const auto &Subtarget = MF.getSubtarget();
536 const auto *TII = Subtarget.getInstrInfo();
537 if (!TII) {
538 MF.getFunction().getContext().emitError(
539 ErrorStr: "MIR2Vec: No TargetInstrInfo available; cannot compute embeddings");
540 return MBBVector;
541 }
542
543 // Process each machine instruction in the basic block
544 for (const auto &MI : MBB) {
545 // Skip debug instructions and other metadata
546 if (MI.isDebugInstr())
547 continue;
548 MBBVector += computeEmbeddings(MI);
549 }
550
551 return MBBVector;
552}
553
554Embedding MIREmbedder::computeEmbeddings() const {
555 Embedding MFuncVector(Dimension, 0);
556
557 // Consider all reachable machine basic blocks in the function
558 for (const auto *MBB : depth_first(G: &MF))
559 MFuncVector += computeEmbeddings(MBB: *MBB);
560 return MFuncVector;
561}
562
563SymbolicMIREmbedder::SymbolicMIREmbedder(const MachineFunction &MF,
564 const MIRVocabulary &Vocab)
565 : MIREmbedder(MF, Vocab) {}
566
567std::unique_ptr<SymbolicMIREmbedder>
568SymbolicMIREmbedder::create(const MachineFunction &MF,
569 const MIRVocabulary &Vocab) {
570 return std::make_unique<SymbolicMIREmbedder>(args: MF, args: Vocab);
571}
572
573Embedding SymbolicMIREmbedder::computeEmbeddings(const MachineInstr &MI) const {
574 // Skip debug instructions and other metadata
575 if (MI.isDebugInstr())
576 return Embedding(Dimension, 0);
577
578 // Opcode embedding
579 Embedding InstructionEmbedding = Vocab[MI.getOpcode()];
580
581 // Add operand contributions
582 for (const MachineOperand &MO : MI.operands())
583 InstructionEmbedding += Vocab[MO];
584
585 return InstructionEmbedding;
586}
587
588//===----------------------------------------------------------------------===//
589// Printer Passes
590//===----------------------------------------------------------------------===//
591
592char MIR2VecVocabPrinterLegacyPass::ID = 0;
593INITIALIZE_PASS_BEGIN(MIR2VecVocabPrinterLegacyPass, "print-mir2vec-vocab",
594 "MIR2Vec Vocabulary Printer Pass", false, true)
595INITIALIZE_PASS_DEPENDENCY(MIR2VecVocabLegacyAnalysis)
596INITIALIZE_PASS_DEPENDENCY(MachineModuleInfoWrapperPass)
597INITIALIZE_PASS_END(MIR2VecVocabPrinterLegacyPass, "print-mir2vec-vocab",
598 "MIR2Vec Vocabulary Printer Pass", false, true)
599
600bool MIR2VecVocabPrinterLegacyPass::runOnMachineFunction(MachineFunction &MF) {
601 return false;
602}
603
604bool MIR2VecVocabPrinterLegacyPass::doFinalization(Module &M) {
605 auto &Analysis = getAnalysis<MIR2VecVocabLegacyAnalysis>();
606 auto MIR2VecVocabOrErr = Analysis.getMIR2VecVocabulary(M);
607
608 if (!MIR2VecVocabOrErr) {
609 OS << "MIR2Vec Vocabulary Printer: Failed to get vocabulary - "
610 << toString(E: MIR2VecVocabOrErr.takeError()) << "\n";
611 return false;
612 }
613
614 auto &MIR2VecVocab = *MIR2VecVocabOrErr;
615 unsigned Pos = 0;
616 for (const auto &Entry : MIR2VecVocab) {
617 // Skip zero embeddings to avoid printing entries not in the vocabulary.
618 // This makes the output stable across changes to the opcode list.
619 if (PrintAllVocabEntries || !Entry.isZero()) {
620 OS << "Key: " << MIR2VecVocab.getStringKey(Pos) << ": ";
621 Entry.print(OS);
622 }
623 ++Pos;
624 }
625
626 return false;
627}
628
629MachineFunctionPass *
630llvm::createMIR2VecVocabPrinterLegacyPass(raw_ostream &OS) {
631 return new MIR2VecVocabPrinterLegacyPass(OS);
632}
633
634char MIR2VecPrinterLegacyPass::ID = 0;
635INITIALIZE_PASS_BEGIN(MIR2VecPrinterLegacyPass, "print-mir2vec",
636 "MIR2Vec Embedder Printer Pass", false, true)
637INITIALIZE_PASS_DEPENDENCY(MIR2VecVocabLegacyAnalysis)
638INITIALIZE_PASS_DEPENDENCY(MachineModuleInfoWrapperPass)
639INITIALIZE_PASS_END(MIR2VecPrinterLegacyPass, "print-mir2vec",
640 "MIR2Vec Embedder Printer Pass", false, true)
641
642bool MIR2VecPrinterLegacyPass::runOnMachineFunction(MachineFunction &MF) {
643 auto &Analysis = getAnalysis<MIR2VecVocabLegacyAnalysis>();
644 auto VocabOrErr =
645 Analysis.getMIR2VecVocabulary(M: *MF.getFunction().getParent());
646 assert(VocabOrErr && "Failed to get MIR2Vec vocabulary");
647 auto &MIRVocab = *VocabOrErr;
648
649 auto Emb = mir2vec::MIREmbedder::create(Mode: MIR2VecEmbeddingKind, MF, Vocab: MIRVocab);
650 if (!Emb) {
651 OS << "Error creating MIR2Vec embeddings for function " << MF.getName()
652 << "\n";
653 return false;
654 }
655
656 OS << "MIR2Vec embeddings for machine function " << MF.getName() << ":\n";
657 OS << "Machine Function vector: ";
658 Emb->getMFunctionVector().print(OS);
659
660 OS << "Machine basic block vectors:\n";
661 for (const MachineBasicBlock &MBB : MF) {
662 OS << "Machine basic block: " << MBB.getFullName() << ":\n";
663 Emb->getMBBVector(MBB).print(OS);
664 }
665
666 OS << "Machine instruction vectors:\n";
667 for (const MachineBasicBlock &MBB : MF) {
668 for (const MachineInstr &MI : MBB) {
669 // Skip debug instructions as they are not
670 // embedded
671 if (MI.isDebugInstr())
672 continue;
673
674 OS << "Machine instruction: ";
675 MI.print(OS);
676 Emb->getMInstVector(MI).print(OS);
677 }
678 }
679
680 return false;
681}
682
683MachineFunctionPass *llvm::createMIR2VecPrinterLegacyPass(raw_ostream &OS) {
684 return new MIR2VecPrinterLegacyPass(OS);
685}
686