1//===- llvm-ir2vec.cpp - IR2Vec/MIR2Vec Embedding Generation Tool --------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8///
9/// \file
10/// This file implements the IR2Vec and MIR2Vec embedding generation tool.
11///
12/// This tool supports two modes:
13/// - LLVM IR mode (-mode=llvm): Process LLVM IR
14/// - Machine IR mode (-mode=mir): Process Machine IR
15///
16/// Available subcommands:
17///
18/// 1. Triplet Generation (triplets):
19/// Generates numeric triplets (head, tail, relation) for vocabulary
20/// training. Output format: MAX_RELATION=N header followed by
21/// head\ttail\trelation lines. Relations: 0=Type, 1=Next, 2+=Arg0,Arg1,...
22///
23/// For LLVM IR:
24/// llvm-ir2vec triplets input.bc -o train2id.txt
25///
26/// For Machine IR:
27/// llvm-ir2vec triplets -mode=mir input.mir -o train2id.txt
28///
29/// 2. Entity Mappings (entities):
30/// Generates entity mappings for vocabulary training.
31/// Output format: <total_entities> header followed by entity\tid lines.
32///
33/// For LLVM IR:
34/// llvm-ir2vec entities input.bc -o entity2id.txt
35///
36/// For Machine IR:
37/// llvm-ir2vec entities -mode=mir input.mir -o entity2id.txt
38///
39/// 3. Embedding Generation (embeddings):
40/// Generates IR2Vec/MIR2Vec embeddings using a trained vocabulary.
41///
42/// For LLVM IR:
43/// llvm-ir2vec embeddings --ir2vec-vocab-path=vocab.json
44/// --ir2vec-kind=<kind> --level=<level> input.bc -o embeddings.txt
45/// Kind: --ir2vec-kind=symbolic (default), --ir2vec-kind=flow-aware
46///
47/// For Machine IR:
48/// llvm-ir2vec embeddings -mode=mir --mir2vec-vocab-path=vocab.json
49/// --level=<level> input.mir -o embeddings.txt
50///
51/// Levels: --level=inst (instructions), --level=bb (basic blocks),
52/// --level=func (functions) (See IR2Vec.cpp/MIR2Vec.cpp for more embedding
53/// generation options)
54///
55//===----------------------------------------------------------------------===//
56
57#include "lib/Utils.h"
58#include "llvm/ADT/ArrayRef.h"
59#include "llvm/Analysis/IR2Vec.h"
60#include "llvm/CodeGen/CommandFlags.h"
61#include "llvm/CodeGen/MIR2Vec.h"
62#include "llvm/CodeGen/MIRParser/MIRParser.h"
63#include "llvm/CodeGen/MachineFunction.h"
64#include "llvm/CodeGen/MachineModuleInfo.h"
65#include "llvm/CodeGen/TargetInstrInfo.h"
66#include "llvm/CodeGen/TargetRegisterInfo.h"
67#include "llvm/IR/BasicBlock.h"
68#include "llvm/IR/Function.h"
69#include "llvm/IR/InstIterator.h"
70#include "llvm/IR/Instructions.h"
71#include "llvm/IR/LLVMContext.h"
72#include "llvm/IR/Module.h"
73#include "llvm/IR/PassInstrumentation.h"
74#include "llvm/IR/PassManager.h"
75#include "llvm/IR/Type.h"
76#include "llvm/IRReader/IRReader.h"
77#include "llvm/MC/TargetRegistry.h"
78#include "llvm/Support/CommandLine.h"
79#include "llvm/Support/Debug.h"
80#include "llvm/Support/Errc.h"
81#include "llvm/Support/InitLLVM.h"
82#include "llvm/Support/SourceMgr.h"
83#include "llvm/Support/TargetSelect.h"
84#include "llvm/Support/WithColor.h"
85#include "llvm/Support/raw_ostream.h"
86#include "llvm/Target/TargetMachine.h"
87#include "llvm/TargetParser/Host.h"
88
89#define DEBUG_TYPE "ir2vec"
90
91namespace llvm {
92
93// Common option category for options shared between IR2Vec and MIR2Vec
94static cl::OptionCategory CommonCategory("Common Options",
95 "Options applicable to both IR2Vec "
96 "and MIR2Vec modes");
97
98enum IRKind {
99 LLVMIR = 0, ///< LLVM IR
100 MIR ///< Machine IR
101};
102
103static cl::opt<IRKind>
104 IRMode("mode", cl::desc("Tool operation mode:"),
105 cl::values(clEnumValN(LLVMIR, "llvm", "Process LLVM IR"),
106 clEnumValN(MIR, "mir", "Process Machine IR")),
107 cl::init(Val: LLVMIR), cl::cat(CommonCategory));
108
109// Subcommands
110static cl::SubCommand
111 TripletsSubCmd("triplets", "Generate triplets for vocabulary training");
112static cl::SubCommand
113 EntitiesSubCmd("entities",
114 "Generate entity mappings for vocabulary training");
115static cl::SubCommand
116 EmbeddingsSubCmd("embeddings",
117 "Generate embeddings using trained vocabulary");
118
119// Common options
120static cl::opt<std::string> InputFilename(
121 cl::Positional, cl::desc("<input bitcode/MIR file or '-' for stdin>"),
122 cl::init(Val: "-"), cl::sub(TripletsSubCmd), cl::sub(EntitiesSubCmd),
123 cl::sub(EmbeddingsSubCmd), cl::cat(CommonCategory));
124
125static cl::opt<std::string> OutputFilename("o", cl::desc("Output filename"),
126 cl::value_desc("filename"),
127 cl::init(Val: "-"),
128 cl::cat(CommonCategory));
129
130// Embedding-specific options
131static cl::opt<std::string>
132 FunctionName("function", cl::desc("Process specific function only"),
133 cl::value_desc("name"), cl::Optional, cl::init(Val: ""),
134 cl::sub(EmbeddingsSubCmd), cl::cat(CommonCategory));
135
136static cl::opt<EmbeddingLevel>
137 Level("level", cl::desc("Embedding generation level:"),
138 cl::values(clEnumValN(InstructionLevel, "inst",
139 "Generate instruction-level embeddings"),
140 clEnumValN(BasicBlockLevel, "bb",
141 "Generate basic block-level embeddings"),
142 clEnumValN(FunctionLevel, "func",
143 "Generate function-level embeddings")),
144 cl::init(Val: FunctionLevel), cl::sub(EmbeddingsSubCmd),
145 cl::cat(CommonCategory));
146
147namespace ir2vec {
148
149/// Process the module and generate output based on selected subcommand
150static Error processModule(Module &M, raw_ostream &OS) {
151 IR2VecTool Tool(M);
152
153 if (EmbeddingsSubCmd) {
154 // Initialize vocabulary for embedding generation
155 // Note: Requires --ir2vec-vocab-path option to be set
156 // and this value will be populated in the var VocabFile
157 if (VocabFile.empty()) {
158 return createStringError(
159 EC: errc::invalid_argument,
160 S: "IR2Vec vocabulary file path not specified; "
161 "You may need to set it using --ir2vec-vocab-path");
162 }
163
164 if (Error Err = Tool.initializeVocabulary(VocabPath: VocabFile))
165 return Err;
166
167 if (!FunctionName.empty()) {
168 // Process single function
169 if (const Function *F = M.getFunction(Name: FunctionName))
170 Tool.writeEmbeddingsToStream(F: *F, OS, Level);
171 else
172 return createStringError(EC: errc::invalid_argument,
173 Fmt: "Function '%s' not found",
174 Vals: FunctionName.c_str());
175 } else {
176 // Process all functions
177 Tool.writeEmbeddingsToStream(OS, Level);
178 }
179 } else {
180 // Both triplets and entities use triplet generation
181 Tool.writeTripletsToStream(OS);
182 }
183 return Error::success();
184}
185} // namespace ir2vec
186
187namespace mir2vec {
188
189/// Setup MIR context from input file
190static Error setupMIRContext(const std::string &InputFile, MIRContext &Ctx) {
191 SMDiagnostic Err;
192
193 auto MIR = createMIRParserFromFile(Filename: InputFile, Error&: Err, Context&: Ctx.Context);
194 if (!MIR) {
195 Err.print(ProgName: ToolName, S&: errs());
196 return createStringError(EC: errc::invalid_argument,
197 S: "Failed to parse MIR file");
198 }
199
200 auto SetDataLayout = [&](StringRef DataLayoutTargetTriple,
201 StringRef OldDLStr) -> std::optional<std::string> {
202 std::string IRTargetTriple = DataLayoutTargetTriple.str();
203 Triple TheTriple = Triple(IRTargetTriple);
204 if (TheTriple.getTriple().empty())
205 TheTriple.setTriple(sys::getDefaultTargetTriple());
206
207 auto TMOrErr = codegen::createTargetMachineForTriple(TargetTriple: TheTriple.str());
208 if (!TMOrErr) {
209 Err.print(ProgName: ToolName, S&: errs());
210 exit(status: 1); // Match original behavior
211 }
212 Ctx.TM = std::move(*TMOrErr);
213 return Ctx.TM->createDataLayout().getStringRepresentation();
214 };
215
216 Ctx.M = MIR->parseIRModule(DataLayoutCallback: SetDataLayout);
217 if (!Ctx.M) {
218 Err.print(ProgName: ToolName, S&: errs());
219 return createStringError(EC: errc::invalid_argument,
220 S: "Failed to parse IR module");
221 }
222
223 Ctx.MMI = std::make_unique<MachineModuleInfo>(args: Ctx.TM.get());
224 if (!Ctx.MMI || MIR->parseMachineFunctions(M&: *Ctx.M, MMI&: *Ctx.MMI)) {
225 Err.print(ProgName: ToolName, S&: errs());
226 return createStringError(EC: errc::invalid_argument,
227 S: "Failed to parse machine functions");
228 }
229
230 return Error::success();
231}
232
233/// Generic vocabulary initialization and processing
234template <typename ProcessFunc>
235static Error processWithVocabulary(MIRContext &Ctx, raw_ostream &OS,
236 bool useLayoutVocab, ProcessFunc processFn) {
237 MIR2VecTool Tool(*Ctx.MMI);
238
239 // Initialize appropriate vocabulary type
240 bool success = useLayoutVocab ? Tool.initializeVocabularyForLayout(M: *Ctx.M)
241 : Tool.initializeVocabulary(M: *Ctx.M);
242
243 if (!success) {
244 WithColor::error(OS&: errs(), Prefix: ToolName)
245 << "Failed to initialize MIR2Vec vocabulary"
246 << (useLayoutVocab ? " for layout" : "") << ".\n";
247 return createStringError(EC: errc::invalid_argument,
248 S: "Vocabulary initialization failed");
249 }
250
251 assert(Tool.getVocabulary() &&
252 "MIR2Vec vocabulary should be initialized at this point");
253
254 LLVM_DEBUG(dbgs() << "MIR2Vec vocabulary loaded successfully.\n"
255 << "Vocabulary dimension: "
256 << Tool.getVocabulary()->getDimension() << "\n"
257 << "Vocabulary size: "
258 << Tool.getVocabulary()->getCanonicalSize() << "\n");
259
260 // Execute the specific processing logic
261 return processFn(Tool);
262}
263
264/// Process module for triplet generation
265static Error processModuleForTriplets(MIRContext &Ctx, raw_ostream &OS) {
266 return processWithVocabulary(Ctx, OS, /*useLayoutVocab=*/true,
267 processFn: [&](MIR2VecTool &Tool) -> Error {
268 Tool.writeTripletsToStream(M: *Ctx.M, OS);
269 return Error::success();
270 });
271}
272
273/// Process module for entity generation
274static Error processModuleForEntities(MIRContext &Ctx, raw_ostream &OS) {
275 return processWithVocabulary(Ctx, OS, /*useLayoutVocab=*/true,
276 processFn: [&](MIR2VecTool &Tool) -> Error {
277 Tool.writeEntitiesToStream(OS);
278 return Error::success();
279 });
280}
281
282/// Process module for embedding generation
283static Error processModuleForEmbeddings(MIRContext &Ctx, raw_ostream &OS) {
284 return processWithVocabulary(
285 Ctx, OS, /*useLayoutVocab=*/false, processFn: [&](MIR2VecTool &Tool) -> Error {
286 if (!FunctionName.empty()) {
287 // Process single function
288 Function *F = Ctx.M->getFunction(Name: FunctionName);
289 if (!F) {
290 WithColor::error(OS&: errs(), Prefix: ToolName)
291 << "Function '" << FunctionName << "' not found\n";
292 return createStringError(EC: errc::invalid_argument,
293 S: "Function not found");
294 }
295
296 MachineFunction *MF = Ctx.MMI->getMachineFunction(F: *F);
297 if (!MF) {
298 WithColor::error(OS&: errs(), Prefix: ToolName)
299 << "No MachineFunction for " << FunctionName << "\n";
300 return createStringError(EC: errc::invalid_argument,
301 S: "No MachineFunction");
302 }
303
304 Tool.writeEmbeddingsToStream(MF&: *MF, OS, Level);
305 } else {
306 // Process all functions
307 Tool.writeEmbeddingsToStream(M: *Ctx.M, OS, Level);
308 }
309 return Error::success();
310 });
311}
312
313/// Main entry point for MIR processing
314static Error processModule(const std::string &InputFile, raw_ostream &OS) {
315 MIRContext Ctx;
316
317 // Setup MIR context (parse file, setup target machine, etc.)
318 if (auto Err = setupMIRContext(InputFile, Ctx))
319 return Err;
320
321 // Process based on subcommand
322 if (TripletsSubCmd)
323 return processModuleForTriplets(Ctx, OS);
324 else if (EntitiesSubCmd)
325 return processModuleForEntities(Ctx, OS);
326 else if (EmbeddingsSubCmd)
327 return processModuleForEmbeddings(Ctx, OS);
328 else {
329 WithColor::error(OS&: errs(), Prefix: ToolName)
330 << "Please specify a subcommand: triplets, entities, or embeddings\n";
331 return createStringError(EC: errc::invalid_argument, S: "No subcommand specified");
332 }
333}
334
335} // namespace mir2vec
336
337} // namespace llvm
338
339int main(int argc, char **argv) {
340 using namespace llvm;
341 using namespace llvm::ir2vec;
342 using namespace llvm::mir2vec;
343
344 InitLLVM X(argc, argv);
345 // Show Common, IR2Vec and MIR2Vec option categories
346 cl::HideUnrelatedOptions(Categories: ArrayRef<const cl::OptionCategory *>{
347 &CommonCategory, &ir2vec::IR2VecCategory, &mir2vec::MIR2VecCategory});
348 cl::ParseCommandLineOptions(
349 argc, argv,
350 Overview: "IR2Vec/MIR2Vec - Embedding Generation Tool\n"
351 "Generates embeddings for a given LLVM IR or MIR and "
352 "supports triplet generation for vocabulary "
353 "training and embedding generation.\n\n"
354 "See https://llvm.org/docs/CommandGuide/llvm-ir2vec.html for more "
355 "information.\n");
356
357 std::error_code EC;
358 raw_fd_ostream OS(OutputFilename, EC);
359 if (EC) {
360 WithColor::error(OS&: errs(), Prefix: ToolName)
361 << "opening output file: " << EC.message() << "\n";
362 return 1;
363 }
364
365 if (IRMode == IRKind::LLVMIR) {
366 if (EntitiesSubCmd) {
367 // Just dump entity mappings without processing any IR
368 IR2VecTool::writeEntitiesToStream(OS);
369 return 0;
370 }
371
372 // Parse the input LLVM IR file or stdin
373 SMDiagnostic Err;
374 LLVMContext Context;
375 std::unique_ptr<Module> M = parseIRFile(Filename: InputFilename, Err, Context);
376 if (!M) {
377 Err.print(ProgName: ToolName, S&: errs());
378 return 1;
379 }
380
381 if (Error Err = processModule(M&: *M, OS)) {
382 handleAllErrors(E: std::move(Err), Handlers: [&](const ErrorInfoBase &EIB) {
383 WithColor::error(OS&: errs(), Prefix: ToolName) << EIB.message() << "\n";
384 });
385 return 1;
386 }
387 return 0;
388 }
389 if (IRMode == IRKind::MIR) {
390 // Initialize targets for Machine IR processing
391 InitializeAllTargets();
392 InitializeAllTargetMCs();
393 InitializeAllAsmParsers();
394 InitializeAllAsmPrinters();
395 static codegen::RegisterCodeGenFlags CGF;
396
397 if (Error Err = mir2vec::processModule(InputFile: InputFilename, OS)) {
398 handleAllErrors(E: std::move(Err), Handlers: [&](const ErrorInfoBase &EIB) {
399 WithColor::error(OS&: errs(), Prefix: ToolName) << EIB.message() << "\n";
400 });
401 return 1;
402 }
403
404 return 0;
405 }
406
407 return 0;
408}
409