| 1 | //===- MIR2Vec.cpp - Implementation of MIR2Vec ---------------------------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM |
| 4 | // Exceptions. See the LICENSE file for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | /// |
| 9 | /// \file |
| 10 | /// This file implements the MIR2Vec algorithm for Machine IR embeddings. |
| 11 | /// |
| 12 | //===----------------------------------------------------------------------===// |
| 13 | |
| 14 | #include "llvm/CodeGen/MIR2Vec.h" |
| 15 | #include "llvm/ADT/DepthFirstIterator.h" |
| 16 | #include "llvm/ADT/Statistic.h" |
| 17 | #include "llvm/CodeGen/TargetInstrInfo.h" |
| 18 | #include "llvm/IR/Module.h" |
| 19 | #include "llvm/InitializePasses.h" |
| 20 | #include "llvm/Pass.h" |
| 21 | #include "llvm/Support/Errc.h" |
| 22 | #include "llvm/Support/MemoryBuffer.h" |
| 23 | #include "llvm/Support/Regex.h" |
| 24 | |
| 25 | using namespace llvm; |
| 26 | using namespace mir2vec; |
| 27 | |
| 28 | #define DEBUG_TYPE "mir2vec" |
| 29 | |
| 30 | STATISTIC(MIRVocabMissCounter, |
| 31 | "Number of lookups to MIR entities not present in the vocabulary" ); |
| 32 | |
| 33 | namespace llvm { |
| 34 | namespace mir2vec { |
| 35 | cl::OptionCategory MIR2VecCategory("MIR2Vec Options" ); |
| 36 | |
| 37 | // FIXME: Use a default vocab when not specified |
| 38 | static cl::opt<std::string> |
| 39 | VocabFile("mir2vec-vocab-path" , cl::Optional, |
| 40 | cl::desc("Path to the vocabulary file for MIR2Vec" ), cl::init(Val: "" ), |
| 41 | cl::cat(MIR2VecCategory)); |
| 42 | cl::opt<float> OpcWeight("mir2vec-opc-weight" , cl::Optional, cl::init(Val: 1.0), |
| 43 | cl::desc("Weight for machine opcode embeddings" ), |
| 44 | cl::cat(MIR2VecCategory)); |
| 45 | cl::opt<float> CommonOperandWeight( |
| 46 | "mir2vec-common-operand-weight" , cl::Optional, cl::init(Val: 1.0), |
| 47 | cl::desc("Weight for common operand embeddings" ), cl::cat(MIR2VecCategory)); |
| 48 | cl::opt<float> |
| 49 | RegOperandWeight("mir2vec-reg-operand-weight" , cl::Optional, cl::init(Val: 1.0), |
| 50 | cl::desc("Weight for register operand embeddings" ), |
| 51 | cl::cat(MIR2VecCategory)); |
| 52 | cl::opt<MIR2VecKind> MIR2VecEmbeddingKind( |
| 53 | "mir2vec-kind" , cl::Optional, |
| 54 | cl::values(clEnumValN(MIR2VecKind::Symbolic, "symbolic" , |
| 55 | "Generate symbolic embeddings for MIR" )), |
| 56 | cl::init(Val: MIR2VecKind::Symbolic), cl::desc("MIR2Vec embedding kind" ), |
| 57 | cl::cat(MIR2VecCategory)); |
| 58 | |
| 59 | static cl::opt<bool> PrintAllVocabEntries( |
| 60 | "mir2vec-print-all-vocab-entries" , cl::Optional, cl::init(Val: false), |
| 61 | cl::desc("Print all vocabulary entries including zero embeddings" ), |
| 62 | cl::cat(MIR2VecCategory)); |
| 63 | |
| 64 | } // namespace mir2vec |
| 65 | } // namespace llvm |
| 66 | |
| 67 | //===----------------------------------------------------------------------===// |
| 68 | // Vocabulary |
| 69 | //===----------------------------------------------------------------------===// |
| 70 | |
| 71 | MIRVocabulary::MIRVocabulary(VocabMap &&OpcodeMap, VocabMap &&CommonOperandMap, |
| 72 | VocabMap &&PhysicalRegisterMap, |
| 73 | VocabMap &&VirtualRegisterMap, |
| 74 | const TargetInstrInfo &TII, |
| 75 | const TargetRegisterInfo &TRI, |
| 76 | const MachineRegisterInfo &MRI) |
| 77 | : TII(TII), TRI(TRI), MRI(MRI) { |
| 78 | buildCanonicalOpcodeMapping(); |
| 79 | unsigned CanonicalOpcodeCount = UniqueBaseOpcodeNames.size(); |
| 80 | assert(CanonicalOpcodeCount > 0 && |
| 81 | "No canonical opcodes found for target - invalid vocabulary" ); |
| 82 | |
| 83 | buildRegisterOperandMapping(); |
| 84 | |
| 85 | // Define layout of vocabulary sections |
| 86 | Layout.OpcodeBase = 0; |
| 87 | Layout.CommonOperandBase = CanonicalOpcodeCount; |
| 88 | // We expect same classes for physical and virtual registers |
| 89 | Layout.PhyRegBase = Layout.CommonOperandBase + std::size(CommonOperandNames); |
| 90 | Layout.VirtRegBase = Layout.PhyRegBase + RegisterOperandNames.size(); |
| 91 | |
| 92 | generateStorage(OpcodeMap, CommonOperandMap, PhyRegMap: PhysicalRegisterMap, |
| 93 | VirtRegMap: VirtualRegisterMap); |
| 94 | Layout.TotalEntries = Storage.size(); |
| 95 | } |
| 96 | |
| 97 | Expected<MIRVocabulary> |
| 98 | MIRVocabulary::create(VocabMap &&OpcodeMap, VocabMap &&CommonOperandMap, |
| 99 | VocabMap &&PhyRegMap, VocabMap &&VirtRegMap, |
| 100 | const TargetInstrInfo &TII, const TargetRegisterInfo &TRI, |
| 101 | const MachineRegisterInfo &MRI) { |
| 102 | if (OpcodeMap.empty() || CommonOperandMap.empty() || PhyRegMap.empty() || |
| 103 | VirtRegMap.empty()) |
| 104 | return createStringError(EC: errc::invalid_argument, |
| 105 | S: "Empty vocabulary entries provided" ); |
| 106 | |
| 107 | MIRVocabulary Vocab(std::move(OpcodeMap), std::move(CommonOperandMap), |
| 108 | std::move(PhyRegMap), std::move(VirtRegMap), TII, TRI, |
| 109 | MRI); |
| 110 | |
| 111 | // Validate Storage after construction |
| 112 | if (!Vocab.Storage.isValid()) |
| 113 | return createStringError(EC: errc::invalid_argument, |
| 114 | S: "Failed to create valid vocabulary storage" ); |
| 115 | Vocab.ZeroEmbedding = Embedding(Vocab.Storage.getDimension(), 0.0); |
| 116 | return std::move(Vocab); |
| 117 | } |
| 118 | |
| 119 | std::string MIRVocabulary::(StringRef InstrName) { |
| 120 | // Extract base instruction name using regex to capture letters and |
| 121 | // underscores Examples: "ADD32rr" -> "ADD", "ARITH_FENCE" -> "ARITH_FENCE" |
| 122 | // |
| 123 | // TODO: Consider more sophisticated extraction: |
| 124 | // - Handle complex prefixes like "AVX1_SETALLONES" correctly (Currently, it |
| 125 | // would naively map to "AVX") |
| 126 | // - Extract width suffixes (8,16,32,64) as separate features |
| 127 | // - Capture addressing mode suffixes (r,i,m,ri,etc.) for better analysis |
| 128 | // (Currently, instances like "MOV32mi" map to "MOV", but "ADDPDrr" would map |
| 129 | // to "ADDPDrr") |
| 130 | |
| 131 | assert(!InstrName.empty() && "Instruction name should not be empty" ); |
| 132 | |
| 133 | // Use regex to extract initial sequence of letters and underscores |
| 134 | static const Regex BaseOpcodeRegex("([a-zA-Z_]+)" ); |
| 135 | SmallVector<StringRef, 2> Matches; |
| 136 | |
| 137 | if (BaseOpcodeRegex.match(String: InstrName, Matches: &Matches) && Matches.size() > 1) { |
| 138 | StringRef Match = Matches[1]; |
| 139 | // Trim trailing underscores |
| 140 | while (!Match.empty() && Match.back() == '_') |
| 141 | Match = Match.drop_back(); |
| 142 | return Match.str(); |
| 143 | } |
| 144 | |
| 145 | // Fallback to original name if no pattern matches |
| 146 | return InstrName.str(); |
| 147 | } |
| 148 | |
| 149 | unsigned MIRVocabulary::getCanonicalIndexForBaseName(StringRef BaseName) const { |
| 150 | assert(!UniqueBaseOpcodeNames.empty() && "Canonical mapping not built" ); |
| 151 | auto It = std::find(first: UniqueBaseOpcodeNames.begin(), |
| 152 | last: UniqueBaseOpcodeNames.end(), val: BaseName.str()); |
| 153 | assert(It != UniqueBaseOpcodeNames.end() && |
| 154 | "Base name not found in unique opcodes" ); |
| 155 | return std::distance(first: UniqueBaseOpcodeNames.begin(), last: It); |
| 156 | } |
| 157 | |
| 158 | unsigned MIRVocabulary::getCanonicalOpcodeIndex(unsigned Opcode) const { |
| 159 | auto BaseOpcode = extractBaseOpcodeName(InstrName: TII.getName(Opcode)); |
| 160 | return getCanonicalIndexForBaseName(BaseName: BaseOpcode); |
| 161 | } |
| 162 | |
| 163 | unsigned |
| 164 | MIRVocabulary::getCanonicalIndexForOperandName(StringRef OperandName) const { |
| 165 | auto It = std::find(first: std::begin(arr: CommonOperandNames), |
| 166 | last: std::end(arr: CommonOperandNames), val: OperandName); |
| 167 | assert(It != std::end(CommonOperandNames) && |
| 168 | "Operand name not found in common operands" ); |
| 169 | return Layout.CommonOperandBase + |
| 170 | std::distance(first: std::begin(arr: CommonOperandNames), last: It); |
| 171 | } |
| 172 | |
| 173 | unsigned |
| 174 | MIRVocabulary::getCanonicalIndexForRegisterClass(StringRef RegName, |
| 175 | bool IsPhysical) const { |
| 176 | auto It = std::find(first: RegisterOperandNames.begin(), last: RegisterOperandNames.end(), |
| 177 | val: RegName); |
| 178 | assert(It != RegisterOperandNames.end() && |
| 179 | "Register name not found in register operands" ); |
| 180 | unsigned LocalIndex = std::distance(first: RegisterOperandNames.begin(), last: It); |
| 181 | return (IsPhysical ? Layout.PhyRegBase : Layout.VirtRegBase) + LocalIndex; |
| 182 | } |
| 183 | |
| 184 | std::string MIRVocabulary::getStringKey(unsigned Pos) const { |
| 185 | assert(Pos < Layout.TotalEntries && "Position out of bounds in vocabulary" ); |
| 186 | |
| 187 | // Handle opcodes section |
| 188 | if (Pos < Layout.CommonOperandBase) { |
| 189 | // Convert canonical index back to base opcode name |
| 190 | auto It = UniqueBaseOpcodeNames.begin(); |
| 191 | std::advance(i&: It, n: Pos); |
| 192 | assert(It != UniqueBaseOpcodeNames.end() && |
| 193 | "Canonical index out of bounds in opcode section" ); |
| 194 | return *It; |
| 195 | } |
| 196 | |
| 197 | auto getLocalIndex = [](unsigned Pos, size_t BaseOffset, size_t Bound, |
| 198 | const char *Msg) { |
| 199 | unsigned LocalIndex = Pos - BaseOffset; |
| 200 | assert(LocalIndex < Bound && Msg); |
| 201 | return LocalIndex; |
| 202 | }; |
| 203 | |
| 204 | // Handle common operands section |
| 205 | if (Pos < Layout.PhyRegBase) { |
| 206 | unsigned LocalIndex = getLocalIndex( |
| 207 | Pos, Layout.CommonOperandBase, std::size(CommonOperandNames), |
| 208 | "Local index out of bounds in common operands" ); |
| 209 | return CommonOperandNames[LocalIndex].str(); |
| 210 | } |
| 211 | |
| 212 | // Handle physical registers section |
| 213 | if (Pos < Layout.VirtRegBase) { |
| 214 | unsigned LocalIndex = |
| 215 | getLocalIndex(Pos, Layout.PhyRegBase, RegisterOperandNames.size(), |
| 216 | "Local index out of bounds in physical registers" ); |
| 217 | return "PhyReg_" + RegisterOperandNames[LocalIndex]; |
| 218 | } |
| 219 | |
| 220 | // Handle virtual registers section |
| 221 | unsigned LocalIndex = |
| 222 | getLocalIndex(Pos, Layout.VirtRegBase, RegisterOperandNames.size(), |
| 223 | "Local index out of bounds in virtual registers" ); |
| 224 | return "VirtReg_" + RegisterOperandNames[LocalIndex]; |
| 225 | } |
| 226 | |
| 227 | void MIRVocabulary::generateStorage(const VocabMap &OpcodeMap, |
| 228 | const VocabMap &CommonOperandsMap, |
| 229 | const VocabMap &PhyRegMap, |
| 230 | const VocabMap &VirtRegMap) { |
| 231 | |
| 232 | // Helper for handling missing entities in the vocabulary. |
| 233 | // Currently, we use a zero vector. In the future, we will throw an error to |
| 234 | // ensure that *all* known entities are present in the vocabulary. |
| 235 | auto handleMissingEntity = [](StringRef Key) { |
| 236 | LLVM_DEBUG(errs() << "MIR2Vec: Missing vocabulary entry for " << Key |
| 237 | << "; using zero vector. This will result in an error " |
| 238 | "in the future.\n" ); |
| 239 | ++MIRVocabMissCounter; |
| 240 | }; |
| 241 | |
| 242 | // Initialize opcode embeddings section |
| 243 | unsigned EmbeddingDim = OpcodeMap.begin()->second.size(); |
| 244 | std::vector<Embedding> OpcodeEmbeddings(Layout.CommonOperandBase, |
| 245 | Embedding(EmbeddingDim)); |
| 246 | |
| 247 | // Populate opcode embeddings using canonical mapping |
| 248 | for (auto COpcodeName : UniqueBaseOpcodeNames) { |
| 249 | if (auto It = OpcodeMap.find(x: COpcodeName); It != OpcodeMap.end()) { |
| 250 | auto COpcodeIndex = getCanonicalIndexForBaseName(BaseName: COpcodeName); |
| 251 | assert(COpcodeIndex < Layout.CommonOperandBase && |
| 252 | "Canonical index out of bounds" ); |
| 253 | OpcodeEmbeddings[COpcodeIndex] = It->second; |
| 254 | } else { |
| 255 | handleMissingEntity(COpcodeName); |
| 256 | } |
| 257 | } |
| 258 | |
| 259 | // Initialize common operand embeddings section |
| 260 | std::vector<Embedding> CommonOperandEmbeddings(std::size(CommonOperandNames), |
| 261 | Embedding(EmbeddingDim)); |
| 262 | unsigned OperandIndex = 0; |
| 263 | for (const auto &CommonOperandName : CommonOperandNames) { |
| 264 | if (auto It = CommonOperandsMap.find(x: CommonOperandName.str()); |
| 265 | It != CommonOperandsMap.end()) { |
| 266 | CommonOperandEmbeddings[OperandIndex] = It->second; |
| 267 | } else { |
| 268 | handleMissingEntity(CommonOperandName); |
| 269 | } |
| 270 | ++OperandIndex; |
| 271 | } |
| 272 | |
| 273 | // Helper lambda for creating register operand embeddings |
| 274 | auto createRegisterEmbeddings = [&](const VocabMap &RegMap) { |
| 275 | std::vector<Embedding> RegEmbeddings(TRI.getNumRegClasses(), |
| 276 | Embedding(EmbeddingDim)); |
| 277 | unsigned RegOperandIndex = 0; |
| 278 | for (const auto &RegOperandName : RegisterOperandNames) { |
| 279 | if (auto It = RegMap.find(x: RegOperandName); It != RegMap.end()) |
| 280 | RegEmbeddings[RegOperandIndex] = It->second; |
| 281 | else |
| 282 | handleMissingEntity(RegOperandName); |
| 283 | ++RegOperandIndex; |
| 284 | } |
| 285 | return RegEmbeddings; |
| 286 | }; |
| 287 | |
| 288 | // Initialize register operand embeddings sections |
| 289 | std::vector<Embedding> PhyRegEmbeddings = createRegisterEmbeddings(PhyRegMap); |
| 290 | std::vector<Embedding> VirtRegEmbeddings = |
| 291 | createRegisterEmbeddings(VirtRegMap); |
| 292 | |
| 293 | // Scale the vocabulary sections based on the provided weights |
| 294 | auto scaleVocabSection = [](std::vector<Embedding> &Embeddings, |
| 295 | double Weight) { |
| 296 | for (auto &Embedding : Embeddings) |
| 297 | Embedding *= Weight; |
| 298 | }; |
| 299 | scaleVocabSection(OpcodeEmbeddings, OpcWeight); |
| 300 | scaleVocabSection(CommonOperandEmbeddings, CommonOperandWeight); |
| 301 | scaleVocabSection(PhyRegEmbeddings, RegOperandWeight); |
| 302 | scaleVocabSection(VirtRegEmbeddings, RegOperandWeight); |
| 303 | |
| 304 | std::vector<std::vector<Embedding>> Sections( |
| 305 | static_cast<unsigned>(Section::MaxSections)); |
| 306 | Sections[static_cast<unsigned>(Section::Opcodes)] = |
| 307 | std::move(OpcodeEmbeddings); |
| 308 | Sections[static_cast<unsigned>(Section::CommonOperands)] = |
| 309 | std::move(CommonOperandEmbeddings); |
| 310 | Sections[static_cast<unsigned>(Section::PhyRegisters)] = |
| 311 | std::move(PhyRegEmbeddings); |
| 312 | Sections[static_cast<unsigned>(Section::VirtRegisters)] = |
| 313 | std::move(VirtRegEmbeddings); |
| 314 | |
| 315 | Storage = ir2vec::VocabStorage(std::move(Sections)); |
| 316 | } |
| 317 | |
| 318 | void MIRVocabulary::buildCanonicalOpcodeMapping() { |
| 319 | // Check if already built |
| 320 | if (!UniqueBaseOpcodeNames.empty()) |
| 321 | return; |
| 322 | |
| 323 | // Build mapping from opcodes to canonical base opcode indices |
| 324 | for (unsigned Opcode = 0; Opcode < TII.getNumOpcodes(); ++Opcode) { |
| 325 | std::string BaseOpcode = extractBaseOpcodeName(InstrName: TII.getName(Opcode)); |
| 326 | UniqueBaseOpcodeNames.insert(x: BaseOpcode); |
| 327 | } |
| 328 | |
| 329 | LLVM_DEBUG(dbgs() << "MIR2Vec: Built canonical mapping for target with " |
| 330 | << UniqueBaseOpcodeNames.size() |
| 331 | << " unique base opcodes\n" ); |
| 332 | } |
| 333 | |
| 334 | void MIRVocabulary::buildRegisterOperandMapping() { |
| 335 | // Check if already built |
| 336 | if (!RegisterOperandNames.empty()) |
| 337 | return; |
| 338 | |
| 339 | for (unsigned RC = 0; RC < TRI.getNumRegClasses(); ++RC) { |
| 340 | const TargetRegisterClass *RegClass = TRI.getRegClass(i: RC); |
| 341 | if (!RegClass) |
| 342 | continue; |
| 343 | |
| 344 | // Get the register class name |
| 345 | StringRef ClassName = TRI.getRegClassName(Class: RegClass); |
| 346 | RegisterOperandNames.push_back(Elt: ClassName.str()); |
| 347 | } |
| 348 | } |
| 349 | |
| 350 | unsigned MIRVocabulary::getCommonOperandIndex( |
| 351 | MachineOperand::MachineOperandType OperandType) const { |
| 352 | assert(OperandType != MachineOperand::MO_Register && |
| 353 | "Expected non-register operand type" ); |
| 354 | assert(OperandType > MachineOperand::MO_Register && |
| 355 | OperandType < MachineOperand::MO_Last && "Operand type out of bounds" ); |
| 356 | return static_cast<unsigned>(OperandType) - 1; |
| 357 | } |
| 358 | |
| 359 | unsigned MIRVocabulary::getRegisterOperandIndex(Register Reg) const { |
| 360 | assert(!RegisterOperandNames.empty() && "Register operand mapping not built" ); |
| 361 | assert(Reg.isValid() && "Invalid register; not expected here" ); |
| 362 | assert((Reg.isPhysical() || Reg.isVirtual()) && |
| 363 | "Expected a physical or virtual register" ); |
| 364 | |
| 365 | const TargetRegisterClass *RegClass = nullptr; |
| 366 | |
| 367 | // For physical registers, use TRI to get minimal register class as a |
| 368 | // physical register can belong to multiple classes. For virtual |
| 369 | // registers, use MRI to uniquely identify the assigned register class. |
| 370 | if (Reg.isPhysical()) |
| 371 | RegClass = TRI.getMinimalPhysRegClass(Reg); |
| 372 | else |
| 373 | RegClass = MRI.getRegClass(Reg); |
| 374 | |
| 375 | if (RegClass) |
| 376 | return RegClass->getID(); |
| 377 | // Fallback for registers without a class (shouldn't happen) |
| 378 | llvm_unreachable("Register operand without a valid register class" ); |
| 379 | return 0; |
| 380 | } |
| 381 | |
| 382 | Expected<MIRVocabulary> MIRVocabulary::createDummyVocabForTest( |
| 383 | const TargetInstrInfo &TII, const TargetRegisterInfo &TRI, |
| 384 | const MachineRegisterInfo &MRI, unsigned Dim) { |
| 385 | assert(Dim > 0 && "Dimension must be greater than zero" ); |
| 386 | |
| 387 | float DummyVal = 0.1f; |
| 388 | |
| 389 | VocabMap DummyOpcMap, DummyOperandMap, DummyPhyRegMap, DummyVirtRegMap; |
| 390 | |
| 391 | // Process opcodes directly without creating temporary vocabulary |
| 392 | for (unsigned Opcode = 0; Opcode < TII.getNumOpcodes(); ++Opcode) { |
| 393 | std::string BaseOpcode = extractBaseOpcodeName(InstrName: TII.getName(Opcode)); |
| 394 | if (DummyOpcMap.count(x: BaseOpcode) == 0) { // Only add if not already present |
| 395 | DummyOpcMap[BaseOpcode] = Embedding(Dim, DummyVal); |
| 396 | DummyVal += 0.1f; |
| 397 | } |
| 398 | } |
| 399 | |
| 400 | // Add common operands |
| 401 | for (const auto &CommonOperandName : CommonOperandNames) { |
| 402 | DummyOperandMap[CommonOperandName.str()] = Embedding(Dim, DummyVal); |
| 403 | DummyVal += 0.1f; |
| 404 | } |
| 405 | |
| 406 | // Process register classes directly |
| 407 | for (unsigned RC = 0; RC < TRI.getNumRegClasses(); ++RC) { |
| 408 | const TargetRegisterClass *RegClass = TRI.getRegClass(i: RC); |
| 409 | if (!RegClass) |
| 410 | continue; |
| 411 | |
| 412 | std::string ClassName = TRI.getRegClassName(Class: RegClass); |
| 413 | DummyPhyRegMap[ClassName] = Embedding(Dim, DummyVal); |
| 414 | DummyVirtRegMap[ClassName] = Embedding(Dim, DummyVal); |
| 415 | DummyVal += 0.1f; |
| 416 | } |
| 417 | |
| 418 | // Create vocabulary directly without temporary instance |
| 419 | return MIRVocabulary::create( |
| 420 | OpcodeMap: std::move(DummyOpcMap), CommonOperandMap: std::move(DummyOperandMap), |
| 421 | PhyRegMap: std::move(DummyPhyRegMap), VirtRegMap: std::move(DummyVirtRegMap), TII, TRI, MRI); |
| 422 | } |
| 423 | |
| 424 | //===----------------------------------------------------------------------===// |
| 425 | // MIR2VecVocabProvider and MIR2VecVocabLegacyAnalysis |
| 426 | //===----------------------------------------------------------------------===// |
| 427 | |
| 428 | Expected<mir2vec::MIRVocabulary> |
| 429 | MIR2VecVocabProvider::getVocabulary(const Module &M) { |
| 430 | VocabMap OpcVocab, CommonOperandVocab, PhyRegVocabMap, VirtRegVocabMap; |
| 431 | |
| 432 | if (Error Err = readVocabulary(OpcVocab, CommonOperandVocab, PhyRegVocabMap, |
| 433 | VirtRegVocabMap)) |
| 434 | return std::move(Err); |
| 435 | |
| 436 | for (const auto &F : M) { |
| 437 | if (F.isDeclaration()) |
| 438 | continue; |
| 439 | |
| 440 | if (auto *MF = MMI.getMachineFunction(F)) { |
| 441 | auto &Subtarget = MF->getSubtarget(); |
| 442 | if (const auto *TII = Subtarget.getInstrInfo()) |
| 443 | if (const auto *TRI = Subtarget.getRegisterInfo()) |
| 444 | return mir2vec::MIRVocabulary::create( |
| 445 | OpcodeMap: std::move(OpcVocab), CommonOperandMap: std::move(CommonOperandVocab), |
| 446 | PhyRegMap: std::move(PhyRegVocabMap), VirtRegMap: std::move(VirtRegVocabMap), TII: *TII, TRI: *TRI, |
| 447 | MRI: MF->getRegInfo()); |
| 448 | } |
| 449 | } |
| 450 | return createStringError(EC: errc::invalid_argument, |
| 451 | S: "No machine functions found in module" ); |
| 452 | } |
| 453 | |
| 454 | Error MIR2VecVocabProvider::readVocabulary(VocabMap &OpcodeVocab, |
| 455 | VocabMap &CommonOperandVocab, |
| 456 | VocabMap &PhyRegVocabMap, |
| 457 | VocabMap &VirtRegVocabMap) { |
| 458 | if (VocabFile.empty()) |
| 459 | return createStringError( |
| 460 | EC: errc::invalid_argument, |
| 461 | S: "MIR2Vec vocabulary file path not specified; set it " |
| 462 | "using --mir2vec-vocab-path" ); |
| 463 | |
| 464 | auto BufOrError = MemoryBuffer::getFileOrSTDIN(Filename: VocabFile, /*IsText=*/true); |
| 465 | if (!BufOrError) |
| 466 | return createFileError(F: VocabFile, EC: BufOrError.getError()); |
| 467 | |
| 468 | auto Content = BufOrError.get()->getBuffer(); |
| 469 | |
| 470 | Expected<json::Value> ParsedVocabValue = json::parse(JSON: Content); |
| 471 | if (!ParsedVocabValue) |
| 472 | return ParsedVocabValue.takeError(); |
| 473 | |
| 474 | unsigned OpcodeDim = 0, CommonOperandDim = 0, PhyRegOperandDim = 0, |
| 475 | VirtRegOperandDim = 0; |
| 476 | if (auto Err = ir2vec::VocabStorage::parseVocabSection( |
| 477 | Key: "Opcodes" , ParsedVocabValue: *ParsedVocabValue, TargetVocab&: OpcodeVocab, Dim&: OpcodeDim)) |
| 478 | return Err; |
| 479 | |
| 480 | if (auto Err = ir2vec::VocabStorage::parseVocabSection( |
| 481 | Key: "CommonOperands" , ParsedVocabValue: *ParsedVocabValue, TargetVocab&: CommonOperandVocab, |
| 482 | Dim&: CommonOperandDim)) |
| 483 | return Err; |
| 484 | |
| 485 | if (auto Err = ir2vec::VocabStorage::parseVocabSection( |
| 486 | Key: "PhysicalRegisters" , ParsedVocabValue: *ParsedVocabValue, TargetVocab&: PhyRegVocabMap, |
| 487 | Dim&: PhyRegOperandDim)) |
| 488 | return Err; |
| 489 | |
| 490 | if (auto Err = ir2vec::VocabStorage::parseVocabSection( |
| 491 | Key: "VirtualRegisters" , ParsedVocabValue: *ParsedVocabValue, TargetVocab&: VirtRegVocabMap, |
| 492 | Dim&: VirtRegOperandDim)) |
| 493 | return Err; |
| 494 | |
| 495 | // All sections must have the same embedding dimension |
| 496 | if (!(OpcodeDim == CommonOperandDim && CommonOperandDim == PhyRegOperandDim && |
| 497 | PhyRegOperandDim == VirtRegOperandDim)) { |
| 498 | return createStringError( |
| 499 | EC: errc::illegal_byte_sequence, |
| 500 | S: "MIR2Vec vocabulary sections have different dimensions" ); |
| 501 | } |
| 502 | |
| 503 | return Error::success(); |
| 504 | } |
| 505 | |
| 506 | char MIR2VecVocabLegacyAnalysis::ID = 0; |
| 507 | INITIALIZE_PASS_BEGIN(MIR2VecVocabLegacyAnalysis, "mir2vec-vocab-analysis" , |
| 508 | "MIR2Vec Vocabulary Analysis" , false, true) |
| 509 | INITIALIZE_PASS_DEPENDENCY(MachineModuleInfoWrapperPass) |
| 510 | INITIALIZE_PASS_END(MIR2VecVocabLegacyAnalysis, "mir2vec-vocab-analysis" , |
| 511 | "MIR2Vec Vocabulary Analysis" , false, true) |
| 512 | |
| 513 | StringRef MIR2VecVocabLegacyAnalysis::getPassName() const { |
| 514 | return "MIR2Vec Vocabulary Analysis" ; |
| 515 | } |
| 516 | |
| 517 | //===----------------------------------------------------------------------===// |
| 518 | // MIREmbedder and its subclasses |
| 519 | //===----------------------------------------------------------------------===// |
| 520 | |
| 521 | std::unique_ptr<MIREmbedder> MIREmbedder::create(MIR2VecKind Mode, |
| 522 | const MachineFunction &MF, |
| 523 | const MIRVocabulary &Vocab) { |
| 524 | switch (Mode) { |
| 525 | case MIR2VecKind::Symbolic: |
| 526 | return std::make_unique<SymbolicMIREmbedder>(args: MF, args: Vocab); |
| 527 | } |
| 528 | return nullptr; |
| 529 | } |
| 530 | |
| 531 | Embedding MIREmbedder::computeEmbeddings(const MachineBasicBlock &MBB) const { |
| 532 | Embedding MBBVector(Dimension, 0); |
| 533 | |
| 534 | // Get instruction info for opcode name resolution |
| 535 | const auto &Subtarget = MF.getSubtarget(); |
| 536 | const auto *TII = Subtarget.getInstrInfo(); |
| 537 | if (!TII) { |
| 538 | MF.getFunction().getContext().emitError( |
| 539 | ErrorStr: "MIR2Vec: No TargetInstrInfo available; cannot compute embeddings" ); |
| 540 | return MBBVector; |
| 541 | } |
| 542 | |
| 543 | // Process each machine instruction in the basic block |
| 544 | for (const auto &MI : MBB) { |
| 545 | // Skip debug instructions and other metadata |
| 546 | if (MI.isDebugInstr()) |
| 547 | continue; |
| 548 | MBBVector += computeEmbeddings(MI); |
| 549 | } |
| 550 | |
| 551 | return MBBVector; |
| 552 | } |
| 553 | |
| 554 | Embedding MIREmbedder::computeEmbeddings() const { |
| 555 | Embedding MFuncVector(Dimension, 0); |
| 556 | |
| 557 | // Consider all reachable machine basic blocks in the function |
| 558 | for (const auto *MBB : depth_first(G: &MF)) |
| 559 | MFuncVector += computeEmbeddings(MBB: *MBB); |
| 560 | return MFuncVector; |
| 561 | } |
| 562 | |
| 563 | SymbolicMIREmbedder::SymbolicMIREmbedder(const MachineFunction &MF, |
| 564 | const MIRVocabulary &Vocab) |
| 565 | : MIREmbedder(MF, Vocab) {} |
| 566 | |
| 567 | std::unique_ptr<SymbolicMIREmbedder> |
| 568 | SymbolicMIREmbedder::create(const MachineFunction &MF, |
| 569 | const MIRVocabulary &Vocab) { |
| 570 | return std::make_unique<SymbolicMIREmbedder>(args: MF, args: Vocab); |
| 571 | } |
| 572 | |
| 573 | Embedding SymbolicMIREmbedder::computeEmbeddings(const MachineInstr &MI) const { |
| 574 | // Skip debug instructions and other metadata |
| 575 | if (MI.isDebugInstr()) |
| 576 | return Embedding(Dimension, 0); |
| 577 | |
| 578 | // Opcode embedding |
| 579 | Embedding InstructionEmbedding = Vocab[MI.getOpcode()]; |
| 580 | |
| 581 | // Add operand contributions |
| 582 | for (const MachineOperand &MO : MI.operands()) |
| 583 | InstructionEmbedding += Vocab[MO]; |
| 584 | |
| 585 | return InstructionEmbedding; |
| 586 | } |
| 587 | |
| 588 | //===----------------------------------------------------------------------===// |
| 589 | // Printer Passes |
| 590 | //===----------------------------------------------------------------------===// |
| 591 | |
| 592 | char MIR2VecVocabPrinterLegacyPass::ID = 0; |
| 593 | INITIALIZE_PASS_BEGIN(MIR2VecVocabPrinterLegacyPass, "print-mir2vec-vocab" , |
| 594 | "MIR2Vec Vocabulary Printer Pass" , false, true) |
| 595 | INITIALIZE_PASS_DEPENDENCY(MIR2VecVocabLegacyAnalysis) |
| 596 | INITIALIZE_PASS_DEPENDENCY(MachineModuleInfoWrapperPass) |
| 597 | INITIALIZE_PASS_END(MIR2VecVocabPrinterLegacyPass, "print-mir2vec-vocab" , |
| 598 | "MIR2Vec Vocabulary Printer Pass" , false, true) |
| 599 | |
| 600 | bool MIR2VecVocabPrinterLegacyPass::runOnMachineFunction(MachineFunction &MF) { |
| 601 | return false; |
| 602 | } |
| 603 | |
| 604 | bool MIR2VecVocabPrinterLegacyPass::doFinalization(Module &M) { |
| 605 | auto &Analysis = getAnalysis<MIR2VecVocabLegacyAnalysis>(); |
| 606 | auto MIR2VecVocabOrErr = Analysis.getMIR2VecVocabulary(M); |
| 607 | |
| 608 | if (!MIR2VecVocabOrErr) { |
| 609 | OS << "MIR2Vec Vocabulary Printer: Failed to get vocabulary - " |
| 610 | << toString(E: MIR2VecVocabOrErr.takeError()) << "\n" ; |
| 611 | return false; |
| 612 | } |
| 613 | |
| 614 | auto &MIR2VecVocab = *MIR2VecVocabOrErr; |
| 615 | unsigned Pos = 0; |
| 616 | for (const auto &Entry : MIR2VecVocab) { |
| 617 | // Skip zero embeddings to avoid printing entries not in the vocabulary. |
| 618 | // This makes the output stable across changes to the opcode list. |
| 619 | if (PrintAllVocabEntries || !Entry.isZero()) { |
| 620 | OS << "Key: " << MIR2VecVocab.getStringKey(Pos) << ": " ; |
| 621 | Entry.print(OS); |
| 622 | } |
| 623 | ++Pos; |
| 624 | } |
| 625 | |
| 626 | return false; |
| 627 | } |
| 628 | |
| 629 | MachineFunctionPass * |
| 630 | llvm::createMIR2VecVocabPrinterLegacyPass(raw_ostream &OS) { |
| 631 | return new MIR2VecVocabPrinterLegacyPass(OS); |
| 632 | } |
| 633 | |
| 634 | char MIR2VecPrinterLegacyPass::ID = 0; |
| 635 | INITIALIZE_PASS_BEGIN(MIR2VecPrinterLegacyPass, "print-mir2vec" , |
| 636 | "MIR2Vec Embedder Printer Pass" , false, true) |
| 637 | INITIALIZE_PASS_DEPENDENCY(MIR2VecVocabLegacyAnalysis) |
| 638 | INITIALIZE_PASS_DEPENDENCY(MachineModuleInfoWrapperPass) |
| 639 | INITIALIZE_PASS_END(MIR2VecPrinterLegacyPass, "print-mir2vec" , |
| 640 | "MIR2Vec Embedder Printer Pass" , false, true) |
| 641 | |
| 642 | bool MIR2VecPrinterLegacyPass::runOnMachineFunction(MachineFunction &MF) { |
| 643 | auto &Analysis = getAnalysis<MIR2VecVocabLegacyAnalysis>(); |
| 644 | auto VocabOrErr = |
| 645 | Analysis.getMIR2VecVocabulary(M: *MF.getFunction().getParent()); |
| 646 | assert(VocabOrErr && "Failed to get MIR2Vec vocabulary" ); |
| 647 | auto &MIRVocab = *VocabOrErr; |
| 648 | |
| 649 | auto Emb = mir2vec::MIREmbedder::create(Mode: MIR2VecEmbeddingKind, MF, Vocab: MIRVocab); |
| 650 | if (!Emb) { |
| 651 | OS << "Error creating MIR2Vec embeddings for function " << MF.getName() |
| 652 | << "\n" ; |
| 653 | return false; |
| 654 | } |
| 655 | |
| 656 | OS << "MIR2Vec embeddings for machine function " << MF.getName() << ":\n" ; |
| 657 | OS << "Machine Function vector: " ; |
| 658 | Emb->getMFunctionVector().print(OS); |
| 659 | |
| 660 | OS << "Machine basic block vectors:\n" ; |
| 661 | for (const MachineBasicBlock &MBB : MF) { |
| 662 | OS << "Machine basic block: " << MBB.getFullName() << ":\n" ; |
| 663 | Emb->getMBBVector(MBB).print(OS); |
| 664 | } |
| 665 | |
| 666 | OS << "Machine instruction vectors:\n" ; |
| 667 | for (const MachineBasicBlock &MBB : MF) { |
| 668 | for (const MachineInstr &MI : MBB) { |
| 669 | // Skip debug instructions as they are not |
| 670 | // embedded |
| 671 | if (MI.isDebugInstr()) |
| 672 | continue; |
| 673 | |
| 674 | OS << "Machine instruction: " ; |
| 675 | MI.print(OS); |
| 676 | Emb->getMInstVector(MI).print(OS); |
| 677 | } |
| 678 | } |
| 679 | |
| 680 | return false; |
| 681 | } |
| 682 | |
| 683 | MachineFunctionPass *llvm::createMIR2VecPrinterLegacyPass(raw_ostream &OS) { |
| 684 | return new MIR2VecPrinterLegacyPass(OS); |
| 685 | } |
| 686 | |