1//===- llvm/ADT/PointerUnion.h - Pointer Type Union -------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8///
9/// \file
10/// This file defines the PointerUnion class, which is a discriminated union of
11/// pointer types.
12///
13//===----------------------------------------------------------------------===//
14
15#ifndef LLVM_ADT_POINTERUNION_H
16#define LLVM_ADT_POINTERUNION_H
17
18#include "llvm/ADT/DenseMapInfo.h"
19#include "llvm/ADT/PointerIntPair.h"
20#include "llvm/ADT/STLExtras.h"
21#include "llvm/Support/Casting.h"
22#include "llvm/Support/PointerLikeTypeTraits.h"
23#include <algorithm>
24#include <array>
25#include <cassert>
26#include <cstddef>
27#include <cstdint>
28#include <optional>
29
30namespace llvm {
31
32namespace pointer_union_detail {
33
34/// Determine the number of bits required to store values in [0, NumValues).
35/// This is ceil(log2(NumValues)).
36constexpr int bitsRequired(unsigned NumValues) {
37 return NumValues == 0 ? 0 : llvm::bit_width_constexpr(Value: NumValues - 1);
38}
39
40template <typename... Ts> constexpr int lowBitsAvailable() {
41 return std::min(
42 {static_cast<int>(PointerLikeTypeTraits<Ts>::NumLowBitsAvailable)...});
43}
44
45/// True if all types have enough low bits for a fixed-width tag.
46template <typename... PTs> constexpr bool useFixedWidthTags() {
47 return lowBitsAvailable<PTs...>() >= bitsRequired(NumValues: sizeof...(PTs));
48}
49
50/// True if types are in non-decreasing NumLowBitsAvailable order.
51// TODO: Switch to llvm::is_sorted when it becomes constexpr.
52template <typename... PTs> constexpr bool typesInNonDecreasingBitOrder() {
53 int Bits[] = {PointerLikeTypeTraits<PTs>::NumLowBitsAvailable...};
54 for (size_t I = 1; I < sizeof...(PTs); ++I)
55 if (Bits[I] < Bits[I - 1])
56 return false;
57 return true;
58}
59
60/// Tag descriptor for one type in the union.
61struct TagEntry {
62 uintptr_t Value; // Bit pattern stored in the low bits.
63 uintptr_t Mask; // Mask covering all tag bits for this entry.
64};
65
66/// Compute fixed-width tag table (all types have enough bits for the tag).
67/// For example, with 4 types and 3 available bits, the tag is 2 bits wide
68/// (values 0-3) and each entry has the same mask of 0x3.
69template <typename... PTs>
70constexpr std::array<TagEntry, sizeof...(PTs)> computeFixedTags() {
71 constexpr size_t N = sizeof...(PTs);
72 constexpr uintptr_t TagMask = (uintptr_t(1) << bitsRequired(NumValues: N)) - 1;
73 std::array<TagEntry, N> Result = {};
74 for (size_t I = 0; I < N; ++I) {
75 Result[I].Value = uintptr_t(I);
76 Result[I].Mask = TagMask;
77 }
78 return Result;
79}
80
81/// Compute variable-width tag table, or return std::nullopt if the types
82/// don't fit. Types must be in non-decreasing NumLowBitsAvailable order.
83/// Groups types by available bits into tiers; each non-final tier reserves
84/// its highest code as an escape prefix.
85///
86/// Example with 3 tiers (2-bit, 3-bit, 5-bit types):
87/// Tier 0 (2 bits): codes 0b00, 0b01, 0b10; escape = 0b11
88/// Tier 1 (3 bits): codes 0b011, escape = 0b111
89/// Tier 2 (5 bits): codes 0b00111, 0b01111, 0b10111, 0b11111
90template <typename... PTs>
91constexpr std::optional<std::array<TagEntry, sizeof...(PTs)>>
92computeExtendedTags() {
93 constexpr size_t N = sizeof...(PTs);
94 std::array<TagEntry, N> Result = {};
95 int Bits[] = {PointerLikeTypeTraits<PTs>::NumLowBitsAvailable...};
96 uintptr_t EscapePrefix = 0;
97 int PrevBits = 0;
98 size_t I = 0;
99 // Walk tiers (groups of types with the same NumLowBitsAvailable). For each
100 // tier, assign tag values using the new bits introduced by this tier,
101 // prefixed by the accumulated escape codes from previous tiers. Non-final
102 // tiers reserve their highest code as an escape to the next tier.
103 while (I < N) {
104 int TierBits = Bits[I];
105 if (TierBits < PrevBits)
106 return std::nullopt;
107 int NewBits = TierBits - PrevBits;
108 size_t TierEnd = I;
109 while (TierEnd < N && Bits[TierEnd] == TierBits)
110 ++TierEnd;
111 bool IsLastTier = (TierEnd == N);
112 size_t TypesInTier = TierEnd - I;
113 size_t Capacity =
114 IsLastTier ? (size_t(1) << NewBits) : ((size_t(1) << NewBits) - 1);
115 if (TypesInTier > Capacity)
116 return std::nullopt;
117 for (size_t J = 0; J < TypesInTier; ++J) {
118 Result[I + J].Value = EscapePrefix | (uintptr_t(J) << PrevBits);
119 Result[I + J].Mask = (uintptr_t(1) << TierBits) - 1;
120 }
121 uintptr_t EscapeCode = (uintptr_t(1) << NewBits) - 1;
122 EscapePrefix |= EscapeCode << PrevBits;
123 PrevBits = TierBits;
124 I = TierEnd;
125 }
126 return Result;
127}
128
129/// CRTP base that generates non-template constructors and assignment operators
130/// for each type in the union. Non-template constructors allow implicit
131/// conversions (derived-to-base, non-const-to-const).
132template <typename Derived, int Idx, typename... Types>
133class PointerUnionMembers;
134
135template <typename Derived, int Idx> class PointerUnionMembers<Derived, Idx> {
136protected:
137 detail::PunnedPointer<void *> Val;
138 PointerUnionMembers() : Val(uintptr_t(0)) {}
139
140 template <typename To, typename From, typename Enable>
141 friend struct ::llvm::CastInfo;
142 template <typename> friend struct ::llvm::PointerLikeTypeTraits;
143};
144
145template <typename Derived, int Idx, typename Type, typename... Types>
146class PointerUnionMembers<Derived, Idx, Type, Types...>
147 : public PointerUnionMembers<Derived, Idx + 1, Types...> {
148 using Base = PointerUnionMembers<Derived, Idx + 1, Types...>;
149
150public:
151 using Base::Base;
152 PointerUnionMembers() = default;
153
154 PointerUnionMembers(Type V) { this->Val = Derived::encode(V); }
155
156 using Base::operator=;
157 Derived &operator=(Type V) {
158 this->Val = Derived::encode(V);
159 return static_cast<Derived &>(*this);
160 }
161};
162
163} // end namespace pointer_union_detail
164
165/// A discriminated union of two or more pointer types, with the discriminator
166/// in the low bits of the pointer.
167///
168/// This implementation is extremely efficient in space due to leveraging the
169/// low bits of the pointer, while exposing a natural and type-safe API.
170///
171/// When all types have enough alignment for a fixed-width tag,
172/// the tag is placed in the high end of the available low bits, leaving spare
173/// low bits for nesting in PointerIntPair or SmallPtrSet. When types have
174/// heterogeneous alignment, a variable-length escape-encoded tag
175/// is used; in that case, types must be listed in non-decreasing
176/// NumLowBitsAvailable order.
177///
178/// Common use patterns would be something like this:
179/// PointerUnion<int*, float*> P;
180/// P = (int*)0;
181/// printf("%d %d", P.is<int*>(), P.is<float*>()); // prints "1 0"
182/// X = P.get<int*>(); // ok.
183/// Y = P.get<float*>(); // runtime assertion failure.
184/// Z = P.get<double*>(); // compile time failure.
185/// P = (float*)0;
186/// Y = P.get<float*>(); // ok.
187/// X = P.get<int*>(); // runtime assertion failure.
188/// PointerUnion<int*, int*> Q; // compile time failure.
189template <typename... PTs>
190class PointerUnion
191 : public pointer_union_detail::PointerUnionMembers<PointerUnion<PTs...>, 0,
192 PTs...> {
193 static_assert(sizeof...(PTs) > 0, "PointerUnion must have at least one type");
194 static_assert(TypesAreDistinct<PTs...>::value,
195 "PointerUnion alternative types cannot be repeated");
196
197 using Base = typename PointerUnion::PointerUnionMembers;
198 using First = TypeAtIndex<0, PTs...>;
199
200 template <typename, int, typename...>
201 friend class pointer_union_detail::PointerUnionMembers;
202 template <typename To, typename From, typename Enable> friend struct CastInfo;
203 template <typename> friend struct PointerLikeTypeTraits;
204
205 // These are constexpr functions rather than static constexpr data members
206 // so that alignof() on potentially incomplete types is not evaluated at
207 // class-definition time.
208
209 static constexpr bool useFixedWidthTags() {
210 return pointer_union_detail::useFixedWidthTags<PTs...>();
211 }
212
213 static constexpr int minLowBitsAvailable() {
214 return pointer_union_detail::lowBitsAvailable<PTs...>();
215 }
216
217 static constexpr int tagBits() {
218 return pointer_union_detail::bitsRequired(NumValues: sizeof...(PTs));
219 }
220
221 /// When using fixed-width tags, the tag is shifted to the high end of the
222 /// available low bits so that the lowest bits remain free for nesting. With
223 /// variable-width encoding mode, the tag starts at bit 0.
224 static constexpr int tagShift() {
225 return useFixedWidthTags() ? (minLowBitsAvailable() - tagBits()) : 0;
226 }
227
228 using TagTable = std::array<pointer_union_detail::TagEntry, sizeof...(PTs)>;
229
230 /// Returns the tag lookup table for this union's encoding scheme.
231 static constexpr TagTable getTagTable() {
232 if constexpr (useFixedWidthTags()) {
233 return pointer_union_detail::computeFixedTags<PTs...>();
234 } else {
235 static_assert(
236 pointer_union_detail::typesInNonDecreasingBitOrder<PTs...>(),
237 "Variable-width PointerUnion types must be in non-decreasing "
238 "NumLowBitsAvailable order");
239 constexpr auto Table =
240 pointer_union_detail::computeExtendedTags<PTs...>();
241 static_assert(Table.has_value(),
242 "Too many types for the available low bits");
243 return *Table;
244 }
245 }
246
247 // Variable-width isNull: check membership in the sparse set of tag values.
248 // A single threshold comparison does not work here because lower-tier
249 // non-null pointers can encode to values below higher-tier thresholds.
250 template <size_t... Is>
251 static constexpr bool isNullVariableImpl(uintptr_t V,
252 std::index_sequence<Is...>) {
253 constexpr TagTable Table = getTagTable();
254 static_assert(tagShift() == 0,
255 "isNullVariableImpl assumes tag starts at bit 0");
256 return ((V == Table[Is].Value) || ...);
257 }
258
259 template <typename T> static uintptr_t encode(T V) {
260 constexpr TagTable Table = getTagTable();
261 constexpr int Shift = tagShift();
262 constexpr size_t Idx = FirstIndexOfType<T, PTs...>::value;
263 static_assert(Table[0].Value == 0,
264 "First type must have tag value 0 for getAddrOfPtr1");
265 uintptr_t PtrInt = reinterpret_cast<uintptr_t>(
266 PointerLikeTypeTraits<T>::getAsVoidPointer(V));
267 assert((PtrInt & (Table[Idx].Mask << Shift)) == 0 &&
268 "Pointer low bits collide with tag");
269 return PtrInt | (Table[Idx].Value << Shift);
270 }
271
272public:
273 PointerUnion() = default;
274 PointerUnion(std::nullptr_t) : PointerUnion() {}
275 using Base::Base;
276 using Base::operator=;
277
278 /// Assignment from nullptr clears the union, resetting to the first type.
279 const PointerUnion &operator=(std::nullptr_t) {
280 this->Val = uintptr_t(0);
281 return *this;
282 }
283
284 /// Test if the pointer held in the union is null, regardless of
285 /// which type it is.
286 bool isNull() const {
287 if constexpr (useFixedWidthTags()) {
288 return (static_cast<uintptr_t>(this->Val.asInt()) >>
289 minLowBitsAvailable()) == 0;
290 } else {
291 return isNullVariableImpl(static_cast<uintptr_t>(this->Val.asInt()),
292 std::index_sequence_for<PTs...>{});
293 }
294 }
295
296 explicit operator bool() const { return !isNull(); }
297
298 // FIXME: Replace the uses of is(), get() and dyn_cast() with
299 // isa<T>, cast<T> and the llvm::dyn_cast<T>
300
301 /// Test if the Union currently holds the type matching T.
302 template <typename T> [[deprecated("Use isa instead")]] bool is() const {
303 return isa<T>(*this);
304 }
305
306 /// Returns the value of the specified pointer type.
307 ///
308 /// If the specified pointer type is incorrect, assert.
309 template <typename T> [[deprecated("Use cast instead")]] T get() const {
310 assert(isa<T>(*this) && "Invalid accessor called");
311 return cast<T>(*this);
312 }
313
314 /// Returns the current pointer if it is of the specified pointer type,
315 /// otherwise returns null.
316 template <typename T> inline T dyn_cast() const {
317 return llvm::dyn_cast_if_present<T>(*this);
318 }
319
320 /// If the union is set to the first pointer type get an address pointing to
321 /// it.
322 First const *getAddrOfPtr1() const {
323 return const_cast<PointerUnion *>(this)->getAddrOfPtr1();
324 }
325
326 /// If the union is set to the first pointer type get an address pointing to
327 /// it.
328 First *getAddrOfPtr1() {
329 static_assert(FirstIndexOfType<First, PTs...>::value == 0,
330 "First type must have tag value 0 for getAddrOfPtr1");
331 assert(isa<First>(*this) && "Val is not the first pointer");
332 // tag == 0 for first type, so asInt() is the raw pointer value.
333 assert(
334 PointerLikeTypeTraits<First>::getAsVoidPointer(cast<First>(*this)) ==
335 reinterpret_cast<void *>(this->Val.asInt()) &&
336 "Can't get the address because PointerLikeTypeTraits changes the ptr");
337 return const_cast<First *>(
338 reinterpret_cast<const First *>(this->Val.getPointerAddress()));
339 }
340
341 void *getOpaqueValue() const {
342 return reinterpret_cast<void *>(this->Val.asInt());
343 }
344
345 static inline PointerUnion getFromOpaqueValue(void *VP) {
346 PointerUnion V;
347 V.Val = reinterpret_cast<intptr_t>(VP);
348 return V;
349 }
350
351 friend bool operator==(PointerUnion lhs, PointerUnion rhs) {
352 return lhs.getOpaqueValue() == rhs.getOpaqueValue();
353 }
354
355 friend bool operator!=(PointerUnion lhs, PointerUnion rhs) {
356 return lhs.getOpaqueValue() != rhs.getOpaqueValue();
357 }
358
359 friend bool operator<(PointerUnion lhs, PointerUnion rhs) {
360 return lhs.getOpaqueValue() < rhs.getOpaqueValue();
361 }
362};
363
364// Specialization of CastInfo for PointerUnion.
365template <typename To, typename... PTs>
366struct CastInfo<To, PointerUnion<PTs...>>
367 : public DefaultDoCastIfPossible<To, PointerUnion<PTs...>,
368 CastInfo<To, PointerUnion<PTs...>>> {
369 using From = PointerUnion<PTs...>;
370
371 static inline bool isPossible(From &F) {
372 constexpr std::array<pointer_union_detail::TagEntry, sizeof...(PTs)> Table =
373 From::getTagTable();
374 constexpr int Shift = From::tagShift();
375 constexpr size_t Idx = FirstIndexOfType<To, PTs...>::value;
376 auto V = reinterpret_cast<uintptr_t>(F.getOpaqueValue());
377 constexpr uintptr_t TagMask = Table[Idx].Mask << Shift;
378 constexpr uintptr_t TagValue = Table[Idx].Value << Shift;
379 return (V & TagMask) == TagValue;
380 }
381
382 static To doCast(From &F) {
383 assert(isPossible(F) && "cast to an incompatible type!");
384 constexpr std::array<pointer_union_detail::TagEntry, sizeof...(PTs)> Table =
385 From::getTagTable();
386 constexpr int Shift = From::tagShift();
387 constexpr size_t Idx = FirstIndexOfType<To, PTs...>::value;
388 constexpr uintptr_t PtrMask = ~(uintptr_t(Table[Idx].Mask) << Shift);
389 void *Ptr = reinterpret_cast<void *>(
390 reinterpret_cast<uintptr_t>(F.getOpaqueValue()) & PtrMask);
391 return PointerLikeTypeTraits<To>::getFromVoidPointer(Ptr);
392 }
393
394 static inline To castFailed() { return To(); }
395};
396
397template <typename To, typename... PTs>
398struct CastInfo<To, const PointerUnion<PTs...>>
399 : public ConstStrippingForwardingCast<To, const PointerUnion<PTs...>,
400 CastInfo<To, PointerUnion<PTs...>>> {
401};
402
403// Teach SmallPtrSet that PointerUnion is "basically a pointer".
404// Spare low bits below the tag are available for nesting.
405// This specialization is only instantiated when used (lazy), so
406// PointerLikeTypeTraits<PTs> / alignof() are not evaluated for
407// incomplete types.
408template <typename... PTs> struct PointerLikeTypeTraits<PointerUnion<PTs...>> {
409 using Union = PointerUnion<PTs...>;
410
411 static inline void *getAsVoidPointer(const Union &P) {
412 return P.getOpaqueValue();
413 }
414
415 static inline Union getFromVoidPointer(void *P) {
416 return Union::getFromOpaqueValue(P);
417 }
418
419 // The number of bits available are the min of the pointer types minus the
420 // bits needed for the discriminator.
421 static constexpr int NumLowBitsAvailable = Union::tagShift();
422};
423
424// Teach DenseMap how to use PointerUnions as keys.
425template <typename... PTs> struct DenseMapInfo<PointerUnion<PTs...>> {
426 using Union = PointerUnion<PTs...>;
427 using FirstInfo = DenseMapInfo<TypeAtIndex<0, PTs...>>;
428
429 static inline Union getEmptyKey() { return Union(FirstInfo::getEmptyKey()); }
430
431 static inline Union getTombstoneKey() {
432 return Union(FirstInfo::getTombstoneKey());
433 }
434
435 static unsigned getHashValue(const Union &UnionVal) {
436 auto Key = reinterpret_cast<uintptr_t>(UnionVal.getOpaqueValue());
437 return DenseMapInfo<uintptr_t>::getHashValue(Val: Key);
438 }
439
440 static bool isEqual(const Union &LHS, const Union &RHS) {
441 return LHS == RHS;
442 }
443};
444
445} // end namespace llvm
446
447#endif // LLVM_ADT_POINTERUNION_H
448