1//===-- llvm-split: command line tool for testing module splitting --------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This program can be used to test the llvm::SplitModule and
10// TargetMachine::splitModule functions.
11//
12//===----------------------------------------------------------------------===//
13
14#include "llvm/ADT/SmallString.h"
15#include "llvm/ADT/SmallVector.h"
16#include "llvm/ADT/StringExtras.h"
17#include "llvm/Bitcode/BitcodeWriter.h"
18#include "llvm/IR/LLVMContext.h"
19#include "llvm/IR/PassInstrumentation.h"
20#include "llvm/IR/PassManager.h"
21#include "llvm/IR/Verifier.h"
22#include "llvm/IRReader/IRReader.h"
23#include "llvm/MC/TargetRegistry.h"
24#include "llvm/Support/CommandLine.h"
25#include "llvm/Support/FileSystem.h"
26#include "llvm/Support/FormatVariadic.h"
27#include "llvm/Support/InitLLVM.h"
28#include "llvm/Support/SourceMgr.h"
29#include "llvm/Support/TargetSelect.h"
30#include "llvm/Support/ToolOutputFile.h"
31#include "llvm/Support/WithColor.h"
32#include "llvm/Support/raw_ostream.h"
33#include "llvm/Target/TargetMachine.h"
34#include "llvm/TargetParser/Triple.h"
35#include "llvm/Transforms/IPO/GlobalDCE.h"
36#include "llvm/Transforms/Utils/SplitModule.h"
37#include "llvm/Transforms/Utils/SplitModuleByCategory.h"
38
39using namespace llvm;
40
41static cl::OptionCategory SplitCategory("Split Options");
42
43static cl::opt<std::string> InputFilename(cl::Positional,
44 cl::desc("<input bitcode file>"),
45 cl::init(Val: "-"),
46 cl::value_desc("filename"),
47 cl::cat(SplitCategory));
48
49static cl::opt<std::string> OutputFilename("o",
50 cl::desc("Override output filename"),
51 cl::value_desc("filename"),
52 cl::cat(SplitCategory));
53
54static cl::opt<unsigned> NumOutputs("j", cl::Prefix, cl::init(Val: 2),
55 cl::desc("Number of output files"),
56 cl::cat(SplitCategory));
57
58static cl::opt<bool>
59 PreserveLocals("preserve-locals", cl::Prefix, cl::init(Val: false),
60 cl::desc("Split without externalizing locals"),
61 cl::cat(SplitCategory));
62
63static cl::opt<bool>
64 RoundRobin("round-robin", cl::Prefix, cl::init(Val: false),
65 cl::desc("Use round-robin distribution of functions to "
66 "modules instead of the default name-hash-based one"),
67 cl::cat(SplitCategory));
68
69static cl::opt<std::string>
70 MTriple("mtriple",
71 cl::desc("Target triple. When present, a TargetMachine is created "
72 "and TargetMachine::splitModule is used instead of the "
73 "common SplitModule logic."),
74 cl::value_desc("triple"), cl::cat(SplitCategory));
75
76static cl::opt<std::string>
77 MCPU("mcpu", cl::desc("Target CPU, ignored if --mtriple is not used"),
78 cl::value_desc("cpu"), cl::cat(SplitCategory));
79
80enum class SplitByCategoryType {
81 SBCT_ByModuleId,
82 SBCT_ByKernel,
83 SBCT_None,
84};
85
86static cl::opt<SplitByCategoryType> SplitByCategory(
87 "split-by-category",
88 cl::desc("Split by category. If present, splitting by category is used "
89 "with the specified categorization type."),
90 cl::Optional, cl::init(Val: SplitByCategoryType::SBCT_None),
91 cl::values(clEnumValN(SplitByCategoryType::SBCT_ByModuleId, "module-id",
92 "one output module per translation unit marked with "
93 "\"module-id\" attribute"),
94 clEnumValN(SplitByCategoryType::SBCT_ByKernel, "kernel",
95 "one output module per kernel")),
96 cl::cat(SplitCategory));
97
98static cl::opt<bool> OutputAssembly{
99 "S", cl::desc("Write output as LLVM assembly"), cl::cat(SplitCategory)};
100
101void writeStringToFile(StringRef Content, StringRef Path) {
102 std::error_code EC;
103 raw_fd_ostream OS(Path, EC);
104 if (EC) {
105 errs() << formatv(Fmt: "error opening file: {0}, error: {1}\n", Vals&: Path,
106 Vals: EC.message());
107 exit(status: 1);
108 }
109
110 OS << Content << "\n";
111}
112
113void writeModuleToFile(const Module &M, StringRef Path, bool OutputAssembly) {
114 int FD = -1;
115 if (std::error_code EC = sys::fs::openFileForWrite(Name: Path, ResultFD&: FD)) {
116 errs() << formatv(Fmt: "error opening file: {0}, error: {1}", Vals&: Path, Vals: EC.message())
117 << '\n';
118 exit(status: 1);
119 }
120
121 raw_fd_ostream OS(FD, /*ShouldClose*/ true);
122 if (OutputAssembly)
123 M.print(OS, /*AssemblyAnnotationWriter*/ AAW: nullptr);
124 else
125 WriteBitcodeToFile(M, Out&: OS);
126}
127
128/// EntryPointCategorizer is used for splitting by category either by module-id
129/// or by kernels. It doesn't provide categories for functions other than
130/// kernels. Categorizer computes a string key for the given Function and
131/// records the association between the string key and an integer category. If a
132/// string key is already belongs to some category than the corresponding
133/// integer category is returned.
134class EntryPointCategorizer {
135public:
136 EntryPointCategorizer(SplitByCategoryType Type) : Type(Type) {}
137
138 EntryPointCategorizer() = delete;
139 EntryPointCategorizer(EntryPointCategorizer &) = delete;
140 EntryPointCategorizer &operator=(const EntryPointCategorizer &) = delete;
141 EntryPointCategorizer(EntryPointCategorizer &&) = default;
142 EntryPointCategorizer &operator=(EntryPointCategorizer &&) = default;
143
144 /// Returns integer specifying the category for the given \p F.
145 /// If the given function isn't a kernel then returns std::nullopt.
146 std::optional<int> operator()(const Function &F) {
147 if (!isEntryPoint(F))
148 return std::nullopt; // skip the function.
149
150 auto StringKey = computeFunctionCategory(Type, F);
151 if (auto it = StrKeyToID.find(Val: StringRef(StringKey)); it != StrKeyToID.end())
152 return it->second;
153
154 int ID = static_cast<int>(StrKeyToID.size());
155 return StrKeyToID.try_emplace(Key: std::move(StringKey), Args&: ID).first->second;
156 }
157
158private:
159 static bool isEntryPoint(const Function &F) {
160 if (F.isDeclaration())
161 return false;
162
163 return F.hasKernelCallingConv();
164 }
165
166 static SmallString<0> computeFunctionCategory(SplitByCategoryType Type,
167 const Function &F) {
168 static constexpr char ATTR_MODULE_ID[] = "module-id";
169 SmallString<0> Key;
170 switch (Type) {
171 case SplitByCategoryType::SBCT_ByKernel:
172 Key = F.getName().str();
173 break;
174 case SplitByCategoryType::SBCT_ByModuleId:
175 Key = F.getFnAttribute(Kind: ATTR_MODULE_ID).getValueAsString().str();
176 break;
177 default:
178 llvm_unreachable("unexpected mode.");
179 }
180
181 return Key;
182 }
183
184private:
185 struct KeyInfo {
186 static SmallString<0> getEmptyKey() { return SmallString<0>(""); }
187
188 static SmallString<0> getTombstoneKey() { return SmallString<0>("-"); }
189
190 static bool isEqual(const SmallString<0> &LHS, const SmallString<0> &RHS) {
191 return LHS == RHS;
192 }
193
194 static unsigned getHashValue(const SmallString<0> &S) {
195 return llvm::hash_value(S: StringRef(S));
196 }
197 };
198
199 SplitByCategoryType Type;
200 DenseMap<SmallString<0>, int, KeyInfo> StrKeyToID;
201};
202
203void cleanupModule(Module &M) {
204 ModuleAnalysisManager MAM;
205 MAM.registerPass(PassBuilder: [&] { return PassInstrumentationAnalysis(); });
206 ModulePassManager MPM;
207 MPM.addPass(Pass: GlobalDCEPass()); // Delete unreachable globals.
208 MPM.run(IR&: M, AM&: MAM);
209}
210
211Error runSplitModuleByCategory(std::unique_ptr<Module> M) {
212 size_t OutputID = 0;
213 auto PostSplitCallback = [&](std::unique_ptr<Module> MPart) {
214 if (verifyModule(M: *MPart)) {
215 errs() << "Broken Module!\n";
216 exit(status: 1);
217 }
218
219 // TODO: DCE is a crucial pass since it removes unused declarations.
220 // At the moment, LIT checking can't be perfomed without DCE.
221 cleanupModule(M&: *MPart);
222 size_t ID = OutputID;
223 ++OutputID;
224 StringRef ModuleSuffix = OutputAssembly ? ".ll" : ".bc";
225 std::string ModulePath =
226 (Twine(OutputFilename) + "_" + Twine(ID) + ModuleSuffix).str();
227 writeModuleToFile(M: *MPart, Path: ModulePath, OutputAssembly);
228 };
229
230 auto Categorizer = EntryPointCategorizer(SplitByCategory);
231 splitModuleTransitiveFromEntryPoints(M: std::move(M), EntryPointCategorizer: Categorizer,
232 Callback: PostSplitCallback);
233 return Error::success();
234}
235
236int main(int argc, char **argv) {
237 InitLLVM X(argc, argv);
238
239 LLVMContext Context;
240 SMDiagnostic Err;
241 cl::HideUnrelatedOptions(Categories: {&SplitCategory, &getColorCategory()});
242 cl::ParseCommandLineOptions(argc, argv, Overview: "LLVM module splitter\n");
243
244 Triple TT(MTriple);
245
246 std::unique_ptr<TargetMachine> TM;
247 if (!MTriple.empty()) {
248 InitializeAllTargets();
249 InitializeAllTargetMCs();
250
251 std::string Error;
252 const Target *T = TargetRegistry::lookupTarget(TheTriple: TT, Error);
253 if (!T) {
254 errs() << "unknown target '" << MTriple << "': " << Error << "\n";
255 return 1;
256 }
257
258 TargetOptions Options;
259 TM = std::unique_ptr<TargetMachine>(T->createTargetMachine(
260 TT, CPU: MCPU, /*FS*/ Features: "", Options, RM: std::nullopt, CM: std::nullopt));
261 }
262
263 std::unique_ptr<Module> M = parseIRFile(Filename: InputFilename, Err, Context);
264
265 if (!M) {
266 Err.print(ProgName: argv[0], S&: errs());
267 return 1;
268 }
269
270 unsigned I = 0;
271 const auto HandleModulePart = [&](std::unique_ptr<Module> MPart) {
272 std::error_code EC;
273 std::unique_ptr<ToolOutputFile> Out(
274 new ToolOutputFile(OutputFilename + utostr(X: I++), EC, sys::fs::OF_None));
275 if (EC) {
276 errs() << EC.message() << '\n';
277 exit(status: 1);
278 }
279
280 if (verifyModule(M: *MPart, OS: &errs())) {
281 errs() << "Broken module!\n";
282 exit(status: 1);
283 }
284
285 WriteBitcodeToFile(M: *MPart, Out&: Out->os());
286
287 // Declare success.
288 Out->keep();
289 };
290
291 if (SplitByCategory != SplitByCategoryType::SBCT_None) {
292 auto E = runSplitModuleByCategory(M: std::move(M));
293 if (E) {
294 errs() << E << "\n";
295 Err.print(ProgName: argv[0], S&: errs());
296 return 1;
297 }
298
299 return 0;
300 }
301
302 if (TM) {
303 if (PreserveLocals) {
304 errs() << "warning: --preserve-locals has no effect when using "
305 "TargetMachine::splitModule\n";
306 }
307 if (RoundRobin)
308 errs() << "warning: --round-robin has no effect when using "
309 "TargetMachine::splitModule\n";
310
311 if (TM->splitModule(M&: *M, NumParts: NumOutputs, ModuleCallback: HandleModulePart))
312 return 0;
313
314 errs() << "warning: "
315 "TargetMachine::splitModule failed, falling back to default "
316 "splitModule implementation\n";
317 }
318
319 SplitModule(M&: *M, N: NumOutputs, ModuleCallback: HandleModulePart, PreserveLocals, RoundRobin);
320 return 0;
321}
322