1 | //===----------------------------------------------------------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #ifndef _LIBCPP___RANDOM_DISCRETE_DISTRIBUTION_H |
10 | #define _LIBCPP___RANDOM_DISCRETE_DISTRIBUTION_H |
11 | |
12 | #include <__algorithm/upper_bound.h> |
13 | #include <__config> |
14 | #include <__random/is_valid.h> |
15 | #include <__random/uniform_real_distribution.h> |
16 | #include <cstddef> |
17 | #include <iosfwd> |
18 | #include <numeric> |
19 | #include <vector> |
20 | |
21 | #if !defined(_LIBCPP_HAS_NO_PRAGMA_SYSTEM_HEADER) |
22 | # pragma GCC system_header |
23 | #endif |
24 | |
25 | _LIBCPP_PUSH_MACROS |
26 | #include <__undef_macros> |
27 | |
28 | _LIBCPP_BEGIN_NAMESPACE_STD |
29 | |
30 | template <class _IntType = int> |
31 | class _LIBCPP_TEMPLATE_VIS discrete_distribution { |
32 | static_assert(__libcpp_random_is_valid_inttype<_IntType>::value, "IntType must be a supported integer type" ); |
33 | |
34 | public: |
35 | // types |
36 | typedef _IntType result_type; |
37 | |
38 | class _LIBCPP_TEMPLATE_VIS param_type { |
39 | vector<double> __p_; |
40 | |
41 | public: |
42 | typedef discrete_distribution distribution_type; |
43 | |
44 | _LIBCPP_HIDE_FROM_ABI param_type() {} |
45 | template <class _InputIterator> |
46 | _LIBCPP_HIDE_FROM_ABI param_type(_InputIterator __f, _InputIterator __l) : __p_(__f, __l) { |
47 | __init(); |
48 | } |
49 | #ifndef _LIBCPP_CXX03_LANG |
50 | _LIBCPP_HIDE_FROM_ABI param_type(initializer_list<double> __wl) : __p_(__wl.begin(), __wl.end()) { __init(); } |
51 | #endif // _LIBCPP_CXX03_LANG |
52 | template <class _UnaryOperation> |
53 | _LIBCPP_HIDE_FROM_ABI param_type(size_t __nw, double __xmin, double __xmax, _UnaryOperation __fw); |
54 | |
55 | _LIBCPP_HIDE_FROM_ABI vector<double> probabilities() const; |
56 | |
57 | friend _LIBCPP_HIDE_FROM_ABI bool operator==(const param_type& __x, const param_type& __y) { |
58 | return __x.__p_ == __y.__p_; |
59 | } |
60 | friend _LIBCPP_HIDE_FROM_ABI bool operator!=(const param_type& __x, const param_type& __y) { return !(__x == __y); } |
61 | |
62 | private: |
63 | _LIBCPP_HIDE_FROM_ABI void __init(); |
64 | |
65 | friend class discrete_distribution; |
66 | |
67 | template <class _CharT, class _Traits, class _IT> |
68 | friend basic_ostream<_CharT, _Traits>& |
69 | operator<<(basic_ostream<_CharT, _Traits>& __os, const discrete_distribution<_IT>& __x); |
70 | |
71 | template <class _CharT, class _Traits, class _IT> |
72 | friend basic_istream<_CharT, _Traits>& |
73 | operator>>(basic_istream<_CharT, _Traits>& __is, discrete_distribution<_IT>& __x); |
74 | }; |
75 | |
76 | private: |
77 | param_type __p_; |
78 | |
79 | public: |
80 | // constructor and reset functions |
81 | _LIBCPP_HIDE_FROM_ABI discrete_distribution() {} |
82 | template <class _InputIterator> |
83 | _LIBCPP_HIDE_FROM_ABI discrete_distribution(_InputIterator __f, _InputIterator __l) : __p_(__f, __l) {} |
84 | #ifndef _LIBCPP_CXX03_LANG |
85 | _LIBCPP_HIDE_FROM_ABI discrete_distribution(initializer_list<double> __wl) : __p_(__wl) {} |
86 | #endif // _LIBCPP_CXX03_LANG |
87 | template <class _UnaryOperation> |
88 | _LIBCPP_HIDE_FROM_ABI discrete_distribution(size_t __nw, double __xmin, double __xmax, _UnaryOperation __fw) |
89 | : __p_(__nw, __xmin, __xmax, __fw) {} |
90 | _LIBCPP_HIDE_FROM_ABI explicit discrete_distribution(const param_type& __p) : __p_(__p) {} |
91 | _LIBCPP_HIDE_FROM_ABI void reset() {} |
92 | |
93 | // generating functions |
94 | template <class _URNG> |
95 | _LIBCPP_HIDE_FROM_ABI result_type operator()(_URNG& __g) { |
96 | return (*this)(__g, __p_); |
97 | } |
98 | template <class _URNG> |
99 | _LIBCPP_HIDE_FROM_ABI result_type operator()(_URNG& __g, const param_type& __p); |
100 | |
101 | // property functions |
102 | _LIBCPP_HIDE_FROM_ABI vector<double> probabilities() const { return __p_.probabilities(); } |
103 | |
104 | _LIBCPP_HIDE_FROM_ABI param_type param() const { return __p_; } |
105 | _LIBCPP_HIDE_FROM_ABI void param(const param_type& __p) { __p_ = __p; } |
106 | |
107 | _LIBCPP_HIDE_FROM_ABI result_type min() const { return 0; } |
108 | _LIBCPP_HIDE_FROM_ABI result_type max() const { return __p_.__p_.size(); } |
109 | |
110 | friend _LIBCPP_HIDE_FROM_ABI bool operator==(const discrete_distribution& __x, const discrete_distribution& __y) { |
111 | return __x.__p_ == __y.__p_; |
112 | } |
113 | friend _LIBCPP_HIDE_FROM_ABI bool operator!=(const discrete_distribution& __x, const discrete_distribution& __y) { |
114 | return !(__x == __y); |
115 | } |
116 | |
117 | template <class _CharT, class _Traits, class _IT> |
118 | friend basic_ostream<_CharT, _Traits>& |
119 | operator<<(basic_ostream<_CharT, _Traits>& __os, const discrete_distribution<_IT>& __x); |
120 | |
121 | template <class _CharT, class _Traits, class _IT> |
122 | friend basic_istream<_CharT, _Traits>& |
123 | operator>>(basic_istream<_CharT, _Traits>& __is, discrete_distribution<_IT>& __x); |
124 | }; |
125 | |
126 | template <class _IntType> |
127 | template <class _UnaryOperation> |
128 | discrete_distribution<_IntType>::param_type::param_type( |
129 | size_t __nw, double __xmin, double __xmax, _UnaryOperation __fw) { |
130 | if (__nw > 1) { |
131 | __p_.reserve(__nw - 1); |
132 | double __d = (__xmax - __xmin) / __nw; |
133 | double __d2 = __d / 2; |
134 | for (size_t __k = 0; __k < __nw; ++__k) |
135 | __p_.push_back(__fw(__xmin + __k * __d + __d2)); |
136 | __init(); |
137 | } |
138 | } |
139 | |
140 | template <class _IntType> |
141 | void discrete_distribution<_IntType>::param_type::__init() { |
142 | if (!__p_.empty()) { |
143 | if (__p_.size() > 1) { |
144 | double __s = std::accumulate(__p_.begin(), __p_.end(), 0.0); |
145 | for (vector<double>::iterator __i = __p_.begin(), __e = __p_.end(); __i < __e; ++__i) |
146 | *__i /= __s; |
147 | vector<double> __t(__p_.size() - 1); |
148 | std::partial_sum(__p_.begin(), __p_.end() - 1, __t.begin()); |
149 | swap(__p_, __t); |
150 | } else { |
151 | __p_.clear(); |
152 | __p_.shrink_to_fit(); |
153 | } |
154 | } |
155 | } |
156 | |
157 | template <class _IntType> |
158 | vector<double> discrete_distribution<_IntType>::param_type::probabilities() const { |
159 | size_t __n = __p_.size(); |
160 | vector<double> __p(__n + 1); |
161 | std::adjacent_difference(__p_.begin(), __p_.end(), __p.begin()); |
162 | if (__n > 0) |
163 | __p[__n] = 1 - __p_[__n - 1]; |
164 | else |
165 | __p[0] = 1; |
166 | return __p; |
167 | } |
168 | |
169 | template <class _IntType> |
170 | template <class _URNG> |
171 | _IntType discrete_distribution<_IntType>::operator()(_URNG& __g, const param_type& __p) { |
172 | static_assert(__libcpp_random_is_valid_urng<_URNG>::value, "" ); |
173 | uniform_real_distribution<double> __gen; |
174 | return static_cast<_IntType>(std::upper_bound(__p.__p_.begin(), __p.__p_.end(), __gen(__g)) - __p.__p_.begin()); |
175 | } |
176 | |
177 | template <class _CharT, class _Traits, class _IT> |
178 | _LIBCPP_HIDE_FROM_ABI basic_ostream<_CharT, _Traits>& |
179 | operator<<(basic_ostream<_CharT, _Traits>& __os, const discrete_distribution<_IT>& __x) { |
180 | __save_flags<_CharT, _Traits> __lx(__os); |
181 | typedef basic_ostream<_CharT, _Traits> _OStream; |
182 | __os.flags(_OStream::dec | _OStream::left | _OStream::fixed | _OStream::scientific); |
183 | _CharT __sp = __os.widen(' '); |
184 | __os.fill(__sp); |
185 | size_t __n = __x.__p_.__p_.size(); |
186 | __os << __n; |
187 | for (size_t __i = 0; __i < __n; ++__i) |
188 | __os << __sp << __x.__p_.__p_[__i]; |
189 | return __os; |
190 | } |
191 | |
192 | template <class _CharT, class _Traits, class _IT> |
193 | _LIBCPP_HIDE_FROM_ABI basic_istream<_CharT, _Traits>& |
194 | operator>>(basic_istream<_CharT, _Traits>& __is, discrete_distribution<_IT>& __x) { |
195 | __save_flags<_CharT, _Traits> __lx(__is); |
196 | typedef basic_istream<_CharT, _Traits> _Istream; |
197 | __is.flags(_Istream::dec | _Istream::skipws); |
198 | size_t __n; |
199 | __is >> __n; |
200 | vector<double> __p(__n); |
201 | for (size_t __i = 0; __i < __n; ++__i) |
202 | __is >> __p[__i]; |
203 | if (!__is.fail()) |
204 | swap(__x.__p_.__p_, __p); |
205 | return __is; |
206 | } |
207 | |
208 | _LIBCPP_END_NAMESPACE_STD |
209 | |
210 | _LIBCPP_POP_MACROS |
211 | |
212 | #endif // _LIBCPP___RANDOM_DISCRETE_DISTRIBUTION_H |
213 | |