1//===- llvm/Support/Parallel.h - Parallel algorithms ----------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#ifndef LLVM_SUPPORT_PARALLEL_H
10#define LLVM_SUPPORT_PARALLEL_H
11
12#include "llvm/ADT/STLExtras.h"
13#include "llvm/Config/llvm-config.h"
14#include "llvm/Support/Error.h"
15#include "llvm/Support/MathExtras.h"
16#include "llvm/Support/Threading.h"
17
18#include <algorithm>
19#include <condition_variable>
20#include <functional>
21#include <mutex>
22
23namespace llvm {
24
25namespace parallel {
26
27// Strategy for the default executor used by the parallel routines provided by
28// this file. It defaults to using all hardware threads and should be
29// initialized before the first use of parallel routines.
30extern ThreadPoolStrategy strategy;
31
32#if LLVM_ENABLE_THREADS
33#define GET_THREAD_INDEX_IMPL \
34 if (parallel::strategy.ThreadsRequested == 1) \
35 return 0; \
36 assert((threadIndex != UINT_MAX) && \
37 "getThreadIndex() must be called from a thread created by " \
38 "ThreadPoolExecutor"); \
39 return threadIndex;
40
41#ifdef _WIN32
42// Direct access to thread_local variables from a different DLL isn't
43// possible with Windows Native TLS.
44unsigned getThreadIndex();
45#else
46// Don't access this directly, use the getThreadIndex wrapper.
47extern thread_local unsigned threadIndex;
48
49inline unsigned getThreadIndex() { GET_THREAD_INDEX_IMPL; }
50#endif
51
52size_t getThreadCount();
53#else
54inline unsigned getThreadIndex() { return 0; }
55inline size_t getThreadCount() { return 1; }
56#endif
57
58namespace detail {
59class Latch {
60 uint32_t Count;
61 mutable std::mutex Mutex;
62 mutable std::condition_variable Cond;
63
64public:
65 explicit Latch(uint32_t Count = 0) : Count(Count) {}
66 ~Latch() {
67 // Ensure at least that sync() was called.
68 assert(Count == 0);
69 }
70
71 void inc() {
72 std::lock_guard<std::mutex> lock(Mutex);
73 ++Count;
74 }
75
76 void dec() {
77 std::lock_guard<std::mutex> lock(Mutex);
78 if (--Count == 0)
79 Cond.notify_all();
80 }
81
82 void sync() const {
83 std::unique_lock<std::mutex> lock(Mutex);
84 Cond.wait(lock&: lock, p: [&] { return Count == 0; });
85 }
86};
87} // namespace detail
88
89class TaskGroup {
90 detail::Latch L;
91 bool Parallel;
92
93public:
94 TaskGroup();
95 ~TaskGroup();
96
97 // Spawn a task, but does not wait for it to finish.
98 // Tasks marked with \p Sequential will be executed
99 // exactly in the order which they were spawned.
100 // Note: Sequential tasks may be executed on different
101 // threads, but strictly in sequential order.
102 void spawn(std::function<void()> f, bool Sequential = false);
103
104 void sync() const { L.sync(); }
105
106 bool isParallel() const { return Parallel; }
107};
108
109namespace detail {
110
111#if LLVM_ENABLE_THREADS
112const ptrdiff_t MinParallelSize = 1024;
113
114/// Inclusive median.
115template <class RandomAccessIterator, class Comparator>
116RandomAccessIterator medianOf3(RandomAccessIterator Start,
117 RandomAccessIterator End,
118 const Comparator &Comp) {
119 RandomAccessIterator Mid = Start + (std::distance(Start, End) / 2);
120 return Comp(*Start, *(End - 1))
121 ? (Comp(*Mid, *(End - 1)) ? (Comp(*Start, *Mid) ? Mid : Start)
122 : End - 1)
123 : (Comp(*Mid, *Start) ? (Comp(*(End - 1), *Mid) ? Mid : End - 1)
124 : Start);
125}
126
127template <class RandomAccessIterator, class Comparator>
128void parallel_quick_sort(RandomAccessIterator Start, RandomAccessIterator End,
129 const Comparator &Comp, TaskGroup &TG, size_t Depth) {
130 // Do a sequential sort for small inputs.
131 if (std::distance(Start, End) < detail::MinParallelSize || Depth == 0) {
132 llvm::sort(Start, End, Comp);
133 return;
134 }
135
136 // Partition.
137 auto Pivot = medianOf3(Start, End, Comp);
138 // Move Pivot to End.
139 std::swap(*(End - 1), *Pivot);
140 Pivot = std::partition(Start, End - 1, [&Comp, End](decltype(*Start) V) {
141 return Comp(V, *(End - 1));
142 });
143 // Move Pivot to middle of partition.
144 std::swap(*Pivot, *(End - 1));
145
146 // Recurse.
147 TG.spawn(f: [=, &Comp, &TG] {
148 parallel_quick_sort(Start, Pivot, Comp, TG, Depth - 1);
149 });
150 parallel_quick_sort(Pivot + 1, End, Comp, TG, Depth - 1);
151}
152
153template <class RandomAccessIterator, class Comparator>
154void parallel_sort(RandomAccessIterator Start, RandomAccessIterator End,
155 const Comparator &Comp) {
156 TaskGroup TG;
157 parallel_quick_sort(Start, End, Comp, TG,
158 llvm::Log2_64(Value: std::distance(Start, End)) + 1);
159}
160
161// TaskGroup has a relatively high overhead, so we want to reduce
162// the number of spawn() calls. We'll create up to 1024 tasks here.
163// (Note that 1024 is an arbitrary number. This code probably needs
164// improving to take the number of available cores into account.)
165enum { MaxTasksPerGroup = 1024 };
166
167template <class IterTy, class ResultTy, class ReduceFuncTy,
168 class TransformFuncTy>
169ResultTy parallel_transform_reduce(IterTy Begin, IterTy End, ResultTy Init,
170 ReduceFuncTy Reduce,
171 TransformFuncTy Transform) {
172 // Limit the number of tasks to MaxTasksPerGroup to limit job scheduling
173 // overhead on large inputs.
174 size_t NumInputs = std::distance(Begin, End);
175 if (NumInputs == 0)
176 return std::move(Init);
177 size_t NumTasks = std::min(a: static_cast<size_t>(MaxTasksPerGroup), b: NumInputs);
178 std::vector<ResultTy> Results(NumTasks, Init);
179 {
180 // Each task processes either TaskSize or TaskSize+1 inputs. Any inputs
181 // remaining after dividing them equally amongst tasks are distributed as
182 // one extra input over the first tasks.
183 TaskGroup TG;
184 size_t TaskSize = NumInputs / NumTasks;
185 size_t RemainingInputs = NumInputs % NumTasks;
186 IterTy TBegin = Begin;
187 for (size_t TaskId = 0; TaskId < NumTasks; ++TaskId) {
188 IterTy TEnd = TBegin + TaskSize + (TaskId < RemainingInputs ? 1 : 0);
189 TG.spawn(f: [=, &Transform, &Reduce, &Results] {
190 // Reduce the result of transformation eagerly within each task.
191 ResultTy R = Init;
192 for (IterTy It = TBegin; It != TEnd; ++It)
193 R = Reduce(R, Transform(*It));
194 Results[TaskId] = R;
195 });
196 TBegin = TEnd;
197 }
198 assert(TBegin == End);
199 }
200
201 // Do a final reduction. There are at most 1024 tasks, so this only adds
202 // constant single-threaded overhead for large inputs. Hopefully most
203 // reductions are cheaper than the transformation.
204 ResultTy FinalResult = std::move(Results.front());
205 for (ResultTy &PartialResult :
206 MutableArrayRef(Results.data() + 1, Results.size() - 1))
207 FinalResult = Reduce(FinalResult, std::move(PartialResult));
208 return std::move(FinalResult);
209}
210
211#endif
212
213} // namespace detail
214} // namespace parallel
215
216template <class RandomAccessIterator,
217 class Comparator = std::less<
218 typename std::iterator_traits<RandomAccessIterator>::value_type>>
219void parallelSort(RandomAccessIterator Start, RandomAccessIterator End,
220 const Comparator &Comp = Comparator()) {
221#if LLVM_ENABLE_THREADS
222 if (parallel::strategy.ThreadsRequested != 1) {
223 parallel::detail::parallel_sort(Start, End, Comp);
224 return;
225 }
226#endif
227 llvm::sort(Start, End, Comp);
228}
229
230void parallelFor(size_t Begin, size_t End, function_ref<void(size_t)> Fn);
231
232template <class IterTy, class FuncTy>
233void parallelForEach(IterTy Begin, IterTy End, FuncTy Fn) {
234 parallelFor(0, End - Begin, [&](size_t I) { Fn(Begin[I]); });
235}
236
237template <class IterTy, class ResultTy, class ReduceFuncTy,
238 class TransformFuncTy>
239ResultTy parallelTransformReduce(IterTy Begin, IterTy End, ResultTy Init,
240 ReduceFuncTy Reduce,
241 TransformFuncTy Transform) {
242#if LLVM_ENABLE_THREADS
243 if (parallel::strategy.ThreadsRequested != 1) {
244 return parallel::detail::parallel_transform_reduce(Begin, End, Init, Reduce,
245 Transform);
246 }
247#endif
248 for (IterTy I = Begin; I != End; ++I)
249 Init = Reduce(std::move(Init), Transform(*I));
250 return std::move(Init);
251}
252
253// Range wrappers.
254template <class RangeTy,
255 class Comparator = std::less<decltype(*std::begin(RangeTy()))>>
256void parallelSort(RangeTy &&R, const Comparator &Comp = Comparator()) {
257 parallelSort(std::begin(R), std::end(R), Comp);
258}
259
260template <class RangeTy, class FuncTy>
261void parallelForEach(RangeTy &&R, FuncTy Fn) {
262 parallelForEach(std::begin(R), std::end(R), Fn);
263}
264
265template <class RangeTy, class ResultTy, class ReduceFuncTy,
266 class TransformFuncTy>
267ResultTy parallelTransformReduce(RangeTy &&R, ResultTy Init,
268 ReduceFuncTy Reduce,
269 TransformFuncTy Transform) {
270 return parallelTransformReduce(std::begin(R), std::end(R), Init, Reduce,
271 Transform);
272}
273
274// Parallel for-each, but with error handling.
275template <class RangeTy, class FuncTy>
276Error parallelForEachError(RangeTy &&R, FuncTy Fn) {
277 // The transform_reduce algorithm requires that the initial value be copyable.
278 // Error objects are uncopyable. We only need to copy initial success values,
279 // so work around this mismatch via the C API. The C API represents success
280 // values with a null pointer. The joinErrors discards null values and joins
281 // multiple errors into an ErrorList.
282 return unwrap(parallelTransformReduce(
283 std::begin(R), std::end(R), wrap(Err: Error::success()),
284 [](LLVMErrorRef Lhs, LLVMErrorRef Rhs) {
285 return wrap(Err: joinErrors(E1: unwrap(ErrRef: Lhs), E2: unwrap(ErrRef: Rhs)));
286 },
287 [&Fn](auto &&V) { return wrap(Fn(V)); }));
288}
289
290} // namespace llvm
291
292#endif // LLVM_SUPPORT_PARALLEL_H
293