1//===- AMDGPUSplitModule.cpp ----------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9/// \file Implements a module splitting algorithm designed to support the
10/// FullLTO --lto-partitions option for parallel codegen. This is completely
11/// different from the common SplitModule pass, as this system is designed with
12/// AMDGPU in mind.
13///
14/// The basic idea of this module splitting implementation is the same as
15/// SplitModule: load-balance the module's functions across a set of N
16/// partitions to allow parallel codegen. However, it does it very
17/// differently than the target-agnostic variant:
18/// - The module has "split roots", which are kernels in the vast
19// majority of cases.
20/// - Each root has a set of dependencies, and when a root and its
21/// dependencies is considered "big", we try to put it in a partition where
22/// most dependencies are already imported, to avoid duplicating large
23/// amounts of code.
24/// - There's special care for indirect calls in order to ensure
25/// AMDGPUResourceUsageAnalysis can work correctly.
26///
27/// This file also includes a more elaborate logging system to enable
28/// users to easily generate logs that (if desired) do not include any value
29/// names, in order to not leak information about the source file.
30/// Such logs are very helpful to understand and fix potential issues with
31/// module splitting.
32
33#include "AMDGPUSplitModule.h"
34#include "AMDGPUTargetMachine.h"
35#include "Utils/AMDGPUBaseInfo.h"
36#include "llvm/ADT/DenseMap.h"
37#include "llvm/ADT/SmallVector.h"
38#include "llvm/ADT/StringExtras.h"
39#include "llvm/ADT/StringRef.h"
40#include "llvm/Analysis/CallGraph.h"
41#include "llvm/Analysis/TargetTransformInfo.h"
42#include "llvm/IR/Function.h"
43#include "llvm/IR/Instruction.h"
44#include "llvm/IR/Module.h"
45#include "llvm/IR/User.h"
46#include "llvm/IR/Value.h"
47#include "llvm/Support/Casting.h"
48#include "llvm/Support/Debug.h"
49#include "llvm/Support/FileSystem.h"
50#include "llvm/Support/Path.h"
51#include "llvm/Support/Process.h"
52#include "llvm/Support/SHA256.h"
53#include "llvm/Support/Threading.h"
54#include "llvm/Support/raw_ostream.h"
55#include "llvm/Transforms/Utils/Cloning.h"
56#include <algorithm>
57#include <cassert>
58#include <iterator>
59#include <memory>
60#include <utility>
61#include <vector>
62
63using namespace llvm;
64
65#define DEBUG_TYPE "amdgpu-split-module"
66
67namespace {
68
69static cl::opt<float> LargeFnFactor(
70 "amdgpu-module-splitting-large-function-threshold", cl::init(Val: 2.0f),
71 cl::Hidden,
72 cl::desc(
73 "consider a function as large and needing special treatment when the "
74 "cost of importing it into a partition"
75 "exceeds the average cost of a partition by this factor; e;g. 2.0 "
76 "means if the function and its dependencies is 2 times bigger than "
77 "an average partition; 0 disables large functions handling entirely"));
78
79static cl::opt<float> LargeFnOverlapForMerge(
80 "amdgpu-module-splitting-large-function-merge-overlap", cl::init(Val: 0.8f),
81 cl::Hidden,
82 cl::desc(
83 "defines how much overlap between two large function's dependencies "
84 "is needed to put them in the same partition"));
85
86static cl::opt<bool> NoExternalizeGlobals(
87 "amdgpu-module-splitting-no-externalize-globals", cl::Hidden,
88 cl::desc("disables externalization of global variable with local linkage; "
89 "may cause globals to be duplicated which increases binary size"));
90
91static cl::opt<std::string>
92 LogDirOpt("amdgpu-module-splitting-log-dir", cl::Hidden,
93 cl::desc("output directory for AMDGPU module splitting logs"));
94
95static cl::opt<bool>
96 LogPrivate("amdgpu-module-splitting-log-private", cl::Hidden,
97 cl::desc("hash value names before printing them in the AMDGPU "
98 "module splitting logs"));
99
100using CostType = InstructionCost::CostType;
101using PartitionID = unsigned;
102using GetTTIFn = function_ref<const TargetTransformInfo &(Function &)>;
103
104static bool isEntryPoint(const Function *F) {
105 return AMDGPU::isEntryFunctionCC(CC: F->getCallingConv());
106}
107
108static std::string getName(const Value &V) {
109 static bool HideNames;
110
111 static llvm::once_flag HideNameInitFlag;
112 llvm::call_once(flag&: HideNameInitFlag, F: [&]() {
113 if (LogPrivate.getNumOccurrences())
114 HideNames = LogPrivate;
115 else {
116 const auto EV = sys::Process::GetEnv(name: "AMD_SPLIT_MODULE_LOG_PRIVATE");
117 HideNames = (EV.value_or(u: "0") != "0");
118 }
119 });
120
121 if (!HideNames)
122 return V.getName().str();
123 return toHex(Input: SHA256::hash(Data: arrayRefFromStringRef(Input: V.getName())),
124 /*LowerCase=*/true);
125}
126
127/// Main logging helper.
128///
129/// Logging can be configured by the following environment variable.
130/// AMD_SPLIT_MODULE_LOG_DIR=<filepath>
131/// If set, uses <filepath> as the directory to write logfiles to
132/// each time module splitting is used.
133/// AMD_SPLIT_MODULE_LOG_PRIVATE
134/// If set to anything other than zero, all names are hidden.
135///
136/// Both environment variables have corresponding CL options which
137/// takes priority over them.
138///
139/// Any output printed to the log files is also printed to dbgs() when -debug is
140/// used and LLVM_DEBUG is defined.
141///
142/// This approach has a small disadvantage over LLVM_DEBUG though: logging logic
143/// cannot be removed from the code (by building without debug). This probably
144/// has a small performance cost because if some computation/formatting is
145/// needed for logging purpose, it may be done everytime only to be ignored
146/// by the logger.
147///
148/// As this pass only runs once and is not doing anything computationally
149/// expensive, this is likely a reasonable trade-off.
150///
151/// If some computation should really be avoided when unused, users of the class
152/// can check whether any logging will occur by using the bool operator.
153///
154/// \code
155/// if (SML) {
156/// // Executes only if logging to a file or if -debug is available and
157/// used.
158/// }
159/// \endcode
160class SplitModuleLogger {
161public:
162 SplitModuleLogger(const Module &M) {
163 std::string LogDir = LogDirOpt;
164 if (LogDir.empty())
165 LogDir = sys::Process::GetEnv(name: "AMD_SPLIT_MODULE_LOG_DIR").value_or(u: "");
166
167 // No log dir specified means we don't need to log to a file.
168 // We may still log to dbgs(), though.
169 if (LogDir.empty())
170 return;
171
172 // If a log directory is specified, create a new file with a unique name in
173 // that directory.
174 int Fd;
175 SmallString<0> PathTemplate;
176 SmallString<0> RealPath;
177 sys::path::append(path&: PathTemplate, a: LogDir, b: "Module-%%-%%-%%-%%-%%-%%-%%.txt");
178 if (auto Err =
179 sys::fs::createUniqueFile(Model: PathTemplate.str(), ResultFD&: Fd, ResultPath&: RealPath)) {
180 report_fatal_error(reason: "Failed to create log file at '" + Twine(LogDir) +
181 "': " + Err.message(),
182 /*CrashDiag=*/gen_crash_diag: false);
183 }
184
185 FileOS = std::make_unique<raw_fd_ostream>(args&: Fd, /*shouldClose=*/args: true);
186 }
187
188 bool hasLogFile() const { return FileOS != nullptr; }
189
190 raw_ostream &logfile() {
191 assert(FileOS && "no logfile!");
192 return *FileOS;
193 }
194
195 /// \returns true if this SML will log anything either to a file or dbgs().
196 /// Can be used to avoid expensive computations that are ignored when logging
197 /// is disabled.
198 operator bool() const {
199 return hasLogFile() || (DebugFlag && isCurrentDebugType(DEBUG_TYPE));
200 }
201
202private:
203 std::unique_ptr<raw_fd_ostream> FileOS;
204};
205
206template <typename Ty>
207static SplitModuleLogger &operator<<(SplitModuleLogger &SML, const Ty &Val) {
208 static_assert(
209 !std::is_same_v<Ty, Value>,
210 "do not print values to logs directly, use handleName instead!");
211 LLVM_DEBUG(dbgs() << Val);
212 if (SML.hasLogFile())
213 SML.logfile() << Val;
214 return SML;
215}
216
217/// Calculate the cost of each function in \p M
218/// \param SML Log Helper
219/// \param GetTTI Abstract getter for TargetTransformInfo.
220/// \param M Module to analyze.
221/// \param CostMap[out] Resulting Function -> Cost map.
222/// \return The module's total cost.
223static CostType
224calculateFunctionCosts(SplitModuleLogger &SML, GetTTIFn GetTTI, Module &M,
225 DenseMap<const Function *, CostType> &CostMap) {
226 CostType ModuleCost = 0;
227 CostType KernelCost = 0;
228
229 for (auto &Fn : M) {
230 if (Fn.isDeclaration())
231 continue;
232
233 CostType FnCost = 0;
234 const auto &TTI = GetTTI(Fn);
235 for (const auto &BB : Fn) {
236 for (const auto &I : BB) {
237 auto Cost =
238 TTI.getInstructionCost(U: &I, CostKind: TargetTransformInfo::TCK_CodeSize);
239 assert(Cost != InstructionCost::getMax());
240 // Assume expensive if we can't tell the cost of an instruction.
241 CostType CostVal =
242 Cost.getValue().value_or(u: TargetTransformInfo::TCC_Expensive);
243 assert((FnCost + CostVal) >= FnCost && "Overflow!");
244 FnCost += CostVal;
245 }
246 }
247
248 assert(FnCost != 0);
249
250 CostMap[&Fn] = FnCost;
251 assert((ModuleCost + FnCost) >= ModuleCost && "Overflow!");
252 ModuleCost += FnCost;
253
254 if (isEntryPoint(F: &Fn))
255 KernelCost += FnCost;
256 }
257
258 CostType FnCost = (ModuleCost - KernelCost);
259 CostType ModuleCostOr1 = ModuleCost ? ModuleCost : 1;
260 SML << "=> Total Module Cost: " << ModuleCost << '\n'
261 << " => KernelCost: " << KernelCost << " ("
262 << format(Fmt: "%0.2f", Vals: (float(KernelCost) / ModuleCostOr1) * 100) << "%)\n"
263 << " => FnsCost: " << FnCost << " ("
264 << format(Fmt: "%0.2f", Vals: (float(FnCost) / ModuleCostOr1) * 100) << "%)\n";
265
266 return ModuleCost;
267}
268
269static bool canBeIndirectlyCalled(const Function &F) {
270 if (F.isDeclaration() || isEntryPoint(F: &F))
271 return false;
272 return !F.hasLocalLinkage() ||
273 F.hasAddressTaken(/*PutOffender=*/nullptr,
274 /*IgnoreCallbackUses=*/false,
275 /*IgnoreAssumeLikeCalls=*/true,
276 /*IgnoreLLVMUsed=*/IngoreLLVMUsed: true,
277 /*IgnoreARCAttachedCall=*/false,
278 /*IgnoreCastedDirectCall=*/true);
279}
280
281/// When a function or any of its callees performs an indirect call, this
282/// takes over \ref addAllDependencies and adds all potentially callable
283/// functions to \p Fns so they can be counted as dependencies of the function.
284///
285/// This is needed due to how AMDGPUResourceUsageAnalysis operates: in the
286/// presence of an indirect call, the function's resource usage is the same as
287/// the most expensive function in the module.
288/// \param M The module.
289/// \param Fns[out] Resulting list of functions.
290static void addAllIndirectCallDependencies(const Module &M,
291 DenseSet<const Function *> &Fns) {
292 for (const auto &Fn : M) {
293 if (canBeIndirectlyCalled(F: Fn))
294 Fns.insert(V: &Fn);
295 }
296}
297
298/// Adds the functions that \p Fn may call to \p Fns, then recurses into each
299/// callee until all reachable functions have been gathered.
300///
301/// \param SML Log Helper
302/// \param CG Call graph for \p Fn's module.
303/// \param Fn Current function to look at.
304/// \param Fns[out] Resulting list of functions.
305/// \param OnlyDirect Whether to only consider direct callees.
306/// \param HadIndirectCall[out] Set to true if an indirect call was seen at some
307/// point, either in \p Fn or in one of the function it calls. When that
308/// happens, we fall back to adding all callable functions inside \p Fn's module
309/// to \p Fns.
310static void addAllDependencies(SplitModuleLogger &SML, const CallGraph &CG,
311 const Function &Fn,
312 DenseSet<const Function *> &Fns, bool OnlyDirect,
313 bool &HadIndirectCall) {
314 assert(!Fn.isDeclaration());
315
316 const Module &M = *Fn.getParent();
317 SmallVector<const Function *> WorkList({&Fn});
318 while (!WorkList.empty()) {
319 const auto &CurFn = *WorkList.pop_back_val();
320 assert(!CurFn.isDeclaration());
321
322 // Scan for an indirect call. If such a call is found, we have to
323 // conservatively assume this can call all non-entrypoint functions in the
324 // module.
325
326 for (auto &CGEntry : *CG[&CurFn]) {
327 auto *CGNode = CGEntry.second;
328 auto *Callee = CGNode->getFunction();
329 if (!Callee) {
330 if (OnlyDirect)
331 continue;
332
333 // Functions have an edge towards CallsExternalNode if they're external
334 // declarations, or if they do an indirect call. As we only process
335 // definitions here, we know this means the function has an indirect
336 // call. We then have to conservatively assume this can call all
337 // non-entrypoint functions in the module.
338 if (CGNode != CG.getCallsExternalNode())
339 continue; // this is another function-less node we don't care about.
340
341 SML << "Indirect call detected in " << getName(V: CurFn)
342 << " - treating all non-entrypoint functions as "
343 "potential dependencies\n";
344
345 // TODO: Print an ORE as well ?
346 addAllIndirectCallDependencies(M, Fns);
347 HadIndirectCall = true;
348 continue;
349 }
350
351 if (Callee->isDeclaration())
352 continue;
353
354 auto [It, Inserted] = Fns.insert(V: Callee);
355 if (Inserted)
356 WorkList.push_back(Elt: Callee);
357 }
358 }
359}
360
361/// Contains information about a function and its dependencies.
362/// This is a splitting root. The splitting algorithm works by
363/// assigning these to partitions.
364struct FunctionWithDependencies {
365 FunctionWithDependencies(SplitModuleLogger &SML, CallGraph &CG,
366 const DenseMap<const Function *, CostType> &FnCosts,
367 const Function *Fn)
368 : Fn(Fn) {
369 // When Fn is not a kernel, we don't need to collect indirect callees.
370 // Resource usage analysis is only performed on kernels, and we collect
371 // indirect callees for resource usage analysis.
372 addAllDependencies(SML, CG, Fn: *Fn, Fns&: Dependencies,
373 /*OnlyDirect*/ !isEntryPoint(F: Fn), HadIndirectCall&: HasIndirectCall);
374 TotalCost = FnCosts.at(Val: Fn);
375 for (const auto *Dep : Dependencies) {
376 TotalCost += FnCosts.at(Val: Dep);
377
378 // We cannot duplicate functions with external linkage, or functions that
379 // may be overriden at runtime.
380 HasNonDuplicatableDependecy |=
381 (Dep->hasExternalLinkage() || !Dep->isDefinitionExact());
382 }
383 }
384
385 const Function *Fn = nullptr;
386 DenseSet<const Function *> Dependencies;
387 /// Whether \p Fn or any of its \ref Dependencies contains an indirect call.
388 bool HasIndirectCall = false;
389 /// Whether any of \p Fn's dependencies cannot be duplicated.
390 bool HasNonDuplicatableDependecy = false;
391
392 CostType TotalCost = 0;
393
394 /// \returns true if this function and its dependencies can be considered
395 /// large according to \p Threshold.
396 bool isLarge(CostType Threshold) const {
397 return TotalCost > Threshold && !Dependencies.empty();
398 }
399};
400
401/// Calculates how much overlap there is between \p A and \p B.
402/// \return A number between 0.0 and 1.0, where 1.0 means A == B and 0.0 means A
403/// and B have no shared elements. Kernels do not count in overlap calculation.
404static float calculateOverlap(const DenseSet<const Function *> &A,
405 const DenseSet<const Function *> &B) {
406 DenseSet<const Function *> Total;
407 for (const auto *F : A) {
408 if (!isEntryPoint(F))
409 Total.insert(V: F);
410 }
411
412 if (Total.empty())
413 return 0.0f;
414
415 unsigned NumCommon = 0;
416 for (const auto *F : B) {
417 if (isEntryPoint(F))
418 continue;
419
420 auto [It, Inserted] = Total.insert(V: F);
421 if (!Inserted)
422 ++NumCommon;
423 }
424
425 return static_cast<float>(NumCommon) / Total.size();
426}
427
428/// Performs all of the partitioning work on \p M.
429/// \param SML Log Helper
430/// \param M Module to partition.
431/// \param NumParts Number of partitions to create.
432/// \param ModuleCost Total cost of all functions in \p M.
433/// \param FnCosts Map of Function -> Cost
434/// \param WorkList Functions and their dependencies to process in order.
435/// \returns The created partitions (a vector of size \p NumParts )
436static std::vector<DenseSet<const Function *>>
437doPartitioning(SplitModuleLogger &SML, Module &M, unsigned NumParts,
438 CostType ModuleCost,
439 const DenseMap<const Function *, CostType> &FnCosts,
440 const SmallVector<FunctionWithDependencies> &WorkList) {
441
442 SML << "\n--Partitioning Starts--\n";
443
444 // Calculate a "large function threshold". When more than one function's total
445 // import cost exceeds this value, we will try to assign it to an existing
446 // partition to reduce the amount of duplication needed.
447 //
448 // e.g. let two functions X and Y have a import cost of ~10% of the module, we
449 // assign X to a partition as usual, but when we get to Y, we check if it's
450 // worth also putting it in Y's partition.
451 const CostType LargeFnThreshold =
452 LargeFnFactor ? CostType(((ModuleCost / NumParts) * LargeFnFactor))
453 : std::numeric_limits<CostType>::max();
454
455 std::vector<DenseSet<const Function *>> Partitions;
456 Partitions.resize(new_size: NumParts);
457
458 // Assign functions to partitions, and try to keep the partitions more or
459 // less balanced. We do that through a priority queue sorted in reverse, so we
460 // can always look at the partition with the least content.
461 //
462 // There are some cases where we will be deliberately unbalanced though.
463 // - Large functions: we try to merge with existing partitions to reduce code
464 // duplication.
465 // - Functions with indirect or external calls always go in the first
466 // partition (P0).
467 auto ComparePartitions = [](const std::pair<PartitionID, CostType> &a,
468 const std::pair<PartitionID, CostType> &b) {
469 // When two partitions have the same cost, assign to the one with the
470 // biggest ID first. This allows us to put things in P0 last, because P0 may
471 // have other stuff added later.
472 if (a.second == b.second)
473 return a.first < b.first;
474 return a.second > b.second;
475 };
476
477 // We can't use priority_queue here because we need to be able to access any
478 // element. This makes this a bit inefficient as we need to sort it again
479 // everytime we change it, but it's a very small array anyway (likely under 64
480 // partitions) so it's a cheap operation.
481 std::vector<std::pair<PartitionID, CostType>> BalancingQueue;
482 for (unsigned I = 0; I < NumParts; ++I)
483 BalancingQueue.emplace_back(args&: I, args: 0);
484
485 // Helper function to handle assigning a function to a partition. This takes
486 // care of updating the balancing queue.
487 const auto AssignToPartition = [&](PartitionID PID,
488 const FunctionWithDependencies &FWD) {
489 auto &FnsInPart = Partitions[PID];
490 FnsInPart.insert(V: FWD.Fn);
491 FnsInPart.insert(I: FWD.Dependencies.begin(), E: FWD.Dependencies.end());
492
493 SML << "assign " << getName(V: *FWD.Fn) << " to P" << PID << "\n -> ";
494 if (!FWD.Dependencies.empty()) {
495 SML << FWD.Dependencies.size() << " dependencies added\n";
496 };
497
498 // Update the balancing queue. we scan backwards because in the common case
499 // the partition is at the end.
500 for (auto &[QueuePID, Cost] : reverse(C&: BalancingQueue)) {
501 if (QueuePID == PID) {
502 CostType NewCost = 0;
503 for (auto *Fn : Partitions[PID])
504 NewCost += FnCosts.at(Val: Fn);
505
506 SML << "[Updating P" << PID << " Cost]:" << Cost << " -> " << NewCost;
507 if (Cost) {
508 SML << " (" << unsigned(((float(NewCost) / Cost) - 1) * 100)
509 << "% increase)";
510 }
511 SML << '\n';
512
513 Cost = NewCost;
514 }
515 }
516
517 sort(C&: BalancingQueue, Comp: ComparePartitions);
518 };
519
520 for (auto &CurFn : WorkList) {
521 // When a function has indirect calls, it must stay in the first partition
522 // alongside every reachable non-entry function. This is a nightmare case
523 // for splitting as it severely limits what we can do.
524 if (CurFn.HasIndirectCall) {
525 SML << "Function with indirect call(s): " << getName(V: *CurFn.Fn)
526 << " defaulting to P0\n";
527 AssignToPartition(0, CurFn);
528 continue;
529 }
530
531 // When a function has non duplicatable dependencies, we have to keep it in
532 // the first partition as well. This is a conservative approach, a
533 // finer-grained approach could keep track of which dependencies are
534 // non-duplicatable exactly and just make sure they're grouped together.
535 if (CurFn.HasNonDuplicatableDependecy) {
536 SML << "Function with externally visible dependency "
537 << getName(V: *CurFn.Fn) << " defaulting to P0\n";
538 AssignToPartition(0, CurFn);
539 continue;
540 }
541
542 // Be smart with large functions to avoid duplicating their dependencies.
543 if (CurFn.isLarge(Threshold: LargeFnThreshold)) {
544 assert(LargeFnOverlapForMerge >= 0.0f && LargeFnOverlapForMerge <= 1.0f);
545 SML << "Large Function: " << getName(V: *CurFn.Fn)
546 << " - looking for partition with at least "
547 << format(Fmt: "%0.2f", Vals: LargeFnOverlapForMerge * 100) << "% overlap\n";
548
549 bool Assigned = false;
550 for (const auto &[PID, Fns] : enumerate(First&: Partitions)) {
551 float Overlap = calculateOverlap(A: CurFn.Dependencies, B: Fns);
552 SML << " => " << format(Fmt: "%0.2f", Vals: Overlap * 100) << "% overlap with P"
553 << PID << '\n';
554 if (Overlap > LargeFnOverlapForMerge) {
555 SML << " selecting P" << PID << '\n';
556 AssignToPartition(PID, CurFn);
557 Assigned = true;
558 }
559 }
560
561 if (Assigned)
562 continue;
563 }
564
565 // Normal "load-balancing", assign to partition with least pressure.
566 auto [PID, CurCost] = BalancingQueue.back();
567 AssignToPartition(PID, CurFn);
568 }
569
570 if (SML) {
571 for (const auto &[Idx, Part] : enumerate(First&: Partitions)) {
572 CostType Cost = 0;
573 for (auto *Fn : Part)
574 Cost += FnCosts.at(Val: Fn);
575 SML << "P" << Idx << " has a total cost of " << Cost << " ("
576 << format(Fmt: "%0.2f", Vals: (float(Cost) / ModuleCost) * 100)
577 << "% of source module)\n";
578 }
579
580 SML << "--Partitioning Done--\n\n";
581 }
582
583 // Check no functions were missed.
584#ifndef NDEBUG
585 DenseSet<const Function *> AllFunctions;
586 for (const auto &Part : Partitions)
587 AllFunctions.insert(Part.begin(), Part.end());
588
589 for (auto &Fn : M) {
590 if (!Fn.isDeclaration() && !AllFunctions.contains(&Fn)) {
591 assert(AllFunctions.contains(&Fn) && "Missed a function?!");
592 }
593 }
594#endif
595
596 return Partitions;
597}
598
599static void externalize(GlobalValue &GV) {
600 if (GV.hasLocalLinkage()) {
601 GV.setLinkage(GlobalValue::ExternalLinkage);
602 GV.setVisibility(GlobalValue::HiddenVisibility);
603 }
604
605 // Unnamed entities must be named consistently between modules. setName will
606 // give a distinct name to each such entity.
607 if (!GV.hasName())
608 GV.setName("__llvmsplit_unnamed");
609}
610
611static bool hasDirectCaller(const Function &Fn) {
612 for (auto &U : Fn.uses()) {
613 if (auto *CB = dyn_cast<CallBase>(Val: U.getUser()); CB && CB->isCallee(U: &U))
614 return true;
615 }
616 return false;
617}
618
619static void splitAMDGPUModule(
620 GetTTIFn GetTTI, Module &M, unsigned N,
621 function_ref<void(std::unique_ptr<Module> MPart)> ModuleCallback) {
622
623 SplitModuleLogger SML(M);
624
625 CallGraph CG(M);
626
627 // Externalize functions whose address are taken.
628 //
629 // This is needed because partitioning is purely based on calls, but sometimes
630 // a kernel/function may just look at the address of another local function
631 // and not do anything (no calls). After partitioning, that local function may
632 // end up in a different module (so it's just a declaration in the module
633 // where its address is taken), which emits a "undefined hidden symbol" linker
634 // error.
635 //
636 // Additionally, it guides partitioning to not duplicate this function if it's
637 // called directly at some point.
638 for (auto &Fn : M) {
639 if (Fn.hasAddressTaken()) {
640 if (Fn.hasLocalLinkage()) {
641 SML << "[externalize] " << Fn.getName()
642 << " because its address is taken\n";
643 }
644 externalize(GV&: Fn);
645 }
646 }
647
648 // Externalize local GVs, which avoids duplicating their initializers, which
649 // in turns helps keep code size in check.
650 if (!NoExternalizeGlobals) {
651 for (auto &GV : M.globals()) {
652 if (GV.hasLocalLinkage())
653 SML << "[externalize] GV " << GV.getName() << '\n';
654 externalize(GV);
655 }
656 }
657
658 // Start by calculating the cost of every function in the module, as well as
659 // the module's overall cost.
660 DenseMap<const Function *, CostType> FnCosts;
661 const CostType ModuleCost = calculateFunctionCosts(SML, GetTTI, M, CostMap&: FnCosts);
662
663 // First, gather ever kernel into the worklist.
664 SmallVector<FunctionWithDependencies> WorkList;
665 for (auto &Fn : M) {
666 if (isEntryPoint(F: &Fn) && !Fn.isDeclaration())
667 WorkList.emplace_back(Args&: SML, Args&: CG, Args&: FnCosts, Args: &Fn);
668 }
669
670 // Then, find missing functions that need to be considered as additional
671 // roots. These can't be called in theory, but in practice we still have to
672 // handle them to avoid linker errors.
673 {
674 DenseSet<const Function *> SeenFunctions;
675 for (const auto &FWD : WorkList) {
676 SeenFunctions.insert(V: FWD.Fn);
677 SeenFunctions.insert(I: FWD.Dependencies.begin(), E: FWD.Dependencies.end());
678 }
679
680 for (auto &Fn : M) {
681 // If this function is not part of any kernel's dependencies and isn't
682 // directly called, consider it as a root.
683 if (!Fn.isDeclaration() && !isEntryPoint(F: &Fn) &&
684 !SeenFunctions.count(V: &Fn) && !hasDirectCaller(Fn)) {
685 WorkList.emplace_back(Args&: SML, Args&: CG, Args&: FnCosts, Args: &Fn);
686 }
687 }
688 }
689
690 // Sort the worklist so the most expensive roots are seen first.
691 sort(C&: WorkList, Comp: [&](auto &A, auto &B) {
692 // Sort by total cost, and if the total cost is identical, sort
693 // alphabetically.
694 if (A.TotalCost == B.TotalCost)
695 return A.Fn->getName() < B.Fn->getName();
696 return A.TotalCost > B.TotalCost;
697 });
698
699 if (SML) {
700 SML << "Worklist\n";
701 for (const auto &FWD : WorkList) {
702 SML << "[root] " << getName(V: *FWD.Fn) << " (totalCost:" << FWD.TotalCost
703 << " indirect:" << FWD.HasIndirectCall
704 << " hasNonDuplicatableDep:" << FWD.HasNonDuplicatableDependecy
705 << ")\n";
706 // Sort function names before printing to ensure determinism.
707 SmallVector<std::string> SortedDepNames;
708 SortedDepNames.reserve(N: FWD.Dependencies.size());
709 for (const auto *Dep : FWD.Dependencies)
710 SortedDepNames.push_back(Elt: getName(V: *Dep));
711 sort(C&: SortedDepNames);
712
713 for (const auto &Name : SortedDepNames)
714 SML << " [dependency] " << Name << '\n';
715 }
716 }
717
718 // This performs all of the partitioning work.
719 auto Partitions = doPartitioning(SML, M, NumParts: N, ModuleCost, FnCosts, WorkList);
720 assert(Partitions.size() == N);
721
722 // If we didn't externalize GVs, then local GVs need to be conservatively
723 // imported into every module (including their initializers), and then cleaned
724 // up afterwards.
725 const auto NeedsConservativeImport = [&](const GlobalValue *GV) {
726 // We conservatively import private/internal GVs into every module and clean
727 // them up afterwards.
728 const auto *Var = dyn_cast<GlobalVariable>(Val: GV);
729 return Var && Var->hasLocalLinkage();
730 };
731
732 SML << "Creating " << N << " modules...\n";
733 unsigned TotalFnImpls = 0;
734 for (unsigned I = 0; I < N; ++I) {
735 const auto &FnsInPart = Partitions[I];
736
737 ValueToValueMapTy VMap;
738 std::unique_ptr<Module> MPart(
739 CloneModule(M, VMap, ShouldCloneDefinition: [&](const GlobalValue *GV) {
740 // Functions go in their assigned partition.
741 if (const auto *Fn = dyn_cast<Function>(Val: GV))
742 return FnsInPart.contains(V: Fn);
743
744 if (NeedsConservativeImport(GV))
745 return true;
746
747 // Everything else goes in the first partition.
748 return I == 0;
749 }));
750
751 // Clean-up conservatively imported GVs without any users.
752 for (auto &GV : make_early_inc_range(Range: MPart->globals())) {
753 if (NeedsConservativeImport(&GV) && GV.use_empty())
754 GV.eraseFromParent();
755 }
756
757 unsigned NumAllFns = 0, NumKernels = 0;
758 for (auto &Cur : *MPart) {
759 if (!Cur.isDeclaration()) {
760 ++NumAllFns;
761 if (isEntryPoint(F: &Cur))
762 ++NumKernels;
763 }
764 }
765 TotalFnImpls += NumAllFns;
766 SML << " - Module " << I << " with " << NumAllFns << " functions ("
767 << NumKernels << " kernels)\n";
768 ModuleCallback(std::move(MPart));
769 }
770
771 SML << TotalFnImpls << " function definitions across all modules ("
772 << format(Fmt: "%0.2f", Vals: (float(TotalFnImpls) / FnCosts.size()) * 100)
773 << "% of original module)\n";
774}
775} // namespace
776
777PreservedAnalyses AMDGPUSplitModulePass::run(Module &M,
778 ModuleAnalysisManager &MAM) {
779 FunctionAnalysisManager &FAM =
780 MAM.getResult<FunctionAnalysisManagerModuleProxy>(IR&: M).getManager();
781 const auto TTIGetter = [&FAM](Function &F) -> const TargetTransformInfo & {
782 return FAM.getResult<TargetIRAnalysis>(IR&: F);
783 };
784 splitAMDGPUModule(GetTTI: TTIGetter, M, N, ModuleCallback);
785 // We don't change the original module.
786 return PreservedAnalyses::all();
787}
788