1//===--- SPIRVUtils.cpp ---- SPIR-V Utility Functions -----------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file contains miscellaneous utility functions.
10//
11//===----------------------------------------------------------------------===//
12
13#include "SPIRVUtils.h"
14#include "MCTargetDesc/SPIRVBaseInfo.h"
15#include "SPIRV.h"
16#include "SPIRVGlobalRegistry.h"
17#include "SPIRVInstrInfo.h"
18#include "SPIRVSubtarget.h"
19#include "llvm/ADT/StringRef.h"
20#include "llvm/CodeGen/GlobalISel/GenericMachineInstrs.h"
21#include "llvm/CodeGen/GlobalISel/MachineIRBuilder.h"
22#include "llvm/CodeGen/MachineInstr.h"
23#include "llvm/CodeGen/MachineInstrBuilder.h"
24#include "llvm/Demangle/Demangle.h"
25#include "llvm/IR/IntrinsicInst.h"
26#include "llvm/IR/IntrinsicsSPIRV.h"
27#include <queue>
28#include <vector>
29
30namespace llvm {
31namespace SPIRV {
32// This code restores function args/retvalue types for composite cases
33// because the final types should still be aggregate whereas they're i32
34// during the translation to cope with aggregate flattening etc.
35// TODO: should these just return nullptr when there's no metadata?
36static FunctionType *extractFunctionTypeFromMetadata(NamedMDNode *NMD,
37 FunctionType *FTy,
38 StringRef Name) {
39 if (!NMD)
40 return FTy;
41
42 constexpr auto getConstInt = [](MDNode *MD, unsigned OpId) -> ConstantInt * {
43 if (MD->getNumOperands() <= OpId)
44 return nullptr;
45 if (auto *CMeta = dyn_cast<ConstantAsMetadata>(Val: MD->getOperand(I: OpId)))
46 return dyn_cast<ConstantInt>(Val: CMeta->getValue());
47 return nullptr;
48 };
49
50 auto It = find_if(Range: NMD->operands(), P: [Name](MDNode *N) {
51 if (auto *MDS = dyn_cast_or_null<MDString>(Val: N->getOperand(I: 0)))
52 return MDS->getString() == Name;
53 return false;
54 });
55
56 if (It == NMD->op_end())
57 return FTy;
58
59 Type *RetTy = FTy->getReturnType();
60 SmallVector<Type *, 4> PTys(FTy->params());
61
62 for (unsigned I = 1; I != (*It)->getNumOperands(); ++I) {
63 MDNode *MD = dyn_cast<MDNode>(Val: (*It)->getOperand(I));
64 assert(MD && "MDNode operand is expected");
65
66 if (auto *Const = getConstInt(MD, 0)) {
67 auto *CMeta = dyn_cast<ConstantAsMetadata>(Val: MD->getOperand(I: 1));
68 assert(CMeta && "ConstantAsMetadata operand is expected");
69 assert(Const->getSExtValue() >= -1);
70 // Currently -1 indicates return value, greater values mean
71 // argument numbers.
72 if (Const->getSExtValue() == -1)
73 RetTy = CMeta->getType();
74 else
75 PTys[Const->getSExtValue()] = CMeta->getType();
76 }
77 }
78
79 return FunctionType::get(Result: RetTy, Params: PTys, isVarArg: FTy->isVarArg());
80}
81
82FunctionType *getOriginalFunctionType(const Function &F) {
83 return extractFunctionTypeFromMetadata(
84 NMD: F.getParent()->getNamedMetadata(Name: "spv.cloned_funcs"), FTy: F.getFunctionType(),
85 Name: F.getName());
86}
87
88FunctionType *getOriginalFunctionType(const CallBase &CB) {
89 return extractFunctionTypeFromMetadata(
90 NMD: CB.getModule()->getNamedMetadata(Name: "spv.mutated_callsites"),
91 FTy: CB.getFunctionType(), Name: CB.getName());
92}
93} // Namespace SPIRV
94
95// The following functions are used to add these string literals as a series of
96// 32-bit integer operands with the correct format, and unpack them if necessary
97// when making string comparisons in compiler passes.
98// SPIR-V requires null-terminated UTF-8 strings padded to 32-bit alignment.
99static uint32_t convertCharsToWord(const StringRef &Str, unsigned i) {
100 uint32_t Word = 0u; // Build up this 32-bit word from 4 8-bit chars.
101 for (unsigned WordIndex = 0; WordIndex < 4; ++WordIndex) {
102 unsigned StrIndex = i + WordIndex;
103 uint8_t CharToAdd = 0; // Initilize char as padding/null.
104 if (StrIndex < Str.size()) { // If it's within the string, get a real char.
105 CharToAdd = Str[StrIndex];
106 }
107 Word |= (CharToAdd << (WordIndex * 8));
108 }
109 return Word;
110}
111
112// Get length including padding and null terminator.
113static size_t getPaddedLen(const StringRef &Str) {
114 return (Str.size() + 4) & ~3;
115}
116
117void addStringImm(const StringRef &Str, MCInst &Inst) {
118 const size_t PaddedLen = getPaddedLen(Str);
119 for (unsigned i = 0; i < PaddedLen; i += 4) {
120 // Add an operand for the 32-bits of chars or padding.
121 Inst.addOperand(Op: MCOperand::createImm(Val: convertCharsToWord(Str, i)));
122 }
123}
124
125void addStringImm(const StringRef &Str, MachineInstrBuilder &MIB) {
126 const size_t PaddedLen = getPaddedLen(Str);
127 for (unsigned i = 0; i < PaddedLen; i += 4) {
128 // Add an operand for the 32-bits of chars or padding.
129 MIB.addImm(Val: convertCharsToWord(Str, i));
130 }
131}
132
133void addStringImm(const StringRef &Str, IRBuilder<> &B,
134 std::vector<Value *> &Args) {
135 const size_t PaddedLen = getPaddedLen(Str);
136 for (unsigned i = 0; i < PaddedLen; i += 4) {
137 // Add a vector element for the 32-bits of chars or padding.
138 Args.push_back(x: B.getInt32(C: convertCharsToWord(Str, i)));
139 }
140}
141
142std::string getStringImm(const MachineInstr &MI, unsigned StartIndex) {
143 return getSPIRVStringOperand(MI, StartIndex);
144}
145
146std::string getStringValueFromReg(Register Reg, MachineRegisterInfo &MRI) {
147 MachineInstr *Def = getVRegDef(MRI, Reg);
148 assert(Def && Def->getOpcode() == TargetOpcode::G_GLOBAL_VALUE &&
149 "Expected G_GLOBAL_VALUE");
150 const GlobalValue *GV = Def->getOperand(i: 1).getGlobal();
151 Value *V = GV->getOperand(i: 0);
152 const ConstantDataArray *CDA = cast<ConstantDataArray>(Val: V);
153 return CDA->getAsCString().str();
154}
155
156void addNumImm(const APInt &Imm, MachineInstrBuilder &MIB) {
157 const auto Bitwidth = Imm.getBitWidth();
158 if (Bitwidth == 1)
159 return; // Already handled
160 else if (Bitwidth <= 32) {
161 MIB.addImm(Val: Imm.getZExtValue());
162 // Asm Printer needs this info to print floating-type correctly
163 if (Bitwidth == 16)
164 MIB.getInstr()->setAsmPrinterFlag(SPIRV::ASM_PRINTER_WIDTH16);
165 return;
166 } else if (Bitwidth <= 64) {
167 uint64_t FullImm = Imm.getZExtValue();
168 uint32_t LowBits = FullImm & 0xffffffff;
169 uint32_t HighBits = (FullImm >> 32) & 0xffffffff;
170 MIB.addImm(Val: LowBits).addImm(Val: HighBits);
171 // Asm Printer needs this info to print 64-bit operands correctly
172 MIB.getInstr()->setAsmPrinterFlag(SPIRV::ASM_PRINTER_WIDTH64);
173 return;
174 } else if (Bitwidth <= 128) {
175 uint32_t LowBits = Imm.getRawData()[0] & 0xffffffff;
176 uint32_t MidBits0 = (Imm.getRawData()[0] >> 32) & 0xffffffff;
177 uint32_t MidBits1 = Imm.getRawData()[1] & 0xffffffff;
178 uint32_t HighBits = (Imm.getRawData()[1] >> 32) & 0xffffffff;
179 MIB.addImm(Val: LowBits).addImm(Val: MidBits0).addImm(Val: MidBits1).addImm(Val: HighBits);
180 return;
181 }
182 report_fatal_error(reason: "Unsupported constant bitwidth");
183}
184
185void buildOpName(Register Target, const StringRef &Name,
186 MachineIRBuilder &MIRBuilder) {
187 if (!Name.empty()) {
188 auto MIB = MIRBuilder.buildInstr(Opcode: SPIRV::OpName).addUse(RegNo: Target);
189 addStringImm(Str: Name, MIB);
190 }
191}
192
193void buildOpName(Register Target, const StringRef &Name, MachineInstr &I,
194 const SPIRVInstrInfo &TII) {
195 if (!Name.empty()) {
196 auto MIB =
197 BuildMI(BB&: *I.getParent(), I, MIMD: I.getDebugLoc(), MCID: TII.get(Opcode: SPIRV::OpName))
198 .addUse(RegNo: Target);
199 addStringImm(Str: Name, MIB);
200 }
201}
202
203static void finishBuildOpDecorate(MachineInstrBuilder &MIB,
204 const std::vector<uint32_t> &DecArgs,
205 StringRef StrImm) {
206 if (!StrImm.empty())
207 addStringImm(Str: StrImm, MIB);
208 for (const auto &DecArg : DecArgs)
209 MIB.addImm(Val: DecArg);
210}
211
212void buildOpDecorate(Register Reg, MachineIRBuilder &MIRBuilder,
213 SPIRV::Decoration::Decoration Dec,
214 const std::vector<uint32_t> &DecArgs, StringRef StrImm) {
215 auto MIB = MIRBuilder.buildInstr(Opcode: SPIRV::OpDecorate)
216 .addUse(RegNo: Reg)
217 .addImm(Val: static_cast<uint32_t>(Dec));
218 finishBuildOpDecorate(MIB, DecArgs, StrImm);
219}
220
221void buildOpDecorate(Register Reg, MachineInstr &I, const SPIRVInstrInfo &TII,
222 SPIRV::Decoration::Decoration Dec,
223 const std::vector<uint32_t> &DecArgs, StringRef StrImm) {
224 MachineBasicBlock &MBB = *I.getParent();
225 auto MIB = BuildMI(BB&: MBB, I, MIMD: I.getDebugLoc(), MCID: TII.get(Opcode: SPIRV::OpDecorate))
226 .addUse(RegNo: Reg)
227 .addImm(Val: static_cast<uint32_t>(Dec));
228 finishBuildOpDecorate(MIB, DecArgs, StrImm);
229}
230
231void buildOpMemberDecorate(Register Reg, MachineIRBuilder &MIRBuilder,
232 SPIRV::Decoration::Decoration Dec, uint32_t Member,
233 const std::vector<uint32_t> &DecArgs,
234 StringRef StrImm) {
235 auto MIB = MIRBuilder.buildInstr(Opcode: SPIRV::OpMemberDecorate)
236 .addUse(RegNo: Reg)
237 .addImm(Val: Member)
238 .addImm(Val: static_cast<uint32_t>(Dec));
239 finishBuildOpDecorate(MIB, DecArgs, StrImm);
240}
241
242void buildOpMemberDecorate(Register Reg, MachineInstr &I,
243 const SPIRVInstrInfo &TII,
244 SPIRV::Decoration::Decoration Dec, uint32_t Member,
245 const std::vector<uint32_t> &DecArgs,
246 StringRef StrImm) {
247 MachineBasicBlock &MBB = *I.getParent();
248 auto MIB = BuildMI(BB&: MBB, I, MIMD: I.getDebugLoc(), MCID: TII.get(Opcode: SPIRV::OpMemberDecorate))
249 .addUse(RegNo: Reg)
250 .addImm(Val: Member)
251 .addImm(Val: static_cast<uint32_t>(Dec));
252 finishBuildOpDecorate(MIB, DecArgs, StrImm);
253}
254
255void buildOpSpirvDecorations(Register Reg, MachineIRBuilder &MIRBuilder,
256 const MDNode *GVarMD, const SPIRVSubtarget &ST) {
257 for (unsigned I = 0, E = GVarMD->getNumOperands(); I != E; ++I) {
258 auto *OpMD = dyn_cast<MDNode>(Val: GVarMD->getOperand(I));
259 if (!OpMD)
260 report_fatal_error(reason: "Invalid decoration");
261 if (OpMD->getNumOperands() == 0)
262 report_fatal_error(reason: "Expect operand(s) of the decoration");
263 ConstantInt *DecorationId =
264 mdconst::dyn_extract<ConstantInt>(MD: OpMD->getOperand(I: 0));
265 if (!DecorationId)
266 report_fatal_error(reason: "Expect SPIR-V <Decoration> operand to be the first "
267 "element of the decoration");
268
269 // The goal of `spirv.Decorations` metadata is to provide a way to
270 // represent SPIR-V entities that do not map to LLVM in an obvious way.
271 // FP flags do have obvious matches between LLVM IR and SPIR-V.
272 // Additionally, we have no guarantee at this point that the flags passed
273 // through the decoration are not violated already in the optimizer passes.
274 // Therefore, we simply ignore FP flags, including NoContraction, and
275 // FPFastMathMode.
276 if (DecorationId->getZExtValue() ==
277 static_cast<uint32_t>(SPIRV::Decoration::NoContraction) ||
278 DecorationId->getZExtValue() ==
279 static_cast<uint32_t>(SPIRV::Decoration::FPFastMathMode)) {
280 continue; // Ignored.
281 }
282 auto MIB = MIRBuilder.buildInstr(Opcode: SPIRV::OpDecorate)
283 .addUse(RegNo: Reg)
284 .addImm(Val: static_cast<uint32_t>(DecorationId->getZExtValue()));
285 for (unsigned OpI = 1, OpE = OpMD->getNumOperands(); OpI != OpE; ++OpI) {
286 if (ConstantInt *OpV =
287 mdconst::dyn_extract<ConstantInt>(MD: OpMD->getOperand(I: OpI)))
288 MIB.addImm(Val: static_cast<uint32_t>(OpV->getZExtValue()));
289 else if (MDString *OpV = dyn_cast<MDString>(Val: OpMD->getOperand(I: OpI)))
290 addStringImm(Str: OpV->getString(), MIB);
291 else
292 report_fatal_error(reason: "Unexpected operand of the decoration");
293 }
294 }
295}
296
297MachineBasicBlock::iterator getOpVariableMBBIt(MachineInstr &I) {
298 MachineFunction *MF = I.getParent()->getParent();
299 MachineBasicBlock *MBB = &MF->front();
300 MachineBasicBlock::iterator It = MBB->SkipPHIsAndLabels(I: MBB->begin()),
301 E = MBB->end();
302 bool IsHeader = false;
303 unsigned Opcode;
304 for (; It != E && It != I; ++It) {
305 Opcode = It->getOpcode();
306 if (Opcode == SPIRV::OpFunction || Opcode == SPIRV::OpFunctionParameter) {
307 IsHeader = true;
308 } else if (IsHeader &&
309 !(Opcode == SPIRV::ASSIGN_TYPE || Opcode == SPIRV::OpLabel)) {
310 ++It;
311 break;
312 }
313 }
314 return It;
315}
316
317MachineBasicBlock::iterator getInsertPtValidEnd(MachineBasicBlock *MBB) {
318 MachineBasicBlock::iterator I = MBB->end();
319 if (I == MBB->begin())
320 return I;
321 --I;
322 while (I->isTerminator() || I->isDebugValue()) {
323 if (I == MBB->begin())
324 break;
325 --I;
326 }
327 return I;
328}
329
330SPIRV::StorageClass::StorageClass
331addressSpaceToStorageClass(unsigned AddrSpace, const SPIRVSubtarget &STI) {
332 switch (AddrSpace) {
333 case 0:
334 return SPIRV::StorageClass::Function;
335 case 1:
336 return SPIRV::StorageClass::CrossWorkgroup;
337 case 2:
338 return SPIRV::StorageClass::UniformConstant;
339 case 3:
340 return SPIRV::StorageClass::Workgroup;
341 case 4:
342 return SPIRV::StorageClass::Generic;
343 case 5:
344 return STI.canUseExtension(E: SPIRV::Extension::SPV_INTEL_usm_storage_classes)
345 ? SPIRV::StorageClass::DeviceOnlyINTEL
346 : SPIRV::StorageClass::CrossWorkgroup;
347 case 6:
348 return STI.canUseExtension(E: SPIRV::Extension::SPV_INTEL_usm_storage_classes)
349 ? SPIRV::StorageClass::HostOnlyINTEL
350 : SPIRV::StorageClass::CrossWorkgroup;
351 case 7:
352 return SPIRV::StorageClass::Input;
353 case 8:
354 return SPIRV::StorageClass::Output;
355 case 9:
356 return SPIRV::StorageClass::CodeSectionINTEL;
357 case 10:
358 return SPIRV::StorageClass::Private;
359 case 11:
360 return SPIRV::StorageClass::StorageBuffer;
361 case 12:
362 return SPIRV::StorageClass::Uniform;
363 case 13:
364 return SPIRV::StorageClass::PushConstant;
365 default:
366 report_fatal_error(reason: "Unknown address space");
367 }
368}
369
370SPIRV::MemorySemantics::MemorySemantics
371getMemSemanticsForStorageClass(SPIRV::StorageClass::StorageClass SC) {
372 switch (SC) {
373 case SPIRV::StorageClass::StorageBuffer:
374 case SPIRV::StorageClass::Uniform:
375 return SPIRV::MemorySemantics::UniformMemory;
376 case SPIRV::StorageClass::Workgroup:
377 return SPIRV::MemorySemantics::WorkgroupMemory;
378 case SPIRV::StorageClass::CrossWorkgroup:
379 return SPIRV::MemorySemantics::CrossWorkgroupMemory;
380 case SPIRV::StorageClass::AtomicCounter:
381 return SPIRV::MemorySemantics::AtomicCounterMemory;
382 case SPIRV::StorageClass::Image:
383 return SPIRV::MemorySemantics::ImageMemory;
384 default:
385 return SPIRV::MemorySemantics::None;
386 }
387}
388
389SPIRV::MemorySemantics::MemorySemantics getMemSemantics(AtomicOrdering Ord) {
390 switch (Ord) {
391 case AtomicOrdering::Acquire:
392 return SPIRV::MemorySemantics::Acquire;
393 case AtomicOrdering::Release:
394 return SPIRV::MemorySemantics::Release;
395 case AtomicOrdering::AcquireRelease:
396 return SPIRV::MemorySemantics::AcquireRelease;
397 case AtomicOrdering::SequentiallyConsistent:
398 return SPIRV::MemorySemantics::SequentiallyConsistent;
399 case AtomicOrdering::Unordered:
400 case AtomicOrdering::Monotonic:
401 case AtomicOrdering::NotAtomic:
402 return SPIRV::MemorySemantics::None;
403 }
404 llvm_unreachable(nullptr);
405}
406
407SPIRV::Scope::Scope getMemScope(LLVMContext &Ctx, SyncScope::ID Id) {
408 // Named by
409 // https://registry.khronos.org/SPIR-V/specs/unified1/SPIRV.html#_scope_id.
410 // We don't need aliases for Invocation and CrossDevice, as we already have
411 // them covered by "singlethread" and "" strings respectively (see
412 // implementation of LLVMContext::LLVMContext()).
413 static const llvm::SyncScope::ID SubGroup =
414 Ctx.getOrInsertSyncScopeID(SSN: "subgroup");
415 static const llvm::SyncScope::ID WorkGroup =
416 Ctx.getOrInsertSyncScopeID(SSN: "workgroup");
417 static const llvm::SyncScope::ID Device =
418 Ctx.getOrInsertSyncScopeID(SSN: "device");
419
420 if (Id == llvm::SyncScope::SingleThread)
421 return SPIRV::Scope::Invocation;
422 else if (Id == llvm::SyncScope::System)
423 return SPIRV::Scope::CrossDevice;
424 else if (Id == SubGroup)
425 return SPIRV::Scope::Subgroup;
426 else if (Id == WorkGroup)
427 return SPIRV::Scope::Workgroup;
428 else if (Id == Device)
429 return SPIRV::Scope::Device;
430 return SPIRV::Scope::CrossDevice;
431}
432
433MachineInstr *getDefInstrMaybeConstant(Register &ConstReg,
434 const MachineRegisterInfo *MRI) {
435 MachineInstr *MI = MRI->getVRegDef(Reg: ConstReg);
436 MachineInstr *ConstInstr =
437 MI->getOpcode() == SPIRV::G_TRUNC || MI->getOpcode() == SPIRV::G_ZEXT
438 ? MRI->getVRegDef(Reg: MI->getOperand(i: 1).getReg())
439 : MI;
440 if (auto *GI = dyn_cast<GIntrinsic>(Val: ConstInstr)) {
441 if (GI->is(ID: Intrinsic::spv_track_constant)) {
442 ConstReg = ConstInstr->getOperand(i: 2).getReg();
443 return MRI->getVRegDef(Reg: ConstReg);
444 }
445 } else if (ConstInstr->getOpcode() == SPIRV::ASSIGN_TYPE) {
446 ConstReg = ConstInstr->getOperand(i: 1).getReg();
447 return MRI->getVRegDef(Reg: ConstReg);
448 } else if (ConstInstr->getOpcode() == TargetOpcode::G_CONSTANT ||
449 ConstInstr->getOpcode() == TargetOpcode::G_FCONSTANT) {
450 ConstReg = ConstInstr->getOperand(i: 0).getReg();
451 return ConstInstr;
452 }
453 return MRI->getVRegDef(Reg: ConstReg);
454}
455
456uint64_t getIConstVal(Register ConstReg, const MachineRegisterInfo *MRI) {
457 const MachineInstr *MI = getDefInstrMaybeConstant(ConstReg, MRI);
458 assert(MI && MI->getOpcode() == TargetOpcode::G_CONSTANT);
459 return MI->getOperand(i: 1).getCImm()->getValue().getZExtValue();
460}
461
462int64_t getIConstValSext(Register ConstReg, const MachineRegisterInfo *MRI) {
463 const MachineInstr *MI = getDefInstrMaybeConstant(ConstReg, MRI);
464 assert(MI && MI->getOpcode() == TargetOpcode::G_CONSTANT);
465 return MI->getOperand(i: 1).getCImm()->getSExtValue();
466}
467
468bool isSpvIntrinsic(const MachineInstr &MI, Intrinsic::ID IntrinsicID) {
469 if (const auto *GI = dyn_cast<GIntrinsic>(Val: &MI))
470 return GI->is(ID: IntrinsicID);
471 return false;
472}
473
474Type *getMDOperandAsType(const MDNode *N, unsigned I) {
475 Type *ElementTy = cast<ValueAsMetadata>(Val: N->getOperand(I))->getType();
476 return toTypedPointer(Ty: ElementTy);
477}
478
479// The set of names is borrowed from the SPIR-V translator.
480// TODO: may be implemented in SPIRVBuiltins.td.
481static bool isPipeOrAddressSpaceCastBI(const StringRef MangledName) {
482 return MangledName == "write_pipe_2" || MangledName == "read_pipe_2" ||
483 MangledName == "write_pipe_2_bl" || MangledName == "read_pipe_2_bl" ||
484 MangledName == "write_pipe_4" || MangledName == "read_pipe_4" ||
485 MangledName == "reserve_write_pipe" ||
486 MangledName == "reserve_read_pipe" ||
487 MangledName == "commit_write_pipe" ||
488 MangledName == "commit_read_pipe" ||
489 MangledName == "work_group_reserve_write_pipe" ||
490 MangledName == "work_group_reserve_read_pipe" ||
491 MangledName == "work_group_commit_write_pipe" ||
492 MangledName == "work_group_commit_read_pipe" ||
493 MangledName == "get_pipe_num_packets_ro" ||
494 MangledName == "get_pipe_max_packets_ro" ||
495 MangledName == "get_pipe_num_packets_wo" ||
496 MangledName == "get_pipe_max_packets_wo" ||
497 MangledName == "sub_group_reserve_write_pipe" ||
498 MangledName == "sub_group_reserve_read_pipe" ||
499 MangledName == "sub_group_commit_write_pipe" ||
500 MangledName == "sub_group_commit_read_pipe" ||
501 MangledName == "to_global" || MangledName == "to_local" ||
502 MangledName == "to_private";
503}
504
505static bool isEnqueueKernelBI(const StringRef MangledName) {
506 return MangledName == "__enqueue_kernel_basic" ||
507 MangledName == "__enqueue_kernel_basic_events" ||
508 MangledName == "__enqueue_kernel_varargs" ||
509 MangledName == "__enqueue_kernel_events_varargs";
510}
511
512static bool isKernelQueryBI(const StringRef MangledName) {
513 return MangledName == "__get_kernel_work_group_size_impl" ||
514 MangledName == "__get_kernel_sub_group_count_for_ndrange_impl" ||
515 MangledName == "__get_kernel_max_sub_group_size_for_ndrange_impl" ||
516 MangledName == "__get_kernel_preferred_work_group_size_multiple_impl";
517}
518
519static bool isNonMangledOCLBuiltin(StringRef Name) {
520 if (!Name.starts_with(Prefix: "__"))
521 return false;
522
523 return isEnqueueKernelBI(MangledName: Name) || isKernelQueryBI(MangledName: Name) ||
524 isPipeOrAddressSpaceCastBI(MangledName: Name.drop_front(N: 2)) ||
525 Name == "__translate_sampler_initializer";
526}
527
528std::string getOclOrSpirvBuiltinDemangledName(StringRef Name) {
529 bool IsNonMangledOCL = isNonMangledOCLBuiltin(Name);
530 bool IsNonMangledSPIRV = Name.starts_with(Prefix: "__spirv_");
531 bool IsNonMangledHLSL = Name.starts_with(Prefix: "__hlsl_");
532 bool IsMangled = Name.starts_with(Prefix: "_Z");
533
534 // Otherwise use simple demangling to return the function name.
535 if (IsNonMangledOCL || IsNonMangledSPIRV || IsNonMangledHLSL || !IsMangled)
536 return Name.str();
537
538 // Try to use the itanium demangler.
539 if (char *DemangledName = itaniumDemangle(mangled_name: Name.data())) {
540 std::string Result = DemangledName;
541 free(ptr: DemangledName);
542 return Result;
543 }
544
545 // Autocheck C++, maybe need to do explicit check of the source language.
546 // OpenCL C++ built-ins are declared in cl namespace.
547 // TODO: consider using 'St' abbriviation for cl namespace mangling.
548 // Similar to ::std:: in C++.
549 size_t Start, Len = 0;
550 size_t DemangledNameLenStart = 2;
551 if (Name.starts_with(Prefix: "_ZN")) {
552 // Skip CV and ref qualifiers.
553 size_t NameSpaceStart = Name.find_first_not_of(Chars: "rVKRO", From: 3);
554 // All built-ins are in the ::cl:: namespace.
555 if (Name.substr(Start: NameSpaceStart, N: 11) != "2cl7__spirv")
556 return std::string();
557 DemangledNameLenStart = NameSpaceStart + 11;
558 }
559 Start = Name.find_first_not_of(Chars: "0123456789", From: DemangledNameLenStart);
560 [[maybe_unused]] bool Error =
561 Name.substr(Start: DemangledNameLenStart, N: Start - DemangledNameLenStart)
562 .getAsInteger(Radix: 10, Result&: Len);
563 assert(!Error && "Failed to parse demangled name length");
564 return Name.substr(Start, N: Len).str();
565}
566
567bool hasBuiltinTypePrefix(StringRef Name) {
568 if (Name.starts_with(Prefix: "opencl.") || Name.starts_with(Prefix: "ocl_") ||
569 Name.starts_with(Prefix: "spirv."))
570 return true;
571 return false;
572}
573
574bool isSpecialOpaqueType(const Type *Ty) {
575 if (const TargetExtType *ExtTy = dyn_cast<TargetExtType>(Val: Ty))
576 return isTypedPointerWrapper(ExtTy)
577 ? false
578 : hasBuiltinTypePrefix(Name: ExtTy->getName());
579
580 return false;
581}
582
583bool isEntryPoint(const Function &F) {
584 // OpenCL handling: any function with the SPIR_KERNEL
585 // calling convention will be a potential entry point.
586 if (F.getCallingConv() == CallingConv::SPIR_KERNEL)
587 return true;
588
589 // HLSL handling: special attribute are emitted from the
590 // front-end.
591 if (F.getFnAttribute(Kind: "hlsl.shader").isValid())
592 return true;
593
594 return false;
595}
596
597Type *parseBasicTypeName(StringRef &TypeName, LLVMContext &Ctx) {
598 TypeName.consume_front(Prefix: "atomic_");
599 if (TypeName.consume_front(Prefix: "void"))
600 return Type::getVoidTy(C&: Ctx);
601 else if (TypeName.consume_front(Prefix: "bool") || TypeName.consume_front(Prefix: "_Bool"))
602 return Type::getIntNTy(C&: Ctx, N: 1);
603 else if (TypeName.consume_front(Prefix: "char") ||
604 TypeName.consume_front(Prefix: "signed char") ||
605 TypeName.consume_front(Prefix: "unsigned char") ||
606 TypeName.consume_front(Prefix: "uchar"))
607 return Type::getInt8Ty(C&: Ctx);
608 else if (TypeName.consume_front(Prefix: "short") ||
609 TypeName.consume_front(Prefix: "signed short") ||
610 TypeName.consume_front(Prefix: "unsigned short") ||
611 TypeName.consume_front(Prefix: "ushort"))
612 return Type::getInt16Ty(C&: Ctx);
613 else if (TypeName.consume_front(Prefix: "int") ||
614 TypeName.consume_front(Prefix: "signed int") ||
615 TypeName.consume_front(Prefix: "unsigned int") ||
616 TypeName.consume_front(Prefix: "uint"))
617 return Type::getInt32Ty(C&: Ctx);
618 else if (TypeName.consume_front(Prefix: "long") ||
619 TypeName.consume_front(Prefix: "signed long") ||
620 TypeName.consume_front(Prefix: "unsigned long") ||
621 TypeName.consume_front(Prefix: "ulong"))
622 return Type::getInt64Ty(C&: Ctx);
623 else if (TypeName.consume_front(Prefix: "half") ||
624 TypeName.consume_front(Prefix: "_Float16") ||
625 TypeName.consume_front(Prefix: "__fp16"))
626 return Type::getHalfTy(C&: Ctx);
627 else if (TypeName.consume_front(Prefix: "float"))
628 return Type::getFloatTy(C&: Ctx);
629 else if (TypeName.consume_front(Prefix: "double"))
630 return Type::getDoubleTy(C&: Ctx);
631
632 // Unable to recognize SPIRV type name
633 return nullptr;
634}
635
636std::unordered_set<BasicBlock *>
637PartialOrderingVisitor::getReachableFrom(BasicBlock *Start) {
638 std::queue<BasicBlock *> ToVisit;
639 ToVisit.push(x: Start);
640
641 std::unordered_set<BasicBlock *> Output;
642 while (ToVisit.size() != 0) {
643 BasicBlock *BB = ToVisit.front();
644 ToVisit.pop();
645
646 if (Output.count(x: BB) != 0)
647 continue;
648 Output.insert(x: BB);
649
650 for (BasicBlock *Successor : successors(BB)) {
651 if (DT.dominates(A: Successor, B: BB))
652 continue;
653 ToVisit.push(x: Successor);
654 }
655 }
656
657 return Output;
658}
659
660bool PartialOrderingVisitor::CanBeVisited(BasicBlock *BB) const {
661 for (BasicBlock *P : predecessors(BB)) {
662 // Ignore back-edges.
663 if (DT.dominates(A: BB, B: P))
664 continue;
665
666 // One of the predecessor hasn't been visited. Not ready yet.
667 if (BlockToOrder.count(x: P) == 0)
668 return false;
669
670 // If the block is a loop exit, the loop must be finished before
671 // we can continue.
672 Loop *L = LI.getLoopFor(BB: P);
673 if (L == nullptr || L->contains(BB))
674 continue;
675
676 // SPIR-V requires a single back-edge. And the backend first
677 // step transforms loops into the simplified format. If we have
678 // more than 1 back-edge, something is wrong.
679 assert(L->getNumBackEdges() <= 1);
680
681 // If the loop has no latch, loop's rank won't matter, so we can
682 // proceed.
683 BasicBlock *Latch = L->getLoopLatch();
684 assert(Latch);
685 if (Latch == nullptr)
686 continue;
687
688 // The latch is not ready yet, let's wait.
689 if (BlockToOrder.count(x: Latch) == 0)
690 return false;
691 }
692
693 return true;
694}
695
696size_t PartialOrderingVisitor::GetNodeRank(BasicBlock *BB) const {
697 auto It = BlockToOrder.find(x: BB);
698 if (It != BlockToOrder.end())
699 return It->second.Rank;
700
701 size_t result = 0;
702 for (BasicBlock *P : predecessors(BB)) {
703 // Ignore back-edges.
704 if (DT.dominates(A: BB, B: P))
705 continue;
706
707 auto Iterator = BlockToOrder.end();
708 Loop *L = LI.getLoopFor(BB: P);
709 BasicBlock *Latch = L ? L->getLoopLatch() : nullptr;
710
711 // If the predecessor is either outside a loop, or part of
712 // the same loop, simply take its rank + 1.
713 if (L == nullptr || L->contains(BB) || Latch == nullptr) {
714 Iterator = BlockToOrder.find(x: P);
715 } else {
716 // Otherwise, take the loop's rank (highest rank in the loop) as base.
717 // Since loops have a single latch, highest rank is easy to find.
718 // If the loop has no latch, then it doesn't matter.
719 Iterator = BlockToOrder.find(x: Latch);
720 }
721
722 assert(Iterator != BlockToOrder.end());
723 result = std::max(a: result, b: Iterator->second.Rank + 1);
724 }
725
726 return result;
727}
728
729size_t PartialOrderingVisitor::visit(BasicBlock *BB, size_t Unused) {
730 ToVisit.push(x: BB);
731 Queued.insert(x: BB);
732
733 size_t QueueIndex = 0;
734 while (ToVisit.size() != 0) {
735 BasicBlock *BB = ToVisit.front();
736 ToVisit.pop();
737
738 if (!CanBeVisited(BB)) {
739 ToVisit.push(x: BB);
740 if (QueueIndex >= ToVisit.size())
741 llvm::report_fatal_error(
742 reason: "No valid candidate in the queue. Is the graph reducible?");
743 QueueIndex++;
744 continue;
745 }
746
747 QueueIndex = 0;
748 size_t Rank = GetNodeRank(BB);
749 OrderInfo Info = {.Rank: Rank, .TraversalIndex: BlockToOrder.size()};
750 BlockToOrder.emplace(args&: BB, args&: Info);
751
752 for (BasicBlock *S : successors(BB)) {
753 if (Queued.count(x: S) != 0)
754 continue;
755 ToVisit.push(x: S);
756 Queued.insert(x: S);
757 }
758 }
759
760 return 0;
761}
762
763PartialOrderingVisitor::PartialOrderingVisitor(Function &F) {
764 DT.recalculate(Func&: F);
765 LI = LoopInfo(DT);
766
767 visit(BB: &*F.begin(), Unused: 0);
768
769 Order.reserve(n: F.size());
770 for (auto &[BB, Info] : BlockToOrder)
771 Order.emplace_back(args: BB);
772
773 std::sort(first: Order.begin(), last: Order.end(), comp: [&](const auto &LHS, const auto &RHS) {
774 return compare(LHS, RHS);
775 });
776}
777
778bool PartialOrderingVisitor::compare(const BasicBlock *LHS,
779 const BasicBlock *RHS) const {
780 const OrderInfo &InfoLHS = BlockToOrder.at(k: const_cast<BasicBlock *>(LHS));
781 const OrderInfo &InfoRHS = BlockToOrder.at(k: const_cast<BasicBlock *>(RHS));
782 if (InfoLHS.Rank != InfoRHS.Rank)
783 return InfoLHS.Rank < InfoRHS.Rank;
784 return InfoLHS.TraversalIndex < InfoRHS.TraversalIndex;
785}
786
787void PartialOrderingVisitor::partialOrderVisit(
788 BasicBlock &Start, std::function<bool(BasicBlock *)> Op) {
789 std::unordered_set<BasicBlock *> Reachable = getReachableFrom(Start: &Start);
790 assert(BlockToOrder.count(&Start) != 0);
791
792 // Skipping blocks with a rank inferior to |Start|'s rank.
793 auto It = Order.begin();
794 while (It != Order.end() && *It != &Start)
795 ++It;
796
797 // This is unexpected. Worst case |Start| is the last block,
798 // so It should point to the last block, not past-end.
799 assert(It != Order.end());
800
801 // By default, there is no rank limit. Setting it to the maximum value.
802 std::optional<size_t> EndRank = std::nullopt;
803 for (; It != Order.end(); ++It) {
804 if (EndRank.has_value() && BlockToOrder[*It].Rank > *EndRank)
805 break;
806
807 if (Reachable.count(x: *It) == 0) {
808 continue;
809 }
810
811 if (!Op(*It)) {
812 EndRank = BlockToOrder[*It].Rank;
813 }
814 }
815}
816
817bool sortBlocks(Function &F) {
818 if (F.size() == 0)
819 return false;
820
821 bool Modified = false;
822 std::vector<BasicBlock *> Order;
823 Order.reserve(n: F.size());
824
825 ReversePostOrderTraversal<Function *> RPOT(&F);
826 llvm::append_range(C&: Order, R&: RPOT);
827
828 assert(&*F.begin() == Order[0]);
829 BasicBlock *LastBlock = &*F.begin();
830 for (BasicBlock *BB : Order) {
831 if (BB != LastBlock && &*LastBlock->getNextNode() != BB) {
832 Modified = true;
833 BB->moveAfter(MovePos: LastBlock);
834 }
835 LastBlock = BB;
836 }
837
838 return Modified;
839}
840
841MachineInstr *getVRegDef(MachineRegisterInfo &MRI, Register Reg) {
842 MachineInstr *MaybeDef = MRI.getVRegDef(Reg);
843 if (MaybeDef && MaybeDef->getOpcode() == SPIRV::ASSIGN_TYPE)
844 MaybeDef = MRI.getVRegDef(Reg: MaybeDef->getOperand(i: 1).getReg());
845 return MaybeDef;
846}
847
848bool getVacantFunctionName(Module &M, std::string &Name) {
849 // It's a bit of paranoia, but still we don't want to have even a chance that
850 // the loop will work for too long.
851 constexpr unsigned MaxIters = 1024;
852 for (unsigned I = 0; I < MaxIters; ++I) {
853 std::string OrdName = Name + Twine(I).str();
854 if (!M.getFunction(Name: OrdName)) {
855 Name = std::move(OrdName);
856 return true;
857 }
858 }
859 return false;
860}
861
862// Assign SPIR-V type to the register. If the register has no valid assigned
863// class, set register LLT type and class according to the SPIR-V type.
864void setRegClassType(Register Reg, SPIRVType *SpvType, SPIRVGlobalRegistry *GR,
865 MachineRegisterInfo *MRI, const MachineFunction &MF,
866 bool Force) {
867 GR->assignSPIRVTypeToVReg(Type: SpvType, VReg: Reg, MF);
868 if (!MRI->getRegClassOrNull(Reg) || Force) {
869 MRI->setRegClass(Reg, RC: GR->getRegClass(SpvType));
870 MRI->setType(VReg: Reg, Ty: GR->getRegType(SpvType));
871 }
872}
873
874// Create a SPIR-V type, assign SPIR-V type to the register. If the register has
875// no valid assigned class, set register LLT type and class according to the
876// SPIR-V type.
877void setRegClassType(Register Reg, const Type *Ty, SPIRVGlobalRegistry *GR,
878 MachineIRBuilder &MIRBuilder,
879 SPIRV::AccessQualifier::AccessQualifier AccessQual,
880 bool EmitIR, bool Force) {
881 setRegClassType(Reg,
882 SpvType: GR->getOrCreateSPIRVType(Type: Ty, MIRBuilder, AQ: AccessQual, EmitIR),
883 GR, MRI: MIRBuilder.getMRI(), MF: MIRBuilder.getMF(), Force);
884}
885
886// Create a virtual register and assign SPIR-V type to the register. Set
887// register LLT type and class according to the SPIR-V type.
888Register createVirtualRegister(SPIRVType *SpvType, SPIRVGlobalRegistry *GR,
889 MachineRegisterInfo *MRI,
890 const MachineFunction &MF) {
891 Register Reg = MRI->createVirtualRegister(RegClass: GR->getRegClass(SpvType));
892 MRI->setType(VReg: Reg, Ty: GR->getRegType(SpvType));
893 GR->assignSPIRVTypeToVReg(Type: SpvType, VReg: Reg, MF);
894 return Reg;
895}
896
897// Create a virtual register and assign SPIR-V type to the register. Set
898// register LLT type and class according to the SPIR-V type.
899Register createVirtualRegister(SPIRVType *SpvType, SPIRVGlobalRegistry *GR,
900 MachineIRBuilder &MIRBuilder) {
901 return createVirtualRegister(SpvType, GR, MRI: MIRBuilder.getMRI(),
902 MF: MIRBuilder.getMF());
903}
904
905// Create a SPIR-V type, virtual register and assign SPIR-V type to the
906// register. Set register LLT type and class according to the SPIR-V type.
907Register createVirtualRegister(
908 const Type *Ty, SPIRVGlobalRegistry *GR, MachineIRBuilder &MIRBuilder,
909 SPIRV::AccessQualifier::AccessQualifier AccessQual, bool EmitIR) {
910 return createVirtualRegister(
911 SpvType: GR->getOrCreateSPIRVType(Type: Ty, MIRBuilder, AQ: AccessQual, EmitIR), GR,
912 MIRBuilder);
913}
914
915CallInst *buildIntrWithMD(Intrinsic::ID IntrID, ArrayRef<Type *> Types,
916 Value *Arg, Value *Arg2, ArrayRef<Constant *> Imms,
917 IRBuilder<> &B) {
918 SmallVector<Value *, 4> Args;
919 Args.push_back(Elt: Arg2);
920 Args.push_back(Elt: buildMD(Arg));
921 llvm::append_range(C&: Args, R&: Imms);
922 return B.CreateIntrinsic(ID: IntrID, Types: {Types}, Args);
923}
924
925// Return true if there is an opaque pointer type nested in the argument.
926bool isNestedPointer(const Type *Ty) {
927 if (Ty->isPtrOrPtrVectorTy())
928 return true;
929 if (const FunctionType *RefTy = dyn_cast<FunctionType>(Val: Ty)) {
930 if (isNestedPointer(Ty: RefTy->getReturnType()))
931 return true;
932 for (const Type *ArgTy : RefTy->params())
933 if (isNestedPointer(Ty: ArgTy))
934 return true;
935 return false;
936 }
937 if (const ArrayType *RefTy = dyn_cast<ArrayType>(Val: Ty))
938 return isNestedPointer(Ty: RefTy->getElementType());
939 return false;
940}
941
942bool isSpvIntrinsic(const Value *Arg) {
943 if (const auto *II = dyn_cast<IntrinsicInst>(Val: Arg))
944 if (Function *F = II->getCalledFunction())
945 if (F->getName().starts_with(Prefix: "llvm.spv."))
946 return true;
947 return false;
948}
949
950// Function to create continued instructions for SPV_INTEL_long_composites
951// extension
952SmallVector<MachineInstr *, 4>
953createContinuedInstructions(MachineIRBuilder &MIRBuilder, unsigned Opcode,
954 unsigned MinWC, unsigned ContinuedOpcode,
955 ArrayRef<Register> Args, Register ReturnRegister,
956 Register TypeID) {
957
958 SmallVector<MachineInstr *, 4> Instructions;
959 constexpr unsigned MaxWordCount = UINT16_MAX;
960 const size_t NumElements = Args.size();
961 size_t MaxNumElements = MaxWordCount - MinWC;
962 size_t SPIRVStructNumElements = NumElements;
963
964 if (NumElements > MaxNumElements) {
965 // Do adjustments for continued instructions which always had only one
966 // minumum word count.
967 SPIRVStructNumElements = MaxNumElements;
968 MaxNumElements = MaxWordCount - 1;
969 }
970
971 auto MIB =
972 MIRBuilder.buildInstr(Opcode).addDef(RegNo: ReturnRegister).addUse(RegNo: TypeID);
973
974 for (size_t I = 0; I < SPIRVStructNumElements; ++I)
975 MIB.addUse(RegNo: Args[I]);
976
977 Instructions.push_back(Elt: MIB.getInstr());
978
979 for (size_t I = SPIRVStructNumElements; I < NumElements;
980 I += MaxNumElements) {
981 auto MIB = MIRBuilder.buildInstr(Opcode: ContinuedOpcode);
982 for (size_t J = I; J < std::min(a: I + MaxNumElements, b: NumElements); ++J)
983 MIB.addUse(RegNo: Args[J]);
984 Instructions.push_back(Elt: MIB.getInstr());
985 }
986 return Instructions;
987}
988
989SmallVector<unsigned, 1>
990getSpirvLoopControlOperandsFromLoopMetadata(MDNode *LoopMD) {
991 unsigned LC = SPIRV::LoopControl::None;
992 // Currently used only to store PartialCount value. Later when other
993 // LoopControls are added - this map should be sorted before making
994 // them loop_merge operands to satisfy 3.23. Loop Control requirements.
995 std::vector<std::pair<unsigned, unsigned>> MaskToValueMap;
996 if (findOptionMDForLoopID(LoopID: LoopMD, Name: "llvm.loop.unroll.disable")) {
997 LC |= SPIRV::LoopControl::DontUnroll;
998 } else {
999 if (findOptionMDForLoopID(LoopID: LoopMD, Name: "llvm.loop.unroll.enable") ||
1000 findOptionMDForLoopID(LoopID: LoopMD, Name: "llvm.loop.unroll.full")) {
1001 LC |= SPIRV::LoopControl::Unroll;
1002 }
1003 if (MDNode *CountMD =
1004 findOptionMDForLoopID(LoopID: LoopMD, Name: "llvm.loop.unroll.count")) {
1005 if (auto *CI =
1006 mdconst::extract_or_null<ConstantInt>(MD: CountMD->getOperand(I: 1))) {
1007 unsigned Count = CI->getZExtValue();
1008 if (Count != 1) {
1009 LC |= SPIRV::LoopControl::PartialCount;
1010 MaskToValueMap.emplace_back(
1011 args: std::make_pair(x: SPIRV::LoopControl::PartialCount, y&: Count));
1012 }
1013 }
1014 }
1015 }
1016 SmallVector<unsigned, 1> Result = {LC};
1017 for (auto &[Mask, Val] : MaskToValueMap)
1018 Result.push_back(Elt: Val);
1019 return Result;
1020}
1021
1022SmallVector<unsigned, 1> getSpirvLoopControlOperandsFromLoopMetadata(Loop *L) {
1023 return getSpirvLoopControlOperandsFromLoopMetadata(LoopMD: L->getLoopID());
1024}
1025
1026const std::set<unsigned> &getTypeFoldingSupportedOpcodes() {
1027 // clang-format off
1028 static const std::set<unsigned> TypeFoldingSupportingOpcs = {
1029 TargetOpcode::G_ADD,
1030 TargetOpcode::G_FADD,
1031 TargetOpcode::G_STRICT_FADD,
1032 TargetOpcode::G_SUB,
1033 TargetOpcode::G_FSUB,
1034 TargetOpcode::G_STRICT_FSUB,
1035 TargetOpcode::G_MUL,
1036 TargetOpcode::G_FMUL,
1037 TargetOpcode::G_STRICT_FMUL,
1038 TargetOpcode::G_SDIV,
1039 TargetOpcode::G_UDIV,
1040 TargetOpcode::G_FDIV,
1041 TargetOpcode::G_STRICT_FDIV,
1042 TargetOpcode::G_SREM,
1043 TargetOpcode::G_UREM,
1044 TargetOpcode::G_FREM,
1045 TargetOpcode::G_STRICT_FREM,
1046 TargetOpcode::G_FNEG,
1047 TargetOpcode::G_CONSTANT,
1048 TargetOpcode::G_FCONSTANT,
1049 TargetOpcode::G_AND,
1050 TargetOpcode::G_OR,
1051 TargetOpcode::G_XOR,
1052 TargetOpcode::G_SHL,
1053 TargetOpcode::G_ASHR,
1054 TargetOpcode::G_LSHR,
1055 TargetOpcode::G_SELECT,
1056 TargetOpcode::G_EXTRACT_VECTOR_ELT,
1057 };
1058 // clang-format on
1059 return TypeFoldingSupportingOpcs;
1060}
1061
1062bool isTypeFoldingSupported(unsigned Opcode) {
1063 return getTypeFoldingSupportedOpcodes().count(x: Opcode) > 0;
1064}
1065
1066// Traversing [g]MIR accounting for pseudo-instructions.
1067MachineInstr *passCopy(MachineInstr *Def, const MachineRegisterInfo *MRI) {
1068 return (Def->getOpcode() == SPIRV::ASSIGN_TYPE ||
1069 Def->getOpcode() == TargetOpcode::COPY)
1070 ? MRI->getVRegDef(Reg: Def->getOperand(i: 1).getReg())
1071 : Def;
1072}
1073
1074MachineInstr *getDef(const MachineOperand &MO, const MachineRegisterInfo *MRI) {
1075 if (MachineInstr *Def = MRI->getVRegDef(Reg: MO.getReg()))
1076 return passCopy(Def, MRI);
1077 return nullptr;
1078}
1079
1080MachineInstr *getImm(const MachineOperand &MO, const MachineRegisterInfo *MRI) {
1081 if (MachineInstr *Def = getDef(MO, MRI)) {
1082 if (Def->getOpcode() == TargetOpcode::G_CONSTANT ||
1083 Def->getOpcode() == SPIRV::OpConstantI)
1084 return Def;
1085 }
1086 return nullptr;
1087}
1088
1089int64_t foldImm(const MachineOperand &MO, const MachineRegisterInfo *MRI) {
1090 if (MachineInstr *Def = getImm(MO, MRI)) {
1091 if (Def->getOpcode() == SPIRV::OpConstantI)
1092 return Def->getOperand(i: 2).getImm();
1093 if (Def->getOpcode() == TargetOpcode::G_CONSTANT)
1094 return Def->getOperand(i: 1).getCImm()->getZExtValue();
1095 }
1096 llvm_unreachable("Unexpected integer constant pattern");
1097}
1098
1099unsigned getArrayComponentCount(const MachineRegisterInfo *MRI,
1100 const MachineInstr *ResType) {
1101 return foldImm(MO: ResType->getOperand(i: 2), MRI);
1102}
1103
1104MachineBasicBlock::iterator
1105getFirstValidInstructionInsertPoint(MachineBasicBlock &BB) {
1106 // Find the position to insert the OpVariable instruction.
1107 // We will insert it after the last OpFunctionParameter, if any, or
1108 // after OpFunction otherwise.
1109 MachineBasicBlock::iterator VarPos = BB.begin();
1110 while (VarPos != BB.end() && VarPos->getOpcode() != SPIRV::OpFunction) {
1111 ++VarPos;
1112 }
1113 // Advance VarPos to the next instruction after OpFunction, it will either
1114 // be an OpFunctionParameter, so that we can start the next loop, or the
1115 // position to insert the OpVariable instruction.
1116 ++VarPos;
1117 while (VarPos != BB.end() &&
1118 VarPos->getOpcode() == SPIRV::OpFunctionParameter) {
1119 ++VarPos;
1120 }
1121 // VarPos is now pointing at after the last OpFunctionParameter, if any,
1122 // or after OpFunction, if no parameters.
1123 return VarPos != BB.end() && VarPos->getOpcode() == SPIRV::OpLabel ? ++VarPos
1124 : VarPos;
1125}
1126
1127bool matchPeeledArrayPattern(const StructType *Ty, Type *&OriginalElementType,
1128 uint64_t &TotalSize) {
1129 // An array of N padded structs is represented as {[N-1 x <{T, pad}>], T}.
1130 if (Ty->getStructNumElements() != 2)
1131 return false;
1132
1133 Type *FirstElement = Ty->getStructElementType(N: 0);
1134 Type *SecondElement = Ty->getStructElementType(N: 1);
1135
1136 if (!FirstElement->isArrayTy())
1137 return false;
1138
1139 Type *ArrayElementType = FirstElement->getArrayElementType();
1140 if (!ArrayElementType->isStructTy() ||
1141 ArrayElementType->getStructNumElements() != 2)
1142 return false;
1143
1144 Type *T_in_struct = ArrayElementType->getStructElementType(N: 0);
1145 if (T_in_struct != SecondElement)
1146 return false;
1147
1148 auto *Padding_in_struct =
1149 dyn_cast<TargetExtType>(Val: ArrayElementType->getStructElementType(N: 1));
1150 if (!Padding_in_struct || Padding_in_struct->getName() != "spirv.Padding")
1151 return false;
1152
1153 const uint64_t ArraySize = FirstElement->getArrayNumElements();
1154 TotalSize = ArraySize + 1;
1155 OriginalElementType = ArrayElementType;
1156 return true;
1157}
1158
1159Type *reconstitutePeeledArrayType(Type *Ty) {
1160 if (!Ty->isStructTy())
1161 return Ty;
1162
1163 auto *STy = cast<StructType>(Val: Ty);
1164 Type *OriginalElementType = nullptr;
1165 uint64_t TotalSize = 0;
1166 if (matchPeeledArrayPattern(Ty: STy, OriginalElementType, TotalSize)) {
1167 Type *ResultTy = ArrayType::get(
1168 ElementType: reconstitutePeeledArrayType(Ty: OriginalElementType), NumElements: TotalSize);
1169 return ResultTy;
1170 }
1171
1172 SmallVector<Type *, 4> NewElementTypes;
1173 bool Changed = false;
1174 for (Type *ElementTy : STy->elements()) {
1175 Type *NewElementTy = reconstitutePeeledArrayType(Ty: ElementTy);
1176 if (NewElementTy != ElementTy)
1177 Changed = true;
1178 NewElementTypes.push_back(Elt: NewElementTy);
1179 }
1180
1181 if (!Changed)
1182 return Ty;
1183
1184 Type *ResultTy;
1185 if (STy->isLiteral())
1186 ResultTy =
1187 StructType::get(Context&: STy->getContext(), Elements: NewElementTypes, isPacked: STy->isPacked());
1188 else {
1189 auto *NewTy = StructType::create(Context&: STy->getContext(), Name: STy->getName());
1190 NewTy->setBody(Elements: NewElementTypes, isPacked: STy->isPacked());
1191 ResultTy = NewTy;
1192 }
1193 return ResultTy;
1194}
1195
1196std::optional<SPIRV::LinkageType::LinkageType>
1197getSpirvLinkageTypeFor(const SPIRVSubtarget &ST, const GlobalValue &GV) {
1198 if (GV.hasLocalLinkage() || GV.hasHiddenVisibility())
1199 return std::nullopt;
1200
1201 if (GV.isDeclarationForLinker())
1202 return SPIRV::LinkageType::Import;
1203
1204 if (GV.hasLinkOnceODRLinkage() &&
1205 ST.canUseExtension(E: SPIRV::Extension::SPV_KHR_linkonce_odr))
1206 return SPIRV::LinkageType::LinkOnceODR;
1207
1208 return SPIRV::LinkageType::Export;
1209}
1210
1211} // namespace llvm
1212