1 | //===------------ SPIRVMapping.h - SPIR-V Duplicates Tracker ----*- C++ -*-===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // General infrastructure for keeping track of the values that according to |
10 | // the SPIR-V binary layout should be global to the whole module. |
11 | // |
12 | //===----------------------------------------------------------------------===// |
13 | |
14 | #ifndef LLVM_LIB_TARGET_SPIRV_SPIRVIRMAPPING_H |
15 | #define LLVM_LIB_TARGET_SPIRV_SPIRVIRMAPPING_H |
16 | |
17 | #include "MCTargetDesc/SPIRVBaseInfo.h" |
18 | #include "MCTargetDesc/SPIRVMCTargetDesc.h" |
19 | #include "SPIRVUtils.h" |
20 | #include "llvm/ADT/DenseMap.h" |
21 | #include "llvm/ADT/Hashing.h" |
22 | #include "llvm/CodeGen/GlobalISel/MachineIRBuilder.h" |
23 | #include "llvm/CodeGen/MachineModuleInfo.h" |
24 | |
25 | #include <type_traits> |
26 | |
27 | namespace llvm { |
28 | namespace SPIRV { |
29 | |
30 | inline size_t to_hash(const MachineInstr *MI) { |
31 | hash_code H = llvm::hash_combine(args: MI->getOpcode(), args: MI->getNumOperands()); |
32 | for (unsigned I = MI->getNumDefs(); I < MI->getNumOperands(); ++I) { |
33 | const MachineOperand &MO = MI->getOperand(i: I); |
34 | if (MO.getType() == MachineOperand::MO_CImmediate) |
35 | H = llvm::hash_combine(args: H, args: MO.getType(), args: MO.getCImm()); |
36 | else if (MO.getType() == MachineOperand::MO_FPImmediate) |
37 | H = llvm::hash_combine(args: H, args: MO.getType(), args: MO.getFPImm()); |
38 | else |
39 | H = llvm::hash_combine(args: H, args: MO.getType()); |
40 | } |
41 | return H; |
42 | } |
43 | |
44 | using MIHandle = std::tuple<const MachineInstr *, Register, size_t>; |
45 | |
46 | inline MIHandle getMIKey(const MachineInstr *MI) { |
47 | return std::make_tuple(args&: MI, args: MI->getOperand(i: 0).getReg(), args: SPIRV::to_hash(MI)); |
48 | } |
49 | |
50 | using IRHandle = std::tuple<const void *, unsigned, unsigned>; |
51 | using IRHandleMF = std::pair<IRHandle, const MachineFunction *>; |
52 | |
53 | inline IRHandleMF getIRHandleMF(IRHandle Handle, const MachineFunction *MF) { |
54 | return std::make_pair(x&: Handle, y&: MF); |
55 | } |
56 | |
57 | enum SpecialTypeKind { |
58 | STK_Empty = 0, |
59 | STK_Image, |
60 | STK_SampledImage, |
61 | STK_Sampler, |
62 | STK_Pipe, |
63 | STK_DeviceEvent, |
64 | STK_ElementPointer, |
65 | STK_Type, |
66 | STK_Value, |
67 | STK_MachineInstr, |
68 | STK_VkBuffer, |
69 | STK_ExplictLayoutType, |
70 | STK_Last = -1 |
71 | }; |
72 | |
73 | union ImageAttrs { |
74 | struct BitFlags { |
75 | unsigned Dim : 3; |
76 | unsigned Depth : 2; |
77 | unsigned Arrayed : 1; |
78 | unsigned MS : 1; |
79 | unsigned Sampled : 2; |
80 | unsigned ImageFormat : 6; |
81 | unsigned AQ : 2; |
82 | } Flags; |
83 | unsigned Val; |
84 | |
85 | ImageAttrs(unsigned Dim, unsigned Depth, unsigned Arrayed, unsigned MS, |
86 | unsigned Sampled, unsigned ImageFormat, unsigned AQ = 0) { |
87 | Val = 0; |
88 | Flags.Dim = Dim; |
89 | Flags.Depth = Depth; |
90 | Flags.Arrayed = Arrayed; |
91 | Flags.MS = MS; |
92 | Flags.Sampled = Sampled; |
93 | Flags.ImageFormat = ImageFormat; |
94 | Flags.AQ = AQ; |
95 | } |
96 | }; |
97 | |
98 | inline IRHandle irhandle_image(const Type *SampledTy, unsigned Dim, |
99 | unsigned Depth, unsigned Arrayed, unsigned MS, |
100 | unsigned Sampled, unsigned ImageFormat, |
101 | unsigned AQ = 0) { |
102 | return std::make_tuple( |
103 | args&: SampledTy, |
104 | args: ImageAttrs(Dim, Depth, Arrayed, MS, Sampled, ImageFormat, AQ).Val, |
105 | args: SpecialTypeKind::STK_Image); |
106 | } |
107 | |
108 | inline IRHandle irhandle_sampled_image(const Type *SampledTy, |
109 | const MachineInstr *ImageTy) { |
110 | assert(ImageTy->getOpcode() == SPIRV::OpTypeImage); |
111 | unsigned AC = AccessQualifier::AccessQualifier::None; |
112 | if (ImageTy->getNumOperands() > 8) |
113 | AC = ImageTy->getOperand(i: 8).getImm(); |
114 | return std::make_tuple( |
115 | args&: SampledTy, |
116 | args: ImageAttrs( |
117 | ImageTy->getOperand(i: 2).getImm(), ImageTy->getOperand(i: 3).getImm(), |
118 | ImageTy->getOperand(i: 4).getImm(), ImageTy->getOperand(i: 5).getImm(), |
119 | ImageTy->getOperand(i: 6).getImm(), ImageTy->getOperand(i: 7).getImm(), AC) |
120 | .Val, |
121 | args: SpecialTypeKind::STK_SampledImage); |
122 | } |
123 | |
124 | inline IRHandle irhandle_sampler() { |
125 | return std::make_tuple(args: nullptr, args: 0U, args: SpecialTypeKind::STK_Sampler); |
126 | } |
127 | |
128 | inline IRHandle irhandle_pipe(uint8_t AQ) { |
129 | return std::make_tuple(args: nullptr, args&: AQ, args: SpecialTypeKind::STK_Pipe); |
130 | } |
131 | |
132 | inline IRHandle irhandle_event() { |
133 | return std::make_tuple(args: nullptr, args: 0U, args: SpecialTypeKind::STK_DeviceEvent); |
134 | } |
135 | |
136 | inline IRHandle irhandle_pointee(const Type *ElementType, |
137 | unsigned AddressSpace) { |
138 | return std::make_tuple(args: unifyPtrType(Ty: ElementType), args&: AddressSpace, |
139 | args: SpecialTypeKind::STK_ElementPointer); |
140 | } |
141 | |
142 | inline IRHandle irhandle_ptr(const void *Ptr, unsigned Arg, |
143 | enum SpecialTypeKind STK) { |
144 | return std::make_tuple(args&: Ptr, args&: Arg, args&: STK); |
145 | } |
146 | |
147 | inline IRHandle irhandle_vkbuffer(const Type *ElementType, |
148 | StorageClass::StorageClass SC, |
149 | bool IsWriteable) { |
150 | return std::make_tuple(args&: ElementType, args: (SC << 1) | IsWriteable, |
151 | args: SpecialTypeKind::STK_VkBuffer); |
152 | } |
153 | |
154 | inline IRHandle irhandle_explict_layout_type(const Type *Ty) { |
155 | const Type *WrpTy = unifyPtrType(Ty); |
156 | return irhandle_ptr(Ptr: WrpTy, Arg: Ty->getTypeID(), STK: STK_ExplictLayoutType); |
157 | } |
158 | |
159 | inline IRHandle handle(const Type *Ty) { |
160 | const Type *WrpTy = unifyPtrType(Ty); |
161 | return irhandle_ptr(Ptr: WrpTy, Arg: Ty->getTypeID(), STK: STK_Type); |
162 | } |
163 | |
164 | inline IRHandle handle(const Value *V) { |
165 | return irhandle_ptr(Ptr: V, Arg: V->getValueID(), STK: STK_Value); |
166 | } |
167 | |
168 | inline IRHandle handle(const MachineInstr *KeyMI) { |
169 | return irhandle_ptr(Ptr: KeyMI, Arg: SPIRV::to_hash(MI: KeyMI), STK: STK_MachineInstr); |
170 | } |
171 | |
172 | inline bool type_has_layout_decoration(const Type *T) { |
173 | return (isa<StructType>(Val: T) || isa<ArrayType>(Val: T)); |
174 | } |
175 | |
176 | } // namespace SPIRV |
177 | |
178 | // Bi-directional mappings between LLVM entities and (v-reg, machine function) |
179 | // pairs support management of unique SPIR-V definitions per machine function |
180 | // per an LLVM/GlobalISel entity (e.g., Type, Constant, Machine Instruction). |
181 | class SPIRVIRMapping { |
182 | DenseMap<SPIRV::IRHandleMF, SPIRV::MIHandle> Vregs; |
183 | DenseMap<const MachineInstr *, SPIRV::IRHandleMF> Defs; |
184 | |
185 | public: |
186 | bool add(SPIRV::IRHandle Handle, const MachineInstr *MI) { |
187 | if (auto DefIt = Defs.find(Val: MI); DefIt != Defs.end()) { |
188 | auto [ExistHandle, ExistMF] = DefIt->second; |
189 | if (Handle == ExistHandle && MI->getMF() == ExistMF) |
190 | return false; // already exists |
191 | // invalidate the record |
192 | Vregs.erase(Val: DefIt->second); |
193 | Defs.erase(I: DefIt); |
194 | } |
195 | SPIRV::IRHandleMF HandleMF = SPIRV::getIRHandleMF(Handle, MF: MI->getMF()); |
196 | SPIRV::MIHandle MIKey = SPIRV::getMIKey(MI); |
197 | auto It1 = Vregs.try_emplace(Key: HandleMF, Args&: MIKey); |
198 | if (!It1.second) { |
199 | // there is an expired record that we need to invalidate |
200 | Defs.erase(Val: std::get<0>(t&: It1.first->second)); |
201 | // update the record |
202 | It1.first->second = MIKey; |
203 | } |
204 | [[maybe_unused]] auto It2 = Defs.try_emplace(Key: MI, Args&: HandleMF); |
205 | assert(It2.second); |
206 | return true; |
207 | } |
208 | bool erase(const MachineInstr *MI) { |
209 | bool Res = false; |
210 | if (auto It = Defs.find(Val: MI); It != Defs.end()) { |
211 | Res = Vregs.erase(Val: It->second); |
212 | Defs.erase(I: It); |
213 | } |
214 | return Res; |
215 | } |
216 | const MachineInstr *findMI(SPIRV::IRHandle Handle, |
217 | const MachineFunction *MF) { |
218 | SPIRV::IRHandleMF HandleMF = SPIRV::getIRHandleMF(Handle, MF); |
219 | auto It = Vregs.find(Val: HandleMF); |
220 | if (It == Vregs.end()) |
221 | return nullptr; |
222 | auto [MI, Reg, Hash] = It->second; |
223 | const MachineInstr *Def = MF->getRegInfo().getVRegDef(Reg); |
224 | if (!Def || Def != MI || SPIRV::to_hash(MI) != Hash) { |
225 | // there is an expired record that we need to invalidate |
226 | erase(MI); |
227 | return nullptr; |
228 | } |
229 | assert(Defs.contains(MI) && Defs.find(MI)->second == HandleMF); |
230 | return MI; |
231 | } |
232 | Register find(SPIRV::IRHandle Handle, const MachineFunction *MF) { |
233 | const MachineInstr *MI = findMI(Handle, MF); |
234 | return MI ? MI->getOperand(i: 0).getReg() : Register(); |
235 | } |
236 | |
237 | // helpers |
238 | bool add(const Type *PointeeTy, unsigned AddressSpace, |
239 | const MachineInstr *MI) { |
240 | return add(Handle: SPIRV::irhandle_pointee(ElementType: PointeeTy, AddressSpace), MI); |
241 | } |
242 | Register find(const Type *PointeeTy, unsigned AddressSpace, |
243 | const MachineFunction *MF) { |
244 | return find(Handle: SPIRV::irhandle_pointee(ElementType: PointeeTy, AddressSpace), MF); |
245 | } |
246 | const MachineInstr *findMI(const Type *PointeeTy, unsigned AddressSpace, |
247 | const MachineFunction *MF) { |
248 | return findMI(Handle: SPIRV::irhandle_pointee(ElementType: PointeeTy, AddressSpace), MF); |
249 | } |
250 | |
251 | bool add(const Value *V, const MachineInstr *MI) { |
252 | return add(Handle: SPIRV::handle(V), MI); |
253 | } |
254 | |
255 | bool add(const Type *T, bool RequiresExplicitLayout, const MachineInstr *MI) { |
256 | if (RequiresExplicitLayout && SPIRV::type_has_layout_decoration(T)) { |
257 | return add(Handle: SPIRV::irhandle_explict_layout_type(Ty: T), MI); |
258 | } |
259 | return add(Handle: SPIRV::handle(Ty: T), MI); |
260 | } |
261 | |
262 | bool add(const MachineInstr *Obj, const MachineInstr *MI) { |
263 | return add(Handle: SPIRV::handle(KeyMI: Obj), MI); |
264 | } |
265 | |
266 | Register find(const Value *V, const MachineFunction *MF) { |
267 | return find(Handle: SPIRV::handle(V), MF); |
268 | } |
269 | |
270 | Register find(const Type *T, bool RequiresExplicitLayout, |
271 | const MachineFunction *MF) { |
272 | if (RequiresExplicitLayout && SPIRV::type_has_layout_decoration(T)) |
273 | return find(Handle: SPIRV::irhandle_explict_layout_type(Ty: T), MF); |
274 | return find(Handle: SPIRV::handle(Ty: T), MF); |
275 | } |
276 | |
277 | Register find(const MachineInstr *MI, const MachineFunction *MF) { |
278 | return find(Handle: SPIRV::handle(KeyMI: MI), MF); |
279 | } |
280 | |
281 | const MachineInstr *findMI(const Value *Obj, const MachineFunction *MF) { |
282 | return findMI(Handle: SPIRV::handle(V: Obj), MF); |
283 | } |
284 | |
285 | const MachineInstr *findMI(const Type *T, bool RequiresExplicitLayout, |
286 | const MachineFunction *MF) { |
287 | if (RequiresExplicitLayout && SPIRV::type_has_layout_decoration(T)) |
288 | return findMI(Handle: SPIRV::irhandle_explict_layout_type(Ty: T), MF); |
289 | return findMI(Handle: SPIRV::handle(Ty: T), MF); |
290 | } |
291 | |
292 | const MachineInstr *findMI(const MachineInstr *Obj, |
293 | const MachineFunction *MF) { |
294 | return findMI(Handle: SPIRV::handle(KeyMI: Obj), MF); |
295 | } |
296 | }; |
297 | } // namespace llvm |
298 | #endif // LLVM_LIB_TARGET_SPIRV_SPIRVIRMAPPING_H |
299 | |