1//===-- Implementation header for rsqrtf16 ----------------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#ifndef LLVM_LIBC_SRC___SUPPORT_MATH_RSQRTF16_H
10#define LLVM_LIBC_SRC___SUPPORT_MATH_RSQRTF16_H
11
12#include "include/llvm-libc-macros/float16-macros.h"
13
14#ifdef LIBC_TYPES_HAS_FLOAT16
15
16#include "src/__support/CPP/bit.h"
17#include "src/__support/FPUtil/FEnvImpl.h"
18#include "src/__support/FPUtil/FPBits.h"
19#include "src/__support/FPUtil/cast.h"
20#include "src/__support/FPUtil/multiply_add.h"
21#include "src/__support/FPUtil/sqrt.h"
22#include "src/__support/macros/optimization.h"
23
24namespace LIBC_NAMESPACE_DECL {
25namespace math {
26
27#ifndef LIBC_TARGET_CPU_HAS_FPU_FLOAT
28namespace rsqrtf16_internal {
29
30using FPBits = fputil::FPBits<float16>;
31
32// Fixed-point computations below use Q29: the integer N represents
33// N * 2^-29. Multiplying two Q29 values produces a Q58 value, so products are
34// shifted right by RSQRT_FRACTION_BITS to return to Q29.
35LIBC_INLINE_VAR constexpr int RSQRT_FRACTION_BITS = 29;
36LIBC_INLINE_VAR constexpr int64_t ONE = int64_t(1) << RSQRT_FRACTION_BITS;
37LIBC_INLINE_VAR constexpr int64_t THREE_HALVES = 3 * (ONE >> 1);
38
39LIBC_INLINE_VAR constexpr int HALF_FRACTION_LEN = FPBits::FRACTION_LEN;
40LIBC_INLINE_VAR constexpr int HALF_SIGNIFICAND_LEN = HALF_FRACTION_LEN + 1;
41LIBC_INLINE_VAR constexpr int HALF_EXP_BIAS = FPBits::EXP_BIAS;
42LIBC_INLINE_VAR constexpr uint16_t HALF_FRACTION_MASK = FPBits::FRACTION_MASK;
43LIBC_INLINE_VAR constexpr uint16_t HALF_MIN_NORMAL =
44 FPBits::min_normal().uintval();
45LIBC_INLINE_VAR constexpr uint16_t HALF_MAX_NORMAL =
46 FPBits::max_normal().uintval();
47LIBC_INLINE_VAR constexpr uint32_t HALF_HIDDEN_BIT = uint32_t(1)
48 << HALF_FRACTION_LEN;
49LIBC_INLINE_VAR constexpr int UINT32_BITS =
50 8 * static_cast<int>(sizeof(uint32_t));
51
52// Exact representation exponents for values stored as:
53// x = significand * 2^exponent,
54// where normal significands include the hidden bit.
55LIBC_INLINE_VAR constexpr int EXACT_NORMAL_EXP_OFFSET =
56 -HALF_EXP_BIAS - HALF_FRACTION_LEN;
57LIBC_INLINE_VAR constexpr int EXACT_SUBNORMAL_EXP =
58 1 - HALF_EXP_BIAS - HALF_FRACTION_LEN;
59
60// Exponents for the reduced form used by the approximation:
61// x = m * 2^exponent, with 0.5 <= m < 1.
62LIBC_INLINE_VAR constexpr int REDUCED_NORMAL_EXP_OFFSET = 1 - HALF_EXP_BIAS;
63LIBC_INLINE_VAR constexpr int REDUCED_SUBNORMAL_EXP =
64 EXACT_SUBNORMAL_EXP + HALF_SIGNIFICAND_LEN;
65
66LIBC_INLINE_VAR constexpr int RSQRT_APPROX_BITS = 4;
67LIBC_INLINE_VAR constexpr int RSQRT_APPROX_INDEX_SHIFT =
68 HALF_FRACTION_LEN - RSQRT_APPROX_BITS;
69
70// Midpoint lookup table for 1/sqrt(x) on 16 sub-intervals of [0.5;1).
71// Values are stored in Q29 fixed-point format. The Newton step and exact
72// rounding below correct the seed before producing the final half result.
73LIBC_INLINE_VAR constexpr uint32_t RSQRT_APPROX[16] = {
74 0x2c905a6f, 0x2b459b19, 0x2a160d52, 0x28fe28a0, 0x27fb00f0, 0x270a2574,
75 0x262987b2, 0x25576878, 0x24924925, 0x23d8e025, 0x232a0fda, 0x2284df58,
76 0x21e8748c, 0x21540f7b, 0x20c70664, 0x2040c289,
77};
78LIBC_INLINE_VAR constexpr int64_t ONE_OVER_SQRT2 = 0x16a09e60;
79
80LIBC_INLINE constexpr int floor_log2(uint64_t x) {
81 return 63 - cpp::countl_zero(x);
82}
83
84LIBC_INLINE constexpr int64_t initial_approximation(uint32_t x_mant) {
85 return RSQRT_APPROX[(x_mant - HALF_HIDDEN_BIT) >> RSQRT_APPROX_INDEX_SHIFT];
86}
87
88LIBC_INLINE constexpr int64_t newton_raphson(uint32_t m, int64_t y) {
89 // Refine y ~= 1/sqrt(m) with:
90 // y_{n+1} = y_n * (1.5 - 0.5 * m * y_n^2)
91 // where both m and y are stored in Q29.
92 int64_t y2 = (y * y) >> RSQRT_FRACTION_BITS;
93 int64_t my2 = (static_cast<int64_t>(m) * y2) >> RSQRT_FRACTION_BITS;
94 int64_t factor = THREE_HALVES - (my2 >> 1);
95 return (y * factor) >> RSQRT_FRACTION_BITS;
96}
97
98LIBC_INLINE constexpr uint16_t fixed_to_half_bits(uint64_t y, int scale_exp) {
99 // Convert y * 2^scale_exp, with y in Q29, to an approximate positive normal
100 // half bit pattern. This only creates a nearby candidate; exact rounding is
101 // handled by floor_rsqrt and round_result.
102 int y_log2 = floor_log2(y);
103 int out_exp = scale_exp + y_log2 - RSQRT_FRACTION_BITS;
104 int biased_exp = out_exp + HALF_EXP_BIAS;
105
106 uint32_t out_sig =
107 y_log2 >= HALF_FRACTION_LEN
108 ? static_cast<uint32_t>(y >> (y_log2 - HALF_FRACTION_LEN))
109 : static_cast<uint32_t>(y << (HALF_FRACTION_LEN - y_log2));
110
111 if (biased_exp <= 0)
112 return HALF_MIN_NORMAL;
113 if (biased_exp >= FPBits::MAX_BIASED_EXPONENT)
114 return HALF_MAX_NORMAL;
115
116 return static_cast<uint16_t>((biased_exp << HALF_FRACTION_LEN) |
117 (out_sig & HALF_FRACTION_MASK));
118}
119
120// `value` is the approximate positive half bit pattern produced by the table
121// seed, Newton step, and exponent scaling. `x_sig` and `x_exp` keep the exact
122// input as x_sig * 2^x_exp, which is needed to compare candidates against the
123// mathematical result without floating-point operations.
124struct ApproxResult {
125 uint16_t value;
126 uint32_t x_sig;
127 int x_exp;
128};
129
130LIBC_INLINE constexpr ApproxResult approximate_rsqrt(uint16_t x_abs) {
131 uint32_t x_mant = x_abs & HALF_FRACTION_MASK;
132 uint32_t x_sig = x_mant;
133 int x_exp = EXACT_SUBNORMAL_EXP;
134 int exponent = 0;
135
136 // Decompose the finite positive input as:
137 // x = m * 2^exponent, with 0.5 <= m < 1.
138 // `x_sig` and `x_exp` keep the exact input as x_sig * 2^x_exp for the integer
139 // rounding test below.
140 if (x_abs >= HALF_MIN_NORMAL) {
141 int biased_exp = static_cast<int>(x_abs >> HALF_FRACTION_LEN);
142 x_mant |= HALF_HIDDEN_BIT;
143 x_sig = x_mant;
144 x_exp = biased_exp + EXACT_NORMAL_EXP_OFFSET;
145 exponent = biased_exp + REDUCED_NORMAL_EXP_OFFSET;
146 } else {
147 int shift = cpp::countl_zero(x_mant) - (UINT32_BITS - HALF_SIGNIFICAND_LEN);
148 x_mant <<= shift;
149 exponent = REDUCED_SUBNORMAL_EXP - shift;
150 }
151
152 uint32_t m = x_mant << (RSQRT_FRACTION_BITS - HALF_SIGNIFICAND_LEN);
153 int64_t y = newton_raphson(m, initial_approximation(x_mant));
154
155 // Since rsqrt(m * 2^e) = rsqrt(m) * 2^(-e/2), odd exponents need one
156 // extra factor of 1/sqrt(2) before applying the integral power of two.
157 int scale_exp = 0;
158 if (exponent & 1) {
159 y = (y * ONE_OVER_SQRT2) >> RSQRT_FRACTION_BITS;
160 scale_exp = -((exponent - 1) / 2);
161 } else {
162 scale_exp = -(exponent / 2);
163 }
164
165 return {fixed_to_half_bits(static_cast<uint64_t>(y), scale_exp), x_sig,
166 x_exp};
167}
168
169// Compare y = sig * 2^exp with 1 / sqrt(x_sig * 2^x_exp).
170// Return -1 if y is below the exact value, 0 if exact, and 1 if above.
171// Instead of computing a reciprocal square root, square both sides:
172// y <= 1/sqrt(x) <=> y^2 * x <= 1.
173LIBC_INLINE constexpr int compare_with_rsqrt(uint32_t sig, int exp,
174 uint32_t x_sig, int x_exp) {
175 uint64_t lhs = static_cast<uint64_t>(sig) * sig * x_sig;
176 // For all finite positive half inputs and candidates produced by this
177 // algorithm, 2 * exp + x_exp is in [-34, -20].
178 int rshift = -(2 * exp + x_exp);
179 uint64_t rhs = uint64_t(1) << rshift;
180 if (lhs < rhs)
181 return -1;
182 if (lhs > rhs)
183 return 1;
184 return 0;
185}
186
187LIBC_INLINE constexpr int compare_half_with_rsqrt(uint16_t y, uint32_t x_sig,
188 int x_exp) {
189 uint32_t y_sig = HALF_HIDDEN_BIT | (y & HALF_FRACTION_MASK);
190 int y_exp =
191 static_cast<int>(y >> HALF_FRACTION_LEN) + EXACT_NORMAL_EXP_OFFSET;
192 return compare_with_rsqrt(y_sig, y_exp, x_sig, x_exp);
193}
194
195struct FloorResult {
196 uint16_t value;
197 int cmp;
198};
199
200LIBC_INLINE constexpr FloorResult floor_rsqrt(uint16_t approx, uint32_t x_sig,
201 int x_exp) {
202 // The table seed and Newton step have been validated exhaustively to produce
203 // a candidate at most one half-precision step below the exact floor.
204 uint16_t y = approx < HALF_MIN_NORMAL ? HALF_MIN_NORMAL : approx;
205 int cmp = compare_half_with_rsqrt(y, x_sig, x_exp);
206 if (LIBC_UNLIKELY(cmp > 0)) {
207 --y;
208 cmp = compare_half_with_rsqrt(y, x_sig, x_exp);
209 } else if (LIBC_UNLIKELY(y < HALF_MAX_NORMAL)) {
210 int next_cmp = compare_half_with_rsqrt(y + 1, x_sig, x_exp);
211 if (LIBC_UNLIKELY(next_cmp <= 0)) {
212 ++y;
213 cmp = next_cmp;
214 }
215 }
216 return {y, cmp};
217}
218
219LIBC_INLINE constexpr uint16_t round_result(FloorResult floor, uint32_t x_sig,
220 int x_exp) {
221 uint16_t y = floor.value;
222 if (floor.cmp == 0)
223 return y;
224
225 // Once `y` is the greatest half value below the exact result, directed
226 // rounding is immediate. Round-to-nearest compares against the midpoint
227 // between `y` and the next half value, then applies ties-to-even.
228 int rounding_mode = FE_TONEAREST;
229 if (!cpp::is_constant_evaluated())
230 rounding_mode = fputil::get_round();
231 if (rounding_mode == FE_UPWARD)
232 return y + 1;
233 if (rounding_mode != FE_TONEAREST)
234 return y;
235
236 uint32_t y_sig = HALF_HIDDEN_BIT | (y & HALF_FRACTION_MASK);
237 int y_exp =
238 static_cast<int>(y >> HALF_FRACTION_LEN) + EXACT_NORMAL_EXP_OFFSET;
239 uint32_t midpoint_sig = (y_sig << 1) | 1;
240 int midpoint_cmp = compare_with_rsqrt(midpoint_sig, y_exp - 1, x_sig, x_exp);
241
242 if (midpoint_cmp < 0)
243 return y + 1;
244 if (midpoint_cmp > 0)
245 return y;
246 return (y & 1) ? static_cast<uint16_t>(y + 1) : y;
247}
248
249LIBC_INLINE constexpr float16 rsqrtf16_no_float(uint16_t x_abs) {
250 ApproxResult approx = approximate_rsqrt(x_abs);
251 FloorResult floor = floor_rsqrt(approx.value, approx.x_sig, approx.x_exp);
252 return fputil::FPBits<float16>(
253 round_result(floor, approx.x_sig, approx.x_exp))
254 .get_val();
255}
256
257} // namespace rsqrtf16_internal
258#endif // !LIBC_TARGET_CPU_HAS_FPU_FLOAT
259
260LIBC_INLINE constexpr float16 rsqrtf16(float16 x) {
261 using FPBits = fputil::FPBits<float16>;
262 FPBits xbits(x);
263
264 uint16_t x_u = xbits.uintval();
265 uint16_t x_abs = x_u & 0x7fff;
266
267 constexpr uint16_t INF_BIT = FPBits::inf().uintval();
268
269 // x is 0, inf/nan, or negative.
270 if (LIBC_UNLIKELY(x_u == 0 || x_u >= INF_BIT)) {
271 // x is NaN
272 if (x_abs > INF_BIT) {
273 if (xbits.is_signaling_nan()) {
274 fputil::raise_except_if_required(FE_INVALID);
275 return FPBits::quiet_nan().get_val();
276 }
277 return x;
278 }
279
280 // |x| = 0
281 if (x_abs == 0) {
282 fputil::raise_except_if_required(FE_DIVBYZERO);
283 fputil::set_errno_if_required(ERANGE);
284 return FPBits::inf(sign: xbits.sign()).get_val();
285 }
286
287 // -inf <= x < 0
288 if (x_u > 0x7fff) {
289 fputil::raise_except_if_required(FE_INVALID);
290 fputil::set_errno_if_required(EDOM);
291 return FPBits::quiet_nan().get_val();
292 }
293
294 // x = +inf => rsqrt(x) = +0
295 return FPBits::zero(sign: xbits.sign()).get_val();
296 }
297
298#ifdef LIBC_TARGET_CPU_HAS_FPU_FLOAT
299 float result = 1.0f / fputil::sqrt<float>(x: fputil::cast<float>(x));
300
301 // Targeted post-corrections to ensure correct rounding in half for specific
302 // mantissa patterns
303 const uint16_t half_mantissa = x_abs & 0x3ff;
304 if (LIBC_UNLIKELY(half_mantissa == 0x011F)) {
305 result = fputil::multiply_add(x: result, y: 0x1.0p-21f, z: result);
306 } else if (LIBC_UNLIKELY(half_mantissa == 0x0313)) {
307 result = fputil::multiply_add(x: result, y: -0x1.0p-21f, z: result);
308 }
309
310 return fputil::cast<float16>(x: result);
311
312#else
313 return rsqrtf16_internal::rsqrtf16_no_float(x_abs);
314#endif
315}
316
317} // namespace math
318} // namespace LIBC_NAMESPACE_DECL
319
320#endif // LIBC_TYPES_HAS_FLOAT16
321
322#endif // LLVM_LIBC_SRC___SUPPORT_MATH_RSQRTF16_H
323