1//===- AMDGPUMetadataVerifier.cpp - MsgPack Types ---------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9/// \file
10/// Implements a verifier for AMDGPU HSA metadata.
11//
12//===----------------------------------------------------------------------===//
13
14#include "llvm/BinaryFormat/AMDGPUMetadataVerifier.h"
15
16#include "llvm/ADT/STLExtras.h"
17#include "llvm/ADT/StringSwitch.h"
18#include "llvm/BinaryFormat/MsgPackDocument.h"
19
20namespace llvm {
21namespace AMDGPU {
22namespace HSAMD {
23namespace V3 {
24
25bool MetadataVerifier::verifyScalar(
26 msgpack::DocNode &Node, msgpack::Type SKind,
27 function_ref<bool(msgpack::DocNode &)> verifyValue) {
28 if (!Node.isScalar())
29 return false;
30 if (Node.getKind() != SKind) {
31 if (Strict)
32 return false;
33 // If we are not strict, we interpret string values as "implicitly typed"
34 // and attempt to coerce them to the expected type here.
35 if (Node.getKind() != msgpack::Type::String)
36 return false;
37 StringRef StringValue = Node.getString();
38 Node.fromString(S: StringValue);
39 if (Node.getKind() != SKind)
40 return false;
41 }
42 if (verifyValue)
43 return verifyValue(Node);
44 return true;
45}
46
47bool MetadataVerifier::verifyInteger(msgpack::DocNode &Node) {
48 if (!verifyScalar(Node, SKind: msgpack::Type::UInt))
49 if (!verifyScalar(Node, SKind: msgpack::Type::Int))
50 return false;
51 return true;
52}
53
54bool MetadataVerifier::verifyArray(
55 msgpack::DocNode &Node, function_ref<bool(msgpack::DocNode &)> verifyNode,
56 std::optional<size_t> Size) {
57 if (!Node.isArray())
58 return false;
59 auto &Array = Node.getArray();
60 if (Size && Array.size() != *Size)
61 return false;
62 return llvm::all_of(Range&: Array, P: verifyNode);
63}
64
65bool MetadataVerifier::verifyEntry(
66 msgpack::MapDocNode &MapNode, StringRef Key, bool Required,
67 function_ref<bool(msgpack::DocNode &)> verifyNode) {
68 auto Entry = MapNode.find(Key);
69 if (Entry == MapNode.end())
70 return !Required;
71 return verifyNode(Entry->second);
72}
73
74bool MetadataVerifier::verifyScalarEntry(
75 msgpack::MapDocNode &MapNode, StringRef Key, bool Required,
76 msgpack::Type SKind,
77 function_ref<bool(msgpack::DocNode &)> verifyValue) {
78 return verifyEntry(MapNode, Key, Required,
79 verifyNode: [this, SKind, verifyValue](msgpack::DocNode &Node) {
80 return verifyScalar(Node, SKind, verifyValue);
81 });
82}
83
84bool MetadataVerifier::verifyIntegerEntry(msgpack::MapDocNode &MapNode,
85 StringRef Key, bool Required) {
86 return verifyEntry(MapNode, Key, Required, verifyNode: [this](msgpack::DocNode &Node) {
87 return verifyInteger(Node);
88 });
89}
90
91bool MetadataVerifier::verifyKernelArgs(msgpack::DocNode &Node) {
92 if (!Node.isMap())
93 return false;
94 auto &ArgsMap = Node.getMap();
95
96 if (!verifyScalarEntry(MapNode&: ArgsMap, Key: ".name", Required: false,
97 SKind: msgpack::Type::String))
98 return false;
99 if (!verifyScalarEntry(MapNode&: ArgsMap, Key: ".type_name", Required: false,
100 SKind: msgpack::Type::String))
101 return false;
102 if (!verifyIntegerEntry(MapNode&: ArgsMap, Key: ".size", Required: true))
103 return false;
104 if (!verifyIntegerEntry(MapNode&: ArgsMap, Key: ".offset", Required: true))
105 return false;
106 if (!verifyScalarEntry(MapNode&: ArgsMap, Key: ".value_kind", Required: true, SKind: msgpack::Type::String,
107 verifyValue: [](msgpack::DocNode &SNode) {
108 return StringSwitch<bool>(SNode.getString())
109 .Case(S: "by_value", Value: true)
110 .Case(S: "global_buffer", Value: true)
111 .Case(S: "dynamic_shared_pointer", Value: true)
112 .Case(S: "sampler", Value: true)
113 .Case(S: "image", Value: true)
114 .Case(S: "pipe", Value: true)
115 .Case(S: "queue", Value: true)
116 .Case(S: "hidden_block_count_x", Value: true)
117 .Case(S: "hidden_block_count_y", Value: true)
118 .Case(S: "hidden_block_count_z", Value: true)
119 .Case(S: "hidden_group_size_x", Value: true)
120 .Case(S: "hidden_group_size_y", Value: true)
121 .Case(S: "hidden_group_size_z", Value: true)
122 .Case(S: "hidden_remainder_x", Value: true)
123 .Case(S: "hidden_remainder_y", Value: true)
124 .Case(S: "hidden_remainder_z", Value: true)
125 .Case(S: "hidden_global_offset_x", Value: true)
126 .Case(S: "hidden_global_offset_y", Value: true)
127 .Case(S: "hidden_global_offset_z", Value: true)
128 .Case(S: "hidden_grid_dims", Value: true)
129 .Case(S: "hidden_none", Value: true)
130 .Case(S: "hidden_printf_buffer", Value: true)
131 .Case(S: "hidden_hostcall_buffer", Value: true)
132 .Case(S: "hidden_heap_v1", Value: true)
133 .Case(S: "hidden_default_queue", Value: true)
134 .Case(S: "hidden_completion_action", Value: true)
135 .Case(S: "hidden_multigrid_sync_arg", Value: true)
136 .Case(S: "hidden_dynamic_lds_size", Value: true)
137 .Case(S: "hidden_private_base", Value: true)
138 .Case(S: "hidden_shared_base", Value: true)
139 .Case(S: "hidden_queue_ptr", Value: true)
140 .Default(Value: false);
141 }))
142 return false;
143 if (!verifyIntegerEntry(MapNode&: ArgsMap, Key: ".pointee_align", Required: false))
144 return false;
145 if (!verifyScalarEntry(MapNode&: ArgsMap, Key: ".address_space", Required: false,
146 SKind: msgpack::Type::String,
147 verifyValue: [](msgpack::DocNode &SNode) {
148 return StringSwitch<bool>(SNode.getString())
149 .Case(S: "private", Value: true)
150 .Case(S: "global", Value: true)
151 .Case(S: "constant", Value: true)
152 .Case(S: "local", Value: true)
153 .Case(S: "generic", Value: true)
154 .Case(S: "region", Value: true)
155 .Default(Value: false);
156 }))
157 return false;
158 if (!verifyScalarEntry(MapNode&: ArgsMap, Key: ".access", Required: false,
159 SKind: msgpack::Type::String,
160 verifyValue: [](msgpack::DocNode &SNode) {
161 return StringSwitch<bool>(SNode.getString())
162 .Case(S: "read_only", Value: true)
163 .Case(S: "write_only", Value: true)
164 .Case(S: "read_write", Value: true)
165 .Default(Value: false);
166 }))
167 return false;
168 if (!verifyScalarEntry(MapNode&: ArgsMap, Key: ".actual_access", Required: false,
169 SKind: msgpack::Type::String,
170 verifyValue: [](msgpack::DocNode &SNode) {
171 return StringSwitch<bool>(SNode.getString())
172 .Case(S: "read_only", Value: true)
173 .Case(S: "write_only", Value: true)
174 .Case(S: "read_write", Value: true)
175 .Default(Value: false);
176 }))
177 return false;
178 if (!verifyScalarEntry(MapNode&: ArgsMap, Key: ".is_const", Required: false,
179 SKind: msgpack::Type::Boolean))
180 return false;
181 if (!verifyScalarEntry(MapNode&: ArgsMap, Key: ".is_restrict", Required: false,
182 SKind: msgpack::Type::Boolean))
183 return false;
184 if (!verifyScalarEntry(MapNode&: ArgsMap, Key: ".is_volatile", Required: false,
185 SKind: msgpack::Type::Boolean))
186 return false;
187 if (!verifyScalarEntry(MapNode&: ArgsMap, Key: ".is_pipe", Required: false,
188 SKind: msgpack::Type::Boolean))
189 return false;
190
191 return true;
192}
193
194bool MetadataVerifier::verifyKernel(msgpack::DocNode &Node) {
195 if (!Node.isMap())
196 return false;
197 auto &KernelMap = Node.getMap();
198
199 if (!verifyScalarEntry(MapNode&: KernelMap, Key: ".name", Required: true,
200 SKind: msgpack::Type::String))
201 return false;
202 if (!verifyScalarEntry(MapNode&: KernelMap, Key: ".symbol", Required: true,
203 SKind: msgpack::Type::String))
204 return false;
205 if (!verifyScalarEntry(MapNode&: KernelMap, Key: ".language", Required: false,
206 SKind: msgpack::Type::String,
207 verifyValue: [](msgpack::DocNode &SNode) {
208 return StringSwitch<bool>(SNode.getString())
209 .Case(S: "OpenCL C", Value: true)
210 .Case(S: "OpenCL C++", Value: true)
211 .Case(S: "HCC", Value: true)
212 .Case(S: "HIP", Value: true)
213 .Case(S: "OpenMP", Value: true)
214 .Case(S: "Assembler", Value: true)
215 .Default(Value: false);
216 }))
217 return false;
218 if (!verifyEntry(
219 MapNode&: KernelMap, Key: ".language_version", Required: false, verifyNode: [this](msgpack::DocNode &Node) {
220 return verifyArray(
221 Node,
222 verifyNode: [this](msgpack::DocNode &Node) { return verifyInteger(Node); }, Size: 2);
223 }))
224 return false;
225 if (!verifyEntry(MapNode&: KernelMap, Key: ".args", Required: false, verifyNode: [this](msgpack::DocNode &Node) {
226 return verifyArray(Node, verifyNode: [this](msgpack::DocNode &Node) {
227 return verifyKernelArgs(Node);
228 });
229 }))
230 return false;
231 if (!verifyEntry(MapNode&: KernelMap, Key: ".reqd_workgroup_size", Required: false,
232 verifyNode: [this](msgpack::DocNode &Node) {
233 return verifyArray(Node,
234 verifyNode: [this](msgpack::DocNode &Node) {
235 return verifyInteger(Node);
236 },
237 Size: 3);
238 }))
239 return false;
240 if (!verifyEntry(MapNode&: KernelMap, Key: ".workgroup_size_hint", Required: false,
241 verifyNode: [this](msgpack::DocNode &Node) {
242 return verifyArray(Node,
243 verifyNode: [this](msgpack::DocNode &Node) {
244 return verifyInteger(Node);
245 },
246 Size: 3);
247 }))
248 return false;
249 if (!verifyScalarEntry(MapNode&: KernelMap, Key: ".vec_type_hint", Required: false,
250 SKind: msgpack::Type::String))
251 return false;
252 if (!verifyScalarEntry(MapNode&: KernelMap, Key: ".device_enqueue_symbol", Required: false,
253 SKind: msgpack::Type::String))
254 return false;
255 if (!verifyIntegerEntry(MapNode&: KernelMap, Key: ".kernarg_segment_size", Required: true))
256 return false;
257 if (!verifyIntegerEntry(MapNode&: KernelMap, Key: ".group_segment_fixed_size", Required: true))
258 return false;
259 if (!verifyIntegerEntry(MapNode&: KernelMap, Key: ".private_segment_fixed_size", Required: true))
260 return false;
261 if (!verifyScalarEntry(MapNode&: KernelMap, Key: ".uses_dynamic_stack", Required: false,
262 SKind: msgpack::Type::Boolean))
263 return false;
264 if (!verifyIntegerEntry(MapNode&: KernelMap, Key: ".workgroup_processor_mode", Required: false))
265 return false;
266 if (!verifyIntegerEntry(MapNode&: KernelMap, Key: ".kernarg_segment_align", Required: true))
267 return false;
268 if (!verifyIntegerEntry(MapNode&: KernelMap, Key: ".wavefront_size", Required: true))
269 return false;
270 if (!verifyIntegerEntry(MapNode&: KernelMap, Key: ".sgpr_count", Required: true))
271 return false;
272 if (!verifyIntegerEntry(MapNode&: KernelMap, Key: ".vgpr_count", Required: true))
273 return false;
274 if (!verifyIntegerEntry(MapNode&: KernelMap, Key: ".max_flat_workgroup_size", Required: true))
275 return false;
276 if (!verifyIntegerEntry(MapNode&: KernelMap, Key: ".sgpr_spill_count", Required: false))
277 return false;
278 if (!verifyIntegerEntry(MapNode&: KernelMap, Key: ".vgpr_spill_count", Required: false))
279 return false;
280 if (!verifyIntegerEntry(MapNode&: KernelMap, Key: ".uniform_work_group_size", Required: false))
281 return false;
282 if (!verifyEntry(
283 MapNode&: KernelMap, Key: ".cluster_dims", Required: false, verifyNode: [this](msgpack::DocNode &Node) {
284 return verifyArray(
285 Node,
286 verifyNode: [this](msgpack::DocNode &Node) { return verifyInteger(Node); },
287 Size: 3);
288 }))
289 return false;
290
291 return true;
292}
293
294bool MetadataVerifier::verify(msgpack::DocNode &HSAMetadataRoot) {
295 if (!HSAMetadataRoot.isMap())
296 return false;
297 auto &RootMap = HSAMetadataRoot.getMap();
298
299 if (!verifyEntry(
300 MapNode&: RootMap, Key: "amdhsa.version", Required: true, verifyNode: [this](msgpack::DocNode &Node) {
301 return verifyArray(
302 Node,
303 verifyNode: [this](msgpack::DocNode &Node) { return verifyInteger(Node); }, Size: 2);
304 }))
305 return false;
306 if (!verifyEntry(
307 MapNode&: RootMap, Key: "amdhsa.printf", Required: false, verifyNode: [this](msgpack::DocNode &Node) {
308 return verifyArray(Node, verifyNode: [this](msgpack::DocNode &Node) {
309 return verifyScalar(Node, SKind: msgpack::Type::String);
310 });
311 }))
312 return false;
313 if (!verifyEntry(MapNode&: RootMap, Key: "amdhsa.kernels", Required: true,
314 verifyNode: [this](msgpack::DocNode &Node) {
315 return verifyArray(Node, verifyNode: [this](msgpack::DocNode &Node) {
316 return verifyKernel(Node);
317 });
318 }))
319 return false;
320
321 return true;
322}
323
324} // end namespace V3
325} // end namespace HSAMD
326} // end namespace AMDGPU
327} // end namespace llvm
328