| 1 | //===-- Implementation header for rsqrtf16 ----------------------*- C++ -*-===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | |
| 9 | #ifndef LLVM_LIBC_SRC___SUPPORT_MATH_RSQRTF16_H |
| 10 | #define LLVM_LIBC_SRC___SUPPORT_MATH_RSQRTF16_H |
| 11 | |
| 12 | #include "include/llvm-libc-macros/float16-macros.h" |
| 13 | |
| 14 | #ifdef LIBC_TYPES_HAS_FLOAT16 |
| 15 | |
| 16 | #include "src/__support/CPP/bit.h" |
| 17 | #include "src/__support/FPUtil/FEnvImpl.h" |
| 18 | #include "src/__support/FPUtil/FPBits.h" |
| 19 | #include "src/__support/FPUtil/cast.h" |
| 20 | #include "src/__support/FPUtil/multiply_add.h" |
| 21 | #include "src/__support/FPUtil/sqrt.h" |
| 22 | #include "src/__support/macros/optimization.h" |
| 23 | |
| 24 | namespace LIBC_NAMESPACE_DECL { |
| 25 | namespace math { |
| 26 | |
| 27 | #ifndef LIBC_TARGET_CPU_HAS_FPU_FLOAT |
| 28 | namespace rsqrtf16_internal { |
| 29 | |
| 30 | using FPBits = fputil::FPBits<float16>; |
| 31 | |
| 32 | // Fixed-point computations below use Q29: the integer N represents |
| 33 | // N * 2^-29. Multiplying two Q29 values produces a Q58 value, so products are |
| 34 | // shifted right by RSQRT_FRACTION_BITS to return to Q29. |
| 35 | LIBC_INLINE_VAR constexpr int RSQRT_FRACTION_BITS = 29; |
| 36 | LIBC_INLINE_VAR constexpr int64_t ONE = int64_t(1) << RSQRT_FRACTION_BITS; |
| 37 | LIBC_INLINE_VAR constexpr int64_t THREE_HALVES = 3 * (ONE >> 1); |
| 38 | |
| 39 | LIBC_INLINE_VAR constexpr int HALF_FRACTION_LEN = FPBits::FRACTION_LEN; |
| 40 | LIBC_INLINE_VAR constexpr int HALF_SIGNIFICAND_LEN = HALF_FRACTION_LEN + 1; |
| 41 | LIBC_INLINE_VAR constexpr int HALF_EXP_BIAS = FPBits::EXP_BIAS; |
| 42 | LIBC_INLINE_VAR constexpr uint16_t HALF_FRACTION_MASK = FPBits::FRACTION_MASK; |
| 43 | LIBC_INLINE_VAR constexpr uint16_t HALF_MIN_NORMAL = |
| 44 | FPBits::min_normal().uintval(); |
| 45 | LIBC_INLINE_VAR constexpr uint16_t HALF_MAX_NORMAL = |
| 46 | FPBits::max_normal().uintval(); |
| 47 | LIBC_INLINE_VAR constexpr uint32_t HALF_HIDDEN_BIT = uint32_t(1) |
| 48 | << HALF_FRACTION_LEN; |
| 49 | LIBC_INLINE_VAR constexpr int UINT32_BITS = |
| 50 | 8 * static_cast<int>(sizeof(uint32_t)); |
| 51 | |
| 52 | // Exact representation exponents for values stored as: |
| 53 | // x = significand * 2^exponent, |
| 54 | // where normal significands include the hidden bit. |
| 55 | LIBC_INLINE_VAR constexpr int EXACT_NORMAL_EXP_OFFSET = |
| 56 | -HALF_EXP_BIAS - HALF_FRACTION_LEN; |
| 57 | LIBC_INLINE_VAR constexpr int EXACT_SUBNORMAL_EXP = |
| 58 | 1 - HALF_EXP_BIAS - HALF_FRACTION_LEN; |
| 59 | |
| 60 | // Exponents for the reduced form used by the approximation: |
| 61 | // x = m * 2^exponent, with 0.5 <= m < 1. |
| 62 | LIBC_INLINE_VAR constexpr int REDUCED_NORMAL_EXP_OFFSET = 1 - HALF_EXP_BIAS; |
| 63 | LIBC_INLINE_VAR constexpr int REDUCED_SUBNORMAL_EXP = |
| 64 | EXACT_SUBNORMAL_EXP + HALF_SIGNIFICAND_LEN; |
| 65 | |
| 66 | LIBC_INLINE_VAR constexpr int RSQRT_APPROX_BITS = 4; |
| 67 | LIBC_INLINE_VAR constexpr int RSQRT_APPROX_INDEX_SHIFT = |
| 68 | HALF_FRACTION_LEN - RSQRT_APPROX_BITS; |
| 69 | |
| 70 | // Midpoint lookup table for 1/sqrt(x) on 16 sub-intervals of [0.5;1). |
| 71 | // Values are stored in Q29 fixed-point format. The Newton step and exact |
| 72 | // rounding below correct the seed before producing the final half result. |
| 73 | LIBC_INLINE_VAR constexpr uint32_t RSQRT_APPROX[16] = { |
| 74 | 0x2c905a6f, 0x2b459b19, 0x2a160d52, 0x28fe28a0, 0x27fb00f0, 0x270a2574, |
| 75 | 0x262987b2, 0x25576878, 0x24924925, 0x23d8e025, 0x232a0fda, 0x2284df58, |
| 76 | 0x21e8748c, 0x21540f7b, 0x20c70664, 0x2040c289, |
| 77 | }; |
| 78 | LIBC_INLINE_VAR constexpr int64_t ONE_OVER_SQRT2 = 0x16a09e60; |
| 79 | |
| 80 | LIBC_INLINE constexpr int floor_log2(uint64_t x) { |
| 81 | return 63 - cpp::countl_zero(x); |
| 82 | } |
| 83 | |
| 84 | LIBC_INLINE constexpr int64_t initial_approximation(uint32_t x_mant) { |
| 85 | return RSQRT_APPROX[(x_mant - HALF_HIDDEN_BIT) >> RSQRT_APPROX_INDEX_SHIFT]; |
| 86 | } |
| 87 | |
| 88 | LIBC_INLINE constexpr int64_t newton_raphson(uint32_t m, int64_t y) { |
| 89 | // Refine y ~= 1/sqrt(m) with: |
| 90 | // y_{n+1} = y_n * (1.5 - 0.5 * m * y_n^2) |
| 91 | // where both m and y are stored in Q29. |
| 92 | int64_t y2 = (y * y) >> RSQRT_FRACTION_BITS; |
| 93 | int64_t my2 = (static_cast<int64_t>(m) * y2) >> RSQRT_FRACTION_BITS; |
| 94 | int64_t factor = THREE_HALVES - (my2 >> 1); |
| 95 | return (y * factor) >> RSQRT_FRACTION_BITS; |
| 96 | } |
| 97 | |
| 98 | LIBC_INLINE constexpr uint16_t fixed_to_half_bits(uint64_t y, int scale_exp) { |
| 99 | // Convert y * 2^scale_exp, with y in Q29, to an approximate positive normal |
| 100 | // half bit pattern. This only creates a nearby candidate; exact rounding is |
| 101 | // handled by floor_rsqrt and round_result. |
| 102 | int y_log2 = floor_log2(y); |
| 103 | int out_exp = scale_exp + y_log2 - RSQRT_FRACTION_BITS; |
| 104 | int biased_exp = out_exp + HALF_EXP_BIAS; |
| 105 | |
| 106 | uint32_t out_sig = |
| 107 | y_log2 >= HALF_FRACTION_LEN |
| 108 | ? static_cast<uint32_t>(y >> (y_log2 - HALF_FRACTION_LEN)) |
| 109 | : static_cast<uint32_t>(y << (HALF_FRACTION_LEN - y_log2)); |
| 110 | |
| 111 | if (biased_exp <= 0) |
| 112 | return HALF_MIN_NORMAL; |
| 113 | if (biased_exp >= FPBits::MAX_BIASED_EXPONENT) |
| 114 | return HALF_MAX_NORMAL; |
| 115 | |
| 116 | return static_cast<uint16_t>((biased_exp << HALF_FRACTION_LEN) | |
| 117 | (out_sig & HALF_FRACTION_MASK)); |
| 118 | } |
| 119 | |
| 120 | // `value` is the approximate positive half bit pattern produced by the table |
| 121 | // seed, Newton step, and exponent scaling. `x_sig` and `x_exp` keep the exact |
| 122 | // input as x_sig * 2^x_exp, which is needed to compare candidates against the |
| 123 | // mathematical result without floating-point operations. |
| 124 | struct ApproxResult { |
| 125 | uint16_t value; |
| 126 | uint32_t x_sig; |
| 127 | int x_exp; |
| 128 | }; |
| 129 | |
| 130 | LIBC_INLINE constexpr ApproxResult approximate_rsqrt(uint16_t x_abs) { |
| 131 | uint32_t x_mant = x_abs & HALF_FRACTION_MASK; |
| 132 | uint32_t x_sig = x_mant; |
| 133 | int x_exp = EXACT_SUBNORMAL_EXP; |
| 134 | int exponent = 0; |
| 135 | |
| 136 | // Decompose the finite positive input as: |
| 137 | // x = m * 2^exponent, with 0.5 <= m < 1. |
| 138 | // `x_sig` and `x_exp` keep the exact input as x_sig * 2^x_exp for the integer |
| 139 | // rounding test below. |
| 140 | if (x_abs >= HALF_MIN_NORMAL) { |
| 141 | int biased_exp = static_cast<int>(x_abs >> HALF_FRACTION_LEN); |
| 142 | x_mant |= HALF_HIDDEN_BIT; |
| 143 | x_sig = x_mant; |
| 144 | x_exp = biased_exp + EXACT_NORMAL_EXP_OFFSET; |
| 145 | exponent = biased_exp + REDUCED_NORMAL_EXP_OFFSET; |
| 146 | } else { |
| 147 | int shift = cpp::countl_zero(x_mant) - (UINT32_BITS - HALF_SIGNIFICAND_LEN); |
| 148 | x_mant <<= shift; |
| 149 | exponent = REDUCED_SUBNORMAL_EXP - shift; |
| 150 | } |
| 151 | |
| 152 | uint32_t m = x_mant << (RSQRT_FRACTION_BITS - HALF_SIGNIFICAND_LEN); |
| 153 | int64_t y = newton_raphson(m, initial_approximation(x_mant)); |
| 154 | |
| 155 | // Since rsqrt(m * 2^e) = rsqrt(m) * 2^(-e/2), odd exponents need one |
| 156 | // extra factor of 1/sqrt(2) before applying the integral power of two. |
| 157 | int scale_exp = 0; |
| 158 | if (exponent & 1) { |
| 159 | y = (y * ONE_OVER_SQRT2) >> RSQRT_FRACTION_BITS; |
| 160 | scale_exp = -((exponent - 1) / 2); |
| 161 | } else { |
| 162 | scale_exp = -(exponent / 2); |
| 163 | } |
| 164 | |
| 165 | return {fixed_to_half_bits(static_cast<uint64_t>(y), scale_exp), x_sig, |
| 166 | x_exp}; |
| 167 | } |
| 168 | |
| 169 | // Compare y = sig * 2^exp with 1 / sqrt(x_sig * 2^x_exp). |
| 170 | // Return -1 if y is below the exact value, 0 if exact, and 1 if above. |
| 171 | // Instead of computing a reciprocal square root, square both sides: |
| 172 | // y <= 1/sqrt(x) <=> y^2 * x <= 1. |
| 173 | LIBC_INLINE constexpr int compare_with_rsqrt(uint32_t sig, int exp, |
| 174 | uint32_t x_sig, int x_exp) { |
| 175 | uint64_t lhs = static_cast<uint64_t>(sig) * sig * x_sig; |
| 176 | // For all finite positive half inputs and candidates produced by this |
| 177 | // algorithm, 2 * exp + x_exp is in [-34, -20]. |
| 178 | int rshift = -(2 * exp + x_exp); |
| 179 | uint64_t rhs = uint64_t(1) << rshift; |
| 180 | if (lhs < rhs) |
| 181 | return -1; |
| 182 | if (lhs > rhs) |
| 183 | return 1; |
| 184 | return 0; |
| 185 | } |
| 186 | |
| 187 | LIBC_INLINE constexpr int compare_half_with_rsqrt(uint16_t y, uint32_t x_sig, |
| 188 | int x_exp) { |
| 189 | uint32_t y_sig = HALF_HIDDEN_BIT | (y & HALF_FRACTION_MASK); |
| 190 | int y_exp = |
| 191 | static_cast<int>(y >> HALF_FRACTION_LEN) + EXACT_NORMAL_EXP_OFFSET; |
| 192 | return compare_with_rsqrt(y_sig, y_exp, x_sig, x_exp); |
| 193 | } |
| 194 | |
| 195 | struct FloorResult { |
| 196 | uint16_t value; |
| 197 | int cmp; |
| 198 | }; |
| 199 | |
| 200 | LIBC_INLINE constexpr FloorResult floor_rsqrt(uint16_t approx, uint32_t x_sig, |
| 201 | int x_exp) { |
| 202 | // The table seed and Newton step have been validated exhaustively to produce |
| 203 | // a candidate at most one half-precision step below the exact floor. |
| 204 | uint16_t y = approx < HALF_MIN_NORMAL ? HALF_MIN_NORMAL : approx; |
| 205 | int cmp = compare_half_with_rsqrt(y, x_sig, x_exp); |
| 206 | if (LIBC_UNLIKELY(cmp > 0)) { |
| 207 | --y; |
| 208 | cmp = compare_half_with_rsqrt(y, x_sig, x_exp); |
| 209 | } else if (LIBC_UNLIKELY(y < HALF_MAX_NORMAL)) { |
| 210 | int next_cmp = compare_half_with_rsqrt(y + 1, x_sig, x_exp); |
| 211 | if (LIBC_UNLIKELY(next_cmp <= 0)) { |
| 212 | ++y; |
| 213 | cmp = next_cmp; |
| 214 | } |
| 215 | } |
| 216 | return {y, cmp}; |
| 217 | } |
| 218 | |
| 219 | LIBC_INLINE constexpr uint16_t round_result(FloorResult floor, uint32_t x_sig, |
| 220 | int x_exp) { |
| 221 | uint16_t y = floor.value; |
| 222 | if (floor.cmp == 0) |
| 223 | return y; |
| 224 | |
| 225 | // Once `y` is the greatest half value below the exact result, directed |
| 226 | // rounding is immediate. Round-to-nearest compares against the midpoint |
| 227 | // between `y` and the next half value, then applies ties-to-even. |
| 228 | int rounding_mode = FE_TONEAREST; |
| 229 | if (!cpp::is_constant_evaluated()) |
| 230 | rounding_mode = fputil::get_round(); |
| 231 | if (rounding_mode == FE_UPWARD) |
| 232 | return y + 1; |
| 233 | if (rounding_mode != FE_TONEAREST) |
| 234 | return y; |
| 235 | |
| 236 | uint32_t y_sig = HALF_HIDDEN_BIT | (y & HALF_FRACTION_MASK); |
| 237 | int y_exp = |
| 238 | static_cast<int>(y >> HALF_FRACTION_LEN) + EXACT_NORMAL_EXP_OFFSET; |
| 239 | uint32_t midpoint_sig = (y_sig << 1) | 1; |
| 240 | int midpoint_cmp = compare_with_rsqrt(midpoint_sig, y_exp - 1, x_sig, x_exp); |
| 241 | |
| 242 | if (midpoint_cmp < 0) |
| 243 | return y + 1; |
| 244 | if (midpoint_cmp > 0) |
| 245 | return y; |
| 246 | return (y & 1) ? static_cast<uint16_t>(y + 1) : y; |
| 247 | } |
| 248 | |
| 249 | LIBC_INLINE constexpr float16 rsqrtf16_no_float(uint16_t x_abs) { |
| 250 | ApproxResult approx = approximate_rsqrt(x_abs); |
| 251 | FloorResult floor = floor_rsqrt(approx.value, approx.x_sig, approx.x_exp); |
| 252 | return fputil::FPBits<float16>( |
| 253 | round_result(floor, approx.x_sig, approx.x_exp)) |
| 254 | .get_val(); |
| 255 | } |
| 256 | |
| 257 | } // namespace rsqrtf16_internal |
| 258 | #endif // !LIBC_TARGET_CPU_HAS_FPU_FLOAT |
| 259 | |
| 260 | LIBC_INLINE constexpr float16 rsqrtf16(float16 x) { |
| 261 | using FPBits = fputil::FPBits<float16>; |
| 262 | FPBits xbits(x); |
| 263 | |
| 264 | uint16_t x_u = xbits.uintval(); |
| 265 | uint16_t x_abs = x_u & 0x7fff; |
| 266 | |
| 267 | constexpr uint16_t INF_BIT = FPBits::inf().uintval(); |
| 268 | |
| 269 | // x is 0, inf/nan, or negative. |
| 270 | if (LIBC_UNLIKELY(x_u == 0 || x_u >= INF_BIT)) { |
| 271 | // x is NaN |
| 272 | if (x_abs > INF_BIT) { |
| 273 | if (xbits.is_signaling_nan()) { |
| 274 | fputil::raise_except_if_required(FE_INVALID); |
| 275 | return FPBits::quiet_nan().get_val(); |
| 276 | } |
| 277 | return x; |
| 278 | } |
| 279 | |
| 280 | // |x| = 0 |
| 281 | if (x_abs == 0) { |
| 282 | fputil::raise_except_if_required(FE_DIVBYZERO); |
| 283 | fputil::set_errno_if_required(ERANGE); |
| 284 | return FPBits::inf(sign: xbits.sign()).get_val(); |
| 285 | } |
| 286 | |
| 287 | // -inf <= x < 0 |
| 288 | if (x_u > 0x7fff) { |
| 289 | fputil::raise_except_if_required(FE_INVALID); |
| 290 | fputil::set_errno_if_required(EDOM); |
| 291 | return FPBits::quiet_nan().get_val(); |
| 292 | } |
| 293 | |
| 294 | // x = +inf => rsqrt(x) = +0 |
| 295 | return FPBits::zero(sign: xbits.sign()).get_val(); |
| 296 | } |
| 297 | |
| 298 | #ifdef LIBC_TARGET_CPU_HAS_FPU_FLOAT |
| 299 | float result = 1.0f / fputil::sqrt<float>(x: fputil::cast<float>(x)); |
| 300 | |
| 301 | // Targeted post-corrections to ensure correct rounding in half for specific |
| 302 | // mantissa patterns |
| 303 | const uint16_t half_mantissa = x_abs & 0x3ff; |
| 304 | if (LIBC_UNLIKELY(half_mantissa == 0x011F)) { |
| 305 | result = fputil::multiply_add(x: result, y: 0x1.0p-21f, z: result); |
| 306 | } else if (LIBC_UNLIKELY(half_mantissa == 0x0313)) { |
| 307 | result = fputil::multiply_add(x: result, y: -0x1.0p-21f, z: result); |
| 308 | } |
| 309 | |
| 310 | return fputil::cast<float16>(x: result); |
| 311 | |
| 312 | #else |
| 313 | return rsqrtf16_internal::rsqrtf16_no_float(x_abs); |
| 314 | #endif |
| 315 | } |
| 316 | |
| 317 | } // namespace math |
| 318 | } // namespace LIBC_NAMESPACE_DECL |
| 319 | |
| 320 | #endif // LIBC_TYPES_HAS_FLOAT16 |
| 321 | |
| 322 | #endif // LLVM_LIBC_SRC___SUPPORT_MATH_RSQRTF16_H |
| 323 | |