| 1 | //===-- Implementation header for log_bf16 ----------------------*- C++ -*-===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | |
| 9 | #ifndef LLVM_LIBC_SRC___SUPPORT_MATH_LOG_BF16_H |
| 10 | #define LLVM_LIBC_SRC___SUPPORT_MATH_LOG_BF16_H |
| 11 | |
| 12 | #include "src/__support/FPUtil/FPBits.h" |
| 13 | #include "src/__support/FPUtil/bfloat16.h" |
| 14 | #include "src/__support/FPUtil/cast.h" |
| 15 | #include "src/__support/FPUtil/multiply_add.h" |
| 16 | #include "src/__support/common.h" |
| 17 | #include "src/__support/macros/config.h" |
| 18 | #include "src/__support/macros/optimization.h" |
| 19 | #include "src/__support/macros/properties/cpu_features.h" |
| 20 | |
| 21 | namespace LIBC_NAMESPACE_DECL { |
| 22 | |
| 23 | namespace math { |
| 24 | |
| 25 | LIBC_INLINE bfloat16 log_bf16(bfloat16 x) { |
| 26 | |
| 27 | // Generated by Sollya with the following commands: |
| 28 | // > display = hexadecimal; |
| 29 | // > round(log(2), SG, RN); |
| 30 | constexpr float BF16_LOGF_2 = 0x1.62e43p-1f; |
| 31 | |
| 32 | // Generated by Sollya with the following commands: |
| 33 | // > display = hexadecimal; |
| 34 | // > for i from 0 to 127 do print(round(log(1 + i * 2^-7), SG, RN)); |
| 35 | constexpr float LOG_1_PLUS_M[128] = { |
| 36 | 0x0.0p0f, 0x1.fe02a6p-8f, 0x1.fc0a8cp-7f, 0x1.7b91bp-6f, |
| 37 | 0x1.f829bp-6f, 0x1.39e87cp-5f, 0x1.77459p-5f, 0x1.b42dd8p-5f, |
| 38 | 0x1.f0a30cp-5f, 0x1.16536ep-4f, 0x1.341d7ap-4f, 0x1.51b074p-4f, |
| 39 | 0x1.6f0d28p-4f, 0x1.8c345ep-4f, 0x1.a926d4p-4f, 0x1.c5e548p-4f, |
| 40 | 0x1.e27076p-4f, 0x1.fec914p-4f, 0x1.0d77e8p-3f, 0x1.1b72aep-3f, |
| 41 | 0x1.29553p-3f, 0x1.371fc2p-3f, 0x1.44d2b6p-3f, 0x1.526e5ep-3f, |
| 42 | 0x1.5ff308p-3f, 0x1.6d60fep-3f, 0x1.7ab89p-3f, 0x1.87fa06p-3f, |
| 43 | 0x1.9525aap-3f, 0x1.a23bc2p-3f, 0x1.af3c94p-3f, 0x1.bc2868p-3f, |
| 44 | 0x1.c8ff7cp-3f, 0x1.d5c216p-3f, 0x1.e27076p-3f, 0x1.ef0adcp-3f, |
| 45 | 0x1.fb9186p-3f, 0x1.04025ap-2f, 0x1.0a324ep-2f, 0x1.1058cp-2f, |
| 46 | 0x1.1675cap-2f, 0x1.1c898cp-2f, 0x1.22942p-2f, 0x1.2895a2p-2f, |
| 47 | 0x1.2e8e2cp-2f, 0x1.347ddap-2f, 0x1.3a64c6p-2f, 0x1.404308p-2f, |
| 48 | 0x1.4618bcp-2f, 0x1.4be5fap-2f, 0x1.51aad8p-2f, 0x1.576772p-2f, |
| 49 | 0x1.5d1bdcp-2f, 0x1.62c83p-2f, 0x1.686c82p-2f, 0x1.6e08eap-2f, |
| 50 | 0x1.739d8p-2f, 0x1.792a56p-2f, 0x1.7eaf84p-2f, 0x1.842d1ep-2f, |
| 51 | 0x1.89a338p-2f, 0x1.8f11e8p-2f, 0x1.947942p-2f, 0x1.99d958p-2f, |
| 52 | 0x1.9f323ep-2f, 0x1.a4840ap-2f, 0x1.a9cecap-2f, 0x1.af1294p-2f, |
| 53 | 0x1.b44f78p-2f, 0x1.b9858ap-2f, 0x1.beb4dap-2f, 0x1.c3dd7ap-2f, |
| 54 | 0x1.c8ff7cp-2f, 0x1.ce1afp-2f, 0x1.d32fe8p-2f, 0x1.d83e72p-2f, |
| 55 | 0x1.dd46ap-2f, 0x1.e24882p-2f, 0x1.e74426p-2f, 0x1.ec399ep-2f, |
| 56 | 0x1.f128f6p-2f, 0x1.f6124p-2f, 0x1.faf588p-2f, 0x1.ffd2ep-2f, |
| 57 | 0x1.02552ap-1f, 0x1.04bdfap-1f, 0x1.0723e6p-1f, 0x1.0986f4p-1f, |
| 58 | 0x1.0be72ep-1f, 0x1.0e4498p-1f, 0x1.109f3ap-1f, 0x1.12f71ap-1f, |
| 59 | 0x1.154c3ep-1f, 0x1.179eacp-1f, 0x1.19ee6cp-1f, 0x1.1c3b82p-1f, |
| 60 | 0x1.1e85f6p-1f, 0x1.20cdcep-1f, 0x1.23130ep-1f, 0x1.2555bcp-1f, |
| 61 | 0x1.2795e2p-1f, 0x1.29d38p-1f, 0x1.2c0e9ep-1f, 0x1.2e4744p-1f, |
| 62 | 0x1.307d74p-1f, 0x1.32b134p-1f, 0x1.34e28ap-1f, 0x1.37117cp-1f, |
| 63 | 0x1.393e0ep-1f, 0x1.3b6844p-1f, 0x1.3d9026p-1f, 0x1.3fb5b8p-1f, |
| 64 | 0x1.41d8fep-1f, 0x1.43f9fep-1f, 0x1.4618bcp-1f, 0x1.48353ep-1f, |
| 65 | 0x1.4a4f86p-1f, 0x1.4c679ap-1f, 0x1.4e7d82p-1f, 0x1.50913cp-1f, |
| 66 | 0x1.52a2d2p-1f, 0x1.54b246p-1f, 0x1.56bf9ep-1f, 0x1.58cadcp-1f, |
| 67 | 0x1.5ad404p-1f, 0x1.5cdb1ep-1f, 0x1.5ee02ap-1f, 0x1.60e33p-1f, |
| 68 | }; |
| 69 | using FPBits = fputil::FPBits<bfloat16>; |
| 70 | FPBits x_bits(x); |
| 71 | |
| 72 | uint16_t x_u = x_bits.uintval(); |
| 73 | |
| 74 | // If x <= 0, or x is 1, or x is +inf, or x is NaN. |
| 75 | if (LIBC_UNLIKELY(x_u == 0U || x_u == 0x3f80U || x_u >= 0x7f80U)) { |
| 76 | // log(NaN) = NaN |
| 77 | if (x_bits.is_nan()) { |
| 78 | if (x_bits.is_signaling_nan()) { |
| 79 | fputil::raise_except_if_required(FE_INVALID); |
| 80 | return FPBits::quiet_nan().get_val(); |
| 81 | } |
| 82 | |
| 83 | return x; |
| 84 | } |
| 85 | |
| 86 | // log(+/-0) = −inf |
| 87 | if ((x_u & 0x7fffU) == 0U) { |
| 88 | fputil::raise_except_if_required(FE_DIVBYZERO); |
| 89 | return FPBits::inf(sign: Sign::NEG).get_val(); |
| 90 | } |
| 91 | |
| 92 | // log(1) = 0 |
| 93 | if (x_u == 0x3f80U) |
| 94 | return FPBits::zero().get_val(); |
| 95 | |
| 96 | // x < 0 |
| 97 | if (x_u > 0x8000U) { |
| 98 | fputil::set_errno_if_required(EDOM); |
| 99 | fputil::raise_except_if_required(FE_INVALID); |
| 100 | return FPBits::quiet_nan().get_val(); |
| 101 | } |
| 102 | |
| 103 | // log(+inf) = +inf |
| 104 | return FPBits::inf().get_val(); |
| 105 | } |
| 106 | |
| 107 | #ifndef LIBC_TARGET_CPU_HAS_FMA |
| 108 | // log(0.00000000000000171390679426508540927898138761520386) |
| 109 | // ~= -34.00000095 |
| 110 | if (LIBC_UNLIKELY(x_u == 0x26F7U)) |
| 111 | return bfloat16(-34.0000009); |
| 112 | #endif // LIBC_TARGET_CPU_HAS_FMA |
| 113 | |
| 114 | int e = -FPBits::EXP_BIAS; |
| 115 | |
| 116 | // When x is subnormal, normalize it. |
| 117 | if ((x_u & FPBits::EXP_MASK) == 0U) { |
| 118 | // Can't pass an integer to fputil::cast directly. |
| 119 | constexpr float NORMALIZE_EXP = 1U << FPBits::FRACTION_LEN; |
| 120 | x_bits = FPBits(x_bits.get_val() * fputil::cast<bfloat16>(x: NORMALIZE_EXP)); |
| 121 | x_u = x_bits.uintval(); |
| 122 | e -= FPBits::FRACTION_LEN; |
| 123 | } |
| 124 | |
| 125 | // To compute log(x), we perform the following range reduction: |
| 126 | // x = 2^e * (1 + m), |
| 127 | // log(x) = e * log(2) + log(1 + m). |
| 128 | // for BFloat16, mantissa is at most 7 explicit bits, so we lookup |
| 129 | // log(1 + m) in LOG_1_PLUS_M table using `m` as key. |
| 130 | |
| 131 | // Get the 7-bit mantissa directly as the table index |
| 132 | uint16_t m = x_bits.get_mantissa(); |
| 133 | |
| 134 | // Get unbiased exponent |
| 135 | e += x_u >> FPBits::FRACTION_LEN; |
| 136 | |
| 137 | return fputil::cast<bfloat16>(x: fputil::multiply_add( |
| 138 | x: static_cast<float>(e), y: BF16_LOGF_2, z: LOG_1_PLUS_M[m])); |
| 139 | } |
| 140 | |
| 141 | } // namespace math |
| 142 | |
| 143 | } // namespace LIBC_NAMESPACE_DECL |
| 144 | |
| 145 | #endif // LLVM_LIBC_SRC___SUPPORT_MATH_LOG_BF16_H |
| 146 | |