1//===- RootSignatureMetadata.h - HLSL Root Signature helpers --------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8///
9/// \file This file implements a library for working with HLSL Root Signatures
10/// and their metadata representation.
11///
12//===----------------------------------------------------------------------===//
13
14#include "llvm/Frontend/HLSL/RootSignatureMetadata.h"
15#include "llvm/Frontend/HLSL/RootSignatureValidations.h"
16#include "llvm/IR/IRBuilder.h"
17#include "llvm/IR/Metadata.h"
18#include "llvm/Support/DXILABI.h"
19#include "llvm/Support/ScopedPrinter.h"
20
21using namespace llvm;
22
23namespace llvm {
24namespace hlsl {
25namespace rootsig {
26
27char RootSignatureValidationError::ID;
28
29static std::optional<uint32_t> extractMdIntValue(MDNode *Node,
30 unsigned int OpId) {
31 if (auto *CI =
32 mdconst::dyn_extract<ConstantInt>(MD: Node->getOperand(I: OpId).get()))
33 return CI->getZExtValue();
34 return std::nullopt;
35}
36
37static std::optional<float> extractMdFloatValue(MDNode *Node,
38 unsigned int OpId) {
39 if (auto *CI = mdconst::dyn_extract<ConstantFP>(MD: Node->getOperand(I: OpId).get()))
40 return CI->getValueAPF().convertToFloat();
41 return std::nullopt;
42}
43
44static std::optional<StringRef> extractMdStringValue(MDNode *Node,
45 unsigned int OpId) {
46 MDString *NodeText = dyn_cast<MDString>(Val: Node->getOperand(I: OpId));
47 if (NodeText == nullptr)
48 return std::nullopt;
49 return NodeText->getString();
50}
51
52namespace {
53
54// We use the OverloadVisit with std::visit to ensure the compiler catches if a
55// new RootElement variant type is added but it's metadata generation isn't
56// handled.
57template <class... Ts> struct OverloadedVisit : Ts... {
58 using Ts::operator()...;
59};
60template <class... Ts> OverloadedVisit(Ts...) -> OverloadedVisit<Ts...>;
61
62struct FmtRange {
63 dxil::ResourceClass Type;
64 uint32_t Register;
65 uint32_t Space;
66
67 FmtRange(const mcdxbc::DescriptorRange &Range)
68 : Type(Range.RangeType), Register(Range.BaseShaderRegister),
69 Space(Range.RegisterSpace) {}
70};
71
72raw_ostream &operator<<(llvm::raw_ostream &OS, const FmtRange &Range) {
73 OS << getResourceClassName(RC: Range.Type) << "(register=" << Range.Register
74 << ", space=" << Range.Space << ")";
75 return OS;
76}
77
78struct FmtMDNode {
79 const MDNode *Node;
80
81 FmtMDNode(const MDNode *Node) : Node(Node) {}
82};
83
84raw_ostream &operator<<(llvm::raw_ostream &OS, FmtMDNode Fmt) {
85 Fmt.Node->printTree(OS);
86 return OS;
87}
88
89static Error makeRSError(const Twine &Msg) {
90 return make_error<RootSignatureValidationError>(Args: Msg);
91}
92} // namespace
93
94template <typename T, typename = std::enable_if_t<
95 std::is_enum_v<T> &&
96 std::is_same_v<std::underlying_type_t<T>, uint32_t>>>
97static Expected<T>
98extractEnumValue(MDNode *Node, unsigned int OpId, StringRef ErrText,
99 llvm::function_ref<bool(uint32_t)> VerifyFn) {
100 if (std::optional<uint32_t> Val = extractMdIntValue(Node, OpId)) {
101 if (!VerifyFn(*Val))
102 return makeRSError(Msg: formatv(Fmt: "Invalid value for {0}: {1}", Vals&: ErrText, Vals&: Val));
103 return static_cast<T>(*Val);
104 }
105 return makeRSError(Msg: formatv(Fmt: "Invalid value for {0}:", Vals&: ErrText));
106}
107
108MDNode *MetadataBuilder::BuildRootSignature() {
109 const auto Visitor = OverloadedVisit{
110 [this](const dxbc::RootFlags &Flags) -> MDNode * {
111 return BuildRootFlags(Flags);
112 },
113 [this](const RootConstants &Constants) -> MDNode * {
114 return BuildRootConstants(Constants);
115 },
116 [this](const RootDescriptor &Descriptor) -> MDNode * {
117 return BuildRootDescriptor(Descriptor);
118 },
119 [this](const DescriptorTableClause &Clause) -> MDNode * {
120 return BuildDescriptorTableClause(Clause);
121 },
122 [this](const DescriptorTable &Table) -> MDNode * {
123 return BuildDescriptorTable(Table);
124 },
125 [this](const StaticSampler &Sampler) -> MDNode * {
126 return BuildStaticSampler(Sampler);
127 },
128 };
129
130 for (const RootElement &Element : Elements) {
131 MDNode *ElementMD = std::visit(visitor: Visitor, variants: Element);
132 assert(ElementMD != nullptr &&
133 "Root Element must be initialized and validated");
134 GeneratedMetadata.push_back(Elt: ElementMD);
135 }
136
137 return MDNode::get(Context&: Ctx, MDs: GeneratedMetadata);
138}
139
140MDNode *MetadataBuilder::BuildRootFlags(const dxbc::RootFlags &Flags) {
141 IRBuilder<> Builder(Ctx);
142 Metadata *Operands[] = {
143 MDString::get(Context&: Ctx, Str: "RootFlags"),
144 ConstantAsMetadata::get(C: Builder.getInt32(C: to_underlying(E: Flags))),
145 };
146 return MDNode::get(Context&: Ctx, MDs: Operands);
147}
148
149MDNode *MetadataBuilder::BuildRootConstants(const RootConstants &Constants) {
150 IRBuilder<> Builder(Ctx);
151 Metadata *Operands[] = {
152 MDString::get(Context&: Ctx, Str: "RootConstants"),
153 ConstantAsMetadata::get(
154 C: Builder.getInt32(C: to_underlying(E: Constants.Visibility))),
155 ConstantAsMetadata::get(C: Builder.getInt32(C: Constants.Reg.Number)),
156 ConstantAsMetadata::get(C: Builder.getInt32(C: Constants.Space)),
157 ConstantAsMetadata::get(C: Builder.getInt32(C: Constants.Num32BitConstants)),
158 };
159 return MDNode::get(Context&: Ctx, MDs: Operands);
160}
161
162MDNode *MetadataBuilder::BuildRootDescriptor(const RootDescriptor &Descriptor) {
163 IRBuilder<> Builder(Ctx);
164 StringRef ResName = dxil::getResourceClassName(RC: Descriptor.Type);
165 assert(!ResName.empty() && "Provided an invalid Resource Class");
166 SmallString<7> Name({"Root", ResName});
167 Metadata *Operands[] = {
168 MDString::get(Context&: Ctx, Str: Name),
169 ConstantAsMetadata::get(
170 C: Builder.getInt32(C: to_underlying(E: Descriptor.Visibility))),
171 ConstantAsMetadata::get(C: Builder.getInt32(C: Descriptor.Reg.Number)),
172 ConstantAsMetadata::get(C: Builder.getInt32(C: Descriptor.Space)),
173 ConstantAsMetadata::get(
174 C: Builder.getInt32(C: to_underlying(E: Descriptor.Flags))),
175 };
176 return MDNode::get(Context&: Ctx, MDs: Operands);
177}
178
179MDNode *MetadataBuilder::BuildDescriptorTable(const DescriptorTable &Table) {
180 IRBuilder<> Builder(Ctx);
181 SmallVector<Metadata *> TableOperands;
182 // Set the mandatory arguments
183 TableOperands.push_back(Elt: MDString::get(Context&: Ctx, Str: "DescriptorTable"));
184 TableOperands.push_back(Elt: ConstantAsMetadata::get(
185 C: Builder.getInt32(C: to_underlying(E: Table.Visibility))));
186
187 // Remaining operands are references to the table's clauses. The in-memory
188 // representation of the Root Elements created from parsing will ensure that
189 // the previous N elements are the clauses for this table.
190 assert(Table.NumClauses <= GeneratedMetadata.size() &&
191 "Table expected all owned clauses to be generated already");
192 // So, add a refence to each clause to our operands
193 TableOperands.append(in_start: GeneratedMetadata.end() - Table.NumClauses,
194 in_end: GeneratedMetadata.end());
195 // Then, remove those clauses from the general list of Root Elements
196 GeneratedMetadata.pop_back_n(NumItems: Table.NumClauses);
197
198 return MDNode::get(Context&: Ctx, MDs: TableOperands);
199}
200
201MDNode *MetadataBuilder::BuildDescriptorTableClause(
202 const DescriptorTableClause &Clause) {
203 IRBuilder<> Builder(Ctx);
204 StringRef ResName = dxil::getResourceClassName(RC: Clause.Type);
205 assert(!ResName.empty() && "Provided an invalid Resource Class");
206 Metadata *Operands[] = {
207 MDString::get(Context&: Ctx, Str: ResName),
208 ConstantAsMetadata::get(C: Builder.getInt32(C: Clause.NumDescriptors)),
209 ConstantAsMetadata::get(C: Builder.getInt32(C: Clause.Reg.Number)),
210 ConstantAsMetadata::get(C: Builder.getInt32(C: Clause.Space)),
211 ConstantAsMetadata::get(C: Builder.getInt32(C: Clause.Offset)),
212 ConstantAsMetadata::get(C: Builder.getInt32(C: to_underlying(E: Clause.Flags))),
213 };
214 return MDNode::get(Context&: Ctx, MDs: Operands);
215}
216
217MDNode *MetadataBuilder::BuildStaticSampler(const StaticSampler &Sampler) {
218 IRBuilder<> Builder(Ctx);
219 Metadata *Operands[] = {
220 MDString::get(Context&: Ctx, Str: "StaticSampler"),
221 ConstantAsMetadata::get(C: Builder.getInt32(C: to_underlying(E: Sampler.Filter))),
222 ConstantAsMetadata::get(
223 C: Builder.getInt32(C: to_underlying(E: Sampler.AddressU))),
224 ConstantAsMetadata::get(
225 C: Builder.getInt32(C: to_underlying(E: Sampler.AddressV))),
226 ConstantAsMetadata::get(
227 C: Builder.getInt32(C: to_underlying(E: Sampler.AddressW))),
228 ConstantAsMetadata::get(
229 C: ConstantFP::get(Ty: Type::getFloatTy(C&: Ctx), V: Sampler.MipLODBias)),
230 ConstantAsMetadata::get(C: Builder.getInt32(C: Sampler.MaxAnisotropy)),
231 ConstantAsMetadata::get(
232 C: Builder.getInt32(C: to_underlying(E: Sampler.CompFunc))),
233 ConstantAsMetadata::get(
234 C: Builder.getInt32(C: to_underlying(E: Sampler.BorderColor))),
235 ConstantAsMetadata::get(
236 C: ConstantFP::get(Ty: Type::getFloatTy(C&: Ctx), V: Sampler.MinLOD)),
237 ConstantAsMetadata::get(
238 C: ConstantFP::get(Ty: Type::getFloatTy(C&: Ctx), V: Sampler.MaxLOD)),
239 ConstantAsMetadata::get(C: Builder.getInt32(C: Sampler.Reg.Number)),
240 ConstantAsMetadata::get(C: Builder.getInt32(C: Sampler.Space)),
241 ConstantAsMetadata::get(
242 C: Builder.getInt32(C: to_underlying(E: Sampler.Visibility))),
243 ConstantAsMetadata::get(C: Builder.getInt32(C: to_underlying(E: Sampler.Flags))),
244 };
245 return MDNode::get(Context&: Ctx, MDs: Operands);
246}
247
248Error MetadataParser::parseRootFlags(mcdxbc::RootSignatureDesc &RSD,
249 MDNode *RootFlagNode) {
250 if (RootFlagNode->getNumOperands() != 2)
251 return makeRSError(Msg: "Invalid format for RootFlags Element");
252
253 if (std::optional<uint32_t> Val = extractMdIntValue(Node: RootFlagNode, OpId: 1))
254 RSD.Flags = *Val;
255 else
256 return makeRSError(Msg: "Invalid value for RootFlag");
257
258 return Error::success();
259}
260
261Error MetadataParser::parseRootConstants(mcdxbc::RootSignatureDesc &RSD,
262 MDNode *RootConstantNode) {
263 if (RootConstantNode->getNumOperands() != 5)
264 return makeRSError(Msg: "Invalid format for RootConstants Element");
265
266 Expected<dxbc::ShaderVisibility> Visibility =
267 extractEnumValue<dxbc::ShaderVisibility>(Node: RootConstantNode, OpId: 1,
268 ErrText: "ShaderVisibility",
269 VerifyFn: dxbc::isValidShaderVisibility);
270 if (auto E = Visibility.takeError())
271 return Error(std::move(E));
272
273 mcdxbc::RootConstants Constants;
274 if (std::optional<uint32_t> Val = extractMdIntValue(Node: RootConstantNode, OpId: 2))
275 Constants.ShaderRegister = *Val;
276 else
277 return makeRSError(Msg: "Invalid value for ShaderRegister");
278
279 if (std::optional<uint32_t> Val = extractMdIntValue(Node: RootConstantNode, OpId: 3))
280 Constants.RegisterSpace = *Val;
281 else
282 return makeRSError(Msg: "Invalid value for RegisterSpace");
283
284 if (std::optional<uint32_t> Val = extractMdIntValue(Node: RootConstantNode, OpId: 4))
285 Constants.Num32BitValues = *Val;
286 else
287 return makeRSError(Msg: "Invalid value for Num32BitValues");
288
289 RSD.ParametersContainer.addParameter(Type: dxbc::RootParameterType::Constants32Bit,
290 Visibility: *Visibility, Constant: Constants);
291
292 return Error::success();
293}
294
295Error MetadataParser::parseRootDescriptors(
296 mcdxbc::RootSignatureDesc &RSD, MDNode *RootDescriptorNode,
297 RootSignatureElementKind ElementKind) {
298 assert((ElementKind == RootSignatureElementKind::SRV ||
299 ElementKind == RootSignatureElementKind::UAV ||
300 ElementKind == RootSignatureElementKind::CBV) &&
301 "parseRootDescriptors should only be called with RootDescriptor "
302 "element kind.");
303 if (RootDescriptorNode->getNumOperands() != 5)
304 return makeRSError(Msg: "Invalid format for Root Descriptor Element");
305
306 dxbc::RootParameterType Type;
307 switch (ElementKind) {
308 case RootSignatureElementKind::SRV:
309 Type = dxbc::RootParameterType::SRV;
310 break;
311 case RootSignatureElementKind::UAV:
312 Type = dxbc::RootParameterType::UAV;
313 break;
314 case RootSignatureElementKind::CBV:
315 Type = dxbc::RootParameterType::CBV;
316 break;
317 default:
318 llvm_unreachable("invalid Root Descriptor kind");
319 break;
320 }
321
322 Expected<dxbc::ShaderVisibility> Visibility =
323 extractEnumValue<dxbc::ShaderVisibility>(Node: RootDescriptorNode, OpId: 1,
324 ErrText: "ShaderVisibility",
325 VerifyFn: dxbc::isValidShaderVisibility);
326 if (auto E = Visibility.takeError())
327 return Error(std::move(E));
328
329 mcdxbc::RootDescriptor Descriptor;
330 if (std::optional<uint32_t> Val = extractMdIntValue(Node: RootDescriptorNode, OpId: 2))
331 Descriptor.ShaderRegister = *Val;
332 else
333 return makeRSError(Msg: "Invalid value for ShaderRegister");
334
335 if (std::optional<uint32_t> Val = extractMdIntValue(Node: RootDescriptorNode, OpId: 3))
336 Descriptor.RegisterSpace = *Val;
337 else
338 return makeRSError(Msg: "Invalid value for RegisterSpace");
339
340 if (std::optional<uint32_t> Val = extractMdIntValue(Node: RootDescriptorNode, OpId: 4))
341 Descriptor.Flags = *Val;
342 else
343 return makeRSError(Msg: "Invalid value for Root Descriptor Flags");
344
345 RSD.ParametersContainer.addParameter(Type, Visibility: *Visibility, Descriptor);
346 return Error::success();
347}
348
349Error MetadataParser::parseDescriptorRange(mcdxbc::DescriptorTable &Table,
350 MDNode *RangeDescriptorNode) {
351 if (RangeDescriptorNode->getNumOperands() != 6)
352 return makeRSError(Msg: "Invalid format for Descriptor Range");
353
354 mcdxbc::DescriptorRange Range;
355
356 std::optional<StringRef> ElementText =
357 extractMdStringValue(Node: RangeDescriptorNode, OpId: 0);
358
359 if (!ElementText.has_value())
360 return makeRSError(Msg: "Invalid format for Descriptor Range");
361
362 if (*ElementText == "CBV")
363 Range.RangeType = dxil::ResourceClass::CBuffer;
364 else if (*ElementText == "SRV")
365 Range.RangeType = dxil::ResourceClass::SRV;
366 else if (*ElementText == "UAV")
367 Range.RangeType = dxil::ResourceClass::UAV;
368 else if (*ElementText == "Sampler")
369 Range.RangeType = dxil::ResourceClass::Sampler;
370 else
371 return makeRSError(Msg: formatv(Fmt: "Invalid Descriptor Range type.\n{0}",
372 Vals: FmtMDNode{RangeDescriptorNode}));
373
374 if (std::optional<uint32_t> Val = extractMdIntValue(Node: RangeDescriptorNode, OpId: 1))
375 Range.NumDescriptors = *Val;
376 else
377 return makeRSError(Msg: formatv(Fmt: "Invalid number of Descriptor in Range.\n{0}",
378 Vals: FmtMDNode{RangeDescriptorNode}));
379
380 if (std::optional<uint32_t> Val = extractMdIntValue(Node: RangeDescriptorNode, OpId: 2))
381 Range.BaseShaderRegister = *Val;
382 else
383 return makeRSError(Msg: "Invalid value for BaseShaderRegister");
384
385 if (std::optional<uint32_t> Val = extractMdIntValue(Node: RangeDescriptorNode, OpId: 3))
386 Range.RegisterSpace = *Val;
387 else
388 return makeRSError(Msg: "Invalid value for RegisterSpace");
389
390 if (std::optional<uint32_t> Val = extractMdIntValue(Node: RangeDescriptorNode, OpId: 4))
391 Range.OffsetInDescriptorsFromTableStart = *Val;
392 else
393 return makeRSError(Msg: "Invalid value for OffsetInDescriptorsFromTableStart");
394
395 if (std::optional<uint32_t> Val = extractMdIntValue(Node: RangeDescriptorNode, OpId: 5))
396 Range.Flags = *Val;
397 else
398 return makeRSError(Msg: "Invalid value for Descriptor Range Flags");
399
400 Table.Ranges.push_back(Elt: Range);
401 return Error::success();
402}
403
404Error MetadataParser::parseDescriptorTable(mcdxbc::RootSignatureDesc &RSD,
405 MDNode *DescriptorTableNode) {
406 const unsigned int NumOperands = DescriptorTableNode->getNumOperands();
407 if (NumOperands < 2)
408 return makeRSError(Msg: "Invalid format for Descriptor Table");
409
410 Expected<dxbc::ShaderVisibility> Visibility =
411 extractEnumValue<dxbc::ShaderVisibility>(Node: DescriptorTableNode, OpId: 1,
412 ErrText: "ShaderVisibility",
413 VerifyFn: dxbc::isValidShaderVisibility);
414 if (auto E = Visibility.takeError())
415 return Error(std::move(E));
416
417 mcdxbc::DescriptorTable Table;
418
419 for (unsigned int I = 2; I < NumOperands; I++) {
420 MDNode *Element = dyn_cast<MDNode>(Val: DescriptorTableNode->getOperand(I));
421 if (Element == nullptr)
422 return makeRSError(Msg: formatv(Fmt: "Missing Root Element Metadata Node.\n{0}",
423 Vals: FmtMDNode{DescriptorTableNode}));
424
425 if (auto Err = parseDescriptorRange(Table, RangeDescriptorNode: Element))
426 return Err;
427 }
428
429 RSD.ParametersContainer.addParameter(Type: dxbc::RootParameterType::DescriptorTable,
430 Visibility: *Visibility, Table);
431 return Error::success();
432}
433
434Error MetadataParser::parseStaticSampler(mcdxbc::RootSignatureDesc &RSD,
435 MDNode *StaticSamplerNode) {
436 if (StaticSamplerNode->getNumOperands() != 15)
437 return makeRSError(Msg: "Invalid format for Static Sampler");
438
439 mcdxbc::StaticSampler Sampler;
440
441 Expected<dxbc::SamplerFilter> Filter = extractEnumValue<dxbc::SamplerFilter>(
442 Node: StaticSamplerNode, OpId: 1, ErrText: "Filter", VerifyFn: dxbc::isValidSamplerFilter);
443 if (auto E = Filter.takeError())
444 return Error(std::move(E));
445 Sampler.Filter = *Filter;
446
447 Expected<dxbc::TextureAddressMode> AddressU =
448 extractEnumValue<dxbc::TextureAddressMode>(
449 Node: StaticSamplerNode, OpId: 2, ErrText: "AddressU", VerifyFn: dxbc::isValidAddress);
450 if (auto E = AddressU.takeError())
451 return Error(std::move(E));
452 Sampler.AddressU = *AddressU;
453
454 Expected<dxbc::TextureAddressMode> AddressV =
455 extractEnumValue<dxbc::TextureAddressMode>(
456 Node: StaticSamplerNode, OpId: 3, ErrText: "AddressV", VerifyFn: dxbc::isValidAddress);
457 if (auto E = AddressV.takeError())
458 return Error(std::move(E));
459 Sampler.AddressV = *AddressV;
460
461 Expected<dxbc::TextureAddressMode> AddressW =
462 extractEnumValue<dxbc::TextureAddressMode>(
463 Node: StaticSamplerNode, OpId: 4, ErrText: "AddressW", VerifyFn: dxbc::isValidAddress);
464 if (auto E = AddressW.takeError())
465 return Error(std::move(E));
466 Sampler.AddressW = *AddressW;
467
468 if (std::optional<float> Val = extractMdFloatValue(Node: StaticSamplerNode, OpId: 5))
469 Sampler.MipLODBias = *Val;
470 else
471 return makeRSError(Msg: "Invalid value for MipLODBias");
472
473 if (std::optional<uint32_t> Val = extractMdIntValue(Node: StaticSamplerNode, OpId: 6))
474 Sampler.MaxAnisotropy = *Val;
475 else
476 return makeRSError(Msg: "Invalid value for MaxAnisotropy");
477
478 Expected<dxbc::ComparisonFunc> ComparisonFunc =
479 extractEnumValue<dxbc::ComparisonFunc>(
480 Node: StaticSamplerNode, OpId: 7, ErrText: "ComparisonFunc", VerifyFn: dxbc::isValidComparisonFunc);
481 if (auto E = ComparisonFunc.takeError())
482 return Error(std::move(E));
483 Sampler.ComparisonFunc = *ComparisonFunc;
484
485 Expected<dxbc::StaticBorderColor> BorderColor =
486 extractEnumValue<dxbc::StaticBorderColor>(
487 Node: StaticSamplerNode, OpId: 8, ErrText: "BorderColor", VerifyFn: dxbc::isValidBorderColor);
488 if (auto E = BorderColor.takeError())
489 return Error(std::move(E));
490 Sampler.BorderColor = *BorderColor;
491
492 if (std::optional<float> Val = extractMdFloatValue(Node: StaticSamplerNode, OpId: 9))
493 Sampler.MinLOD = *Val;
494 else
495 return makeRSError(Msg: "Invalid value for MinLOD");
496
497 if (std::optional<float> Val = extractMdFloatValue(Node: StaticSamplerNode, OpId: 10))
498 Sampler.MaxLOD = *Val;
499 else
500 return makeRSError(Msg: "Invalid value for MaxLOD");
501
502 if (std::optional<uint32_t> Val = extractMdIntValue(Node: StaticSamplerNode, OpId: 11))
503 Sampler.ShaderRegister = *Val;
504 else
505 return makeRSError(Msg: "Invalid value for ShaderRegister");
506
507 if (std::optional<uint32_t> Val = extractMdIntValue(Node: StaticSamplerNode, OpId: 12))
508 Sampler.RegisterSpace = *Val;
509 else
510 return makeRSError(Msg: "Invalid value for RegisterSpace");
511
512 Expected<dxbc::ShaderVisibility> Visibility =
513 extractEnumValue<dxbc::ShaderVisibility>(Node: StaticSamplerNode, OpId: 13,
514 ErrText: "ShaderVisibility",
515 VerifyFn: dxbc::isValidShaderVisibility);
516 if (auto E = Visibility.takeError())
517 return Error(std::move(E));
518 Sampler.ShaderVisibility = *Visibility;
519
520 if (std::optional<uint32_t> Val = extractMdIntValue(Node: StaticSamplerNode, OpId: 14))
521 Sampler.Flags = *Val;
522 else
523 return makeRSError(Msg: "Invalid value for Static Sampler Flags");
524
525 RSD.StaticSamplers.push_back(Elt: Sampler);
526 return Error::success();
527}
528
529Error MetadataParser::parseRootSignatureElement(mcdxbc::RootSignatureDesc &RSD,
530 MDNode *Element) {
531 std::optional<StringRef> ElementText = extractMdStringValue(Node: Element, OpId: 0);
532 if (!ElementText.has_value())
533 return makeRSError(Msg: "Invalid format for Root Element");
534
535 RootSignatureElementKind ElementKind =
536 StringSwitch<RootSignatureElementKind>(*ElementText)
537 .Case(S: "RootFlags", Value: RootSignatureElementKind::RootFlags)
538 .Case(S: "RootConstants", Value: RootSignatureElementKind::RootConstants)
539 .Case(S: "RootCBV", Value: RootSignatureElementKind::CBV)
540 .Case(S: "RootSRV", Value: RootSignatureElementKind::SRV)
541 .Case(S: "RootUAV", Value: RootSignatureElementKind::UAV)
542 .Case(S: "DescriptorTable", Value: RootSignatureElementKind::DescriptorTable)
543 .Case(S: "StaticSampler", Value: RootSignatureElementKind::StaticSamplers)
544 .Default(Value: RootSignatureElementKind::Error);
545
546 switch (ElementKind) {
547
548 case RootSignatureElementKind::RootFlags:
549 return parseRootFlags(RSD, RootFlagNode: Element);
550 case RootSignatureElementKind::RootConstants:
551 return parseRootConstants(RSD, RootConstantNode: Element);
552 case RootSignatureElementKind::CBV:
553 case RootSignatureElementKind::SRV:
554 case RootSignatureElementKind::UAV:
555 return parseRootDescriptors(RSD, RootDescriptorNode: Element, ElementKind);
556 case RootSignatureElementKind::DescriptorTable:
557 return parseDescriptorTable(RSD, DescriptorTableNode: Element);
558 case RootSignatureElementKind::StaticSamplers:
559 return parseStaticSampler(RSD, StaticSamplerNode: Element);
560 case RootSignatureElementKind::Error:
561 return makeRSError(
562 Msg: formatv(Fmt: "Invalid Root Signature Element\n{0}", Vals: FmtMDNode{Element}));
563 }
564
565 llvm_unreachable("Unhandled RootSignatureElementKind enum.");
566}
567
568static Error
569validateDescriptorTableSamplerMixin(const mcdxbc::DescriptorTable &Table,
570 uint32_t Location) {
571 dxil::ResourceClass CurrRC = dxil::ResourceClass::Sampler;
572 for (const mcdxbc::DescriptorRange &Range : Table.Ranges) {
573 if (Range.RangeType == dxil::ResourceClass::Sampler &&
574 CurrRC != dxil::ResourceClass::Sampler)
575 return makeRSError(
576 Msg: formatv(Fmt: "Samplers cannot be mixed with other resource types in a "
577 "descriptor table, {0}(location={1})",
578 Vals: getResourceClassName(RC: CurrRC), Vals&: Location));
579 CurrRC = Range.RangeType;
580 }
581 return Error::success();
582}
583
584static Error
585validateDescriptorTableRegisterOverflow(const mcdxbc::DescriptorTable &Table,
586 uint32_t Location) {
587 uint64_t Offset = 0;
588 bool IsPrevUnbound = false;
589 for (const mcdxbc::DescriptorRange &Range : Table.Ranges) {
590 // Validation of NumDescriptors should have happened by this point.
591 if (Range.NumDescriptors == 0)
592 continue;
593
594 const uint64_t RangeBound = llvm::hlsl::rootsig::computeRangeBound(
595 Offset: Range.BaseShaderRegister, Size: Range.NumDescriptors);
596
597 if (!verifyNoOverflowedOffset(Offset: RangeBound))
598 return makeRSError(
599 Msg: formatv(Fmt: "Overflow for shader register range: {0}", Vals: FmtRange{Range}));
600
601 bool IsAppending =
602 Range.OffsetInDescriptorsFromTableStart == DescriptorTableOffsetAppend;
603 if (!IsAppending)
604 Offset = Range.OffsetInDescriptorsFromTableStart;
605
606 if (IsPrevUnbound && IsAppending)
607 return makeRSError(
608 Msg: formatv(Fmt: "Range {0} cannot be appended after an unbounded range",
609 Vals: FmtRange{Range}));
610
611 const uint64_t OffsetBound =
612 llvm::hlsl::rootsig::computeRangeBound(Offset, Size: Range.NumDescriptors);
613
614 if (!verifyNoOverflowedOffset(Offset: OffsetBound))
615 return makeRSError(Msg: formatv(Fmt: "Offset overflow for descriptor range: {0}.",
616 Vals: FmtRange{Range}));
617
618 Offset = OffsetBound + 1;
619 IsPrevUnbound =
620 Range.NumDescriptors == llvm::hlsl::rootsig::NumDescriptorsUnbounded;
621 }
622
623 return Error::success();
624}
625
626Error MetadataParser::validateRootSignature(
627 const mcdxbc::RootSignatureDesc &RSD) {
628 Error DeferredErrs = Error::success();
629 if (!hlsl::rootsig::verifyVersion(Version: RSD.Version)) {
630 DeferredErrs = joinErrors(
631 E1: std::move(DeferredErrs),
632 E2: makeRSError(Msg: formatv(Fmt: "Invalid value for Version: {0}", Vals: RSD.Version)));
633 }
634
635 if (!hlsl::rootsig::verifyRootFlag(Flags: RSD.Flags)) {
636 DeferredErrs = joinErrors(
637 E1: std::move(DeferredErrs),
638 E2: makeRSError(Msg: formatv(Fmt: "Invalid value for RootFlags: {0}", Vals: RSD.Flags)));
639 }
640
641 for (const mcdxbc::RootParameterInfo &Info : RSD.ParametersContainer) {
642
643 switch (Info.Type) {
644 case dxbc::RootParameterType::Constants32Bit:
645 break;
646
647 case dxbc::RootParameterType::CBV:
648 case dxbc::RootParameterType::UAV:
649 case dxbc::RootParameterType::SRV: {
650 const mcdxbc::RootDescriptor &Descriptor =
651 RSD.ParametersContainer.getRootDescriptor(Index: Info.Location);
652 if (!hlsl::rootsig::verifyRegisterValue(RegisterValue: Descriptor.ShaderRegister))
653 DeferredErrs = joinErrors(
654 E1: std::move(DeferredErrs),
655 E2: makeRSError(Msg: formatv(Fmt: "Invalid value for ShaderRegister: {0}",
656 Vals: Descriptor.ShaderRegister)));
657
658 if (!hlsl::rootsig::verifyRegisterSpace(RegisterSpace: Descriptor.RegisterSpace))
659 DeferredErrs = joinErrors(
660 E1: std::move(DeferredErrs),
661 E2: makeRSError(Msg: formatv(Fmt: "Invalid value for RegisterSpace: {0}",
662 Vals: Descriptor.RegisterSpace)));
663
664 bool IsValidFlag =
665 dxbc::isValidRootDesciptorFlags(V: Descriptor.Flags) &&
666 hlsl::rootsig::verifyRootDescriptorFlag(
667 Version: RSD.Version, Flags: dxbc::RootDescriptorFlags(Descriptor.Flags));
668 if (!IsValidFlag)
669 DeferredErrs = joinErrors(
670 E1: std::move(DeferredErrs),
671 E2: makeRSError(Msg: formatv(Fmt: "Invalid value for RootDescriptorFlag: {0}",
672 Vals: Descriptor.Flags)));
673 break;
674 }
675 case dxbc::RootParameterType::DescriptorTable: {
676 const mcdxbc::DescriptorTable &Table =
677 RSD.ParametersContainer.getDescriptorTable(Index: Info.Location);
678 for (const mcdxbc::DescriptorRange &Range : Table) {
679 if (!hlsl::rootsig::verifyRegisterSpace(RegisterSpace: Range.RegisterSpace))
680 DeferredErrs = joinErrors(
681 E1: std::move(DeferredErrs),
682 E2: makeRSError(Msg: formatv(Fmt: "Invalid value for RegisterSpace: {0}",
683 Vals: Range.RegisterSpace)));
684
685 if (!hlsl::rootsig::verifyNumDescriptors(NumDescriptors: Range.NumDescriptors))
686 DeferredErrs = joinErrors(
687 E1: std::move(DeferredErrs),
688 E2: makeRSError(Msg: formatv(Fmt: "Invalid value for NumDescriptors: {0}",
689 Vals: Range.NumDescriptors)));
690
691 bool IsValidFlag = dxbc::isValidDescriptorRangeFlags(V: Range.Flags) &&
692 hlsl::rootsig::verifyDescriptorRangeFlag(
693 Version: RSD.Version, Type: Range.RangeType,
694 Flags: dxbc::DescriptorRangeFlags(Range.Flags));
695 if (!IsValidFlag)
696 DeferredErrs = joinErrors(
697 E1: std::move(DeferredErrs),
698 E2: makeRSError(Msg: formatv(Fmt: "Invalid value for DescriptorFlag: {0}",
699 Vals: Range.Flags)));
700
701 if (Error Err =
702 validateDescriptorTableSamplerMixin(Table, Location: Info.Location))
703 DeferredErrs = joinErrors(E1: std::move(DeferredErrs), E2: std::move(Err));
704
705 if (Error Err =
706 validateDescriptorTableRegisterOverflow(Table, Location: Info.Location))
707 DeferredErrs = joinErrors(E1: std::move(DeferredErrs), E2: std::move(Err));
708 }
709 break;
710 }
711 }
712 }
713
714 for (const mcdxbc::StaticSampler &Sampler : RSD.StaticSamplers) {
715
716 if (!hlsl::rootsig::verifyMipLODBias(MipLODBias: Sampler.MipLODBias))
717 DeferredErrs =
718 joinErrors(E1: std::move(DeferredErrs),
719 E2: makeRSError(Msg: formatv(Fmt: "Invalid value for MipLODBias: {0:e}",
720 Vals: Sampler.MipLODBias)));
721
722 if (!hlsl::rootsig::verifyMaxAnisotropy(MaxAnisotropy: Sampler.MaxAnisotropy))
723 DeferredErrs =
724 joinErrors(E1: std::move(DeferredErrs),
725 E2: makeRSError(Msg: formatv(Fmt: "Invalid value for MaxAnisotropy: {0}",
726 Vals: Sampler.MaxAnisotropy)));
727
728 if (!hlsl::rootsig::verifyLOD(LOD: Sampler.MinLOD))
729 DeferredErrs =
730 joinErrors(E1: std::move(DeferredErrs),
731 E2: makeRSError(Msg: formatv(Fmt: "Invalid value for MinLOD: {0}",
732 Vals: Sampler.MinLOD)));
733
734 if (!hlsl::rootsig::verifyLOD(LOD: Sampler.MaxLOD))
735 DeferredErrs =
736 joinErrors(E1: std::move(DeferredErrs),
737 E2: makeRSError(Msg: formatv(Fmt: "Invalid value for MaxLOD: {0}",
738 Vals: Sampler.MaxLOD)));
739
740 if (!hlsl::rootsig::verifyRegisterValue(RegisterValue: Sampler.ShaderRegister))
741 DeferredErrs = joinErrors(
742 E1: std::move(DeferredErrs),
743 E2: makeRSError(Msg: formatv(Fmt: "Invalid value for ShaderRegister: {0}",
744 Vals: Sampler.ShaderRegister)));
745
746 if (!hlsl::rootsig::verifyRegisterSpace(RegisterSpace: Sampler.RegisterSpace))
747 DeferredErrs =
748 joinErrors(E1: std::move(DeferredErrs),
749 E2: makeRSError(Msg: formatv(Fmt: "Invalid value for RegisterSpace: {0}",
750 Vals: Sampler.RegisterSpace)));
751 bool IsValidFlag =
752 dxbc::isValidStaticSamplerFlags(V: Sampler.Flags) &&
753 hlsl::rootsig::verifyStaticSamplerFlags(
754 Version: RSD.Version, Flags: dxbc::StaticSamplerFlags(Sampler.Flags));
755 if (!IsValidFlag)
756 DeferredErrs = joinErrors(
757 E1: std::move(DeferredErrs),
758 E2: makeRSError(Msg: formatv(Fmt: "Invalid value for Static Sampler Flag: {0}",
759 Vals: Sampler.Flags)));
760 }
761
762 return DeferredErrs;
763}
764
765Expected<mcdxbc::RootSignatureDesc>
766MetadataParser::ParseRootSignature(uint32_t Version) {
767 Error DeferredErrs = Error::success();
768 mcdxbc::RootSignatureDesc RSD;
769 RSD.Version = Version;
770 for (const auto &Operand : Root->operands()) {
771 MDNode *Element = dyn_cast<MDNode>(Val: Operand);
772 if (Element == nullptr)
773 return joinErrors(
774 E1: std::move(DeferredErrs),
775 E2: makeRSError(Msg: formatv(Fmt: "Missing Root Element Metadata Node.")));
776
777 if (auto Err = parseRootSignatureElement(RSD, Element))
778 DeferredErrs = joinErrors(E1: std::move(DeferredErrs), E2: std::move(Err));
779 }
780
781 if (auto Err = validateRootSignature(RSD))
782 DeferredErrs = joinErrors(E1: std::move(DeferredErrs), E2: std::move(Err));
783
784 if (DeferredErrs)
785 return std::move(DeferredErrs);
786
787 return std::move(RSD);
788}
789} // namespace rootsig
790} // namespace hlsl
791} // namespace llvm
792