1//===- HLSLBufferLayoutBuilder.cpp ----------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "HLSLBufferLayoutBuilder.h"
10#include "CGHLSLRuntime.h"
11#include "CodeGenModule.h"
12#include "TargetInfo.h"
13#include "clang/AST/Type.h"
14#include <climits>
15
16//===----------------------------------------------------------------------===//
17// Implementation of constant buffer layout common between DirectX and
18// SPIR/SPIR-V.
19//===----------------------------------------------------------------------===//
20
21using namespace clang;
22using namespace clang::CodeGen;
23
24static const CharUnits CBufferRowSize =
25 CharUnits::fromQuantity(Quantity: llvm::hlsl::CBufferRowSizeInBytes);
26
27namespace clang {
28namespace CodeGen {
29
30llvm::StructType *
31HLSLBufferLayoutBuilder::layOutStruct(const RecordType *RT,
32 const CGHLSLOffsetInfo &OffsetInfo) {
33
34 // check if we already have the layout type for this struct
35 // TODO: Do we need to check for matching OffsetInfo?
36 if (llvm::StructType *Ty = CGM.getHLSLRuntime().getHLSLBufferLayoutType(LayoutStructTy: RT))
37 return Ty;
38
39 // iterate over all fields of the record, including fields on base classes
40 llvm::SmallVector<CXXRecordDecl *> RecordDecls;
41 RecordDecls.push_back(Elt: RT->castAsCXXRecordDecl());
42 while (RecordDecls.back()->getNumBases()) {
43 CXXRecordDecl *D = RecordDecls.back();
44 assert(D->getNumBases() == 1 &&
45 "HLSL doesn't support multiple inheritance");
46 RecordDecls.push_back(Elt: D->bases_begin()->getType()->castAsCXXRecordDecl());
47 }
48
49 SmallVector<std::pair<const FieldDecl *, uint32_t>> FieldsWithOffset;
50 unsigned OffsetIdx = 0;
51 for (const CXXRecordDecl *RD : llvm::reverse(C&: RecordDecls))
52 for (const auto *FD : RD->fields())
53 FieldsWithOffset.emplace_back(Args&: FD, Args: OffsetInfo[OffsetIdx++]);
54
55 if (!OffsetInfo.empty())
56 llvm::stable_sort(Range&: FieldsWithOffset, C: [](const auto &LHS, const auto &RHS) {
57 return CGHLSLOffsetInfo::compareOffsets(LHS: LHS.second, RHS: RHS.second);
58 });
59
60 SmallVector<llvm::Type *> Layout;
61 CharUnits CurrentOffset = CharUnits::Zero();
62 for (auto &[FD, Offset] : FieldsWithOffset) {
63 llvm::Type *LayoutType = layOutType(Type: FD->getType());
64
65 const llvm::DataLayout &DL = CGM.getDataLayout();
66 CharUnits Size =
67 CharUnits::fromQuantity(Quantity: DL.getTypeSizeInBits(Ty: LayoutType) / 8);
68 CharUnits Align = CharUnits::fromQuantity(Quantity: DL.getABITypeAlign(Ty: LayoutType));
69
70 if (LayoutType->isAggregateType() ||
71 (CurrentOffset % CBufferRowSize) + Size > CBufferRowSize)
72 Align = Align.alignTo(Align: CBufferRowSize);
73
74 CharUnits NextOffset = CurrentOffset.alignTo(Align);
75
76 if (Offset != CGHLSLOffsetInfo::Unspecified) {
77 CharUnits PackOffset = CharUnits::fromQuantity(Quantity: Offset);
78 assert(PackOffset >= NextOffset &&
79 "Offset is invalid - would overlap with previous object");
80 NextOffset = PackOffset;
81 }
82
83 if (NextOffset > CurrentOffset) {
84 llvm::Type *Padding = CGM.getTargetCodeGenInfo().getHLSLPadding(
85 CGM, NumBytes: NextOffset - CurrentOffset);
86 assert(Padding && "No padding type for target?");
87 Layout.emplace_back(Args&: Padding);
88 CurrentOffset = NextOffset;
89 }
90 Layout.emplace_back(Args&: LayoutType);
91 CurrentOffset += Size;
92 }
93
94 // Create the layout struct type; anonymous structs have empty name but
95 // non-empty qualified name
96 const auto *Decl = RT->castAsCXXRecordDecl();
97 std::string Name =
98 Decl->getName().empty() ? "anon" : Decl->getQualifiedNameAsString();
99
100 llvm::StructType *NewTy = llvm::StructType::create(Elements: Layout, Name,
101 /*isPacked=*/true);
102 CGM.getHLSLRuntime().addHLSLBufferLayoutType(LayoutStructTy: RT, LayoutTy: NewTy);
103 return NewTy;
104}
105
106llvm::Type *HLSLBufferLayoutBuilder::layOutArray(const ConstantArrayType *AT) {
107 llvm::Type *EltTy = layOutType(Type: AT->getElementType());
108 uint64_t Count = AT->getZExtSize();
109
110 CharUnits EltSize =
111 CharUnits::fromQuantity(Quantity: CGM.getDataLayout().getTypeSizeInBits(Ty: EltTy) / 8);
112 CharUnits Padding = EltSize.alignTo(Align: CBufferRowSize) - EltSize;
113
114 // If we don't have any padding between elements then we just need the array
115 // itself.
116 if (Count < 2 || Padding.isZero())
117 return llvm::ArrayType::get(ElementType: EltTy, NumElements: Count);
118
119 llvm::LLVMContext &Context = CGM.getLLVMContext();
120 llvm::Type *PaddingTy =
121 CGM.getTargetCodeGenInfo().getHLSLPadding(CGM, NumBytes: Padding);
122 assert(PaddingTy && "No padding type for target?");
123 auto *PaddedEltTy =
124 llvm::StructType::get(Context, Elements: {EltTy, PaddingTy}, /*isPacked=*/true);
125 return llvm::StructType::get(
126 Context, Elements: {llvm::ArrayType::get(ElementType: PaddedEltTy, NumElements: Count - 1), EltTy},
127 /*IsPacked=*/isPacked: true);
128}
129
130llvm::Type *HLSLBufferLayoutBuilder::layOutType(QualType Ty) {
131 if (const auto *AT = CGM.getContext().getAsConstantArrayType(T: Ty))
132 return layOutArray(AT);
133
134 if (Ty->isStructureOrClassType()) {
135 CGHLSLOffsetInfo EmptyOffsets;
136 return layOutStruct(RT: Ty->getAsCanonical<RecordType>(), OffsetInfo: EmptyOffsets);
137 }
138
139 return CGM.getTypes().ConvertTypeForMem(T: Ty);
140}
141
142} // namespace CodeGen
143} // namespace clang
144