1//===-- llvm-split: command line tool for testing module splitting --------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This program can be used to test the llvm::SplitModule and
10// TargetMachine::splitModule functions.
11//
12//===----------------------------------------------------------------------===//
13
14#include "llvm/ADT/SmallString.h"
15#include "llvm/ADT/SmallVector.h"
16#include "llvm/ADT/StringExtras.h"
17#include "llvm/Bitcode/BitcodeWriter.h"
18#include "llvm/IR/LLVMContext.h"
19#include "llvm/IR/PassInstrumentation.h"
20#include "llvm/IR/PassManager.h"
21#include "llvm/IR/Verifier.h"
22#include "llvm/IRReader/IRReader.h"
23#include "llvm/MC/TargetRegistry.h"
24#include "llvm/Support/CommandLine.h"
25#include "llvm/Support/FileSystem.h"
26#include "llvm/Support/FormatVariadic.h"
27#include "llvm/Support/InitLLVM.h"
28#include "llvm/Support/SourceMgr.h"
29#include "llvm/Support/TargetSelect.h"
30#include "llvm/Support/ToolOutputFile.h"
31#include "llvm/Support/WithColor.h"
32#include "llvm/Support/raw_ostream.h"
33#include "llvm/Target/TargetMachine.h"
34#include "llvm/TargetParser/Triple.h"
35#include "llvm/Transforms/IPO/GlobalDCE.h"
36#include "llvm/Transforms/Utils/SplitModule.h"
37#include "llvm/Transforms/Utils/SplitModuleByCategory.h"
38
39using namespace llvm;
40
41static cl::OptionCategory SplitCategory("Split Options");
42
43static cl::opt<std::string> InputFilename(cl::Positional,
44 cl::desc("<input bitcode file>"),
45 cl::init(Val: "-"),
46 cl::value_desc("filename"),
47 cl::cat(SplitCategory));
48
49static cl::opt<std::string> OutputFilename("o",
50 cl::desc("Override output filename"),
51 cl::value_desc("filename"),
52 cl::cat(SplitCategory));
53
54static cl::opt<unsigned> NumOutputs("j", cl::Prefix, cl::init(Val: 2),
55 cl::desc("Number of output files"),
56 cl::cat(SplitCategory));
57
58static cl::opt<bool>
59 PreserveLocals("preserve-locals", cl::Prefix, cl::init(Val: false),
60 cl::desc("Split without externalizing locals"),
61 cl::cat(SplitCategory));
62
63static cl::opt<bool>
64 RoundRobin("round-robin", cl::Prefix, cl::init(Val: false),
65 cl::desc("Use round-robin distribution of functions to "
66 "modules instead of the default name-hash-based one"),
67 cl::cat(SplitCategory));
68
69static cl::opt<std::string>
70 MTriple("mtriple",
71 cl::desc("Target triple. When present, a TargetMachine is created "
72 "and TargetMachine::splitModule is used instead of the "
73 "common SplitModule logic."),
74 cl::value_desc("triple"), cl::cat(SplitCategory));
75
76static cl::opt<std::string>
77 MCPU("mcpu", cl::desc("Target CPU, ignored if --mtriple is not used"),
78 cl::value_desc("cpu"), cl::cat(SplitCategory));
79
80enum class SplitByCategoryType {
81 SBCT_ByModuleId,
82 SBCT_ByKernel,
83 SBCT_None,
84};
85
86static cl::opt<SplitByCategoryType> SplitByCategory(
87 "split-by-category",
88 cl::desc("Split by category. If present, splitting by category is used "
89 "with the specified categorization type."),
90 cl::Optional, cl::init(Val: SplitByCategoryType::SBCT_None),
91 cl::values(clEnumValN(SplitByCategoryType::SBCT_ByModuleId, "module-id",
92 "one output module per translation unit marked with "
93 "\"module-id\" attribute"),
94 clEnumValN(SplitByCategoryType::SBCT_ByKernel, "kernel",
95 "one output module per kernel")),
96 cl::cat(SplitCategory));
97
98static cl::opt<bool> OutputAssembly{
99 "S", cl::desc("Write output as LLVM assembly"), cl::cat(SplitCategory)};
100
101void writeStringToFile(StringRef Content, StringRef Path) {
102 std::error_code EC;
103 raw_fd_ostream OS(Path, EC);
104 if (EC) {
105 errs() << formatv(Fmt: "error opening file: {0}, error: {1}\n", Vals&: Path,
106 Vals: EC.message());
107 exit(status: 1);
108 }
109
110 OS << Content << "\n";
111}
112
113void writeModuleToFile(const Module &M, StringRef Path, bool OutputAssembly) {
114 int FD = -1;
115 if (std::error_code EC = sys::fs::openFileForWrite(Name: Path, ResultFD&: FD)) {
116 errs() << formatv(Fmt: "error opening file: {0}, error: {1}", Vals&: Path, Vals: EC.message())
117 << '\n';
118 exit(status: 1);
119 }
120
121 raw_fd_ostream OS(FD, /*ShouldClose*/ true);
122 if (OutputAssembly)
123 M.print(OS, /*AssemblyAnnotationWriter*/ AAW: nullptr);
124 else
125 WriteBitcodeToFile(M, Out&: OS);
126}
127
128/// EntryPointCategorizer is used for splitting by category either by module-id
129/// or by kernels. It doesn't provide categories for functions other than
130/// kernels. Categorizer computes a string key for the given Function and
131/// records the association between the string key and an integer category. If a
132/// string key is already belongs to some category than the corresponding
133/// integer category is returned.
134class EntryPointCategorizer {
135public:
136 EntryPointCategorizer(SplitByCategoryType Type) : Type(Type) {}
137
138 EntryPointCategorizer() = delete;
139 EntryPointCategorizer(EntryPointCategorizer &) = delete;
140 EntryPointCategorizer &operator=(const EntryPointCategorizer &) = delete;
141 EntryPointCategorizer(EntryPointCategorizer &&) = default;
142 EntryPointCategorizer &operator=(EntryPointCategorizer &&) = default;
143
144 /// Returns integer specifying the category for the given \p F.
145 /// If the given function isn't a kernel then returns std::nullopt.
146 std::optional<int> operator()(const Function &F) {
147 if (!isEntryPoint(F))
148 return std::nullopt; // skip the function.
149
150 auto StringKey = computeFunctionCategory(Type, F);
151 if (auto it = StrKeyToID.find(Val: StringRef(StringKey)); it != StrKeyToID.end())
152 return it->second;
153
154 int ID = static_cast<int>(StrKeyToID.size());
155 return StrKeyToID.try_emplace(Key: std::move(StringKey), Args&: ID).first->second;
156 }
157
158private:
159 static bool isEntryPoint(const Function &F) {
160 if (F.isDeclaration())
161 return false;
162
163 return F.getCallingConv() == CallingConv::SPIR_KERNEL ||
164 F.getCallingConv() == CallingConv::AMDGPU_KERNEL ||
165 F.getCallingConv() == CallingConv::PTX_Kernel;
166 }
167
168 static SmallString<0> computeFunctionCategory(SplitByCategoryType Type,
169 const Function &F) {
170 static constexpr char ATTR_MODULE_ID[] = "module-id";
171 SmallString<0> Key;
172 switch (Type) {
173 case SplitByCategoryType::SBCT_ByKernel:
174 Key = F.getName().str();
175 break;
176 case SplitByCategoryType::SBCT_ByModuleId:
177 Key = F.getFnAttribute(Kind: ATTR_MODULE_ID).getValueAsString().str();
178 break;
179 default:
180 llvm_unreachable("unexpected mode.");
181 }
182
183 return Key;
184 }
185
186private:
187 struct KeyInfo {
188 static SmallString<0> getEmptyKey() { return SmallString<0>(""); }
189
190 static SmallString<0> getTombstoneKey() { return SmallString<0>("-"); }
191
192 static bool isEqual(const SmallString<0> &LHS, const SmallString<0> &RHS) {
193 return LHS == RHS;
194 }
195
196 static unsigned getHashValue(const SmallString<0> &S) {
197 return llvm::hash_value(S: StringRef(S));
198 }
199 };
200
201 SplitByCategoryType Type;
202 DenseMap<SmallString<0>, int, KeyInfo> StrKeyToID;
203};
204
205void cleanupModule(Module &M) {
206 ModuleAnalysisManager MAM;
207 MAM.registerPass(PassBuilder: [&] { return PassInstrumentationAnalysis(); });
208 ModulePassManager MPM;
209 MPM.addPass(Pass: GlobalDCEPass()); // Delete unreachable globals.
210 MPM.run(IR&: M, AM&: MAM);
211}
212
213Error runSplitModuleByCategory(std::unique_ptr<Module> M) {
214 size_t OutputID = 0;
215 auto PostSplitCallback = [&](std::unique_ptr<Module> MPart) {
216 if (verifyModule(M: *MPart)) {
217 errs() << "Broken Module!\n";
218 exit(status: 1);
219 }
220
221 // TODO: DCE is a crucial pass since it removes unused declarations.
222 // At the moment, LIT checking can't be perfomed without DCE.
223 cleanupModule(M&: *MPart);
224 size_t ID = OutputID;
225 ++OutputID;
226 StringRef ModuleSuffix = OutputAssembly ? ".ll" : ".bc";
227 std::string ModulePath =
228 (Twine(OutputFilename) + "_" + Twine(ID) + ModuleSuffix).str();
229 writeModuleToFile(M: *MPart, Path: ModulePath, OutputAssembly);
230 };
231
232 auto Categorizer = EntryPointCategorizer(SplitByCategory);
233 splitModuleTransitiveFromEntryPoints(M: std::move(M), EntryPointCategorizer: Categorizer,
234 Callback: PostSplitCallback);
235 return Error::success();
236}
237
238int main(int argc, char **argv) {
239 InitLLVM X(argc, argv);
240
241 LLVMContext Context;
242 SMDiagnostic Err;
243 cl::HideUnrelatedOptions(Categories: {&SplitCategory, &getColorCategory()});
244 cl::ParseCommandLineOptions(argc, argv, Overview: "LLVM module splitter\n");
245
246 Triple TT(MTriple);
247
248 std::unique_ptr<TargetMachine> TM;
249 if (!MTriple.empty()) {
250 InitializeAllTargets();
251 InitializeAllTargetMCs();
252
253 std::string Error;
254 const Target *T = TargetRegistry::lookupTarget(TheTriple: TT, Error);
255 if (!T) {
256 errs() << "unknown target '" << MTriple << "': " << Error << "\n";
257 return 1;
258 }
259
260 TargetOptions Options;
261 TM = std::unique_ptr<TargetMachine>(T->createTargetMachine(
262 TT, CPU: MCPU, /*FS*/ Features: "", Options, RM: std::nullopt, CM: std::nullopt));
263 }
264
265 std::unique_ptr<Module> M = parseIRFile(Filename: InputFilename, Err, Context);
266
267 if (!M) {
268 Err.print(ProgName: argv[0], S&: errs());
269 return 1;
270 }
271
272 unsigned I = 0;
273 const auto HandleModulePart = [&](std::unique_ptr<Module> MPart) {
274 std::error_code EC;
275 std::unique_ptr<ToolOutputFile> Out(
276 new ToolOutputFile(OutputFilename + utostr(X: I++), EC, sys::fs::OF_None));
277 if (EC) {
278 errs() << EC.message() << '\n';
279 exit(status: 1);
280 }
281
282 if (verifyModule(M: *MPart, OS: &errs())) {
283 errs() << "Broken module!\n";
284 exit(status: 1);
285 }
286
287 WriteBitcodeToFile(M: *MPart, Out&: Out->os());
288
289 // Declare success.
290 Out->keep();
291 };
292
293 if (SplitByCategory != SplitByCategoryType::SBCT_None) {
294 auto E = runSplitModuleByCategory(M: std::move(M));
295 if (E) {
296 errs() << E << "\n";
297 Err.print(ProgName: argv[0], S&: errs());
298 return 1;
299 }
300
301 return 0;
302 }
303
304 if (TM) {
305 if (PreserveLocals) {
306 errs() << "warning: --preserve-locals has no effect when using "
307 "TargetMachine::splitModule\n";
308 }
309 if (RoundRobin)
310 errs() << "warning: --round-robin has no effect when using "
311 "TargetMachine::splitModule\n";
312
313 if (TM->splitModule(M&: *M, NumParts: NumOutputs, ModuleCallback: HandleModulePart))
314 return 0;
315
316 errs() << "warning: "
317 "TargetMachine::splitModule failed, falling back to default "
318 "splitModule implementation\n";
319 }
320
321 SplitModule(M&: *M, N: NumOutputs, ModuleCallback: HandleModulePart, PreserveLocals, RoundRobin);
322 return 0;
323}
324