1//===--- SPIRVUtils.cpp ---- SPIR-V Utility Functions -----------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file contains miscellaneous utility functions.
10//
11//===----------------------------------------------------------------------===//
12
13#include "SPIRVUtils.h"
14#include "MCTargetDesc/SPIRVBaseInfo.h"
15#include "SPIRV.h"
16#include "SPIRVGlobalRegistry.h"
17#include "SPIRVInstrInfo.h"
18#include "SPIRVSubtarget.h"
19#include "llvm/ADT/StringRef.h"
20#include "llvm/CodeGen/GlobalISel/GenericMachineInstrs.h"
21#include "llvm/CodeGen/GlobalISel/MachineIRBuilder.h"
22#include "llvm/CodeGen/MachineInstr.h"
23#include "llvm/CodeGen/MachineInstrBuilder.h"
24#include "llvm/Demangle/Demangle.h"
25#include "llvm/IR/IntrinsicInst.h"
26#include "llvm/IR/IntrinsicsSPIRV.h"
27#include <queue>
28#include <vector>
29
30namespace llvm {
31
32// The following functions are used to add these string literals as a series of
33// 32-bit integer operands with the correct format, and unpack them if necessary
34// when making string comparisons in compiler passes.
35// SPIR-V requires null-terminated UTF-8 strings padded to 32-bit alignment.
36static uint32_t convertCharsToWord(const StringRef &Str, unsigned i) {
37 uint32_t Word = 0u; // Build up this 32-bit word from 4 8-bit chars.
38 for (unsigned WordIndex = 0; WordIndex < 4; ++WordIndex) {
39 unsigned StrIndex = i + WordIndex;
40 uint8_t CharToAdd = 0; // Initilize char as padding/null.
41 if (StrIndex < Str.size()) { // If it's within the string, get a real char.
42 CharToAdd = Str[StrIndex];
43 }
44 Word |= (CharToAdd << (WordIndex * 8));
45 }
46 return Word;
47}
48
49// Get length including padding and null terminator.
50static size_t getPaddedLen(const StringRef &Str) {
51 return (Str.size() + 4) & ~3;
52}
53
54void addStringImm(const StringRef &Str, MCInst &Inst) {
55 const size_t PaddedLen = getPaddedLen(Str);
56 for (unsigned i = 0; i < PaddedLen; i += 4) {
57 // Add an operand for the 32-bits of chars or padding.
58 Inst.addOperand(Op: MCOperand::createImm(Val: convertCharsToWord(Str, i)));
59 }
60}
61
62void addStringImm(const StringRef &Str, MachineInstrBuilder &MIB) {
63 const size_t PaddedLen = getPaddedLen(Str);
64 for (unsigned i = 0; i < PaddedLen; i += 4) {
65 // Add an operand for the 32-bits of chars or padding.
66 MIB.addImm(Val: convertCharsToWord(Str, i));
67 }
68}
69
70void addStringImm(const StringRef &Str, IRBuilder<> &B,
71 std::vector<Value *> &Args) {
72 const size_t PaddedLen = getPaddedLen(Str);
73 for (unsigned i = 0; i < PaddedLen; i += 4) {
74 // Add a vector element for the 32-bits of chars or padding.
75 Args.push_back(x: B.getInt32(C: convertCharsToWord(Str, i)));
76 }
77}
78
79std::string getStringImm(const MachineInstr &MI, unsigned StartIndex) {
80 return getSPIRVStringOperand(MI, StartIndex);
81}
82
83std::string getStringValueFromReg(Register Reg, MachineRegisterInfo &MRI) {
84 MachineInstr *Def = getVRegDef(MRI, Reg);
85 assert(Def && Def->getOpcode() == TargetOpcode::G_GLOBAL_VALUE &&
86 "Expected G_GLOBAL_VALUE");
87 const GlobalValue *GV = Def->getOperand(i: 1).getGlobal();
88 Value *V = GV->getOperand(i: 0);
89 const ConstantDataArray *CDA = cast<ConstantDataArray>(Val: V);
90 return CDA->getAsCString().str();
91}
92
93void addNumImm(const APInt &Imm, MachineInstrBuilder &MIB) {
94 const auto Bitwidth = Imm.getBitWidth();
95 if (Bitwidth == 1)
96 return; // Already handled
97 else if (Bitwidth <= 32) {
98 MIB.addImm(Val: Imm.getZExtValue());
99 // Asm Printer needs this info to print floating-type correctly
100 if (Bitwidth == 16)
101 MIB.getInstr()->setAsmPrinterFlag(SPIRV::ASM_PRINTER_WIDTH16);
102 return;
103 } else if (Bitwidth <= 64) {
104 uint64_t FullImm = Imm.getZExtValue();
105 uint32_t LowBits = FullImm & 0xffffffff;
106 uint32_t HighBits = (FullImm >> 32) & 0xffffffff;
107 MIB.addImm(Val: LowBits).addImm(Val: HighBits);
108 return;
109 }
110 report_fatal_error(reason: "Unsupported constant bitwidth");
111}
112
113void buildOpName(Register Target, const StringRef &Name,
114 MachineIRBuilder &MIRBuilder) {
115 if (!Name.empty()) {
116 auto MIB = MIRBuilder.buildInstr(Opcode: SPIRV::OpName).addUse(RegNo: Target);
117 addStringImm(Str: Name, MIB);
118 }
119}
120
121void buildOpName(Register Target, const StringRef &Name, MachineInstr &I,
122 const SPIRVInstrInfo &TII) {
123 if (!Name.empty()) {
124 auto MIB =
125 BuildMI(BB&: *I.getParent(), I, MIMD: I.getDebugLoc(), MCID: TII.get(Opcode: SPIRV::OpName))
126 .addUse(RegNo: Target);
127 addStringImm(Str: Name, MIB);
128 }
129}
130
131static void finishBuildOpDecorate(MachineInstrBuilder &MIB,
132 const std::vector<uint32_t> &DecArgs,
133 StringRef StrImm) {
134 if (!StrImm.empty())
135 addStringImm(Str: StrImm, MIB);
136 for (const auto &DecArg : DecArgs)
137 MIB.addImm(Val: DecArg);
138}
139
140void buildOpDecorate(Register Reg, MachineIRBuilder &MIRBuilder,
141 SPIRV::Decoration::Decoration Dec,
142 const std::vector<uint32_t> &DecArgs, StringRef StrImm) {
143 auto MIB = MIRBuilder.buildInstr(Opcode: SPIRV::OpDecorate)
144 .addUse(RegNo: Reg)
145 .addImm(Val: static_cast<uint32_t>(Dec));
146 finishBuildOpDecorate(MIB, DecArgs, StrImm);
147}
148
149void buildOpDecorate(Register Reg, MachineInstr &I, const SPIRVInstrInfo &TII,
150 SPIRV::Decoration::Decoration Dec,
151 const std::vector<uint32_t> &DecArgs, StringRef StrImm) {
152 MachineBasicBlock &MBB = *I.getParent();
153 auto MIB = BuildMI(BB&: MBB, I, MIMD: I.getDebugLoc(), MCID: TII.get(Opcode: SPIRV::OpDecorate))
154 .addUse(RegNo: Reg)
155 .addImm(Val: static_cast<uint32_t>(Dec));
156 finishBuildOpDecorate(MIB, DecArgs, StrImm);
157}
158
159void buildOpMemberDecorate(Register Reg, MachineIRBuilder &MIRBuilder,
160 SPIRV::Decoration::Decoration Dec, uint32_t Member,
161 const std::vector<uint32_t> &DecArgs,
162 StringRef StrImm) {
163 auto MIB = MIRBuilder.buildInstr(Opcode: SPIRV::OpMemberDecorate)
164 .addUse(RegNo: Reg)
165 .addImm(Val: Member)
166 .addImm(Val: static_cast<uint32_t>(Dec));
167 finishBuildOpDecorate(MIB, DecArgs, StrImm);
168}
169
170void buildOpMemberDecorate(Register Reg, MachineInstr &I,
171 const SPIRVInstrInfo &TII,
172 SPIRV::Decoration::Decoration Dec, uint32_t Member,
173 const std::vector<uint32_t> &DecArgs,
174 StringRef StrImm) {
175 MachineBasicBlock &MBB = *I.getParent();
176 auto MIB = BuildMI(BB&: MBB, I, MIMD: I.getDebugLoc(), MCID: TII.get(Opcode: SPIRV::OpMemberDecorate))
177 .addUse(RegNo: Reg)
178 .addImm(Val: Member)
179 .addImm(Val: static_cast<uint32_t>(Dec));
180 finishBuildOpDecorate(MIB, DecArgs, StrImm);
181}
182
183void buildOpSpirvDecorations(Register Reg, MachineIRBuilder &MIRBuilder,
184 const MDNode *GVarMD) {
185 for (unsigned I = 0, E = GVarMD->getNumOperands(); I != E; ++I) {
186 auto *OpMD = dyn_cast<MDNode>(Val: GVarMD->getOperand(I));
187 if (!OpMD)
188 report_fatal_error(reason: "Invalid decoration");
189 if (OpMD->getNumOperands() == 0)
190 report_fatal_error(reason: "Expect operand(s) of the decoration");
191 ConstantInt *DecorationId =
192 mdconst::dyn_extract<ConstantInt>(MD: OpMD->getOperand(I: 0));
193 if (!DecorationId)
194 report_fatal_error(reason: "Expect SPIR-V <Decoration> operand to be the first "
195 "element of the decoration");
196 auto MIB = MIRBuilder.buildInstr(Opcode: SPIRV::OpDecorate)
197 .addUse(RegNo: Reg)
198 .addImm(Val: static_cast<uint32_t>(DecorationId->getZExtValue()));
199 for (unsigned OpI = 1, OpE = OpMD->getNumOperands(); OpI != OpE; ++OpI) {
200 if (ConstantInt *OpV =
201 mdconst::dyn_extract<ConstantInt>(MD: OpMD->getOperand(I: OpI)))
202 MIB.addImm(Val: static_cast<uint32_t>(OpV->getZExtValue()));
203 else if (MDString *OpV = dyn_cast<MDString>(Val: OpMD->getOperand(I: OpI)))
204 addStringImm(Str: OpV->getString(), MIB);
205 else
206 report_fatal_error(reason: "Unexpected operand of the decoration");
207 }
208 }
209}
210
211MachineBasicBlock::iterator getOpVariableMBBIt(MachineInstr &I) {
212 MachineFunction *MF = I.getParent()->getParent();
213 MachineBasicBlock *MBB = &MF->front();
214 MachineBasicBlock::iterator It = MBB->SkipPHIsAndLabels(I: MBB->begin()),
215 E = MBB->end();
216 bool IsHeader = false;
217 unsigned Opcode;
218 for (; It != E && It != I; ++It) {
219 Opcode = It->getOpcode();
220 if (Opcode == SPIRV::OpFunction || Opcode == SPIRV::OpFunctionParameter) {
221 IsHeader = true;
222 } else if (IsHeader &&
223 !(Opcode == SPIRV::ASSIGN_TYPE || Opcode == SPIRV::OpLabel)) {
224 ++It;
225 break;
226 }
227 }
228 return It;
229}
230
231MachineBasicBlock::iterator getInsertPtValidEnd(MachineBasicBlock *MBB) {
232 MachineBasicBlock::iterator I = MBB->end();
233 if (I == MBB->begin())
234 return I;
235 --I;
236 while (I->isTerminator() || I->isDebugValue()) {
237 if (I == MBB->begin())
238 break;
239 --I;
240 }
241 return I;
242}
243
244SPIRV::StorageClass::StorageClass
245addressSpaceToStorageClass(unsigned AddrSpace, const SPIRVSubtarget &STI) {
246 switch (AddrSpace) {
247 case 0:
248 return SPIRV::StorageClass::Function;
249 case 1:
250 return SPIRV::StorageClass::CrossWorkgroup;
251 case 2:
252 return SPIRV::StorageClass::UniformConstant;
253 case 3:
254 return SPIRV::StorageClass::Workgroup;
255 case 4:
256 return SPIRV::StorageClass::Generic;
257 case 5:
258 return STI.canUseExtension(E: SPIRV::Extension::SPV_INTEL_usm_storage_classes)
259 ? SPIRV::StorageClass::DeviceOnlyINTEL
260 : SPIRV::StorageClass::CrossWorkgroup;
261 case 6:
262 return STI.canUseExtension(E: SPIRV::Extension::SPV_INTEL_usm_storage_classes)
263 ? SPIRV::StorageClass::HostOnlyINTEL
264 : SPIRV::StorageClass::CrossWorkgroup;
265 case 7:
266 return SPIRV::StorageClass::Input;
267 case 8:
268 return SPIRV::StorageClass::Output;
269 case 9:
270 return SPIRV::StorageClass::CodeSectionINTEL;
271 case 10:
272 return SPIRV::StorageClass::Private;
273 case 11:
274 return SPIRV::StorageClass::StorageBuffer;
275 case 12:
276 return SPIRV::StorageClass::Uniform;
277 default:
278 report_fatal_error(reason: "Unknown address space");
279 }
280}
281
282SPIRV::MemorySemantics::MemorySemantics
283getMemSemanticsForStorageClass(SPIRV::StorageClass::StorageClass SC) {
284 switch (SC) {
285 case SPIRV::StorageClass::StorageBuffer:
286 case SPIRV::StorageClass::Uniform:
287 return SPIRV::MemorySemantics::UniformMemory;
288 case SPIRV::StorageClass::Workgroup:
289 return SPIRV::MemorySemantics::WorkgroupMemory;
290 case SPIRV::StorageClass::CrossWorkgroup:
291 return SPIRV::MemorySemantics::CrossWorkgroupMemory;
292 case SPIRV::StorageClass::AtomicCounter:
293 return SPIRV::MemorySemantics::AtomicCounterMemory;
294 case SPIRV::StorageClass::Image:
295 return SPIRV::MemorySemantics::ImageMemory;
296 default:
297 return SPIRV::MemorySemantics::None;
298 }
299}
300
301SPIRV::MemorySemantics::MemorySemantics getMemSemantics(AtomicOrdering Ord) {
302 switch (Ord) {
303 case AtomicOrdering::Acquire:
304 return SPIRV::MemorySemantics::Acquire;
305 case AtomicOrdering::Release:
306 return SPIRV::MemorySemantics::Release;
307 case AtomicOrdering::AcquireRelease:
308 return SPIRV::MemorySemantics::AcquireRelease;
309 case AtomicOrdering::SequentiallyConsistent:
310 return SPIRV::MemorySemantics::SequentiallyConsistent;
311 case AtomicOrdering::Unordered:
312 case AtomicOrdering::Monotonic:
313 case AtomicOrdering::NotAtomic:
314 return SPIRV::MemorySemantics::None;
315 }
316 llvm_unreachable(nullptr);
317}
318
319SPIRV::Scope::Scope getMemScope(LLVMContext &Ctx, SyncScope::ID Id) {
320 // Named by
321 // https://registry.khronos.org/SPIR-V/specs/unified1/SPIRV.html#_scope_id.
322 // We don't need aliases for Invocation and CrossDevice, as we already have
323 // them covered by "singlethread" and "" strings respectively (see
324 // implementation of LLVMContext::LLVMContext()).
325 static const llvm::SyncScope::ID SubGroup =
326 Ctx.getOrInsertSyncScopeID(SSN: "subgroup");
327 static const llvm::SyncScope::ID WorkGroup =
328 Ctx.getOrInsertSyncScopeID(SSN: "workgroup");
329 static const llvm::SyncScope::ID Device =
330 Ctx.getOrInsertSyncScopeID(SSN: "device");
331
332 if (Id == llvm::SyncScope::SingleThread)
333 return SPIRV::Scope::Invocation;
334 else if (Id == llvm::SyncScope::System)
335 return SPIRV::Scope::CrossDevice;
336 else if (Id == SubGroup)
337 return SPIRV::Scope::Subgroup;
338 else if (Id == WorkGroup)
339 return SPIRV::Scope::Workgroup;
340 else if (Id == Device)
341 return SPIRV::Scope::Device;
342 return SPIRV::Scope::CrossDevice;
343}
344
345MachineInstr *getDefInstrMaybeConstant(Register &ConstReg,
346 const MachineRegisterInfo *MRI) {
347 MachineInstr *MI = MRI->getVRegDef(Reg: ConstReg);
348 MachineInstr *ConstInstr =
349 MI->getOpcode() == SPIRV::G_TRUNC || MI->getOpcode() == SPIRV::G_ZEXT
350 ? MRI->getVRegDef(Reg: MI->getOperand(i: 1).getReg())
351 : MI;
352 if (auto *GI = dyn_cast<GIntrinsic>(Val: ConstInstr)) {
353 if (GI->is(ID: Intrinsic::spv_track_constant)) {
354 ConstReg = ConstInstr->getOperand(i: 2).getReg();
355 return MRI->getVRegDef(Reg: ConstReg);
356 }
357 } else if (ConstInstr->getOpcode() == SPIRV::ASSIGN_TYPE) {
358 ConstReg = ConstInstr->getOperand(i: 1).getReg();
359 return MRI->getVRegDef(Reg: ConstReg);
360 } else if (ConstInstr->getOpcode() == TargetOpcode::G_CONSTANT ||
361 ConstInstr->getOpcode() == TargetOpcode::G_FCONSTANT) {
362 ConstReg = ConstInstr->getOperand(i: 0).getReg();
363 return ConstInstr;
364 }
365 return MRI->getVRegDef(Reg: ConstReg);
366}
367
368uint64_t getIConstVal(Register ConstReg, const MachineRegisterInfo *MRI) {
369 const MachineInstr *MI = getDefInstrMaybeConstant(ConstReg, MRI);
370 assert(MI && MI->getOpcode() == TargetOpcode::G_CONSTANT);
371 return MI->getOperand(i: 1).getCImm()->getValue().getZExtValue();
372}
373
374bool isSpvIntrinsic(const MachineInstr &MI, Intrinsic::ID IntrinsicID) {
375 if (const auto *GI = dyn_cast<GIntrinsic>(Val: &MI))
376 return GI->is(ID: IntrinsicID);
377 return false;
378}
379
380Type *getMDOperandAsType(const MDNode *N, unsigned I) {
381 Type *ElementTy = cast<ValueAsMetadata>(Val: N->getOperand(I))->getType();
382 return toTypedPointer(Ty: ElementTy);
383}
384
385// The set of names is borrowed from the SPIR-V translator.
386// TODO: may be implemented in SPIRVBuiltins.td.
387static bool isPipeOrAddressSpaceCastBI(const StringRef MangledName) {
388 return MangledName == "write_pipe_2" || MangledName == "read_pipe_2" ||
389 MangledName == "write_pipe_2_bl" || MangledName == "read_pipe_2_bl" ||
390 MangledName == "write_pipe_4" || MangledName == "read_pipe_4" ||
391 MangledName == "reserve_write_pipe" ||
392 MangledName == "reserve_read_pipe" ||
393 MangledName == "commit_write_pipe" ||
394 MangledName == "commit_read_pipe" ||
395 MangledName == "work_group_reserve_write_pipe" ||
396 MangledName == "work_group_reserve_read_pipe" ||
397 MangledName == "work_group_commit_write_pipe" ||
398 MangledName == "work_group_commit_read_pipe" ||
399 MangledName == "get_pipe_num_packets_ro" ||
400 MangledName == "get_pipe_max_packets_ro" ||
401 MangledName == "get_pipe_num_packets_wo" ||
402 MangledName == "get_pipe_max_packets_wo" ||
403 MangledName == "sub_group_reserve_write_pipe" ||
404 MangledName == "sub_group_reserve_read_pipe" ||
405 MangledName == "sub_group_commit_write_pipe" ||
406 MangledName == "sub_group_commit_read_pipe" ||
407 MangledName == "to_global" || MangledName == "to_local" ||
408 MangledName == "to_private";
409}
410
411static bool isEnqueueKernelBI(const StringRef MangledName) {
412 return MangledName == "__enqueue_kernel_basic" ||
413 MangledName == "__enqueue_kernel_basic_events" ||
414 MangledName == "__enqueue_kernel_varargs" ||
415 MangledName == "__enqueue_kernel_events_varargs";
416}
417
418static bool isKernelQueryBI(const StringRef MangledName) {
419 return MangledName == "__get_kernel_work_group_size_impl" ||
420 MangledName == "__get_kernel_sub_group_count_for_ndrange_impl" ||
421 MangledName == "__get_kernel_max_sub_group_size_for_ndrange_impl" ||
422 MangledName == "__get_kernel_preferred_work_group_size_multiple_impl";
423}
424
425static bool isNonMangledOCLBuiltin(StringRef Name) {
426 if (!Name.starts_with(Prefix: "__"))
427 return false;
428
429 return isEnqueueKernelBI(MangledName: Name) || isKernelQueryBI(MangledName: Name) ||
430 isPipeOrAddressSpaceCastBI(MangledName: Name.drop_front(N: 2)) ||
431 Name == "__translate_sampler_initializer";
432}
433
434std::string getOclOrSpirvBuiltinDemangledName(StringRef Name) {
435 bool IsNonMangledOCL = isNonMangledOCLBuiltin(Name);
436 bool IsNonMangledSPIRV = Name.starts_with(Prefix: "__spirv_");
437 bool IsNonMangledHLSL = Name.starts_with(Prefix: "__hlsl_");
438 bool IsMangled = Name.starts_with(Prefix: "_Z");
439
440 // Otherwise use simple demangling to return the function name.
441 if (IsNonMangledOCL || IsNonMangledSPIRV || IsNonMangledHLSL || !IsMangled)
442 return Name.str();
443
444 // Try to use the itanium demangler.
445 if (char *DemangledName = itaniumDemangle(mangled_name: Name.data())) {
446 std::string Result = DemangledName;
447 free(ptr: DemangledName);
448 return Result;
449 }
450
451 // Autocheck C++, maybe need to do explicit check of the source language.
452 // OpenCL C++ built-ins are declared in cl namespace.
453 // TODO: consider using 'St' abbriviation for cl namespace mangling.
454 // Similar to ::std:: in C++.
455 size_t Start, Len = 0;
456 size_t DemangledNameLenStart = 2;
457 if (Name.starts_with(Prefix: "_ZN")) {
458 // Skip CV and ref qualifiers.
459 size_t NameSpaceStart = Name.find_first_not_of(Chars: "rVKRO", From: 3);
460 // All built-ins are in the ::cl:: namespace.
461 if (Name.substr(Start: NameSpaceStart, N: 11) != "2cl7__spirv")
462 return std::string();
463 DemangledNameLenStart = NameSpaceStart + 11;
464 }
465 Start = Name.find_first_not_of(Chars: "0123456789", From: DemangledNameLenStart);
466 Name.substr(Start: DemangledNameLenStart, N: Start - DemangledNameLenStart)
467 .getAsInteger(Radix: 10, Result&: Len);
468 return Name.substr(Start, N: Len).str();
469}
470
471bool hasBuiltinTypePrefix(StringRef Name) {
472 if (Name.starts_with(Prefix: "opencl.") || Name.starts_with(Prefix: "ocl_") ||
473 Name.starts_with(Prefix: "spirv."))
474 return true;
475 return false;
476}
477
478bool isSpecialOpaqueType(const Type *Ty) {
479 if (const TargetExtType *ExtTy = dyn_cast<TargetExtType>(Val: Ty))
480 return isTypedPointerWrapper(ExtTy)
481 ? false
482 : hasBuiltinTypePrefix(Name: ExtTy->getName());
483
484 return false;
485}
486
487bool isEntryPoint(const Function &F) {
488 // OpenCL handling: any function with the SPIR_KERNEL
489 // calling convention will be a potential entry point.
490 if (F.getCallingConv() == CallingConv::SPIR_KERNEL)
491 return true;
492
493 // HLSL handling: special attribute are emitted from the
494 // front-end.
495 if (F.getFnAttribute(Kind: "hlsl.shader").isValid())
496 return true;
497
498 return false;
499}
500
501Type *parseBasicTypeName(StringRef &TypeName, LLVMContext &Ctx) {
502 TypeName.consume_front(Prefix: "atomic_");
503 if (TypeName.consume_front(Prefix: "void"))
504 return Type::getVoidTy(C&: Ctx);
505 else if (TypeName.consume_front(Prefix: "bool") || TypeName.consume_front(Prefix: "_Bool"))
506 return Type::getIntNTy(C&: Ctx, N: 1);
507 else if (TypeName.consume_front(Prefix: "char") ||
508 TypeName.consume_front(Prefix: "signed char") ||
509 TypeName.consume_front(Prefix: "unsigned char") ||
510 TypeName.consume_front(Prefix: "uchar"))
511 return Type::getInt8Ty(C&: Ctx);
512 else if (TypeName.consume_front(Prefix: "short") ||
513 TypeName.consume_front(Prefix: "signed short") ||
514 TypeName.consume_front(Prefix: "unsigned short") ||
515 TypeName.consume_front(Prefix: "ushort"))
516 return Type::getInt16Ty(C&: Ctx);
517 else if (TypeName.consume_front(Prefix: "int") ||
518 TypeName.consume_front(Prefix: "signed int") ||
519 TypeName.consume_front(Prefix: "unsigned int") ||
520 TypeName.consume_front(Prefix: "uint"))
521 return Type::getInt32Ty(C&: Ctx);
522 else if (TypeName.consume_front(Prefix: "long") ||
523 TypeName.consume_front(Prefix: "signed long") ||
524 TypeName.consume_front(Prefix: "unsigned long") ||
525 TypeName.consume_front(Prefix: "ulong"))
526 return Type::getInt64Ty(C&: Ctx);
527 else if (TypeName.consume_front(Prefix: "half") ||
528 TypeName.consume_front(Prefix: "_Float16") ||
529 TypeName.consume_front(Prefix: "__fp16"))
530 return Type::getHalfTy(C&: Ctx);
531 else if (TypeName.consume_front(Prefix: "float"))
532 return Type::getFloatTy(C&: Ctx);
533 else if (TypeName.consume_front(Prefix: "double"))
534 return Type::getDoubleTy(C&: Ctx);
535
536 // Unable to recognize SPIRV type name
537 return nullptr;
538}
539
540std::unordered_set<BasicBlock *>
541PartialOrderingVisitor::getReachableFrom(BasicBlock *Start) {
542 std::queue<BasicBlock *> ToVisit;
543 ToVisit.push(x: Start);
544
545 std::unordered_set<BasicBlock *> Output;
546 while (ToVisit.size() != 0) {
547 BasicBlock *BB = ToVisit.front();
548 ToVisit.pop();
549
550 if (Output.count(x: BB) != 0)
551 continue;
552 Output.insert(x: BB);
553
554 for (BasicBlock *Successor : successors(BB)) {
555 if (DT.dominates(A: Successor, B: BB))
556 continue;
557 ToVisit.push(x: Successor);
558 }
559 }
560
561 return Output;
562}
563
564bool PartialOrderingVisitor::CanBeVisited(BasicBlock *BB) const {
565 for (BasicBlock *P : predecessors(BB)) {
566 // Ignore back-edges.
567 if (DT.dominates(A: BB, B: P))
568 continue;
569
570 // One of the predecessor hasn't been visited. Not ready yet.
571 if (BlockToOrder.count(x: P) == 0)
572 return false;
573
574 // If the block is a loop exit, the loop must be finished before
575 // we can continue.
576 Loop *L = LI.getLoopFor(BB: P);
577 if (L == nullptr || L->contains(BB))
578 continue;
579
580 // SPIR-V requires a single back-edge. And the backend first
581 // step transforms loops into the simplified format. If we have
582 // more than 1 back-edge, something is wrong.
583 assert(L->getNumBackEdges() <= 1);
584
585 // If the loop has no latch, loop's rank won't matter, so we can
586 // proceed.
587 BasicBlock *Latch = L->getLoopLatch();
588 assert(Latch);
589 if (Latch == nullptr)
590 continue;
591
592 // The latch is not ready yet, let's wait.
593 if (BlockToOrder.count(x: Latch) == 0)
594 return false;
595 }
596
597 return true;
598}
599
600size_t PartialOrderingVisitor::GetNodeRank(BasicBlock *BB) const {
601 auto It = BlockToOrder.find(x: BB);
602 if (It != BlockToOrder.end())
603 return It->second.Rank;
604
605 size_t result = 0;
606 for (BasicBlock *P : predecessors(BB)) {
607 // Ignore back-edges.
608 if (DT.dominates(A: BB, B: P))
609 continue;
610
611 auto Iterator = BlockToOrder.end();
612 Loop *L = LI.getLoopFor(BB: P);
613 BasicBlock *Latch = L ? L->getLoopLatch() : nullptr;
614
615 // If the predecessor is either outside a loop, or part of
616 // the same loop, simply take its rank + 1.
617 if (L == nullptr || L->contains(BB) || Latch == nullptr) {
618 Iterator = BlockToOrder.find(x: P);
619 } else {
620 // Otherwise, take the loop's rank (highest rank in the loop) as base.
621 // Since loops have a single latch, highest rank is easy to find.
622 // If the loop has no latch, then it doesn't matter.
623 Iterator = BlockToOrder.find(x: Latch);
624 }
625
626 assert(Iterator != BlockToOrder.end());
627 result = std::max(a: result, b: Iterator->second.Rank + 1);
628 }
629
630 return result;
631}
632
633size_t PartialOrderingVisitor::visit(BasicBlock *BB, size_t Unused) {
634 ToVisit.push(x: BB);
635 Queued.insert(x: BB);
636
637 size_t QueueIndex = 0;
638 while (ToVisit.size() != 0) {
639 BasicBlock *BB = ToVisit.front();
640 ToVisit.pop();
641
642 if (!CanBeVisited(BB)) {
643 ToVisit.push(x: BB);
644 if (QueueIndex >= ToVisit.size())
645 llvm::report_fatal_error(
646 reason: "No valid candidate in the queue. Is the graph reducible?");
647 QueueIndex++;
648 continue;
649 }
650
651 QueueIndex = 0;
652 size_t Rank = GetNodeRank(BB);
653 OrderInfo Info = {.Rank: Rank, .TraversalIndex: BlockToOrder.size()};
654 BlockToOrder.emplace(args&: BB, args&: Info);
655
656 for (BasicBlock *S : successors(BB)) {
657 if (Queued.count(x: S) != 0)
658 continue;
659 ToVisit.push(x: S);
660 Queued.insert(x: S);
661 }
662 }
663
664 return 0;
665}
666
667PartialOrderingVisitor::PartialOrderingVisitor(Function &F) {
668 DT.recalculate(Func&: F);
669 LI = LoopInfo(DT);
670
671 visit(BB: &*F.begin(), Unused: 0);
672
673 Order.reserve(n: F.size());
674 for (auto &[BB, Info] : BlockToOrder)
675 Order.emplace_back(args: BB);
676
677 std::sort(first: Order.begin(), last: Order.end(), comp: [&](const auto &LHS, const auto &RHS) {
678 return compare(LHS, RHS);
679 });
680}
681
682bool PartialOrderingVisitor::compare(const BasicBlock *LHS,
683 const BasicBlock *RHS) const {
684 const OrderInfo &InfoLHS = BlockToOrder.at(k: const_cast<BasicBlock *>(LHS));
685 const OrderInfo &InfoRHS = BlockToOrder.at(k: const_cast<BasicBlock *>(RHS));
686 if (InfoLHS.Rank != InfoRHS.Rank)
687 return InfoLHS.Rank < InfoRHS.Rank;
688 return InfoLHS.TraversalIndex < InfoRHS.TraversalIndex;
689}
690
691void PartialOrderingVisitor::partialOrderVisit(
692 BasicBlock &Start, std::function<bool(BasicBlock *)> Op) {
693 std::unordered_set<BasicBlock *> Reachable = getReachableFrom(Start: &Start);
694 assert(BlockToOrder.count(&Start) != 0);
695
696 // Skipping blocks with a rank inferior to |Start|'s rank.
697 auto It = Order.begin();
698 while (It != Order.end() && *It != &Start)
699 ++It;
700
701 // This is unexpected. Worst case |Start| is the last block,
702 // so It should point to the last block, not past-end.
703 assert(It != Order.end());
704
705 // By default, there is no rank limit. Setting it to the maximum value.
706 std::optional<size_t> EndRank = std::nullopt;
707 for (; It != Order.end(); ++It) {
708 if (EndRank.has_value() && BlockToOrder[*It].Rank > *EndRank)
709 break;
710
711 if (Reachable.count(x: *It) == 0) {
712 continue;
713 }
714
715 if (!Op(*It)) {
716 EndRank = BlockToOrder[*It].Rank;
717 }
718 }
719}
720
721bool sortBlocks(Function &F) {
722 if (F.size() == 0)
723 return false;
724
725 bool Modified = false;
726 std::vector<BasicBlock *> Order;
727 Order.reserve(n: F.size());
728
729 ReversePostOrderTraversal<Function *> RPOT(&F);
730 llvm::append_range(C&: Order, R&: RPOT);
731
732 assert(&*F.begin() == Order[0]);
733 BasicBlock *LastBlock = &*F.begin();
734 for (BasicBlock *BB : Order) {
735 if (BB != LastBlock && &*LastBlock->getNextNode() != BB) {
736 Modified = true;
737 BB->moveAfter(MovePos: LastBlock);
738 }
739 LastBlock = BB;
740 }
741
742 return Modified;
743}
744
745MachineInstr *getVRegDef(MachineRegisterInfo &MRI, Register Reg) {
746 MachineInstr *MaybeDef = MRI.getVRegDef(Reg);
747 if (MaybeDef && MaybeDef->getOpcode() == SPIRV::ASSIGN_TYPE)
748 MaybeDef = MRI.getVRegDef(Reg: MaybeDef->getOperand(i: 1).getReg());
749 return MaybeDef;
750}
751
752bool getVacantFunctionName(Module &M, std::string &Name) {
753 // It's a bit of paranoia, but still we don't want to have even a chance that
754 // the loop will work for too long.
755 constexpr unsigned MaxIters = 1024;
756 for (unsigned I = 0; I < MaxIters; ++I) {
757 std::string OrdName = Name + Twine(I).str();
758 if (!M.getFunction(Name: OrdName)) {
759 Name = OrdName;
760 return true;
761 }
762 }
763 return false;
764}
765
766// Assign SPIR-V type to the register. If the register has no valid assigned
767// class, set register LLT type and class according to the SPIR-V type.
768void setRegClassType(Register Reg, SPIRVType *SpvType, SPIRVGlobalRegistry *GR,
769 MachineRegisterInfo *MRI, const MachineFunction &MF,
770 bool Force) {
771 GR->assignSPIRVTypeToVReg(Type: SpvType, VReg: Reg, MF);
772 if (!MRI->getRegClassOrNull(Reg) || Force) {
773 MRI->setRegClass(Reg, RC: GR->getRegClass(SpvType));
774 MRI->setType(VReg: Reg, Ty: GR->getRegType(SpvType));
775 }
776}
777
778// Create a SPIR-V type, assign SPIR-V type to the register. If the register has
779// no valid assigned class, set register LLT type and class according to the
780// SPIR-V type.
781void setRegClassType(Register Reg, const Type *Ty, SPIRVGlobalRegistry *GR,
782 MachineIRBuilder &MIRBuilder,
783 SPIRV::AccessQualifier::AccessQualifier AccessQual,
784 bool EmitIR, bool Force) {
785 setRegClassType(Reg,
786 SpvType: GR->getOrCreateSPIRVType(Type: Ty, MIRBuilder, AQ: AccessQual, EmitIR),
787 GR, MRI: MIRBuilder.getMRI(), MF: MIRBuilder.getMF(), Force);
788}
789
790// Create a virtual register and assign SPIR-V type to the register. Set
791// register LLT type and class according to the SPIR-V type.
792Register createVirtualRegister(SPIRVType *SpvType, SPIRVGlobalRegistry *GR,
793 MachineRegisterInfo *MRI,
794 const MachineFunction &MF) {
795 Register Reg = MRI->createVirtualRegister(RegClass: GR->getRegClass(SpvType));
796 MRI->setType(VReg: Reg, Ty: GR->getRegType(SpvType));
797 GR->assignSPIRVTypeToVReg(Type: SpvType, VReg: Reg, MF);
798 return Reg;
799}
800
801// Create a virtual register and assign SPIR-V type to the register. Set
802// register LLT type and class according to the SPIR-V type.
803Register createVirtualRegister(SPIRVType *SpvType, SPIRVGlobalRegistry *GR,
804 MachineIRBuilder &MIRBuilder) {
805 return createVirtualRegister(SpvType, GR, MRI: MIRBuilder.getMRI(),
806 MF: MIRBuilder.getMF());
807}
808
809// Create a SPIR-V type, virtual register and assign SPIR-V type to the
810// register. Set register LLT type and class according to the SPIR-V type.
811Register createVirtualRegister(
812 const Type *Ty, SPIRVGlobalRegistry *GR, MachineIRBuilder &MIRBuilder,
813 SPIRV::AccessQualifier::AccessQualifier AccessQual, bool EmitIR) {
814 return createVirtualRegister(
815 SpvType: GR->getOrCreateSPIRVType(Type: Ty, MIRBuilder, AQ: AccessQual, EmitIR), GR,
816 MIRBuilder);
817}
818
819CallInst *buildIntrWithMD(Intrinsic::ID IntrID, ArrayRef<Type *> Types,
820 Value *Arg, Value *Arg2, ArrayRef<Constant *> Imms,
821 IRBuilder<> &B) {
822 SmallVector<Value *, 4> Args;
823 Args.push_back(Elt: Arg2);
824 Args.push_back(Elt: buildMD(Arg));
825 llvm::append_range(C&: Args, R&: Imms);
826 return B.CreateIntrinsic(ID: IntrID, Types: {Types}, Args);
827}
828
829// Return true if there is an opaque pointer type nested in the argument.
830bool isNestedPointer(const Type *Ty) {
831 if (Ty->isPtrOrPtrVectorTy())
832 return true;
833 if (const FunctionType *RefTy = dyn_cast<FunctionType>(Val: Ty)) {
834 if (isNestedPointer(Ty: RefTy->getReturnType()))
835 return true;
836 for (const Type *ArgTy : RefTy->params())
837 if (isNestedPointer(Ty: ArgTy))
838 return true;
839 return false;
840 }
841 if (const ArrayType *RefTy = dyn_cast<ArrayType>(Val: Ty))
842 return isNestedPointer(Ty: RefTy->getElementType());
843 return false;
844}
845
846bool isSpvIntrinsic(const Value *Arg) {
847 if (const auto *II = dyn_cast<IntrinsicInst>(Val: Arg))
848 if (Function *F = II->getCalledFunction())
849 if (F->getName().starts_with(Prefix: "llvm.spv."))
850 return true;
851 return false;
852}
853
854// Function to create continued instructions for SPV_INTEL_long_composites
855// extension
856SmallVector<MachineInstr *, 4>
857createContinuedInstructions(MachineIRBuilder &MIRBuilder, unsigned Opcode,
858 unsigned MinWC, unsigned ContinuedOpcode,
859 ArrayRef<Register> Args, Register ReturnRegister,
860 Register TypeID) {
861
862 SmallVector<MachineInstr *, 4> Instructions;
863 constexpr unsigned MaxWordCount = UINT16_MAX;
864 const size_t NumElements = Args.size();
865 size_t MaxNumElements = MaxWordCount - MinWC;
866 size_t SPIRVStructNumElements = NumElements;
867
868 if (NumElements > MaxNumElements) {
869 // Do adjustments for continued instructions which always had only one
870 // minumum word count.
871 SPIRVStructNumElements = MaxNumElements;
872 MaxNumElements = MaxWordCount - 1;
873 }
874
875 auto MIB =
876 MIRBuilder.buildInstr(Opcode).addDef(RegNo: ReturnRegister).addUse(RegNo: TypeID);
877
878 for (size_t I = 0; I < SPIRVStructNumElements; ++I)
879 MIB.addUse(RegNo: Args[I]);
880
881 Instructions.push_back(Elt: MIB.getInstr());
882
883 for (size_t I = SPIRVStructNumElements; I < NumElements;
884 I += MaxNumElements) {
885 auto MIB = MIRBuilder.buildInstr(Opcode: ContinuedOpcode);
886 for (size_t J = I; J < std::min(a: I + MaxNumElements, b: NumElements); ++J)
887 MIB.addUse(RegNo: Args[J]);
888 Instructions.push_back(Elt: MIB.getInstr());
889 }
890 return Instructions;
891}
892
893SmallVector<unsigned, 1> getSpirvLoopControlOperandsFromLoopMetadata(Loop *L) {
894 unsigned LC = SPIRV::LoopControl::None;
895 // Currently used only to store PartialCount value. Later when other
896 // LoopControls are added - this map should be sorted before making
897 // them loop_merge operands to satisfy 3.23. Loop Control requirements.
898 std::vector<std::pair<unsigned, unsigned>> MaskToValueMap;
899 if (getBooleanLoopAttribute(TheLoop: L, Name: "llvm.loop.unroll.disable")) {
900 LC |= SPIRV::LoopControl::DontUnroll;
901 } else {
902 if (getBooleanLoopAttribute(TheLoop: L, Name: "llvm.loop.unroll.enable") ||
903 getBooleanLoopAttribute(TheLoop: L, Name: "llvm.loop.unroll.full")) {
904 LC |= SPIRV::LoopControl::Unroll;
905 }
906 std::optional<int> Count =
907 getOptionalIntLoopAttribute(TheLoop: L, Name: "llvm.loop.unroll.count");
908 if (Count && Count != 1) {
909 LC |= SPIRV::LoopControl::PartialCount;
910 MaskToValueMap.emplace_back(
911 args: std::make_pair(x: SPIRV::LoopControl::PartialCount, y&: *Count));
912 }
913 }
914 SmallVector<unsigned, 1> Result = {LC};
915 for (auto &[Mask, Val] : MaskToValueMap)
916 Result.push_back(Elt: Val);
917 return Result;
918}
919
920const std::set<unsigned> &getTypeFoldingSupportedOpcodes() {
921 // clang-format off
922 static const std::set<unsigned> TypeFoldingSupportingOpcs = {
923 TargetOpcode::G_ADD,
924 TargetOpcode::G_FADD,
925 TargetOpcode::G_STRICT_FADD,
926 TargetOpcode::G_SUB,
927 TargetOpcode::G_FSUB,
928 TargetOpcode::G_STRICT_FSUB,
929 TargetOpcode::G_MUL,
930 TargetOpcode::G_FMUL,
931 TargetOpcode::G_STRICT_FMUL,
932 TargetOpcode::G_SDIV,
933 TargetOpcode::G_UDIV,
934 TargetOpcode::G_FDIV,
935 TargetOpcode::G_STRICT_FDIV,
936 TargetOpcode::G_SREM,
937 TargetOpcode::G_UREM,
938 TargetOpcode::G_FREM,
939 TargetOpcode::G_STRICT_FREM,
940 TargetOpcode::G_FNEG,
941 TargetOpcode::G_CONSTANT,
942 TargetOpcode::G_FCONSTANT,
943 TargetOpcode::G_AND,
944 TargetOpcode::G_OR,
945 TargetOpcode::G_XOR,
946 TargetOpcode::G_SHL,
947 TargetOpcode::G_ASHR,
948 TargetOpcode::G_LSHR,
949 TargetOpcode::G_SELECT,
950 TargetOpcode::G_EXTRACT_VECTOR_ELT,
951 };
952 // clang-format on
953 return TypeFoldingSupportingOpcs;
954}
955
956bool isTypeFoldingSupported(unsigned Opcode) {
957 return getTypeFoldingSupportedOpcodes().count(x: Opcode) > 0;
958}
959
960// Traversing [g]MIR accounting for pseudo-instructions.
961MachineInstr *passCopy(MachineInstr *Def, const MachineRegisterInfo *MRI) {
962 return (Def->getOpcode() == SPIRV::ASSIGN_TYPE ||
963 Def->getOpcode() == TargetOpcode::COPY)
964 ? MRI->getVRegDef(Reg: Def->getOperand(i: 1).getReg())
965 : Def;
966}
967
968MachineInstr *getDef(const MachineOperand &MO, const MachineRegisterInfo *MRI) {
969 if (MachineInstr *Def = MRI->getVRegDef(Reg: MO.getReg()))
970 return passCopy(Def, MRI);
971 return nullptr;
972}
973
974MachineInstr *getImm(const MachineOperand &MO, const MachineRegisterInfo *MRI) {
975 if (MachineInstr *Def = getDef(MO, MRI)) {
976 if (Def->getOpcode() == TargetOpcode::G_CONSTANT ||
977 Def->getOpcode() == SPIRV::OpConstantI)
978 return Def;
979 }
980 return nullptr;
981}
982
983int64_t foldImm(const MachineOperand &MO, const MachineRegisterInfo *MRI) {
984 if (MachineInstr *Def = getImm(MO, MRI)) {
985 if (Def->getOpcode() == SPIRV::OpConstantI)
986 return Def->getOperand(i: 2).getImm();
987 if (Def->getOpcode() == TargetOpcode::G_CONSTANT)
988 return Def->getOperand(i: 1).getCImm()->getZExtValue();
989 }
990 llvm_unreachable("Unexpected integer constant pattern");
991}
992
993unsigned getArrayComponentCount(const MachineRegisterInfo *MRI,
994 const MachineInstr *ResType) {
995 return foldImm(MO: ResType->getOperand(i: 2), MRI);
996}
997
998} // namespace llvm
999