1//===- HLSLBufferLayoutBuilder.cpp ----------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "HLSLBufferLayoutBuilder.h"
10#include "CGHLSLRuntime.h"
11#include "CodeGenModule.h"
12#include "TargetInfo.h"
13#include "clang/AST/Type.h"
14#include <climits>
15
16//===----------------------------------------------------------------------===//
17// Implementation of constant buffer layout common between DirectX and
18// SPIR/SPIR-V.
19//===----------------------------------------------------------------------===//
20
21using namespace clang;
22using namespace clang::CodeGen;
23
24static const CharUnits CBufferRowSize =
25 CharUnits::fromQuantity(Quantity: llvm::hlsl::CBufferRowSizeInBytes);
26
27namespace clang {
28namespace CodeGen {
29
30llvm::StructType *
31HLSLBufferLayoutBuilder::layOutStruct(const RecordType *RT,
32 const CGHLSLOffsetInfo &OffsetInfo) {
33
34 // check if we already have the layout type for this struct
35 // TODO: Do we need to check for matching OffsetInfo?
36 if (llvm::StructType *Ty = CGM.getHLSLRuntime().getHLSLBufferLayoutType(LayoutStructTy: RT))
37 return Ty;
38
39 // iterate over all fields of the record, including fields on base classes
40 llvm::SmallVector<CXXRecordDecl *> RecordDecls;
41 RecordDecls.push_back(Elt: RT->castAsCXXRecordDecl());
42 while (RecordDecls.back()->getNumBases()) {
43 CXXRecordDecl *D = RecordDecls.back();
44 assert(D->getNumBases() == 1 &&
45 "HLSL doesn't support multiple inheritance");
46 RecordDecls.push_back(Elt: D->bases_begin()->getType()->castAsCXXRecordDecl());
47 }
48
49 SmallVector<std::pair<const FieldDecl *, uint32_t>> FieldsWithOffset;
50 unsigned OffsetIdx = 0;
51 for (const CXXRecordDecl *RD : llvm::reverse(C&: RecordDecls))
52 for (const auto *FD : RD->fields())
53 FieldsWithOffset.emplace_back(Args&: FD, Args: OffsetInfo[OffsetIdx++]);
54
55 if (!OffsetInfo.empty())
56 llvm::stable_sort(Range&: FieldsWithOffset, C: [](const auto &LHS, const auto &RHS) {
57 return CGHLSLOffsetInfo::compareOffsets(LHS: LHS.second, RHS: RHS.second);
58 });
59
60 SmallVector<llvm::Type *> Layout;
61 CharUnits CurrentOffset = CharUnits::Zero();
62 for (auto &[FD, Offset] : FieldsWithOffset) {
63 llvm::Type *LayoutType = layOutType(Type: FD->getType());
64
65 const llvm::DataLayout &DL = CGM.getDataLayout();
66 CharUnits Size =
67 CharUnits::fromQuantity(Quantity: DL.getTypeSizeInBits(Ty: LayoutType) / 8);
68 CharUnits Align = CharUnits::fromQuantity(Quantity: DL.getABITypeAlign(Ty: LayoutType));
69
70 if (LayoutType->isAggregateType() ||
71 (CurrentOffset % CBufferRowSize) + Size > CBufferRowSize)
72 Align = Align.alignTo(Align: CBufferRowSize);
73
74 CharUnits NextOffset = CurrentOffset.alignTo(Align);
75
76 if (Offset != CGHLSLOffsetInfo::Unspecified) {
77 CharUnits PackOffset = CharUnits::fromQuantity(Quantity: Offset);
78 assert(PackOffset >= NextOffset &&
79 "Offset is invalid - would overlap with previous object");
80 NextOffset = PackOffset;
81 }
82
83 if (NextOffset > CurrentOffset) {
84 llvm::Type *Padding = CGM.getTargetCodeGenInfo().getHLSLPadding(
85 CGM, NumBytes: NextOffset - CurrentOffset);
86 assert(Padding && "No padding type for target?");
87 Layout.emplace_back(Args&: Padding);
88 CurrentOffset = NextOffset;
89 }
90 Layout.emplace_back(Args&: LayoutType);
91 CurrentOffset += Size;
92 }
93
94 // Create the layout struct type; anonymous structs have empty name but
95 // non-empty qualified name
96 const auto *Decl = RT->castAsCXXRecordDecl();
97 std::string Name =
98 Decl->getName().empty() ? "anon" : Decl->getQualifiedNameAsString();
99
100 llvm::StructType *NewTy = llvm::StructType::create(Elements: Layout, Name,
101 /*isPacked=*/true);
102 CGM.getHLSLRuntime().addHLSLBufferLayoutType(LayoutStructTy: RT, LayoutTy: NewTy);
103 return NewTy;
104}
105
106llvm::Type *HLSLBufferLayoutBuilder::padArrayElements(llvm::Type *EltTy,
107 uint64_t Count) {
108 CharUnits EltSize =
109 CharUnits::fromQuantity(Quantity: CGM.getDataLayout().getTypeSizeInBits(Ty: EltTy) / 8);
110 CharUnits Padding = EltSize.alignTo(Align: CBufferRowSize) - EltSize;
111
112 // If we don't have any padding between elements then we just need the array
113 // itself.
114 if (Count < 2 || Padding.isZero())
115 return llvm::ArrayType::get(ElementType: EltTy, NumElements: Count);
116
117 llvm::LLVMContext &Context = CGM.getLLVMContext();
118 llvm::Type *PaddingTy =
119 CGM.getTargetCodeGenInfo().getHLSLPadding(CGM, NumBytes: Padding);
120 assert(PaddingTy && "No padding type for target?");
121 auto *PaddedEltTy =
122 llvm::StructType::get(Context, Elements: {EltTy, PaddingTy}, /*isPacked=*/true);
123 return llvm::StructType::get(
124 Context, Elements: {llvm::ArrayType::get(ElementType: PaddedEltTy, NumElements: Count - 1), EltTy},
125 /*IsPacked=*/isPacked: true);
126}
127
128llvm::Type *HLSLBufferLayoutBuilder::layOutArray(const ConstantArrayType *AT) {
129 llvm::Type *EltTy = layOutType(Type: AT->getElementType());
130 uint64_t Count = AT->getZExtSize();
131 return padArrayElements(EltTy, Count);
132}
133
134llvm::Type *
135HLSLBufferLayoutBuilder::layOutMatrix(const ConstantMatrixType *MT) {
136 // ConvertTypeForMem already handles row/column-major layout and bool
137 // promotion, producing [Count x <VecLen x EltTy>]. We just need to add
138 // cbuffer padding between the array elements.
139 llvm::ArrayType *MemTy =
140 cast<llvm::ArrayType>(Val: CGM.getTypes().ConvertTypeForMem(T: QualType(MT, 0)));
141 return padArrayElements(EltTy: MemTy->getElementType(), Count: MemTy->getNumElements());
142}
143
144llvm::Type *HLSLBufferLayoutBuilder::layOutType(QualType Ty) {
145 if (const auto *AT = CGM.getContext().getAsConstantArrayType(T: Ty))
146 return layOutArray(AT);
147
148 if (Ty->isStructureOrClassType()) {
149 CGHLSLOffsetInfo EmptyOffsets;
150 return layOutStruct(RT: Ty->getAsCanonical<RecordType>(), OffsetInfo: EmptyOffsets);
151 }
152
153 if (Ty->isConstantMatrixType()) {
154 const auto *MT = Ty->castAs<ConstantMatrixType>();
155 return layOutMatrix(MT);
156 }
157
158 return CGM.getTypes().ConvertTypeForMem(T: Ty);
159}
160
161} // namespace CodeGen
162} // namespace clang
163