1 | //===- llvm-extract.cpp - LLVM function extraction utility ----------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This utility changes the input module to only contain a single function, |
10 | // which is primarily used for debugging transformations. |
11 | // |
12 | //===----------------------------------------------------------------------===// |
13 | |
14 | #include "llvm/ADT/SetVector.h" |
15 | #include "llvm/ADT/SmallPtrSet.h" |
16 | #include "llvm/Bitcode/BitcodeWriterPass.h" |
17 | #include "llvm/IR/DataLayout.h" |
18 | #include "llvm/IR/IRPrintingPasses.h" |
19 | #include "llvm/IR/Instructions.h" |
20 | #include "llvm/IR/LLVMContext.h" |
21 | #include "llvm/IR/Module.h" |
22 | #include "llvm/IRPrinter/IRPrintingPasses.h" |
23 | #include "llvm/IRReader/IRReader.h" |
24 | #include "llvm/Passes/PassBuilder.h" |
25 | #include "llvm/Support/CommandLine.h" |
26 | #include "llvm/Support/Error.h" |
27 | #include "llvm/Support/FileSystem.h" |
28 | #include "llvm/Support/InitLLVM.h" |
29 | #include "llvm/Support/Regex.h" |
30 | #include "llvm/Support/SourceMgr.h" |
31 | #include "llvm/Support/SystemUtils.h" |
32 | #include "llvm/Support/ToolOutputFile.h" |
33 | #include "llvm/Transforms/IPO.h" |
34 | #include "llvm/Transforms/IPO/BlockExtractor.h" |
35 | #include "llvm/Transforms/IPO/ExtractGV.h" |
36 | #include "llvm/Transforms/IPO/GlobalDCE.h" |
37 | #include "llvm/Transforms/IPO/StripDeadPrototypes.h" |
38 | #include "llvm/Transforms/IPO/StripSymbols.h" |
39 | #include <memory> |
40 | #include <utility> |
41 | |
42 | using namespace llvm; |
43 | |
44 | cl::OptionCategory ("llvm-extract Options" ); |
45 | |
46 | // InputFilename - The filename to read from. |
47 | static cl::opt<std::string> InputFilename(cl::Positional, |
48 | cl::desc("<input bitcode file>" ), |
49 | cl::init(Val: "-" ), |
50 | cl::value_desc("filename" )); |
51 | |
52 | static cl::opt<std::string> OutputFilename("o" , |
53 | cl::desc("Specify output filename" ), |
54 | cl::value_desc("filename" ), |
55 | cl::init(Val: "-" ), cl::cat(ExtractCat)); |
56 | |
57 | static cl::opt<bool> Force("f" , cl::desc("Enable binary output on terminals" ), |
58 | cl::cat(ExtractCat)); |
59 | |
60 | static cl::opt<bool> DeleteFn("delete" , |
61 | cl::desc("Delete specified Globals from Module" ), |
62 | cl::cat(ExtractCat)); |
63 | |
64 | static cl::opt<bool> KeepConstInit("keep-const-init" , |
65 | cl::desc("Keep initializers of constants" ), |
66 | cl::cat(ExtractCat)); |
67 | |
68 | static cl::opt<bool> |
69 | Recursive("recursive" , cl::desc("Recursively extract all called functions" ), |
70 | cl::cat(ExtractCat)); |
71 | |
72 | // ExtractFuncs - The functions to extract from the module. |
73 | static cl::list<std::string> |
74 | ("func" , cl::desc("Specify function to extract" ), |
75 | cl::value_desc("function" ), cl::cat(ExtractCat)); |
76 | |
77 | // ExtractRegExpFuncs - The functions, matched via regular expression, to |
78 | // extract from the module. |
79 | static cl::list<std::string> |
80 | ("rfunc" , |
81 | cl::desc("Specify function(s) to extract using a " |
82 | "regular expression" ), |
83 | cl::value_desc("rfunction" ), cl::cat(ExtractCat)); |
84 | |
85 | // ExtractBlocks - The blocks to extract from the module. |
86 | static cl::list<std::string> ( |
87 | "bb" , |
88 | cl::desc( |
89 | "Specify <function, basic block1[;basic block2...]> pairs to extract.\n" |
90 | "Each pair will create a function.\n" |
91 | "If multiple basic blocks are specified in one pair,\n" |
92 | "the first block in the sequence should dominate the rest.\n" |
93 | "eg:\n" |
94 | " --bb=f:bb1;bb2 will extract one function with both bb1 and bb2;\n" |
95 | " --bb=f:bb1 --bb=f:bb2 will extract two functions, one with bb1, one " |
96 | "with bb2." ), |
97 | cl::value_desc("function:bb1[;bb2...]" ), cl::cat(ExtractCat)); |
98 | |
99 | // ExtractAlias - The alias to extract from the module. |
100 | static cl::list<std::string> |
101 | ("alias" , cl::desc("Specify alias to extract" ), |
102 | cl::value_desc("alias" ), cl::cat(ExtractCat)); |
103 | |
104 | // ExtractRegExpAliases - The aliases, matched via regular expression, to |
105 | // extract from the module. |
106 | static cl::list<std::string> |
107 | ("ralias" , |
108 | cl::desc("Specify alias(es) to extract using a " |
109 | "regular expression" ), |
110 | cl::value_desc("ralias" ), cl::cat(ExtractCat)); |
111 | |
112 | // ExtractGlobals - The globals to extract from the module. |
113 | static cl::list<std::string> |
114 | ("glob" , cl::desc("Specify global to extract" ), |
115 | cl::value_desc("global" ), cl::cat(ExtractCat)); |
116 | |
117 | // ExtractRegExpGlobals - The globals, matched via regular expression, to |
118 | // extract from the module... |
119 | static cl::list<std::string> |
120 | ("rglob" , |
121 | cl::desc("Specify global(s) to extract using a " |
122 | "regular expression" ), |
123 | cl::value_desc("rglobal" ), cl::cat(ExtractCat)); |
124 | |
125 | static cl::opt<bool> OutputAssembly("S" , |
126 | cl::desc("Write output as LLVM assembly" ), |
127 | cl::Hidden, cl::cat(ExtractCat)); |
128 | |
129 | static cl::opt<bool> PreserveBitcodeUseListOrder( |
130 | "preserve-bc-uselistorder" , |
131 | cl::desc("Preserve use-list order when writing LLVM bitcode." ), |
132 | cl::init(Val: true), cl::Hidden, cl::cat(ExtractCat)); |
133 | |
134 | static cl::opt<bool> PreserveAssemblyUseListOrder( |
135 | "preserve-ll-uselistorder" , |
136 | cl::desc("Preserve use-list order when writing LLVM assembly." ), |
137 | cl::init(Val: false), cl::Hidden, cl::cat(ExtractCat)); |
138 | |
139 | int main(int argc, char **argv) { |
140 | InitLLVM X(argc, argv); |
141 | |
142 | LLVMContext Context; |
143 | cl::HideUnrelatedOptions(Category&: ExtractCat); |
144 | cl::ParseCommandLineOptions(argc, argv, Overview: "llvm extractor\n" ); |
145 | |
146 | // Use lazy loading, since we only care about selected global values. |
147 | SMDiagnostic Err; |
148 | std::unique_ptr<Module> M = getLazyIRFileModule(Filename: InputFilename, Err, Context); |
149 | |
150 | if (!M) { |
151 | Err.print(ProgName: argv[0], S&: errs()); |
152 | return 1; |
153 | } |
154 | |
155 | // Use SetVector to avoid duplicates. |
156 | SetVector<GlobalValue *> GVs; |
157 | |
158 | // Figure out which aliases we should extract. |
159 | for (size_t i = 0, e = ExtractAliases.size(); i != e; ++i) { |
160 | GlobalAlias *GA = M->getNamedAlias(Name: ExtractAliases[i]); |
161 | if (!GA) { |
162 | errs() << argv[0] << ": program doesn't contain alias named '" |
163 | << ExtractAliases[i] << "'!\n" ; |
164 | return 1; |
165 | } |
166 | GVs.insert(X: GA); |
167 | } |
168 | |
169 | // Extract aliases via regular expression matching. |
170 | for (size_t i = 0, e = ExtractRegExpAliases.size(); i != e; ++i) { |
171 | std::string Error; |
172 | Regex RegEx(ExtractRegExpAliases[i]); |
173 | if (!RegEx.isValid(Error)) { |
174 | errs() << argv[0] << ": '" << ExtractRegExpAliases[i] << "' " |
175 | "invalid regex: " << Error; |
176 | } |
177 | bool match = false; |
178 | for (Module::alias_iterator GA = M->alias_begin(), E = M->alias_end(); |
179 | GA != E; GA++) { |
180 | if (RegEx.match(String: GA->getName())) { |
181 | GVs.insert(X: &*GA); |
182 | match = true; |
183 | } |
184 | } |
185 | if (!match) { |
186 | errs() << argv[0] << ": program doesn't contain global named '" |
187 | << ExtractRegExpAliases[i] << "'!\n" ; |
188 | return 1; |
189 | } |
190 | } |
191 | |
192 | // Figure out which globals we should extract. |
193 | for (size_t i = 0, e = ExtractGlobals.size(); i != e; ++i) { |
194 | GlobalValue *GV = M->getNamedGlobal(Name: ExtractGlobals[i]); |
195 | if (!GV) { |
196 | errs() << argv[0] << ": program doesn't contain global named '" |
197 | << ExtractGlobals[i] << "'!\n" ; |
198 | return 1; |
199 | } |
200 | GVs.insert(X: GV); |
201 | } |
202 | |
203 | // Extract globals via regular expression matching. |
204 | for (size_t i = 0, e = ExtractRegExpGlobals.size(); i != e; ++i) { |
205 | std::string Error; |
206 | Regex RegEx(ExtractRegExpGlobals[i]); |
207 | if (!RegEx.isValid(Error)) { |
208 | errs() << argv[0] << ": '" << ExtractRegExpGlobals[i] << "' " |
209 | "invalid regex: " << Error; |
210 | } |
211 | bool match = false; |
212 | for (auto &GV : M->globals()) { |
213 | if (RegEx.match(String: GV.getName())) { |
214 | GVs.insert(X: &GV); |
215 | match = true; |
216 | } |
217 | } |
218 | if (!match) { |
219 | errs() << argv[0] << ": program doesn't contain global named '" |
220 | << ExtractRegExpGlobals[i] << "'!\n" ; |
221 | return 1; |
222 | } |
223 | } |
224 | |
225 | // Figure out which functions we should extract. |
226 | for (size_t i = 0, e = ExtractFuncs.size(); i != e; ++i) { |
227 | GlobalValue *GV = M->getFunction(Name: ExtractFuncs[i]); |
228 | if (!GV) { |
229 | errs() << argv[0] << ": program doesn't contain function named '" |
230 | << ExtractFuncs[i] << "'!\n" ; |
231 | return 1; |
232 | } |
233 | GVs.insert(X: GV); |
234 | } |
235 | // Extract functions via regular expression matching. |
236 | for (size_t i = 0, e = ExtractRegExpFuncs.size(); i != e; ++i) { |
237 | std::string Error; |
238 | StringRef RegExStr = ExtractRegExpFuncs[i]; |
239 | Regex RegEx(RegExStr); |
240 | if (!RegEx.isValid(Error)) { |
241 | errs() << argv[0] << ": '" << ExtractRegExpFuncs[i] << "' " |
242 | "invalid regex: " << Error; |
243 | } |
244 | bool match = false; |
245 | for (Module::iterator F = M->begin(), E = M->end(); F != E; |
246 | F++) { |
247 | if (RegEx.match(String: F->getName())) { |
248 | GVs.insert(X: &*F); |
249 | match = true; |
250 | } |
251 | } |
252 | if (!match) { |
253 | errs() << argv[0] << ": program doesn't contain global named '" |
254 | << ExtractRegExpFuncs[i] << "'!\n" ; |
255 | return 1; |
256 | } |
257 | } |
258 | |
259 | // Figure out which BasicBlocks we should extract. |
260 | SmallVector<std::pair<Function *, SmallVector<StringRef, 16>>, 2> BBMap; |
261 | for (StringRef StrPair : ExtractBlocks) { |
262 | SmallVector<StringRef, 16> BBNames; |
263 | auto BBInfo = StrPair.split(Separator: ':'); |
264 | // Get the function. |
265 | Function *F = M->getFunction(Name: BBInfo.first); |
266 | if (!F) { |
267 | errs() << argv[0] << ": program doesn't contain a function named '" |
268 | << BBInfo.first << "'!\n" ; |
269 | return 1; |
270 | } |
271 | // Add the function to the materialize list, and store the basic block names |
272 | // to check after materialization. |
273 | GVs.insert(X: F); |
274 | BBInfo.second.split(A&: BBNames, Separator: ';', /*MaxSplit=*/-1, /*KeepEmpty=*/false); |
275 | BBMap.push_back(Elt: {F, std::move(BBNames)}); |
276 | } |
277 | |
278 | // Use *argv instead of argv[0] to work around a wrong GCC warning. |
279 | ExitOnError ExitOnErr(std::string(*argv) + ": error reading input: " ); |
280 | |
281 | if (Recursive) { |
282 | std::vector<llvm::Function *> Workqueue; |
283 | for (GlobalValue *GV : GVs) { |
284 | if (auto *F = dyn_cast<Function>(Val: GV)) { |
285 | Workqueue.push_back(x: F); |
286 | } |
287 | } |
288 | while (!Workqueue.empty()) { |
289 | Function *F = &*Workqueue.back(); |
290 | Workqueue.pop_back(); |
291 | ExitOnErr(F->materialize()); |
292 | for (auto &BB : *F) { |
293 | for (auto &I : BB) { |
294 | CallBase *CB = dyn_cast<CallBase>(Val: &I); |
295 | if (!CB) |
296 | continue; |
297 | Function *CF = CB->getCalledFunction(); |
298 | if (!CF) |
299 | continue; |
300 | if (CF->isDeclaration() || GVs.count(key: CF)) |
301 | continue; |
302 | GVs.insert(X: CF); |
303 | Workqueue.push_back(x: CF); |
304 | } |
305 | } |
306 | } |
307 | } |
308 | |
309 | auto Materialize = [&](GlobalValue &GV) { ExitOnErr(GV.materialize()); }; |
310 | |
311 | // Materialize requisite global values. |
312 | if (!DeleteFn) { |
313 | for (size_t i = 0, e = GVs.size(); i != e; ++i) |
314 | Materialize(*GVs[i]); |
315 | } else { |
316 | // Deleting. Materialize every GV that's *not* in GVs. |
317 | SmallPtrSet<GlobalValue *, 8> GVSet(GVs.begin(), GVs.end()); |
318 | for (auto &F : *M) { |
319 | if (!GVSet.count(Ptr: &F)) |
320 | Materialize(F); |
321 | } |
322 | } |
323 | |
324 | { |
325 | std::vector<GlobalValue *> Gvs(GVs.begin(), GVs.end()); |
326 | LoopAnalysisManager LAM; |
327 | FunctionAnalysisManager FAM; |
328 | CGSCCAnalysisManager CGAM; |
329 | ModuleAnalysisManager MAM; |
330 | |
331 | PassBuilder PB; |
332 | |
333 | PB.registerModuleAnalyses(MAM); |
334 | PB.registerCGSCCAnalyses(CGAM); |
335 | PB.registerFunctionAnalyses(FAM); |
336 | PB.registerLoopAnalyses(LAM); |
337 | PB.crossRegisterProxies(LAM, FAM, CGAM, MAM); |
338 | |
339 | ModulePassManager PM; |
340 | PM.addPass(Pass: ExtractGVPass(Gvs, DeleteFn, KeepConstInit)); |
341 | PM.run(IR&: *M, AM&: MAM); |
342 | |
343 | // Now that we have all the GVs we want, mark the module as fully |
344 | // materialized. |
345 | // FIXME: should the GVExtractionPass handle this? |
346 | ExitOnErr(M->materializeAll()); |
347 | } |
348 | |
349 | // Extract the specified basic blocks from the module and erase the existing |
350 | // functions. |
351 | if (!ExtractBlocks.empty()) { |
352 | // Figure out which BasicBlocks we should extract. |
353 | std::vector<std::vector<BasicBlock *>> GroupOfBBs; |
354 | for (auto &P : BBMap) { |
355 | std::vector<BasicBlock *> BBs; |
356 | for (StringRef BBName : P.second) { |
357 | // The function has been materialized, so add its matching basic blocks |
358 | // to the block extractor list, or fail if a name is not found. |
359 | auto Res = llvm::find_if(Range&: *P.first, P: [&](const BasicBlock &BB) { |
360 | return BB.getName() == BBName; |
361 | }); |
362 | if (Res == P.first->end()) { |
363 | errs() << argv[0] << ": function " << P.first->getName() |
364 | << " doesn't contain a basic block named '" << BBName |
365 | << "'!\n" ; |
366 | return 1; |
367 | } |
368 | BBs.push_back(x: &*Res); |
369 | } |
370 | GroupOfBBs.push_back(x: BBs); |
371 | } |
372 | |
373 | LoopAnalysisManager LAM; |
374 | FunctionAnalysisManager FAM; |
375 | CGSCCAnalysisManager CGAM; |
376 | ModuleAnalysisManager MAM; |
377 | |
378 | PassBuilder PB; |
379 | |
380 | PB.registerModuleAnalyses(MAM); |
381 | PB.registerCGSCCAnalyses(CGAM); |
382 | PB.registerFunctionAnalyses(FAM); |
383 | PB.registerLoopAnalyses(LAM); |
384 | PB.crossRegisterProxies(LAM, FAM, CGAM, MAM); |
385 | |
386 | ModulePassManager PM; |
387 | PM.addPass(Pass: BlockExtractorPass(std::move(GroupOfBBs), true)); |
388 | PM.run(IR&: *M, AM&: MAM); |
389 | } |
390 | |
391 | // In addition to deleting all other functions, we also want to spiff it |
392 | // up a little bit. Do this now. |
393 | |
394 | LoopAnalysisManager LAM; |
395 | FunctionAnalysisManager FAM; |
396 | CGSCCAnalysisManager CGAM; |
397 | ModuleAnalysisManager MAM; |
398 | |
399 | PassBuilder PB; |
400 | |
401 | PB.registerModuleAnalyses(MAM); |
402 | PB.registerCGSCCAnalyses(CGAM); |
403 | PB.registerFunctionAnalyses(FAM); |
404 | PB.registerLoopAnalyses(LAM); |
405 | PB.crossRegisterProxies(LAM, FAM, CGAM, MAM); |
406 | |
407 | ModulePassManager PM; |
408 | if (!DeleteFn) |
409 | PM.addPass(Pass: GlobalDCEPass()); |
410 | PM.addPass(Pass: StripDeadDebugInfoPass()); |
411 | PM.addPass(Pass: StripDeadPrototypesPass()); |
412 | |
413 | std::error_code EC; |
414 | ToolOutputFile Out(OutputFilename, EC, sys::fs::OF_None); |
415 | if (EC) { |
416 | errs() << EC.message() << '\n'; |
417 | return 1; |
418 | } |
419 | |
420 | if (OutputAssembly) |
421 | PM.addPass(Pass: PrintModulePass(Out.os(), "" , PreserveAssemblyUseListOrder)); |
422 | else if (Force || !CheckBitcodeOutputToConsole(stream_to_check&: Out.os())) |
423 | PM.addPass(Pass: BitcodeWriterPass(Out.os(), PreserveBitcodeUseListOrder)); |
424 | |
425 | PM.run(IR&: *M, AM&: MAM); |
426 | |
427 | // Declare success. |
428 | Out.keep(); |
429 | |
430 | return 0; |
431 | } |
432 | |