1//===- llvm-extract.cpp - LLVM function extraction utility ----------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This utility changes the input module to only contain a single function,
10// which is primarily used for debugging transformations.
11//
12//===----------------------------------------------------------------------===//
13
14#include "llvm/ADT/SetVector.h"
15#include "llvm/ADT/SmallPtrSet.h"
16#include "llvm/Bitcode/BitcodeWriterPass.h"
17#include "llvm/IR/DataLayout.h"
18#include "llvm/IR/IRPrintingPasses.h"
19#include "llvm/IR/Instructions.h"
20#include "llvm/IR/LLVMContext.h"
21#include "llvm/IR/Module.h"
22#include "llvm/IRPrinter/IRPrintingPasses.h"
23#include "llvm/IRReader/IRReader.h"
24#include "llvm/Passes/PassBuilder.h"
25#include "llvm/Support/CommandLine.h"
26#include "llvm/Support/Error.h"
27#include "llvm/Support/FileSystem.h"
28#include "llvm/Support/InitLLVM.h"
29#include "llvm/Support/Regex.h"
30#include "llvm/Support/SourceMgr.h"
31#include "llvm/Support/SystemUtils.h"
32#include "llvm/Support/ToolOutputFile.h"
33#include "llvm/Transforms/IPO.h"
34#include "llvm/Transforms/IPO/BlockExtractor.h"
35#include "llvm/Transforms/IPO/ExtractGV.h"
36#include "llvm/Transforms/IPO/GlobalDCE.h"
37#include "llvm/Transforms/IPO/StripDeadPrototypes.h"
38#include "llvm/Transforms/IPO/StripSymbols.h"
39#include <memory>
40#include <utility>
41
42using namespace llvm;
43
44static cl::OptionCategory ExtractCat("llvm-extract Options");
45
46// InputFilename - The filename to read from.
47static cl::opt<std::string> InputFilename(cl::Positional,
48 cl::desc("<input bitcode file>"),
49 cl::init(Val: "-"),
50 cl::value_desc("filename"));
51
52static cl::opt<std::string> OutputFilename("o",
53 cl::desc("Specify output filename"),
54 cl::value_desc("filename"),
55 cl::init(Val: "-"), cl::cat(ExtractCat));
56
57static cl::opt<bool> Force("f", cl::desc("Enable binary output on terminals"),
58 cl::cat(ExtractCat));
59
60static cl::opt<bool> DeleteFn("delete",
61 cl::desc("Delete specified Globals from Module"),
62 cl::cat(ExtractCat));
63
64static cl::opt<bool> KeepConstInit("keep-const-init",
65 cl::desc("Keep initializers of constants"),
66 cl::cat(ExtractCat));
67
68static cl::opt<bool>
69 Recursive("recursive", cl::desc("Recursively extract all called functions"),
70 cl::cat(ExtractCat));
71
72// ExtractFuncs - The functions to extract from the module.
73static cl::list<std::string>
74 ExtractFuncs("func", cl::desc("Specify function to extract"),
75 cl::value_desc("function"), cl::cat(ExtractCat));
76
77// ExtractRegExpFuncs - The functions, matched via regular expression, to
78// extract from the module.
79static cl::list<std::string>
80 ExtractRegExpFuncs("rfunc",
81 cl::desc("Specify function(s) to extract using a "
82 "regular expression"),
83 cl::value_desc("rfunction"), cl::cat(ExtractCat));
84
85// ExtractBlocks - The blocks to extract from the module.
86static cl::list<std::string> ExtractBlocks(
87 "bb",
88 cl::desc(
89 "Specify <function, basic block1[;basic block2...]> pairs to extract.\n"
90 "Each pair will create a function.\n"
91 "If multiple basic blocks are specified in one pair,\n"
92 "the first block in the sequence should dominate the rest.\n"
93 "If an unnamed basic block is to be extracted,\n"
94 "'%' should be added before the basic block variable names.\n"
95 "eg:\n"
96 " --bb=f:bb1;bb2 will extract one function with both bb1 and bb2;\n"
97 " --bb=f:bb1 --bb=f:bb2 will extract two functions, one with bb1, one "
98 "with bb2.\n"
99 " --bb=f:%1 will extract one function with basic block 1;"),
100 cl::value_desc("function:bb1[;bb2...]"), cl::cat(ExtractCat));
101
102// ExtractAlias - The alias to extract from the module.
103static cl::list<std::string>
104 ExtractAliases("alias", cl::desc("Specify alias to extract"),
105 cl::value_desc("alias"), cl::cat(ExtractCat));
106
107// ExtractRegExpAliases - The aliases, matched via regular expression, to
108// extract from the module.
109static cl::list<std::string>
110 ExtractRegExpAliases("ralias",
111 cl::desc("Specify alias(es) to extract using a "
112 "regular expression"),
113 cl::value_desc("ralias"), cl::cat(ExtractCat));
114
115// ExtractGlobals - The globals to extract from the module.
116static cl::list<std::string>
117 ExtractGlobals("glob", cl::desc("Specify global to extract"),
118 cl::value_desc("global"), cl::cat(ExtractCat));
119
120// ExtractRegExpGlobals - The globals, matched via regular expression, to
121// extract from the module...
122static cl::list<std::string>
123 ExtractRegExpGlobals("rglob",
124 cl::desc("Specify global(s) to extract using a "
125 "regular expression"),
126 cl::value_desc("rglobal"), cl::cat(ExtractCat));
127
128static cl::opt<bool> OutputAssembly("S",
129 cl::desc("Write output as LLVM assembly"),
130 cl::Hidden, cl::cat(ExtractCat));
131
132int main(int argc, char **argv) {
133 InitLLVM X(argc, argv);
134
135 LLVMContext Context;
136 cl::HideUnrelatedOptions(Category&: ExtractCat);
137 cl::ParseCommandLineOptions(argc, argv, Overview: "llvm extractor\n");
138
139 // Use lazy loading, since we only care about selected global values.
140 SMDiagnostic Err;
141 std::unique_ptr<Module> M = getLazyIRFileModule(Filename: InputFilename, Err, Context);
142
143 if (!M) {
144 Err.print(ProgName: argv[0], S&: errs());
145 return 1;
146 }
147
148 // Use SetVector to avoid duplicates.
149 SetVector<GlobalValue *> GVs;
150
151 // Figure out which aliases we should extract.
152 for (size_t i = 0, e = ExtractAliases.size(); i != e; ++i) {
153 GlobalAlias *GA = M->getNamedAlias(Name: ExtractAliases[i]);
154 if (!GA) {
155 errs() << argv[0] << ": program doesn't contain alias named '"
156 << ExtractAliases[i] << "'!\n";
157 return 1;
158 }
159 GVs.insert(X: GA);
160 }
161
162 // Extract aliases via regular expression matching.
163 for (size_t i = 0, e = ExtractRegExpAliases.size(); i != e; ++i) {
164 std::string Error;
165 Regex RegEx(ExtractRegExpAliases[i]);
166 if (!RegEx.isValid(Error)) {
167 errs() << argv[0] << ": '" << ExtractRegExpAliases[i] << "' "
168 "invalid regex: " << Error;
169 }
170 bool match = false;
171 for (Module::alias_iterator GA = M->alias_begin(), E = M->alias_end();
172 GA != E; GA++) {
173 if (RegEx.match(String: GA->getName())) {
174 GVs.insert(X: &*GA);
175 match = true;
176 }
177 }
178 if (!match) {
179 errs() << argv[0] << ": program doesn't contain global named '"
180 << ExtractRegExpAliases[i] << "'!\n";
181 return 1;
182 }
183 }
184
185 // Figure out which globals we should extract.
186 for (size_t i = 0, e = ExtractGlobals.size(); i != e; ++i) {
187 GlobalValue *GV = M->getNamedGlobal(Name: ExtractGlobals[i]);
188 if (!GV) {
189 errs() << argv[0] << ": program doesn't contain global named '"
190 << ExtractGlobals[i] << "'!\n";
191 return 1;
192 }
193 GVs.insert(X: GV);
194 }
195
196 // Extract globals via regular expression matching.
197 for (size_t i = 0, e = ExtractRegExpGlobals.size(); i != e; ++i) {
198 std::string Error;
199 Regex RegEx(ExtractRegExpGlobals[i]);
200 if (!RegEx.isValid(Error)) {
201 errs() << argv[0] << ": '" << ExtractRegExpGlobals[i] << "' "
202 "invalid regex: " << Error;
203 }
204 bool match = false;
205 for (auto &GV : M->globals()) {
206 if (RegEx.match(String: GV.getName())) {
207 GVs.insert(X: &GV);
208 match = true;
209 }
210 }
211 if (!match) {
212 errs() << argv[0] << ": program doesn't contain global named '"
213 << ExtractRegExpGlobals[i] << "'!\n";
214 return 1;
215 }
216 }
217
218 // Figure out which functions we should extract.
219 for (size_t i = 0, e = ExtractFuncs.size(); i != e; ++i) {
220 GlobalValue *GV = M->getFunction(Name: ExtractFuncs[i]);
221 if (!GV) {
222 errs() << argv[0] << ": program doesn't contain function named '"
223 << ExtractFuncs[i] << "'!\n";
224 return 1;
225 }
226 GVs.insert(X: GV);
227 }
228 // Extract functions via regular expression matching.
229 for (size_t i = 0, e = ExtractRegExpFuncs.size(); i != e; ++i) {
230 std::string Error;
231 StringRef RegExStr = ExtractRegExpFuncs[i];
232 Regex RegEx(RegExStr);
233 if (!RegEx.isValid(Error)) {
234 errs() << argv[0] << ": '" << ExtractRegExpFuncs[i] << "' "
235 "invalid regex: " << Error;
236 }
237 bool match = false;
238 for (Module::iterator F = M->begin(), E = M->end(); F != E;
239 F++) {
240 if (RegEx.match(String: F->getName())) {
241 GVs.insert(X: &*F);
242 match = true;
243 }
244 }
245 if (!match) {
246 errs() << argv[0] << ": program doesn't contain global named '"
247 << ExtractRegExpFuncs[i] << "'!\n";
248 return 1;
249 }
250 }
251
252 // Figure out which BasicBlocks we should extract.
253 SmallVector<std::pair<Function *, SmallVector<StringRef, 16>>, 2> BBMap;
254 for (StringRef StrPair : ExtractBlocks) {
255 SmallVector<StringRef, 16> BBNames;
256 auto BBInfo = StrPair.split(Separator: ':');
257 // Get the function.
258 Function *F = M->getFunction(Name: BBInfo.first);
259 if (!F) {
260 errs() << argv[0] << ": program doesn't contain a function named '"
261 << BBInfo.first << "'!\n";
262 return 1;
263 }
264 // Add the function to the materialize list, and store the basic block names
265 // to check after materialization.
266 GVs.insert(X: F);
267 BBInfo.second.split(A&: BBNames, Separator: ';', /*MaxSplit=*/-1, /*KeepEmpty=*/false);
268 BBMap.push_back(Elt: {F, std::move(BBNames)});
269 }
270
271 // Use *argv instead of argv[0] to work around a wrong GCC warning.
272 ExitOnError ExitOnErr(std::string(*argv) + ": error reading input: ");
273
274 if (Recursive) {
275 std::vector<llvm::Function *> Workqueue;
276 for (GlobalValue *GV : GVs) {
277 if (auto *F = dyn_cast<Function>(Val: GV)) {
278 Workqueue.push_back(x: F);
279 }
280 }
281 while (!Workqueue.empty()) {
282 Function *F = &*Workqueue.back();
283 Workqueue.pop_back();
284 ExitOnErr(F->materialize());
285 for (auto &BB : *F) {
286 for (auto &I : BB) {
287 CallBase *CB = dyn_cast<CallBase>(Val: &I);
288 if (!CB)
289 continue;
290 Function *CF = CB->getCalledFunction();
291 if (!CF)
292 continue;
293 if (CF->isDeclaration() || !GVs.insert(X: CF))
294 continue;
295 Workqueue.push_back(x: CF);
296 }
297 }
298 }
299 }
300
301 auto Materialize = [&](GlobalValue &GV) { ExitOnErr(GV.materialize()); };
302
303 // Materialize requisite global values.
304 if (!DeleteFn) {
305 for (size_t i = 0, e = GVs.size(); i != e; ++i)
306 Materialize(*GVs[i]);
307 } else {
308 // Deleting. Materialize every GV that's *not* in GVs.
309 SmallPtrSet<GlobalValue *, 8> GVSet(llvm::from_range, GVs);
310 for (auto &F : *M) {
311 if (!GVSet.count(Ptr: &F))
312 Materialize(F);
313 }
314 }
315
316 {
317 std::vector<GlobalValue *> Gvs(GVs.begin(), GVs.end());
318 LoopAnalysisManager LAM;
319 FunctionAnalysisManager FAM;
320 CGSCCAnalysisManager CGAM;
321 ModuleAnalysisManager MAM;
322
323 PassBuilder PB;
324
325 PB.registerModuleAnalyses(MAM);
326 PB.registerCGSCCAnalyses(CGAM);
327 PB.registerFunctionAnalyses(FAM);
328 PB.registerLoopAnalyses(LAM);
329 PB.crossRegisterProxies(LAM, FAM, CGAM, MAM);
330
331 ModulePassManager PM;
332 PM.addPass(Pass: ExtractGVPass(Gvs, DeleteFn, KeepConstInit));
333 PM.run(IR&: *M, AM&: MAM);
334
335 // Now that we have all the GVs we want, mark the module as fully
336 // materialized.
337 // FIXME: should the GVExtractionPass handle this?
338 ExitOnErr(M->materializeAll());
339 }
340
341 // Extract the specified basic blocks from the module and erase the existing
342 // functions.
343 if (!ExtractBlocks.empty()) {
344 // Figure out which BasicBlocks we should extract.
345 std::vector<std::vector<BasicBlock *>> GroupOfBBs;
346 for (auto &P : BBMap) {
347 std::vector<BasicBlock *> BBs;
348 for (StringRef BBName : P.second) {
349 // The function has been materialized, so add its matching basic blocks
350 // to the block extractor list, or fail if a name is not found.
351 auto Res = llvm::find_if(Range&: *P.first, P: [&](const BasicBlock &BB) {
352 return BB.getNameOrAsOperand() == BBName;
353 });
354 if (Res == P.first->end()) {
355 errs() << argv[0] << ": function " << P.first->getName()
356 << " doesn't contain a basic block named '" << BBName
357 << "'!\n";
358 return 1;
359 }
360 BBs.push_back(x: &*Res);
361 }
362 GroupOfBBs.push_back(x: BBs);
363 }
364
365 LoopAnalysisManager LAM;
366 FunctionAnalysisManager FAM;
367 CGSCCAnalysisManager CGAM;
368 ModuleAnalysisManager MAM;
369
370 PassBuilder PB;
371
372 PB.registerModuleAnalyses(MAM);
373 PB.registerCGSCCAnalyses(CGAM);
374 PB.registerFunctionAnalyses(FAM);
375 PB.registerLoopAnalyses(LAM);
376 PB.crossRegisterProxies(LAM, FAM, CGAM, MAM);
377
378 ModulePassManager PM;
379 PM.addPass(Pass: BlockExtractorPass(std::move(GroupOfBBs), true));
380 PM.run(IR&: *M, AM&: MAM);
381 }
382
383 // In addition to deleting all other functions, we also want to spiff it
384 // up a little bit. Do this now.
385
386 LoopAnalysisManager LAM;
387 FunctionAnalysisManager FAM;
388 CGSCCAnalysisManager CGAM;
389 ModuleAnalysisManager MAM;
390
391 PassBuilder PB;
392
393 PB.registerModuleAnalyses(MAM);
394 PB.registerCGSCCAnalyses(CGAM);
395 PB.registerFunctionAnalyses(FAM);
396 PB.registerLoopAnalyses(LAM);
397 PB.crossRegisterProxies(LAM, FAM, CGAM, MAM);
398
399 ModulePassManager PM;
400 if (!DeleteFn)
401 PM.addPass(Pass: GlobalDCEPass());
402 PM.addPass(Pass: StripDeadDebugInfoPass());
403 PM.addPass(Pass: StripDeadPrototypesPass());
404 PM.addPass(Pass: StripDeadCGProfilePass());
405
406 std::error_code EC;
407 ToolOutputFile Out(OutputFilename, EC, sys::fs::OF_None);
408 if (EC) {
409 errs() << EC.message() << '\n';
410 return 1;
411 }
412
413 if (OutputAssembly)
414 PM.addPass(
415 Pass: PrintModulePass(Out.os(), "", /* ShouldPreserveUseListOrder */ false));
416 else if (Force || !CheckBitcodeOutputToConsole(stream_to_check&: Out.os()))
417 PM.addPass(
418 Pass: BitcodeWriterPass(Out.os(), /* ShouldPreserveUseListOrder */ true));
419
420 PM.run(IR&: *M, AM&: MAM);
421
422 // Declare success.
423 Out.keep();
424
425 return 0;
426}
427