1//===- HLSLRootSignatureUtils.cpp - HLSL Root Signature helpers -----------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8///
9/// \file This file contains helpers for working with HLSL Root Signatures.
10///
11//===----------------------------------------------------------------------===//
12
13#include "llvm/Frontend/HLSL/HLSLRootSignatureUtils.h"
14#include "llvm/ADT/SmallString.h"
15#include "llvm/ADT/bit.h"
16#include "llvm/IR/IRBuilder.h"
17#include "llvm/IR/Metadata.h"
18#include "llvm/Support/ScopedPrinter.h"
19
20namespace llvm {
21namespace hlsl {
22namespace rootsig {
23
24template <typename T>
25static std::optional<StringRef> getEnumName(const T Value,
26 ArrayRef<EnumEntry<T>> Enums) {
27 for (const auto &EnumItem : Enums)
28 if (EnumItem.Value == Value)
29 return EnumItem.Name;
30 return std::nullopt;
31}
32
33template <typename T>
34static raw_ostream &printEnum(raw_ostream &OS, const T Value,
35 ArrayRef<EnumEntry<T>> Enums) {
36 auto MaybeName = getEnumName(Value, Enums);
37 if (MaybeName)
38 OS << *MaybeName;
39 return OS;
40}
41
42template <typename T>
43static raw_ostream &printFlags(raw_ostream &OS, const T Value,
44 ArrayRef<EnumEntry<T>> Flags) {
45 bool FlagSet = false;
46 unsigned Remaining = llvm::to_underlying(Value);
47 while (Remaining) {
48 unsigned Bit = 1u << llvm::countr_zero(Val: Remaining);
49 if (Remaining & Bit) {
50 if (FlagSet)
51 OS << " | ";
52
53 auto MaybeFlag = getEnumName(T(Bit), Flags);
54 if (MaybeFlag)
55 OS << *MaybeFlag;
56 else
57 OS << "invalid: " << Bit;
58
59 FlagSet = true;
60 }
61 Remaining &= ~Bit;
62 }
63
64 if (!FlagSet)
65 OS << "None";
66 return OS;
67}
68
69static const EnumEntry<RegisterType> RegisterNames[] = {
70 {"b", RegisterType::BReg},
71 {"t", RegisterType::TReg},
72 {"u", RegisterType::UReg},
73 {"s", RegisterType::SReg},
74};
75
76static raw_ostream &operator<<(raw_ostream &OS, const Register &Reg) {
77 printEnum(OS, Value: Reg.ViewType, Enums: ArrayRef(RegisterNames));
78 OS << Reg.Number;
79
80 return OS;
81}
82
83static raw_ostream &operator<<(raw_ostream &OS,
84 const llvm::dxbc::ShaderVisibility &Visibility) {
85 printEnum(OS, Value: Visibility, Enums: dxbc::getShaderVisibility());
86
87 return OS;
88}
89
90static raw_ostream &operator<<(raw_ostream &OS,
91 const llvm::dxbc::SamplerFilter &Filter) {
92 printEnum(OS, Value: Filter, Enums: dxbc::getSamplerFilters());
93
94 return OS;
95}
96
97static raw_ostream &operator<<(raw_ostream &OS,
98 const dxbc::TextureAddressMode &Address) {
99 printEnum(OS, Value: Address, Enums: dxbc::getTextureAddressModes());
100
101 return OS;
102}
103
104static raw_ostream &operator<<(raw_ostream &OS,
105 const dxbc::ComparisonFunc &CompFunc) {
106 printEnum(OS, Value: CompFunc, Enums: dxbc::getComparisonFuncs());
107
108 return OS;
109}
110
111static raw_ostream &operator<<(raw_ostream &OS,
112 const dxbc::StaticBorderColor &BorderColor) {
113 printEnum(OS, Value: BorderColor, Enums: dxbc::getStaticBorderColors());
114
115 return OS;
116}
117
118static const EnumEntry<dxil::ResourceClass> ResourceClassNames[] = {
119 {"CBV", dxil::ResourceClass::CBuffer},
120 {"SRV", dxil::ResourceClass::SRV},
121 {"UAV", dxil::ResourceClass::UAV},
122 {"Sampler", dxil::ResourceClass::Sampler},
123};
124
125static raw_ostream &operator<<(raw_ostream &OS, const ClauseType &Type) {
126 printEnum(OS, Value: dxil::ResourceClass(llvm::to_underlying(E: Type)),
127 Enums: ArrayRef(ResourceClassNames));
128
129 return OS;
130}
131
132static raw_ostream &operator<<(raw_ostream &OS,
133 const dxbc::RootDescriptorFlags &Flags) {
134 printFlags(OS, Value: Flags, Flags: dxbc::getRootDescriptorFlags());
135
136 return OS;
137}
138
139static raw_ostream &operator<<(raw_ostream &OS,
140 const llvm::dxbc::DescriptorRangeFlags &Flags) {
141 printFlags(OS, Value: Flags, Flags: dxbc::getDescriptorRangeFlags());
142
143 return OS;
144}
145
146raw_ostream &operator<<(raw_ostream &OS, const dxbc::RootFlags &Flags) {
147 OS << "RootFlags(";
148 printFlags(OS, Value: Flags, Flags: dxbc::getRootFlags());
149 OS << ")";
150
151 return OS;
152}
153
154raw_ostream &operator<<(raw_ostream &OS, const RootConstants &Constants) {
155 OS << "RootConstants(num32BitConstants = " << Constants.Num32BitConstants
156 << ", " << Constants.Reg << ", space = " << Constants.Space
157 << ", visibility = " << Constants.Visibility << ")";
158
159 return OS;
160}
161
162raw_ostream &operator<<(raw_ostream &OS, const DescriptorTable &Table) {
163 OS << "DescriptorTable(numClauses = " << Table.NumClauses
164 << ", visibility = " << Table.Visibility << ")";
165
166 return OS;
167}
168
169raw_ostream &operator<<(raw_ostream &OS, const DescriptorTableClause &Clause) {
170 OS << Clause.Type << "(" << Clause.Reg << ", numDescriptors = ";
171 if (Clause.NumDescriptors == NumDescriptorsUnbounded)
172 OS << "unbounded";
173 else
174 OS << Clause.NumDescriptors;
175 OS << ", space = " << Clause.Space << ", offset = ";
176 if (Clause.Offset == DescriptorTableOffsetAppend)
177 OS << "DescriptorTableOffsetAppend";
178 else
179 OS << Clause.Offset;
180 OS << ", flags = " << Clause.Flags << ")";
181
182 return OS;
183}
184
185raw_ostream &operator<<(raw_ostream &OS, const RootDescriptor &Descriptor) {
186 ClauseType Type = ClauseType(llvm::to_underlying(E: Descriptor.Type));
187 OS << "Root" << Type << "(" << Descriptor.Reg
188 << ", space = " << Descriptor.Space
189 << ", visibility = " << Descriptor.Visibility
190 << ", flags = " << Descriptor.Flags << ")";
191
192 return OS;
193}
194
195raw_ostream &operator<<(raw_ostream &OS, const StaticSampler &Sampler) {
196 OS << "StaticSampler(" << Sampler.Reg << ", filter = " << Sampler.Filter
197 << ", addressU = " << Sampler.AddressU
198 << ", addressV = " << Sampler.AddressV
199 << ", addressW = " << Sampler.AddressW
200 << ", mipLODBias = " << Sampler.MipLODBias
201 << ", maxAnisotropy = " << Sampler.MaxAnisotropy
202 << ", comparisonFunc = " << Sampler.CompFunc
203 << ", borderColor = " << Sampler.BorderColor
204 << ", minLOD = " << Sampler.MinLOD << ", maxLOD = " << Sampler.MaxLOD
205 << ", space = " << Sampler.Space << ", visibility = " << Sampler.Visibility
206 << ")";
207 return OS;
208}
209
210namespace {
211
212// We use the OverloadVisit with std::visit to ensure the compiler catches if a
213// new RootElement variant type is added but it's operator<< or metadata
214// generation isn't handled.
215template <class... Ts> struct OverloadedVisit : Ts... {
216 using Ts::operator()...;
217};
218template <class... Ts> OverloadedVisit(Ts...) -> OverloadedVisit<Ts...>;
219
220} // namespace
221
222raw_ostream &operator<<(raw_ostream &OS, const RootElement &Element) {
223 const auto Visitor = OverloadedVisit{
224 [&OS](const dxbc::RootFlags &Flags) { OS << Flags; },
225 [&OS](const RootConstants &Constants) { OS << Constants; },
226 [&OS](const RootDescriptor &Descriptor) { OS << Descriptor; },
227 [&OS](const DescriptorTableClause &Clause) { OS << Clause; },
228 [&OS](const DescriptorTable &Table) { OS << Table; },
229 [&OS](const StaticSampler &Sampler) { OS << Sampler; },
230 };
231 std::visit(visitor: Visitor, variants: Element);
232 return OS;
233}
234
235void dumpRootElements(raw_ostream &OS, ArrayRef<RootElement> Elements) {
236 OS << " RootElements{";
237 bool First = true;
238 for (const RootElement &Element : Elements) {
239 if (!First)
240 OS << ",";
241 OS << " " << Element;
242 First = false;
243 }
244 OS << "}";
245}
246
247MDNode *MetadataBuilder::BuildRootSignature() {
248 const auto Visitor = OverloadedVisit{
249 [this](const dxbc::RootFlags &Flags) -> MDNode * {
250 return BuildRootFlags(Flags);
251 },
252 [this](const RootConstants &Constants) -> MDNode * {
253 return BuildRootConstants(Constants);
254 },
255 [this](const RootDescriptor &Descriptor) -> MDNode * {
256 return BuildRootDescriptor(Descriptor);
257 },
258 [this](const DescriptorTableClause &Clause) -> MDNode * {
259 return BuildDescriptorTableClause(Clause);
260 },
261 [this](const DescriptorTable &Table) -> MDNode * {
262 return BuildDescriptorTable(Table);
263 },
264 [this](const StaticSampler &Sampler) -> MDNode * {
265 return BuildStaticSampler(Sampler);
266 },
267 };
268
269 for (const RootElement &Element : Elements) {
270 MDNode *ElementMD = std::visit(visitor: Visitor, variants: Element);
271 assert(ElementMD != nullptr &&
272 "Root Element must be initialized and validated");
273 GeneratedMetadata.push_back(Elt: ElementMD);
274 }
275
276 return MDNode::get(Context&: Ctx, MDs: GeneratedMetadata);
277}
278
279MDNode *MetadataBuilder::BuildRootFlags(const dxbc::RootFlags &Flags) {
280 IRBuilder<> Builder(Ctx);
281 Metadata *Operands[] = {
282 MDString::get(Context&: Ctx, Str: "RootFlags"),
283 ConstantAsMetadata::get(C: Builder.getInt32(C: llvm::to_underlying(E: Flags))),
284 };
285 return MDNode::get(Context&: Ctx, MDs: Operands);
286}
287
288MDNode *MetadataBuilder::BuildRootConstants(const RootConstants &Constants) {
289 IRBuilder<> Builder(Ctx);
290 Metadata *Operands[] = {
291 MDString::get(Context&: Ctx, Str: "RootConstants"),
292 ConstantAsMetadata::get(
293 C: Builder.getInt32(C: llvm::to_underlying(E: Constants.Visibility))),
294 ConstantAsMetadata::get(C: Builder.getInt32(C: Constants.Reg.Number)),
295 ConstantAsMetadata::get(C: Builder.getInt32(C: Constants.Space)),
296 ConstantAsMetadata::get(C: Builder.getInt32(C: Constants.Num32BitConstants)),
297 };
298 return MDNode::get(Context&: Ctx, MDs: Operands);
299}
300
301MDNode *MetadataBuilder::BuildRootDescriptor(const RootDescriptor &Descriptor) {
302 IRBuilder<> Builder(Ctx);
303 std::optional<StringRef> TypeName =
304 getEnumName(Value: dxil::ResourceClass(llvm::to_underlying(E: Descriptor.Type)),
305 Enums: ArrayRef(ResourceClassNames));
306 assert(TypeName && "Provided an invalid Resource Class");
307 llvm::SmallString<7> Name({"Root", *TypeName});
308 Metadata *Operands[] = {
309 MDString::get(Context&: Ctx, Str: Name),
310 ConstantAsMetadata::get(
311 C: Builder.getInt32(C: llvm::to_underlying(E: Descriptor.Visibility))),
312 ConstantAsMetadata::get(C: Builder.getInt32(C: Descriptor.Reg.Number)),
313 ConstantAsMetadata::get(C: Builder.getInt32(C: Descriptor.Space)),
314 ConstantAsMetadata::get(
315 C: Builder.getInt32(C: llvm::to_underlying(E: Descriptor.Flags))),
316 };
317 return MDNode::get(Context&: Ctx, MDs: Operands);
318}
319
320MDNode *MetadataBuilder::BuildDescriptorTable(const DescriptorTable &Table) {
321 IRBuilder<> Builder(Ctx);
322 SmallVector<Metadata *> TableOperands;
323 // Set the mandatory arguments
324 TableOperands.push_back(Elt: MDString::get(Context&: Ctx, Str: "DescriptorTable"));
325 TableOperands.push_back(Elt: ConstantAsMetadata::get(
326 C: Builder.getInt32(C: llvm::to_underlying(E: Table.Visibility))));
327
328 // Remaining operands are references to the table's clauses. The in-memory
329 // representation of the Root Elements created from parsing will ensure that
330 // the previous N elements are the clauses for this table.
331 assert(Table.NumClauses <= GeneratedMetadata.size() &&
332 "Table expected all owned clauses to be generated already");
333 // So, add a refence to each clause to our operands
334 TableOperands.append(in_start: GeneratedMetadata.end() - Table.NumClauses,
335 in_end: GeneratedMetadata.end());
336 // Then, remove those clauses from the general list of Root Elements
337 GeneratedMetadata.pop_back_n(NumItems: Table.NumClauses);
338
339 return MDNode::get(Context&: Ctx, MDs: TableOperands);
340}
341
342MDNode *MetadataBuilder::BuildDescriptorTableClause(
343 const DescriptorTableClause &Clause) {
344 IRBuilder<> Builder(Ctx);
345 std::optional<StringRef> Name =
346 getEnumName(Value: dxil::ResourceClass(llvm::to_underlying(E: Clause.Type)),
347 Enums: ArrayRef(ResourceClassNames));
348 assert(Name && "Provided an invalid Resource Class");
349 Metadata *Operands[] = {
350 MDString::get(Context&: Ctx, Str: *Name),
351 ConstantAsMetadata::get(C: Builder.getInt32(C: Clause.NumDescriptors)),
352 ConstantAsMetadata::get(C: Builder.getInt32(C: Clause.Reg.Number)),
353 ConstantAsMetadata::get(C: Builder.getInt32(C: Clause.Space)),
354 ConstantAsMetadata::get(C: Builder.getInt32(C: Clause.Offset)),
355 ConstantAsMetadata::get(
356 C: Builder.getInt32(C: llvm::to_underlying(E: Clause.Flags))),
357 };
358 return MDNode::get(Context&: Ctx, MDs: Operands);
359}
360
361MDNode *MetadataBuilder::BuildStaticSampler(const StaticSampler &Sampler) {
362 IRBuilder<> Builder(Ctx);
363 Metadata *Operands[] = {
364 MDString::get(Context&: Ctx, Str: "StaticSampler"),
365 ConstantAsMetadata::get(
366 C: Builder.getInt32(C: llvm::to_underlying(E: Sampler.Filter))),
367 ConstantAsMetadata::get(
368 C: Builder.getInt32(C: llvm::to_underlying(E: Sampler.AddressU))),
369 ConstantAsMetadata::get(
370 C: Builder.getInt32(C: llvm::to_underlying(E: Sampler.AddressV))),
371 ConstantAsMetadata::get(
372 C: Builder.getInt32(C: llvm::to_underlying(E: Sampler.AddressW))),
373 ConstantAsMetadata::get(C: llvm::ConstantFP::get(Ty: llvm::Type::getFloatTy(C&: Ctx),
374 V: Sampler.MipLODBias)),
375 ConstantAsMetadata::get(C: Builder.getInt32(C: Sampler.MaxAnisotropy)),
376 ConstantAsMetadata::get(
377 C: Builder.getInt32(C: llvm::to_underlying(E: Sampler.CompFunc))),
378 ConstantAsMetadata::get(
379 C: Builder.getInt32(C: llvm::to_underlying(E: Sampler.BorderColor))),
380 ConstantAsMetadata::get(
381 C: llvm::ConstantFP::get(Ty: llvm::Type::getFloatTy(C&: Ctx), V: Sampler.MinLOD)),
382 ConstantAsMetadata::get(
383 C: llvm::ConstantFP::get(Ty: llvm::Type::getFloatTy(C&: Ctx), V: Sampler.MaxLOD)),
384 ConstantAsMetadata::get(C: Builder.getInt32(C: Sampler.Reg.Number)),
385 ConstantAsMetadata::get(C: Builder.getInt32(C: Sampler.Space)),
386 ConstantAsMetadata::get(
387 C: Builder.getInt32(C: llvm::to_underlying(E: Sampler.Visibility))),
388 };
389 return MDNode::get(Context&: Ctx, MDs: Operands);
390}
391
392std::optional<const RangeInfo *>
393ResourceRange::getOverlapping(const RangeInfo &Info) const {
394 MapT::const_iterator Interval = Intervals.find(x: Info.LowerBound);
395 if (!Interval.valid() || Info.UpperBound < Interval.start())
396 return std::nullopt;
397 return Interval.value();
398}
399
400const RangeInfo *ResourceRange::lookup(uint32_t X) const {
401 return Intervals.lookup(x: X, NotFound: nullptr);
402}
403
404void ResourceRange::clear() { return Intervals.clear(); }
405
406std::optional<const RangeInfo *> ResourceRange::insert(const RangeInfo &Info) {
407 uint32_t LowerBound = Info.LowerBound;
408 uint32_t UpperBound = Info.UpperBound;
409
410 std::optional<const RangeInfo *> Res = std::nullopt;
411 MapT::iterator Interval = Intervals.begin();
412
413 while (true) {
414 if (UpperBound < LowerBound)
415 break;
416
417 Interval.advanceTo(x: LowerBound);
418 if (!Interval.valid()) // No interval found
419 break;
420
421 // Let Interval = [x;y] and [LowerBound;UpperBound] = [a;b] and note that
422 // a <= y implicitly from Intervals.find(LowerBound)
423 if (UpperBound < Interval.start())
424 break; // found interval does not overlap with inserted one
425
426 if (!Res.has_value()) // Update to be the first found intersection
427 Res = Interval.value();
428
429 if (Interval.start() <= LowerBound && UpperBound <= Interval.stop()) {
430 // x <= a <= b <= y implies that [a;b] is covered by [x;y]
431 // -> so we don't need to insert this, report an overlap
432 return Res;
433 } else if (LowerBound <= Interval.start() &&
434 Interval.stop() <= UpperBound) {
435 // a <= x <= y <= b implies that [x;y] is covered by [a;b]
436 // -> so remove the existing interval that we will cover with the
437 // overwrite
438 Interval.erase();
439 } else if (LowerBound < Interval.start() && UpperBound <= Interval.stop()) {
440 // a < x <= b <= y implies that [a; x] is not covered but [x;b] is
441 // -> so set b = x - 1 such that [a;x-1] is now the interval to insert
442 UpperBound = Interval.start() - 1;
443 } else if (Interval.start() <= LowerBound && Interval.stop() < UpperBound) {
444 // a < x <= b <= y implies that [y; b] is not covered but [a;y] is
445 // -> so set a = y + 1 such that [y+1;b] is now the interval to insert
446 LowerBound = Interval.stop() + 1;
447 }
448 }
449
450 assert(LowerBound <= UpperBound && "Attempting to insert an empty interval");
451 Intervals.insert(a: LowerBound, b: UpperBound, y: &Info);
452 return Res;
453}
454
455} // namespace rootsig
456} // namespace hlsl
457} // namespace llvm
458