1//===-- A class to store high precision floating point numbers --*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#ifndef LLVM_LIBC_SRC___SUPPORT_FPUTIL_DYADIC_FLOAT_H
10#define LLVM_LIBC_SRC___SUPPORT_FPUTIL_DYADIC_FLOAT_H
11
12#include "FEnvImpl.h"
13#include "FPBits.h"
14#include "hdr/errno_macros.h"
15#include "hdr/fenv_macros.h"
16#include "multiply_add.h"
17#include "rounding_mode.h"
18#include "src/__support/CPP/type_traits.h"
19#include "src/__support/big_int.h"
20#include "src/__support/macros/attributes.h"
21#include "src/__support/macros/config.h"
22#include "src/__support/macros/optimization.h" // LIBC_UNLIKELY
23#include "src/__support/macros/properties/types.h"
24
25#include <stddef.h>
26
27namespace LIBC_NAMESPACE_DECL {
28namespace fputil {
29
30// Decide whether to round a UInt up, down or not at all at a given bit
31// position, based on the current rounding mode. The assumption is that the
32// caller is going to make the integer `value >> rshift`, and then might need
33// to round it up by 1 depending on the value of the bits shifted off the
34// bottom.
35//
36// `logical_sign` causes the behavior of FE_DOWNWARD and FE_UPWARD to
37// be reversed, which is what you'd want if this is the mantissa of a
38// negative floating-point number.
39//
40// Return value is +1 if the value should be rounded up; -1 if it should be
41// rounded down; 0 if it's exact and needs no rounding.
42template <size_t Bits>
43LIBC_INLINE LIBC_CONSTEXPR_DEFAULT int
44rounding_direction(const LIBC_NAMESPACE::UInt<Bits> &value, size_t rshift,
45 [[maybe_unused]] Sign logical_sign) {
46 // logical_sign only affects FE_DOWNWARD and FE_UPWARD rounding modes. In the
47 // case of LIBC_MATH_HAS_ASSUME_ROUND_NEAREST_ONLY being enabled, this option
48 // is no-op
49
50 if (rshift == 0 || (rshift < Bits && (value << (Bits - rshift)) == 0) ||
51 (rshift >= Bits && value == 0))
52 return 0; // exact
53
54#ifdef LIBC_MATH_HAS_ASSUME_ROUND_NEAREST_ONLY
55 if (rshift > 0 && rshift <= Bits && value.get_bit(rshift - 1)) {
56 // We round up, unless the value is an exact halfway case and
57 // the bit that will end up in the units place is 0, in which
58 // case tie-break-to-even says round down.
59 bool round_bit = rshift < Bits ? value.get_bit(rshift) : 0;
60 return round_bit != 0 || (value << (Bits - rshift + 1)) != 0 ? +1 : -1;
61 } else {
62 return -1;
63 }
64#else // !LIBC_MATH_HAS_ASSUME_ROUND_NEAREST_ONLY
65 switch (quick_get_round()) {
66 case FE_TONEAREST:
67 if (rshift > 0 && rshift <= Bits && value.get_bit(rshift - 1)) {
68 // We round up, unless the value is an exact halfway case and
69 // the bit that will end up in the units place is 0, in which
70 // case tie-break-to-even says round down.
71 bool round_bit = rshift < Bits ? value.get_bit(rshift) : 0;
72 return round_bit != 0 || (value << (Bits - rshift + 1)) != 0 ? +1 : -1;
73 } else {
74 return -1;
75 }
76 case FE_TOWARDZERO:
77 return -1;
78 case FE_DOWNWARD:
79 return logical_sign.is_neg() &&
80 (rshift < Bits && (value << (Bits - rshift)) != 0)
81 ? +1
82 : -1;
83 case FE_UPWARD:
84 return logical_sign.is_pos() &&
85 (rshift < Bits && (value << (Bits - rshift)) != 0)
86 ? +1
87 : -1;
88 default:
89 __builtin_unreachable();
90 }
91#endif // LIBC_MATH_HAS_ASSUME_ROUND_NEAREST_ONLY
92}
93
94// A generic class to perform computations of high precision floating points.
95// We store the value in dyadic format, including 3 fields:
96// sign : boolean value - false means positive, true means negative
97// exponent: the exponent value of the least significant bit of the mantissa.
98// mantissa: unsigned integer of length `Bits`.
99// So the real value that is stored is:
100// real value = (-1)^sign * 2^exponent * (mantissa as unsigned integer)
101// The stored data is normal if for non-zero mantissa, the leading bit is 1.
102// The outputs of the constructors and most functions will be normalized.
103// To simplify and improve the efficiency, many functions will assume that the
104// inputs are normal.
105template <size_t Bits> struct DyadicFloat {
106 using MantissaType = LIBC_NAMESPACE::UInt<Bits>;
107
108 Sign sign = Sign::POS;
109 int exponent = 0;
110 MantissaType mantissa = MantissaType(0);
111
112 LIBC_INLINE constexpr DyadicFloat() = default;
113
114 template <typename T, cpp::enable_if_t<cpp::is_floating_point_v<T>, int> = 0>
115 LIBC_INLINE LIBC_BIT_CAST_CONSTEXPR DyadicFloat(T x) {
116 static_assert(FPBits<T>::FRACTION_LEN < Bits);
117 FPBits<T> x_bits(x);
118 sign = x_bits.sign();
119 exponent = x_bits.get_explicit_exponent() - FPBits<T>::FRACTION_LEN;
120 mantissa = MantissaType(x_bits.get_explicit_mantissa());
121 normalize();
122 }
123
124 LIBC_INLINE constexpr DyadicFloat(Sign s, int e, const MantissaType &m)
125 : sign(s), exponent(e), mantissa(m) {
126 normalize();
127 }
128
129 // Normalizing the mantissa, bringing the leading 1 bit to the most
130 // significant bit.
131 LIBC_INLINE constexpr DyadicFloat &normalize() {
132 if (!mantissa.is_zero()) {
133 int shift_length = cpp::countl_zero(mantissa);
134 exponent -= shift_length;
135 mantissa <<= static_cast<size_t>(shift_length);
136 }
137 return *this;
138 }
139
140 // Used for aligning exponents. Output might not be normalized.
141 LIBC_INLINE constexpr DyadicFloat &shift_left(unsigned shift_length) {
142 if (shift_length < Bits) {
143 exponent -= static_cast<int>(shift_length);
144 mantissa <<= shift_length;
145 } else {
146 exponent = 0;
147 mantissa = MantissaType(0);
148 }
149 return *this;
150 }
151
152 // Used for aligning exponents. Output might not be normalized.
153 LIBC_INLINE constexpr DyadicFloat &shift_right(unsigned shift_length) {
154 if (shift_length < Bits) {
155 exponent += static_cast<int>(shift_length);
156 mantissa >>= shift_length;
157 } else {
158 exponent = 0;
159 mantissa = MantissaType(0);
160 }
161 return *this;
162 }
163
164 // Assume that it is already normalized. Output the unbiased exponent.
165 LIBC_INLINE constexpr int get_unbiased_exponent() const {
166 return exponent + (Bits - 1);
167 }
168
169 // Produce a correctly rounded DyadicFloat from a too-large mantissa,
170 // by shifting it down and rounding if necessary.
171 template <size_t MantissaBits>
172 LIBC_INLINE LIBC_CONSTEXPR_DEFAULT static DyadicFloat<Bits>
173 round(Sign result_sign, int result_exponent,
174 const LIBC_NAMESPACE::UInt<MantissaBits> &input_mantissa,
175 size_t rshift) {
176 MantissaType result_mantissa(input_mantissa >> rshift);
177 if (rounding_direction(input_mantissa, rshift, result_sign) > 0) {
178 ++result_mantissa;
179 if (result_mantissa == 0) {
180 // Rounding up made the mantissa integer wrap round to 0,
181 // carrying a bit off the top. So we've rounded up to the next
182 // exponent.
183 result_mantissa.set_bit(Bits - 1);
184 ++result_exponent;
185 }
186 }
187 return DyadicFloat(result_sign, result_exponent, result_mantissa);
188 }
189
190 template <typename T, bool ShouldSignalExceptions>
191 LIBC_INLINE LIBC_CONSTEXPR_DEFAULT cpp::enable_if_t<
192 cpp::is_floating_point_v<T> && (FPBits<T>::FRACTION_LEN < Bits), T>
193 generic_as() const {
194 using FPBits = FPBits<T>;
195 using StorageType = typename FPBits::StorageType;
196
197 constexpr int EXTRA_FRACTION_LEN = Bits - 1 - FPBits::FRACTION_LEN;
198
199 if (mantissa == 0)
200 return FPBits::zero(sign).get_val();
201
202 int unbiased_exp = get_unbiased_exponent();
203
204 if (unbiased_exp + FPBits::EXP_BIAS >= FPBits::MAX_BIASED_EXPONENT) {
205 if constexpr (ShouldSignalExceptions) {
206 set_errno_if_required(ERANGE);
207 raise_except_if_required(FE_OVERFLOW | FE_INEXACT);
208 }
209
210#ifdef LIBC_MATH_HAS_ASSUME_ROUND_NEAREST_ONLY
211 return FPBits::inf(sign).get_val();
212#else // !LIBC_MATH_HAS_ASSUME_ROUND_NEAREST_ONLY
213 switch (quick_get_round()) {
214 case FE_TONEAREST:
215 return FPBits::inf(sign).get_val();
216 case FE_TOWARDZERO:
217 return FPBits::max_normal(sign).get_val();
218 case FE_DOWNWARD:
219 if (sign.is_pos())
220 return FPBits::max_normal(Sign::POS).get_val();
221 return FPBits::inf(Sign::NEG).get_val();
222 case FE_UPWARD:
223 if (sign.is_neg())
224 return FPBits::max_normal(Sign::NEG).get_val();
225 return FPBits::inf(Sign::POS).get_val();
226 default:
227 __builtin_unreachable();
228 }
229#endif // LIBC_MATH_HAS_ASSUME_ROUND_NEAREST_ONLY
230 }
231
232 StorageType out_biased_exp = 0;
233 StorageType out_mantissa = 0;
234 bool round = false;
235 bool sticky = false;
236 bool underflow = false;
237
238 if (unbiased_exp < -FPBits::EXP_BIAS - FPBits::FRACTION_LEN) {
239 sticky = true;
240 underflow = true;
241 } else if (unbiased_exp == -FPBits::EXP_BIAS - FPBits::FRACTION_LEN) {
242 round = true;
243 // underflow is detected pre-rounding FE_UNDERFLOW may be raised
244 // even if rounding produces a non-underflow result
245 underflow = true;
246 MantissaType sticky_mask = (MantissaType(1) << (Bits - 1)) - 1;
247 sticky = (mantissa & sticky_mask) != 0;
248 } else {
249 int extra_fraction_len = EXTRA_FRACTION_LEN;
250
251 if (unbiased_exp < 1 - FPBits::EXP_BIAS) {
252 underflow = true;
253 extra_fraction_len += 1 - FPBits::EXP_BIAS - unbiased_exp;
254 } else {
255 out_biased_exp =
256 static_cast<StorageType>(unbiased_exp + FPBits::EXP_BIAS);
257 }
258
259 MantissaType round_mask = MantissaType(1) << (extra_fraction_len - 1);
260 round = (mantissa & round_mask) != 0;
261 MantissaType sticky_mask = round_mask - 1;
262 sticky = (mantissa & sticky_mask) != 0;
263
264 out_mantissa = static_cast<StorageType>(mantissa >> extra_fraction_len);
265 }
266
267 bool lsb = (out_mantissa & 1) != 0;
268
269 StorageType result =
270 FPBits::create_value(sign, out_biased_exp, out_mantissa).uintval();
271
272#ifdef LIBC_MATH_HAS_ASSUME_ROUND_NEAREST_ONLY
273 if (round && (lsb || sticky))
274 ++result;
275#else
276 switch (quick_get_round()) {
277 case FE_TONEAREST:
278 if (round && (lsb || sticky))
279 ++result;
280 break;
281 case FE_DOWNWARD:
282 if (sign.is_neg() && (round || sticky))
283 ++result;
284 break;
285 case FE_UPWARD:
286 if (sign.is_pos() && (round || sticky))
287 ++result;
288 break;
289 default:
290 break;
291 }
292#endif // LIBC_MATH_HAS_ASSUME_ROUND_NEAREST_ONLY
293
294 if (ShouldSignalExceptions && (round || sticky)) {
295 int excepts = FE_INEXACT;
296 if (FPBits(result).is_inf()) {
297 set_errno_if_required(ERANGE);
298 excepts |= FE_OVERFLOW;
299 } else if (underflow) {
300 set_errno_if_required(ERANGE);
301 excepts |= FE_UNDERFLOW;
302 }
303 raise_except_if_required(excepts);
304 }
305
306 return FPBits(result).get_val();
307 }
308
309 template <typename T, bool ShouldSignalExceptions,
310 typename = cpp::enable_if_t<cpp::is_floating_point_v<T> &&
311 (FPBits<T>::FRACTION_LEN < Bits),
312 void>>
313 LIBC_INLINE LIBC_CONSTEXPR_DEFAULT T fast_as() const {
314 if (LIBC_UNLIKELY(mantissa.is_zero()))
315 return FPBits<T>::zero(sign).get_val();
316
317 // Assume that it is normalized, and output is also normal.
318 constexpr uint32_t PRECISION = FPBits<T>::FRACTION_LEN + 1;
319 using output_bits_t = typename FPBits<T>::StorageType;
320 constexpr output_bits_t IMPLICIT_MASK =
321 FPBits<T>::SIG_MASK - FPBits<T>::FRACTION_MASK;
322
323 int exp_hi = exponent + static_cast<int>((Bits - 1) + FPBits<T>::EXP_BIAS);
324
325 if (LIBC_UNLIKELY(exp_hi > 2 * FPBits<T>::EXP_BIAS)) {
326 // Results overflow.
327 T d_hi =
328 FPBits<T>::create_value(sign, 2 * FPBits<T>::EXP_BIAS, IMPLICIT_MASK)
329 .get_val();
330 // volatile prevents constant propagation that would result in infinity
331 // always being returned no matter the current rounding mode.
332 volatile T two = static_cast<T>(2.0);
333 T r = two * d_hi;
334
335 // TODO: Whether rounding down the absolute value to max_normal should
336 // also raise FE_OVERFLOW and set ERANGE is debatable.
337 if (ShouldSignalExceptions && FPBits<T>(r).is_inf())
338 set_errno_if_required(ERANGE);
339
340 return r;
341 }
342
343 bool denorm = false;
344 uint32_t shift = Bits - PRECISION;
345 if (LIBC_UNLIKELY(exp_hi <= 0)) {
346 // Output is denormal.
347 denorm = true;
348 shift = (Bits - PRECISION) + static_cast<uint32_t>(1 - exp_hi);
349
350 exp_hi = FPBits<T>::EXP_BIAS;
351 }
352
353 int exp_lo = exp_hi - static_cast<int>(PRECISION) - 1;
354
355 MantissaType m_hi =
356 shift >= MantissaType::BITS ? MantissaType(0) : mantissa >> shift;
357
358 T d_hi = FPBits<T>::create_value(
359 sign, static_cast<output_bits_t>(exp_hi),
360 (static_cast<output_bits_t>(m_hi) & FPBits<T>::SIG_MASK) |
361 IMPLICIT_MASK)
362 .get_val();
363
364 MantissaType round_mask =
365 shift - 1 >= MantissaType::BITS ? 0 : MantissaType(1) << (shift - 1);
366 MantissaType sticky_mask = round_mask - MantissaType(1);
367
368 bool round_bit = !(mantissa & round_mask).is_zero();
369 bool sticky_bit = !(mantissa & sticky_mask).is_zero();
370 int round_and_sticky = int(round_bit) * 2 + int(sticky_bit);
371
372 T d_lo;
373
374 if (LIBC_UNLIKELY(exp_lo <= 0)) {
375 // d_lo is denormal, but the output is normal.
376 int scale_up_exponent = 1 - exp_lo;
377 T scale_up_factor =
378 FPBits<T>::create_value(Sign::POS,
379 static_cast<output_bits_t>(
380 FPBits<T>::EXP_BIAS + scale_up_exponent),
381 IMPLICIT_MASK)
382 .get_val();
383 T scale_down_factor =
384 FPBits<T>::create_value(Sign::POS,
385 static_cast<output_bits_t>(
386 FPBits<T>::EXP_BIAS - scale_up_exponent),
387 IMPLICIT_MASK)
388 .get_val();
389
390 d_lo = FPBits<T>::create_value(
391 sign, static_cast<output_bits_t>(exp_lo + scale_up_exponent),
392 IMPLICIT_MASK)
393 .get_val();
394
395 return multiply_add(d_lo, T(round_and_sticky), d_hi * scale_up_factor) *
396 scale_down_factor;
397 }
398
399 d_lo = FPBits<T>::create_value(sign, static_cast<output_bits_t>(exp_lo),
400 IMPLICIT_MASK)
401 .get_val();
402
403 // Still correct without FMA instructions if `d_lo` is not underflow.
404 T r = multiply_add(d_lo, T(round_and_sticky), d_hi);
405
406 if (LIBC_UNLIKELY(denorm)) {
407 // Exponent before rounding is in denormal range, simply clear the
408 // exponent field.
409 output_bits_t clear_exp = static_cast<output_bits_t>(
410 output_bits_t(exp_hi) << FPBits<T>::SIG_LEN);
411 output_bits_t r_bits = FPBits<T>(r).uintval() - clear_exp;
412
413 if (!(r_bits & FPBits<T>::EXP_MASK)) {
414 // Output is denormal after rounding, clear the implicit bit for
415 // 80-bit long double.
416 r_bits -= IMPLICIT_MASK;
417
418 // TODO: IEEE Std 754-2019 lets implementers choose whether to check
419 // for "tininess" before or after rounding for base-2 formats, as long
420 // as the same choice is made for all operations. Our choice to check
421 // after rounding might not be the same as the hardware's.
422 if (ShouldSignalExceptions && round_and_sticky) {
423 set_errno_if_required(ERANGE);
424 raise_except_if_required(FE_UNDERFLOW);
425 }
426 }
427
428 return FPBits<T>(r_bits).get_val();
429 }
430
431 return r;
432 }
433
434 // Assume that it is already normalized.
435 // Output is rounded correctly with respect to the current rounding mode.
436 template <typename T, bool ShouldSignalExceptions,
437 typename = cpp::enable_if_t<cpp::is_floating_point_v<T> &&
438 (FPBits<T>::FRACTION_LEN < Bits),
439 void>>
440 LIBC_INLINE LIBC_CONSTEXPR_DEFAULT T as() const {
441 if constexpr (cpp::is_same_v<T, bfloat16>
442#if defined(LIBC_TYPES_HAS_FLOAT16) && !defined(__LIBC_USE_FLOAT16_CONVERSION)
443 || cpp::is_same_v<T, float16>
444#endif
445#if defined(LIBC_TYPES_HAS_FLOAT128)
446 || cpp::is_same_v<T, float128>
447#endif
448 )
449 return generic_as<T, ShouldSignalExceptions>();
450 else
451 return fast_as<T, ShouldSignalExceptions>();
452 }
453
454 template <typename T,
455 typename = cpp::enable_if_t<cpp::is_floating_point_v<T> &&
456 (FPBits<T>::FRACTION_LEN < Bits),
457 void>>
458 LIBC_INLINE explicit constexpr operator T() const {
459 return as<T, /*ShouldSignalExceptions=*/false>();
460 }
461
462 LIBC_INLINE constexpr MantissaType as_mantissa_type() const {
463 if (mantissa.is_zero())
464 return 0;
465
466 MantissaType new_mant = mantissa;
467 if (exponent > 0) {
468 new_mant <<= exponent;
469 } else {
470 // Cast the exponent to size_t before negating it, rather than after,
471 // to avoid undefined behavior negating INT_MIN as an integer (although
472 // exponents coming in to this function _shouldn't_ be that large). The
473 // result should always end up as a positive size_t.
474 size_t shift = -static_cast<size_t>(exponent);
475 new_mant >>= shift;
476 }
477
478 if (sign.is_neg()) {
479 new_mant = (~new_mant) + 1;
480 }
481
482 return new_mant;
483 }
484
485 LIBC_INLINE LIBC_CONSTEXPR_DEFAULT MantissaType
486 as_mantissa_type_rounded(int *round_dir_out = nullptr) const {
487 int round_dir = 0;
488 MantissaType new_mant;
489 if (mantissa.is_zero()) {
490 new_mant = 0;
491 } else {
492 new_mant = mantissa;
493 if (exponent > 0) {
494 new_mant <<= exponent;
495 } else if (exponent < 0) {
496 // Cast the exponent to size_t before negating it, rather than after,
497 // to avoid undefined behavior negating INT_MIN as an integer
498 // (although exponents coming in to this function _shouldn't_ be that
499 // large). The result should always end up as a positive size_t.
500 size_t shift = -static_cast<size_t>(exponent);
501 if (shift >= Bits)
502 new_mant = 0;
503 else
504 new_mant >>= shift;
505 round_dir = rounding_direction(mantissa, shift, sign);
506 if (round_dir > 0)
507 ++new_mant;
508 }
509
510 if (sign.is_neg()) {
511 new_mant = (~new_mant) + 1;
512 }
513 }
514
515 if (round_dir_out)
516 *round_dir_out = round_dir;
517
518 return new_mant;
519 }
520
521 LIBC_INLINE constexpr DyadicFloat operator-() const {
522 return DyadicFloat(sign.negate(), exponent, mantissa);
523 }
524};
525
526// Quick add - Add 2 dyadic floats with rounding toward 0 and then normalize
527// the output:
528// - Align the exponents so that:
529// new a.exponent = new b.exponent = max(a.exponent, b.exponent)
530// - Add or subtract the mantissas depending on the signs.
531// - Normalize the result.
532// The absolute errors compared to the mathematical sum is bounded by:
533// | quick_add(a, b) - (a + b) | < MSB(a + b) * 2^(-Bits + 2),
534// i.e., errors are up to 2 ULPs.
535// Assume inputs are normalized (by constructors or other functions) so that
536// we don't need to normalize the inputs again in this function. If the
537// inputs are not normalized, the results might lose precision significantly.
538template <size_t Bits>
539LIBC_INLINE constexpr DyadicFloat<Bits> quick_add(DyadicFloat<Bits> a,
540 DyadicFloat<Bits> b) {
541 if (LIBC_UNLIKELY(a.mantissa.is_zero()))
542 return b;
543 if (LIBC_UNLIKELY(b.mantissa.is_zero()))
544 return a;
545
546 // Align exponents
547 if (a.exponent > b.exponent)
548 b.shift_right(static_cast<unsigned>(a.exponent - b.exponent));
549 else if (b.exponent > a.exponent)
550 a.shift_right(static_cast<unsigned>(b.exponent - a.exponent));
551
552 DyadicFloat<Bits> result;
553
554 if (a.sign == b.sign) {
555 // Addition
556 result.sign = a.sign;
557 result.exponent = a.exponent;
558 result.mantissa = a.mantissa;
559 if (result.mantissa.add_overflow(b.mantissa)) {
560 // Mantissa addition overflow.
561 result.shift_right(1);
562 result.mantissa.val[DyadicFloat<Bits>::MantissaType::WORD_COUNT - 1] |=
563 (uint64_t(1) << 63);
564 }
565 // Result is already normalized.
566 return result;
567 }
568
569 // Subtraction
570 if (a.mantissa >= b.mantissa) {
571 result.sign = a.sign;
572 result.exponent = a.exponent;
573 result.mantissa = a.mantissa - b.mantissa;
574 } else {
575 result.sign = b.sign;
576 result.exponent = b.exponent;
577 result.mantissa = b.mantissa - a.mantissa;
578 }
579
580 return result.normalize();
581}
582
583template <size_t Bits>
584LIBC_INLINE constexpr DyadicFloat<Bits> quick_sub(DyadicFloat<Bits> a,
585 DyadicFloat<Bits> b) {
586 return quick_add(a, -b);
587}
588
589// Quick Mul - Slightly less accurate but efficient multiplication of 2 dyadic
590// floats with rounding toward 0 and then normalize the output:
591// result.exponent = a.exponent + b.exponent + Bits,
592// result.mantissa = quick_mul_hi(a.mantissa + b.mantissa)
593// ~ (full product a.mantissa * b.mantissa) >> Bits.
594// The errors compared to the mathematical product is bounded by:
595// 2 * errors of quick_mul_hi = 2 * (UInt<Bits>::WORD_COUNT - 1) in ULPs.
596// Assume inputs are normalized (by constructors or other functions) so that
597// we don't need to normalize the inputs again in this function. If the
598// inputs are not normalized, the results might lose precision significantly.
599template <size_t Bits>
600LIBC_INLINE constexpr DyadicFloat<Bits> quick_mul(const DyadicFloat<Bits> &a,
601 const DyadicFloat<Bits> &b) {
602 DyadicFloat<Bits> result;
603 result.sign = (a.sign != b.sign) ? Sign::NEG : Sign::POS;
604 result.exponent = a.exponent + b.exponent + static_cast<int>(Bits);
605
606 if (!(a.mantissa.is_zero() || b.mantissa.is_zero())) {
607 result.mantissa = a.mantissa.quick_mul_hi(b.mantissa);
608 // Check the leading bit directly, should be faster than using clz in
609 // normalize().
610 if (result.mantissa.val[DyadicFloat<Bits>::MantissaType::WORD_COUNT - 1] >>
611 (DyadicFloat<Bits>::MantissaType::WORD_SIZE - 1) ==
612 0)
613 result.shift_left(1);
614 } else {
615 result.mantissa = (typename DyadicFloat<Bits>::MantissaType)(0);
616 }
617 return result;
618}
619
620// Correctly rounded multiplication of 2 dyadic floats, assuming the
621// exponent remains within range.
622template <size_t Bits>
623LIBC_INLINE constexpr DyadicFloat<Bits>
624rounded_mul(const DyadicFloat<Bits> &a, const DyadicFloat<Bits> &b) {
625 using DblMant = LIBC_NAMESPACE::UInt<(2 * Bits)>;
626 Sign result_sign = (a.sign != b.sign) ? Sign::NEG : Sign::POS;
627 int result_exponent = a.exponent + b.exponent + static_cast<int>(Bits);
628 auto product = DblMant(a.mantissa) * DblMant(b.mantissa);
629 // As in quick_mul(), renormalize by 1 bit manually rather than countl_zero
630 if (product.get_bit(2 * Bits - 1) == 0) {
631 product <<= 1;
632 result_exponent -= 1;
633 }
634
635 return DyadicFloat<Bits>::round(result_sign, result_exponent, product, Bits);
636}
637
638// Approximate reciprocal - given a nonzero a, make a good approximation to
639// 1/a. The method is Newton-Raphson iteration, based on quick_mul.
640template <size_t Bits, typename = cpp::enable_if_t<(Bits >= 32)>>
641LIBC_INLINE constexpr DyadicFloat<Bits>
642approx_reciprocal(const DyadicFloat<Bits> &a) {
643 // Given an approximation x to 1/a, a better one is x' = x(2-ax).
644 //
645 // You can derive this by using the Newton-Raphson formula with the function
646 // f(x) = 1/x - a. But another way to see that it works is to say: suppose
647 // that ax = 1-e for some small error e. Then ax' = ax(2-ax) = (1-e)(1+e) =
648 // 1-e^2. So the error in x' is the square of the error in x, i.e. the
649 // number of correct bits in x' is double the number in x.
650
651 // An initial approximation to the reciprocal
652 DyadicFloat<Bits> x(Sign::POS, -32 - a.exponent - int(Bits),
653 uint64_t(0xFFFFFFFFFFFFFFFF) /
654 static_cast<uint64_t>(a.mantissa >> (Bits - 32)));
655
656 // The constant 2, which we'll need in every iteration
657 DyadicFloat<Bits> two(Sign::POS, 1, 1);
658
659 // We expect at least 31 correct bits from our 32-bit starting approximation
660 size_t ok_bits = 31;
661
662 // The number of good bits doubles in each iteration, except that rounding
663 // errors introduce a little extra each time. Subtract a bit from our
664 // accuracy assessment to account for that.
665 while (ok_bits < Bits) {
666 x = quick_mul(x, quick_sub(two, quick_mul(a, x)));
667 ok_bits = 2 * ok_bits - 1;
668 }
669
670 return x;
671}
672
673// Correctly rounded division of 2 dyadic floats, assuming the
674// exponent remains within range.
675template <size_t Bits>
676LIBC_INLINE constexpr DyadicFloat<Bits>
677rounded_div(const DyadicFloat<Bits> &af, const DyadicFloat<Bits> &bf) {
678 using DblMant = LIBC_NAMESPACE::UInt<(Bits * 2 + 64)>;
679
680 // Make an approximation to the quotient as a * (1/b). Both the
681 // multiplication and the reciprocal are a bit sloppy, which doesn't
682 // matter, because we're going to correct for that below.
683 auto qf = fputil::quick_mul(af, fputil::approx_reciprocal(bf));
684
685 // Switch to BigInt and stop using quick_add and quick_mul: now
686 // we're working in exact integers so as to get the true remainder.
687 DblMant a = af.mantissa, b = bf.mantissa, q = qf.mantissa;
688 q <<= 2; // leave room for a round bit, even if exponent decreases
689 a <<= af.exponent - bf.exponent - qf.exponent + 2;
690 DblMant qb = q * b;
691 if (qb < a) {
692 DblMant too_small = a - b;
693 while (qb <= too_small) {
694 qb += b;
695 ++q;
696 }
697 } else {
698 while (qb > a) {
699 qb -= b;
700 --q;
701 }
702 }
703
704 DyadicFloat<(Bits * 2)> qbig(qf.sign, qf.exponent - 2, q);
705 return DyadicFloat<Bits>::round(qbig.sign, qbig.exponent + Bits,
706 qbig.mantissa, Bits);
707}
708
709// Simple polynomial approximation.
710template <size_t Bits>
711LIBC_INLINE constexpr DyadicFloat<Bits>
712multiply_add(const DyadicFloat<Bits> &a, const DyadicFloat<Bits> &b,
713 const DyadicFloat<Bits> &c) {
714 return quick_add(c, quick_mul(a, b));
715}
716
717// Simple exponentiation implementation for printf. Only handles positive
718// exponents, since division isn't implemented.
719template <size_t Bits>
720LIBC_INLINE constexpr DyadicFloat<Bits> pow_n(const DyadicFloat<Bits> &a,
721 uint32_t power) {
722 DyadicFloat<Bits> result = 1.0;
723 DyadicFloat<Bits> cur_power = a;
724
725 while (power > 0) {
726 if ((power % 2) > 0) {
727 result = quick_mul(result, cur_power);
728 }
729 power = power >> 1;
730 cur_power = quick_mul(cur_power, cur_power);
731 }
732 return result;
733}
734
735template <size_t Bits>
736LIBC_INLINE constexpr DyadicFloat<Bits> mul_pow_2(const DyadicFloat<Bits> &a,
737 int32_t pow_2) {
738 DyadicFloat<Bits> result = a;
739 result.exponent += pow_2;
740 return result;
741}
742
743} // namespace fputil
744} // namespace LIBC_NAMESPACE_DECL
745
746#endif // LLVM_LIBC_SRC___SUPPORT_FPUTIL_DYADIC_FLOAT_H
747