1 | //===--- SPIRVUtils.cpp ---- SPIR-V Utility Functions -----------*- C++ -*-===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file contains miscellaneous utility functions. |
10 | // |
11 | //===----------------------------------------------------------------------===// |
12 | |
13 | #include "SPIRVUtils.h" |
14 | #include "MCTargetDesc/SPIRVBaseInfo.h" |
15 | #include "SPIRV.h" |
16 | #include "SPIRVGlobalRegistry.h" |
17 | #include "SPIRVInstrInfo.h" |
18 | #include "SPIRVSubtarget.h" |
19 | #include "llvm/ADT/StringRef.h" |
20 | #include "llvm/CodeGen/GlobalISel/GenericMachineInstrs.h" |
21 | #include "llvm/CodeGen/GlobalISel/MachineIRBuilder.h" |
22 | #include "llvm/CodeGen/MachineInstr.h" |
23 | #include "llvm/CodeGen/MachineInstrBuilder.h" |
24 | #include "llvm/Demangle/Demangle.h" |
25 | #include "llvm/IR/IntrinsicInst.h" |
26 | #include "llvm/IR/IntrinsicsSPIRV.h" |
27 | #include <queue> |
28 | #include <vector> |
29 | |
30 | namespace llvm { |
31 | |
32 | // The following functions are used to add these string literals as a series of |
33 | // 32-bit integer operands with the correct format, and unpack them if necessary |
34 | // when making string comparisons in compiler passes. |
35 | // SPIR-V requires null-terminated UTF-8 strings padded to 32-bit alignment. |
36 | static uint32_t convertCharsToWord(const StringRef &Str, unsigned i) { |
37 | uint32_t Word = 0u; // Build up this 32-bit word from 4 8-bit chars. |
38 | for (unsigned WordIndex = 0; WordIndex < 4; ++WordIndex) { |
39 | unsigned StrIndex = i + WordIndex; |
40 | uint8_t CharToAdd = 0; // Initilize char as padding/null. |
41 | if (StrIndex < Str.size()) { // If it's within the string, get a real char. |
42 | CharToAdd = Str[StrIndex]; |
43 | } |
44 | Word |= (CharToAdd << (WordIndex * 8)); |
45 | } |
46 | return Word; |
47 | } |
48 | |
49 | // Get length including padding and null terminator. |
50 | static size_t getPaddedLen(const StringRef &Str) { |
51 | return (Str.size() + 4) & ~3; |
52 | } |
53 | |
54 | void addStringImm(const StringRef &Str, MCInst &Inst) { |
55 | const size_t PaddedLen = getPaddedLen(Str); |
56 | for (unsigned i = 0; i < PaddedLen; i += 4) { |
57 | // Add an operand for the 32-bits of chars or padding. |
58 | Inst.addOperand(Op: MCOperand::createImm(Val: convertCharsToWord(Str, i))); |
59 | } |
60 | } |
61 | |
62 | void addStringImm(const StringRef &Str, MachineInstrBuilder &MIB) { |
63 | const size_t PaddedLen = getPaddedLen(Str); |
64 | for (unsigned i = 0; i < PaddedLen; i += 4) { |
65 | // Add an operand for the 32-bits of chars or padding. |
66 | MIB.addImm(Val: convertCharsToWord(Str, i)); |
67 | } |
68 | } |
69 | |
70 | void addStringImm(const StringRef &Str, IRBuilder<> &B, |
71 | std::vector<Value *> &Args) { |
72 | const size_t PaddedLen = getPaddedLen(Str); |
73 | for (unsigned i = 0; i < PaddedLen; i += 4) { |
74 | // Add a vector element for the 32-bits of chars or padding. |
75 | Args.push_back(x: B.getInt32(C: convertCharsToWord(Str, i))); |
76 | } |
77 | } |
78 | |
79 | std::string getStringImm(const MachineInstr &MI, unsigned StartIndex) { |
80 | return getSPIRVStringOperand(MI, StartIndex); |
81 | } |
82 | |
83 | std::string getStringValueFromReg(Register Reg, MachineRegisterInfo &MRI) { |
84 | MachineInstr *Def = getVRegDef(MRI, Reg); |
85 | assert(Def && Def->getOpcode() == TargetOpcode::G_GLOBAL_VALUE && |
86 | "Expected G_GLOBAL_VALUE" ); |
87 | const GlobalValue *GV = Def->getOperand(i: 1).getGlobal(); |
88 | Value *V = GV->getOperand(i: 0); |
89 | const ConstantDataArray *CDA = cast<ConstantDataArray>(Val: V); |
90 | return CDA->getAsCString().str(); |
91 | } |
92 | |
93 | void addNumImm(const APInt &Imm, MachineInstrBuilder &MIB) { |
94 | const auto Bitwidth = Imm.getBitWidth(); |
95 | if (Bitwidth == 1) |
96 | return; // Already handled |
97 | else if (Bitwidth <= 32) { |
98 | MIB.addImm(Val: Imm.getZExtValue()); |
99 | // Asm Printer needs this info to print floating-type correctly |
100 | if (Bitwidth == 16) |
101 | MIB.getInstr()->setAsmPrinterFlag(SPIRV::ASM_PRINTER_WIDTH16); |
102 | return; |
103 | } else if (Bitwidth <= 64) { |
104 | uint64_t FullImm = Imm.getZExtValue(); |
105 | uint32_t LowBits = FullImm & 0xffffffff; |
106 | uint32_t HighBits = (FullImm >> 32) & 0xffffffff; |
107 | MIB.addImm(Val: LowBits).addImm(Val: HighBits); |
108 | return; |
109 | } |
110 | report_fatal_error(reason: "Unsupported constant bitwidth" ); |
111 | } |
112 | |
113 | void buildOpName(Register Target, const StringRef &Name, |
114 | MachineIRBuilder &MIRBuilder) { |
115 | if (!Name.empty()) { |
116 | auto MIB = MIRBuilder.buildInstr(Opcode: SPIRV::OpName).addUse(RegNo: Target); |
117 | addStringImm(Str: Name, MIB); |
118 | } |
119 | } |
120 | |
121 | void buildOpName(Register Target, const StringRef &Name, MachineInstr &I, |
122 | const SPIRVInstrInfo &TII) { |
123 | if (!Name.empty()) { |
124 | auto MIB = |
125 | BuildMI(BB&: *I.getParent(), I, MIMD: I.getDebugLoc(), MCID: TII.get(Opcode: SPIRV::OpName)) |
126 | .addUse(RegNo: Target); |
127 | addStringImm(Str: Name, MIB); |
128 | } |
129 | } |
130 | |
131 | static void finishBuildOpDecorate(MachineInstrBuilder &MIB, |
132 | const std::vector<uint32_t> &DecArgs, |
133 | StringRef StrImm) { |
134 | if (!StrImm.empty()) |
135 | addStringImm(Str: StrImm, MIB); |
136 | for (const auto &DecArg : DecArgs) |
137 | MIB.addImm(Val: DecArg); |
138 | } |
139 | |
140 | void buildOpDecorate(Register Reg, MachineIRBuilder &MIRBuilder, |
141 | SPIRV::Decoration::Decoration Dec, |
142 | const std::vector<uint32_t> &DecArgs, StringRef StrImm) { |
143 | auto MIB = MIRBuilder.buildInstr(Opcode: SPIRV::OpDecorate) |
144 | .addUse(RegNo: Reg) |
145 | .addImm(Val: static_cast<uint32_t>(Dec)); |
146 | finishBuildOpDecorate(MIB, DecArgs, StrImm); |
147 | } |
148 | |
149 | void buildOpDecorate(Register Reg, MachineInstr &I, const SPIRVInstrInfo &TII, |
150 | SPIRV::Decoration::Decoration Dec, |
151 | const std::vector<uint32_t> &DecArgs, StringRef StrImm) { |
152 | MachineBasicBlock &MBB = *I.getParent(); |
153 | auto MIB = BuildMI(BB&: MBB, I, MIMD: I.getDebugLoc(), MCID: TII.get(Opcode: SPIRV::OpDecorate)) |
154 | .addUse(RegNo: Reg) |
155 | .addImm(Val: static_cast<uint32_t>(Dec)); |
156 | finishBuildOpDecorate(MIB, DecArgs, StrImm); |
157 | } |
158 | |
159 | void buildOpMemberDecorate(Register Reg, MachineIRBuilder &MIRBuilder, |
160 | SPIRV::Decoration::Decoration Dec, uint32_t Member, |
161 | const std::vector<uint32_t> &DecArgs, |
162 | StringRef StrImm) { |
163 | auto MIB = MIRBuilder.buildInstr(Opcode: SPIRV::OpMemberDecorate) |
164 | .addUse(RegNo: Reg) |
165 | .addImm(Val: Member) |
166 | .addImm(Val: static_cast<uint32_t>(Dec)); |
167 | finishBuildOpDecorate(MIB, DecArgs, StrImm); |
168 | } |
169 | |
170 | void buildOpMemberDecorate(Register Reg, MachineInstr &I, |
171 | const SPIRVInstrInfo &TII, |
172 | SPIRV::Decoration::Decoration Dec, uint32_t Member, |
173 | const std::vector<uint32_t> &DecArgs, |
174 | StringRef StrImm) { |
175 | MachineBasicBlock &MBB = *I.getParent(); |
176 | auto MIB = BuildMI(BB&: MBB, I, MIMD: I.getDebugLoc(), MCID: TII.get(Opcode: SPIRV::OpMemberDecorate)) |
177 | .addUse(RegNo: Reg) |
178 | .addImm(Val: Member) |
179 | .addImm(Val: static_cast<uint32_t>(Dec)); |
180 | finishBuildOpDecorate(MIB, DecArgs, StrImm); |
181 | } |
182 | |
183 | void buildOpSpirvDecorations(Register Reg, MachineIRBuilder &MIRBuilder, |
184 | const MDNode *GVarMD) { |
185 | for (unsigned I = 0, E = GVarMD->getNumOperands(); I != E; ++I) { |
186 | auto *OpMD = dyn_cast<MDNode>(Val: GVarMD->getOperand(I)); |
187 | if (!OpMD) |
188 | report_fatal_error(reason: "Invalid decoration" ); |
189 | if (OpMD->getNumOperands() == 0) |
190 | report_fatal_error(reason: "Expect operand(s) of the decoration" ); |
191 | ConstantInt *DecorationId = |
192 | mdconst::dyn_extract<ConstantInt>(MD: OpMD->getOperand(I: 0)); |
193 | if (!DecorationId) |
194 | report_fatal_error(reason: "Expect SPIR-V <Decoration> operand to be the first " |
195 | "element of the decoration" ); |
196 | auto MIB = MIRBuilder.buildInstr(Opcode: SPIRV::OpDecorate) |
197 | .addUse(RegNo: Reg) |
198 | .addImm(Val: static_cast<uint32_t>(DecorationId->getZExtValue())); |
199 | for (unsigned OpI = 1, OpE = OpMD->getNumOperands(); OpI != OpE; ++OpI) { |
200 | if (ConstantInt *OpV = |
201 | mdconst::dyn_extract<ConstantInt>(MD: OpMD->getOperand(I: OpI))) |
202 | MIB.addImm(Val: static_cast<uint32_t>(OpV->getZExtValue())); |
203 | else if (MDString *OpV = dyn_cast<MDString>(Val: OpMD->getOperand(I: OpI))) |
204 | addStringImm(Str: OpV->getString(), MIB); |
205 | else |
206 | report_fatal_error(reason: "Unexpected operand of the decoration" ); |
207 | } |
208 | } |
209 | } |
210 | |
211 | MachineBasicBlock::iterator getOpVariableMBBIt(MachineInstr &I) { |
212 | MachineFunction *MF = I.getParent()->getParent(); |
213 | MachineBasicBlock *MBB = &MF->front(); |
214 | MachineBasicBlock::iterator It = MBB->SkipPHIsAndLabels(I: MBB->begin()), |
215 | E = MBB->end(); |
216 | bool = false; |
217 | unsigned Opcode; |
218 | for (; It != E && It != I; ++It) { |
219 | Opcode = It->getOpcode(); |
220 | if (Opcode == SPIRV::OpFunction || Opcode == SPIRV::OpFunctionParameter) { |
221 | IsHeader = true; |
222 | } else if (IsHeader && |
223 | !(Opcode == SPIRV::ASSIGN_TYPE || Opcode == SPIRV::OpLabel)) { |
224 | ++It; |
225 | break; |
226 | } |
227 | } |
228 | return It; |
229 | } |
230 | |
231 | MachineBasicBlock::iterator getInsertPtValidEnd(MachineBasicBlock *MBB) { |
232 | MachineBasicBlock::iterator I = MBB->end(); |
233 | if (I == MBB->begin()) |
234 | return I; |
235 | --I; |
236 | while (I->isTerminator() || I->isDebugValue()) { |
237 | if (I == MBB->begin()) |
238 | break; |
239 | --I; |
240 | } |
241 | return I; |
242 | } |
243 | |
244 | SPIRV::StorageClass::StorageClass |
245 | addressSpaceToStorageClass(unsigned AddrSpace, const SPIRVSubtarget &STI) { |
246 | switch (AddrSpace) { |
247 | case 0: |
248 | return SPIRV::StorageClass::Function; |
249 | case 1: |
250 | return SPIRV::StorageClass::CrossWorkgroup; |
251 | case 2: |
252 | return SPIRV::StorageClass::UniformConstant; |
253 | case 3: |
254 | return SPIRV::StorageClass::Workgroup; |
255 | case 4: |
256 | return SPIRV::StorageClass::Generic; |
257 | case 5: |
258 | return STI.canUseExtension(E: SPIRV::Extension::SPV_INTEL_usm_storage_classes) |
259 | ? SPIRV::StorageClass::DeviceOnlyINTEL |
260 | : SPIRV::StorageClass::CrossWorkgroup; |
261 | case 6: |
262 | return STI.canUseExtension(E: SPIRV::Extension::SPV_INTEL_usm_storage_classes) |
263 | ? SPIRV::StorageClass::HostOnlyINTEL |
264 | : SPIRV::StorageClass::CrossWorkgroup; |
265 | case 7: |
266 | return SPIRV::StorageClass::Input; |
267 | case 8: |
268 | return SPIRV::StorageClass::Output; |
269 | case 9: |
270 | return SPIRV::StorageClass::CodeSectionINTEL; |
271 | case 10: |
272 | return SPIRV::StorageClass::Private; |
273 | case 11: |
274 | return SPIRV::StorageClass::StorageBuffer; |
275 | case 12: |
276 | return SPIRV::StorageClass::Uniform; |
277 | default: |
278 | report_fatal_error(reason: "Unknown address space" ); |
279 | } |
280 | } |
281 | |
282 | SPIRV::MemorySemantics::MemorySemantics |
283 | getMemSemanticsForStorageClass(SPIRV::StorageClass::StorageClass SC) { |
284 | switch (SC) { |
285 | case SPIRV::StorageClass::StorageBuffer: |
286 | case SPIRV::StorageClass::Uniform: |
287 | return SPIRV::MemorySemantics::UniformMemory; |
288 | case SPIRV::StorageClass::Workgroup: |
289 | return SPIRV::MemorySemantics::WorkgroupMemory; |
290 | case SPIRV::StorageClass::CrossWorkgroup: |
291 | return SPIRV::MemorySemantics::CrossWorkgroupMemory; |
292 | case SPIRV::StorageClass::AtomicCounter: |
293 | return SPIRV::MemorySemantics::AtomicCounterMemory; |
294 | case SPIRV::StorageClass::Image: |
295 | return SPIRV::MemorySemantics::ImageMemory; |
296 | default: |
297 | return SPIRV::MemorySemantics::None; |
298 | } |
299 | } |
300 | |
301 | SPIRV::MemorySemantics::MemorySemantics getMemSemantics(AtomicOrdering Ord) { |
302 | switch (Ord) { |
303 | case AtomicOrdering::Acquire: |
304 | return SPIRV::MemorySemantics::Acquire; |
305 | case AtomicOrdering::Release: |
306 | return SPIRV::MemorySemantics::Release; |
307 | case AtomicOrdering::AcquireRelease: |
308 | return SPIRV::MemorySemantics::AcquireRelease; |
309 | case AtomicOrdering::SequentiallyConsistent: |
310 | return SPIRV::MemorySemantics::SequentiallyConsistent; |
311 | case AtomicOrdering::Unordered: |
312 | case AtomicOrdering::Monotonic: |
313 | case AtomicOrdering::NotAtomic: |
314 | return SPIRV::MemorySemantics::None; |
315 | } |
316 | llvm_unreachable(nullptr); |
317 | } |
318 | |
319 | SPIRV::Scope::Scope getMemScope(LLVMContext &Ctx, SyncScope::ID Id) { |
320 | // Named by |
321 | // https://registry.khronos.org/SPIR-V/specs/unified1/SPIRV.html#_scope_id. |
322 | // We don't need aliases for Invocation and CrossDevice, as we already have |
323 | // them covered by "singlethread" and "" strings respectively (see |
324 | // implementation of LLVMContext::LLVMContext()). |
325 | static const llvm::SyncScope::ID SubGroup = |
326 | Ctx.getOrInsertSyncScopeID(SSN: "subgroup" ); |
327 | static const llvm::SyncScope::ID WorkGroup = |
328 | Ctx.getOrInsertSyncScopeID(SSN: "workgroup" ); |
329 | static const llvm::SyncScope::ID Device = |
330 | Ctx.getOrInsertSyncScopeID(SSN: "device" ); |
331 | |
332 | if (Id == llvm::SyncScope::SingleThread) |
333 | return SPIRV::Scope::Invocation; |
334 | else if (Id == llvm::SyncScope::System) |
335 | return SPIRV::Scope::CrossDevice; |
336 | else if (Id == SubGroup) |
337 | return SPIRV::Scope::Subgroup; |
338 | else if (Id == WorkGroup) |
339 | return SPIRV::Scope::Workgroup; |
340 | else if (Id == Device) |
341 | return SPIRV::Scope::Device; |
342 | return SPIRV::Scope::CrossDevice; |
343 | } |
344 | |
345 | MachineInstr *getDefInstrMaybeConstant(Register &ConstReg, |
346 | const MachineRegisterInfo *MRI) { |
347 | MachineInstr *MI = MRI->getVRegDef(Reg: ConstReg); |
348 | MachineInstr *ConstInstr = |
349 | MI->getOpcode() == SPIRV::G_TRUNC || MI->getOpcode() == SPIRV::G_ZEXT |
350 | ? MRI->getVRegDef(Reg: MI->getOperand(i: 1).getReg()) |
351 | : MI; |
352 | if (auto *GI = dyn_cast<GIntrinsic>(Val: ConstInstr)) { |
353 | if (GI->is(ID: Intrinsic::spv_track_constant)) { |
354 | ConstReg = ConstInstr->getOperand(i: 2).getReg(); |
355 | return MRI->getVRegDef(Reg: ConstReg); |
356 | } |
357 | } else if (ConstInstr->getOpcode() == SPIRV::ASSIGN_TYPE) { |
358 | ConstReg = ConstInstr->getOperand(i: 1).getReg(); |
359 | return MRI->getVRegDef(Reg: ConstReg); |
360 | } else if (ConstInstr->getOpcode() == TargetOpcode::G_CONSTANT || |
361 | ConstInstr->getOpcode() == TargetOpcode::G_FCONSTANT) { |
362 | ConstReg = ConstInstr->getOperand(i: 0).getReg(); |
363 | return ConstInstr; |
364 | } |
365 | return MRI->getVRegDef(Reg: ConstReg); |
366 | } |
367 | |
368 | uint64_t getIConstVal(Register ConstReg, const MachineRegisterInfo *MRI) { |
369 | const MachineInstr *MI = getDefInstrMaybeConstant(ConstReg, MRI); |
370 | assert(MI && MI->getOpcode() == TargetOpcode::G_CONSTANT); |
371 | return MI->getOperand(i: 1).getCImm()->getValue().getZExtValue(); |
372 | } |
373 | |
374 | bool isSpvIntrinsic(const MachineInstr &MI, Intrinsic::ID IntrinsicID) { |
375 | if (const auto *GI = dyn_cast<GIntrinsic>(Val: &MI)) |
376 | return GI->is(ID: IntrinsicID); |
377 | return false; |
378 | } |
379 | |
380 | Type *getMDOperandAsType(const MDNode *N, unsigned I) { |
381 | Type *ElementTy = cast<ValueAsMetadata>(Val: N->getOperand(I))->getType(); |
382 | return toTypedPointer(Ty: ElementTy); |
383 | } |
384 | |
385 | // The set of names is borrowed from the SPIR-V translator. |
386 | // TODO: may be implemented in SPIRVBuiltins.td. |
387 | static bool isPipeOrAddressSpaceCastBI(const StringRef MangledName) { |
388 | return MangledName == "write_pipe_2" || MangledName == "read_pipe_2" || |
389 | MangledName == "write_pipe_2_bl" || MangledName == "read_pipe_2_bl" || |
390 | MangledName == "write_pipe_4" || MangledName == "read_pipe_4" || |
391 | MangledName == "reserve_write_pipe" || |
392 | MangledName == "reserve_read_pipe" || |
393 | MangledName == "commit_write_pipe" || |
394 | MangledName == "commit_read_pipe" || |
395 | MangledName == "work_group_reserve_write_pipe" || |
396 | MangledName == "work_group_reserve_read_pipe" || |
397 | MangledName == "work_group_commit_write_pipe" || |
398 | MangledName == "work_group_commit_read_pipe" || |
399 | MangledName == "get_pipe_num_packets_ro" || |
400 | MangledName == "get_pipe_max_packets_ro" || |
401 | MangledName == "get_pipe_num_packets_wo" || |
402 | MangledName == "get_pipe_max_packets_wo" || |
403 | MangledName == "sub_group_reserve_write_pipe" || |
404 | MangledName == "sub_group_reserve_read_pipe" || |
405 | MangledName == "sub_group_commit_write_pipe" || |
406 | MangledName == "sub_group_commit_read_pipe" || |
407 | MangledName == "to_global" || MangledName == "to_local" || |
408 | MangledName == "to_private" ; |
409 | } |
410 | |
411 | static bool isEnqueueKernelBI(const StringRef MangledName) { |
412 | return MangledName == "__enqueue_kernel_basic" || |
413 | MangledName == "__enqueue_kernel_basic_events" || |
414 | MangledName == "__enqueue_kernel_varargs" || |
415 | MangledName == "__enqueue_kernel_events_varargs" ; |
416 | } |
417 | |
418 | static bool isKernelQueryBI(const StringRef MangledName) { |
419 | return MangledName == "__get_kernel_work_group_size_impl" || |
420 | MangledName == "__get_kernel_sub_group_count_for_ndrange_impl" || |
421 | MangledName == "__get_kernel_max_sub_group_size_for_ndrange_impl" || |
422 | MangledName == "__get_kernel_preferred_work_group_size_multiple_impl" ; |
423 | } |
424 | |
425 | static bool isNonMangledOCLBuiltin(StringRef Name) { |
426 | if (!Name.starts_with(Prefix: "__" )) |
427 | return false; |
428 | |
429 | return isEnqueueKernelBI(MangledName: Name) || isKernelQueryBI(MangledName: Name) || |
430 | isPipeOrAddressSpaceCastBI(MangledName: Name.drop_front(N: 2)) || |
431 | Name == "__translate_sampler_initializer" ; |
432 | } |
433 | |
434 | std::string getOclOrSpirvBuiltinDemangledName(StringRef Name) { |
435 | bool IsNonMangledOCL = isNonMangledOCLBuiltin(Name); |
436 | bool IsNonMangledSPIRV = Name.starts_with(Prefix: "__spirv_" ); |
437 | bool IsNonMangledHLSL = Name.starts_with(Prefix: "__hlsl_" ); |
438 | bool IsMangled = Name.starts_with(Prefix: "_Z" ); |
439 | |
440 | // Otherwise use simple demangling to return the function name. |
441 | if (IsNonMangledOCL || IsNonMangledSPIRV || IsNonMangledHLSL || !IsMangled) |
442 | return Name.str(); |
443 | |
444 | // Try to use the itanium demangler. |
445 | if (char *DemangledName = itaniumDemangle(mangled_name: Name.data())) { |
446 | std::string Result = DemangledName; |
447 | free(ptr: DemangledName); |
448 | return Result; |
449 | } |
450 | |
451 | // Autocheck C++, maybe need to do explicit check of the source language. |
452 | // OpenCL C++ built-ins are declared in cl namespace. |
453 | // TODO: consider using 'St' abbriviation for cl namespace mangling. |
454 | // Similar to ::std:: in C++. |
455 | size_t Start, Len = 0; |
456 | size_t DemangledNameLenStart = 2; |
457 | if (Name.starts_with(Prefix: "_ZN" )) { |
458 | // Skip CV and ref qualifiers. |
459 | size_t NameSpaceStart = Name.find_first_not_of(Chars: "rVKRO" , From: 3); |
460 | // All built-ins are in the ::cl:: namespace. |
461 | if (Name.substr(Start: NameSpaceStart, N: 11) != "2cl7__spirv" ) |
462 | return std::string(); |
463 | DemangledNameLenStart = NameSpaceStart + 11; |
464 | } |
465 | Start = Name.find_first_not_of(Chars: "0123456789" , From: DemangledNameLenStart); |
466 | Name.substr(Start: DemangledNameLenStart, N: Start - DemangledNameLenStart) |
467 | .getAsInteger(Radix: 10, Result&: Len); |
468 | return Name.substr(Start, N: Len).str(); |
469 | } |
470 | |
471 | bool hasBuiltinTypePrefix(StringRef Name) { |
472 | if (Name.starts_with(Prefix: "opencl." ) || Name.starts_with(Prefix: "ocl_" ) || |
473 | Name.starts_with(Prefix: "spirv." )) |
474 | return true; |
475 | return false; |
476 | } |
477 | |
478 | bool isSpecialOpaqueType(const Type *Ty) { |
479 | if (const TargetExtType *ExtTy = dyn_cast<TargetExtType>(Val: Ty)) |
480 | return isTypedPointerWrapper(ExtTy) |
481 | ? false |
482 | : hasBuiltinTypePrefix(Name: ExtTy->getName()); |
483 | |
484 | return false; |
485 | } |
486 | |
487 | bool isEntryPoint(const Function &F) { |
488 | // OpenCL handling: any function with the SPIR_KERNEL |
489 | // calling convention will be a potential entry point. |
490 | if (F.getCallingConv() == CallingConv::SPIR_KERNEL) |
491 | return true; |
492 | |
493 | // HLSL handling: special attribute are emitted from the |
494 | // front-end. |
495 | if (F.getFnAttribute(Kind: "hlsl.shader" ).isValid()) |
496 | return true; |
497 | |
498 | return false; |
499 | } |
500 | |
501 | Type *parseBasicTypeName(StringRef &TypeName, LLVMContext &Ctx) { |
502 | TypeName.consume_front(Prefix: "atomic_" ); |
503 | if (TypeName.consume_front(Prefix: "void" )) |
504 | return Type::getVoidTy(C&: Ctx); |
505 | else if (TypeName.consume_front(Prefix: "bool" ) || TypeName.consume_front(Prefix: "_Bool" )) |
506 | return Type::getIntNTy(C&: Ctx, N: 1); |
507 | else if (TypeName.consume_front(Prefix: "char" ) || |
508 | TypeName.consume_front(Prefix: "signed char" ) || |
509 | TypeName.consume_front(Prefix: "unsigned char" ) || |
510 | TypeName.consume_front(Prefix: "uchar" )) |
511 | return Type::getInt8Ty(C&: Ctx); |
512 | else if (TypeName.consume_front(Prefix: "short" ) || |
513 | TypeName.consume_front(Prefix: "signed short" ) || |
514 | TypeName.consume_front(Prefix: "unsigned short" ) || |
515 | TypeName.consume_front(Prefix: "ushort" )) |
516 | return Type::getInt16Ty(C&: Ctx); |
517 | else if (TypeName.consume_front(Prefix: "int" ) || |
518 | TypeName.consume_front(Prefix: "signed int" ) || |
519 | TypeName.consume_front(Prefix: "unsigned int" ) || |
520 | TypeName.consume_front(Prefix: "uint" )) |
521 | return Type::getInt32Ty(C&: Ctx); |
522 | else if (TypeName.consume_front(Prefix: "long" ) || |
523 | TypeName.consume_front(Prefix: "signed long" ) || |
524 | TypeName.consume_front(Prefix: "unsigned long" ) || |
525 | TypeName.consume_front(Prefix: "ulong" )) |
526 | return Type::getInt64Ty(C&: Ctx); |
527 | else if (TypeName.consume_front(Prefix: "half" ) || |
528 | TypeName.consume_front(Prefix: "_Float16" ) || |
529 | TypeName.consume_front(Prefix: "__fp16" )) |
530 | return Type::getHalfTy(C&: Ctx); |
531 | else if (TypeName.consume_front(Prefix: "float" )) |
532 | return Type::getFloatTy(C&: Ctx); |
533 | else if (TypeName.consume_front(Prefix: "double" )) |
534 | return Type::getDoubleTy(C&: Ctx); |
535 | |
536 | // Unable to recognize SPIRV type name |
537 | return nullptr; |
538 | } |
539 | |
540 | std::unordered_set<BasicBlock *> |
541 | PartialOrderingVisitor::getReachableFrom(BasicBlock *Start) { |
542 | std::queue<BasicBlock *> ToVisit; |
543 | ToVisit.push(x: Start); |
544 | |
545 | std::unordered_set<BasicBlock *> Output; |
546 | while (ToVisit.size() != 0) { |
547 | BasicBlock *BB = ToVisit.front(); |
548 | ToVisit.pop(); |
549 | |
550 | if (Output.count(x: BB) != 0) |
551 | continue; |
552 | Output.insert(x: BB); |
553 | |
554 | for (BasicBlock *Successor : successors(BB)) { |
555 | if (DT.dominates(A: Successor, B: BB)) |
556 | continue; |
557 | ToVisit.push(x: Successor); |
558 | } |
559 | } |
560 | |
561 | return Output; |
562 | } |
563 | |
564 | bool PartialOrderingVisitor::CanBeVisited(BasicBlock *BB) const { |
565 | for (BasicBlock *P : predecessors(BB)) { |
566 | // Ignore back-edges. |
567 | if (DT.dominates(A: BB, B: P)) |
568 | continue; |
569 | |
570 | // One of the predecessor hasn't been visited. Not ready yet. |
571 | if (BlockToOrder.count(x: P) == 0) |
572 | return false; |
573 | |
574 | // If the block is a loop exit, the loop must be finished before |
575 | // we can continue. |
576 | Loop *L = LI.getLoopFor(BB: P); |
577 | if (L == nullptr || L->contains(BB)) |
578 | continue; |
579 | |
580 | // SPIR-V requires a single back-edge. And the backend first |
581 | // step transforms loops into the simplified format. If we have |
582 | // more than 1 back-edge, something is wrong. |
583 | assert(L->getNumBackEdges() <= 1); |
584 | |
585 | // If the loop has no latch, loop's rank won't matter, so we can |
586 | // proceed. |
587 | BasicBlock *Latch = L->getLoopLatch(); |
588 | assert(Latch); |
589 | if (Latch == nullptr) |
590 | continue; |
591 | |
592 | // The latch is not ready yet, let's wait. |
593 | if (BlockToOrder.count(x: Latch) == 0) |
594 | return false; |
595 | } |
596 | |
597 | return true; |
598 | } |
599 | |
600 | size_t PartialOrderingVisitor::GetNodeRank(BasicBlock *BB) const { |
601 | auto It = BlockToOrder.find(x: BB); |
602 | if (It != BlockToOrder.end()) |
603 | return It->second.Rank; |
604 | |
605 | size_t result = 0; |
606 | for (BasicBlock *P : predecessors(BB)) { |
607 | // Ignore back-edges. |
608 | if (DT.dominates(A: BB, B: P)) |
609 | continue; |
610 | |
611 | auto Iterator = BlockToOrder.end(); |
612 | Loop *L = LI.getLoopFor(BB: P); |
613 | BasicBlock *Latch = L ? L->getLoopLatch() : nullptr; |
614 | |
615 | // If the predecessor is either outside a loop, or part of |
616 | // the same loop, simply take its rank + 1. |
617 | if (L == nullptr || L->contains(BB) || Latch == nullptr) { |
618 | Iterator = BlockToOrder.find(x: P); |
619 | } else { |
620 | // Otherwise, take the loop's rank (highest rank in the loop) as base. |
621 | // Since loops have a single latch, highest rank is easy to find. |
622 | // If the loop has no latch, then it doesn't matter. |
623 | Iterator = BlockToOrder.find(x: Latch); |
624 | } |
625 | |
626 | assert(Iterator != BlockToOrder.end()); |
627 | result = std::max(a: result, b: Iterator->second.Rank + 1); |
628 | } |
629 | |
630 | return result; |
631 | } |
632 | |
633 | size_t PartialOrderingVisitor::visit(BasicBlock *BB, size_t Unused) { |
634 | ToVisit.push(x: BB); |
635 | Queued.insert(x: BB); |
636 | |
637 | size_t QueueIndex = 0; |
638 | while (ToVisit.size() != 0) { |
639 | BasicBlock *BB = ToVisit.front(); |
640 | ToVisit.pop(); |
641 | |
642 | if (!CanBeVisited(BB)) { |
643 | ToVisit.push(x: BB); |
644 | if (QueueIndex >= ToVisit.size()) |
645 | llvm::report_fatal_error( |
646 | reason: "No valid candidate in the queue. Is the graph reducible?" ); |
647 | QueueIndex++; |
648 | continue; |
649 | } |
650 | |
651 | QueueIndex = 0; |
652 | size_t Rank = GetNodeRank(BB); |
653 | OrderInfo Info = {.Rank: Rank, .TraversalIndex: BlockToOrder.size()}; |
654 | BlockToOrder.emplace(args&: BB, args&: Info); |
655 | |
656 | for (BasicBlock *S : successors(BB)) { |
657 | if (Queued.count(x: S) != 0) |
658 | continue; |
659 | ToVisit.push(x: S); |
660 | Queued.insert(x: S); |
661 | } |
662 | } |
663 | |
664 | return 0; |
665 | } |
666 | |
667 | PartialOrderingVisitor::PartialOrderingVisitor(Function &F) { |
668 | DT.recalculate(Func&: F); |
669 | LI = LoopInfo(DT); |
670 | |
671 | visit(BB: &*F.begin(), Unused: 0); |
672 | |
673 | Order.reserve(n: F.size()); |
674 | for (auto &[BB, Info] : BlockToOrder) |
675 | Order.emplace_back(args: BB); |
676 | |
677 | std::sort(first: Order.begin(), last: Order.end(), comp: [&](const auto &LHS, const auto &RHS) { |
678 | return compare(LHS, RHS); |
679 | }); |
680 | } |
681 | |
682 | bool PartialOrderingVisitor::compare(const BasicBlock *LHS, |
683 | const BasicBlock *RHS) const { |
684 | const OrderInfo &InfoLHS = BlockToOrder.at(k: const_cast<BasicBlock *>(LHS)); |
685 | const OrderInfo &InfoRHS = BlockToOrder.at(k: const_cast<BasicBlock *>(RHS)); |
686 | if (InfoLHS.Rank != InfoRHS.Rank) |
687 | return InfoLHS.Rank < InfoRHS.Rank; |
688 | return InfoLHS.TraversalIndex < InfoRHS.TraversalIndex; |
689 | } |
690 | |
691 | void PartialOrderingVisitor::partialOrderVisit( |
692 | BasicBlock &Start, std::function<bool(BasicBlock *)> Op) { |
693 | std::unordered_set<BasicBlock *> Reachable = getReachableFrom(Start: &Start); |
694 | assert(BlockToOrder.count(&Start) != 0); |
695 | |
696 | // Skipping blocks with a rank inferior to |Start|'s rank. |
697 | auto It = Order.begin(); |
698 | while (It != Order.end() && *It != &Start) |
699 | ++It; |
700 | |
701 | // This is unexpected. Worst case |Start| is the last block, |
702 | // so It should point to the last block, not past-end. |
703 | assert(It != Order.end()); |
704 | |
705 | // By default, there is no rank limit. Setting it to the maximum value. |
706 | std::optional<size_t> EndRank = std::nullopt; |
707 | for (; It != Order.end(); ++It) { |
708 | if (EndRank.has_value() && BlockToOrder[*It].Rank > *EndRank) |
709 | break; |
710 | |
711 | if (Reachable.count(x: *It) == 0) { |
712 | continue; |
713 | } |
714 | |
715 | if (!Op(*It)) { |
716 | EndRank = BlockToOrder[*It].Rank; |
717 | } |
718 | } |
719 | } |
720 | |
721 | bool sortBlocks(Function &F) { |
722 | if (F.size() == 0) |
723 | return false; |
724 | |
725 | bool Modified = false; |
726 | std::vector<BasicBlock *> Order; |
727 | Order.reserve(n: F.size()); |
728 | |
729 | ReversePostOrderTraversal<Function *> RPOT(&F); |
730 | llvm::append_range(C&: Order, R&: RPOT); |
731 | |
732 | assert(&*F.begin() == Order[0]); |
733 | BasicBlock *LastBlock = &*F.begin(); |
734 | for (BasicBlock *BB : Order) { |
735 | if (BB != LastBlock && &*LastBlock->getNextNode() != BB) { |
736 | Modified = true; |
737 | BB->moveAfter(MovePos: LastBlock); |
738 | } |
739 | LastBlock = BB; |
740 | } |
741 | |
742 | return Modified; |
743 | } |
744 | |
745 | MachineInstr *getVRegDef(MachineRegisterInfo &MRI, Register Reg) { |
746 | MachineInstr *MaybeDef = MRI.getVRegDef(Reg); |
747 | if (MaybeDef && MaybeDef->getOpcode() == SPIRV::ASSIGN_TYPE) |
748 | MaybeDef = MRI.getVRegDef(Reg: MaybeDef->getOperand(i: 1).getReg()); |
749 | return MaybeDef; |
750 | } |
751 | |
752 | bool getVacantFunctionName(Module &M, std::string &Name) { |
753 | // It's a bit of paranoia, but still we don't want to have even a chance that |
754 | // the loop will work for too long. |
755 | constexpr unsigned MaxIters = 1024; |
756 | for (unsigned I = 0; I < MaxIters; ++I) { |
757 | std::string OrdName = Name + Twine(I).str(); |
758 | if (!M.getFunction(Name: OrdName)) { |
759 | Name = OrdName; |
760 | return true; |
761 | } |
762 | } |
763 | return false; |
764 | } |
765 | |
766 | // Assign SPIR-V type to the register. If the register has no valid assigned |
767 | // class, set register LLT type and class according to the SPIR-V type. |
768 | void setRegClassType(Register Reg, SPIRVType *SpvType, SPIRVGlobalRegistry *GR, |
769 | MachineRegisterInfo *MRI, const MachineFunction &MF, |
770 | bool Force) { |
771 | GR->assignSPIRVTypeToVReg(Type: SpvType, VReg: Reg, MF); |
772 | if (!MRI->getRegClassOrNull(Reg) || Force) { |
773 | MRI->setRegClass(Reg, RC: GR->getRegClass(SpvType)); |
774 | MRI->setType(VReg: Reg, Ty: GR->getRegType(SpvType)); |
775 | } |
776 | } |
777 | |
778 | // Create a SPIR-V type, assign SPIR-V type to the register. If the register has |
779 | // no valid assigned class, set register LLT type and class according to the |
780 | // SPIR-V type. |
781 | void setRegClassType(Register Reg, const Type *Ty, SPIRVGlobalRegistry *GR, |
782 | MachineIRBuilder &MIRBuilder, |
783 | SPIRV::AccessQualifier::AccessQualifier AccessQual, |
784 | bool EmitIR, bool Force) { |
785 | setRegClassType(Reg, |
786 | SpvType: GR->getOrCreateSPIRVType(Type: Ty, MIRBuilder, AQ: AccessQual, EmitIR), |
787 | GR, MRI: MIRBuilder.getMRI(), MF: MIRBuilder.getMF(), Force); |
788 | } |
789 | |
790 | // Create a virtual register and assign SPIR-V type to the register. Set |
791 | // register LLT type and class according to the SPIR-V type. |
792 | Register createVirtualRegister(SPIRVType *SpvType, SPIRVGlobalRegistry *GR, |
793 | MachineRegisterInfo *MRI, |
794 | const MachineFunction &MF) { |
795 | Register Reg = MRI->createVirtualRegister(RegClass: GR->getRegClass(SpvType)); |
796 | MRI->setType(VReg: Reg, Ty: GR->getRegType(SpvType)); |
797 | GR->assignSPIRVTypeToVReg(Type: SpvType, VReg: Reg, MF); |
798 | return Reg; |
799 | } |
800 | |
801 | // Create a virtual register and assign SPIR-V type to the register. Set |
802 | // register LLT type and class according to the SPIR-V type. |
803 | Register createVirtualRegister(SPIRVType *SpvType, SPIRVGlobalRegistry *GR, |
804 | MachineIRBuilder &MIRBuilder) { |
805 | return createVirtualRegister(SpvType, GR, MRI: MIRBuilder.getMRI(), |
806 | MF: MIRBuilder.getMF()); |
807 | } |
808 | |
809 | // Create a SPIR-V type, virtual register and assign SPIR-V type to the |
810 | // register. Set register LLT type and class according to the SPIR-V type. |
811 | Register createVirtualRegister( |
812 | const Type *Ty, SPIRVGlobalRegistry *GR, MachineIRBuilder &MIRBuilder, |
813 | SPIRV::AccessQualifier::AccessQualifier AccessQual, bool EmitIR) { |
814 | return createVirtualRegister( |
815 | SpvType: GR->getOrCreateSPIRVType(Type: Ty, MIRBuilder, AQ: AccessQual, EmitIR), GR, |
816 | MIRBuilder); |
817 | } |
818 | |
819 | CallInst *buildIntrWithMD(Intrinsic::ID IntrID, ArrayRef<Type *> Types, |
820 | Value *Arg, Value *Arg2, ArrayRef<Constant *> Imms, |
821 | IRBuilder<> &B) { |
822 | SmallVector<Value *, 4> Args; |
823 | Args.push_back(Elt: Arg2); |
824 | Args.push_back(Elt: buildMD(Arg)); |
825 | llvm::append_range(C&: Args, R&: Imms); |
826 | return B.CreateIntrinsic(ID: IntrID, Types: {Types}, Args); |
827 | } |
828 | |
829 | // Return true if there is an opaque pointer type nested in the argument. |
830 | bool isNestedPointer(const Type *Ty) { |
831 | if (Ty->isPtrOrPtrVectorTy()) |
832 | return true; |
833 | if (const FunctionType *RefTy = dyn_cast<FunctionType>(Val: Ty)) { |
834 | if (isNestedPointer(Ty: RefTy->getReturnType())) |
835 | return true; |
836 | for (const Type *ArgTy : RefTy->params()) |
837 | if (isNestedPointer(Ty: ArgTy)) |
838 | return true; |
839 | return false; |
840 | } |
841 | if (const ArrayType *RefTy = dyn_cast<ArrayType>(Val: Ty)) |
842 | return isNestedPointer(Ty: RefTy->getElementType()); |
843 | return false; |
844 | } |
845 | |
846 | bool isSpvIntrinsic(const Value *Arg) { |
847 | if (const auto *II = dyn_cast<IntrinsicInst>(Val: Arg)) |
848 | if (Function *F = II->getCalledFunction()) |
849 | if (F->getName().starts_with(Prefix: "llvm.spv." )) |
850 | return true; |
851 | return false; |
852 | } |
853 | |
854 | // Function to create continued instructions for SPV_INTEL_long_composites |
855 | // extension |
856 | SmallVector<MachineInstr *, 4> |
857 | createContinuedInstructions(MachineIRBuilder &MIRBuilder, unsigned Opcode, |
858 | unsigned MinWC, unsigned ContinuedOpcode, |
859 | ArrayRef<Register> Args, Register ReturnRegister, |
860 | Register TypeID) { |
861 | |
862 | SmallVector<MachineInstr *, 4> Instructions; |
863 | constexpr unsigned MaxWordCount = UINT16_MAX; |
864 | const size_t NumElements = Args.size(); |
865 | size_t MaxNumElements = MaxWordCount - MinWC; |
866 | size_t SPIRVStructNumElements = NumElements; |
867 | |
868 | if (NumElements > MaxNumElements) { |
869 | // Do adjustments for continued instructions which always had only one |
870 | // minumum word count. |
871 | SPIRVStructNumElements = MaxNumElements; |
872 | MaxNumElements = MaxWordCount - 1; |
873 | } |
874 | |
875 | auto MIB = |
876 | MIRBuilder.buildInstr(Opcode).addDef(RegNo: ReturnRegister).addUse(RegNo: TypeID); |
877 | |
878 | for (size_t I = 0; I < SPIRVStructNumElements; ++I) |
879 | MIB.addUse(RegNo: Args[I]); |
880 | |
881 | Instructions.push_back(Elt: MIB.getInstr()); |
882 | |
883 | for (size_t I = SPIRVStructNumElements; I < NumElements; |
884 | I += MaxNumElements) { |
885 | auto MIB = MIRBuilder.buildInstr(Opcode: ContinuedOpcode); |
886 | for (size_t J = I; J < std::min(a: I + MaxNumElements, b: NumElements); ++J) |
887 | MIB.addUse(RegNo: Args[J]); |
888 | Instructions.push_back(Elt: MIB.getInstr()); |
889 | } |
890 | return Instructions; |
891 | } |
892 | |
893 | SmallVector<unsigned, 1> getSpirvLoopControlOperandsFromLoopMetadata(Loop *L) { |
894 | unsigned LC = SPIRV::LoopControl::None; |
895 | // Currently used only to store PartialCount value. Later when other |
896 | // LoopControls are added - this map should be sorted before making |
897 | // them loop_merge operands to satisfy 3.23. Loop Control requirements. |
898 | std::vector<std::pair<unsigned, unsigned>> MaskToValueMap; |
899 | if (getBooleanLoopAttribute(TheLoop: L, Name: "llvm.loop.unroll.disable" )) { |
900 | LC |= SPIRV::LoopControl::DontUnroll; |
901 | } else { |
902 | if (getBooleanLoopAttribute(TheLoop: L, Name: "llvm.loop.unroll.enable" ) || |
903 | getBooleanLoopAttribute(TheLoop: L, Name: "llvm.loop.unroll.full" )) { |
904 | LC |= SPIRV::LoopControl::Unroll; |
905 | } |
906 | std::optional<int> Count = |
907 | getOptionalIntLoopAttribute(TheLoop: L, Name: "llvm.loop.unroll.count" ); |
908 | if (Count && Count != 1) { |
909 | LC |= SPIRV::LoopControl::PartialCount; |
910 | MaskToValueMap.emplace_back( |
911 | args: std::make_pair(x: SPIRV::LoopControl::PartialCount, y&: *Count)); |
912 | } |
913 | } |
914 | SmallVector<unsigned, 1> Result = {LC}; |
915 | for (auto &[Mask, Val] : MaskToValueMap) |
916 | Result.push_back(Elt: Val); |
917 | return Result; |
918 | } |
919 | |
920 | const std::set<unsigned> &getTypeFoldingSupportedOpcodes() { |
921 | // clang-format off |
922 | static const std::set<unsigned> TypeFoldingSupportingOpcs = { |
923 | TargetOpcode::G_ADD, |
924 | TargetOpcode::G_FADD, |
925 | TargetOpcode::G_STRICT_FADD, |
926 | TargetOpcode::G_SUB, |
927 | TargetOpcode::G_FSUB, |
928 | TargetOpcode::G_STRICT_FSUB, |
929 | TargetOpcode::G_MUL, |
930 | TargetOpcode::G_FMUL, |
931 | TargetOpcode::G_STRICT_FMUL, |
932 | TargetOpcode::G_SDIV, |
933 | TargetOpcode::G_UDIV, |
934 | TargetOpcode::G_FDIV, |
935 | TargetOpcode::G_STRICT_FDIV, |
936 | TargetOpcode::G_SREM, |
937 | TargetOpcode::G_UREM, |
938 | TargetOpcode::G_FREM, |
939 | TargetOpcode::G_STRICT_FREM, |
940 | TargetOpcode::G_FNEG, |
941 | TargetOpcode::G_CONSTANT, |
942 | TargetOpcode::G_FCONSTANT, |
943 | TargetOpcode::G_AND, |
944 | TargetOpcode::G_OR, |
945 | TargetOpcode::G_XOR, |
946 | TargetOpcode::G_SHL, |
947 | TargetOpcode::G_ASHR, |
948 | TargetOpcode::G_LSHR, |
949 | TargetOpcode::G_SELECT, |
950 | TargetOpcode::G_EXTRACT_VECTOR_ELT, |
951 | }; |
952 | // clang-format on |
953 | return TypeFoldingSupportingOpcs; |
954 | } |
955 | |
956 | bool isTypeFoldingSupported(unsigned Opcode) { |
957 | return getTypeFoldingSupportedOpcodes().count(x: Opcode) > 0; |
958 | } |
959 | |
960 | // Traversing [g]MIR accounting for pseudo-instructions. |
961 | MachineInstr *passCopy(MachineInstr *Def, const MachineRegisterInfo *MRI) { |
962 | return (Def->getOpcode() == SPIRV::ASSIGN_TYPE || |
963 | Def->getOpcode() == TargetOpcode::COPY) |
964 | ? MRI->getVRegDef(Reg: Def->getOperand(i: 1).getReg()) |
965 | : Def; |
966 | } |
967 | |
968 | MachineInstr *getDef(const MachineOperand &MO, const MachineRegisterInfo *MRI) { |
969 | if (MachineInstr *Def = MRI->getVRegDef(Reg: MO.getReg())) |
970 | return passCopy(Def, MRI); |
971 | return nullptr; |
972 | } |
973 | |
974 | MachineInstr *getImm(const MachineOperand &MO, const MachineRegisterInfo *MRI) { |
975 | if (MachineInstr *Def = getDef(MO, MRI)) { |
976 | if (Def->getOpcode() == TargetOpcode::G_CONSTANT || |
977 | Def->getOpcode() == SPIRV::OpConstantI) |
978 | return Def; |
979 | } |
980 | return nullptr; |
981 | } |
982 | |
983 | int64_t foldImm(const MachineOperand &MO, const MachineRegisterInfo *MRI) { |
984 | if (MachineInstr *Def = getImm(MO, MRI)) { |
985 | if (Def->getOpcode() == SPIRV::OpConstantI) |
986 | return Def->getOperand(i: 2).getImm(); |
987 | if (Def->getOpcode() == TargetOpcode::G_CONSTANT) |
988 | return Def->getOperand(i: 1).getCImm()->getZExtValue(); |
989 | } |
990 | llvm_unreachable("Unexpected integer constant pattern" ); |
991 | } |
992 | |
993 | unsigned getArrayComponentCount(const MachineRegisterInfo *MRI, |
994 | const MachineInstr *ResType) { |
995 | return foldImm(MO: ResType->getOperand(i: 2), MRI); |
996 | } |
997 | |
998 | } // namespace llvm |
999 | |