1// -*- C++ -*-
2//===----------------------------------------------------------------------===//
3//
4// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
5// See https://llvm.org/LICENSE.txt for license information.
6// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7//
8//===----------------------------------------------------------------------===//
9
10#ifndef _LIBCPP___ALGORITHM_EQUAL_H
11#define _LIBCPP___ALGORITHM_EQUAL_H
12
13#include <__algorithm/comp.h>
14#include <__algorithm/find_segment_if.h>
15#include <__algorithm/min.h>
16#include <__algorithm/unwrap_iter.h>
17#include <__config>
18#include <__functional/identity.h>
19#include <__fwd/bit_reference.h>
20#include <__iterator/iterator_traits.h>
21#include <__iterator/segmented_iterator.h>
22#include <__string/constexpr_c_functions.h>
23#include <__type_traits/common_type.h>
24#include <__type_traits/desugars_to.h>
25#include <__type_traits/enable_if.h>
26#include <__type_traits/invoke.h>
27#include <__type_traits/is_equality_comparable.h>
28#include <__type_traits/is_same.h>
29#include <__type_traits/is_volatile.h>
30#include <__utility/move.h>
31#include <__utility/unreachable.h>
32
33#if !defined(_LIBCPP_HAS_NO_PRAGMA_SYSTEM_HEADER)
34# pragma GCC system_header
35#endif
36
37_LIBCPP_PUSH_MACROS
38#include <__undef_macros>
39
40_LIBCPP_BEGIN_NAMESPACE_STD
41
42template <class _Cp, bool _IsConst1, bool _IsConst2>
43[[__nodiscard__]] _LIBCPP_CONSTEXPR_SINCE_CXX20 _LIBCPP_HIDE_FROM_ABI bool
44__equal_unaligned(__bit_iterator<_Cp, _IsConst1> __first1,
45 __bit_iterator<_Cp, _IsConst1> __last1,
46 __bit_iterator<_Cp, _IsConst2> __first2) {
47 using _It = __bit_iterator<_Cp, _IsConst1>;
48 using difference_type = typename _It::difference_type;
49 using __storage_type = typename _It::__storage_type;
50
51 const int __bits_per_word = _It::__bits_per_word;
52 difference_type __n = __last1 - __first1;
53 if (__n > 0) {
54 // do first word
55 if (__first1.__ctz_ != 0) {
56 unsigned __clz_f = __bits_per_word - __first1.__ctz_;
57 difference_type __dn = std::min(static_cast<difference_type>(__clz_f), __n);
58 __n -= __dn;
59 __storage_type __m = std::__middle_mask<__storage_type>(__clz_f - __dn, __first1.__ctz_);
60 __storage_type __b = *__first1.__seg_ & __m;
61 unsigned __clz_r = __bits_per_word - __first2.__ctz_;
62 __storage_type __ddn = std::min<__storage_type>(__dn, __clz_r);
63 __m = std::__middle_mask<__storage_type>(__clz_r - __ddn, __first2.__ctz_);
64 if (__first2.__ctz_ > __first1.__ctz_) {
65 if (static_cast<__storage_type>(*__first2.__seg_ & __m) !=
66 static_cast<__storage_type>(__b << (__first2.__ctz_ - __first1.__ctz_)))
67 return false;
68 } else {
69 if (static_cast<__storage_type>(*__first2.__seg_ & __m) !=
70 static_cast<__storage_type>(__b >> (__first1.__ctz_ - __first2.__ctz_)))
71 return false;
72 }
73 __first2.__seg_ += (__ddn + __first2.__ctz_) / __bits_per_word;
74 __first2.__ctz_ = static_cast<unsigned>((__ddn + __first2.__ctz_) % __bits_per_word);
75 __dn -= __ddn;
76 if (__dn > 0) {
77 __m = std::__trailing_mask<__storage_type>(__bits_per_word - __n);
78 if (static_cast<__storage_type>(*__first2.__seg_ & __m) !=
79 static_cast<__storage_type>(__b >> (__first1.__ctz_ + __ddn)))
80 return false;
81 __first2.__ctz_ = static_cast<unsigned>(__dn);
82 }
83 ++__first1.__seg_;
84 // __first1.__ctz_ = 0;
85 }
86 // __first1.__ctz_ == 0;
87 // do middle words
88 unsigned __clz_r = __bits_per_word - __first2.__ctz_;
89 __storage_type __m = std::__leading_mask<__storage_type>(__first2.__ctz_);
90 for (; __n >= __bits_per_word; __n -= __bits_per_word, ++__first1.__seg_) {
91 __storage_type __b = *__first1.__seg_;
92 if (static_cast<__storage_type>(*__first2.__seg_ & __m) != static_cast<__storage_type>(__b << __first2.__ctz_))
93 return false;
94 ++__first2.__seg_;
95 if (static_cast<__storage_type>(*__first2.__seg_ & static_cast<__storage_type>(~__m)) !=
96 static_cast<__storage_type>(__b >> __clz_r))
97 return false;
98 }
99 // do last word
100 if (__n > 0) {
101 __m = std::__trailing_mask<__storage_type>(__bits_per_word - __n);
102 __storage_type __b = *__first1.__seg_ & __m;
103 __storage_type __dn = std::min(__n, static_cast<difference_type>(__clz_r));
104 __m = std::__middle_mask<__storage_type>(__clz_r - __dn, __first2.__ctz_);
105 if (static_cast<__storage_type>(*__first2.__seg_ & __m) != static_cast<__storage_type>(__b << __first2.__ctz_))
106 return false;
107 __first2.__seg_ += (__dn + __first2.__ctz_) / __bits_per_word;
108 __first2.__ctz_ = static_cast<unsigned>((__dn + __first2.__ctz_) % __bits_per_word);
109 __n -= __dn;
110 if (__n > 0) {
111 __m = std::__trailing_mask<__storage_type>(__bits_per_word - __n);
112 if (static_cast<__storage_type>(*__first2.__seg_ & __m) != static_cast<__storage_type>(__b >> __dn))
113 return false;
114 }
115 }
116 }
117 return true;
118}
119
120template <class _Cp, bool _IsConst1, bool _IsConst2>
121[[__nodiscard__]] _LIBCPP_CONSTEXPR_SINCE_CXX20 _LIBCPP_HIDE_FROM_ABI bool
122__equal_aligned(__bit_iterator<_Cp, _IsConst1> __first1,
123 __bit_iterator<_Cp, _IsConst1> __last1,
124 __bit_iterator<_Cp, _IsConst2> __first2) {
125 using _It = __bit_iterator<_Cp, _IsConst1>;
126 using difference_type = typename _It::difference_type;
127 using __storage_type = typename _It::__storage_type;
128
129 const int __bits_per_word = _It::__bits_per_word;
130 difference_type __n = __last1 - __first1;
131 if (__n > 0) {
132 // do first word
133 if (__first1.__ctz_ != 0) {
134 unsigned __clz = __bits_per_word - __first1.__ctz_;
135 difference_type __dn = std::min(static_cast<difference_type>(__clz), __n);
136 __n -= __dn;
137 __storage_type __m = std::__middle_mask<__storage_type>(__clz - __dn, __first1.__ctz_);
138 if ((*__first2.__seg_ & __m) != (*__first1.__seg_ & __m))
139 return false;
140 ++__first2.__seg_;
141 ++__first1.__seg_;
142 // __first1.__ctz_ = 0;
143 // __first2.__ctz_ = 0;
144 }
145 // __first1.__ctz_ == 0;
146 // __first2.__ctz_ == 0;
147 // do middle words
148 for (; __n >= __bits_per_word; __n -= __bits_per_word, ++__first1.__seg_, ++__first2.__seg_)
149 if (*__first2.__seg_ != *__first1.__seg_)
150 return false;
151 // do last word
152 if (__n > 0) {
153 __storage_type __m = std::__trailing_mask<__storage_type>(__bits_per_word - __n);
154 if ((*__first2.__seg_ & __m) != (*__first1.__seg_ & __m))
155 return false;
156 }
157 }
158 return true;
159}
160
161template <class _Cp,
162 bool _IsConst1,
163 bool _IsConst2,
164 class _BinaryPredicate,
165 class _Proj1,
166 class _Proj2,
167 __enable_if_t<__is_identity<_Proj1>::value && __is_identity<_Proj2>::value &&
168 __desugars_to_v<__equal_tag, _BinaryPredicate, bool, bool>,
169 int> = 0>
170[[__nodiscard__]] inline _LIBCPP_HIDE_FROM_ABI _LIBCPP_CONSTEXPR_SINCE_CXX20 bool __equal_iter_impl(
171 __bit_iterator<_Cp, _IsConst1> __first1,
172 __bit_iterator<_Cp, _IsConst1> __last1,
173 __bit_iterator<_Cp, _IsConst2> __first2,
174 _BinaryPredicate,
175 _Proj1&,
176 _Proj2&) {
177 if (__first1.__ctz_ == __first2.__ctz_)
178 return std::__equal_aligned(__first1, __last1, __first2);
179 return std::__equal_unaligned(__first1, __last1, __first2);
180}
181
182template <class _Tp,
183 class _Up,
184 class _BinaryPredicate,
185 class _Proj1,
186 class _Proj2,
187 __enable_if_t<__is_identity<_Proj1>::value && __is_identity<_Proj2>::value &&
188 __desugars_to_v<__equal_tag, _BinaryPredicate, _Tp, _Up> && !is_volatile<_Tp>::value &&
189 !is_volatile<_Up>::value && __is_trivially_equality_comparable_v<_Tp, _Up>,
190 int> = 0>
191[[__nodiscard__]] inline _LIBCPP_HIDE_FROM_ABI _LIBCPP_CONSTEXPR_SINCE_CXX20 bool
192__equal_iter_impl(_Tp* __first1, _Tp* __last1, _Up* __first2, _BinaryPredicate&, _Proj1&, _Proj2&) {
193 return std::__constexpr_memcmp_equal(__first1, __first2, __element_count(__last1 - __first1));
194}
195
196template <class _InIter1, class _Sent1, class _InIter2, class _Pred, class _Proj1, class _Proj2>
197[[__nodiscard__]] inline _LIBCPP_HIDE_FROM_ABI _LIBCPP_CONSTEXPR_SINCE_CXX20 bool __equal_iter_impl(
198 _InIter1 __first1, _Sent1 __last1, _InIter2 __first2, _Pred& __pred, _Proj1& __proj1, _Proj2& __proj2) {
199#ifndef _LIBCPP_CXX03_LANG
200 if constexpr (__has_random_access_iterator_category<_InIter1>::value &&
201 __has_random_access_iterator_category<_InIter2>::value) {
202 if constexpr (is_same<_InIter1, _Sent1>::value && __is_segmented_iterator_v<_InIter1>) {
203 using __local_iterator_t = typename __segmented_iterator_traits<_InIter1>::__local_iterator;
204 bool __is_equal = true;
205 std::__find_segment_if(__first1, __last1, [&](__local_iterator_t __lfirst, __local_iterator_t __llast) {
206 if (std::__equal_iter_impl(
207 std::__unwrap_iter(__lfirst), std::__unwrap_iter(__llast), __first2, __pred, __proj1, __proj2)) {
208 __first2 += __llast - __lfirst;
209 return __llast;
210 }
211 __is_equal = false;
212 return __lfirst;
213 });
214 return __is_equal;
215 } else if constexpr (__is_segmented_iterator_v<_InIter2>) {
216 using _Traits = __segmented_iterator_traits<_InIter2>;
217 using _DiffT =
218 typename common_type<__iterator_difference_type<_InIter1>, __iterator_difference_type<_InIter2> >::type;
219
220 if (__first1 == __last1)
221 return true;
222
223 auto __local_first = _Traits::__local(__first2);
224 auto __segment_iterator = _Traits::__segment(__first2);
225
226 while (true) {
227 auto __local_last = _Traits::__end(__segment_iterator);
228 auto __size = std::min<_DiffT>(__local_last - __local_first, __last1 - __first1);
229 if (!std::__equal_iter_impl(
230 __first1, __first1 + __size, std::__unwrap_iter(__local_first), __pred, __proj1, __proj2))
231 return false;
232
233 __first1 += __size;
234 if (__first1 == __last1)
235 return true;
236
237 __local_first = _Traits::__begin(++__segment_iterator);
238 }
239 }
240 }
241#endif
242 for (; __first1 != __last1; ++__first1, (void)++__first2)
243 if (!std::__invoke(__pred, std::__invoke(__proj1, *__first1), std::__invoke(__proj2, *__first2)))
244 return false;
245 return true;
246}
247
248template <class _InputIterator1, class _InputIterator2, class _BinaryPredicate>
249[[__nodiscard__]] inline _LIBCPP_HIDE_FROM_ABI _LIBCPP_CONSTEXPR_SINCE_CXX20 bool
250equal(_InputIterator1 __first1, _InputIterator1 __last1, _InputIterator2 __first2, _BinaryPredicate __pred) {
251 __identity __proj;
252 return std::__equal_iter_impl(
253 std::__unwrap_iter(__first1), std::__unwrap_iter(__last1), std::__unwrap_iter(__first2), __pred, __proj, __proj);
254}
255
256template <class _InputIterator1, class _InputIterator2>
257[[__nodiscard__]] inline _LIBCPP_HIDE_FROM_ABI _LIBCPP_CONSTEXPR_SINCE_CXX20 bool
258equal(_InputIterator1 __first1, _InputIterator1 __last1, _InputIterator2 __first2) {
259 return std::equal(__first1, __last1, __first2, __equal_to());
260}
261
262#if _LIBCPP_STD_VER >= 14
263
264template <bool __known_equal_length,
265 class _Iter1,
266 class _Sent1,
267 class _Iter2,
268 class _Sent2,
269 class _Pred,
270 class _Proj1,
271 class _Proj2>
272[[__nodiscard__]] inline _LIBCPP_HIDE_FROM_ABI _LIBCPP_CONSTEXPR_SINCE_CXX20 bool __equal_impl(
273 _Iter1 __first1, _Sent1 __last1, _Iter2 __first2, _Sent2 __last2, _Pred& __comp, _Proj1& __proj1, _Proj2& __proj2) {
274 if constexpr (__known_equal_length) {
275 return std::__equal_iter_impl(
276 std::move(__first1), std::move(__last1), std::move(__first2), __comp, __proj1, __proj2);
277 } else {
278 while (__first1 != __last1 && __first2 != __last2) {
279 if (!std::__invoke(__comp, std::__invoke(__proj1, *__first1), std::__invoke(__proj2, *__first2)))
280 return false;
281 ++__first1;
282 ++__first2;
283 }
284 return __first1 == __last1 && __first2 == __last2;
285 }
286}
287
288template <class _InputIterator1, class _InputIterator2, class _BinaryPredicate>
289[[__nodiscard__]] inline _LIBCPP_HIDE_FROM_ABI _LIBCPP_CONSTEXPR_SINCE_CXX20 bool
290equal(_InputIterator1 __first1,
291 _InputIterator1 __last1,
292 _InputIterator2 __first2,
293 _InputIterator2 __last2,
294 _BinaryPredicate __pred) {
295 constexpr bool __both_random_access =
296 __has_random_access_iterator_category<_InputIterator1>::value &&
297 __has_random_access_iterator_category<_InputIterator2>::value;
298 if constexpr (__both_random_access) {
299 if (__last1 - __first1 != __last2 - __first2)
300 return false;
301 }
302 __identity __proj;
303 return std::__equal_impl<__both_random_access>(
304 std::__unwrap_iter(__first1),
305 std::__unwrap_iter(__last1),
306 std::__unwrap_iter(__first2),
307 std::__unwrap_iter(__last2),
308 __pred,
309 __proj,
310 __proj);
311}
312
313template <class _InputIterator1, class _InputIterator2>
314[[__nodiscard__]] inline _LIBCPP_HIDE_FROM_ABI _LIBCPP_CONSTEXPR_SINCE_CXX20 bool
315equal(_InputIterator1 __first1, _InputIterator1 __last1, _InputIterator2 __first2, _InputIterator2 __last2) {
316 return std::equal(__first1, __last1, __first2, __last2, __equal_to());
317}
318
319#endif // _LIBCPP_STD_VER >= 14
320
321_LIBCPP_END_NAMESPACE_STD
322
323_LIBCPP_POP_MACROS
324
325#endif // _LIBCPP___ALGORITHM_EQUAL_H
326