1 | //===- BalancedPartitioning.cpp -------------------------------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file implements BalancedPartitioning, a recursive balanced graph |
10 | // partitioning algorithm. |
11 | // |
12 | //===----------------------------------------------------------------------===// |
13 | |
14 | #include "llvm/Support/BalancedPartitioning.h" |
15 | #include "llvm/Support/Debug.h" |
16 | #include "llvm/Support/Format.h" |
17 | #include "llvm/Support/FormatVariadic.h" |
18 | #include "llvm/Support/ThreadPool.h" |
19 | |
20 | using namespace llvm; |
21 | #define DEBUG_TYPE "balanced-partitioning" |
22 | |
23 | void BPFunctionNode::dump(raw_ostream &OS) const { |
24 | OS << formatv(Fmt: "{{ID={0} Utilities={{{1:$[,]}} Bucket={2}}" , Vals: Id, |
25 | Vals: make_range(x: UtilityNodes.begin(), y: UtilityNodes.end()), Vals: Bucket); |
26 | } |
27 | |
28 | template <typename Func> |
29 | void BalancedPartitioning::BPThreadPool::async(Func &&F) { |
30 | #if LLVM_ENABLE_THREADS |
31 | // This new thread could spawn more threads, so mark it as active |
32 | ++NumActiveThreads; |
33 | TheThreadPool.async([=]() { |
34 | // Run the task |
35 | F(); |
36 | |
37 | // This thread will no longer spawn new threads, so mark it as inactive |
38 | if (--NumActiveThreads == 0) { |
39 | // There are no more active threads, so mark as finished and notify |
40 | { |
41 | std::unique_lock<std::mutex> lock(mtx); |
42 | assert(!IsFinishedSpawning); |
43 | IsFinishedSpawning = true; |
44 | } |
45 | cv.notify_one(); |
46 | } |
47 | }); |
48 | #else |
49 | llvm_unreachable("threads are disabled" ); |
50 | #endif |
51 | } |
52 | |
53 | void BalancedPartitioning::BPThreadPool::wait() { |
54 | #if LLVM_ENABLE_THREADS |
55 | // TODO: We could remove the mutex and condition variable and use |
56 | // std::atomic::wait() instead, but that isn't available until C++20 |
57 | { |
58 | std::unique_lock<std::mutex> lock(mtx); |
59 | cv.wait(lock&: lock, p: [&]() { return IsFinishedSpawning; }); |
60 | assert(IsFinishedSpawning && NumActiveThreads == 0); |
61 | } |
62 | // Now we can call ThreadPool::wait() since all tasks have been submitted |
63 | TheThreadPool.wait(); |
64 | #else |
65 | llvm_unreachable("threads are disabled" ); |
66 | #endif |
67 | } |
68 | |
69 | BalancedPartitioning::BalancedPartitioning( |
70 | const BalancedPartitioningConfig &Config) |
71 | : Config(Config) { |
72 | // Pre-computing log2 values |
73 | Log2Cache[0] = 0.0; |
74 | for (unsigned I = 1; I < LOG_CACHE_SIZE; I++) |
75 | Log2Cache[I] = std::log2(x: I); |
76 | } |
77 | |
78 | void BalancedPartitioning::run(std::vector<BPFunctionNode> &Nodes) const { |
79 | LLVM_DEBUG( |
80 | dbgs() << format( |
81 | "Partitioning %d nodes using depth %d and %d iterations per split\n" , |
82 | Nodes.size(), Config.SplitDepth, Config.IterationsPerSplit)); |
83 | std::optional<BPThreadPool> TP; |
84 | #if LLVM_ENABLE_THREADS |
85 | DefaultThreadPool TheThreadPool; |
86 | if (Config.TaskSplitDepth > 1) |
87 | TP.emplace(args&: TheThreadPool); |
88 | #endif |
89 | |
90 | // Record the input order |
91 | for (unsigned I = 0; I < Nodes.size(); I++) |
92 | Nodes[I].InputOrderIndex = I; |
93 | |
94 | auto NodesRange = llvm::make_range(x: Nodes.begin(), y: Nodes.end()); |
95 | auto BisectTask = [=, &TP]() { |
96 | bisect(Nodes: NodesRange, /*RecDepth=*/0, /*RootBucket=*/1, /*Offset=*/0, TP); |
97 | }; |
98 | if (TP) { |
99 | TP->async(F: std::move(BisectTask)); |
100 | TP->wait(); |
101 | } else { |
102 | BisectTask(); |
103 | } |
104 | |
105 | llvm::stable_sort(Range&: NodesRange, C: [](const auto &L, const auto &R) { |
106 | return L.Bucket < R.Bucket; |
107 | }); |
108 | |
109 | LLVM_DEBUG(dbgs() << "Balanced partitioning completed\n" ); |
110 | } |
111 | |
112 | void BalancedPartitioning::bisect(const FunctionNodeRange Nodes, |
113 | unsigned RecDepth, unsigned RootBucket, |
114 | unsigned Offset, |
115 | std::optional<BPThreadPool> &TP) const { |
116 | unsigned NumNodes = std::distance(first: Nodes.begin(), last: Nodes.end()); |
117 | if (NumNodes <= 1 || RecDepth >= Config.SplitDepth) { |
118 | // We've reach the lowest level of the recursion tree. Fall back to the |
119 | // original order and assign to buckets. |
120 | llvm::sort(C: Nodes, Comp: [](const auto &L, const auto &R) { |
121 | return L.InputOrderIndex < R.InputOrderIndex; |
122 | }); |
123 | for (auto &N : Nodes) |
124 | N.Bucket = Offset++; |
125 | return; |
126 | } |
127 | |
128 | LLVM_DEBUG(dbgs() << format("Bisect with %d nodes and root bucket %d\n" , |
129 | NumNodes, RootBucket)); |
130 | |
131 | std::mt19937 RNG(RootBucket); |
132 | |
133 | unsigned LeftBucket = 2 * RootBucket; |
134 | unsigned RightBucket = 2 * RootBucket + 1; |
135 | |
136 | // Split into two and assign to the left and right buckets |
137 | split(Nodes, StartBucket: LeftBucket); |
138 | |
139 | runIterations(Nodes, LeftBucket, RightBucket, RNG); |
140 | |
141 | // Split nodes wrt the resulting buckets |
142 | auto NodesMid = |
143 | llvm::partition(Range: Nodes, P: [&](auto &N) { return N.Bucket == LeftBucket; }); |
144 | unsigned MidOffset = Offset + std::distance(first: Nodes.begin(), last: NodesMid); |
145 | |
146 | auto LeftNodes = llvm::make_range(x: Nodes.begin(), y: NodesMid); |
147 | auto RightNodes = llvm::make_range(x: NodesMid, y: Nodes.end()); |
148 | |
149 | auto LeftRecTask = [=, &TP]() { |
150 | bisect(Nodes: LeftNodes, RecDepth: RecDepth + 1, RootBucket: LeftBucket, Offset, TP); |
151 | }; |
152 | auto RightRecTask = [=, &TP]() { |
153 | bisect(Nodes: RightNodes, RecDepth: RecDepth + 1, RootBucket: RightBucket, Offset: MidOffset, TP); |
154 | }; |
155 | |
156 | if (TP && RecDepth < Config.TaskSplitDepth && NumNodes >= 4) { |
157 | TP->async(F: std::move(LeftRecTask)); |
158 | TP->async(F: std::move(RightRecTask)); |
159 | } else { |
160 | LeftRecTask(); |
161 | RightRecTask(); |
162 | } |
163 | } |
164 | |
165 | void BalancedPartitioning::runIterations(const FunctionNodeRange Nodes, |
166 | unsigned LeftBucket, |
167 | unsigned RightBucket, |
168 | std::mt19937 &RNG) const { |
169 | unsigned NumNodes = std::distance(first: Nodes.begin(), last: Nodes.end()); |
170 | DenseMap<BPFunctionNode::UtilityNodeT, unsigned> UtilityNodeIndex; |
171 | for (auto &N : Nodes) |
172 | for (auto &UN : N.UtilityNodes) |
173 | ++UtilityNodeIndex[UN]; |
174 | // Remove utility nodes if they have just one edge or are connected to all |
175 | // functions |
176 | for (auto &N : Nodes) |
177 | llvm::erase_if(C&: N.UtilityNodes, P: [&](auto &UN) { |
178 | return UtilityNodeIndex[UN] == 1 || UtilityNodeIndex[UN] == NumNodes; |
179 | }); |
180 | |
181 | // Renumber utility nodes so they can be used to index into Signatures |
182 | UtilityNodeIndex.clear(); |
183 | for (auto &N : Nodes) |
184 | for (auto &UN : N.UtilityNodes) |
185 | UN = UtilityNodeIndex.insert(KV: {UN, UtilityNodeIndex.size()}).first->second; |
186 | |
187 | // Initialize signatures |
188 | SignaturesT Signatures(/*Size=*/UtilityNodeIndex.size()); |
189 | for (auto &N : Nodes) { |
190 | for (auto &UN : N.UtilityNodes) { |
191 | assert(UN < Signatures.size()); |
192 | if (N.Bucket == LeftBucket) { |
193 | Signatures[UN].LeftCount++; |
194 | } else { |
195 | Signatures[UN].RightCount++; |
196 | } |
197 | } |
198 | } |
199 | |
200 | for (unsigned I = 0; I < Config.IterationsPerSplit; I++) { |
201 | unsigned NumMovedNodes = |
202 | runIteration(Nodes, LeftBucket, RightBucket, Signatures, RNG); |
203 | if (NumMovedNodes == 0) |
204 | break; |
205 | } |
206 | } |
207 | |
208 | unsigned BalancedPartitioning::runIteration(const FunctionNodeRange Nodes, |
209 | unsigned LeftBucket, |
210 | unsigned RightBucket, |
211 | SignaturesT &Signatures, |
212 | std::mt19937 &RNG) const { |
213 | // Init signature cost caches |
214 | for (auto &Signature : Signatures) { |
215 | if (Signature.CachedGainIsValid) |
216 | continue; |
217 | unsigned L = Signature.LeftCount; |
218 | unsigned R = Signature.RightCount; |
219 | assert((L > 0 || R > 0) && "incorrect signature" ); |
220 | float Cost = logCost(X: L, Y: R); |
221 | Signature.CachedGainLR = 0.f; |
222 | Signature.CachedGainRL = 0.f; |
223 | if (L > 0) |
224 | Signature.CachedGainLR = Cost - logCost(X: L - 1, Y: R + 1); |
225 | if (R > 0) |
226 | Signature.CachedGainRL = Cost - logCost(X: L + 1, Y: R - 1); |
227 | Signature.CachedGainIsValid = true; |
228 | } |
229 | |
230 | // Compute move gains |
231 | typedef std::pair<float, BPFunctionNode *> GainPair; |
232 | std::vector<GainPair> Gains; |
233 | for (auto &N : Nodes) { |
234 | bool FromLeftToRight = (N.Bucket == LeftBucket); |
235 | float Gain = moveGain(N, FromLeftToRight, Signatures); |
236 | Gains.push_back(x: std::make_pair(x&: Gain, y: &N)); |
237 | } |
238 | |
239 | // Collect left and right gains |
240 | auto LeftEnd = llvm::partition( |
241 | Range&: Gains, P: [&](const auto &GP) { return GP.second->Bucket == LeftBucket; }); |
242 | auto LeftRange = llvm::make_range(x: Gains.begin(), y: LeftEnd); |
243 | auto RightRange = llvm::make_range(x: LeftEnd, y: Gains.end()); |
244 | |
245 | // Sort gains in descending order |
246 | auto LargerGain = [](const auto &L, const auto &R) { |
247 | return L.first > R.first; |
248 | }; |
249 | llvm::stable_sort(Range&: LeftRange, C: LargerGain); |
250 | llvm::stable_sort(Range&: RightRange, C: LargerGain); |
251 | |
252 | unsigned NumMovedDataVertices = 0; |
253 | for (auto [LeftPair, RightPair] : llvm::zip(t&: LeftRange, u&: RightRange)) { |
254 | auto &[LeftGain, LeftNode] = LeftPair; |
255 | auto &[RightGain, RightNode] = RightPair; |
256 | // Stop when the gain is no longer beneficial |
257 | if (LeftGain + RightGain <= 0.f) |
258 | break; |
259 | // Try to exchange the nodes between buckets |
260 | if (moveFunctionNode(N&: *LeftNode, LeftBucket, RightBucket, Signatures, RNG)) |
261 | ++NumMovedDataVertices; |
262 | if (moveFunctionNode(N&: *RightNode, LeftBucket, RightBucket, Signatures, RNG)) |
263 | ++NumMovedDataVertices; |
264 | } |
265 | return NumMovedDataVertices; |
266 | } |
267 | |
268 | bool BalancedPartitioning::moveFunctionNode(BPFunctionNode &N, |
269 | unsigned LeftBucket, |
270 | unsigned RightBucket, |
271 | SignaturesT &Signatures, |
272 | std::mt19937 &RNG) const { |
273 | // Sometimes we skip the move. This helps to escape local optima |
274 | if (std::uniform_real_distribution<float>(0.f, 1.f)(RNG) <= |
275 | Config.SkipProbability) |
276 | return false; |
277 | |
278 | bool FromLeftToRight = (N.Bucket == LeftBucket); |
279 | // Update the current bucket |
280 | N.Bucket = (FromLeftToRight ? RightBucket : LeftBucket); |
281 | |
282 | // Update signatures and invalidate gain cache |
283 | if (FromLeftToRight) { |
284 | for (auto &UN : N.UtilityNodes) { |
285 | auto &Signature = Signatures[UN]; |
286 | Signature.LeftCount--; |
287 | Signature.RightCount++; |
288 | Signature.CachedGainIsValid = false; |
289 | } |
290 | } else { |
291 | for (auto &UN : N.UtilityNodes) { |
292 | auto &Signature = Signatures[UN]; |
293 | Signature.LeftCount++; |
294 | Signature.RightCount--; |
295 | Signature.CachedGainIsValid = false; |
296 | } |
297 | } |
298 | return true; |
299 | } |
300 | |
301 | void BalancedPartitioning::split(const FunctionNodeRange Nodes, |
302 | unsigned StartBucket) const { |
303 | unsigned NumNodes = std::distance(first: Nodes.begin(), last: Nodes.end()); |
304 | auto NodesMid = Nodes.begin() + (NumNodes + 1) / 2; |
305 | |
306 | std::nth_element(first: Nodes.begin(), nth: NodesMid, last: Nodes.end(), comp: [](auto &L, auto &R) { |
307 | return L.InputOrderIndex < R.InputOrderIndex; |
308 | }); |
309 | |
310 | for (auto &N : llvm::make_range(x: Nodes.begin(), y: NodesMid)) |
311 | N.Bucket = StartBucket; |
312 | for (auto &N : llvm::make_range(x: NodesMid, y: Nodes.end())) |
313 | N.Bucket = StartBucket + 1; |
314 | } |
315 | |
316 | float BalancedPartitioning::moveGain(const BPFunctionNode &N, |
317 | bool FromLeftToRight, |
318 | const SignaturesT &Signatures) { |
319 | float Gain = 0.f; |
320 | for (auto &UN : N.UtilityNodes) |
321 | Gain += (FromLeftToRight ? Signatures[UN].CachedGainLR |
322 | : Signatures[UN].CachedGainRL); |
323 | return Gain; |
324 | } |
325 | |
326 | float BalancedPartitioning::logCost(unsigned X, unsigned Y) const { |
327 | return -(X * log2Cached(i: X + 1) + Y * log2Cached(i: Y + 1)); |
328 | } |
329 | |
330 | float BalancedPartitioning::log2Cached(unsigned i) const { |
331 | return (i < LOG_CACHE_SIZE) ? Log2Cache[i] : std::log2(x: i); |
332 | } |
333 | |