1//===- AMDGPUSplitModule.cpp ----------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9/// \file Implements a module splitting algorithm designed to support the
10/// FullLTO --lto-partitions option for parallel codegen.
11///
12/// The role of this module splitting pass is the same as
13/// lib/Transforms/Utils/SplitModule.cpp: load-balance the module's functions
14/// across a set of N partitions to allow for parallel codegen.
15///
16/// The similarities mostly end here, as this pass achieves load-balancing in a
17/// more elaborate fashion which is targeted towards AMDGPU modules. It can take
18/// advantage of the structure of AMDGPU modules (which are mostly
19/// self-contained) to allow for more efficient splitting without affecting
20/// codegen negatively, or causing innaccurate resource usage analysis.
21///
22/// High-level pass overview:
23/// - SplitGraph & associated classes
24/// - Graph representation of the module and of the dependencies that
25/// matter for splitting.
26/// - RecursiveSearchSplitting
27/// - Core splitting algorithm.
28/// - SplitProposal
29/// - Represents a suggested solution for splitting the input module. These
30/// solutions can be scored to determine the best one when multiple
31/// solutions are available.
32/// - Driver/pass "run" function glues everything together.
33
34#include "AMDGPUSplitModule.h"
35#include "AMDGPUTargetMachine.h"
36#include "Utils/AMDGPUBaseInfo.h"
37#include "llvm/ADT/DenseMap.h"
38#include "llvm/ADT/EquivalenceClasses.h"
39#include "llvm/ADT/GraphTraits.h"
40#include "llvm/ADT/SmallVector.h"
41#include "llvm/ADT/StringExtras.h"
42#include "llvm/ADT/StringRef.h"
43#include "llvm/Analysis/CallGraph.h"
44#include "llvm/Analysis/TargetTransformInfo.h"
45#include "llvm/IR/Function.h"
46#include "llvm/IR/InstIterator.h"
47#include "llvm/IR/Instruction.h"
48#include "llvm/IR/Module.h"
49#include "llvm/IR/Value.h"
50#include "llvm/Support/Allocator.h"
51#include "llvm/Support/Casting.h"
52#include "llvm/Support/DOTGraphTraits.h"
53#include "llvm/Support/Debug.h"
54#include "llvm/Support/GraphWriter.h"
55#include "llvm/Support/Path.h"
56#include "llvm/Support/Timer.h"
57#include "llvm/Support/raw_ostream.h"
58#include "llvm/Transforms/Utils/Cloning.h"
59#include <cassert>
60#include <cmath>
61#include <memory>
62#include <utility>
63#include <vector>
64
65#ifndef NDEBUG
66#include "llvm/Support/LockFileManager.h"
67#endif
68
69#define DEBUG_TYPE "amdgpu-split-module"
70
71namespace llvm {
72namespace {
73
74static cl::opt<unsigned> MaxDepth(
75 "amdgpu-module-splitting-max-depth",
76 cl::desc(
77 "maximum search depth. 0 forces a greedy approach. "
78 "warning: the algorithm is up to O(2^N), where N is the max depth."),
79 cl::init(Val: 8));
80
81static cl::opt<float> LargeFnFactor(
82 "amdgpu-module-splitting-large-threshold", cl::init(Val: 2.0f), cl::Hidden,
83 cl::desc(
84 "when max depth is reached and we can no longer branch out, this "
85 "value determines if a function is worth merging into an already "
86 "existing partition to reduce code duplication. This is a factor "
87 "of the ideal partition size, e.g. 2.0 means we consider the "
88 "function for merging if its cost (including its callees) is 2x the "
89 "size of an ideal partition."));
90
91static cl::opt<float> LargeFnOverlapForMerge(
92 "amdgpu-module-splitting-merge-threshold", cl::init(Val: 0.7f), cl::Hidden,
93 cl::desc("when a function is considered for merging into a partition that "
94 "already contains some of its callees, do the merge if at least "
95 "n% of the code it can reach is already present inside the "
96 "partition; e.g. 0.7 means only merge >70%"));
97
98static cl::opt<bool> NoExternalizeGlobals(
99 "amdgpu-module-splitting-no-externalize-globals", cl::Hidden,
100 cl::desc("disables externalization of global variable with local linkage; "
101 "may cause globals to be duplicated which increases binary size"));
102
103static cl::opt<bool> NoExternalizeOnAddrTaken(
104 "amdgpu-module-splitting-no-externalize-address-taken", cl::Hidden,
105 cl::desc(
106 "disables externalization of functions whose addresses are taken"));
107
108static cl::opt<std::string>
109 ModuleDotCfgOutput("amdgpu-module-splitting-print-module-dotcfg",
110 cl::Hidden,
111 cl::desc("output file to write out the dotgraph "
112 "representation of the input module"));
113
114static cl::opt<std::string> PartitionSummariesOutput(
115 "amdgpu-module-splitting-print-partition-summaries", cl::Hidden,
116 cl::desc("output file to write out a summary of "
117 "the partitions created for each module"));
118
119#ifndef NDEBUG
120static cl::opt<bool>
121 UseLockFile("amdgpu-module-splitting-serial-execution", cl::Hidden,
122 cl::desc("use a lock file so only one process in the system "
123 "can run this pass at once. useful to avoid mangled "
124 "debug output in multithreaded environments."));
125
126static cl::opt<bool>
127 DebugProposalSearch("amdgpu-module-splitting-debug-proposal-search",
128 cl::Hidden,
129 cl::desc("print all proposals received and whether "
130 "they were rejected or accepted"));
131#endif
132
133struct SplitModuleTimer : NamedRegionTimer {
134 SplitModuleTimer(StringRef Name, StringRef Desc)
135 : NamedRegionTimer(Name, Desc, DEBUG_TYPE, "AMDGPU Module Splitting",
136 TimePassesIsEnabled) {}
137};
138
139//===----------------------------------------------------------------------===//
140// Utils
141//===----------------------------------------------------------------------===//
142
143using CostType = InstructionCost::CostType;
144using FunctionsCostMap = DenseMap<const Function *, CostType>;
145using GetTTIFn = function_ref<const TargetTransformInfo &(Function &)>;
146static constexpr unsigned InvalidPID = -1;
147
148/// \param Num numerator
149/// \param Dem denominator
150/// \returns a printable object to print (Num/Dem) using "%0.2f".
151static auto formatRatioOf(CostType Num, CostType Dem) {
152 CostType DemOr1 = Dem ? Dem : 1;
153 return format(Fmt: "%0.2f", Vals: (static_cast<double>(Num) / DemOr1) * 100);
154}
155
156/// Checks whether a given function is non-copyable.
157///
158/// Non-copyable functions cannot be cloned into multiple partitions, and only
159/// one copy of the function can be present across all partitions.
160///
161/// Kernel functions and external functions fall into this category. If we were
162/// to clone them, we would end up with multiple symbol definitions and a very
163/// unhappy linker.
164static bool isNonCopyable(const Function &F) {
165 return F.hasExternalLinkage() || !F.isDefinitionExact() ||
166 AMDGPU::isEntryFunctionCC(CC: F.getCallingConv());
167}
168
169/// If \p GV has local linkage, make it external + hidden.
170static void externalize(GlobalValue &GV) {
171 if (GV.hasLocalLinkage()) {
172 GV.setLinkage(GlobalValue::ExternalLinkage);
173 GV.setVisibility(GlobalValue::HiddenVisibility);
174 }
175
176 // Unnamed entities must be named consistently between modules. setName will
177 // give a distinct name to each such entity.
178 if (!GV.hasName())
179 GV.setName("__llvmsplit_unnamed");
180}
181
182/// Cost analysis function. Calculates the cost of each function in \p M
183///
184/// \param GetTTI Abstract getter for TargetTransformInfo.
185/// \param M Module to analyze.
186/// \param CostMap[out] Resulting Function -> Cost map.
187/// \return The module's total cost.
188static CostType calculateFunctionCosts(GetTTIFn GetTTI, Module &M,
189 FunctionsCostMap &CostMap) {
190 SplitModuleTimer SMT("calculateFunctionCosts", "cost analysis");
191
192 LLVM_DEBUG(dbgs() << "[cost analysis] calculating function costs\n");
193 CostType ModuleCost = 0;
194 [[maybe_unused]] CostType KernelCost = 0;
195
196 for (auto &Fn : M) {
197 if (Fn.isDeclaration())
198 continue;
199
200 CostType FnCost = 0;
201 const auto &TTI = GetTTI(Fn);
202 for (const auto &BB : Fn) {
203 for (const auto &I : BB) {
204 auto Cost =
205 TTI.getInstructionCost(U: &I, CostKind: TargetTransformInfo::TCK_CodeSize);
206 assert(Cost != InstructionCost::getMax());
207 // Assume expensive if we can't tell the cost of an instruction.
208 CostType CostVal = Cost.isValid()
209 ? Cost.getValue()
210 : (CostType)TargetTransformInfo::TCC_Expensive;
211 assert((FnCost + CostVal) >= FnCost && "Overflow!");
212 FnCost += CostVal;
213 }
214 }
215
216 assert(FnCost != 0);
217
218 CostMap[&Fn] = FnCost;
219 assert((ModuleCost + FnCost) >= ModuleCost && "Overflow!");
220 ModuleCost += FnCost;
221
222 if (AMDGPU::isEntryFunctionCC(CC: Fn.getCallingConv()))
223 KernelCost += FnCost;
224 }
225
226 if (CostMap.empty())
227 return 0;
228
229 assert(ModuleCost);
230 LLVM_DEBUG({
231 const CostType FnCost = ModuleCost - KernelCost;
232 dbgs() << " - total module cost is " << ModuleCost << ". kernels cost "
233 << "" << KernelCost << " ("
234 << format("%0.2f", (float(KernelCost) / ModuleCost) * 100)
235 << "% of the module), functions cost " << FnCost << " ("
236 << format("%0.2f", (float(FnCost) / ModuleCost) * 100)
237 << "% of the module)\n";
238 });
239
240 return ModuleCost;
241}
242
243/// \return true if \p F can be indirectly called
244static bool canBeIndirectlyCalled(const Function &F) {
245 if (F.isDeclaration() || AMDGPU::isEntryFunctionCC(CC: F.getCallingConv()))
246 return false;
247 return !F.hasLocalLinkage() ||
248 F.hasAddressTaken(/*PutOffender=*/nullptr,
249 /*IgnoreCallbackUses=*/false,
250 /*IgnoreAssumeLikeCalls=*/true,
251 /*IgnoreLLVMUsed=*/IngoreLLVMUsed: true,
252 /*IgnoreARCAttachedCall=*/false,
253 /*IgnoreCastedDirectCall=*/true);
254}
255
256//===----------------------------------------------------------------------===//
257// Graph-based Module Representation
258//===----------------------------------------------------------------------===//
259
260/// AMDGPUSplitModule's view of the source Module, as a graph of all components
261/// that can be split into different modules.
262///
263/// The most trivial instance of this graph is just the CallGraph of the module,
264/// but it is not guaranteed that the graph is strictly equal to the CG. It
265/// currently always is but it's designed in a way that would eventually allow
266/// us to create abstract nodes, or nodes for different entities such as global
267/// variables or any other meaningful constraint we must consider.
268///
269/// The graph is only mutable by this class, and is generally not modified
270/// after \ref SplitGraph::buildGraph runs. No consumers of the graph can
271/// mutate it.
272class SplitGraph {
273public:
274 class Node;
275
276 enum class EdgeKind : uint8_t {
277 /// The nodes are related through a direct call. This is a "strong" edge as
278 /// it means the Src will directly reference the Dst.
279 DirectCall,
280 /// The nodes are related through an indirect call.
281 /// This is a "weaker" edge and is only considered when traversing the graph
282 /// starting from a kernel. We need this edge for resource usage analysis.
283 ///
284 /// The reason why we have this edge in the first place is due to how
285 /// AMDGPUResourceUsageAnalysis works. In the presence of an indirect call,
286 /// the resource usage of the kernel containing the indirect call is the
287 /// max resource usage of all functions that can be indirectly called.
288 IndirectCall,
289 };
290
291 /// An edge between two nodes. Edges are directional, and tagged with a
292 /// "kind".
293 struct Edge {
294 Edge(Node *Src, Node *Dst, EdgeKind Kind)
295 : Src(Src), Dst(Dst), Kind(Kind) {}
296
297 Node *Src; ///< Source
298 Node *Dst; ///< Destination
299 EdgeKind Kind;
300 };
301
302 using EdgesVec = SmallVector<const Edge *, 0>;
303 using edges_iterator = EdgesVec::const_iterator;
304 using nodes_iterator = const Node *const *;
305
306 SplitGraph(const Module &M, const FunctionsCostMap &CostMap,
307 CostType ModuleCost)
308 : M(M), CostMap(CostMap), ModuleCost(ModuleCost) {}
309
310 void buildGraph(CallGraph &CG);
311
312#ifndef NDEBUG
313 bool verifyGraph() const;
314#endif
315
316 bool empty() const { return Nodes.empty(); }
317 const iterator_range<nodes_iterator> nodes() const {
318 return {Nodes.begin(), Nodes.end()};
319 }
320 const Node &getNode(unsigned ID) const { return *Nodes[ID]; }
321
322 unsigned getNumNodes() const { return Nodes.size(); }
323 BitVector createNodesBitVector() const { return BitVector(Nodes.size()); }
324
325 const Module &getModule() const { return M; }
326
327 CostType getModuleCost() const { return ModuleCost; }
328 CostType getCost(const Function &F) const { return CostMap.at(Val: &F); }
329
330 /// \returns the aggregated cost of all nodes in \p BV (bits set to 1 = node
331 /// IDs).
332 CostType calculateCost(const BitVector &BV) const;
333
334private:
335 /// Retrieves the node for \p GV in \p Cache, or creates a new node for it and
336 /// updates \p Cache.
337 Node &getNode(DenseMap<const GlobalValue *, Node *> &Cache,
338 const GlobalValue &GV);
339
340 // Create a new edge between two nodes and add it to both nodes.
341 const Edge &createEdge(Node &Src, Node &Dst, EdgeKind EK);
342
343 const Module &M;
344 const FunctionsCostMap &CostMap;
345 CostType ModuleCost;
346
347 // Final list of nodes with stable ordering.
348 SmallVector<Node *> Nodes;
349
350 SpecificBumpPtrAllocator<Node> NodesPool;
351
352 // Edges are trivially destructible objects, so as a small optimization we
353 // use a BumpPtrAllocator which avoids destructor calls but also makes
354 // allocation faster.
355 static_assert(
356 std::is_trivially_destructible_v<Edge>,
357 "Edge must be trivially destructible to use the BumpPtrAllocator");
358 BumpPtrAllocator EdgesPool;
359};
360
361/// Nodes in the SplitGraph contain both incoming, and outgoing edges.
362/// Incoming edges have this node as their Dst, and Outgoing ones have this node
363/// as their Src.
364///
365/// Edge objects are shared by both nodes in Src/Dst. They provide immediate
366/// feedback on how two nodes are related, and in which direction they are
367/// related, which is valuable information to make splitting decisions.
368///
369/// Nodes are fundamentally abstract, and any consumers of the graph should
370/// treat them as such. While a node will be a function most of the time, we
371/// could also create nodes for any other reason. In the future, we could have
372/// single nodes for multiple functions, or nodes for GVs, etc.
373class SplitGraph::Node {
374 friend class SplitGraph;
375
376public:
377 Node(unsigned ID, const GlobalValue &GV, CostType IndividualCost,
378 bool IsNonCopyable)
379 : ID(ID), GV(GV), IndividualCost(IndividualCost),
380 IsNonCopyable(IsNonCopyable), IsEntryFnCC(false), IsGraphEntry(false) {
381 if (auto *Fn = dyn_cast<Function>(Val: &GV))
382 IsEntryFnCC = AMDGPU::isEntryFunctionCC(CC: Fn->getCallingConv());
383 }
384
385 /// An 0-indexed ID for the node. The maximum ID (exclusive) is the number of
386 /// nodes in the graph. This ID can be used as an index in a BitVector.
387 unsigned getID() const { return ID; }
388
389 const Function &getFunction() const { return cast<Function>(Val: GV); }
390
391 /// \returns the cost to import this component into a given module, not
392 /// accounting for any dependencies that may need to be imported as well.
393 CostType getIndividualCost() const { return IndividualCost; }
394
395 bool isNonCopyable() const { return IsNonCopyable; }
396 bool isEntryFunctionCC() const { return IsEntryFnCC; }
397
398 /// \returns whether this is an entry point in the graph. Entry points are
399 /// defined as follows: if you take all entry points in the graph, and iterate
400 /// their dependencies, you are guaranteed to visit all nodes in the graph at
401 /// least once.
402 bool isGraphEntryPoint() const { return IsGraphEntry; }
403
404 StringRef getName() const { return GV.getName(); }
405
406 bool hasAnyIncomingEdges() const { return IncomingEdges.size(); }
407 bool hasAnyIncomingEdgesOfKind(EdgeKind EK) const {
408 return any_of(Range: IncomingEdges, P: [&](const auto *E) { return E->Kind == EK; });
409 }
410
411 bool hasAnyOutgoingEdges() const { return OutgoingEdges.size(); }
412 bool hasAnyOutgoingEdgesOfKind(EdgeKind EK) const {
413 return any_of(Range: OutgoingEdges, P: [&](const auto *E) { return E->Kind == EK; });
414 }
415
416 iterator_range<edges_iterator> incoming_edges() const {
417 return IncomingEdges;
418 }
419
420 iterator_range<edges_iterator> outgoing_edges() const {
421 return OutgoingEdges;
422 }
423
424 bool shouldFollowIndirectCalls() const { return isEntryFunctionCC(); }
425
426 /// Visit all children of this node in a recursive fashion. Also visits Self.
427 /// If \ref shouldFollowIndirectCalls returns false, then this only follows
428 /// DirectCall edges.
429 ///
430 /// \param Visitor Visitor Function.
431 void visitAllDependencies(std::function<void(const Node &)> Visitor) const;
432
433 /// Adds the depedencies of this node in \p BV by setting the bit
434 /// corresponding to each node.
435 ///
436 /// Implemented using \ref visitAllDependencies, hence it follows the same
437 /// rules regarding dependencies traversal.
438 ///
439 /// \param[out] BV The bitvector where the bits should be set.
440 void getDependencies(BitVector &BV) const {
441 visitAllDependencies(Visitor: [&](const Node &N) { BV.set(N.getID()); });
442 }
443
444private:
445 void markAsGraphEntry() { IsGraphEntry = true; }
446
447 unsigned ID;
448 const GlobalValue &GV;
449 CostType IndividualCost;
450 bool IsNonCopyable : 1;
451 bool IsEntryFnCC : 1;
452 bool IsGraphEntry : 1;
453
454 // TODO: Use a single sorted vector (with all incoming/outgoing edges grouped
455 // together)
456 EdgesVec IncomingEdges;
457 EdgesVec OutgoingEdges;
458};
459
460void SplitGraph::Node::visitAllDependencies(
461 std::function<void(const Node &)> Visitor) const {
462 const bool FollowIndirect = shouldFollowIndirectCalls();
463 // FIXME: If this can access SplitGraph in the future, use a BitVector
464 // instead.
465 DenseSet<const Node *> Seen;
466 SmallVector<const Node *, 8> WorkList({this});
467 while (!WorkList.empty()) {
468 const Node *CurN = WorkList.pop_back_val();
469 if (auto [It, Inserted] = Seen.insert(V: CurN); !Inserted)
470 continue;
471
472 Visitor(*CurN);
473
474 for (const Edge *E : CurN->outgoing_edges()) {
475 if (!FollowIndirect && E->Kind == EdgeKind::IndirectCall)
476 continue;
477 WorkList.push_back(Elt: E->Dst);
478 }
479 }
480}
481
482/// Checks if \p I has MD_callees and if it does, parse it and put the function
483/// in \p Callees.
484///
485/// \returns true if there was metadata and it was parsed correctly. false if
486/// there was no MD or if it contained unknown entries and parsing failed.
487/// If this returns false, \p Callees will contain incomplete information
488/// and must not be used.
489static bool handleCalleesMD(const Instruction &I,
490 SetVector<Function *> &Callees) {
491 auto *MD = I.getMetadata(KindID: LLVMContext::MD_callees);
492 if (!MD)
493 return false;
494
495 for (const auto &Op : MD->operands()) {
496 Function *Callee = mdconst::extract_or_null<Function>(MD: Op);
497 if (!Callee)
498 return false;
499 Callees.insert(X: Callee);
500 }
501
502 return true;
503}
504
505void SplitGraph::buildGraph(CallGraph &CG) {
506 SplitModuleTimer SMT("buildGraph", "graph construction");
507 LLVM_DEBUG(
508 dbgs()
509 << "[build graph] constructing graph representation of the input\n");
510
511 // FIXME(?): Is the callgraph really worth using if we have to iterate the
512 // function again whenever it fails to give us enough information?
513
514 // We build the graph by just iterating all functions in the module and
515 // working on their direct callees. At the end, all nodes should be linked
516 // together as expected.
517 DenseMap<const GlobalValue *, Node *> Cache;
518 SmallVector<const Function *> FnsWithIndirectCalls, IndirectlyCallableFns;
519 for (const Function &Fn : M) {
520 if (Fn.isDeclaration())
521 continue;
522
523 // Look at direct callees and create the necessary edges in the graph.
524 SetVector<const Function *> DirectCallees;
525 bool CallsExternal = false;
526 for (auto &CGEntry : *CG[&Fn]) {
527 auto *CGNode = CGEntry.second;
528 if (auto *Callee = CGNode->getFunction()) {
529 if (!Callee->isDeclaration())
530 DirectCallees.insert(X: Callee);
531 } else if (CGNode == CG.getCallsExternalNode())
532 CallsExternal = true;
533 }
534
535 // Keep track of this function if it contains an indirect call and/or if it
536 // can be indirectly called.
537 if (CallsExternal) {
538 LLVM_DEBUG(dbgs() << " [!] callgraph is incomplete for ";
539 Fn.printAsOperand(dbgs());
540 dbgs() << " - analyzing function\n");
541
542 SetVector<Function *> KnownCallees;
543 bool HasUnknownIndirectCall = false;
544 for (const auto &Inst : instructions(F: Fn)) {
545 // look at all calls without a direct callee.
546 const auto *CB = dyn_cast<CallBase>(Val: &Inst);
547 if (!CB || CB->getCalledFunction())
548 continue;
549
550 // inline assembly can be ignored, unless InlineAsmIsIndirectCall is
551 // true.
552 if (CB->isInlineAsm()) {
553 LLVM_DEBUG(dbgs() << " found inline assembly\n");
554 continue;
555 }
556
557 if (handleCalleesMD(I: Inst, Callees&: KnownCallees))
558 continue;
559 // If we failed to parse any !callees MD, or some was missing,
560 // the entire KnownCallees list is now unreliable.
561 KnownCallees.clear();
562
563 // Everything else is handled conservatively. If we fall into the
564 // conservative case don't bother analyzing further.
565 HasUnknownIndirectCall = true;
566 break;
567 }
568
569 if (HasUnknownIndirectCall) {
570 LLVM_DEBUG(dbgs() << " indirect call found\n");
571 FnsWithIndirectCalls.push_back(Elt: &Fn);
572 } else if (!KnownCallees.empty())
573 DirectCallees.insert_range(R&: KnownCallees);
574 }
575
576 Node &N = getNode(Cache, GV: Fn);
577 for (const auto *Callee : DirectCallees)
578 createEdge(Src&: N, Dst&: getNode(Cache, GV: *Callee), EK: EdgeKind::DirectCall);
579
580 if (canBeIndirectlyCalled(F: Fn))
581 IndirectlyCallableFns.push_back(Elt: &Fn);
582 }
583
584 // Post-process functions with indirect calls.
585 for (const Function *Fn : FnsWithIndirectCalls) {
586 for (const Function *Candidate : IndirectlyCallableFns) {
587 Node &Src = getNode(Cache, GV: *Fn);
588 Node &Dst = getNode(Cache, GV: *Candidate);
589 createEdge(Src, Dst, EK: EdgeKind::IndirectCall);
590 }
591 }
592
593 // Now, find all entry points.
594 SmallVector<Node *, 16> CandidateEntryPoints;
595 BitVector NodesReachableByKernels = createNodesBitVector();
596 for (Node *N : Nodes) {
597 // Functions with an Entry CC are always graph entry points too.
598 if (N->isEntryFunctionCC()) {
599 N->markAsGraphEntry();
600 N->getDependencies(BV&: NodesReachableByKernels);
601 } else if (!N->hasAnyIncomingEdgesOfKind(EK: EdgeKind::DirectCall))
602 CandidateEntryPoints.push_back(Elt: N);
603 }
604
605 for (Node *N : CandidateEntryPoints) {
606 // This can be another entry point if it's not reachable by a kernel
607 // TODO: We could sort all of the possible new entries in a stable order
608 // (e.g. by cost), then consume them one by one until
609 // NodesReachableByKernels is all 1s. It'd allow us to avoid
610 // considering some nodes as non-entries in some specific cases.
611 if (!NodesReachableByKernels.test(Idx: N->getID()))
612 N->markAsGraphEntry();
613 }
614
615#ifndef NDEBUG
616 assert(verifyGraph());
617#endif
618}
619
620#ifndef NDEBUG
621bool SplitGraph::verifyGraph() const {
622 unsigned ExpectedID = 0;
623 // Exceptionally using a set here in case IDs are messed up.
624 DenseSet<const Node *> SeenNodes;
625 DenseSet<const Function *> SeenFunctionNodes;
626 for (const Node *N : Nodes) {
627 if (N->getID() != (ExpectedID++)) {
628 errs() << "Node IDs are incorrect!\n";
629 return false;
630 }
631
632 if (!SeenNodes.insert(N).second) {
633 errs() << "Node seen more than once!\n";
634 return false;
635 }
636
637 if (&getNode(N->getID()) != N) {
638 errs() << "getNode doesn't return the right node\n";
639 return false;
640 }
641
642 for (const Edge *E : N->IncomingEdges) {
643 if (!E->Src || !E->Dst || (E->Dst != N) ||
644 (find(E->Src->OutgoingEdges, E) == E->Src->OutgoingEdges.end())) {
645 errs() << "ill-formed incoming edges\n";
646 return false;
647 }
648 }
649
650 for (const Edge *E : N->OutgoingEdges) {
651 if (!E->Src || !E->Dst || (E->Src != N) ||
652 (find(E->Dst->IncomingEdges, E) == E->Dst->IncomingEdges.end())) {
653 errs() << "ill-formed outgoing edges\n";
654 return false;
655 }
656 }
657
658 const Function &Fn = N->getFunction();
659 if (AMDGPU::isEntryFunctionCC(Fn.getCallingConv())) {
660 if (N->hasAnyIncomingEdges()) {
661 errs() << "Kernels cannot have incoming edges\n";
662 return false;
663 }
664 }
665
666 if (Fn.isDeclaration()) {
667 errs() << "declarations shouldn't have nodes!\n";
668 return false;
669 }
670
671 auto [It, Inserted] = SeenFunctionNodes.insert(&Fn);
672 if (!Inserted) {
673 errs() << "one function has multiple nodes!\n";
674 return false;
675 }
676 }
677
678 if (ExpectedID != Nodes.size()) {
679 errs() << "Node IDs out of sync!\n";
680 return false;
681 }
682
683 if (createNodesBitVector().size() != getNumNodes()) {
684 errs() << "nodes bit vector doesn't have the right size!\n";
685 return false;
686 }
687
688 // Check we respect the promise of Node::isKernel
689 BitVector BV = createNodesBitVector();
690 for (const Node *N : nodes()) {
691 if (N->isGraphEntryPoint())
692 N->getDependencies(BV);
693 }
694
695 // Ensure each function in the module has an associated node.
696 for (const auto &Fn : M) {
697 if (!Fn.isDeclaration()) {
698 if (!SeenFunctionNodes.contains(&Fn)) {
699 errs() << "Fn has no associated node in the graph!\n";
700 return false;
701 }
702 }
703 }
704
705 if (!BV.all()) {
706 errs() << "not all nodes are reachable through the graph's entry points!\n";
707 return false;
708 }
709
710 return true;
711}
712#endif
713
714CostType SplitGraph::calculateCost(const BitVector &BV) const {
715 CostType Cost = 0;
716 for (unsigned NodeID : BV.set_bits())
717 Cost += getNode(ID: NodeID).getIndividualCost();
718 return Cost;
719}
720
721SplitGraph::Node &
722SplitGraph::getNode(DenseMap<const GlobalValue *, Node *> &Cache,
723 const GlobalValue &GV) {
724 auto &N = Cache[&GV];
725 if (N)
726 return *N;
727
728 CostType Cost = 0;
729 bool NonCopyable = false;
730 if (const Function *Fn = dyn_cast<Function>(Val: &GV)) {
731 NonCopyable = isNonCopyable(F: *Fn);
732 Cost = CostMap.at(Val: Fn);
733 }
734 N = new (NodesPool.Allocate()) Node(Nodes.size(), GV, Cost, NonCopyable);
735 Nodes.push_back(Elt: N);
736 assert(&getNode(N->getID()) == N);
737 return *N;
738}
739
740const SplitGraph::Edge &SplitGraph::createEdge(Node &Src, Node &Dst,
741 EdgeKind EK) {
742 const Edge *E = new (EdgesPool.Allocate<Edge>(Num: 1)) Edge(&Src, &Dst, EK);
743 Src.OutgoingEdges.push_back(Elt: E);
744 Dst.IncomingEdges.push_back(Elt: E);
745 return *E;
746}
747
748//===----------------------------------------------------------------------===//
749// Split Proposals
750//===----------------------------------------------------------------------===//
751
752/// Represents a module splitting proposal.
753///
754/// Proposals are made of N BitVectors, one for each partition, where each bit
755/// set indicates that the node is present and should be copied inside that
756/// partition.
757///
758/// Proposals have several metrics attached so they can be compared/sorted,
759/// which the driver to try multiple strategies resultings in multiple proposals
760/// and choose the best one out of them.
761class SplitProposal {
762public:
763 SplitProposal(const SplitGraph &SG, unsigned MaxPartitions) : SG(&SG) {
764 Partitions.resize(new_size: MaxPartitions, x: {0, SG.createNodesBitVector()});
765 }
766
767 void setName(StringRef NewName) { Name = NewName; }
768 StringRef getName() const { return Name; }
769
770 const BitVector &operator[](unsigned PID) const {
771 return Partitions[PID].second;
772 }
773
774 void add(unsigned PID, const BitVector &BV) {
775 Partitions[PID].second |= BV;
776 updateScore(PID);
777 }
778
779 void print(raw_ostream &OS) const;
780 LLVM_DUMP_METHOD void dump() const { print(OS&: dbgs()); }
781
782 // Find the cheapest partition (lowest cost). In case of ties, always returns
783 // the highest partition number.
784 unsigned findCheapestPartition() const;
785
786 /// Calculate the CodeSize and Bottleneck scores.
787 void calculateScores();
788
789#ifndef NDEBUG
790 void verifyCompleteness() const;
791#endif
792
793 /// Only available after \ref calculateScores is called.
794 ///
795 /// A positive number indicating the % of code duplication that this proposal
796 /// creates. e.g. 0.2 means this proposal adds roughly 20% code size by
797 /// duplicating some functions across partitions.
798 ///
799 /// Value is always rounded up to 3 decimal places.
800 ///
801 /// A perfect score would be 0.0, and anything approaching 1.0 is very bad.
802 double getCodeSizeScore() const { return CodeSizeScore; }
803
804 /// Only available after \ref calculateScores is called.
805 ///
806 /// A number between [0, 1] which indicates how big of a bottleneck is
807 /// expected from the largest partition.
808 ///
809 /// A score of 1.0 means the biggest partition is as big as the source module,
810 /// so build time will be equal to or greater than the build time of the
811 /// initial input.
812 ///
813 /// Value is always rounded up to 3 decimal places.
814 ///
815 /// This is one of the metrics used to estimate this proposal's build time.
816 double getBottleneckScore() const { return BottleneckScore; }
817
818private:
819 void updateScore(unsigned PID) {
820 assert(SG);
821 for (auto &[PCost, Nodes] : Partitions) {
822 TotalCost -= PCost;
823 PCost = SG->calculateCost(BV: Nodes);
824 TotalCost += PCost;
825 }
826 }
827
828 /// \see getCodeSizeScore
829 double CodeSizeScore = 0.0;
830 /// \see getBottleneckScore
831 double BottleneckScore = 0.0;
832 /// Aggregated cost of all partitions
833 CostType TotalCost = 0;
834
835 const SplitGraph *SG = nullptr;
836 std::string Name;
837
838 std::vector<std::pair<CostType, BitVector>> Partitions;
839};
840
841void SplitProposal::print(raw_ostream &OS) const {
842 assert(SG);
843
844 OS << "[proposal] " << Name << ", total cost:" << TotalCost
845 << ", code size score:" << format(Fmt: "%0.3f", Vals: CodeSizeScore)
846 << ", bottleneck score:" << format(Fmt: "%0.3f", Vals: BottleneckScore) << '\n';
847 for (const auto &[PID, Part] : enumerate(First: Partitions)) {
848 const auto &[Cost, NodeIDs] = Part;
849 OS << " - P" << PID << " nodes:" << NodeIDs.count() << " cost: " << Cost
850 << '|' << formatRatioOf(Num: Cost, Dem: SG->getModuleCost()) << "%\n";
851 }
852}
853
854unsigned SplitProposal::findCheapestPartition() const {
855 assert(!Partitions.empty());
856 CostType CurCost = std::numeric_limits<CostType>::max();
857 unsigned CurPID = InvalidPID;
858 for (const auto &[Idx, Part] : enumerate(First: Partitions)) {
859 if (Part.first <= CurCost) {
860 CurPID = Idx;
861 CurCost = Part.first;
862 }
863 }
864 assert(CurPID != InvalidPID);
865 return CurPID;
866}
867
868void SplitProposal::calculateScores() {
869 if (Partitions.empty())
870 return;
871
872 assert(SG);
873 CostType LargestPCost = 0;
874 for (auto &[PCost, Nodes] : Partitions) {
875 if (PCost > LargestPCost)
876 LargestPCost = PCost;
877 }
878
879 CostType ModuleCost = SG->getModuleCost();
880 CodeSizeScore = double(TotalCost) / ModuleCost;
881 assert(CodeSizeScore >= 0.0);
882
883 BottleneckScore = double(LargestPCost) / ModuleCost;
884
885 CodeSizeScore = std::ceil(x: CodeSizeScore * 100.0) / 100.0;
886 BottleneckScore = std::ceil(x: BottleneckScore * 100.0) / 100.0;
887}
888
889#ifndef NDEBUG
890void SplitProposal::verifyCompleteness() const {
891 if (Partitions.empty())
892 return;
893
894 BitVector Result = Partitions[0].second;
895 for (const auto &P : drop_begin(Partitions))
896 Result |= P.second;
897 assert(Result.all() && "some nodes are missing from this proposal!");
898}
899#endif
900
901//===-- RecursiveSearchStrategy -------------------------------------------===//
902
903/// Partitioning algorithm.
904///
905/// This is a recursive search algorithm that can explore multiple possiblities.
906///
907/// When a cluster of nodes can go into more than one partition, and we haven't
908/// reached maximum search depth, we recurse and explore both options and their
909/// consequences. Both branches will yield a proposal, and the driver will grade
910/// both and choose the best one.
911///
912/// If max depth is reached, we will use some heuristics to make a choice. Most
913/// of the time we will just use the least-pressured (cheapest) partition, but
914/// if a cluster is particularly big and there is a good amount of overlap with
915/// an existing partition, we will choose that partition instead.
916class RecursiveSearchSplitting {
917public:
918 using SubmitProposalFn = function_ref<void(SplitProposal)>;
919
920 RecursiveSearchSplitting(const SplitGraph &SG, unsigned NumParts,
921 SubmitProposalFn SubmitProposal);
922
923 void run();
924
925private:
926 struct WorkListEntry {
927 WorkListEntry(const BitVector &BV) : Cluster(BV) {}
928
929 unsigned NumNonEntryNodes = 0;
930 CostType TotalCost = 0;
931 CostType CostExcludingGraphEntryPoints = 0;
932 BitVector Cluster;
933 };
934
935 /// Collects all graph entry points's clusters and sort them so the most
936 /// expensive clusters are viewed first. This will merge clusters together if
937 /// they share a non-copyable dependency.
938 void setupWorkList();
939
940 /// Recursive function that assigns the worklist item at \p Idx into a
941 /// partition of \p SP.
942 ///
943 /// \p Depth is the current search depth. When this value is equal to
944 /// \ref MaxDepth, we can no longer recurse.
945 ///
946 /// This function only recurses if there is more than one possible assignment,
947 /// otherwise it is iterative to avoid creating a call stack that is as big as
948 /// \ref WorkList.
949 void pickPartition(unsigned Depth, unsigned Idx, SplitProposal SP);
950
951 /// \return A pair: first element is the PID of the partition that has the
952 /// most similarities with \p Entry, or \ref InvalidPID if no partition was
953 /// found with at least one element in common. The second element is the
954 /// aggregated cost of all dependencies in common between \p Entry and that
955 /// partition.
956 std::pair<unsigned, CostType>
957 findMostSimilarPartition(const WorkListEntry &Entry, const SplitProposal &SP);
958
959 const SplitGraph &SG;
960 unsigned NumParts;
961 SubmitProposalFn SubmitProposal;
962
963 // A Cluster is considered large when its cost, excluding entry points,
964 // exceeds this value.
965 CostType LargeClusterThreshold = 0;
966 unsigned NumProposalsSubmitted = 0;
967 SmallVector<WorkListEntry> WorkList;
968};
969
970RecursiveSearchSplitting::RecursiveSearchSplitting(
971 const SplitGraph &SG, unsigned NumParts, SubmitProposalFn SubmitProposal)
972 : SG(SG), NumParts(NumParts), SubmitProposal(SubmitProposal) {
973 // arbitrary max value as a safeguard. Anything above 10 will already be
974 // slow, this is just a max value to prevent extreme resource exhaustion or
975 // unbounded run time.
976 if (MaxDepth > 16)
977 report_fatal_error(reason: "[amdgpu-split-module] search depth of " +
978 Twine(MaxDepth) + " is too high!");
979 LargeClusterThreshold =
980 (LargeFnFactor != 0.0)
981 ? CostType(((SG.getModuleCost() / NumParts) * LargeFnFactor))
982 : std::numeric_limits<CostType>::max();
983 LLVM_DEBUG(dbgs() << "[recursive search] large cluster threshold set at "
984 << LargeClusterThreshold << "\n");
985}
986
987void RecursiveSearchSplitting::run() {
988 {
989 SplitModuleTimer SMT("recursive_search_prepare", "preparing worklist");
990 setupWorkList();
991 }
992
993 {
994 SplitModuleTimer SMT("recursive_search_pick", "partitioning");
995 SplitProposal SP(SG, NumParts);
996 pickPartition(/*BranchDepth=*/Depth: 0, /*Idx=*/0, SP);
997 }
998}
999
1000void RecursiveSearchSplitting::setupWorkList() {
1001 // e.g. if A and B are two worklist item, and they both call a non copyable
1002 // dependency C, this does:
1003 // A=C
1004 // B=C
1005 // => NodeEC will create a single group (A, B, C) and we create a new
1006 // WorkList entry for that group.
1007
1008 EquivalenceClasses<unsigned> NodeEC;
1009 for (const SplitGraph::Node *N : SG.nodes()) {
1010 if (!N->isGraphEntryPoint())
1011 continue;
1012
1013 NodeEC.insert(Data: N->getID());
1014 N->visitAllDependencies(Visitor: [&](const SplitGraph::Node &Dep) {
1015 if (&Dep != N && Dep.isNonCopyable())
1016 NodeEC.unionSets(V1: N->getID(), V2: Dep.getID());
1017 });
1018 }
1019
1020 for (const auto &Node : NodeEC) {
1021 if (!Node->isLeader())
1022 continue;
1023
1024 BitVector Cluster = SG.createNodesBitVector();
1025 for (unsigned M : NodeEC.members(ECV: *Node)) {
1026 const SplitGraph::Node &N = SG.getNode(ID: M);
1027 if (N.isGraphEntryPoint())
1028 N.getDependencies(BV&: Cluster);
1029 }
1030 WorkList.emplace_back(Args: std::move(Cluster));
1031 }
1032
1033 // Calculate costs and other useful information.
1034 for (WorkListEntry &Entry : WorkList) {
1035 for (unsigned NodeID : Entry.Cluster.set_bits()) {
1036 const SplitGraph::Node &N = SG.getNode(ID: NodeID);
1037 const CostType Cost = N.getIndividualCost();
1038
1039 Entry.TotalCost += Cost;
1040 if (!N.isGraphEntryPoint()) {
1041 Entry.CostExcludingGraphEntryPoints += Cost;
1042 ++Entry.NumNonEntryNodes;
1043 }
1044 }
1045 }
1046
1047 stable_sort(Range&: WorkList, C: [](const WorkListEntry &A, const WorkListEntry &B) {
1048 if (A.TotalCost != B.TotalCost)
1049 return A.TotalCost > B.TotalCost;
1050
1051 if (A.CostExcludingGraphEntryPoints != B.CostExcludingGraphEntryPoints)
1052 return A.CostExcludingGraphEntryPoints > B.CostExcludingGraphEntryPoints;
1053
1054 if (A.NumNonEntryNodes != B.NumNonEntryNodes)
1055 return A.NumNonEntryNodes > B.NumNonEntryNodes;
1056
1057 return A.Cluster.count() > B.Cluster.count();
1058 });
1059
1060 LLVM_DEBUG({
1061 dbgs() << "[recursive search] worklist:\n";
1062 for (const auto &[Idx, Entry] : enumerate(WorkList)) {
1063 dbgs() << " - [" << Idx << "]: ";
1064 for (unsigned NodeID : Entry.Cluster.set_bits())
1065 dbgs() << NodeID << " ";
1066 dbgs() << "(total_cost:" << Entry.TotalCost
1067 << ", cost_excl_entries:" << Entry.CostExcludingGraphEntryPoints
1068 << ")\n";
1069 }
1070 });
1071}
1072
1073void RecursiveSearchSplitting::pickPartition(unsigned Depth, unsigned Idx,
1074 SplitProposal SP) {
1075 while (Idx < WorkList.size()) {
1076 // Step 1: Determine candidate PIDs.
1077 //
1078 const WorkListEntry &Entry = WorkList[Idx];
1079 const BitVector &Cluster = Entry.Cluster;
1080
1081 // Default option is to do load-balancing, AKA assign to least pressured
1082 // partition.
1083 const unsigned CheapestPID = SP.findCheapestPartition();
1084 assert(CheapestPID != InvalidPID);
1085
1086 // Explore assigning to the kernel that contains the most dependencies in
1087 // common.
1088 const auto [MostSimilarPID, SimilarDepsCost] =
1089 findMostSimilarPartition(Entry, SP);
1090
1091 // We can chose to explore only one path if we only have one valid path, or
1092 // if we reached maximum search depth and can no longer branch out.
1093 unsigned SinglePIDToTry = InvalidPID;
1094 if (MostSimilarPID == InvalidPID) // no similar PID found
1095 SinglePIDToTry = CheapestPID;
1096 else if (MostSimilarPID == CheapestPID) // both landed on the same PID
1097 SinglePIDToTry = CheapestPID;
1098 else if (Depth >= MaxDepth) {
1099 // We have to choose one path. Use a heuristic to guess which one will be
1100 // more appropriate.
1101 if (Entry.CostExcludingGraphEntryPoints > LargeClusterThreshold) {
1102 // Check if the amount of code in common makes it worth it.
1103 assert(SimilarDepsCost && Entry.CostExcludingGraphEntryPoints);
1104 const double Ratio = static_cast<double>(SimilarDepsCost) /
1105 Entry.CostExcludingGraphEntryPoints;
1106 assert(Ratio >= 0.0 && Ratio <= 1.0);
1107 if (Ratio > LargeFnOverlapForMerge) {
1108 // For debug, just print "L", so we'll see "L3=P3" for instance, which
1109 // will mean we reached max depth and chose P3 based on this
1110 // heuristic.
1111 LLVM_DEBUG(dbgs() << 'L');
1112 SinglePIDToTry = MostSimilarPID;
1113 }
1114 } else
1115 SinglePIDToTry = CheapestPID;
1116 }
1117
1118 // Step 2: Explore candidates.
1119
1120 // When we only explore one possible path, and thus branch depth doesn't
1121 // increase, do not recurse, iterate instead.
1122 if (SinglePIDToTry != InvalidPID) {
1123 LLVM_DEBUG(dbgs() << Idx << "=P" << SinglePIDToTry << ' ');
1124 // Only one path to explore, don't clone SP, don't increase depth.
1125 SP.add(PID: SinglePIDToTry, BV: Cluster);
1126 ++Idx;
1127 continue;
1128 }
1129
1130 assert(MostSimilarPID != InvalidPID);
1131
1132 // We explore multiple paths: recurse at increased depth, then stop this
1133 // function.
1134
1135 LLVM_DEBUG(dbgs() << '\n');
1136
1137 // lb = load balancing = put in cheapest partition
1138 {
1139 SplitProposal BranchSP = SP;
1140 LLVM_DEBUG(dbgs().indent(Depth)
1141 << " [lb] " << Idx << "=P" << CheapestPID << "? ");
1142 BranchSP.add(PID: CheapestPID, BV: Cluster);
1143 pickPartition(Depth: Depth + 1, Idx: Idx + 1, SP: BranchSP);
1144 }
1145
1146 // ms = most similar = put in partition with the most in common
1147 {
1148 SplitProposal BranchSP = SP;
1149 LLVM_DEBUG(dbgs().indent(Depth)
1150 << " [ms] " << Idx << "=P" << MostSimilarPID << "? ");
1151 BranchSP.add(PID: MostSimilarPID, BV: Cluster);
1152 pickPartition(Depth: Depth + 1, Idx: Idx + 1, SP: BranchSP);
1153 }
1154
1155 return;
1156 }
1157
1158 // Step 3: If we assigned all WorkList items, submit the proposal.
1159
1160 assert(Idx == WorkList.size());
1161 assert(NumProposalsSubmitted <= (2u << MaxDepth) &&
1162 "Search got out of bounds?");
1163 SP.setName("recursive_search (depth=" + std::to_string(val: Depth) + ") #" +
1164 std::to_string(val: NumProposalsSubmitted++));
1165 LLVM_DEBUG(dbgs() << '\n');
1166 SubmitProposal(SP);
1167}
1168
1169std::pair<unsigned, CostType>
1170RecursiveSearchSplitting::findMostSimilarPartition(const WorkListEntry &Entry,
1171 const SplitProposal &SP) {
1172 if (!Entry.NumNonEntryNodes)
1173 return {InvalidPID, 0};
1174
1175 // We take the partition that is the most similar using Cost as a metric.
1176 // So we take the set of nodes in common, compute their aggregated cost, and
1177 // pick the partition with the highest cost in common.
1178 unsigned ChosenPID = InvalidPID;
1179 CostType ChosenCost = 0;
1180 for (unsigned PID = 0; PID < NumParts; ++PID) {
1181 BitVector BV = SP[PID];
1182 BV &= Entry.Cluster; // FIXME: & doesn't work between BVs?!
1183
1184 if (BV.none())
1185 continue;
1186
1187 const CostType Cost = SG.calculateCost(BV);
1188
1189 if (ChosenPID == InvalidPID || ChosenCost < Cost ||
1190 (ChosenCost == Cost && PID > ChosenPID)) {
1191 ChosenPID = PID;
1192 ChosenCost = Cost;
1193 }
1194 }
1195
1196 return {ChosenPID, ChosenCost};
1197}
1198
1199//===----------------------------------------------------------------------===//
1200// DOTGraph Printing Support
1201//===----------------------------------------------------------------------===//
1202
1203const SplitGraph::Node *mapEdgeToDst(const SplitGraph::Edge *E) {
1204 return E->Dst;
1205}
1206
1207using SplitGraphEdgeDstIterator =
1208 mapped_iterator<SplitGraph::edges_iterator, decltype(&mapEdgeToDst)>;
1209
1210} // namespace
1211
1212template <> struct GraphTraits<SplitGraph> {
1213 using NodeRef = const SplitGraph::Node *;
1214 using nodes_iterator = SplitGraph::nodes_iterator;
1215 using ChildIteratorType = SplitGraphEdgeDstIterator;
1216
1217 using EdgeRef = const SplitGraph::Edge *;
1218 using ChildEdgeIteratorType = SplitGraph::edges_iterator;
1219
1220 static NodeRef getEntryNode(NodeRef N) { return N; }
1221
1222 static ChildIteratorType child_begin(NodeRef Ref) {
1223 return {Ref->outgoing_edges().begin(), mapEdgeToDst};
1224 }
1225 static ChildIteratorType child_end(NodeRef Ref) {
1226 return {Ref->outgoing_edges().end(), mapEdgeToDst};
1227 }
1228
1229 static nodes_iterator nodes_begin(const SplitGraph &G) {
1230 return G.nodes().begin();
1231 }
1232 static nodes_iterator nodes_end(const SplitGraph &G) {
1233 return G.nodes().end();
1234 }
1235};
1236
1237template <> struct DOTGraphTraits<SplitGraph> : public DefaultDOTGraphTraits {
1238 DOTGraphTraits(bool IsSimple = false) : DefaultDOTGraphTraits(IsSimple) {}
1239
1240 static std::string getGraphName(const SplitGraph &SG) {
1241 return SG.getModule().getName().str();
1242 }
1243
1244 std::string getNodeLabel(const SplitGraph::Node *N, const SplitGraph &SG) {
1245 return N->getName().str();
1246 }
1247
1248 static std::string getNodeDescription(const SplitGraph::Node *N,
1249 const SplitGraph &SG) {
1250 std::string Result;
1251 if (N->isEntryFunctionCC())
1252 Result += "entry-fn-cc ";
1253 if (N->isNonCopyable())
1254 Result += "non-copyable ";
1255 Result += "cost:" + std::to_string(val: N->getIndividualCost());
1256 return Result;
1257 }
1258
1259 static std::string getNodeAttributes(const SplitGraph::Node *N,
1260 const SplitGraph &SG) {
1261 return N->hasAnyIncomingEdges() ? "" : "color=\"red\"";
1262 }
1263
1264 static std::string getEdgeAttributes(const SplitGraph::Node *N,
1265 SplitGraphEdgeDstIterator EI,
1266 const SplitGraph &SG) {
1267
1268 switch ((*EI.getCurrent())->Kind) {
1269 case SplitGraph::EdgeKind::DirectCall:
1270 return "";
1271 case SplitGraph::EdgeKind::IndirectCall:
1272 return "style=\"dashed\"";
1273 }
1274 llvm_unreachable("Unknown SplitGraph::EdgeKind enum");
1275 }
1276};
1277
1278//===----------------------------------------------------------------------===//
1279// Driver
1280//===----------------------------------------------------------------------===//
1281
1282namespace {
1283
1284// If we didn't externalize GVs, then local GVs need to be conservatively
1285// imported into every module (including their initializers), and then cleaned
1286// up afterwards.
1287static bool needsConservativeImport(const GlobalValue *GV) {
1288 if (const auto *Var = dyn_cast<GlobalVariable>(Val: GV))
1289 return Var->hasLocalLinkage();
1290 return isa<GlobalAlias>(Val: GV);
1291}
1292
1293/// Prints a summary of the partition \p N, represented by module \p M, to \p
1294/// OS.
1295static void printPartitionSummary(raw_ostream &OS, unsigned N, const Module &M,
1296 unsigned PartCost, unsigned ModuleCost) {
1297 OS << "*** Partition P" << N << " ***\n";
1298
1299 for (const auto &Fn : M) {
1300 if (!Fn.isDeclaration())
1301 OS << " - [function] " << Fn.getName() << "\n";
1302 }
1303
1304 for (const auto &GV : M.globals()) {
1305 if (GV.hasInitializer())
1306 OS << " - [global] " << GV.getName() << "\n";
1307 }
1308
1309 OS << "Partition contains " << formatRatioOf(Num: PartCost, Dem: ModuleCost)
1310 << "% of the source\n";
1311}
1312
1313static void evaluateProposal(SplitProposal &Best, SplitProposal New) {
1314 SplitModuleTimer SMT("proposal_evaluation", "proposal ranking algorithm");
1315
1316 LLVM_DEBUG({
1317 New.verifyCompleteness();
1318 if (DebugProposalSearch)
1319 New.print(dbgs());
1320 });
1321
1322 const double CurBScore = Best.getBottleneckScore();
1323 const double CurCSScore = Best.getCodeSizeScore();
1324 const double NewBScore = New.getBottleneckScore();
1325 const double NewCSScore = New.getCodeSizeScore();
1326
1327 // TODO: Improve this
1328 // We can probably lower the precision of the comparison at first
1329 // e.g. if we have
1330 // - (Current): BScore: 0.489 CSCore 1.105
1331 // - (New): BScore: 0.475 CSCore 1.305
1332 // Currently we'd choose the new one because the bottleneck score is
1333 // lower, but the new one duplicates more code. It may be worth it to
1334 // discard the new proposal as the impact on build time is negligible.
1335
1336 // Compare them
1337 bool IsBest = false;
1338 if (NewBScore < CurBScore)
1339 IsBest = true;
1340 else if (NewBScore == CurBScore)
1341 IsBest = (NewCSScore < CurCSScore); // Use code size as tie breaker.
1342
1343 if (IsBest)
1344 Best = std::move(New);
1345
1346 LLVM_DEBUG(if (DebugProposalSearch) {
1347 if (IsBest)
1348 dbgs() << "[search] new best proposal!\n";
1349 else
1350 dbgs() << "[search] discarding - not profitable\n";
1351 });
1352}
1353
1354/// Trivial helper to create an identical copy of \p M.
1355static std::unique_ptr<Module> cloneAll(const Module &M) {
1356 ValueToValueMapTy VMap;
1357 return CloneModule(M, VMap, ShouldCloneDefinition: [&](const GlobalValue *GV) { return true; });
1358}
1359
1360/// Writes \p SG as a DOTGraph to \ref ModuleDotCfgDir if requested.
1361static void writeDOTGraph(const SplitGraph &SG) {
1362 if (ModuleDotCfgOutput.empty())
1363 return;
1364
1365 std::error_code EC;
1366 raw_fd_ostream OS(ModuleDotCfgOutput, EC);
1367 if (EC) {
1368 errs() << "[" DEBUG_TYPE "]: cannot open '" << ModuleDotCfgOutput
1369 << "' - DOTGraph will not be printed\n";
1370 }
1371 WriteGraph(O&: OS, G: SG, /*ShortName=*/ShortNames: false,
1372 /*Title=*/SG.getModule().getName());
1373}
1374
1375static void splitAMDGPUModule(
1376 GetTTIFn GetTTI, Module &M, unsigned NumParts,
1377 function_ref<void(std::unique_ptr<Module> MPart)> ModuleCallback) {
1378 CallGraph CG(M);
1379
1380 // Externalize functions whose address are taken.
1381 //
1382 // This is needed because partitioning is purely based on calls, but sometimes
1383 // a kernel/function may just look at the address of another local function
1384 // and not do anything (no calls). After partitioning, that local function may
1385 // end up in a different module (so it's just a declaration in the module
1386 // where its address is taken), which emits a "undefined hidden symbol" linker
1387 // error.
1388 //
1389 // Additionally, it guides partitioning to not duplicate this function if it's
1390 // called directly at some point.
1391 //
1392 // TODO: Could we be smarter about this ? This makes all functions whose
1393 // addresses are taken non-copyable. We should probably model this type of
1394 // constraint in the graph and use it to guide splitting, instead of
1395 // externalizing like this. Maybe non-copyable should really mean "keep one
1396 // visible copy, then internalize all other copies" for some functions?
1397 if (!NoExternalizeOnAddrTaken) {
1398 for (auto &Fn : M) {
1399 // TODO: Should aliases count? Probably not but they're so rare I'm not
1400 // sure it's worth fixing.
1401 if (Fn.hasLocalLinkage() && Fn.hasAddressTaken()) {
1402 LLVM_DEBUG(dbgs() << "[externalize] "; Fn.printAsOperand(dbgs());
1403 dbgs() << " because its address is taken\n");
1404 externalize(GV&: Fn);
1405 }
1406 }
1407 }
1408
1409 // Externalize local GVs, which avoids duplicating their initializers, which
1410 // in turns helps keep code size in check.
1411 if (!NoExternalizeGlobals) {
1412 for (auto &GV : M.globals()) {
1413 if (GV.hasLocalLinkage())
1414 LLVM_DEBUG(dbgs() << "[externalize] GV " << GV.getName() << '\n');
1415 externalize(GV);
1416 }
1417 }
1418
1419 // Start by calculating the cost of every function in the module, as well as
1420 // the module's overall cost.
1421 FunctionsCostMap FnCosts;
1422 const CostType ModuleCost = calculateFunctionCosts(GetTTI, M, CostMap&: FnCosts);
1423
1424 // Build the SplitGraph, which represents the module's functions and models
1425 // their dependencies accurately.
1426 SplitGraph SG(M, FnCosts, ModuleCost);
1427 SG.buildGraph(CG);
1428
1429 if (SG.empty()) {
1430 LLVM_DEBUG(
1431 dbgs()
1432 << "[!] no nodes in graph, input is empty - no splitting possible\n");
1433 ModuleCallback(cloneAll(M));
1434 return;
1435 }
1436
1437 LLVM_DEBUG({
1438 dbgs() << "[graph] nodes:\n";
1439 for (const SplitGraph::Node *N : SG.nodes()) {
1440 dbgs() << " - [" << N->getID() << "]: " << N->getName() << " "
1441 << (N->isGraphEntryPoint() ? "(entry)" : "") << " "
1442 << (N->isNonCopyable() ? "(noncopyable)" : "") << "\n";
1443 }
1444 });
1445
1446 writeDOTGraph(SG);
1447
1448 LLVM_DEBUG(dbgs() << "[search] testing splitting strategies\n");
1449
1450 std::optional<SplitProposal> Proposal;
1451 const auto EvaluateProposal = [&](SplitProposal SP) {
1452 SP.calculateScores();
1453 if (!Proposal)
1454 Proposal = std::move(SP);
1455 else
1456 evaluateProposal(Best&: *Proposal, New: std::move(SP));
1457 };
1458
1459 // TODO: It would be very easy to create new strategies by just adding a base
1460 // class to RecursiveSearchSplitting and abstracting it away.
1461 RecursiveSearchSplitting(SG, NumParts, EvaluateProposal).run();
1462 LLVM_DEBUG(if (Proposal) dbgs() << "[search done] selected proposal: "
1463 << Proposal->getName() << "\n";);
1464
1465 if (!Proposal) {
1466 LLVM_DEBUG(dbgs() << "[!] no proposal made, no splitting possible!\n");
1467 ModuleCallback(cloneAll(M));
1468 return;
1469 }
1470
1471 LLVM_DEBUG(Proposal->print(dbgs()););
1472
1473 std::optional<raw_fd_ostream> SummariesOS;
1474 if (!PartitionSummariesOutput.empty()) {
1475 std::error_code EC;
1476 SummariesOS.emplace(args&: PartitionSummariesOutput, args&: EC);
1477 if (EC)
1478 errs() << "[" DEBUG_TYPE "]: cannot open '" << PartitionSummariesOutput
1479 << "' - Partition summaries will not be printed\n";
1480 }
1481
1482 // One module will import all GlobalValues that are not Functions
1483 // and are not subject to conservative import.
1484 bool ImportAllGVs = true;
1485
1486 for (unsigned PID = 0; PID < NumParts; ++PID) {
1487 SplitModuleTimer SMT2("modules_creation",
1488 "creating modules for each partition");
1489 LLVM_DEBUG(dbgs() << "[split] creating new modules\n");
1490
1491 DenseSet<const Function *> FnsInPart;
1492 for (unsigned NodeID : (*Proposal)[PID].set_bits())
1493 FnsInPart.insert(V: &SG.getNode(ID: NodeID).getFunction());
1494
1495 // Don't create empty modules.
1496 if (FnsInPart.empty()) {
1497 LLVM_DEBUG(dbgs() << "[split] P" << PID
1498 << " is empty, not creating module\n");
1499 continue;
1500 }
1501
1502 ValueToValueMapTy VMap;
1503 CostType PartCost = 0;
1504 std::unique_ptr<Module> MPart(
1505 CloneModule(M, VMap, ShouldCloneDefinition: [&](const GlobalValue *GV) {
1506 // Functions go in their assigned partition.
1507 if (const auto *Fn = dyn_cast<Function>(Val: GV)) {
1508 if (FnsInPart.contains(V: Fn)) {
1509 PartCost += SG.getCost(F: *Fn);
1510 return true;
1511 }
1512 return false;
1513 }
1514
1515 // Everything else goes in the first non-empty module we create.
1516 return ImportAllGVs || needsConservativeImport(GV);
1517 }));
1518
1519 ImportAllGVs = false;
1520
1521 // FIXME: Aliases aren't seen often, and their handling isn't perfect so
1522 // bugs are possible.
1523
1524 // Clean-up conservatively imported GVs without any users.
1525 for (auto &GV : make_early_inc_range(Range: MPart->global_values())) {
1526 if (needsConservativeImport(GV: &GV) && GV.use_empty())
1527 GV.eraseFromParent();
1528 }
1529
1530 if (SummariesOS)
1531 printPartitionSummary(OS&: *SummariesOS, N: PID, M: *MPart, PartCost, ModuleCost);
1532
1533 LLVM_DEBUG(
1534 printPartitionSummary(dbgs(), PID, *MPart, PartCost, ModuleCost));
1535
1536 ModuleCallback(std::move(MPart));
1537 }
1538}
1539} // namespace
1540
1541PreservedAnalyses AMDGPUSplitModulePass::run(Module &M,
1542 ModuleAnalysisManager &MAM) {
1543 SplitModuleTimer SMT(
1544 "total", "total pass runtime (incl. potentially waiting for lockfile)");
1545
1546 FunctionAnalysisManager &FAM =
1547 MAM.getResult<FunctionAnalysisManagerModuleProxy>(IR&: M).getManager();
1548 const auto TTIGetter = [&FAM](Function &F) -> const TargetTransformInfo & {
1549 return FAM.getResult<TargetIRAnalysis>(IR&: F);
1550 };
1551
1552 bool Done = false;
1553#ifndef NDEBUG
1554 if (UseLockFile) {
1555 SmallString<128> LockFilePath;
1556 sys::path::system_temp_directory(/*ErasedOnReboot=*/true, LockFilePath);
1557 sys::path::append(LockFilePath, "amdgpu-split-module-debug");
1558 LLVM_DEBUG(dbgs() << DEBUG_TYPE " using lockfile '" << LockFilePath
1559 << "'\n");
1560
1561 while (true) {
1562 llvm::LockFileManager Lock(LockFilePath.str());
1563 bool Owned;
1564 if (Error Err = Lock.tryLock().moveInto(Owned)) {
1565 consumeError(std::move(Err));
1566 LLVM_DEBUG(
1567 dbgs() << "[amdgpu-split-module] unable to acquire lockfile, debug "
1568 "output may be mangled by other processes\n");
1569 } else if (!Owned) {
1570 switch (Lock.waitForUnlockFor(std::chrono::seconds(90))) {
1571 case WaitForUnlockResult::Success:
1572 break;
1573 case WaitForUnlockResult::OwnerDied:
1574 continue; // try again to get the lock.
1575 case WaitForUnlockResult::Timeout:
1576 LLVM_DEBUG(
1577 dbgs()
1578 << "[amdgpu-split-module] unable to acquire lockfile, debug "
1579 "output may be mangled by other processes\n");
1580 Lock.unsafeMaybeUnlock();
1581 break; // give up
1582 }
1583 }
1584
1585 splitAMDGPUModule(TTIGetter, M, N, ModuleCallback);
1586 Done = true;
1587 break;
1588 }
1589 }
1590#endif
1591
1592 if (!Done)
1593 splitAMDGPUModule(GetTTI: TTIGetter, M, NumParts: N, ModuleCallback);
1594
1595 // We can change linkage/visibilities in the input, consider that nothing is
1596 // preserved just to be safe. This pass runs last anyway.
1597 return PreservedAnalyses::none();
1598}
1599} // namespace llvm
1600