1//===- HLSLRootSignature.cpp - HLSL Root Signature helpers ----------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8///
9/// \file This file contains helpers for working with HLSL Root Signatures.
10///
11//===----------------------------------------------------------------------===//
12
13#include "llvm/Frontend/HLSL/HLSLRootSignature.h"
14#include "llvm/Support/DXILABI.h"
15#include "llvm/Support/InterleavedRange.h"
16#include "llvm/Support/ScopedPrinter.h"
17
18namespace llvm {
19namespace hlsl {
20namespace rootsig {
21
22template <typename T>
23static raw_ostream &printFlags(raw_ostream &OS, const T Value,
24 ArrayRef<EnumEntry<T>> Flags) {
25 bool FlagSet = false;
26 unsigned Remaining = llvm::to_underlying(Value);
27 while (Remaining) {
28 unsigned Bit = 1u << llvm::countr_zero(Val: Remaining);
29 if (Remaining & Bit) {
30 if (FlagSet)
31 OS << " | ";
32
33 StringRef MaybeFlag = enumToStringRef(T(Bit), Flags);
34 if (!MaybeFlag.empty())
35 OS << MaybeFlag;
36 else
37 OS << "invalid: " << Bit;
38
39 FlagSet = true;
40 }
41 Remaining &= ~Bit;
42 }
43
44 if (!FlagSet)
45 OS << "None";
46 return OS;
47}
48
49static const EnumEntry<RegisterType> RegisterNames[] = {
50 {"b", RegisterType::BReg},
51 {"t", RegisterType::TReg},
52 {"u", RegisterType::UReg},
53 {"s", RegisterType::SReg},
54};
55
56static raw_ostream &operator<<(raw_ostream &OS, const Register &Reg) {
57 OS << enumToStringRef(Value: Reg.ViewType, EnumValues: ArrayRef(RegisterNames)) << Reg.Number;
58
59 return OS;
60}
61
62static raw_ostream &operator<<(raw_ostream &OS,
63 const llvm::dxbc::ShaderVisibility &Visibility) {
64 OS << enumToStringRef(Value: Visibility, EnumValues: dxbc::getShaderVisibility());
65
66 return OS;
67}
68
69static raw_ostream &operator<<(raw_ostream &OS,
70 const llvm::dxbc::SamplerFilter &Filter) {
71 OS << enumToStringRef(Value: Filter, EnumValues: dxbc::getSamplerFilters());
72
73 return OS;
74}
75
76static raw_ostream &operator<<(raw_ostream &OS,
77 const dxbc::TextureAddressMode &Address) {
78 OS << enumToStringRef(Value: Address, EnumValues: dxbc::getTextureAddressModes());
79
80 return OS;
81}
82
83static raw_ostream &operator<<(raw_ostream &OS,
84 const dxbc::ComparisonFunc &CompFunc) {
85 OS << enumToStringRef(Value: CompFunc, EnumValues: dxbc::getComparisonFuncs());
86
87 return OS;
88}
89
90static raw_ostream &operator<<(raw_ostream &OS,
91 const dxbc::StaticBorderColor &BorderColor) {
92 OS << enumToStringRef(Value: BorderColor, EnumValues: dxbc::getStaticBorderColors());
93
94 return OS;
95}
96
97static raw_ostream &operator<<(raw_ostream &OS,
98 const dxil::ResourceClass &Type) {
99 OS << dxil::getResourceClassName(RC: Type);
100 return OS;
101}
102
103static raw_ostream &operator<<(raw_ostream &OS,
104 const dxbc::RootDescriptorFlags &Flags) {
105 printFlags(OS, Value: Flags, Flags: dxbc::getRootDescriptorFlags());
106
107 return OS;
108}
109
110static raw_ostream &operator<<(raw_ostream &OS,
111 const llvm::dxbc::DescriptorRangeFlags &Flags) {
112 printFlags(OS, Value: Flags, Flags: dxbc::getDescriptorRangeFlags());
113
114 return OS;
115}
116
117static raw_ostream &operator<<(raw_ostream &OS,
118 const llvm::dxbc::StaticSamplerFlags &Flags) {
119 printFlags(OS, Value: Flags, Flags: dxbc::getStaticSamplerFlags());
120
121 return OS;
122}
123
124raw_ostream &operator<<(raw_ostream &OS, const dxbc::RootFlags &Flags) {
125 OS << "RootFlags(";
126 printFlags(OS, Value: Flags, Flags: dxbc::getRootFlags());
127 OS << ")";
128
129 return OS;
130}
131
132raw_ostream &operator<<(raw_ostream &OS, const RootConstants &Constants) {
133 OS << "RootConstants(num32BitConstants = " << Constants.Num32BitConstants
134 << ", " << Constants.Reg << ", space = " << Constants.Space
135 << ", visibility = " << Constants.Visibility << ")";
136
137 return OS;
138}
139
140raw_ostream &operator<<(raw_ostream &OS, const DescriptorTable &Table) {
141 OS << "DescriptorTable(numClauses = " << Table.NumClauses
142 << ", visibility = " << Table.Visibility << ")";
143
144 return OS;
145}
146
147raw_ostream &operator<<(raw_ostream &OS, const DescriptorTableClause &Clause) {
148 OS << Clause.Type << "(" << Clause.Reg << ", numDescriptors = ";
149 if (Clause.NumDescriptors == NumDescriptorsUnbounded)
150 OS << "unbounded";
151 else
152 OS << Clause.NumDescriptors;
153 OS << ", space = " << Clause.Space << ", offset = ";
154 if (Clause.Offset == DescriptorTableOffsetAppend)
155 OS << "DescriptorTableOffsetAppend";
156 else
157 OS << Clause.Offset;
158 OS << ", flags = " << Clause.Flags << ")";
159
160 return OS;
161}
162
163raw_ostream &operator<<(raw_ostream &OS, const RootDescriptor &Descriptor) {
164 OS << "Root" << Descriptor.Type << "(" << Descriptor.Reg
165 << ", space = " << Descriptor.Space
166 << ", visibility = " << Descriptor.Visibility
167 << ", flags = " << Descriptor.Flags << ")";
168
169 return OS;
170}
171
172raw_ostream &operator<<(raw_ostream &OS, const StaticSampler &Sampler) {
173 OS << "StaticSampler(" << Sampler.Reg << ", filter = " << Sampler.Filter
174 << ", addressU = " << Sampler.AddressU
175 << ", addressV = " << Sampler.AddressV
176 << ", addressW = " << Sampler.AddressW
177 << ", mipLODBias = " << Sampler.MipLODBias
178 << ", maxAnisotropy = " << Sampler.MaxAnisotropy
179 << ", comparisonFunc = " << Sampler.CompFunc
180 << ", borderColor = " << Sampler.BorderColor
181 << ", minLOD = " << Sampler.MinLOD << ", maxLOD = " << Sampler.MaxLOD
182 << ", space = " << Sampler.Space << ", visibility = " << Sampler.Visibility
183 << ", flags = " << Sampler.Flags << ")";
184 return OS;
185}
186
187namespace {
188
189// We use the OverloadVisit with std::visit to ensure the compiler catches if a
190// new RootElement variant type is added but it's operator<< isn't handled.
191template <class... Ts> struct OverloadedVisit : Ts... {
192 using Ts::operator()...;
193};
194template <class... Ts> OverloadedVisit(Ts...) -> OverloadedVisit<Ts...>;
195
196} // namespace
197
198raw_ostream &operator<<(raw_ostream &OS, const RootElement &Element) {
199 const auto Visitor = OverloadedVisit{
200 [&OS](const dxbc::RootFlags &Flags) { OS << Flags; },
201 [&OS](const RootConstants &Constants) { OS << Constants; },
202 [&OS](const RootDescriptor &Descriptor) { OS << Descriptor; },
203 [&OS](const DescriptorTableClause &Clause) { OS << Clause; },
204 [&OS](const DescriptorTable &Table) { OS << Table; },
205 [&OS](const StaticSampler &Sampler) { OS << Sampler; },
206 };
207 std::visit(visitor: Visitor, variants: Element);
208 return OS;
209}
210
211void dumpRootElements(raw_ostream &OS, ArrayRef<RootElement> Elements) {
212 OS << " RootElements" << interleaved(R: Elements, Separator: ", ", Prefix: "{", Suffix: "}");
213}
214
215} // namespace rootsig
216} // namespace hlsl
217} // namespace llvm
218