1//===- BalancedPartitioning.cpp -------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements BalancedPartitioning, a recursive balanced graph
10// partitioning algorithm.
11//
12//===----------------------------------------------------------------------===//
13
14#include "llvm/Support/BalancedPartitioning.h"
15#include "llvm/Support/Debug.h"
16#include "llvm/Support/Format.h"
17#include "llvm/Support/FormatVariadic.h"
18#include "llvm/Support/ThreadPool.h"
19
20using namespace llvm;
21#define DEBUG_TYPE "balanced-partitioning"
22
23void BPFunctionNode::dump(raw_ostream &OS) const {
24 OS << formatv(Fmt: "{{ID={0} Utilities={{{1:$[,]}} Bucket={2}}", Vals: Id,
25 Vals: make_range(x: UtilityNodes.begin(), y: UtilityNodes.end()), Vals: Bucket);
26}
27
28template <typename Func>
29void BalancedPartitioning::BPThreadPool::async(Func &&F) {
30#if LLVM_ENABLE_THREADS
31 // This new thread could spawn more threads, so mark it as active
32 ++NumActiveThreads;
33 TheThreadPool.async([=]() {
34 // Run the task
35 F();
36
37 // This thread will no longer spawn new threads, so mark it as inactive
38 if (--NumActiveThreads == 0) {
39 // There are no more active threads, so mark as finished and notify
40 {
41 std::unique_lock<std::mutex> lock(mtx);
42 assert(!IsFinishedSpawning);
43 IsFinishedSpawning = true;
44 }
45 cv.notify_one();
46 }
47 });
48#else
49 llvm_unreachable("threads are disabled");
50#endif
51}
52
53void BalancedPartitioning::BPThreadPool::wait() {
54#if LLVM_ENABLE_THREADS
55 // TODO: We could remove the mutex and condition variable and use
56 // std::atomic::wait() instead, but that isn't available until C++20
57 {
58 std::unique_lock<std::mutex> lock(mtx);
59 cv.wait(lock&: lock, p: [&]() { return IsFinishedSpawning; });
60 assert(IsFinishedSpawning && NumActiveThreads == 0);
61 }
62 // Now we can call ThreadPool::wait() since all tasks have been submitted
63 TheThreadPool.wait();
64#else
65 llvm_unreachable("threads are disabled");
66#endif
67}
68
69BalancedPartitioning::BalancedPartitioning(
70 const BalancedPartitioningConfig &Config)
71 : Config(Config) {
72 // Pre-computing log2 values
73 Log2Cache[0] = 0.0;
74 for (unsigned I = 1; I < LOG_CACHE_SIZE; I++)
75 Log2Cache[I] = std::log2(x: I);
76}
77
78void BalancedPartitioning::run(std::vector<BPFunctionNode> &Nodes) const {
79 LLVM_DEBUG(
80 dbgs() << format(
81 "Partitioning %d nodes using depth %d and %d iterations per split\n",
82 Nodes.size(), Config.SplitDepth, Config.IterationsPerSplit));
83 std::optional<BPThreadPool> TP;
84#if LLVM_ENABLE_THREADS
85 DefaultThreadPool TheThreadPool;
86 if (Config.TaskSplitDepth > 1)
87 TP.emplace(args&: TheThreadPool);
88#endif
89
90 // Record the input order
91 for (unsigned I = 0; I < Nodes.size(); I++)
92 Nodes[I].InputOrderIndex = I;
93
94 auto NodesRange = llvm::make_range(x: Nodes.begin(), y: Nodes.end());
95 auto BisectTask = [=, &TP]() {
96 bisect(Nodes: NodesRange, /*RecDepth=*/0, /*RootBucket=*/1, /*Offset=*/0, TP);
97 };
98 if (TP) {
99 TP->async(F: std::move(BisectTask));
100 TP->wait();
101 } else {
102 BisectTask();
103 }
104
105 llvm::stable_sort(Range&: NodesRange, C: [](const auto &L, const auto &R) {
106 return L.Bucket < R.Bucket;
107 });
108
109 LLVM_DEBUG(dbgs() << "Balanced partitioning completed\n");
110}
111
112void BalancedPartitioning::bisect(const FunctionNodeRange Nodes,
113 unsigned RecDepth, unsigned RootBucket,
114 unsigned Offset,
115 std::optional<BPThreadPool> &TP) const {
116 unsigned NumNodes = std::distance(first: Nodes.begin(), last: Nodes.end());
117 if (NumNodes <= 1 || RecDepth >= Config.SplitDepth) {
118 // We've reach the lowest level of the recursion tree. Fall back to the
119 // original order and assign to buckets.
120 llvm::sort(C: Nodes, Comp: [](const auto &L, const auto &R) {
121 return L.InputOrderIndex < R.InputOrderIndex;
122 });
123 for (auto &N : Nodes)
124 N.Bucket = Offset++;
125 return;
126 }
127
128 LLVM_DEBUG(dbgs() << format("Bisect with %d nodes and root bucket %d\n",
129 NumNodes, RootBucket));
130
131 std::mt19937 RNG(RootBucket);
132
133 unsigned LeftBucket = 2 * RootBucket;
134 unsigned RightBucket = 2 * RootBucket + 1;
135
136 // Split into two and assign to the left and right buckets
137 split(Nodes, StartBucket: LeftBucket);
138
139 runIterations(Nodes, LeftBucket, RightBucket, RNG);
140
141 // Split nodes wrt the resulting buckets
142 auto NodesMid =
143 llvm::partition(Range: Nodes, P: [&](auto &N) { return N.Bucket == LeftBucket; });
144 unsigned MidOffset = Offset + std::distance(first: Nodes.begin(), last: NodesMid);
145
146 auto LeftNodes = llvm::make_range(x: Nodes.begin(), y: NodesMid);
147 auto RightNodes = llvm::make_range(x: NodesMid, y: Nodes.end());
148
149 auto LeftRecTask = [=, &TP]() {
150 bisect(Nodes: LeftNodes, RecDepth: RecDepth + 1, RootBucket: LeftBucket, Offset, TP);
151 };
152 auto RightRecTask = [=, &TP]() {
153 bisect(Nodes: RightNodes, RecDepth: RecDepth + 1, RootBucket: RightBucket, Offset: MidOffset, TP);
154 };
155
156 if (TP && RecDepth < Config.TaskSplitDepth && NumNodes >= 4) {
157 TP->async(F: std::move(LeftRecTask));
158 TP->async(F: std::move(RightRecTask));
159 } else {
160 LeftRecTask();
161 RightRecTask();
162 }
163}
164
165void BalancedPartitioning::runIterations(const FunctionNodeRange Nodes,
166 unsigned LeftBucket,
167 unsigned RightBucket,
168 std::mt19937 &RNG) const {
169 unsigned NumNodes = std::distance(first: Nodes.begin(), last: Nodes.end());
170 DenseMap<BPFunctionNode::UtilityNodeT, unsigned> UtilityNodeIndex;
171 for (auto &N : Nodes)
172 for (auto &UN : N.UtilityNodes)
173 ++UtilityNodeIndex[UN];
174 // Remove utility nodes if they have just one edge or are connected to all
175 // functions
176 for (auto &N : Nodes)
177 llvm::erase_if(C&: N.UtilityNodes, P: [&](auto &UN) {
178 return UtilityNodeIndex[UN] == 1 || UtilityNodeIndex[UN] == NumNodes;
179 });
180
181 // Renumber utility nodes so they can be used to index into Signatures
182 UtilityNodeIndex.clear();
183 for (auto &N : Nodes)
184 for (auto &UN : N.UtilityNodes)
185 UN = UtilityNodeIndex.insert(KV: {UN, UtilityNodeIndex.size()}).first->second;
186
187 // Initialize signatures
188 SignaturesT Signatures(/*Size=*/UtilityNodeIndex.size());
189 for (auto &N : Nodes) {
190 for (auto &UN : N.UtilityNodes) {
191 assert(UN < Signatures.size());
192 if (N.Bucket == LeftBucket) {
193 Signatures[UN].LeftCount++;
194 } else {
195 Signatures[UN].RightCount++;
196 }
197 }
198 }
199
200 for (unsigned I = 0; I < Config.IterationsPerSplit; I++) {
201 unsigned NumMovedNodes =
202 runIteration(Nodes, LeftBucket, RightBucket, Signatures, RNG);
203 if (NumMovedNodes == 0)
204 break;
205 }
206}
207
208unsigned BalancedPartitioning::runIteration(const FunctionNodeRange Nodes,
209 unsigned LeftBucket,
210 unsigned RightBucket,
211 SignaturesT &Signatures,
212 std::mt19937 &RNG) const {
213 // Init signature cost caches
214 for (auto &Signature : Signatures) {
215 if (Signature.CachedGainIsValid)
216 continue;
217 unsigned L = Signature.LeftCount;
218 unsigned R = Signature.RightCount;
219 assert((L > 0 || R > 0) && "incorrect signature");
220 float Cost = logCost(X: L, Y: R);
221 Signature.CachedGainLR = 0.f;
222 Signature.CachedGainRL = 0.f;
223 if (L > 0)
224 Signature.CachedGainLR = Cost - logCost(X: L - 1, Y: R + 1);
225 if (R > 0)
226 Signature.CachedGainRL = Cost - logCost(X: L + 1, Y: R - 1);
227 Signature.CachedGainIsValid = true;
228 }
229
230 // Compute move gains
231 typedef std::pair<float, BPFunctionNode *> GainPair;
232 std::vector<GainPair> Gains;
233 for (auto &N : Nodes) {
234 bool FromLeftToRight = (N.Bucket == LeftBucket);
235 float Gain = moveGain(N, FromLeftToRight, Signatures);
236 Gains.push_back(x: std::make_pair(x&: Gain, y: &N));
237 }
238
239 // Collect left and right gains
240 auto LeftEnd = llvm::partition(
241 Range&: Gains, P: [&](const auto &GP) { return GP.second->Bucket == LeftBucket; });
242 auto LeftRange = llvm::make_range(x: Gains.begin(), y: LeftEnd);
243 auto RightRange = llvm::make_range(x: LeftEnd, y: Gains.end());
244
245 // Sort gains in descending order
246 auto LargerGain = [](const auto &L, const auto &R) {
247 return L.first > R.first;
248 };
249 llvm::stable_sort(Range&: LeftRange, C: LargerGain);
250 llvm::stable_sort(Range&: RightRange, C: LargerGain);
251
252 unsigned NumMovedDataVertices = 0;
253 for (auto [LeftPair, RightPair] : llvm::zip(t&: LeftRange, u&: RightRange)) {
254 auto &[LeftGain, LeftNode] = LeftPair;
255 auto &[RightGain, RightNode] = RightPair;
256 // Stop when the gain is no longer beneficial
257 if (LeftGain + RightGain <= 0.f)
258 break;
259 // Try to exchange the nodes between buckets
260 if (moveFunctionNode(N&: *LeftNode, LeftBucket, RightBucket, Signatures, RNG))
261 ++NumMovedDataVertices;
262 if (moveFunctionNode(N&: *RightNode, LeftBucket, RightBucket, Signatures, RNG))
263 ++NumMovedDataVertices;
264 }
265 return NumMovedDataVertices;
266}
267
268bool BalancedPartitioning::moveFunctionNode(BPFunctionNode &N,
269 unsigned LeftBucket,
270 unsigned RightBucket,
271 SignaturesT &Signatures,
272 std::mt19937 &RNG) const {
273 // Sometimes we skip the move. This helps to escape local optima
274 if (std::uniform_real_distribution<float>(0.f, 1.f)(RNG) <=
275 Config.SkipProbability)
276 return false;
277
278 bool FromLeftToRight = (N.Bucket == LeftBucket);
279 // Update the current bucket
280 N.Bucket = (FromLeftToRight ? RightBucket : LeftBucket);
281
282 // Update signatures and invalidate gain cache
283 if (FromLeftToRight) {
284 for (auto &UN : N.UtilityNodes) {
285 auto &Signature = Signatures[UN];
286 Signature.LeftCount--;
287 Signature.RightCount++;
288 Signature.CachedGainIsValid = false;
289 }
290 } else {
291 for (auto &UN : N.UtilityNodes) {
292 auto &Signature = Signatures[UN];
293 Signature.LeftCount++;
294 Signature.RightCount--;
295 Signature.CachedGainIsValid = false;
296 }
297 }
298 return true;
299}
300
301void BalancedPartitioning::split(const FunctionNodeRange Nodes,
302 unsigned StartBucket) const {
303 unsigned NumNodes = std::distance(first: Nodes.begin(), last: Nodes.end());
304 auto NodesMid = Nodes.begin() + (NumNodes + 1) / 2;
305
306 std::nth_element(first: Nodes.begin(), nth: NodesMid, last: Nodes.end(), comp: [](auto &L, auto &R) {
307 return L.InputOrderIndex < R.InputOrderIndex;
308 });
309
310 for (auto &N : llvm::make_range(x: Nodes.begin(), y: NodesMid))
311 N.Bucket = StartBucket;
312 for (auto &N : llvm::make_range(x: NodesMid, y: Nodes.end()))
313 N.Bucket = StartBucket + 1;
314}
315
316float BalancedPartitioning::moveGain(const BPFunctionNode &N,
317 bool FromLeftToRight,
318 const SignaturesT &Signatures) {
319 float Gain = 0.f;
320 for (auto &UN : N.UtilityNodes)
321 Gain += (FromLeftToRight ? Signatures[UN].CachedGainLR
322 : Signatures[UN].CachedGainRL);
323 return Gain;
324}
325
326float BalancedPartitioning::logCost(unsigned X, unsigned Y) const {
327 return -(X * log2Cached(i: X + 1) + Y * log2Cached(i: Y + 1));
328}
329
330float BalancedPartitioning::log2Cached(unsigned i) const {
331 return (i < LOG_CACHE_SIZE) ? Log2Cache[i] : std::log2(x: i);
332}
333