1//===------------ SPIRVMapping.h - SPIR-V Duplicates Tracker ----*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// General infrastructure for keeping track of the values that according to
10// the SPIR-V binary layout should be global to the whole module.
11//
12//===----------------------------------------------------------------------===//
13
14#ifndef LLVM_LIB_TARGET_SPIRV_SPIRVIRMAPPING_H
15#define LLVM_LIB_TARGET_SPIRV_SPIRVIRMAPPING_H
16
17#include "MCTargetDesc/SPIRVBaseInfo.h"
18#include "MCTargetDesc/SPIRVMCTargetDesc.h"
19#include "SPIRVUtils.h"
20#include "llvm/ADT/DenseMap.h"
21#include "llvm/ADT/Hashing.h"
22#include "llvm/CodeGen/GlobalISel/MachineIRBuilder.h"
23#include "llvm/CodeGen/MachineModuleInfo.h"
24
25#include <type_traits>
26
27namespace llvm {
28namespace SPIRV {
29
30inline size_t to_hash(const MachineInstr *MI) {
31 hash_code H = llvm::hash_combine(args: MI->getOpcode(), args: MI->getNumOperands());
32 for (unsigned I = MI->getNumDefs(); I < MI->getNumOperands(); ++I) {
33 const MachineOperand &MO = MI->getOperand(i: I);
34 if (MO.getType() == MachineOperand::MO_CImmediate)
35 H = llvm::hash_combine(args: H, args: MO.getType(), args: MO.getCImm());
36 else if (MO.getType() == MachineOperand::MO_FPImmediate)
37 H = llvm::hash_combine(args: H, args: MO.getType(), args: MO.getFPImm());
38 else
39 H = llvm::hash_combine(args: H, args: MO.getType());
40 }
41 return H;
42}
43
44using MIHandle = std::tuple<const MachineInstr *, Register, size_t>;
45
46inline MIHandle getMIKey(const MachineInstr *MI) {
47 return std::make_tuple(args&: MI, args: MI->getOperand(i: 0).getReg(), args: SPIRV::to_hash(MI));
48}
49
50using IRHandle = std::tuple<const void *, unsigned, unsigned>;
51using IRHandleMF = std::pair<IRHandle, const MachineFunction *>;
52
53inline IRHandleMF getIRHandleMF(IRHandle Handle, const MachineFunction *MF) {
54 return std::make_pair(x&: Handle, y&: MF);
55}
56
57enum SpecialTypeKind {
58 STK_Empty = 0,
59 STK_Image,
60 STK_SampledImage,
61 STK_Sampler,
62 STK_Pipe,
63 STK_DeviceEvent,
64 STK_ElementPointer,
65 STK_Type,
66 STK_Value,
67 STK_MachineInstr,
68 STK_VkBuffer,
69 STK_ExplictLayoutType,
70 STK_Last = -1
71};
72
73union ImageAttrs {
74 struct BitFlags {
75 unsigned Dim : 3;
76 unsigned Depth : 2;
77 unsigned Arrayed : 1;
78 unsigned MS : 1;
79 unsigned Sampled : 2;
80 unsigned ImageFormat : 6;
81 unsigned AQ : 2;
82 } Flags;
83 unsigned Val;
84
85 ImageAttrs(unsigned Dim, unsigned Depth, unsigned Arrayed, unsigned MS,
86 unsigned Sampled, unsigned ImageFormat, unsigned AQ = 0) {
87 Val = 0;
88 Flags.Dim = Dim;
89 Flags.Depth = Depth;
90 Flags.Arrayed = Arrayed;
91 Flags.MS = MS;
92 Flags.Sampled = Sampled;
93 Flags.ImageFormat = ImageFormat;
94 Flags.AQ = AQ;
95 }
96};
97
98inline IRHandle irhandle_image(const Type *SampledTy, unsigned Dim,
99 unsigned Depth, unsigned Arrayed, unsigned MS,
100 unsigned Sampled, unsigned ImageFormat,
101 unsigned AQ = 0) {
102 return std::make_tuple(
103 args&: SampledTy,
104 args: ImageAttrs(Dim, Depth, Arrayed, MS, Sampled, ImageFormat, AQ).Val,
105 args: SpecialTypeKind::STK_Image);
106}
107
108inline IRHandle irhandle_sampled_image(const Type *SampledTy,
109 const MachineInstr *ImageTy) {
110 assert(ImageTy->getOpcode() == SPIRV::OpTypeImage);
111 unsigned AC = AccessQualifier::AccessQualifier::None;
112 if (ImageTy->getNumOperands() > 8)
113 AC = ImageTy->getOperand(i: 8).getImm();
114 return std::make_tuple(
115 args&: SampledTy,
116 args: ImageAttrs(
117 ImageTy->getOperand(i: 2).getImm(), ImageTy->getOperand(i: 3).getImm(),
118 ImageTy->getOperand(i: 4).getImm(), ImageTy->getOperand(i: 5).getImm(),
119 ImageTy->getOperand(i: 6).getImm(), ImageTy->getOperand(i: 7).getImm(), AC)
120 .Val,
121 args: SpecialTypeKind::STK_SampledImage);
122}
123
124inline IRHandle irhandle_sampler() {
125 return std::make_tuple(args: nullptr, args: 0U, args: SpecialTypeKind::STK_Sampler);
126}
127
128inline IRHandle irhandle_pipe(uint8_t AQ) {
129 return std::make_tuple(args: nullptr, args&: AQ, args: SpecialTypeKind::STK_Pipe);
130}
131
132inline IRHandle irhandle_event() {
133 return std::make_tuple(args: nullptr, args: 0U, args: SpecialTypeKind::STK_DeviceEvent);
134}
135
136inline IRHandle irhandle_pointee(const Type *ElementType,
137 unsigned AddressSpace) {
138 return std::make_tuple(args: unifyPtrType(Ty: ElementType), args&: AddressSpace,
139 args: SpecialTypeKind::STK_ElementPointer);
140}
141
142inline IRHandle irhandle_ptr(const void *Ptr, unsigned Arg,
143 enum SpecialTypeKind STK) {
144 return std::make_tuple(args&: Ptr, args&: Arg, args&: STK);
145}
146
147inline IRHandle irhandle_vkbuffer(const Type *ElementType,
148 StorageClass::StorageClass SC,
149 bool IsWriteable) {
150 return std::make_tuple(args&: ElementType, args: (SC << 1) | IsWriteable,
151 args: SpecialTypeKind::STK_VkBuffer);
152}
153
154inline IRHandle irhandle_explict_layout_type(const Type *Ty) {
155 const Type *WrpTy = unifyPtrType(Ty);
156 return irhandle_ptr(Ptr: WrpTy, Arg: Ty->getTypeID(), STK: STK_ExplictLayoutType);
157}
158
159inline IRHandle handle(const Type *Ty) {
160 const Type *WrpTy = unifyPtrType(Ty);
161 return irhandle_ptr(Ptr: WrpTy, Arg: Ty->getTypeID(), STK: STK_Type);
162}
163
164inline IRHandle handle(const Value *V) {
165 return irhandle_ptr(Ptr: V, Arg: V->getValueID(), STK: STK_Value);
166}
167
168inline IRHandle handle(const MachineInstr *KeyMI) {
169 return irhandle_ptr(Ptr: KeyMI, Arg: SPIRV::to_hash(MI: KeyMI), STK: STK_MachineInstr);
170}
171
172inline bool type_has_layout_decoration(const Type *T) {
173 return (isa<StructType>(Val: T) || isa<ArrayType>(Val: T));
174}
175
176} // namespace SPIRV
177
178// Bi-directional mappings between LLVM entities and (v-reg, machine function)
179// pairs support management of unique SPIR-V definitions per machine function
180// per an LLVM/GlobalISel entity (e.g., Type, Constant, Machine Instruction).
181class SPIRVIRMapping {
182 DenseMap<SPIRV::IRHandleMF, SPIRV::MIHandle> Vregs;
183 DenseMap<const MachineInstr *, SPIRV::IRHandleMF> Defs;
184
185public:
186 bool add(SPIRV::IRHandle Handle, const MachineInstr *MI) {
187 if (auto DefIt = Defs.find(Val: MI); DefIt != Defs.end()) {
188 auto [ExistHandle, ExistMF] = DefIt->second;
189 if (Handle == ExistHandle && MI->getMF() == ExistMF)
190 return false; // already exists
191 // invalidate the record
192 Vregs.erase(Val: DefIt->second);
193 Defs.erase(I: DefIt);
194 }
195 SPIRV::IRHandleMF HandleMF = SPIRV::getIRHandleMF(Handle, MF: MI->getMF());
196 SPIRV::MIHandle MIKey = SPIRV::getMIKey(MI);
197 auto It1 = Vregs.try_emplace(Key: HandleMF, Args&: MIKey);
198 if (!It1.second) {
199 // there is an expired record that we need to invalidate
200 Defs.erase(Val: std::get<0>(t&: It1.first->second));
201 // update the record
202 It1.first->second = MIKey;
203 }
204 [[maybe_unused]] auto It2 = Defs.try_emplace(Key: MI, Args&: HandleMF);
205 assert(It2.second);
206 return true;
207 }
208 bool erase(const MachineInstr *MI) {
209 bool Res = false;
210 if (auto It = Defs.find(Val: MI); It != Defs.end()) {
211 Res = Vregs.erase(Val: It->second);
212 Defs.erase(I: It);
213 }
214 return Res;
215 }
216 const MachineInstr *findMI(SPIRV::IRHandle Handle,
217 const MachineFunction *MF) {
218 SPIRV::IRHandleMF HandleMF = SPIRV::getIRHandleMF(Handle, MF);
219 auto It = Vregs.find(Val: HandleMF);
220 if (It == Vregs.end())
221 return nullptr;
222 auto [MI, Reg, Hash] = It->second;
223 const MachineInstr *Def = MF->getRegInfo().getVRegDef(Reg);
224 if (!Def || Def != MI || SPIRV::to_hash(MI) != Hash) {
225 // there is an expired record that we need to invalidate
226 erase(MI);
227 return nullptr;
228 }
229 assert(Defs.contains(MI) && Defs.find(MI)->second == HandleMF);
230 return MI;
231 }
232 Register find(SPIRV::IRHandle Handle, const MachineFunction *MF) {
233 const MachineInstr *MI = findMI(Handle, MF);
234 return MI ? MI->getOperand(i: 0).getReg() : Register();
235 }
236
237 // helpers
238 bool add(const Type *PointeeTy, unsigned AddressSpace,
239 const MachineInstr *MI) {
240 return add(Handle: SPIRV::irhandle_pointee(ElementType: PointeeTy, AddressSpace), MI);
241 }
242 Register find(const Type *PointeeTy, unsigned AddressSpace,
243 const MachineFunction *MF) {
244 return find(Handle: SPIRV::irhandle_pointee(ElementType: PointeeTy, AddressSpace), MF);
245 }
246 const MachineInstr *findMI(const Type *PointeeTy, unsigned AddressSpace,
247 const MachineFunction *MF) {
248 return findMI(Handle: SPIRV::irhandle_pointee(ElementType: PointeeTy, AddressSpace), MF);
249 }
250
251 bool add(const Value *V, const MachineInstr *MI) {
252 return add(Handle: SPIRV::handle(V), MI);
253 }
254
255 bool add(const Type *T, bool RequiresExplicitLayout, const MachineInstr *MI) {
256 if (RequiresExplicitLayout && SPIRV::type_has_layout_decoration(T)) {
257 return add(Handle: SPIRV::irhandle_explict_layout_type(Ty: T), MI);
258 }
259 return add(Handle: SPIRV::handle(Ty: T), MI);
260 }
261
262 bool add(const MachineInstr *Obj, const MachineInstr *MI) {
263 return add(Handle: SPIRV::handle(KeyMI: Obj), MI);
264 }
265
266 Register find(const Value *V, const MachineFunction *MF) {
267 return find(Handle: SPIRV::handle(V), MF);
268 }
269
270 Register find(const Type *T, bool RequiresExplicitLayout,
271 const MachineFunction *MF) {
272 if (RequiresExplicitLayout && SPIRV::type_has_layout_decoration(T))
273 return find(Handle: SPIRV::irhandle_explict_layout_type(Ty: T), MF);
274 return find(Handle: SPIRV::handle(Ty: T), MF);
275 }
276
277 Register find(const MachineInstr *MI, const MachineFunction *MF) {
278 return find(Handle: SPIRV::handle(KeyMI: MI), MF);
279 }
280
281 const MachineInstr *findMI(const Value *Obj, const MachineFunction *MF) {
282 return findMI(Handle: SPIRV::handle(V: Obj), MF);
283 }
284
285 const MachineInstr *findMI(const Type *T, bool RequiresExplicitLayout,
286 const MachineFunction *MF) {
287 if (RequiresExplicitLayout && SPIRV::type_has_layout_decoration(T))
288 return findMI(Handle: SPIRV::irhandle_explict_layout_type(Ty: T), MF);
289 return findMI(Handle: SPIRV::handle(Ty: T), MF);
290 }
291
292 const MachineInstr *findMI(const MachineInstr *Obj,
293 const MachineFunction *MF) {
294 return findMI(Handle: SPIRV::handle(KeyMI: Obj), MF);
295 }
296};
297} // namespace llvm
298#endif // LLVM_LIB_TARGET_SPIRV_SPIRVIRMAPPING_H
299