1//===-- Implementation header for cbrtf16 ----------------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#ifndef LLVM_LIBC_SRC___SUPPORT_MATH_CBRTF16_H
10#define LLVM_LIBC_SRC___SUPPORT_MATH_CBRTF16_H
11
12#include "include/llvm-libc-macros/float16-macros.h"
13
14#ifdef LIBC_TYPES_HAS_FLOAT16
15
16#include "src/__support/FPUtil/FEnvImpl.h"
17#include "src/__support/FPUtil/FPBits.h"
18#include "src/__support/FPUtil/cast.h"
19#include "src/__support/FPUtil/multiply_add.h"
20#include "src/__support/FPUtil/rounding_mode.h"
21#include "src/__support/macros/config.h"
22#include "src/__support/macros/optimization.h"
23
24namespace LIBC_NAMESPACE_DECL {
25
26namespace math {
27
28LIBC_INLINE constexpr float16 cbrtf16(float16 x) {
29 // look up table for 2^(i/3) for i = 0, 1, 2 in single precision
30 constexpr float CBRT2[3] = {0x1p0f, 0x1.428a3p0f, 0x1.965feap0f};
31
32 // degree-4 polynomials approximation of ((1 + x)^(1/3) - 1)/x for 0 <= x <= 1
33 // generated by Sollya with:
34 // > display=hexadecimal;
35 // for i from 0 to 15 do {
36 // P = fpminimax(((1 + x)^(1/3) - 1)/x, 4, [|SG...|], [i/16, (i + 1)/16]);
37 // print("{", coeff(P, 0), ",", coeff(P, 1), ",", coeff(P, 2), ",",
38 // coeff(P, 3), coeff(P, 4),"},");
39 // };
40 // Then (1 + x)^(1/3) ~ 1 + x * P(x).
41 // For example: for 0 <= x <= 1/8:
42 // P(x) = 0x1.555556p-2 + x * (-0x1.c71d38p-4 + x * (0x1.f9b95ap-5 + x *
43 // (-0x1.4ebe18p-5 + x * 0x1.9ca9d2p-6)))
44 constexpr float COEFFS[16][5] = {
45 {0x1.555556p-2f, -0x1.c71ea4p-4f, 0x1.faa5f2p-5f, -0x1.64febep-5f,
46 0x1.733a46p-5f},
47 {0x1.55554ep-2f, -0x1.c715f6p-4f, 0x1.f88a9ep-5f, -0x1.4456e8p-5f,
48 0x1.5b5ef2p-6f},
49 {0x1.555508p-2f, -0x1.c6f404p-4f, 0x1.f56b7ap-5f, -0x1.33cff8p-5f,
50 0x1.18f146p-6f},
51 {0x1.5553fcp-2f, -0x1.c69bacp-4f, 0x1.efed98p-5f, -0x1.204706p-5f,
52 0x1.c90976p-7f},
53 {0x1.55517p-2f, -0x1.c5f996p-4f, 0x1.e85932p-5f, -0x1.0c0c0ep-5f,
54 0x1.77c766p-7f},
55 {0x1.554c96p-2f, -0x1.c501d2p-4f, 0x1.df0fc4p-5f, -0x1.f067f2p-6f,
56 0x1.380ab8p-7f},
57 {0x1.55448cp-2f, -0x1.c3ab1ep-4f, 0x1.d45876p-5f, -0x1.ca3988p-6f,
58 0x1.04f38ap-7f},
59 {0x1.5538aap-2f, -0x1.c1f886p-4f, 0x1.c8b11p-5f, -0x1.a6a16cp-6f,
60 0x1.b847c2p-8f},
61 {0x1.55278ap-2f, -0x1.bfd538p-4f, 0x1.bbde6p-5f, -0x1.846a8cp-6f,
62 0x1.73bfcp-8f},
63 {0x1.5511dp-2f, -0x1.bd6c88p-4f, 0x1.af0a3ap-5f, -0x1.660852p-6f,
64 0x1.3dbe34p-8f},
65 {0x1.54f82ap-2f, -0x1.bada56p-4f, 0x1.a2aa0ep-5f, -0x1.4b8c2ap-6f,
66 0x1.13379cp-8f},
67 {0x1.54d512p-2f, -0x1.b7a936p-4f, 0x1.94b91ep-5f, -0x1.30792cp-6f,
68 0x1.d7883cp-9f},
69 {0x1.54a8d8p-2f, -0x1.b3fde2p-4f, 0x1.861aeep-5f, -0x1.169484p-6f,
70 0x1.92b4cap-9f},
71 {0x1.548126p-2f, -0x1.b0f4a8p-4f, 0x1.7af574p-5f, -0x1.04644ep-6f,
72 0x1.662fb6p-9f},
73 {0x1.544b9p-2f, -0x1.ad2124p-4f, 0x1.6dd75p-5f, -0x1.e0cbecp-7f,
74 0x1.387692p-9f},
75 {0x1.5422c6p-2f, -0x1.aa61bp-4f, 0x1.64f4bap-5f, -0x1.c742b2p-7f,
76 0x1.1cf15ap-9f},
77 };
78
79 using FPBits = fputil::FPBits<float16>;
80 using FloatBits = fputil::FPBits<float>;
81
82 FPBits x_bits(x);
83
84 uint16_t x_u = x_bits.uintval();
85 uint16_t x_abs = x_u & 0x7fff;
86 uint32_t sign_bit = static_cast<uint32_t>(x_bits.is_neg())
87 << FloatBits::EXP_LEN;
88
89 // cbrtf16(0) = 0, cbrtf16(NaN) = NaN
90 if (LIBC_UNLIKELY(x_abs == 0 || x_abs >= 0x7C00)) {
91 if (x_bits.is_signaling_nan()) {
92 fputil::raise_except(FE_INVALID);
93 return FPBits::quiet_nan().uintval();
94 }
95 return x;
96 }
97
98 float xf = static_cast<float>(x);
99 FloatBits xf_bits(xf);
100
101 // for single precision float, x_e_biased = x_e + 127
102 // since x_e / 3 will round to 0, we will get incorrect
103 // results for x_e < 0 and x mod 3 != 0, so we take x_e_biased
104 // which is always positive.
105 // to calculate the correct biased exponent of the result,
106 // we need to calculate the exponent as:
107 // out_e = floor(x_e / 3) + 127
108 // now, floor((x_e_biased-1) / 3) = floor((x_e + 127 - 1) / 3)
109 // = floor((x_e + 126) / 3)
110 // = floor(x_e/3 + 42)
111 // = floor(x_e/3) + 42
112 // => out_e = (floor((x_e_biased-1) / 3) - 42) + 127
113 // => out_e = (x_e_biased-1) / 3 + (127 - 42);
114 uint32_t x_e_biased = xf_bits.get_biased_exponent();
115 uint32_t out_e = (x_e_biased - 1) / 3 + (127 - 42);
116 uint32_t shift_e = (x_e_biased - 1) % 3;
117
118 // set x_m = 2^(x_e % 3) * (1 + mantissa)
119 uint32_t x_m = xf_bits.get_mantissa();
120
121 // use the leading 4 bits for look up table
122 unsigned idx = static_cast<unsigned>(x_m >> (FloatBits::FRACTION_LEN - 4));
123
124 x_m |= static_cast<uint32_t>(FloatBits::EXP_BIAS) << FloatBits::FRACTION_LEN;
125
126 float x_reduced = FloatBits(x_m).get_val();
127 float dx = x_reduced - 1.0f;
128
129 float dx_sq = dx * dx;
130
131 // c0 = 1 + x * a0
132 float c0 = fputil::multiply_add(x: dx, y: COEFFS[idx][0], z: 1.0f);
133 // c1 = a1 + x * a2
134 float c1 = fputil::multiply_add(x: dx, y: COEFFS[idx][2], z: COEFFS[idx][1]);
135 // c2 = a3 + x * a4
136 float c2 = fputil::multiply_add(x: dx, y: COEFFS[idx][4], z: COEFFS[idx][3]);
137 // we save a multiply_add operation by decreasing the polynomial degree by 2
138 // i.e. using a degree-4 polynomial instead of degree 6.
139
140 float dx_4 = dx_sq * dx_sq;
141
142 // p0 = c0 + x^2 * c1
143 // p0 = (1 + x * a0) + x^2 * (a1 + x * a2)
144 // p0 = 1 + x * a0 + x^2 * a1 + x^3 * a2
145 float p0 = fputil::multiply_add(x: dx_sq, y: c1, z: c0);
146
147 // p1 = c2
148 // p1 = x * a4
149 float p1 = c2;
150
151 // r = p0 + x^4 * p1
152 // r = (1 + x * a0 + x^2 * a1 + x^3 * a2) + x^4 (x * a4)
153 // r = 1 + x * a0 + x^2 * a1 + x^3 * a2 + x^5 * a4
154 // r = 1 + x * (a0 + a1 * x + a2 * x^2 + a3 * x^3 + a4 * x^4)
155 // r = 1 + x * P(x)
156 float r = fputil::multiply_add(x: dx_4, y: p1, z: p0) * CBRT2[shift_e];
157
158 uint32_t r_m = FloatBits(r).get_mantissa();
159 // for float, mantissa is 23 bits (instead of 52 for double)
160 // check if the output is exact. To be exact, the smallest 1-bit of the
161 // output has to be at least 2^-7 or higher. So we check the lowest 15 bits
162 // to see if they are within 2^(-23 + 3) errors from all zeros, then the
163 // result cube root is exact.
164 if (LIBC_UNLIKELY(((r_m + 4) & 0x7fff) <= 8)) {
165 if ((r_m & 0x7fff) <= 4)
166 r_m &= 0xffff'ffe0;
167 else
168 r_m = (r_m & 0xffff'ffe0) + 0x20; // Round up to next multiple of 0x20
169 fputil::clear_except_if_required(FE_INEXACT);
170 // TODO: investigate this "hack"
171 } else if (LIBC_UNLIKELY(fputil::fenv_is_round_up()) &&
172 x_bits.get_mantissa() == 0x0253U) {
173 r_m -= 1 + x_bits.is_neg();
174 }
175
176 uint32_t r_bits = r_m | (static_cast<uint32_t>(out_e | sign_bit)
177 << FloatBits::FRACTION_LEN);
178 return fputil::cast<float16>(x: FloatBits(r_bits).get_val());
179}
180
181} // namespace math
182} // namespace LIBC_NAMESPACE_DECL
183
184#endif // LIBC_TYPES_HAS_FLOAT16
185
186#endif // LLVM_LIBC_SRC___SUPPORT_MATH_CBRTF16_H
187