1//===- AMDGPUMetadataVerifier.cpp - MsgPack Types ---------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9/// \file
10/// Implements a verifier for AMDGPU HSA metadata.
11//
12//===----------------------------------------------------------------------===//
13
14#include "llvm/BinaryFormat/AMDGPUMetadataVerifier.h"
15
16#include "llvm/ADT/STLExtras.h"
17#include "llvm/ADT/StringSwitch.h"
18#include "llvm/BinaryFormat/MsgPackDocument.h"
19
20#include <utility>
21
22namespace llvm {
23namespace AMDGPU {
24namespace HSAMD {
25namespace V3 {
26
27bool MetadataVerifier::verifyScalar(
28 msgpack::DocNode &Node, msgpack::Type SKind,
29 function_ref<bool(msgpack::DocNode &)> verifyValue) {
30 if (!Node.isScalar())
31 return false;
32 if (Node.getKind() != SKind) {
33 if (Strict)
34 return false;
35 // If we are not strict, we interpret string values as "implicitly typed"
36 // and attempt to coerce them to the expected type here.
37 if (Node.getKind() != msgpack::Type::String)
38 return false;
39 StringRef StringValue = Node.getString();
40 Node.fromString(S: StringValue);
41 if (Node.getKind() != SKind)
42 return false;
43 }
44 if (verifyValue)
45 return verifyValue(Node);
46 return true;
47}
48
49bool MetadataVerifier::verifyInteger(msgpack::DocNode &Node) {
50 if (!verifyScalar(Node, SKind: msgpack::Type::UInt))
51 if (!verifyScalar(Node, SKind: msgpack::Type::Int))
52 return false;
53 return true;
54}
55
56bool MetadataVerifier::verifyArray(
57 msgpack::DocNode &Node, function_ref<bool(msgpack::DocNode &)> verifyNode,
58 std::optional<size_t> Size) {
59 if (!Node.isArray())
60 return false;
61 auto &Array = Node.getArray();
62 if (Size && Array.size() != *Size)
63 return false;
64 return llvm::all_of(Range&: Array, P: verifyNode);
65}
66
67bool MetadataVerifier::verifyEntry(
68 msgpack::MapDocNode &MapNode, StringRef Key, bool Required,
69 function_ref<bool(msgpack::DocNode &)> verifyNode) {
70 auto Entry = MapNode.find(Key);
71 if (Entry == MapNode.end())
72 return !Required;
73 return verifyNode(Entry->second);
74}
75
76bool MetadataVerifier::verifyScalarEntry(
77 msgpack::MapDocNode &MapNode, StringRef Key, bool Required,
78 msgpack::Type SKind,
79 function_ref<bool(msgpack::DocNode &)> verifyValue) {
80 return verifyEntry(MapNode, Key, Required, verifyNode: [=](msgpack::DocNode &Node) {
81 return verifyScalar(Node, SKind, verifyValue);
82 });
83}
84
85bool MetadataVerifier::verifyIntegerEntry(msgpack::MapDocNode &MapNode,
86 StringRef Key, bool Required) {
87 return verifyEntry(MapNode, Key, Required, verifyNode: [this](msgpack::DocNode &Node) {
88 return verifyInteger(Node);
89 });
90}
91
92bool MetadataVerifier::verifyKernelArgs(msgpack::DocNode &Node) {
93 if (!Node.isMap())
94 return false;
95 auto &ArgsMap = Node.getMap();
96
97 if (!verifyScalarEntry(MapNode&: ArgsMap, Key: ".name", Required: false,
98 SKind: msgpack::Type::String))
99 return false;
100 if (!verifyScalarEntry(MapNode&: ArgsMap, Key: ".type_name", Required: false,
101 SKind: msgpack::Type::String))
102 return false;
103 if (!verifyIntegerEntry(MapNode&: ArgsMap, Key: ".size", Required: true))
104 return false;
105 if (!verifyIntegerEntry(MapNode&: ArgsMap, Key: ".offset", Required: true))
106 return false;
107 if (!verifyScalarEntry(MapNode&: ArgsMap, Key: ".value_kind", Required: true, SKind: msgpack::Type::String,
108 verifyValue: [](msgpack::DocNode &SNode) {
109 return StringSwitch<bool>(SNode.getString())
110 .Case(S: "by_value", Value: true)
111 .Case(S: "global_buffer", Value: true)
112 .Case(S: "dynamic_shared_pointer", Value: true)
113 .Case(S: "sampler", Value: true)
114 .Case(S: "image", Value: true)
115 .Case(S: "pipe", Value: true)
116 .Case(S: "queue", Value: true)
117 .Case(S: "hidden_block_count_x", Value: true)
118 .Case(S: "hidden_block_count_y", Value: true)
119 .Case(S: "hidden_block_count_z", Value: true)
120 .Case(S: "hidden_group_size_x", Value: true)
121 .Case(S: "hidden_group_size_y", Value: true)
122 .Case(S: "hidden_group_size_z", Value: true)
123 .Case(S: "hidden_remainder_x", Value: true)
124 .Case(S: "hidden_remainder_y", Value: true)
125 .Case(S: "hidden_remainder_z", Value: true)
126 .Case(S: "hidden_global_offset_x", Value: true)
127 .Case(S: "hidden_global_offset_y", Value: true)
128 .Case(S: "hidden_global_offset_z", Value: true)
129 .Case(S: "hidden_grid_dims", Value: true)
130 .Case(S: "hidden_none", Value: true)
131 .Case(S: "hidden_printf_buffer", Value: true)
132 .Case(S: "hidden_hostcall_buffer", Value: true)
133 .Case(S: "hidden_heap_v1", Value: true)
134 .Case(S: "hidden_default_queue", Value: true)
135 .Case(S: "hidden_completion_action", Value: true)
136 .Case(S: "hidden_multigrid_sync_arg", Value: true)
137 .Case(S: "hidden_dynamic_lds_size", Value: true)
138 .Case(S: "hidden_private_base", Value: true)
139 .Case(S: "hidden_shared_base", Value: true)
140 .Case(S: "hidden_queue_ptr", Value: true)
141 .Default(Value: false);
142 }))
143 return false;
144 if (!verifyIntegerEntry(MapNode&: ArgsMap, Key: ".pointee_align", Required: false))
145 return false;
146 if (!verifyScalarEntry(MapNode&: ArgsMap, Key: ".address_space", Required: false,
147 SKind: msgpack::Type::String,
148 verifyValue: [](msgpack::DocNode &SNode) {
149 return StringSwitch<bool>(SNode.getString())
150 .Case(S: "private", Value: true)
151 .Case(S: "global", Value: true)
152 .Case(S: "constant", Value: true)
153 .Case(S: "local", Value: true)
154 .Case(S: "generic", Value: true)
155 .Case(S: "region", Value: true)
156 .Default(Value: false);
157 }))
158 return false;
159 if (!verifyScalarEntry(MapNode&: ArgsMap, Key: ".access", Required: false,
160 SKind: msgpack::Type::String,
161 verifyValue: [](msgpack::DocNode &SNode) {
162 return StringSwitch<bool>(SNode.getString())
163 .Case(S: "read_only", Value: true)
164 .Case(S: "write_only", Value: true)
165 .Case(S: "read_write", Value: true)
166 .Default(Value: false);
167 }))
168 return false;
169 if (!verifyScalarEntry(MapNode&: ArgsMap, Key: ".actual_access", Required: false,
170 SKind: msgpack::Type::String,
171 verifyValue: [](msgpack::DocNode &SNode) {
172 return StringSwitch<bool>(SNode.getString())
173 .Case(S: "read_only", Value: true)
174 .Case(S: "write_only", Value: true)
175 .Case(S: "read_write", Value: true)
176 .Default(Value: false);
177 }))
178 return false;
179 if (!verifyScalarEntry(MapNode&: ArgsMap, Key: ".is_const", Required: false,
180 SKind: msgpack::Type::Boolean))
181 return false;
182 if (!verifyScalarEntry(MapNode&: ArgsMap, Key: ".is_restrict", Required: false,
183 SKind: msgpack::Type::Boolean))
184 return false;
185 if (!verifyScalarEntry(MapNode&: ArgsMap, Key: ".is_volatile", Required: false,
186 SKind: msgpack::Type::Boolean))
187 return false;
188 if (!verifyScalarEntry(MapNode&: ArgsMap, Key: ".is_pipe", Required: false,
189 SKind: msgpack::Type::Boolean))
190 return false;
191
192 return true;
193}
194
195bool MetadataVerifier::verifyKernel(msgpack::DocNode &Node) {
196 if (!Node.isMap())
197 return false;
198 auto &KernelMap = Node.getMap();
199
200 if (!verifyScalarEntry(MapNode&: KernelMap, Key: ".name", Required: true,
201 SKind: msgpack::Type::String))
202 return false;
203 if (!verifyScalarEntry(MapNode&: KernelMap, Key: ".symbol", Required: true,
204 SKind: msgpack::Type::String))
205 return false;
206 if (!verifyScalarEntry(MapNode&: KernelMap, Key: ".language", Required: false,
207 SKind: msgpack::Type::String,
208 verifyValue: [](msgpack::DocNode &SNode) {
209 return StringSwitch<bool>(SNode.getString())
210 .Case(S: "OpenCL C", Value: true)
211 .Case(S: "OpenCL C++", Value: true)
212 .Case(S: "HCC", Value: true)
213 .Case(S: "HIP", Value: true)
214 .Case(S: "OpenMP", Value: true)
215 .Case(S: "Assembler", Value: true)
216 .Default(Value: false);
217 }))
218 return false;
219 if (!verifyEntry(
220 MapNode&: KernelMap, Key: ".language_version", Required: false, verifyNode: [this](msgpack::DocNode &Node) {
221 return verifyArray(
222 Node,
223 verifyNode: [this](msgpack::DocNode &Node) { return verifyInteger(Node); }, Size: 2);
224 }))
225 return false;
226 if (!verifyEntry(MapNode&: KernelMap, Key: ".args", Required: false, verifyNode: [this](msgpack::DocNode &Node) {
227 return verifyArray(Node, verifyNode: [this](msgpack::DocNode &Node) {
228 return verifyKernelArgs(Node);
229 });
230 }))
231 return false;
232 if (!verifyEntry(MapNode&: KernelMap, Key: ".reqd_workgroup_size", Required: false,
233 verifyNode: [this](msgpack::DocNode &Node) {
234 return verifyArray(Node,
235 verifyNode: [this](msgpack::DocNode &Node) {
236 return verifyInteger(Node);
237 },
238 Size: 3);
239 }))
240 return false;
241 if (!verifyEntry(MapNode&: KernelMap, Key: ".workgroup_size_hint", Required: false,
242 verifyNode: [this](msgpack::DocNode &Node) {
243 return verifyArray(Node,
244 verifyNode: [this](msgpack::DocNode &Node) {
245 return verifyInteger(Node);
246 },
247 Size: 3);
248 }))
249 return false;
250 if (!verifyScalarEntry(MapNode&: KernelMap, Key: ".vec_type_hint", Required: false,
251 SKind: msgpack::Type::String))
252 return false;
253 if (!verifyScalarEntry(MapNode&: KernelMap, Key: ".device_enqueue_symbol", Required: false,
254 SKind: msgpack::Type::String))
255 return false;
256 if (!verifyIntegerEntry(MapNode&: KernelMap, Key: ".kernarg_segment_size", Required: true))
257 return false;
258 if (!verifyIntegerEntry(MapNode&: KernelMap, Key: ".group_segment_fixed_size", Required: true))
259 return false;
260 if (!verifyIntegerEntry(MapNode&: KernelMap, Key: ".private_segment_fixed_size", Required: true))
261 return false;
262 if (!verifyScalarEntry(MapNode&: KernelMap, Key: ".uses_dynamic_stack", Required: false,
263 SKind: msgpack::Type::Boolean))
264 return false;
265 if (!verifyIntegerEntry(MapNode&: KernelMap, Key: ".workgroup_processor_mode", Required: false))
266 return false;
267 if (!verifyIntegerEntry(MapNode&: KernelMap, Key: ".kernarg_segment_align", Required: true))
268 return false;
269 if (!verifyIntegerEntry(MapNode&: KernelMap, Key: ".wavefront_size", Required: true))
270 return false;
271 if (!verifyIntegerEntry(MapNode&: KernelMap, Key: ".sgpr_count", Required: true))
272 return false;
273 if (!verifyIntegerEntry(MapNode&: KernelMap, Key: ".vgpr_count", Required: true))
274 return false;
275 if (!verifyIntegerEntry(MapNode&: KernelMap, Key: ".max_flat_workgroup_size", Required: true))
276 return false;
277 if (!verifyIntegerEntry(MapNode&: KernelMap, Key: ".sgpr_spill_count", Required: false))
278 return false;
279 if (!verifyIntegerEntry(MapNode&: KernelMap, Key: ".vgpr_spill_count", Required: false))
280 return false;
281 if (!verifyIntegerEntry(MapNode&: KernelMap, Key: ".uniform_work_group_size", Required: false))
282 return false;
283
284
285 return true;
286}
287
288bool MetadataVerifier::verify(msgpack::DocNode &HSAMetadataRoot) {
289 if (!HSAMetadataRoot.isMap())
290 return false;
291 auto &RootMap = HSAMetadataRoot.getMap();
292
293 if (!verifyEntry(
294 MapNode&: RootMap, Key: "amdhsa.version", Required: true, verifyNode: [this](msgpack::DocNode &Node) {
295 return verifyArray(
296 Node,
297 verifyNode: [this](msgpack::DocNode &Node) { return verifyInteger(Node); }, Size: 2);
298 }))
299 return false;
300 if (!verifyEntry(
301 MapNode&: RootMap, Key: "amdhsa.printf", Required: false, verifyNode: [this](msgpack::DocNode &Node) {
302 return verifyArray(Node, verifyNode: [this](msgpack::DocNode &Node) {
303 return verifyScalar(Node, SKind: msgpack::Type::String);
304 });
305 }))
306 return false;
307 if (!verifyEntry(MapNode&: RootMap, Key: "amdhsa.kernels", Required: true,
308 verifyNode: [this](msgpack::DocNode &Node) {
309 return verifyArray(Node, verifyNode: [this](msgpack::DocNode &Node) {
310 return verifyKernel(Node);
311 });
312 }))
313 return false;
314
315 return true;
316}
317
318} // end namespace V3
319} // end namespace HSAMD
320} // end namespace AMDGPU
321} // end namespace llvm
322