| 1 | //===- llvm-extract.cpp - LLVM function extraction utility ----------------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | // |
| 9 | // This utility changes the input module to only contain a single function, |
| 10 | // which is primarily used for debugging transformations. |
| 11 | // |
| 12 | //===----------------------------------------------------------------------===// |
| 13 | |
| 14 | #include "llvm/ADT/SetVector.h" |
| 15 | #include "llvm/ADT/SmallPtrSet.h" |
| 16 | #include "llvm/Bitcode/BitcodeWriterPass.h" |
| 17 | #include "llvm/IR/DataLayout.h" |
| 18 | #include "llvm/IR/IRPrintingPasses.h" |
| 19 | #include "llvm/IR/Instructions.h" |
| 20 | #include "llvm/IR/LLVMContext.h" |
| 21 | #include "llvm/IR/Module.h" |
| 22 | #include "llvm/IRPrinter/IRPrintingPasses.h" |
| 23 | #include "llvm/IRReader/IRReader.h" |
| 24 | #include "llvm/Passes/PassBuilder.h" |
| 25 | #include "llvm/Support/CommandLine.h" |
| 26 | #include "llvm/Support/Error.h" |
| 27 | #include "llvm/Support/FileSystem.h" |
| 28 | #include "llvm/Support/InitLLVM.h" |
| 29 | #include "llvm/Support/Regex.h" |
| 30 | #include "llvm/Support/SourceMgr.h" |
| 31 | #include "llvm/Support/SystemUtils.h" |
| 32 | #include "llvm/Support/ToolOutputFile.h" |
| 33 | #include "llvm/Transforms/IPO.h" |
| 34 | #include "llvm/Transforms/IPO/BlockExtractor.h" |
| 35 | #include "llvm/Transforms/IPO/ExtractGV.h" |
| 36 | #include "llvm/Transforms/IPO/GlobalDCE.h" |
| 37 | #include "llvm/Transforms/IPO/StripDeadPrototypes.h" |
| 38 | #include "llvm/Transforms/IPO/StripSymbols.h" |
| 39 | #include <memory> |
| 40 | #include <utility> |
| 41 | |
| 42 | using namespace llvm; |
| 43 | |
| 44 | static cl::OptionCategory ("llvm-extract Options" ); |
| 45 | |
| 46 | // InputFilename - The filename to read from. |
| 47 | static cl::opt<std::string> InputFilename(cl::Positional, |
| 48 | cl::desc("<input bitcode file>" ), |
| 49 | cl::init(Val: "-" ), |
| 50 | cl::value_desc("filename" )); |
| 51 | |
| 52 | static cl::opt<std::string> OutputFilename("o" , |
| 53 | cl::desc("Specify output filename" ), |
| 54 | cl::value_desc("filename" ), |
| 55 | cl::init(Val: "-" ), cl::cat(ExtractCat)); |
| 56 | |
| 57 | static cl::opt<bool> Force("f" , cl::desc("Enable binary output on terminals" ), |
| 58 | cl::cat(ExtractCat)); |
| 59 | |
| 60 | static cl::opt<bool> DeleteFn("delete" , |
| 61 | cl::desc("Delete specified Globals from Module" ), |
| 62 | cl::cat(ExtractCat)); |
| 63 | |
| 64 | static cl::opt<bool> KeepConstInit("keep-const-init" , |
| 65 | cl::desc("Keep initializers of constants" ), |
| 66 | cl::cat(ExtractCat)); |
| 67 | |
| 68 | static cl::opt<bool> |
| 69 | Recursive("recursive" , cl::desc("Recursively extract all called functions" ), |
| 70 | cl::cat(ExtractCat)); |
| 71 | |
| 72 | // ExtractFuncs - The functions to extract from the module. |
| 73 | static cl::list<std::string> |
| 74 | ("func" , cl::desc("Specify function to extract" ), |
| 75 | cl::value_desc("function" ), cl::cat(ExtractCat)); |
| 76 | |
| 77 | // ExtractRegExpFuncs - The functions, matched via regular expression, to |
| 78 | // extract from the module. |
| 79 | static cl::list<std::string> |
| 80 | ("rfunc" , |
| 81 | cl::desc("Specify function(s) to extract using a " |
| 82 | "regular expression" ), |
| 83 | cl::value_desc("rfunction" ), cl::cat(ExtractCat)); |
| 84 | |
| 85 | // ExtractBlocks - The blocks to extract from the module. |
| 86 | static cl::list<std::string> ( |
| 87 | "bb" , |
| 88 | cl::desc( |
| 89 | "Specify <function, basic block1[;basic block2...]> pairs to extract.\n" |
| 90 | "Each pair will create a function.\n" |
| 91 | "If multiple basic blocks are specified in one pair,\n" |
| 92 | "the first block in the sequence should dominate the rest.\n" |
| 93 | "If an unnamed basic block is to be extracted,\n" |
| 94 | "'%' should be added before the basic block variable names.\n" |
| 95 | "eg:\n" |
| 96 | " --bb=f:bb1;bb2 will extract one function with both bb1 and bb2;\n" |
| 97 | " --bb=f:bb1 --bb=f:bb2 will extract two functions, one with bb1, one " |
| 98 | "with bb2.\n" |
| 99 | " --bb=f:%1 will extract one function with basic block 1;" ), |
| 100 | cl::value_desc("function:bb1[;bb2...]" ), cl::cat(ExtractCat)); |
| 101 | |
| 102 | // ExtractAlias - The alias to extract from the module. |
| 103 | static cl::list<std::string> |
| 104 | ("alias" , cl::desc("Specify alias to extract" ), |
| 105 | cl::value_desc("alias" ), cl::cat(ExtractCat)); |
| 106 | |
| 107 | // ExtractRegExpAliases - The aliases, matched via regular expression, to |
| 108 | // extract from the module. |
| 109 | static cl::list<std::string> |
| 110 | ("ralias" , |
| 111 | cl::desc("Specify alias(es) to extract using a " |
| 112 | "regular expression" ), |
| 113 | cl::value_desc("ralias" ), cl::cat(ExtractCat)); |
| 114 | |
| 115 | // ExtractGlobals - The globals to extract from the module. |
| 116 | static cl::list<std::string> |
| 117 | ("glob" , cl::desc("Specify global to extract" ), |
| 118 | cl::value_desc("global" ), cl::cat(ExtractCat)); |
| 119 | |
| 120 | // ExtractRegExpGlobals - The globals, matched via regular expression, to |
| 121 | // extract from the module... |
| 122 | static cl::list<std::string> |
| 123 | ("rglob" , |
| 124 | cl::desc("Specify global(s) to extract using a " |
| 125 | "regular expression" ), |
| 126 | cl::value_desc("rglobal" ), cl::cat(ExtractCat)); |
| 127 | |
| 128 | static cl::opt<bool> OutputAssembly("S" , |
| 129 | cl::desc("Write output as LLVM assembly" ), |
| 130 | cl::Hidden, cl::cat(ExtractCat)); |
| 131 | |
| 132 | static cl::opt<bool> PreserveBitcodeUseListOrder( |
| 133 | "preserve-bc-uselistorder" , |
| 134 | cl::desc("Preserve use-list order when writing LLVM bitcode." ), |
| 135 | cl::init(Val: true), cl::Hidden, cl::cat(ExtractCat)); |
| 136 | |
| 137 | static cl::opt<bool> PreserveAssemblyUseListOrder( |
| 138 | "preserve-ll-uselistorder" , |
| 139 | cl::desc("Preserve use-list order when writing LLVM assembly." ), |
| 140 | cl::init(Val: false), cl::Hidden, cl::cat(ExtractCat)); |
| 141 | |
| 142 | int main(int argc, char **argv) { |
| 143 | InitLLVM X(argc, argv); |
| 144 | |
| 145 | LLVMContext Context; |
| 146 | cl::HideUnrelatedOptions(Category&: ExtractCat); |
| 147 | cl::ParseCommandLineOptions(argc, argv, Overview: "llvm extractor\n" ); |
| 148 | |
| 149 | // Use lazy loading, since we only care about selected global values. |
| 150 | SMDiagnostic Err; |
| 151 | std::unique_ptr<Module> M = getLazyIRFileModule(Filename: InputFilename, Err, Context); |
| 152 | |
| 153 | if (!M) { |
| 154 | Err.print(ProgName: argv[0], S&: errs()); |
| 155 | return 1; |
| 156 | } |
| 157 | |
| 158 | // Use SetVector to avoid duplicates. |
| 159 | SetVector<GlobalValue *> GVs; |
| 160 | |
| 161 | // Figure out which aliases we should extract. |
| 162 | for (size_t i = 0, e = ExtractAliases.size(); i != e; ++i) { |
| 163 | GlobalAlias *GA = M->getNamedAlias(Name: ExtractAliases[i]); |
| 164 | if (!GA) { |
| 165 | errs() << argv[0] << ": program doesn't contain alias named '" |
| 166 | << ExtractAliases[i] << "'!\n" ; |
| 167 | return 1; |
| 168 | } |
| 169 | GVs.insert(X: GA); |
| 170 | } |
| 171 | |
| 172 | // Extract aliases via regular expression matching. |
| 173 | for (size_t i = 0, e = ExtractRegExpAliases.size(); i != e; ++i) { |
| 174 | std::string Error; |
| 175 | Regex RegEx(ExtractRegExpAliases[i]); |
| 176 | if (!RegEx.isValid(Error)) { |
| 177 | errs() << argv[0] << ": '" << ExtractRegExpAliases[i] << "' " |
| 178 | "invalid regex: " << Error; |
| 179 | } |
| 180 | bool match = false; |
| 181 | for (Module::alias_iterator GA = M->alias_begin(), E = M->alias_end(); |
| 182 | GA != E; GA++) { |
| 183 | if (RegEx.match(String: GA->getName())) { |
| 184 | GVs.insert(X: &*GA); |
| 185 | match = true; |
| 186 | } |
| 187 | } |
| 188 | if (!match) { |
| 189 | errs() << argv[0] << ": program doesn't contain global named '" |
| 190 | << ExtractRegExpAliases[i] << "'!\n" ; |
| 191 | return 1; |
| 192 | } |
| 193 | } |
| 194 | |
| 195 | // Figure out which globals we should extract. |
| 196 | for (size_t i = 0, e = ExtractGlobals.size(); i != e; ++i) { |
| 197 | GlobalValue *GV = M->getNamedGlobal(Name: ExtractGlobals[i]); |
| 198 | if (!GV) { |
| 199 | errs() << argv[0] << ": program doesn't contain global named '" |
| 200 | << ExtractGlobals[i] << "'!\n" ; |
| 201 | return 1; |
| 202 | } |
| 203 | GVs.insert(X: GV); |
| 204 | } |
| 205 | |
| 206 | // Extract globals via regular expression matching. |
| 207 | for (size_t i = 0, e = ExtractRegExpGlobals.size(); i != e; ++i) { |
| 208 | std::string Error; |
| 209 | Regex RegEx(ExtractRegExpGlobals[i]); |
| 210 | if (!RegEx.isValid(Error)) { |
| 211 | errs() << argv[0] << ": '" << ExtractRegExpGlobals[i] << "' " |
| 212 | "invalid regex: " << Error; |
| 213 | } |
| 214 | bool match = false; |
| 215 | for (auto &GV : M->globals()) { |
| 216 | if (RegEx.match(String: GV.getName())) { |
| 217 | GVs.insert(X: &GV); |
| 218 | match = true; |
| 219 | } |
| 220 | } |
| 221 | if (!match) { |
| 222 | errs() << argv[0] << ": program doesn't contain global named '" |
| 223 | << ExtractRegExpGlobals[i] << "'!\n" ; |
| 224 | return 1; |
| 225 | } |
| 226 | } |
| 227 | |
| 228 | // Figure out which functions we should extract. |
| 229 | for (size_t i = 0, e = ExtractFuncs.size(); i != e; ++i) { |
| 230 | GlobalValue *GV = M->getFunction(Name: ExtractFuncs[i]); |
| 231 | if (!GV) { |
| 232 | errs() << argv[0] << ": program doesn't contain function named '" |
| 233 | << ExtractFuncs[i] << "'!\n" ; |
| 234 | return 1; |
| 235 | } |
| 236 | GVs.insert(X: GV); |
| 237 | } |
| 238 | // Extract functions via regular expression matching. |
| 239 | for (size_t i = 0, e = ExtractRegExpFuncs.size(); i != e; ++i) { |
| 240 | std::string Error; |
| 241 | StringRef RegExStr = ExtractRegExpFuncs[i]; |
| 242 | Regex RegEx(RegExStr); |
| 243 | if (!RegEx.isValid(Error)) { |
| 244 | errs() << argv[0] << ": '" << ExtractRegExpFuncs[i] << "' " |
| 245 | "invalid regex: " << Error; |
| 246 | } |
| 247 | bool match = false; |
| 248 | for (Module::iterator F = M->begin(), E = M->end(); F != E; |
| 249 | F++) { |
| 250 | if (RegEx.match(String: F->getName())) { |
| 251 | GVs.insert(X: &*F); |
| 252 | match = true; |
| 253 | } |
| 254 | } |
| 255 | if (!match) { |
| 256 | errs() << argv[0] << ": program doesn't contain global named '" |
| 257 | << ExtractRegExpFuncs[i] << "'!\n" ; |
| 258 | return 1; |
| 259 | } |
| 260 | } |
| 261 | |
| 262 | // Figure out which BasicBlocks we should extract. |
| 263 | SmallVector<std::pair<Function *, SmallVector<StringRef, 16>>, 2> BBMap; |
| 264 | for (StringRef StrPair : ExtractBlocks) { |
| 265 | SmallVector<StringRef, 16> BBNames; |
| 266 | auto BBInfo = StrPair.split(Separator: ':'); |
| 267 | // Get the function. |
| 268 | Function *F = M->getFunction(Name: BBInfo.first); |
| 269 | if (!F) { |
| 270 | errs() << argv[0] << ": program doesn't contain a function named '" |
| 271 | << BBInfo.first << "'!\n" ; |
| 272 | return 1; |
| 273 | } |
| 274 | // Add the function to the materialize list, and store the basic block names |
| 275 | // to check after materialization. |
| 276 | GVs.insert(X: F); |
| 277 | BBInfo.second.split(A&: BBNames, Separator: ';', /*MaxSplit=*/-1, /*KeepEmpty=*/false); |
| 278 | BBMap.push_back(Elt: {F, std::move(BBNames)}); |
| 279 | } |
| 280 | |
| 281 | // Use *argv instead of argv[0] to work around a wrong GCC warning. |
| 282 | ExitOnError ExitOnErr(std::string(*argv) + ": error reading input: " ); |
| 283 | |
| 284 | if (Recursive) { |
| 285 | std::vector<llvm::Function *> Workqueue; |
| 286 | for (GlobalValue *GV : GVs) { |
| 287 | if (auto *F = dyn_cast<Function>(Val: GV)) { |
| 288 | Workqueue.push_back(x: F); |
| 289 | } |
| 290 | } |
| 291 | while (!Workqueue.empty()) { |
| 292 | Function *F = &*Workqueue.back(); |
| 293 | Workqueue.pop_back(); |
| 294 | ExitOnErr(F->materialize()); |
| 295 | for (auto &BB : *F) { |
| 296 | for (auto &I : BB) { |
| 297 | CallBase *CB = dyn_cast<CallBase>(Val: &I); |
| 298 | if (!CB) |
| 299 | continue; |
| 300 | Function *CF = CB->getCalledFunction(); |
| 301 | if (!CF) |
| 302 | continue; |
| 303 | if (CF->isDeclaration() || !GVs.insert(X: CF)) |
| 304 | continue; |
| 305 | Workqueue.push_back(x: CF); |
| 306 | } |
| 307 | } |
| 308 | } |
| 309 | } |
| 310 | |
| 311 | auto Materialize = [&](GlobalValue &GV) { ExitOnErr(GV.materialize()); }; |
| 312 | |
| 313 | // Materialize requisite global values. |
| 314 | if (!DeleteFn) { |
| 315 | for (size_t i = 0, e = GVs.size(); i != e; ++i) |
| 316 | Materialize(*GVs[i]); |
| 317 | } else { |
| 318 | // Deleting. Materialize every GV that's *not* in GVs. |
| 319 | SmallPtrSet<GlobalValue *, 8> GVSet(llvm::from_range, GVs); |
| 320 | for (auto &F : *M) { |
| 321 | if (!GVSet.count(Ptr: &F)) |
| 322 | Materialize(F); |
| 323 | } |
| 324 | } |
| 325 | |
| 326 | { |
| 327 | std::vector<GlobalValue *> Gvs(GVs.begin(), GVs.end()); |
| 328 | LoopAnalysisManager LAM; |
| 329 | FunctionAnalysisManager FAM; |
| 330 | CGSCCAnalysisManager CGAM; |
| 331 | ModuleAnalysisManager MAM; |
| 332 | |
| 333 | PassBuilder PB; |
| 334 | |
| 335 | PB.registerModuleAnalyses(MAM); |
| 336 | PB.registerCGSCCAnalyses(CGAM); |
| 337 | PB.registerFunctionAnalyses(FAM); |
| 338 | PB.registerLoopAnalyses(LAM); |
| 339 | PB.crossRegisterProxies(LAM, FAM, CGAM, MAM); |
| 340 | |
| 341 | ModulePassManager PM; |
| 342 | PM.addPass(Pass: ExtractGVPass(Gvs, DeleteFn, KeepConstInit)); |
| 343 | PM.run(IR&: *M, AM&: MAM); |
| 344 | |
| 345 | // Now that we have all the GVs we want, mark the module as fully |
| 346 | // materialized. |
| 347 | // FIXME: should the GVExtractionPass handle this? |
| 348 | ExitOnErr(M->materializeAll()); |
| 349 | } |
| 350 | |
| 351 | // Extract the specified basic blocks from the module and erase the existing |
| 352 | // functions. |
| 353 | if (!ExtractBlocks.empty()) { |
| 354 | // Figure out which BasicBlocks we should extract. |
| 355 | std::vector<std::vector<BasicBlock *>> GroupOfBBs; |
| 356 | for (auto &P : BBMap) { |
| 357 | std::vector<BasicBlock *> BBs; |
| 358 | for (StringRef BBName : P.second) { |
| 359 | // The function has been materialized, so add its matching basic blocks |
| 360 | // to the block extractor list, or fail if a name is not found. |
| 361 | auto Res = llvm::find_if(Range&: *P.first, P: [&](const BasicBlock &BB) { |
| 362 | return BB.getNameOrAsOperand() == BBName; |
| 363 | }); |
| 364 | if (Res == P.first->end()) { |
| 365 | errs() << argv[0] << ": function " << P.first->getName() |
| 366 | << " doesn't contain a basic block named '" << BBName |
| 367 | << "'!\n" ; |
| 368 | return 1; |
| 369 | } |
| 370 | BBs.push_back(x: &*Res); |
| 371 | } |
| 372 | GroupOfBBs.push_back(x: BBs); |
| 373 | } |
| 374 | |
| 375 | LoopAnalysisManager LAM; |
| 376 | FunctionAnalysisManager FAM; |
| 377 | CGSCCAnalysisManager CGAM; |
| 378 | ModuleAnalysisManager MAM; |
| 379 | |
| 380 | PassBuilder PB; |
| 381 | |
| 382 | PB.registerModuleAnalyses(MAM); |
| 383 | PB.registerCGSCCAnalyses(CGAM); |
| 384 | PB.registerFunctionAnalyses(FAM); |
| 385 | PB.registerLoopAnalyses(LAM); |
| 386 | PB.crossRegisterProxies(LAM, FAM, CGAM, MAM); |
| 387 | |
| 388 | ModulePassManager PM; |
| 389 | PM.addPass(Pass: BlockExtractorPass(std::move(GroupOfBBs), true)); |
| 390 | PM.run(IR&: *M, AM&: MAM); |
| 391 | } |
| 392 | |
| 393 | // In addition to deleting all other functions, we also want to spiff it |
| 394 | // up a little bit. Do this now. |
| 395 | |
| 396 | LoopAnalysisManager LAM; |
| 397 | FunctionAnalysisManager FAM; |
| 398 | CGSCCAnalysisManager CGAM; |
| 399 | ModuleAnalysisManager MAM; |
| 400 | |
| 401 | PassBuilder PB; |
| 402 | |
| 403 | PB.registerModuleAnalyses(MAM); |
| 404 | PB.registerCGSCCAnalyses(CGAM); |
| 405 | PB.registerFunctionAnalyses(FAM); |
| 406 | PB.registerLoopAnalyses(LAM); |
| 407 | PB.crossRegisterProxies(LAM, FAM, CGAM, MAM); |
| 408 | |
| 409 | ModulePassManager PM; |
| 410 | if (!DeleteFn) |
| 411 | PM.addPass(Pass: GlobalDCEPass()); |
| 412 | PM.addPass(Pass: StripDeadDebugInfoPass()); |
| 413 | PM.addPass(Pass: StripDeadPrototypesPass()); |
| 414 | PM.addPass(Pass: StripDeadCGProfilePass()); |
| 415 | |
| 416 | std::error_code EC; |
| 417 | ToolOutputFile Out(OutputFilename, EC, sys::fs::OF_None); |
| 418 | if (EC) { |
| 419 | errs() << EC.message() << '\n'; |
| 420 | return 1; |
| 421 | } |
| 422 | |
| 423 | if (OutputAssembly) |
| 424 | PM.addPass(Pass: PrintModulePass(Out.os(), "" , PreserveAssemblyUseListOrder)); |
| 425 | else if (Force || !CheckBitcodeOutputToConsole(stream_to_check&: Out.os())) |
| 426 | PM.addPass(Pass: BitcodeWriterPass(Out.os(), PreserveBitcodeUseListOrder)); |
| 427 | |
| 428 | PM.run(IR&: *M, AM&: MAM); |
| 429 | |
| 430 | // Declare success. |
| 431 | Out.keep(); |
| 432 | |
| 433 | return 0; |
| 434 | } |
| 435 | |