1//===-- Square root of IEEE 754 floating point numbers ----------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#ifndef LLVM_LIBC_SRC___SUPPORT_FPUTIL_GENERIC_SQRT_H
10#define LLVM_LIBC_SRC___SUPPORT_FPUTIL_GENERIC_SQRT_H
11
12#include "src/__support/CPP/bit.h" // countl_zero
13#include "src/__support/CPP/type_traits.h"
14#include "src/__support/FPUtil/FEnvImpl.h"
15#include "src/__support/FPUtil/FPBits.h"
16#include "src/__support/FPUtil/cast.h"
17#include "src/__support/FPUtil/dyadic_float.h"
18#include "src/__support/common.h"
19#include "src/__support/macros/config.h"
20#include "src/__support/uint128.h"
21
22#include "hdr/fenv_macros.h"
23
24#ifdef LIBC_TYPES_LONG_DOUBLE_IS_X86_FLOAT80
25#include "sqrt_80_bit_long_double.h"
26#endif // !LIBC_TYPES_LONG_DOUBLE_IS_X86_FLOAT80
27
28namespace LIBC_NAMESPACE_DECL {
29namespace fputil {
30
31namespace internal {
32
33template <typename T> struct SpecialLongDouble {
34 static constexpr bool VALUE = false;
35};
36
37#if defined(LIBC_TYPES_LONG_DOUBLE_IS_X86_FLOAT80)
38template <> struct SpecialLongDouble<long double> {
39 static constexpr bool VALUE = true;
40};
41#endif // LIBC_TYPES_LONG_DOUBLE_IS_X86_FLOAT80
42
43template <typename T>
44LIBC_INLINE constexpr void
45normalize(int &exponent, typename FPBits<T>::StorageType &mantissa) {
46 const int shift =
47 cpp::countl_zero(mantissa) -
48 (8 * static_cast<int>(sizeof(mantissa)) - 1 - FPBits<T>::FRACTION_LEN);
49 exponent -= shift;
50 mantissa <<= shift;
51}
52
53#ifdef LIBC_TYPES_LONG_DOUBLE_IS_FLOAT64
54template <>
55LIBC_INLINE constexpr void normalize<long double>(int &exponent,
56 uint64_t &mantissa) {
57 normalize<double>(exponent, mantissa);
58}
59#elif defined(LIBC_TYPES_LONG_DOUBLE_IS_FLOAT128)
60template <>
61LIBC_INLINE constexpr void normalize<long double>(int &exponent,
62 UInt128 &mantissa) {
63 const uint64_t hi_bits = static_cast<uint64_t>(mantissa >> 64);
64 const int shift =
65 hi_bits ? (cpp::countl_zero(hi_bits) - 15)
66 : (cpp::countl_zero(static_cast<uint64_t>(mantissa)) + 49);
67 exponent -= shift;
68 mantissa <<= shift;
69}
70#endif
71
72} // namespace internal
73
74// Correctly rounded IEEE 754 SQRT for all rounding modes.
75// Shift-and-add algorithm.
76template <typename OutType, typename InType>
77LIBC_INLINE LIBC_CONSTEXPR_DEFAULT static cpp::enable_if_t<
78 cpp::is_floating_point_v<OutType> && cpp::is_floating_point_v<InType> &&
79 sizeof(OutType) <= sizeof(InType),
80 OutType>
81sqrt(InType x) {
82 if constexpr (internal::SpecialLongDouble<OutType>::VALUE &&
83 internal::SpecialLongDouble<InType>::VALUE) {
84#ifdef LIBC_TYPES_LONG_DOUBLE_IS_X86_FLOAT80
85 // Special 80-bit long double.
86 return x86::sqrt(x);
87#endif // !LIBC_TYPES_LONG_DOUBLE_IS_X86_FLOAT80
88 } else {
89 // IEEE floating points formats.
90 using OutFPBits = FPBits<OutType>;
91 using InFPBits = FPBits<InType>;
92 using InStorageType = typename InFPBits::StorageType;
93 using DyadicFloat =
94 DyadicFloat<cpp::bit_ceil(value: static_cast<size_t>(InFPBits::STORAGE_LEN))>;
95
96 constexpr InStorageType ONE = InStorageType(1) << InFPBits::FRACTION_LEN;
97 LIBC_BIT_CAST_CONSTEXPR_VAR auto FLT_NAN = OutFPBits::quiet_nan().get_val();
98
99 InFPBits bits(x);
100
101 if (bits == InFPBits::inf(Sign::POS) || bits.is_zero() || bits.is_nan()) {
102 // sqrt(+Inf) = +Inf
103 // sqrt(+0) = +0
104 // sqrt(-0) = -0
105 // sqrt(NaN) = NaN
106 // sqrt(-NaN) = -NaN
107 return cast<OutType>(x);
108 } else if (bits.is_neg()) {
109 // sqrt(-Inf) = NaN
110 // sqrt(-x) = NaN
111 return FLT_NAN;
112 } else {
113 int x_exp = bits.get_exponent();
114 InStorageType x_mant = bits.get_mantissa();
115
116 // Step 1a: Normalize denormal input and append hidden bit to the mantissa
117 if (bits.is_subnormal()) {
118 ++x_exp; // let x_exp be the correct exponent of ONE bit.
119 internal::normalize<InType>(x_exp, x_mant);
120 } else {
121 x_mant |= ONE;
122 }
123
124 // Step 1b: Make sure the exponent is even.
125 if (x_exp & 1) {
126 --x_exp;
127 x_mant <<= 1;
128 }
129
130 // After step 1b, x = 2^(x_exp) * x_mant, where x_exp is even, and
131 // 1 <= x_mant < 4. So sqrt(x) = 2^(x_exp / 2) * y, with 1 <= y < 2.
132 // Notice that the output of sqrt is always in the normal range.
133 // To perform shift-and-add algorithm to find y, let denote:
134 // y(n) = 1.y_1 y_2 ... y_n, we can define the nth residue to be:
135 // r(n) = 2^n ( x_mant - y(n)^2 ).
136 // That leads to the following recurrence formula:
137 // r(n) = 2*r(n-1) - y_n*[ 2*y(n-1) + 2^(-n-1) ]
138 // with the initial conditions: y(0) = 1, and r(0) = x - 1.
139 // So the nth digit y_n of the mantissa of sqrt(x) can be found by:
140 // y_n = 1 if 2*r(n-1) >= 2*y(n - 1) + 2^(-n-1)
141 // 0 otherwise.
142 InStorageType y = ONE;
143 InStorageType r = x_mant - ONE;
144
145 // TODO: Reduce iteration count to OutFPBits::FRACTION_LEN + 2 or + 3.
146 for (InStorageType current_bit = ONE >> 1; current_bit;
147 current_bit >>= 1) {
148 r <<= 1;
149 // 2*y(n - 1) + 2^(-n-1)
150 InStorageType tmp = static_cast<InStorageType>((y << 1) + current_bit);
151 if (r >= tmp) {
152 r -= tmp;
153 y += current_bit;
154 }
155 }
156
157 // We compute one more iteration in order to round correctly.
158 r <<= 2;
159 y <<= 2;
160 InStorageType tmp = y + 1;
161 if (r >= tmp) {
162 r -= tmp;
163 // Rounding bit.
164 y |= 2;
165 }
166 // Sticky bit.
167 y |= static_cast<unsigned int>(r != 0);
168
169 DyadicFloat yd(Sign::POS, (x_exp >> 1) - 2 - InFPBits::FRACTION_LEN, y);
170 return yd.template as<OutType, /*ShouldSignalExceptions=*/true>();
171 }
172 }
173}
174
175} // namespace fputil
176} // namespace LIBC_NAMESPACE_DECL
177
178#endif // LLVM_LIBC_SRC___SUPPORT_FPUTIL_GENERIC_SQRT_H
179