1//===- HLSLRootSignature.cpp - HLSL Root Signature helpers ----------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8///
9/// \file This file contains helpers for working with HLSL Root Signatures.
10///
11//===----------------------------------------------------------------------===//
12
13#include "llvm/Frontend/HLSL/HLSLRootSignature.h"
14#include "llvm/Support/DXILABI.h"
15#include "llvm/Support/InterleavedRange.h"
16#include "llvm/Support/ScopedPrinter.h"
17
18namespace llvm {
19namespace hlsl {
20namespace rootsig {
21
22template <typename T>
23static raw_ostream &printFlags(raw_ostream &OS, const T Value,
24 EnumStrings<T> Flags) {
25 bool FlagSet = false;
26 unsigned Remaining = llvm::to_underlying(Value);
27 while (Remaining) {
28 unsigned Bit = 1u << llvm::countr_zero(Val: Remaining);
29 if (Remaining & Bit) {
30 if (FlagSet)
31 OS << " | ";
32
33 StringRef MaybeFlag = Flags.toString(T(Bit));
34 if (!MaybeFlag.empty())
35 OS << MaybeFlag;
36 else
37 OS << "invalid: " << Bit;
38
39 FlagSet = true;
40 }
41 Remaining &= ~Bit;
42 }
43
44 if (!FlagSet)
45 OS << "None";
46 return OS;
47}
48
49static raw_ostream &operator<<(raw_ostream &OS, const Register &Reg) {
50 constexpr EnumStringDef<RegisterType> RegisterNameDefs[] = {
51 {.Names: {"b"}, .Value: RegisterType::BReg},
52 {.Names: {"t"}, .Value: RegisterType::TReg},
53 {.Names: {"u"}, .Value: RegisterType::UReg},
54 {.Names: {"s"}, .Value: RegisterType::SReg},
55 };
56 static constexpr auto RegisterNames = BUILD_ENUM_STRINGS(RegisterNameDefs);
57 OS << EnumStrings(RegisterNames).toString(Value: Reg.ViewType) << Reg.Number;
58 return OS;
59}
60
61static raw_ostream &operator<<(raw_ostream &OS,
62 const llvm::dxbc::ShaderVisibility &Visibility) {
63 OS << dxbc::getShaderVisibility().toString(Value: Visibility);
64
65 return OS;
66}
67
68static raw_ostream &operator<<(raw_ostream &OS,
69 const llvm::dxbc::SamplerFilter &Filter) {
70 OS << dxbc::getSamplerFilters().toString(Value: Filter);
71
72 return OS;
73}
74
75static raw_ostream &operator<<(raw_ostream &OS,
76 const dxbc::TextureAddressMode &Address) {
77 OS << dxbc::getTextureAddressModes().toString(Value: Address);
78
79 return OS;
80}
81
82static raw_ostream &operator<<(raw_ostream &OS,
83 const dxbc::ComparisonFunc &CompFunc) {
84 OS << dxbc::getComparisonFuncs().toString(Value: CompFunc);
85
86 return OS;
87}
88
89static raw_ostream &operator<<(raw_ostream &OS,
90 const dxbc::StaticBorderColor &BorderColor) {
91 OS << dxbc::getStaticBorderColors().toString(Value: BorderColor);
92
93 return OS;
94}
95
96static raw_ostream &operator<<(raw_ostream &OS,
97 const dxil::ResourceClass &Type) {
98 OS << dxil::getResourceClassName(RC: Type);
99 return OS;
100}
101
102static raw_ostream &operator<<(raw_ostream &OS,
103 const dxbc::RootDescriptorFlags &Flags) {
104 printFlags(OS, Value: Flags, Flags: dxbc::getRootDescriptorFlags());
105
106 return OS;
107}
108
109static raw_ostream &operator<<(raw_ostream &OS,
110 const llvm::dxbc::DescriptorRangeFlags &Flags) {
111 printFlags(OS, Value: Flags, Flags: dxbc::getDescriptorRangeFlags());
112
113 return OS;
114}
115
116static raw_ostream &operator<<(raw_ostream &OS,
117 const llvm::dxbc::StaticSamplerFlags &Flags) {
118 printFlags(OS, Value: Flags, Flags: dxbc::getStaticSamplerFlags());
119
120 return OS;
121}
122
123raw_ostream &operator<<(raw_ostream &OS, const dxbc::RootFlags &Flags) {
124 OS << "RootFlags(";
125 printFlags(OS, Value: Flags, Flags: dxbc::getRootFlags());
126 OS << ")";
127
128 return OS;
129}
130
131raw_ostream &operator<<(raw_ostream &OS, const RootConstants &Constants) {
132 OS << "RootConstants(num32BitConstants = " << Constants.Num32BitConstants
133 << ", " << Constants.Reg << ", space = " << Constants.Space
134 << ", visibility = " << Constants.Visibility << ")";
135
136 return OS;
137}
138
139raw_ostream &operator<<(raw_ostream &OS, const DescriptorTable &Table) {
140 OS << "DescriptorTable(numClauses = " << Table.NumClauses
141 << ", visibility = " << Table.Visibility << ")";
142
143 return OS;
144}
145
146raw_ostream &operator<<(raw_ostream &OS, const DescriptorTableClause &Clause) {
147 OS << Clause.Type << "(" << Clause.Reg << ", numDescriptors = ";
148 if (Clause.NumDescriptors == NumDescriptorsUnbounded)
149 OS << "unbounded";
150 else
151 OS << Clause.NumDescriptors;
152 OS << ", space = " << Clause.Space << ", offset = ";
153 if (Clause.Offset == DescriptorTableOffsetAppend)
154 OS << "DescriptorTableOffsetAppend";
155 else
156 OS << Clause.Offset;
157 OS << ", flags = " << Clause.Flags << ")";
158
159 return OS;
160}
161
162raw_ostream &operator<<(raw_ostream &OS, const RootDescriptor &Descriptor) {
163 OS << "Root" << Descriptor.Type << "(" << Descriptor.Reg
164 << ", space = " << Descriptor.Space
165 << ", visibility = " << Descriptor.Visibility
166 << ", flags = " << Descriptor.Flags << ")";
167
168 return OS;
169}
170
171raw_ostream &operator<<(raw_ostream &OS, const StaticSampler &Sampler) {
172 OS << "StaticSampler(" << Sampler.Reg << ", filter = " << Sampler.Filter
173 << ", addressU = " << Sampler.AddressU
174 << ", addressV = " << Sampler.AddressV
175 << ", addressW = " << Sampler.AddressW
176 << ", mipLODBias = " << Sampler.MipLODBias
177 << ", maxAnisotropy = " << Sampler.MaxAnisotropy
178 << ", comparisonFunc = " << Sampler.CompFunc
179 << ", borderColor = " << Sampler.BorderColor
180 << ", minLOD = " << Sampler.MinLOD << ", maxLOD = " << Sampler.MaxLOD
181 << ", space = " << Sampler.Space << ", visibility = " << Sampler.Visibility
182 << ", flags = " << Sampler.Flags << ")";
183 return OS;
184}
185
186namespace {
187
188// We use the OverloadVisit with std::visit to ensure the compiler catches if a
189// new RootElement variant type is added but it's operator<< isn't handled.
190template <class... Ts> struct OverloadedVisit : Ts... {
191 using Ts::operator()...;
192};
193template <class... Ts> OverloadedVisit(Ts...) -> OverloadedVisit<Ts...>;
194
195} // namespace
196
197raw_ostream &operator<<(raw_ostream &OS, const RootElement &Element) {
198 const auto Visitor = OverloadedVisit{
199 [&OS](const dxbc::RootFlags &Flags) { OS << Flags; },
200 [&OS](const RootConstants &Constants) { OS << Constants; },
201 [&OS](const RootDescriptor &Descriptor) { OS << Descriptor; },
202 [&OS](const DescriptorTableClause &Clause) { OS << Clause; },
203 [&OS](const DescriptorTable &Table) { OS << Table; },
204 [&OS](const StaticSampler &Sampler) { OS << Sampler; },
205 };
206 std::visit(visitor: Visitor, variants: Element);
207 return OS;
208}
209
210void dumpRootElements(raw_ostream &OS, ArrayRef<RootElement> Elements) {
211 OS << " RootElements" << interleaved(R: Elements, Separator: ", ", Prefix: "{", Suffix: "}");
212}
213
214} // namespace rootsig
215} // namespace hlsl
216} // namespace llvm
217