1//===- llvm-ir2vec.cpp - IR2Vec/MIR2Vec Embedding Generation Tool --------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8///
9/// \file
10/// This file implements the IR2Vec and MIR2Vec embedding generation tool.
11///
12/// This tool supports two modes:
13/// - LLVM IR mode (-mode=llvm): Process LLVM IR
14/// - Machine IR mode (-mode=mir): Process Machine IR
15///
16/// Available subcommands:
17///
18/// 1. Triplet Generation (triplets):
19/// Generates numeric triplets (head, tail, relation) for vocabulary
20/// training. Output format: MAX_RELATION=N header followed by
21/// head\ttail\trelation lines. Relations: 0=Type, 1=Next, 2+=Arg0,Arg1,...
22///
23/// For LLVM IR:
24/// llvm-ir2vec triplets input.bc -o train2id.txt
25///
26/// For Machine IR:
27/// llvm-ir2vec triplets -mode=mir input.mir -o train2id.txt
28///
29/// 2. Entity Mappings (entities):
30/// Generates entity mappings for vocabulary training.
31/// Output format: <total_entities> header followed by entity\tid lines.
32///
33/// For LLVM IR:
34/// llvm-ir2vec entities input.bc -o entity2id.txt
35///
36/// For Machine IR:
37/// llvm-ir2vec entities -mode=mir input.mir -o entity2id.txt
38///
39/// 3. Embedding Generation (embeddings):
40/// Generates IR2Vec/MIR2Vec embeddings using a trained vocabulary.
41///
42/// For LLVM IR:
43/// llvm-ir2vec embeddings --ir2vec-vocab-path=vocab.json
44/// --ir2vec-kind=<kind> --level=<level> input.bc -o embeddings.txt
45/// Kind: --ir2vec-kind=symbolic (default), --ir2vec-kind=flow-aware
46///
47/// For Machine IR:
48/// llvm-ir2vec embeddings -mode=mir --mir2vec-vocab-path=vocab.json
49/// --level=<level> input.mir -o embeddings.txt
50///
51/// Levels: --level=inst (instructions), --level=bb (basic blocks),
52/// --level=func (functions) (See IR2Vec.cpp/MIR2Vec.cpp for more embedding
53/// generation options)
54///
55//===----------------------------------------------------------------------===//
56
57#include "IRUtils/IRUtils.h"
58#include "MIRUtils/MIRUtils.h"
59#include "llvm/ADT/ArrayRef.h"
60#include "llvm/Analysis/IR2Vec.h"
61#include "llvm/CodeGen/CommandFlags.h"
62#include "llvm/CodeGen/MIR2Vec.h"
63#include "llvm/CodeGen/MIRParser/MIRParser.h"
64#include "llvm/CodeGen/MachineFunction.h"
65#include "llvm/CodeGen/MachineModuleInfo.h"
66#include "llvm/CodeGen/TargetInstrInfo.h"
67#include "llvm/CodeGen/TargetRegisterInfo.h"
68#include "llvm/IR/BasicBlock.h"
69#include "llvm/IR/Function.h"
70#include "llvm/IR/InstIterator.h"
71#include "llvm/IR/Instructions.h"
72#include "llvm/IR/LLVMContext.h"
73#include "llvm/IR/Module.h"
74#include "llvm/IR/PassInstrumentation.h"
75#include "llvm/IR/PassManager.h"
76#include "llvm/IR/Type.h"
77#include "llvm/IRReader/IRReader.h"
78#include "llvm/MC/TargetRegistry.h"
79#include "llvm/Support/CommandLine.h"
80#include "llvm/Support/Debug.h"
81#include "llvm/Support/Errc.h"
82#include "llvm/Support/InitLLVM.h"
83#include "llvm/Support/SourceMgr.h"
84#include "llvm/Support/TargetSelect.h"
85#include "llvm/Support/WithColor.h"
86#include "llvm/Support/raw_ostream.h"
87#include "llvm/Target/TargetMachine.h"
88#include "llvm/TargetParser/Host.h"
89
90#define DEBUG_TYPE "ir2vec"
91
92namespace llvm {
93
94// Common option category for options shared between IR2Vec and MIR2Vec
95static cl::OptionCategory CommonCategory("Common Options",
96 "Options applicable to both IR2Vec "
97 "and MIR2Vec modes");
98
99enum IRKind {
100 LLVMIR = 0, ///< LLVM IR
101 MIR ///< Machine IR
102};
103
104static cl::opt<IRKind>
105 IRMode("mode", cl::desc("Tool operation mode:"),
106 cl::values(clEnumValN(LLVMIR, "llvm", "Process LLVM IR"),
107 clEnumValN(MIR, "mir", "Process Machine IR")),
108 cl::init(Val: LLVMIR), cl::cat(CommonCategory));
109
110// Subcommands
111static cl::SubCommand
112 TripletsSubCmd("triplets", "Generate triplets for vocabulary training");
113static cl::SubCommand
114 EntitiesSubCmd("entities",
115 "Generate entity mappings for vocabulary training");
116static cl::SubCommand
117 EmbeddingsSubCmd("embeddings",
118 "Generate embeddings using trained vocabulary");
119
120// Common options
121static cl::opt<std::string> InputFilename(
122 cl::Positional, cl::desc("<input bitcode/MIR file or '-' for stdin>"),
123 cl::init(Val: "-"), cl::sub(TripletsSubCmd), cl::sub(EntitiesSubCmd),
124 cl::sub(EmbeddingsSubCmd), cl::cat(CommonCategory));
125
126static cl::opt<std::string> OutputFilename("o", cl::desc("Output filename"),
127 cl::value_desc("filename"),
128 cl::init(Val: "-"),
129 cl::cat(CommonCategory));
130
131// Embedding-specific options
132static cl::opt<std::string>
133 FunctionName("function", cl::desc("Process specific function only"),
134 cl::value_desc("name"), cl::Optional, cl::init(Val: ""),
135 cl::sub(EmbeddingsSubCmd), cl::cat(CommonCategory));
136
137static cl::opt<EmbeddingLevel>
138 Level("level", cl::desc("Embedding generation level:"),
139 cl::values(clEnumValN(InstructionLevel, "inst",
140 "Generate instruction-level embeddings"),
141 clEnumValN(BasicBlockLevel, "bb",
142 "Generate basic block-level embeddings"),
143 clEnumValN(FunctionLevel, "func",
144 "Generate function-level embeddings")),
145 cl::init(Val: FunctionLevel), cl::sub(EmbeddingsSubCmd),
146 cl::cat(CommonCategory));
147
148namespace ir2vec {
149
150/// Process the module and generate output based on selected subcommand
151static Error processModule(Module &M, raw_ostream &OS) {
152 IR2VecTool Tool(M);
153
154 if (EmbeddingsSubCmd) {
155 // Initialize vocabulary for embedding generation
156 // Note: Requires --ir2vec-vocab-path option to be set
157 // and this value will be populated in the var VocabFile
158 if (VocabFile.empty()) {
159 return createStringError(
160 EC: errc::invalid_argument,
161 S: "IR2Vec vocabulary file path not specified; "
162 "You may need to set it using --ir2vec-vocab-path");
163 }
164
165 std::shared_ptr<Vocabulary> Vocab;
166 if (auto Err = ir2vec::loadVocabulary(VocabPath: VocabFile).moveInto(Value&: Vocab))
167 return Err;
168 if (auto Err = Tool.setVocabulary(std::move(Vocab)))
169 return Err;
170
171 if (!FunctionName.empty()) {
172 // Process single function
173 if (const Function *F = M.getFunction(Name: FunctionName))
174 Tool.writeEmbeddingsToStream(F: *F, OS, Level);
175 else
176 return createStringError(EC: errc::invalid_argument,
177 Fmt: "Function '%s' not found",
178 Vals: FunctionName.c_str());
179 } else {
180 // Process all functions
181 Tool.writeEmbeddingsToStream(OS, Level);
182 }
183 } else {
184 // Both triplets and entities use triplet generation
185 Tool.writeTripletsToStream(OS);
186 }
187 return Error::success();
188}
189} // namespace ir2vec
190
191namespace mir2vec {
192
193/// Setup MIR context from input file
194static Error setupMIRContext(const std::string &InputFile, MIRContext &Ctx) {
195 SMDiagnostic Err;
196
197 auto MIR = createMIRParserFromFile(Filename: InputFile, Error&: Err, Context&: Ctx.Context);
198 if (!MIR) {
199 Err.print(ProgName: ToolName, S&: errs());
200 return createStringError(EC: errc::invalid_argument,
201 S: "Failed to parse MIR file");
202 }
203
204 auto SetDataLayout = [&](StringRef DataLayoutTargetTriple,
205 StringRef OldDLStr) -> std::optional<std::string> {
206 std::string IRTargetTriple = DataLayoutTargetTriple.str();
207 Triple TheTriple = Triple(IRTargetTriple);
208 if (TheTriple.getTriple().empty())
209 TheTriple.setTriple(sys::getDefaultTargetTriple());
210
211 auto TMOrErr = codegen::createTargetMachineForTriple(TargetTriple: TheTriple);
212 if (!TMOrErr) {
213 Err.print(ProgName: ToolName, S&: errs());
214 exit(status: 1); // Match original behavior
215 }
216 Ctx.TM = std::move(*TMOrErr);
217 return Ctx.TM->createDataLayout().getStringRepresentation();
218 };
219
220 Ctx.M = MIR->parseIRModule(DataLayoutCallback: SetDataLayout);
221 if (!Ctx.M) {
222 Err.print(ProgName: ToolName, S&: errs());
223 return createStringError(EC: errc::invalid_argument,
224 S: "Failed to parse IR module");
225 }
226
227 Ctx.MMI = std::make_unique<MachineModuleInfo>(args: Ctx.TM.get());
228 if (!Ctx.MMI || MIR->parseMachineFunctions(M&: *Ctx.M, MMI&: *Ctx.MMI)) {
229 Err.print(ProgName: ToolName, S&: errs());
230 return createStringError(EC: errc::invalid_argument,
231 S: "Failed to parse machine functions");
232 }
233
234 return Error::success();
235}
236
237/// Generic vocabulary initialization and processing
238template <typename ProcessFunc>
239static Error processWithVocabulary(MIRContext &Ctx, raw_ostream &OS,
240 bool useLayoutVocab, ProcessFunc processFn) {
241 MIR2VecTool Tool(*Ctx.MMI);
242
243 // Initialize appropriate vocabulary type
244 bool success = useLayoutVocab ? Tool.initializeVocabularyForLayout(M: *Ctx.M)
245 : Tool.initializeVocabulary(M: *Ctx.M);
246
247 if (!success) {
248 WithColor::error(OS&: errs(), Prefix: ToolName)
249 << "Failed to initialize MIR2Vec vocabulary"
250 << (useLayoutVocab ? " for layout" : "") << ".\n";
251 return createStringError(EC: errc::invalid_argument,
252 S: "Vocabulary initialization failed");
253 }
254
255 assert(Tool.getVocabulary() &&
256 "MIR2Vec vocabulary should be initialized at this point");
257
258 LLVM_DEBUG(dbgs() << "MIR2Vec vocabulary loaded successfully.\n"
259 << "Vocabulary dimension: "
260 << Tool.getVocabulary()->getDimension() << "\n"
261 << "Vocabulary size: "
262 << Tool.getVocabulary()->getCanonicalSize() << "\n");
263
264 // Execute the specific processing logic
265 return processFn(Tool);
266}
267
268/// Process module for triplet generation
269static Error processModuleForTriplets(MIRContext &Ctx, raw_ostream &OS) {
270 return processWithVocabulary(Ctx, OS, /*useLayoutVocab=*/true,
271 processFn: [&](MIR2VecTool &Tool) -> Error {
272 Tool.writeTripletsToStream(M: *Ctx.M, OS);
273 return Error::success();
274 });
275}
276
277/// Process module for entity generation
278static Error processModuleForEntities(MIRContext &Ctx, raw_ostream &OS) {
279 return processWithVocabulary(Ctx, OS, /*useLayoutVocab=*/true,
280 processFn: [&](MIR2VecTool &Tool) -> Error {
281 Tool.writeEntitiesToStream(OS);
282 return Error::success();
283 });
284}
285
286/// Process module for embedding generation
287static Error processModuleForEmbeddings(MIRContext &Ctx, raw_ostream &OS) {
288 return processWithVocabulary(
289 Ctx, OS, /*useLayoutVocab=*/false, processFn: [&](MIR2VecTool &Tool) -> Error {
290 if (!FunctionName.empty()) {
291 // Process single function
292 Function *F = Ctx.M->getFunction(Name: FunctionName);
293 if (!F) {
294 WithColor::error(OS&: errs(), Prefix: ToolName)
295 << "Function '" << FunctionName << "' not found\n";
296 return createStringError(EC: errc::invalid_argument,
297 S: "Function not found");
298 }
299
300 MachineFunction *MF = Ctx.MMI->getMachineFunction(F: *F);
301 if (!MF) {
302 WithColor::error(OS&: errs(), Prefix: ToolName)
303 << "No MachineFunction for " << FunctionName << "\n";
304 return createStringError(EC: errc::invalid_argument,
305 S: "No MachineFunction");
306 }
307
308 Tool.writeEmbeddingsToStream(MF&: *MF, OS, Level);
309 } else {
310 // Process all functions
311 Tool.writeEmbeddingsToStream(M: *Ctx.M, OS, Level);
312 }
313 return Error::success();
314 });
315}
316
317/// Main entry point for MIR processing
318static Error processModule(const std::string &InputFile, raw_ostream &OS) {
319 MIRContext Ctx;
320
321 // Setup MIR context (parse file, setup target machine, etc.)
322 if (auto Err = setupMIRContext(InputFile, Ctx))
323 return Err;
324
325 // Process based on subcommand
326 if (TripletsSubCmd)
327 return processModuleForTriplets(Ctx, OS);
328 else if (EntitiesSubCmd)
329 return processModuleForEntities(Ctx, OS);
330 else if (EmbeddingsSubCmd)
331 return processModuleForEmbeddings(Ctx, OS);
332 else {
333 WithColor::error(OS&: errs(), Prefix: ToolName)
334 << "Please specify a subcommand: triplets, entities, or embeddings\n";
335 return createStringError(EC: errc::invalid_argument, S: "No subcommand specified");
336 }
337}
338
339} // namespace mir2vec
340
341} // namespace llvm
342
343int main(int argc, char **argv) {
344 using namespace llvm;
345 using namespace llvm::ir2vec;
346 using namespace llvm::mir2vec;
347
348 InitLLVM X(argc, argv);
349 // Show Common, IR2Vec and MIR2Vec option categories
350 cl::HideUnrelatedOptions(Categories: ArrayRef<const cl::OptionCategory *>{
351 &CommonCategory, &ir2vec::IR2VecCategory, &mir2vec::MIR2VecCategory});
352 cl::ParseCommandLineOptions(
353 argc, argv,
354 Overview: "IR2Vec/MIR2Vec - Embedding Generation Tool\n"
355 "Generates embeddings for a given LLVM IR or MIR and "
356 "supports triplet generation for vocabulary "
357 "training and embedding generation.\n\n"
358 "See https://llvm.org/docs/CommandGuide/llvm-ir2vec.html for more "
359 "information.\n");
360
361 std::error_code EC;
362 raw_fd_ostream OS(OutputFilename, EC);
363 if (EC) {
364 WithColor::error(OS&: errs(), Prefix: ToolName)
365 << "opening output file: " << EC.message() << "\n";
366 return 1;
367 }
368
369 if (IRMode == IRKind::LLVMIR) {
370 if (EntitiesSubCmd) {
371 // Just dump entity mappings without processing any IR
372 IR2VecTool::writeEntitiesToStream(OS);
373 return 0;
374 }
375
376 // Parse the input LLVM IR file or stdin
377 SMDiagnostic Err;
378 LLVMContext Context;
379 std::unique_ptr<Module> M = parseIRFile(Filename: InputFilename, Err, Context);
380 if (!M) {
381 Err.print(ProgName: ToolName, S&: errs());
382 return 1;
383 }
384
385 if (Error Err = processModule(M&: *M, OS)) {
386 handleAllErrors(E: std::move(Err), Handlers: [&](const ErrorInfoBase &EIB) {
387 WithColor::error(OS&: errs(), Prefix: ToolName) << EIB.message() << "\n";
388 });
389 return 1;
390 }
391 return 0;
392 }
393 if (IRMode == IRKind::MIR) {
394 // Initialize targets for Machine IR processing
395 InitializeAllTargets();
396 InitializeAllTargetMCs();
397 InitializeAllAsmParsers();
398 InitializeAllAsmPrinters();
399 static codegen::RegisterCodeGenFlags CGF;
400
401 if (Error Err = mir2vec::processModule(InputFile: InputFilename, OS)) {
402 handleAllErrors(E: std::move(Err), Handlers: [&](const ErrorInfoBase &EIB) {
403 WithColor::error(OS&: errs(), Prefix: ToolName) << EIB.message() << "\n";
404 });
405 return 1;
406 }
407
408 return 0;
409 }
410
411 return 0;
412}
413