| 1 | //===------------ SPIRVMapping.h - SPIR-V Duplicates Tracker ----*- C++ -*-===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | // |
| 9 | // General infrastructure for keeping track of the values that according to |
| 10 | // the SPIR-V binary layout should be global to the whole module. |
| 11 | // |
| 12 | //===----------------------------------------------------------------------===// |
| 13 | |
| 14 | #ifndef LLVM_LIB_TARGET_SPIRV_SPIRVIRMAPPING_H |
| 15 | #define LLVM_LIB_TARGET_SPIRV_SPIRVIRMAPPING_H |
| 16 | |
| 17 | #include "MCTargetDesc/SPIRVBaseInfo.h" |
| 18 | #include "MCTargetDesc/SPIRVMCTargetDesc.h" |
| 19 | #include "SPIRVUtils.h" |
| 20 | #include "llvm/ADT/DenseMap.h" |
| 21 | #include "llvm/ADT/Hashing.h" |
| 22 | #include "llvm/CodeGen/GlobalISel/MachineIRBuilder.h" |
| 23 | #include "llvm/CodeGen/MachineModuleInfo.h" |
| 24 | |
| 25 | #include <type_traits> |
| 26 | |
| 27 | namespace llvm { |
| 28 | namespace SPIRV { |
| 29 | |
| 30 | inline size_t to_hash(const MachineInstr *MI) { |
| 31 | hash_code H = llvm::hash_combine(args: MI->getOpcode(), args: MI->getNumOperands()); |
| 32 | for (unsigned I = MI->getNumDefs(); I < MI->getNumOperands(); ++I) { |
| 33 | const MachineOperand &MO = MI->getOperand(i: I); |
| 34 | if (MO.getType() == MachineOperand::MO_CImmediate) |
| 35 | H = llvm::hash_combine(args: H, args: MO.getType(), args: MO.getCImm()); |
| 36 | else if (MO.getType() == MachineOperand::MO_FPImmediate) |
| 37 | H = llvm::hash_combine(args: H, args: MO.getType(), args: MO.getFPImm()); |
| 38 | else |
| 39 | H = llvm::hash_combine(args: H, args: MO.getType()); |
| 40 | } |
| 41 | return H; |
| 42 | } |
| 43 | |
| 44 | using MIHandle = std::tuple<const MachineInstr *, Register, size_t>; |
| 45 | |
| 46 | inline MIHandle getMIKey(const MachineInstr *MI) { |
| 47 | return std::make_tuple(args&: MI, args: MI->getOperand(i: 0).getReg(), args: SPIRV::to_hash(MI)); |
| 48 | } |
| 49 | |
| 50 | using IRHandle = std::tuple<const void *, unsigned, unsigned>; |
| 51 | using IRHandleMF = std::pair<IRHandle, const MachineFunction *>; |
| 52 | |
| 53 | inline IRHandleMF getIRHandleMF(IRHandle Handle, const MachineFunction *MF) { |
| 54 | return std::make_pair(x&: Handle, y&: MF); |
| 55 | } |
| 56 | |
| 57 | enum SpecialTypeKind { |
| 58 | STK_Empty = 0, |
| 59 | STK_Image, |
| 60 | STK_SampledImage, |
| 61 | STK_Sampler, |
| 62 | STK_Pipe, |
| 63 | STK_DeviceEvent, |
| 64 | STK_ElementPointer, |
| 65 | STK_Type, |
| 66 | STK_Value, |
| 67 | STK_MachineInstr, |
| 68 | STK_VkBuffer, |
| 69 | STK_ExplictLayoutType, |
| 70 | STK_Last = -1 |
| 71 | }; |
| 72 | |
| 73 | union ImageAttrs { |
| 74 | struct BitFlags { |
| 75 | unsigned Dim : 3; |
| 76 | unsigned Depth : 2; |
| 77 | unsigned Arrayed : 1; |
| 78 | unsigned MS : 1; |
| 79 | unsigned Sampled : 2; |
| 80 | unsigned ImageFormat : 6; |
| 81 | unsigned AQ : 2; |
| 82 | } Flags; |
| 83 | unsigned Val; |
| 84 | |
| 85 | ImageAttrs(unsigned Dim, unsigned Depth, unsigned Arrayed, unsigned MS, |
| 86 | unsigned Sampled, unsigned ImageFormat, unsigned AQ = 0) { |
| 87 | Val = 0; |
| 88 | Flags.Dim = Dim; |
| 89 | Flags.Depth = Depth; |
| 90 | Flags.Arrayed = Arrayed; |
| 91 | Flags.MS = MS; |
| 92 | Flags.Sampled = Sampled; |
| 93 | Flags.ImageFormat = ImageFormat; |
| 94 | Flags.AQ = AQ; |
| 95 | } |
| 96 | }; |
| 97 | |
| 98 | inline IRHandle irhandle_image(const Type *SampledTy, unsigned Dim, |
| 99 | unsigned Depth, unsigned Arrayed, unsigned MS, |
| 100 | unsigned Sampled, unsigned ImageFormat, |
| 101 | unsigned AQ = 0) { |
| 102 | return std::make_tuple( |
| 103 | args&: SampledTy, |
| 104 | args: ImageAttrs(Dim, Depth, Arrayed, MS, Sampled, ImageFormat, AQ).Val, |
| 105 | args: SpecialTypeKind::STK_Image); |
| 106 | } |
| 107 | |
| 108 | inline IRHandle irhandle_sampled_image(const Type *SampledTy, |
| 109 | const MachineInstr *ImageTy) { |
| 110 | assert(ImageTy->getOpcode() == SPIRV::OpTypeImage); |
| 111 | unsigned AC = AccessQualifier::AccessQualifier::None; |
| 112 | if (ImageTy->getNumOperands() > 8) |
| 113 | AC = ImageTy->getOperand(i: 8).getImm(); |
| 114 | return std::make_tuple( |
| 115 | args&: SampledTy, |
| 116 | args: ImageAttrs( |
| 117 | ImageTy->getOperand(i: 2).getImm(), ImageTy->getOperand(i: 3).getImm(), |
| 118 | ImageTy->getOperand(i: 4).getImm(), ImageTy->getOperand(i: 5).getImm(), |
| 119 | ImageTy->getOperand(i: 6).getImm(), ImageTy->getOperand(i: 7).getImm(), AC) |
| 120 | .Val, |
| 121 | args: SpecialTypeKind::STK_SampledImage); |
| 122 | } |
| 123 | |
| 124 | inline IRHandle irhandle_sampler() { |
| 125 | return std::make_tuple(args: nullptr, args: 0U, args: SpecialTypeKind::STK_Sampler); |
| 126 | } |
| 127 | |
| 128 | inline IRHandle irhandle_pipe(uint8_t AQ) { |
| 129 | return std::make_tuple(args: nullptr, args&: AQ, args: SpecialTypeKind::STK_Pipe); |
| 130 | } |
| 131 | |
| 132 | inline IRHandle irhandle_event() { |
| 133 | return std::make_tuple(args: nullptr, args: 0U, args: SpecialTypeKind::STK_DeviceEvent); |
| 134 | } |
| 135 | |
| 136 | inline IRHandle irhandle_pointee(const Type *ElementType, |
| 137 | unsigned AddressSpace) { |
| 138 | return std::make_tuple(args: unifyPtrType(Ty: ElementType), args&: AddressSpace, |
| 139 | args: SpecialTypeKind::STK_ElementPointer); |
| 140 | } |
| 141 | |
| 142 | inline IRHandle irhandle_ptr(const void *Ptr, unsigned Arg, |
| 143 | enum SpecialTypeKind STK) { |
| 144 | return std::make_tuple(args&: Ptr, args&: Arg, args&: STK); |
| 145 | } |
| 146 | |
| 147 | inline IRHandle irhandle_vkbuffer(const Type *ElementType, |
| 148 | StorageClass::StorageClass SC, |
| 149 | bool IsWriteable) { |
| 150 | return std::make_tuple(args&: ElementType, args: (SC << 1) | IsWriteable, |
| 151 | args: SpecialTypeKind::STK_VkBuffer); |
| 152 | } |
| 153 | |
| 154 | inline IRHandle irhandle_explict_layout_type(const Type *Ty) { |
| 155 | const Type *WrpTy = unifyPtrType(Ty); |
| 156 | return irhandle_ptr(Ptr: WrpTy, Arg: Ty->getTypeID(), STK: STK_ExplictLayoutType); |
| 157 | } |
| 158 | |
| 159 | inline IRHandle handle(const Type *Ty) { |
| 160 | const Type *WrpTy = unifyPtrType(Ty); |
| 161 | return irhandle_ptr(Ptr: WrpTy, Arg: Ty->getTypeID(), STK: STK_Type); |
| 162 | } |
| 163 | |
| 164 | inline IRHandle handle(const Value *V) { |
| 165 | return irhandle_ptr(Ptr: V, Arg: V->getValueID(), STK: STK_Value); |
| 166 | } |
| 167 | |
| 168 | inline IRHandle handle(const MachineInstr *KeyMI) { |
| 169 | return irhandle_ptr(Ptr: KeyMI, Arg: SPIRV::to_hash(MI: KeyMI), STK: STK_MachineInstr); |
| 170 | } |
| 171 | |
| 172 | inline bool type_has_layout_decoration(const Type *T) { |
| 173 | return (isa<StructType>(Val: T) || isa<ArrayType>(Val: T)); |
| 174 | } |
| 175 | |
| 176 | } // namespace SPIRV |
| 177 | |
| 178 | // Bi-directional mappings between LLVM entities and (v-reg, machine function) |
| 179 | // pairs support management of unique SPIR-V definitions per machine function |
| 180 | // per an LLVM/GlobalISel entity (e.g., Type, Constant, Machine Instruction). |
| 181 | class SPIRVIRMapping { |
| 182 | DenseMap<SPIRV::IRHandleMF, SPIRV::MIHandle> Vregs; |
| 183 | DenseMap<const MachineInstr *, SPIRV::IRHandleMF> Defs; |
| 184 | |
| 185 | public: |
| 186 | bool add(SPIRV::IRHandle Handle, const MachineInstr *MI) { |
| 187 | if (auto DefIt = Defs.find(Val: MI); DefIt != Defs.end()) { |
| 188 | auto [ExistHandle, ExistMF] = DefIt->second; |
| 189 | if (Handle == ExistHandle && MI->getMF() == ExistMF) |
| 190 | return false; // already exists |
| 191 | // invalidate the record |
| 192 | Vregs.erase(Val: DefIt->second); |
| 193 | Defs.erase(I: DefIt); |
| 194 | } |
| 195 | SPIRV::IRHandleMF HandleMF = SPIRV::getIRHandleMF(Handle, MF: MI->getMF()); |
| 196 | SPIRV::MIHandle MIKey = SPIRV::getMIKey(MI); |
| 197 | auto It1 = Vregs.try_emplace(Key: HandleMF, Args&: MIKey); |
| 198 | if (!It1.second) { |
| 199 | // there is an expired record that we need to invalidate |
| 200 | Defs.erase(Val: std::get<0>(t&: It1.first->second)); |
| 201 | // update the record |
| 202 | It1.first->second = MIKey; |
| 203 | } |
| 204 | [[maybe_unused]] auto It2 = Defs.try_emplace(Key: MI, Args&: HandleMF); |
| 205 | assert(It2.second); |
| 206 | return true; |
| 207 | } |
| 208 | bool erase(const MachineInstr *MI) { |
| 209 | bool Res = false; |
| 210 | if (auto It = Defs.find(Val: MI); It != Defs.end()) { |
| 211 | Res = Vregs.erase(Val: It->second); |
| 212 | Defs.erase(I: It); |
| 213 | } |
| 214 | return Res; |
| 215 | } |
| 216 | const MachineInstr *findMI(SPIRV::IRHandle Handle, |
| 217 | const MachineFunction *MF) { |
| 218 | SPIRV::IRHandleMF HandleMF = SPIRV::getIRHandleMF(Handle, MF); |
| 219 | auto It = Vregs.find(Val: HandleMF); |
| 220 | if (It == Vregs.end()) |
| 221 | return nullptr; |
| 222 | auto [MI, Reg, Hash] = It->second; |
| 223 | const MachineInstr *Def = MF->getRegInfo().getVRegDef(Reg); |
| 224 | if (!Def || Def != MI || SPIRV::to_hash(MI) != Hash) { |
| 225 | // there is an expired record that we need to invalidate |
| 226 | erase(MI); |
| 227 | return nullptr; |
| 228 | } |
| 229 | assert(Defs.contains(MI) && Defs.find(MI)->second == HandleMF); |
| 230 | return MI; |
| 231 | } |
| 232 | Register find(SPIRV::IRHandle Handle, const MachineFunction *MF) { |
| 233 | const MachineInstr *MI = findMI(Handle, MF); |
| 234 | return MI ? MI->getOperand(i: 0).getReg() : Register(); |
| 235 | } |
| 236 | |
| 237 | // helpers |
| 238 | bool add(const Type *PointeeTy, unsigned AddressSpace, |
| 239 | const MachineInstr *MI) { |
| 240 | return add(Handle: SPIRV::irhandle_pointee(ElementType: PointeeTy, AddressSpace), MI); |
| 241 | } |
| 242 | Register find(const Type *PointeeTy, unsigned AddressSpace, |
| 243 | const MachineFunction *MF) { |
| 244 | return find(Handle: SPIRV::irhandle_pointee(ElementType: PointeeTy, AddressSpace), MF); |
| 245 | } |
| 246 | const MachineInstr *findMI(const Type *PointeeTy, unsigned AddressSpace, |
| 247 | const MachineFunction *MF) { |
| 248 | return findMI(Handle: SPIRV::irhandle_pointee(ElementType: PointeeTy, AddressSpace), MF); |
| 249 | } |
| 250 | |
| 251 | bool add(const Value *V, const MachineInstr *MI) { |
| 252 | return add(Handle: SPIRV::handle(V), MI); |
| 253 | } |
| 254 | |
| 255 | bool add(const Type *T, bool RequiresExplicitLayout, const MachineInstr *MI) { |
| 256 | if (RequiresExplicitLayout && SPIRV::type_has_layout_decoration(T)) { |
| 257 | return add(Handle: SPIRV::irhandle_explict_layout_type(Ty: T), MI); |
| 258 | } |
| 259 | return add(Handle: SPIRV::handle(Ty: T), MI); |
| 260 | } |
| 261 | |
| 262 | bool add(const MachineInstr *Obj, const MachineInstr *MI) { |
| 263 | return add(Handle: SPIRV::handle(KeyMI: Obj), MI); |
| 264 | } |
| 265 | |
| 266 | Register find(const Value *V, const MachineFunction *MF) { |
| 267 | return find(Handle: SPIRV::handle(V), MF); |
| 268 | } |
| 269 | |
| 270 | Register find(const Type *T, bool RequiresExplicitLayout, |
| 271 | const MachineFunction *MF) { |
| 272 | if (RequiresExplicitLayout && SPIRV::type_has_layout_decoration(T)) |
| 273 | return find(Handle: SPIRV::irhandle_explict_layout_type(Ty: T), MF); |
| 274 | return find(Handle: SPIRV::handle(Ty: T), MF); |
| 275 | } |
| 276 | |
| 277 | Register find(const MachineInstr *MI, const MachineFunction *MF) { |
| 278 | return find(Handle: SPIRV::handle(KeyMI: MI), MF); |
| 279 | } |
| 280 | |
| 281 | const MachineInstr *findMI(const Value *Obj, const MachineFunction *MF) { |
| 282 | return findMI(Handle: SPIRV::handle(V: Obj), MF); |
| 283 | } |
| 284 | |
| 285 | const MachineInstr *findMI(const Type *T, bool RequiresExplicitLayout, |
| 286 | const MachineFunction *MF) { |
| 287 | if (RequiresExplicitLayout && SPIRV::type_has_layout_decoration(T)) |
| 288 | return findMI(Handle: SPIRV::irhandle_explict_layout_type(Ty: T), MF); |
| 289 | return findMI(Handle: SPIRV::handle(Ty: T), MF); |
| 290 | } |
| 291 | |
| 292 | const MachineInstr *findMI(const MachineInstr *Obj, |
| 293 | const MachineFunction *MF) { |
| 294 | return findMI(Handle: SPIRV::handle(KeyMI: Obj), MF); |
| 295 | } |
| 296 | }; |
| 297 | } // namespace llvm |
| 298 | #endif // LLVM_LIB_TARGET_SPIRV_SPIRVIRMAPPING_H |
| 299 | |