1//===-- llvm-split: command line tool for testing module splitting --------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This program can be used to test the llvm::SplitModule and
10// TargetMachine::splitModule functions.
11//
12//===----------------------------------------------------------------------===//
13
14#include "llvm/ADT/SmallString.h"
15#include "llvm/ADT/SmallVector.h"
16#include "llvm/ADT/StringExtras.h"
17#include "llvm/Bitcode/BitcodeWriter.h"
18#include "llvm/IR/LLVMContext.h"
19#include "llvm/IR/PassInstrumentation.h"
20#include "llvm/IR/PassManager.h"
21#include "llvm/IR/Verifier.h"
22#include "llvm/IRReader/IRReader.h"
23#include "llvm/MC/TargetRegistry.h"
24#include "llvm/Support/CommandLine.h"
25#include "llvm/Support/FileSystem.h"
26#include "llvm/Support/FormatVariadic.h"
27#include "llvm/Support/InitLLVM.h"
28#include "llvm/Support/SourceMgr.h"
29#include "llvm/Support/TargetSelect.h"
30#include "llvm/Support/ToolOutputFile.h"
31#include "llvm/Support/WithColor.h"
32#include "llvm/Support/raw_ostream.h"
33#include "llvm/Target/TargetMachine.h"
34#include "llvm/TargetParser/Triple.h"
35#include "llvm/Transforms/IPO/GlobalDCE.h"
36#include "llvm/Transforms/Utils/SplitModule.h"
37#include "llvm/Transforms/Utils/SplitModuleByCategory.h"
38
39using namespace llvm;
40
41static cl::OptionCategory SplitCategory("Split Options");
42
43static cl::opt<std::string> InputFilename(cl::Positional,
44 cl::desc("<input bitcode file>"),
45 cl::init(Val: "-"),
46 cl::value_desc("filename"),
47 cl::cat(SplitCategory));
48
49static cl::opt<std::string> OutputFilename("o",
50 cl::desc("Override output filename"),
51 cl::value_desc("filename"),
52 cl::cat(SplitCategory));
53
54static cl::opt<unsigned> NumOutputs("j", cl::Prefix, cl::init(Val: 2),
55 cl::desc("Number of output files"),
56 cl::cat(SplitCategory));
57
58static cl::opt<bool>
59 PreserveLocals("preserve-locals", cl::Prefix, cl::init(Val: false),
60 cl::desc("Split without externalizing locals"),
61 cl::cat(SplitCategory));
62
63static cl::opt<bool>
64 RoundRobin("round-robin", cl::Prefix, cl::init(Val: false),
65 cl::desc("Use round-robin distribution of functions to "
66 "modules instead of the default name-hash-based one"),
67 cl::cat(SplitCategory));
68
69static cl::opt<std::string>
70 MTriple("mtriple",
71 cl::desc("Target triple. When present, a TargetMachine is created "
72 "and TargetMachine::splitModule is used instead of the "
73 "common SplitModule logic."),
74 cl::value_desc("triple"), cl::cat(SplitCategory));
75
76static cl::opt<std::string>
77 MCPU("mcpu", cl::desc("Target CPU, ignored if --mtriple is not used"),
78 cl::value_desc("cpu"), cl::cat(SplitCategory));
79
80enum class SplitByCategoryType {
81 SBCT_ByAttribute,
82 SBCT_ByKernel,
83 SBCT_None,
84};
85
86static cl::opt<SplitByCategoryType> SplitByCategory(
87 "split-by-category",
88 cl::desc("Split by category. If present, splitting by category is used "
89 "with the specified categorization type."),
90 cl::Optional, cl::init(Val: SplitByCategoryType::SBCT_None),
91 cl::values(clEnumValN(SplitByCategoryType::SBCT_ByAttribute, "attribute",
92 "one output module per unique value of the function "
93 "attribute named by --category-attribute"),
94 clEnumValN(SplitByCategoryType::SBCT_ByKernel, "kernel",
95 "one output module per kernel")),
96 cl::cat(SplitCategory));
97
98static cl::opt<std::string>
99 CategoryAttribute("category-attribute",
100 cl::desc("Function attribute name to use when splitting "
101 "with -split-by-category=attribute"),
102 cl::value_desc("name"), cl::cat(SplitCategory));
103
104static cl::opt<bool> OutputAssembly{
105 "S", cl::desc("Write output as LLVM assembly"), cl::cat(SplitCategory)};
106
107void writeStringToFile(StringRef Content, StringRef Path) {
108 std::error_code EC;
109 raw_fd_ostream OS(Path, EC);
110 if (EC) {
111 errs() << formatv(Fmt: "error opening file: {0}, error: {1}\n", Vals&: Path,
112 Vals: EC.message());
113 exit(status: 1);
114 }
115
116 OS << Content << "\n";
117}
118
119void writeModuleToFile(const Module &M, StringRef Path, bool OutputAssembly) {
120 int FD = -1;
121 if (std::error_code EC = sys::fs::openFileForWrite(Name: Path, ResultFD&: FD)) {
122 errs() << formatv(Fmt: "error opening file: {0}, error: {1}", Vals&: Path, Vals: EC.message())
123 << '\n';
124 exit(status: 1);
125 }
126
127 raw_fd_ostream OS(FD, /*ShouldClose*/ true);
128 if (OutputAssembly)
129 M.print(OS, /*AssemblyAnnotationWriter*/ AAW: nullptr);
130 else
131 WriteBitcodeToFile(M, Out&: OS);
132}
133
134/// EntryPointCategorizer is used for splitting by category either by a named
135/// function attribute or by kernels. It doesn't provide categories for
136/// functions other than kernels. Categorizer computes a string key for the
137/// given Function and records the association between the string key and an
138/// integer category. If a string key already belongs to some category then the
139/// corresponding integer category is returned.
140class EntryPointCategorizer {
141public:
142 EntryPointCategorizer(SplitByCategoryType Type, StringRef AttributeName)
143 : Type(Type), AttributeName(AttributeName) {}
144
145 EntryPointCategorizer() = delete;
146 EntryPointCategorizer(EntryPointCategorizer &) = delete;
147 EntryPointCategorizer &operator=(const EntryPointCategorizer &) = delete;
148 EntryPointCategorizer(EntryPointCategorizer &&) = default;
149 EntryPointCategorizer &operator=(EntryPointCategorizer &&) = default;
150
151 /// Returns integer specifying the category for the given \p F.
152 /// If the given function isn't a kernel then returns std::nullopt.
153 std::optional<int> operator()(const Function &F) {
154 if (!isEntryPoint(F))
155 return std::nullopt; // skip the function.
156
157 auto StringKey = computeFunctionCategory(Type, F);
158 if (auto it = StrKeyToID.find(Val: StringRef(StringKey)); it != StrKeyToID.end())
159 return it->second;
160
161 int ID = static_cast<int>(StrKeyToID.size());
162 return StrKeyToID.try_emplace(Key: std::move(StringKey), Args&: ID).first->second;
163 }
164
165private:
166 static bool isEntryPoint(const Function &F) {
167 if (F.isDeclaration())
168 return false;
169
170 return F.hasKernelCallingConv();
171 }
172
173 SmallString<0> computeFunctionCategory(SplitByCategoryType Type,
174 const Function &F) {
175 SmallString<0> Key;
176 switch (Type) {
177 case SplitByCategoryType::SBCT_ByKernel:
178 Key = F.getName().str();
179 break;
180 case SplitByCategoryType::SBCT_ByAttribute:
181 Key = F.getFnAttribute(Kind: AttributeName).getValueAsString().str();
182 break;
183 default:
184 llvm_unreachable("unexpected mode.");
185 }
186
187 return Key;
188 }
189
190private:
191 struct KeyInfo {
192 static bool isEqual(const SmallString<0> &LHS, const SmallString<0> &RHS) {
193 return LHS == RHS;
194 }
195
196 static unsigned getHashValue(const SmallString<0> &S) {
197 return llvm::hash_value(S: StringRef(S));
198 }
199 };
200
201 SplitByCategoryType Type;
202 std::string AttributeName;
203 DenseMap<SmallString<0>, int, KeyInfo> StrKeyToID;
204};
205
206void cleanupModule(Module &M) {
207 ModuleAnalysisManager MAM;
208 MAM.registerPass(PassBuilder: [&] { return PassInstrumentationAnalysis(); });
209 ModulePassManager MPM;
210 MPM.addPass(Pass: GlobalDCEPass()); // Delete unreachable globals.
211 MPM.run(IR&: M, AM&: MAM);
212}
213
214Error runSplitModuleByCategory(std::unique_ptr<Module> M) {
215 if (SplitByCategory == SplitByCategoryType::SBCT_ByAttribute &&
216 CategoryAttribute.empty())
217 return createStringError(
218 Fmt: "-split-by-category=attribute requires --category-attribute=<name>");
219
220 size_t OutputID = 0;
221 auto PostSplitCallback = [&](std::unique_ptr<Module> MPart) -> Error {
222 if (verifyModule(M: *MPart)) {
223 errs() << "Broken Module!\n";
224 exit(status: 1);
225 }
226
227 // TODO: DCE is a crucial pass since it removes unused declarations.
228 // At the moment, LIT checking can't be perfomed without DCE.
229 cleanupModule(M&: *MPart);
230 size_t ID = OutputID;
231 ++OutputID;
232 StringRef ModuleSuffix = OutputAssembly ? ".ll" : ".bc";
233 std::string ModulePath =
234 (Twine(OutputFilename) + "_" + Twine(ID) + ModuleSuffix).str();
235 writeModuleToFile(M: *MPart, Path: ModulePath, OutputAssembly);
236 return Error::success();
237 };
238
239 auto Categorizer = EntryPointCategorizer(SplitByCategory, CategoryAttribute);
240 return splitModuleTransitiveFromEntryPoints(M: std::move(M), EntryPointCategorizer: Categorizer,
241 Callback: PostSplitCallback);
242}
243
244int main(int argc, char **argv) {
245 InitLLVM X(argc, argv);
246
247 LLVMContext Context;
248 SMDiagnostic Err;
249 cl::HideUnrelatedOptions(Categories: {&SplitCategory, &getColorCategory()});
250 cl::ParseCommandLineOptions(argc, argv, Overview: "LLVM module splitter\n");
251
252 Triple TT(MTriple);
253
254 std::unique_ptr<TargetMachine> TM;
255 if (!MTriple.empty()) {
256 InitializeAllTargets();
257 InitializeAllTargetMCs();
258
259 std::string Error;
260 const Target *T = TargetRegistry::lookupTarget(TheTriple: TT, Error);
261 if (!T) {
262 errs() << "unknown target '" << MTriple << "': " << Error << "\n";
263 return 1;
264 }
265
266 TargetOptions Options;
267 TM = std::unique_ptr<TargetMachine>(T->createTargetMachine(
268 TT, CPU: MCPU, /*FS*/ Features: "", Options, RM: std::nullopt, CM: std::nullopt));
269 }
270
271 std::unique_ptr<Module> M = parseIRFile(Filename: InputFilename, Err, Context);
272
273 if (!M) {
274 Err.print(ProgName: argv[0], S&: errs());
275 return 1;
276 }
277
278 unsigned I = 0;
279 const auto HandleModulePart = [&](std::unique_ptr<Module> MPart) {
280 std::error_code EC;
281 std::unique_ptr<ToolOutputFile> Out(
282 new ToolOutputFile(OutputFilename + utostr(X: I++), EC, sys::fs::OF_None));
283 if (EC) {
284 errs() << EC.message() << '\n';
285 exit(status: 1);
286 }
287
288 if (verifyModule(M: *MPart, OS: &errs())) {
289 errs() << "Broken module!\n";
290 exit(status: 1);
291 }
292
293 WriteBitcodeToFile(M: *MPart, Out&: Out->os());
294
295 // Declare success.
296 Out->keep();
297 };
298
299 if (SplitByCategory != SplitByCategoryType::SBCT_None) {
300 auto E = runSplitModuleByCategory(M: std::move(M));
301 if (E) {
302 errs() << "error: " << toString(E: std::move(E)) << "\n";
303 return 1;
304 }
305
306 return 0;
307 }
308
309 if (TM) {
310 if (PreserveLocals) {
311 errs() << "warning: --preserve-locals has no effect when using "
312 "TargetMachine::splitModule\n";
313 }
314 if (RoundRobin)
315 errs() << "warning: --round-robin has no effect when using "
316 "TargetMachine::splitModule\n";
317
318 if (TM->splitModule(M&: *M, NumParts: NumOutputs, ModuleCallback: HandleModulePart))
319 return 0;
320
321 errs() << "warning: "
322 "TargetMachine::splitModule failed, falling back to default "
323 "splitModule implementation\n";
324 }
325
326 SplitModule(M&: *M, N: NumOutputs, ModuleCallback: HandleModulePart, PreserveLocals, RoundRobin);
327 return 0;
328}
329