1// -*- C++ -*-
2//===----------------------------------------------------------------------===//
3//
4// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
5// See https://llvm.org/LICENSE.txt for license information.
6// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7//
8//===----------------------------------------------------------------------===//
9
10#ifndef _LIBCPP___ALGORITHM_EQUAL_H
11#define _LIBCPP___ALGORITHM_EQUAL_H
12
13#include <__algorithm/comp.h>
14#include <__algorithm/min.h>
15#include <__algorithm/unwrap_iter.h>
16#include <__config>
17#include <__functional/identity.h>
18#include <__fwd/bit_reference.h>
19#include <__iterator/iterator_traits.h>
20#include <__string/constexpr_c_functions.h>
21#include <__type_traits/desugars_to.h>
22#include <__type_traits/enable_if.h>
23#include <__type_traits/invoke.h>
24#include <__type_traits/is_equality_comparable.h>
25#include <__type_traits/is_volatile.h>
26#include <__utility/move.h>
27
28#if !defined(_LIBCPP_HAS_NO_PRAGMA_SYSTEM_HEADER)
29# pragma GCC system_header
30#endif
31
32_LIBCPP_PUSH_MACROS
33#include <__undef_macros>
34
35_LIBCPP_BEGIN_NAMESPACE_STD
36
37template <class _Cp, bool _IsConst1, bool _IsConst2>
38[[__nodiscard__]] _LIBCPP_CONSTEXPR_SINCE_CXX20 _LIBCPP_HIDE_FROM_ABI bool
39__equal_unaligned(__bit_iterator<_Cp, _IsConst1> __first1,
40 __bit_iterator<_Cp, _IsConst1> __last1,
41 __bit_iterator<_Cp, _IsConst2> __first2) {
42 using _It = __bit_iterator<_Cp, _IsConst1>;
43 using difference_type = typename _It::difference_type;
44 using __storage_type = typename _It::__storage_type;
45
46 const int __bits_per_word = _It::__bits_per_word;
47 difference_type __n = __last1 - __first1;
48 if (__n > 0) {
49 // do first word
50 if (__first1.__ctz_ != 0) {
51 unsigned __clz_f = __bits_per_word - __first1.__ctz_;
52 difference_type __dn = std::min(static_cast<difference_type>(__clz_f), __n);
53 __n -= __dn;
54 __storage_type __m = std::__middle_mask<__storage_type>(__clz_f - __dn, __first1.__ctz_);
55 __storage_type __b = *__first1.__seg_ & __m;
56 unsigned __clz_r = __bits_per_word - __first2.__ctz_;
57 __storage_type __ddn = std::min<__storage_type>(__dn, __clz_r);
58 __m = std::__middle_mask<__storage_type>(__clz_r - __ddn, __first2.__ctz_);
59 if (__first2.__ctz_ > __first1.__ctz_) {
60 if (static_cast<__storage_type>(*__first2.__seg_ & __m) !=
61 static_cast<__storage_type>(__b << (__first2.__ctz_ - __first1.__ctz_)))
62 return false;
63 } else {
64 if (static_cast<__storage_type>(*__first2.__seg_ & __m) !=
65 static_cast<__storage_type>(__b >> (__first1.__ctz_ - __first2.__ctz_)))
66 return false;
67 }
68 __first2.__seg_ += (__ddn + __first2.__ctz_) / __bits_per_word;
69 __first2.__ctz_ = static_cast<unsigned>((__ddn + __first2.__ctz_) % __bits_per_word);
70 __dn -= __ddn;
71 if (__dn > 0) {
72 __m = std::__trailing_mask<__storage_type>(__bits_per_word - __n);
73 if (static_cast<__storage_type>(*__first2.__seg_ & __m) !=
74 static_cast<__storage_type>(__b >> (__first1.__ctz_ + __ddn)))
75 return false;
76 __first2.__ctz_ = static_cast<unsigned>(__dn);
77 }
78 ++__first1.__seg_;
79 // __first1.__ctz_ = 0;
80 }
81 // __first1.__ctz_ == 0;
82 // do middle words
83 unsigned __clz_r = __bits_per_word - __first2.__ctz_;
84 __storage_type __m = std::__leading_mask<__storage_type>(__first2.__ctz_);
85 for (; __n >= __bits_per_word; __n -= __bits_per_word, ++__first1.__seg_) {
86 __storage_type __b = *__first1.__seg_;
87 if (static_cast<__storage_type>(*__first2.__seg_ & __m) != static_cast<__storage_type>(__b << __first2.__ctz_))
88 return false;
89 ++__first2.__seg_;
90 if (static_cast<__storage_type>(*__first2.__seg_ & static_cast<__storage_type>(~__m)) !=
91 static_cast<__storage_type>(__b >> __clz_r))
92 return false;
93 }
94 // do last word
95 if (__n > 0) {
96 __m = std::__trailing_mask<__storage_type>(__bits_per_word - __n);
97 __storage_type __b = *__first1.__seg_ & __m;
98 __storage_type __dn = std::min(__n, static_cast<difference_type>(__clz_r));
99 __m = std::__middle_mask<__storage_type>(__clz_r - __dn, __first2.__ctz_);
100 if (static_cast<__storage_type>(*__first2.__seg_ & __m) != static_cast<__storage_type>(__b << __first2.__ctz_))
101 return false;
102 __first2.__seg_ += (__dn + __first2.__ctz_) / __bits_per_word;
103 __first2.__ctz_ = static_cast<unsigned>((__dn + __first2.__ctz_) % __bits_per_word);
104 __n -= __dn;
105 if (__n > 0) {
106 __m = std::__trailing_mask<__storage_type>(__bits_per_word - __n);
107 if (static_cast<__storage_type>(*__first2.__seg_ & __m) != static_cast<__storage_type>(__b >> __dn))
108 return false;
109 }
110 }
111 }
112 return true;
113}
114
115template <class _Cp, bool _IsConst1, bool _IsConst2>
116[[__nodiscard__]] _LIBCPP_CONSTEXPR_SINCE_CXX20 _LIBCPP_HIDE_FROM_ABI bool
117__equal_aligned(__bit_iterator<_Cp, _IsConst1> __first1,
118 __bit_iterator<_Cp, _IsConst1> __last1,
119 __bit_iterator<_Cp, _IsConst2> __first2) {
120 using _It = __bit_iterator<_Cp, _IsConst1>;
121 using difference_type = typename _It::difference_type;
122 using __storage_type = typename _It::__storage_type;
123
124 const int __bits_per_word = _It::__bits_per_word;
125 difference_type __n = __last1 - __first1;
126 if (__n > 0) {
127 // do first word
128 if (__first1.__ctz_ != 0) {
129 unsigned __clz = __bits_per_word - __first1.__ctz_;
130 difference_type __dn = std::min(static_cast<difference_type>(__clz), __n);
131 __n -= __dn;
132 __storage_type __m = std::__middle_mask<__storage_type>(__clz - __dn, __first1.__ctz_);
133 if ((*__first2.__seg_ & __m) != (*__first1.__seg_ & __m))
134 return false;
135 ++__first2.__seg_;
136 ++__first1.__seg_;
137 // __first1.__ctz_ = 0;
138 // __first2.__ctz_ = 0;
139 }
140 // __first1.__ctz_ == 0;
141 // __first2.__ctz_ == 0;
142 // do middle words
143 for (; __n >= __bits_per_word; __n -= __bits_per_word, ++__first1.__seg_, ++__first2.__seg_)
144 if (*__first2.__seg_ != *__first1.__seg_)
145 return false;
146 // do last word
147 if (__n > 0) {
148 __storage_type __m = std::__trailing_mask<__storage_type>(__bits_per_word - __n);
149 if ((*__first2.__seg_ & __m) != (*__first1.__seg_ & __m))
150 return false;
151 }
152 }
153 return true;
154}
155
156template <class _Cp,
157 bool _IsConst1,
158 bool _IsConst2,
159 class _BinaryPredicate,
160 class _Proj1,
161 class _Proj2,
162 __enable_if_t<__is_identity<_Proj1>::value && __is_identity<_Proj2>::value &&
163 __desugars_to_v<__equal_tag, _BinaryPredicate, bool, bool>,
164 int> = 0>
165[[__nodiscard__]] inline _LIBCPP_HIDE_FROM_ABI _LIBCPP_CONSTEXPR_SINCE_CXX20 bool __equal_iter_impl(
166 __bit_iterator<_Cp, _IsConst1> __first1,
167 __bit_iterator<_Cp, _IsConst1> __last1,
168 __bit_iterator<_Cp, _IsConst2> __first2,
169 _BinaryPredicate,
170 _Proj1&,
171 _Proj2&) {
172 if (__first1.__ctz_ == __first2.__ctz_)
173 return std::__equal_aligned(__first1, __last1, __first2);
174 return std::__equal_unaligned(__first1, __last1, __first2);
175}
176
177template <class _InIter1, class _Sent1, class _InIter2, class _Pred, class _Proj1, class _Proj2>
178[[__nodiscard__]] inline _LIBCPP_HIDE_FROM_ABI _LIBCPP_CONSTEXPR_SINCE_CXX20 bool __equal_iter_impl(
179 _InIter1 __first1, _Sent1 __last1, _InIter2 __first2, _Pred& __pred, _Proj1& __proj1, _Proj2& __proj2) {
180 for (; __first1 != __last1; ++__first1, (void)++__first2)
181 if (!std::__invoke(__pred, std::__invoke(__proj1, *__first1), std::__invoke(__proj2, *__first2)))
182 return false;
183 return true;
184}
185
186template <class _Tp,
187 class _Up,
188 class _BinaryPredicate,
189 class _Proj1,
190 class _Proj2,
191 __enable_if_t<__is_identity<_Proj1>::value && __is_identity<_Proj2>::value &&
192 __desugars_to_v<__equal_tag, _BinaryPredicate, _Tp, _Up> && !is_volatile<_Tp>::value &&
193 !is_volatile<_Up>::value && __is_trivially_equality_comparable_v<_Tp, _Up>,
194 int> = 0>
195[[__nodiscard__]] inline _LIBCPP_HIDE_FROM_ABI _LIBCPP_CONSTEXPR_SINCE_CXX20 bool
196__equal_iter_impl(_Tp* __first1, _Tp* __last1, _Up* __first2, _BinaryPredicate&, _Proj1&, _Proj2&) {
197 return std::__constexpr_memcmp_equal(__first1, __first2, __element_count(__last1 - __first1));
198}
199
200template <class _InputIterator1, class _InputIterator2, class _BinaryPredicate>
201[[__nodiscard__]] inline _LIBCPP_HIDE_FROM_ABI _LIBCPP_CONSTEXPR_SINCE_CXX20 bool
202equal(_InputIterator1 __first1, _InputIterator1 __last1, _InputIterator2 __first2, _BinaryPredicate __pred) {
203 __identity __proj;
204 return std::__equal_iter_impl(
205 std::__unwrap_iter(__first1), std::__unwrap_iter(__last1), std::__unwrap_iter(__first2), __pred, __proj, __proj);
206}
207
208template <class _InputIterator1, class _InputIterator2>
209[[__nodiscard__]] inline _LIBCPP_HIDE_FROM_ABI _LIBCPP_CONSTEXPR_SINCE_CXX20 bool
210equal(_InputIterator1 __first1, _InputIterator1 __last1, _InputIterator2 __first2) {
211 return std::equal(__first1, __last1, __first2, __equal_to());
212}
213
214#if _LIBCPP_STD_VER >= 14
215
216template <bool __known_equal_length,
217 class _Iter1,
218 class _Sent1,
219 class _Iter2,
220 class _Sent2,
221 class _Pred,
222 class _Proj1,
223 class _Proj2>
224[[__nodiscard__]] inline _LIBCPP_HIDE_FROM_ABI _LIBCPP_CONSTEXPR_SINCE_CXX20 bool __equal_impl(
225 _Iter1 __first1, _Sent1 __last1, _Iter2 __first2, _Sent2 __last2, _Pred& __comp, _Proj1& __proj1, _Proj2& __proj2) {
226 if constexpr (__known_equal_length) {
227 return std::__equal_iter_impl(
228 std::move(__first1), std::move(__last1), std::move(__first2), __comp, __proj1, __proj2);
229 } else {
230 while (__first1 != __last1 && __first2 != __last2) {
231 if (!std::__invoke(__comp, std::__invoke(__proj1, *__first1), std::__invoke(__proj2, *__first2)))
232 return false;
233 ++__first1;
234 ++__first2;
235 }
236 return __first1 == __last1 && __first2 == __last2;
237 }
238}
239
240template <class _InputIterator1, class _InputIterator2, class _BinaryPredicate>
241[[__nodiscard__]] inline _LIBCPP_HIDE_FROM_ABI _LIBCPP_CONSTEXPR_SINCE_CXX20 bool
242equal(_InputIterator1 __first1,
243 _InputIterator1 __last1,
244 _InputIterator2 __first2,
245 _InputIterator2 __last2,
246 _BinaryPredicate __pred) {
247 constexpr bool __both_random_access =
248 __has_random_access_iterator_category<_InputIterator1>::value &&
249 __has_random_access_iterator_category<_InputIterator2>::value;
250 if constexpr (__both_random_access) {
251 if (__last1 - __first1 != __last2 - __first2)
252 return false;
253 }
254 __identity __proj;
255 return std::__equal_impl<__both_random_access>(
256 std::__unwrap_iter(__first1),
257 std::__unwrap_iter(__last1),
258 std::__unwrap_iter(__first2),
259 std::__unwrap_iter(__last2),
260 __pred,
261 __proj,
262 __proj);
263}
264
265template <class _InputIterator1, class _InputIterator2>
266[[__nodiscard__]] inline _LIBCPP_HIDE_FROM_ABI _LIBCPP_CONSTEXPR_SINCE_CXX20 bool
267equal(_InputIterator1 __first1, _InputIterator1 __last1, _InputIterator2 __first2, _InputIterator2 __last2) {
268 return std::equal(__first1, __last1, __first2, __last2, __equal_to());
269}
270
271#endif // _LIBCPP_STD_VER >= 14
272
273_LIBCPP_END_NAMESPACE_STD
274
275_LIBCPP_POP_MACROS
276
277#endif // _LIBCPP___ALGORITHM_EQUAL_H
278