1 | //===- llvm/Support/Parallel.h - Parallel algorithms ----------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #ifndef LLVM_SUPPORT_PARALLEL_H |
10 | #define LLVM_SUPPORT_PARALLEL_H |
11 | |
12 | #include "llvm/ADT/STLExtras.h" |
13 | #include "llvm/Config/llvm-config.h" |
14 | #include "llvm/Support/Error.h" |
15 | #include "llvm/Support/MathExtras.h" |
16 | #include "llvm/Support/Threading.h" |
17 | |
18 | #include <algorithm> |
19 | #include <condition_variable> |
20 | #include <functional> |
21 | #include <mutex> |
22 | |
23 | namespace llvm { |
24 | |
25 | namespace parallel { |
26 | |
27 | // Strategy for the default executor used by the parallel routines provided by |
28 | // this file. It defaults to using all hardware threads and should be |
29 | // initialized before the first use of parallel routines. |
30 | extern ThreadPoolStrategy strategy; |
31 | |
32 | #if LLVM_ENABLE_THREADS |
33 | #define GET_THREAD_INDEX_IMPL \ |
34 | if (parallel::strategy.ThreadsRequested == 1) \ |
35 | return 0; \ |
36 | assert((threadIndex != UINT_MAX) && \ |
37 | "getThreadIndex() must be called from a thread created by " \ |
38 | "ThreadPoolExecutor"); \ |
39 | return threadIndex; |
40 | |
41 | #ifdef _WIN32 |
42 | // Direct access to thread_local variables from a different DLL isn't |
43 | // possible with Windows Native TLS. |
44 | unsigned getThreadIndex(); |
45 | #else |
46 | // Don't access this directly, use the getThreadIndex wrapper. |
47 | extern thread_local unsigned threadIndex; |
48 | |
49 | inline unsigned getThreadIndex() { GET_THREAD_INDEX_IMPL; } |
50 | #endif |
51 | |
52 | size_t getThreadCount(); |
53 | #else |
54 | inline unsigned getThreadIndex() { return 0; } |
55 | inline size_t getThreadCount() { return 1; } |
56 | #endif |
57 | |
58 | namespace detail { |
59 | class Latch { |
60 | uint32_t Count; |
61 | mutable std::mutex Mutex; |
62 | mutable std::condition_variable Cond; |
63 | |
64 | public: |
65 | explicit Latch(uint32_t Count = 0) : Count(Count) {} |
66 | ~Latch() { |
67 | // Ensure at least that sync() was called. |
68 | assert(Count == 0); |
69 | } |
70 | |
71 | void inc() { |
72 | std::lock_guard<std::mutex> lock(Mutex); |
73 | ++Count; |
74 | } |
75 | |
76 | void dec() { |
77 | std::lock_guard<std::mutex> lock(Mutex); |
78 | if (--Count == 0) |
79 | Cond.notify_all(); |
80 | } |
81 | |
82 | void sync() const { |
83 | std::unique_lock<std::mutex> lock(Mutex); |
84 | Cond.wait(lock&: lock, p: [&] { return Count == 0; }); |
85 | } |
86 | }; |
87 | } // namespace detail |
88 | |
89 | class TaskGroup { |
90 | detail::Latch L; |
91 | bool Parallel; |
92 | |
93 | public: |
94 | TaskGroup(); |
95 | ~TaskGroup(); |
96 | |
97 | // Spawn a task, but does not wait for it to finish. |
98 | // Tasks marked with \p Sequential will be executed |
99 | // exactly in the order which they were spawned. |
100 | // Note: Sequential tasks may be executed on different |
101 | // threads, but strictly in sequential order. |
102 | void spawn(std::function<void()> f, bool Sequential = false); |
103 | |
104 | void sync() const { L.sync(); } |
105 | |
106 | bool isParallel() const { return Parallel; } |
107 | }; |
108 | |
109 | namespace detail { |
110 | |
111 | #if LLVM_ENABLE_THREADS |
112 | const ptrdiff_t MinParallelSize = 1024; |
113 | |
114 | /// Inclusive median. |
115 | template <class RandomAccessIterator, class Comparator> |
116 | RandomAccessIterator medianOf3(RandomAccessIterator Start, |
117 | RandomAccessIterator End, |
118 | const Comparator &Comp) { |
119 | RandomAccessIterator Mid = Start + (std::distance(Start, End) / 2); |
120 | return Comp(*Start, *(End - 1)) |
121 | ? (Comp(*Mid, *(End - 1)) ? (Comp(*Start, *Mid) ? Mid : Start) |
122 | : End - 1) |
123 | : (Comp(*Mid, *Start) ? (Comp(*(End - 1), *Mid) ? Mid : End - 1) |
124 | : Start); |
125 | } |
126 | |
127 | template <class RandomAccessIterator, class Comparator> |
128 | void parallel_quick_sort(RandomAccessIterator Start, RandomAccessIterator End, |
129 | const Comparator &Comp, TaskGroup &TG, size_t Depth) { |
130 | // Do a sequential sort for small inputs. |
131 | if (std::distance(Start, End) < detail::MinParallelSize || Depth == 0) { |
132 | llvm::sort(Start, End, Comp); |
133 | return; |
134 | } |
135 | |
136 | // Partition. |
137 | auto Pivot = medianOf3(Start, End, Comp); |
138 | // Move Pivot to End. |
139 | std::swap(*(End - 1), *Pivot); |
140 | Pivot = std::partition(Start, End - 1, [&Comp, End](decltype(*Start) V) { |
141 | return Comp(V, *(End - 1)); |
142 | }); |
143 | // Move Pivot to middle of partition. |
144 | std::swap(*Pivot, *(End - 1)); |
145 | |
146 | // Recurse. |
147 | TG.spawn(f: [=, &Comp, &TG] { |
148 | parallel_quick_sort(Start, Pivot, Comp, TG, Depth - 1); |
149 | }); |
150 | parallel_quick_sort(Pivot + 1, End, Comp, TG, Depth - 1); |
151 | } |
152 | |
153 | template <class RandomAccessIterator, class Comparator> |
154 | void parallel_sort(RandomAccessIterator Start, RandomAccessIterator End, |
155 | const Comparator &Comp) { |
156 | TaskGroup TG; |
157 | parallel_quick_sort(Start, End, Comp, TG, |
158 | llvm::Log2_64(Value: std::distance(Start, End)) + 1); |
159 | } |
160 | |
161 | // TaskGroup has a relatively high overhead, so we want to reduce |
162 | // the number of spawn() calls. We'll create up to 1024 tasks here. |
163 | // (Note that 1024 is an arbitrary number. This code probably needs |
164 | // improving to take the number of available cores into account.) |
165 | enum { MaxTasksPerGroup = 1024 }; |
166 | |
167 | template <class IterTy, class ResultTy, class ReduceFuncTy, |
168 | class TransformFuncTy> |
169 | ResultTy parallel_transform_reduce(IterTy Begin, IterTy End, ResultTy Init, |
170 | ReduceFuncTy Reduce, |
171 | TransformFuncTy Transform) { |
172 | // Limit the number of tasks to MaxTasksPerGroup to limit job scheduling |
173 | // overhead on large inputs. |
174 | size_t NumInputs = std::distance(Begin, End); |
175 | if (NumInputs == 0) |
176 | return std::move(Init); |
177 | size_t NumTasks = std::min(a: static_cast<size_t>(MaxTasksPerGroup), b: NumInputs); |
178 | std::vector<ResultTy> Results(NumTasks, Init); |
179 | { |
180 | // Each task processes either TaskSize or TaskSize+1 inputs. Any inputs |
181 | // remaining after dividing them equally amongst tasks are distributed as |
182 | // one extra input over the first tasks. |
183 | TaskGroup TG; |
184 | size_t TaskSize = NumInputs / NumTasks; |
185 | size_t RemainingInputs = NumInputs % NumTasks; |
186 | IterTy TBegin = Begin; |
187 | for (size_t TaskId = 0; TaskId < NumTasks; ++TaskId) { |
188 | IterTy TEnd = TBegin + TaskSize + (TaskId < RemainingInputs ? 1 : 0); |
189 | TG.spawn(f: [=, &Transform, &Reduce, &Results] { |
190 | // Reduce the result of transformation eagerly within each task. |
191 | ResultTy R = Init; |
192 | for (IterTy It = TBegin; It != TEnd; ++It) |
193 | R = Reduce(R, Transform(*It)); |
194 | Results[TaskId] = R; |
195 | }); |
196 | TBegin = TEnd; |
197 | } |
198 | assert(TBegin == End); |
199 | } |
200 | |
201 | // Do a final reduction. There are at most 1024 tasks, so this only adds |
202 | // constant single-threaded overhead for large inputs. Hopefully most |
203 | // reductions are cheaper than the transformation. |
204 | ResultTy FinalResult = std::move(Results.front()); |
205 | for (ResultTy &PartialResult : |
206 | MutableArrayRef(Results.data() + 1, Results.size() - 1)) |
207 | FinalResult = Reduce(FinalResult, std::move(PartialResult)); |
208 | return std::move(FinalResult); |
209 | } |
210 | |
211 | #endif |
212 | |
213 | } // namespace detail |
214 | } // namespace parallel |
215 | |
216 | template <class RandomAccessIterator, |
217 | class Comparator = std::less< |
218 | typename std::iterator_traits<RandomAccessIterator>::value_type>> |
219 | void parallelSort(RandomAccessIterator Start, RandomAccessIterator End, |
220 | const Comparator &Comp = Comparator()) { |
221 | #if LLVM_ENABLE_THREADS |
222 | if (parallel::strategy.ThreadsRequested != 1) { |
223 | parallel::detail::parallel_sort(Start, End, Comp); |
224 | return; |
225 | } |
226 | #endif |
227 | llvm::sort(Start, End, Comp); |
228 | } |
229 | |
230 | void parallelFor(size_t Begin, size_t End, function_ref<void(size_t)> Fn); |
231 | |
232 | template <class IterTy, class FuncTy> |
233 | void parallelForEach(IterTy Begin, IterTy End, FuncTy Fn) { |
234 | parallelFor(0, End - Begin, [&](size_t I) { Fn(Begin[I]); }); |
235 | } |
236 | |
237 | template <class IterTy, class ResultTy, class ReduceFuncTy, |
238 | class TransformFuncTy> |
239 | ResultTy parallelTransformReduce(IterTy Begin, IterTy End, ResultTy Init, |
240 | ReduceFuncTy Reduce, |
241 | TransformFuncTy Transform) { |
242 | #if LLVM_ENABLE_THREADS |
243 | if (parallel::strategy.ThreadsRequested != 1) { |
244 | return parallel::detail::parallel_transform_reduce(Begin, End, Init, Reduce, |
245 | Transform); |
246 | } |
247 | #endif |
248 | for (IterTy I = Begin; I != End; ++I) |
249 | Init = Reduce(std::move(Init), Transform(*I)); |
250 | return std::move(Init); |
251 | } |
252 | |
253 | // Range wrappers. |
254 | template <class RangeTy, |
255 | class Comparator = std::less<decltype(*std::begin(RangeTy()))>> |
256 | void parallelSort(RangeTy &&R, const Comparator &Comp = Comparator()) { |
257 | parallelSort(std::begin(R), std::end(R), Comp); |
258 | } |
259 | |
260 | template <class RangeTy, class FuncTy> |
261 | void parallelForEach(RangeTy &&R, FuncTy Fn) { |
262 | parallelForEach(std::begin(R), std::end(R), Fn); |
263 | } |
264 | |
265 | template <class RangeTy, class ResultTy, class ReduceFuncTy, |
266 | class TransformFuncTy> |
267 | ResultTy parallelTransformReduce(RangeTy &&R, ResultTy Init, |
268 | ReduceFuncTy Reduce, |
269 | TransformFuncTy Transform) { |
270 | return parallelTransformReduce(std::begin(R), std::end(R), Init, Reduce, |
271 | Transform); |
272 | } |
273 | |
274 | // Parallel for-each, but with error handling. |
275 | template <class RangeTy, class FuncTy> |
276 | Error parallelForEachError(RangeTy &&R, FuncTy Fn) { |
277 | // The transform_reduce algorithm requires that the initial value be copyable. |
278 | // Error objects are uncopyable. We only need to copy initial success values, |
279 | // so work around this mismatch via the C API. The C API represents success |
280 | // values with a null pointer. The joinErrors discards null values and joins |
281 | // multiple errors into an ErrorList. |
282 | return unwrap(parallelTransformReduce( |
283 | std::begin(R), std::end(R), wrap(Err: Error::success()), |
284 | [](LLVMErrorRef Lhs, LLVMErrorRef Rhs) { |
285 | return wrap(Err: joinErrors(E1: unwrap(ErrRef: Lhs), E2: unwrap(ErrRef: Rhs))); |
286 | }, |
287 | [&Fn](auto &&V) { return wrap(Fn(V)); })); |
288 | } |
289 | |
290 | } // namespace llvm |
291 | |
292 | #endif // LLVM_SUPPORT_PARALLEL_H |
293 | |