1//===--- SPIRVUtils.cpp ---- SPIR-V Utility Functions -----------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file contains miscellaneous utility functions.
10//
11//===----------------------------------------------------------------------===//
12
13#include "SPIRVUtils.h"
14#include "MCTargetDesc/SPIRVBaseInfo.h"
15#include "SPIRV.h"
16#include "SPIRVGlobalRegistry.h"
17#include "SPIRVInstrInfo.h"
18#include "SPIRVSubtarget.h"
19#include "llvm/ADT/StringRef.h"
20#include "llvm/CodeGen/GlobalISel/GenericMachineInstrs.h"
21#include "llvm/CodeGen/GlobalISel/MachineIRBuilder.h"
22#include "llvm/CodeGen/MachineInstr.h"
23#include "llvm/CodeGen/MachineInstrBuilder.h"
24#include "llvm/Demangle/Demangle.h"
25#include "llvm/IR/IntrinsicInst.h"
26#include "llvm/IR/IntrinsicsSPIRV.h"
27#include <queue>
28#include <vector>
29
30namespace llvm {
31namespace SPIRV {
32// This code restores function args/retvalue types for composite cases
33// because the final types should still be aggregate whereas they're i32
34// during the translation to cope with aggregate flattening etc.
35// TODO: should these just return nullptr when there's no metadata?
36static FunctionType *extractFunctionTypeFromMetadata(NamedMDNode *NMD,
37 FunctionType *FTy,
38 StringRef Name) {
39 if (!NMD)
40 return FTy;
41
42 constexpr auto getConstInt = [](MDNode *MD, unsigned OpId) -> ConstantInt * {
43 if (MD->getNumOperands() <= OpId)
44 return nullptr;
45 if (auto *CMeta = dyn_cast<ConstantAsMetadata>(Val: MD->getOperand(I: OpId)))
46 return dyn_cast<ConstantInt>(Val: CMeta->getValue());
47 return nullptr;
48 };
49
50 auto It = find_if(Range: NMD->operands(), P: [Name](MDNode *N) {
51 if (auto *MDS = dyn_cast_or_null<MDString>(Val: N->getOperand(I: 0)))
52 return MDS->getString() == Name;
53 return false;
54 });
55
56 if (It == NMD->op_end())
57 return FTy;
58
59 Type *RetTy = FTy->getReturnType();
60 SmallVector<Type *, 4> PTys(FTy->params());
61
62 for (unsigned I = 1; I != (*It)->getNumOperands(); ++I) {
63 MDNode *MD = dyn_cast<MDNode>(Val: (*It)->getOperand(I));
64 assert(MD && "MDNode operand is expected");
65
66 if (auto *Const = getConstInt(MD, 0)) {
67 auto *CMeta = dyn_cast<ConstantAsMetadata>(Val: MD->getOperand(I: 1));
68 assert(CMeta && "ConstantAsMetadata operand is expected");
69 int64_t Idx = Const->getSExtValue();
70 // Currently -1 indicates return value, greater values mean
71 // argument numbers.
72 if (Idx == -1) {
73 RetTy = CMeta->getType();
74 continue;
75 }
76 if (Idx >= 0 && static_cast<uint64_t>(Idx) < PTys.size()) {
77 PTys[Idx] = CMeta->getType();
78 continue;
79 }
80 report_fatal_error(reason: "invalid argument index in function type metadata");
81 }
82 }
83
84 return FunctionType::get(Result: RetTy, Params: PTys, isVarArg: FTy->isVarArg());
85}
86
87static StringRef extractAsmConstraintsFromMetadata(NamedMDNode *NMD,
88 StringRef Constraints,
89 StringRef Name) {
90 // TODO: unify the extractors.
91 if (!NMD)
92 return Constraints;
93
94 auto It = find_if(Range: NMD->operands(), P: [Name](MDNode *N) {
95 if (auto *MDS = dyn_cast_or_null<MDString>(Val: N->getOperand(I: 0)))
96 return MDS->getString() == Name;
97 return false;
98 });
99
100 if (It == NMD->op_end())
101 return Constraints;
102
103 // By convention, the constraints string is stored in the final MD operand.
104 MDNode *MD = dyn_cast<MDNode>(Val: (*It)->getOperand(I: (*It)->getNumOperands() - 1));
105 assert(MD && "MDNode operand is expected");
106
107 if (auto *MDS = dyn_cast<MDString>(Val: MD->getOperand(I: 0)))
108 Constraints = MDS->getString();
109
110 return Constraints;
111}
112
113FunctionType *getOriginalFunctionType(const Function &F) {
114 return extractFunctionTypeFromMetadata(
115 NMD: F.getParent()->getNamedMetadata(Name: "spv.cloned_funcs"), FTy: F.getFunctionType(),
116 Name: F.getName());
117}
118
119// Keyed via instruction metadata, not a name.
120static std::optional<StringRef> getMutatedCallsiteKey(const CallBase &CB) {
121 if (MDNode *MD = CB.getMetadata(Kind: "spv.mutated_callsite"))
122 if (MD->getNumOperands() > 0)
123 if (auto *MDS = dyn_cast<MDString>(Val: MD->getOperand(I: 0)))
124 return MDS->getString();
125 return std::nullopt;
126}
127
128FunctionType *getOriginalFunctionType(const CallBase &CB) {
129 std::optional<StringRef> Key = getMutatedCallsiteKey(CB);
130 if (!Key)
131 return CB.getFunctionType();
132 return extractFunctionTypeFromMetadata(
133 NMD: CB.getModule()->getNamedMetadata(Name: "spv.mutated_callsites"),
134 FTy: CB.getFunctionType(), Name: *Key);
135}
136
137StringRef getOriginalAsmConstraints(const CallBase &CB) {
138 StringRef Constraints =
139 cast<InlineAsm>(Val: CB.getCalledOperand())->getConstraintString();
140 std::optional<StringRef> Key = getMutatedCallsiteKey(CB);
141 if (!Key)
142 return Constraints;
143 return extractAsmConstraintsFromMetadata(
144 NMD: CB.getModule()->getNamedMetadata(Name: "spv.mutated_callsites"), Constraints,
145 Name: *Key);
146}
147} // Namespace SPIRV
148
149// The following functions are used to add these string literals as a series of
150// 32-bit integer operands with the correct format, and unpack them if necessary
151// when making string comparisons in compiler passes.
152// SPIR-V requires null-terminated UTF-8 strings padded to 32-bit alignment.
153static uint32_t convertCharsToWord(const StringRef &Str, unsigned i) {
154 uint32_t Word = 0u; // Build up this 32-bit word from 4 8-bit chars.
155 for (unsigned WordIndex = 0; WordIndex < 4; ++WordIndex) {
156 unsigned StrIndex = i + WordIndex;
157 uint8_t CharToAdd = 0; // Initilize char as padding/null.
158 if (StrIndex < Str.size()) { // If it's within the string, get a real char.
159 CharToAdd = Str[StrIndex];
160 }
161 Word |= (CharToAdd << (WordIndex * 8));
162 }
163 return Word;
164}
165
166// Get length including padding and null terminator.
167static size_t getPaddedLen(const StringRef &Str) {
168 return (Str.size() + 4) & ~3;
169}
170
171void addStringImm(const StringRef &Str, MCInst &Inst) {
172 const size_t PaddedLen = getPaddedLen(Str);
173 for (unsigned i = 0; i < PaddedLen; i += 4) {
174 // Add an operand for the 32-bits of chars or padding.
175 Inst.addOperand(Op: MCOperand::createImm(Val: convertCharsToWord(Str, i)));
176 }
177}
178
179void addStringImm(const StringRef &Str, MachineInstrBuilder &MIB) {
180 const size_t PaddedLen = getPaddedLen(Str);
181 for (unsigned i = 0; i < PaddedLen; i += 4) {
182 // Add an operand for the 32-bits of chars or padding.
183 MIB.addImm(Val: convertCharsToWord(Str, i));
184 }
185}
186
187void addStringImm(const StringRef &Str, IRBuilder<> &B,
188 std::vector<Value *> &Args) {
189 const size_t PaddedLen = getPaddedLen(Str);
190 for (unsigned i = 0; i < PaddedLen; i += 4) {
191 // Add a vector element for the 32-bits of chars or padding.
192 Args.push_back(x: B.getInt32(C: convertCharsToWord(Str, i)));
193 }
194}
195
196std::string getStringImm(const MachineInstr &MI, unsigned StartIndex) {
197 return getSPIRVStringOperand(MI, StartIndex);
198}
199
200std::string getStringValueFromReg(Register Reg, MachineRegisterInfo &MRI) {
201 MachineInstr *Def = getVRegDef(MRI, Reg);
202 assert(Def && Def->getOpcode() == TargetOpcode::G_GLOBAL_VALUE &&
203 "Expected G_GLOBAL_VALUE");
204 const GlobalValue *GV = Def->getOperand(i: 1).getGlobal();
205 Value *V = GV->getOperand(i: 0);
206 const ConstantDataArray *CDA = cast<ConstantDataArray>(Val: V);
207 return CDA->getAsCString().str();
208}
209
210void addNumImm(const APInt &Imm, MachineInstrBuilder &MIB) {
211 const auto Bitwidth = Imm.getBitWidth();
212 if (Bitwidth == 1)
213 return; // Already handled
214 else if (Bitwidth <= 32) {
215 MIB.addImm(Val: Imm.getZExtValue());
216 // Asm Printer needs this info to print floating-type correctly
217 if (Bitwidth == 16)
218 MIB.getInstr()->setAsmPrinterFlag(SPIRV::ASM_PRINTER_WIDTH16);
219 return;
220 } else if (Bitwidth <= 64) {
221 uint64_t FullImm = Imm.getZExtValue();
222 uint32_t LowBits = FullImm & 0xffffffff;
223 uint32_t HighBits = (FullImm >> 32) & 0xffffffff;
224 MIB.addImm(Val: LowBits).addImm(Val: HighBits);
225 // Asm Printer needs this info to print 64-bit operands correctly
226 MIB.getInstr()->setAsmPrinterFlag(SPIRV::ASM_PRINTER_WIDTH64);
227 return;
228 } else {
229 // Emit ceil(Bitwidth / 32) words to conform SPIR-V spec.
230 unsigned NumWords = (Bitwidth + 31) / 32;
231 for (unsigned I = 0; I < NumWords; ++I) {
232 unsigned LimbIdx = I / 2;
233 unsigned LimbShift = (I % 2) * 32;
234 uint32_t Word = (Imm.getRawData()[LimbIdx] >> LimbShift) & 0xffffffff;
235 MIB.addImm(Val: Word);
236 }
237 return;
238 }
239}
240
241void buildOpName(Register Target, const StringRef &Name,
242 MachineIRBuilder &MIRBuilder) {
243 if (!Name.empty()) {
244 auto MIB = MIRBuilder.buildInstr(Opcode: SPIRV::OpName).addUse(RegNo: Target);
245 addStringImm(Str: Name, MIB);
246 }
247}
248
249void buildOpName(Register Target, const StringRef &Name, MachineInstr &I,
250 const SPIRVInstrInfo &TII) {
251 if (!Name.empty()) {
252 auto MIB =
253 BuildMI(BB&: *I.getParent(), I, MIMD: I.getDebugLoc(), MCID: TII.get(Opcode: SPIRV::OpName))
254 .addUse(RegNo: Target);
255 addStringImm(Str: Name, MIB);
256 }
257}
258
259static void finishBuildOpDecorate(MachineInstrBuilder &MIB,
260 ArrayRef<uint32_t> DecArgs,
261 StringRef StrImm) {
262 if (!StrImm.empty())
263 addStringImm(Str: StrImm, MIB);
264 for (const auto &DecArg : DecArgs)
265 MIB.addImm(Val: DecArg);
266}
267
268void buildOpDecorate(Register Reg, MachineIRBuilder &MIRBuilder,
269 SPIRV::Decoration::Decoration Dec,
270 ArrayRef<uint32_t> DecArgs, StringRef StrImm) {
271 auto MIB = MIRBuilder.buildInstr(Opcode: SPIRV::OpDecorate)
272 .addUse(RegNo: Reg)
273 .addImm(Val: static_cast<uint32_t>(Dec));
274 finishBuildOpDecorate(MIB, DecArgs, StrImm);
275}
276
277void buildOpDecorate(Register Reg, MachineInstr &I, const SPIRVInstrInfo &TII,
278 SPIRV::Decoration::Decoration Dec,
279 ArrayRef<uint32_t> DecArgs, StringRef StrImm) {
280 MachineBasicBlock &MBB = *I.getParent();
281 auto MIB = BuildMI(BB&: MBB, I, MIMD: I.getDebugLoc(), MCID: TII.get(Opcode: SPIRV::OpDecorate))
282 .addUse(RegNo: Reg)
283 .addImm(Val: static_cast<uint32_t>(Dec));
284 finishBuildOpDecorate(MIB, DecArgs, StrImm);
285}
286
287void buildOpMemberDecorate(Register Reg, MachineIRBuilder &MIRBuilder,
288 SPIRV::Decoration::Decoration Dec, uint32_t Member,
289 ArrayRef<uint32_t> DecArgs, StringRef StrImm) {
290 auto MIB = MIRBuilder.buildInstr(Opcode: SPIRV::OpMemberDecorate)
291 .addUse(RegNo: Reg)
292 .addImm(Val: Member)
293 .addImm(Val: static_cast<uint32_t>(Dec));
294 finishBuildOpDecorate(MIB, DecArgs, StrImm);
295}
296
297void buildOpMemberDecorate(Register Reg, MachineInstr &I,
298 const SPIRVInstrInfo &TII,
299 SPIRV::Decoration::Decoration Dec, uint32_t Member,
300 ArrayRef<uint32_t> DecArgs, StringRef StrImm) {
301 MachineBasicBlock &MBB = *I.getParent();
302 auto MIB = BuildMI(BB&: MBB, I, MIMD: I.getDebugLoc(), MCID: TII.get(Opcode: SPIRV::OpMemberDecorate))
303 .addUse(RegNo: Reg)
304 .addImm(Val: Member)
305 .addImm(Val: static_cast<uint32_t>(Dec));
306 finishBuildOpDecorate(MIB, DecArgs, StrImm);
307}
308
309void buildOpSpirvDecorations(Register Reg, MachineIRBuilder &MIRBuilder,
310 const MDNode *GVarMD, const SPIRVSubtarget &ST) {
311 for (unsigned I = 0, E = GVarMD->getNumOperands(); I != E; ++I) {
312 auto *OpMD = dyn_cast<MDNode>(Val: GVarMD->getOperand(I));
313 if (!OpMD)
314 report_fatal_error(reason: "Invalid decoration");
315 if (OpMD->getNumOperands() == 0)
316 report_fatal_error(reason: "Expect operand(s) of the decoration");
317 ConstantInt *DecorationId =
318 mdconst::dyn_extract<ConstantInt>(MD: OpMD->getOperand(I: 0));
319 if (!DecorationId)
320 report_fatal_error(reason: "Expect SPIR-V <Decoration> operand to be the first "
321 "element of the decoration");
322
323 // The goal of `spirv.Decorations` metadata is to provide a way to
324 // represent SPIR-V entities that do not map to LLVM in an obvious way.
325 // FP flags do have obvious matches between LLVM IR and SPIR-V.
326 // Additionally, we have no guarantee at this point that the flags passed
327 // through the decoration are not violated already in the optimizer passes.
328 // Therefore, we simply ignore FP flags, including NoContraction, and
329 // FPFastMathMode.
330 if (DecorationId->getZExtValue() ==
331 static_cast<uint32_t>(SPIRV::Decoration::NoContraction) ||
332 DecorationId->getZExtValue() ==
333 static_cast<uint32_t>(SPIRV::Decoration::FPFastMathMode)) {
334 continue; // Ignored.
335 }
336 auto MIB = MIRBuilder.buildInstr(Opcode: SPIRV::OpDecorate)
337 .addUse(RegNo: Reg)
338 .addImm(Val: static_cast<uint32_t>(DecorationId->getZExtValue()));
339 for (unsigned OpI = 1, OpE = OpMD->getNumOperands(); OpI != OpE; ++OpI) {
340 if (ConstantInt *OpV =
341 mdconst::dyn_extract<ConstantInt>(MD: OpMD->getOperand(I: OpI)))
342 MIB.addImm(Val: static_cast<uint32_t>(OpV->getZExtValue()));
343 else if (MDString *OpV = dyn_cast<MDString>(Val: OpMD->getOperand(I: OpI)))
344 addStringImm(Str: OpV->getString(), MIB);
345 else
346 report_fatal_error(reason: "Unexpected operand of the decoration");
347 }
348 }
349}
350
351MachineBasicBlock::iterator getOpVariableMBBIt(MachineFunction &MF) {
352 MachineBasicBlock &MBB = MF.front();
353 // Find the position to insert the OpVariable instruction.
354 // We will insert it after the last OpFunctionParameter, if any, or
355 // after OpFunction otherwise.
356 auto IsPreamble = [](const MachineInstr &MI) {
357 switch (MI.getOpcode()) {
358 case SPIRV::OpFunction:
359 case SPIRV::OpFunctionParameter:
360 case SPIRV::OpLabel:
361 case SPIRV::ASSIGN_TYPE:
362 return true;
363 default:
364 return false;
365 }
366 };
367 MachineBasicBlock::iterator VarPos = MBB.SkipPHIsAndLabels(I: MBB.begin());
368 while (VarPos != MBB.end() && VarPos->getOpcode() != SPIRV::OpFunction)
369 ++VarPos;
370 // Advance past the preamble.
371 while (VarPos != MBB.end() && IsPreamble(*VarPos))
372 ++VarPos;
373 return VarPos;
374}
375
376MachineBasicBlock::iterator getInsertPtValidEnd(MachineBasicBlock *MBB) {
377 MachineBasicBlock::iterator I = MBB->end();
378 if (I == MBB->begin())
379 return I;
380 --I;
381 while (I->isTerminator() || I->isDebugValue()) {
382 if (I == MBB->begin())
383 break;
384 --I;
385 }
386 return I;
387}
388
389SPIRV::StorageClass::StorageClass
390addressSpaceToStorageClass(unsigned AddrSpace, const SPIRVSubtarget &STI) {
391 switch (AddrSpace) {
392 case 0:
393 return SPIRV::StorageClass::Function;
394 case 1:
395 return SPIRV::StorageClass::CrossWorkgroup;
396 case 2:
397 return SPIRV::StorageClass::UniformConstant;
398 case 3:
399 return SPIRV::StorageClass::Workgroup;
400 case 4:
401 return SPIRV::StorageClass::Generic;
402 case 5:
403 return STI.canUseExtension(E: SPIRV::Extension::SPV_INTEL_usm_storage_classes)
404 ? SPIRV::StorageClass::DeviceOnlyINTEL
405 : SPIRV::StorageClass::CrossWorkgroup;
406 case 6:
407 return STI.canUseExtension(E: SPIRV::Extension::SPV_INTEL_usm_storage_classes)
408 ? SPIRV::StorageClass::HostOnlyINTEL
409 : SPIRV::StorageClass::CrossWorkgroup;
410 case 7:
411 return SPIRV::StorageClass::Input;
412 case 8:
413 return SPIRV::StorageClass::Output;
414 case 9:
415 return SPIRV::StorageClass::CodeSectionINTEL;
416 case 10:
417 return SPIRV::StorageClass::Private;
418 case 11:
419 return SPIRV::StorageClass::StorageBuffer;
420 case 12:
421 return SPIRV::StorageClass::Uniform;
422 case 13:
423 return SPIRV::StorageClass::PushConstant;
424 default:
425 report_fatal_error(reason: "Unknown address space");
426 }
427}
428
429SPIRV::MemorySemantics::MemorySemantics
430getMemSemanticsForStorageClass(SPIRV::StorageClass::StorageClass SC) {
431 switch (SC) {
432 case SPIRV::StorageClass::StorageBuffer:
433 case SPIRV::StorageClass::Uniform:
434 return SPIRV::MemorySemantics::UniformMemory;
435 case SPIRV::StorageClass::Workgroup:
436 return SPIRV::MemorySemantics::WorkgroupMemory;
437 case SPIRV::StorageClass::CrossWorkgroup:
438 return SPIRV::MemorySemantics::CrossWorkgroupMemory;
439 case SPIRV::StorageClass::AtomicCounter:
440 return SPIRV::MemorySemantics::AtomicCounterMemory;
441 case SPIRV::StorageClass::Image:
442 return SPIRV::MemorySemantics::ImageMemory;
443 default:
444 return SPIRV::MemorySemantics::None;
445 }
446}
447
448SPIRV::MemorySemantics::MemorySemantics getMemSemantics(AtomicOrdering Ord) {
449 switch (Ord) {
450 case AtomicOrdering::Acquire:
451 return SPIRV::MemorySemantics::Acquire;
452 case AtomicOrdering::Release:
453 return SPIRV::MemorySemantics::Release;
454 case AtomicOrdering::AcquireRelease:
455 return SPIRV::MemorySemantics::AcquireRelease;
456 case AtomicOrdering::SequentiallyConsistent:
457 return SPIRV::MemorySemantics::SequentiallyConsistent;
458 case AtomicOrdering::Unordered:
459 case AtomicOrdering::Monotonic:
460 case AtomicOrdering::NotAtomic:
461 return SPIRV::MemorySemantics::None;
462 }
463 llvm_unreachable(nullptr);
464}
465
466SPIRV::Scope::Scope getMemScope(LLVMContext &Ctx, SyncScope::ID Id) {
467 // Named by
468 // https://registry.khronos.org/SPIR-V/specs/unified1/SPIRV.html#_scope_id.
469 // We don't need aliases for Invocation and CrossDevice, as we already have
470 // them covered by "singlethread" and "" strings respectively (see
471 // implementation of LLVMContext::LLVMContext()).
472 static const llvm::SyncScope::ID SubGroup =
473 Ctx.getOrInsertSyncScopeID(SSN: "subgroup");
474 static const llvm::SyncScope::ID WorkGroup =
475 Ctx.getOrInsertSyncScopeID(SSN: "workgroup");
476 static const llvm::SyncScope::ID Device =
477 Ctx.getOrInsertSyncScopeID(SSN: "device");
478
479 if (Id == llvm::SyncScope::SingleThread)
480 return SPIRV::Scope::Invocation;
481 else if (Id == llvm::SyncScope::System)
482 return SPIRV::Scope::CrossDevice;
483 else if (Id == SubGroup)
484 return SPIRV::Scope::Subgroup;
485 else if (Id == WorkGroup)
486 return SPIRV::Scope::Workgroup;
487 else if (Id == Device)
488 return SPIRV::Scope::Device;
489 return SPIRV::Scope::CrossDevice;
490}
491
492MachineInstr *getDefInstrMaybeConstant(Register &ConstReg,
493 const MachineRegisterInfo *MRI) {
494 MachineInstr *MI = MRI->getVRegDef(Reg: ConstReg);
495 MachineInstr *ConstInstr =
496 MI->getOpcode() == SPIRV::G_TRUNC || MI->getOpcode() == SPIRV::G_ZEXT
497 ? MRI->getVRegDef(Reg: MI->getOperand(i: 1).getReg())
498 : MI;
499 if (auto *GI = dyn_cast<GIntrinsic>(Val: ConstInstr)) {
500 if (GI->is(ID: Intrinsic::spv_track_constant)) {
501 ConstReg = ConstInstr->getOperand(i: 2).getReg();
502 return MRI->getVRegDef(Reg: ConstReg);
503 }
504 } else if (ConstInstr->getOpcode() == SPIRV::ASSIGN_TYPE) {
505 ConstReg = ConstInstr->getOperand(i: 1).getReg();
506 return MRI->getVRegDef(Reg: ConstReg);
507 } else if (ConstInstr->getOpcode() == TargetOpcode::G_CONSTANT ||
508 ConstInstr->getOpcode() == TargetOpcode::G_FCONSTANT) {
509 ConstReg = ConstInstr->getOperand(i: 0).getReg();
510 return ConstInstr;
511 }
512 return MRI->getVRegDef(Reg: ConstReg);
513}
514
515uint64_t getIConstVal(Register ConstReg, const MachineRegisterInfo *MRI) {
516 const MachineInstr *MI = getDefInstrMaybeConstant(ConstReg, MRI);
517 assert(MI && MI->getOpcode() == TargetOpcode::G_CONSTANT);
518 return MI->getOperand(i: 1).getCImm()->getValue().getZExtValue();
519}
520
521int64_t getIConstValSext(Register ConstReg, const MachineRegisterInfo *MRI) {
522 const MachineInstr *MI = getDefInstrMaybeConstant(ConstReg, MRI);
523 assert(MI && MI->getOpcode() == TargetOpcode::G_CONSTANT);
524 return MI->getOperand(i: 1).getCImm()->getSExtValue();
525}
526
527bool isSpvIntrinsic(const MachineInstr &MI, Intrinsic::ID IntrinsicID) {
528 if (const auto *GI = dyn_cast<GIntrinsic>(Val: &MI))
529 return GI->is(ID: IntrinsicID);
530 return false;
531}
532
533Type *getMDOperandAsType(const MDNode *N, unsigned I) {
534 Type *ElementTy = cast<ValueAsMetadata>(Val: N->getOperand(I))->getType();
535 return toTypedPointer(Ty: ElementTy);
536}
537
538// The set of names is borrowed from the SPIR-V translator.
539// TODO: may be implemented in SPIRVBuiltins.td.
540static bool isPipeOrAddressSpaceCastBI(const StringRef MangledName) {
541 return MangledName == "write_pipe_2" || MangledName == "read_pipe_2" ||
542 MangledName == "write_pipe_2_bl" || MangledName == "read_pipe_2_bl" ||
543 MangledName == "write_pipe_4" || MangledName == "read_pipe_4" ||
544 MangledName == "reserve_write_pipe" ||
545 MangledName == "reserve_read_pipe" ||
546 MangledName == "commit_write_pipe" ||
547 MangledName == "commit_read_pipe" ||
548 MangledName == "work_group_reserve_write_pipe" ||
549 MangledName == "work_group_reserve_read_pipe" ||
550 MangledName == "work_group_commit_write_pipe" ||
551 MangledName == "work_group_commit_read_pipe" ||
552 MangledName == "get_pipe_num_packets_ro" ||
553 MangledName == "get_pipe_max_packets_ro" ||
554 MangledName == "get_pipe_num_packets_wo" ||
555 MangledName == "get_pipe_max_packets_wo" ||
556 MangledName == "sub_group_reserve_write_pipe" ||
557 MangledName == "sub_group_reserve_read_pipe" ||
558 MangledName == "sub_group_commit_write_pipe" ||
559 MangledName == "sub_group_commit_read_pipe" ||
560 MangledName == "to_global" || MangledName == "to_local" ||
561 MangledName == "to_private";
562}
563
564static bool isEnqueueKernelBI(const StringRef MangledName) {
565 return MangledName == "__enqueue_kernel_basic" ||
566 MangledName == "__enqueue_kernel_basic_events" ||
567 MangledName == "__enqueue_kernel_varargs" ||
568 MangledName == "__enqueue_kernel_events_varargs";
569}
570
571static bool isKernelQueryBI(const StringRef MangledName) {
572 return MangledName == "__get_kernel_work_group_size_impl" ||
573 MangledName == "__get_kernel_sub_group_count_for_ndrange_impl" ||
574 MangledName == "__get_kernel_max_sub_group_size_for_ndrange_impl" ||
575 MangledName == "__get_kernel_preferred_work_group_size_multiple_impl";
576}
577
578static bool isNonMangledOCLBuiltin(StringRef Name) {
579 if (!Name.starts_with(Prefix: "__"))
580 return false;
581
582 return isEnqueueKernelBI(MangledName: Name) || isKernelQueryBI(MangledName: Name) ||
583 isPipeOrAddressSpaceCastBI(MangledName: Name.drop_front(N: 2)) ||
584 Name == "__translate_sampler_initializer";
585}
586
587std::string getOclOrSpirvBuiltinDemangledName(StringRef Name) {
588 bool IsNonMangledOCL = isNonMangledOCLBuiltin(Name);
589 bool IsNonMangledSPIRV = Name.starts_with(Prefix: "__spirv_");
590 bool IsNonMangledHLSL = Name.starts_with(Prefix: "__hlsl_");
591 bool IsMangled = Name.starts_with(Prefix: "_Z");
592
593 // Otherwise use simple demangling to return the function name.
594 if (IsNonMangledOCL || IsNonMangledSPIRV || IsNonMangledHLSL || !IsMangled)
595 return Name.str();
596
597 // Try to use the itanium demangler.
598 if (char *DemangledName = itaniumDemangle(mangled_name: Name.data())) {
599 std::string Result = DemangledName;
600 free(ptr: DemangledName);
601 return Result;
602 }
603
604 // Autocheck C++, maybe need to do explicit check of the source language.
605 // OpenCL C++ built-ins are declared in cl namespace.
606 // TODO: consider using 'St' abbriviation for cl namespace mangling.
607 // Similar to ::std:: in C++.
608 size_t Start, Len = 0;
609 size_t DemangledNameLenStart = 2;
610 if (Name.starts_with(Prefix: "_ZN")) {
611 // Skip CV and ref qualifiers.
612 size_t NameSpaceStart = Name.find_first_not_of(Chars: "rVKRO", From: 3);
613 // All built-ins are in the ::cl:: namespace.
614 if (Name.substr(Start: NameSpaceStart, N: 11) != "2cl7__spirv")
615 return std::string();
616 DemangledNameLenStart = NameSpaceStart + 11;
617 }
618 Start = Name.find_first_not_of(Chars: "0123456789", From: DemangledNameLenStart);
619 [[maybe_unused]] bool Error =
620 Name.substr(Start: DemangledNameLenStart, N: Start - DemangledNameLenStart)
621 .getAsInteger(Radix: 10, Result&: Len);
622 assert(!Error && "Failed to parse demangled name length");
623 return Name.substr(Start, N: Len).str();
624}
625
626bool hasBuiltinTypePrefix(StringRef Name) {
627 if (Name.starts_with(Prefix: "opencl.") || Name.starts_with(Prefix: "ocl_") ||
628 Name.starts_with(Prefix: "spirv."))
629 return true;
630 return false;
631}
632
633bool isSpecialOpaqueType(const Type *Ty) {
634 if (const TargetExtType *ExtTy = dyn_cast<TargetExtType>(Val: Ty))
635 return isTypedPointerWrapper(ExtTy)
636 ? false
637 : hasBuiltinTypePrefix(Name: ExtTy->getName());
638
639 return false;
640}
641
642bool isEntryPoint(const Function &F) {
643 // OpenCL handling: any function with the SPIR_KERNEL
644 // calling convention will be a potential entry point.
645 if (F.getCallingConv() == CallingConv::SPIR_KERNEL)
646 return true;
647
648 // HLSL handling: special attribute are emitted from the
649 // front-end.
650 if (F.getFnAttribute(Kind: "hlsl.shader").isValid())
651 return true;
652
653 return false;
654}
655
656Type *parseBasicTypeName(StringRef &TypeName, LLVMContext &Ctx) {
657 TypeName.consume_front(Prefix: "atomic_");
658 if (TypeName.consume_front(Prefix: "void"))
659 return Type::getVoidTy(C&: Ctx);
660 else if (TypeName.consume_front(Prefix: "bool") || TypeName.consume_front(Prefix: "_Bool"))
661 return Type::getIntNTy(C&: Ctx, N: 1);
662 else if (TypeName.consume_front(Prefix: "char") ||
663 TypeName.consume_front(Prefix: "signed char") ||
664 TypeName.consume_front(Prefix: "unsigned char") ||
665 TypeName.consume_front(Prefix: "uchar"))
666 return Type::getInt8Ty(C&: Ctx);
667 else if (TypeName.consume_front(Prefix: "short") ||
668 TypeName.consume_front(Prefix: "signed short") ||
669 TypeName.consume_front(Prefix: "unsigned short") ||
670 TypeName.consume_front(Prefix: "ushort"))
671 return Type::getInt16Ty(C&: Ctx);
672 else if (TypeName.consume_front(Prefix: "int") ||
673 TypeName.consume_front(Prefix: "signed int") ||
674 TypeName.consume_front(Prefix: "unsigned int") ||
675 TypeName.consume_front(Prefix: "uint"))
676 return Type::getInt32Ty(C&: Ctx);
677 else if (TypeName.consume_front(Prefix: "long") ||
678 TypeName.consume_front(Prefix: "signed long") ||
679 TypeName.consume_front(Prefix: "unsigned long") ||
680 TypeName.consume_front(Prefix: "ulong"))
681 return Type::getInt64Ty(C&: Ctx);
682 else if (TypeName.consume_front(Prefix: "half") ||
683 TypeName.consume_front(Prefix: "_Float16") ||
684 TypeName.consume_front(Prefix: "__fp16"))
685 return Type::getHalfTy(C&: Ctx);
686 else if (TypeName.consume_front(Prefix: "float"))
687 return Type::getFloatTy(C&: Ctx);
688 else if (TypeName.consume_front(Prefix: "double"))
689 return Type::getDoubleTy(C&: Ctx);
690
691 // Unable to recognize SPIRV type name
692 return nullptr;
693}
694
695SmallPtrSet<BasicBlock *, 0>
696PartialOrderingVisitor::getReachableFrom(BasicBlock *Start) {
697 std::queue<BasicBlock *> ToVisit;
698 ToVisit.push(x: Start);
699
700 SmallPtrSet<BasicBlock *, 0> Output;
701 while (ToVisit.size() != 0) {
702 BasicBlock *BB = ToVisit.front();
703 ToVisit.pop();
704
705 if (Output.count(Ptr: BB) != 0)
706 continue;
707 Output.insert(Ptr: BB);
708
709 for (BasicBlock *Successor : successors(BB)) {
710 if (DT.dominates(A: Successor, B: BB))
711 continue;
712 ToVisit.push(x: Successor);
713 }
714 }
715
716 return Output;
717}
718
719bool PartialOrderingVisitor::CanBeVisited(BasicBlock *BB) const {
720 for (BasicBlock *P : predecessors(BB)) {
721 // Ignore back-edges.
722 if (DT.dominates(A: BB, B: P))
723 continue;
724
725 // One of the predecessor hasn't been visited. Not ready yet.
726 if (BlockToOrder.count(Val: P) == 0)
727 return false;
728
729 // If the block is a loop exit, the loop must be finished before
730 // we can continue.
731 Loop *L = LI.getLoopFor(BB: P);
732 if (L == nullptr || L->contains(BB))
733 continue;
734
735 // SPIR-V requires a single back-edge. And the backend first
736 // step transforms loops into the simplified format. If we have
737 // more than 1 back-edge, something is wrong.
738 assert(L->getNumBackEdges() <= 1);
739
740 // If the loop has no latch, loop's rank won't matter, so we can
741 // proceed.
742 BasicBlock *Latch = L->getLoopLatch();
743 assert(Latch);
744 if (Latch == nullptr)
745 continue;
746
747 // The latch is not ready yet, let's wait.
748 if (BlockToOrder.count(Val: Latch) == 0)
749 return false;
750 }
751
752 return true;
753}
754
755size_t PartialOrderingVisitor::GetNodeRank(BasicBlock *BB) const {
756 auto It = BlockToOrder.find(Val: BB);
757 if (It != BlockToOrder.end())
758 return It->second.Rank;
759
760 size_t result = 0;
761 for (BasicBlock *P : predecessors(BB)) {
762 // Ignore back-edges.
763 if (DT.dominates(A: BB, B: P))
764 continue;
765
766 auto Iterator = BlockToOrder.end();
767 Loop *L = LI.getLoopFor(BB: P);
768 BasicBlock *Latch = L ? L->getLoopLatch() : nullptr;
769
770 // If the predecessor is either outside a loop, or part of
771 // the same loop, simply take its rank + 1.
772 if (L == nullptr || L->contains(BB) || Latch == nullptr) {
773 Iterator = BlockToOrder.find(Val: P);
774 } else {
775 // Otherwise, take the loop's rank (highest rank in the loop) as base.
776 // Since loops have a single latch, highest rank is easy to find.
777 // If the loop has no latch, then it doesn't matter.
778 Iterator = BlockToOrder.find(Val: Latch);
779 }
780
781 assert(Iterator != BlockToOrder.end());
782 result = std::max(a: result, b: Iterator->second.Rank + 1);
783 }
784
785 return result;
786}
787
788size_t PartialOrderingVisitor::visit(BasicBlock *BB, size_t Unused) {
789 ToVisit.push(x: BB);
790 Queued.insert(Ptr: BB);
791
792 size_t QueueIndex = 0;
793 while (ToVisit.size() != 0) {
794 BasicBlock *BB = ToVisit.front();
795 ToVisit.pop();
796
797 if (!CanBeVisited(BB)) {
798 ToVisit.push(x: BB);
799 if (QueueIndex >= ToVisit.size())
800 llvm::report_fatal_error(
801 reason: "No valid candidate in the queue. Is the graph reducible?");
802 QueueIndex++;
803 continue;
804 }
805
806 QueueIndex = 0;
807 size_t Rank = GetNodeRank(BB);
808 OrderInfo Info = {.Rank: Rank, .TraversalIndex: BlockToOrder.size()};
809 BlockToOrder.try_emplace(Key: BB, Args&: Info);
810
811 for (BasicBlock *S : successors(BB)) {
812 if (Queued.count(Ptr: S) != 0)
813 continue;
814 ToVisit.push(x: S);
815 Queued.insert(Ptr: S);
816 }
817 }
818
819 return 0;
820}
821
822PartialOrderingVisitor::PartialOrderingVisitor(Function &F) {
823 DT.recalculate(Func&: F);
824 LI = LoopInfo(DT);
825
826 visit(BB: &*F.begin(), Unused: 0);
827
828 Order.reserve(n: F.size());
829 for (auto &[BB, Info] : BlockToOrder)
830 Order.emplace_back(args&: BB);
831
832 std::sort(first: Order.begin(), last: Order.end(), comp: [&](const auto &LHS, const auto &RHS) {
833 return compare(LHS, RHS);
834 });
835}
836
837bool PartialOrderingVisitor::compare(const BasicBlock *LHS,
838 const BasicBlock *RHS) const {
839 const OrderInfo &InfoLHS = BlockToOrder.at(Val: const_cast<BasicBlock *>(LHS));
840 const OrderInfo &InfoRHS = BlockToOrder.at(Val: const_cast<BasicBlock *>(RHS));
841 if (InfoLHS.Rank != InfoRHS.Rank)
842 return InfoLHS.Rank < InfoRHS.Rank;
843 return InfoLHS.TraversalIndex < InfoRHS.TraversalIndex;
844}
845
846void PartialOrderingVisitor::partialOrderVisit(
847 BasicBlock &Start, std::function<bool(BasicBlock *)> Op) {
848 SmallPtrSet<BasicBlock *, 0> Reachable = getReachableFrom(Start: &Start);
849 assert(BlockToOrder.count(&Start) != 0);
850
851 // Skipping blocks with a rank inferior to |Start|'s rank.
852 auto It = Order.begin();
853 while (It != Order.end() && *It != &Start)
854 ++It;
855
856 // This is unexpected. Worst case |Start| is the last block,
857 // so It should point to the last block, not past-end.
858 assert(It != Order.end());
859
860 // By default, there is no rank limit. Setting it to the maximum value.
861 std::optional<size_t> EndRank = std::nullopt;
862 for (; It != Order.end(); ++It) {
863 if (EndRank.has_value() && BlockToOrder[*It].Rank > *EndRank)
864 break;
865
866 if (Reachable.count(Ptr: *It) == 0) {
867 continue;
868 }
869
870 if (!Op(*It)) {
871 EndRank = BlockToOrder[*It].Rank;
872 }
873 }
874}
875
876bool sortBlocks(Function &F) {
877 if (F.size() == 0)
878 return false;
879
880 bool Modified = false;
881 std::vector<BasicBlock *> Order;
882 Order.reserve(n: F.size());
883
884 ReversePostOrderTraversal<Function *> RPOT(&F);
885 llvm::append_range(C&: Order, R&: RPOT);
886
887 assert(&*F.begin() == Order[0]);
888 BasicBlock *LastBlock = &*F.begin();
889 for (BasicBlock *BB : Order) {
890 if (BB != LastBlock && &*LastBlock->getNextNode() != BB) {
891 Modified = true;
892 BB->moveAfter(MovePos: LastBlock);
893 }
894 LastBlock = BB;
895 }
896
897 return Modified;
898}
899
900AllocaInst *createVariable(Function &F, Type *Type) {
901 const DataLayout &DL = F.getDataLayout();
902 return new AllocaInst(Type, DL.getAllocaAddrSpace(), nullptr, "reg",
903 F.begin()->getFirstInsertionPt());
904}
905
906Value *
907createExitVariable(BasicBlock *BB,
908 const DenseMap<BasicBlock *, ConstantInt *> &TargetToValue) {
909 auto *T = BB->getTerminator();
910 if (isa<ReturnInst>(Val: T))
911 return nullptr;
912 if (auto *BI = dyn_cast<UncondBrInst>(Val: T))
913 return TargetToValue.lookup(Val: BI->getSuccessor());
914
915 IRBuilder<> Builder(BB);
916 Builder.SetInsertPoint(T);
917
918 if (auto *BI = dyn_cast<CondBrInst>(Val: T)) {
919 Value *LHS = TargetToValue.lookup(Val: BI->getSuccessor(i: 0));
920 Value *RHS = TargetToValue.lookup(Val: BI->getSuccessor(i: 1));
921
922 if (LHS == nullptr || RHS == nullptr)
923 return LHS == nullptr ? RHS : LHS;
924 return Builder.CreateSelect(C: BI->getCondition(), True: LHS, False: RHS);
925 }
926
927 // TODO: add support for switch cases.
928 llvm_unreachable("Unhandled terminator type.");
929}
930
931MachineInstr *getVRegDef(MachineRegisterInfo &MRI, Register Reg) {
932 MachineInstr *MaybeDef = MRI.getVRegDef(Reg);
933 if (MaybeDef && MaybeDef->getOpcode() == SPIRV::ASSIGN_TYPE)
934 MaybeDef = MRI.getVRegDef(Reg: MaybeDef->getOperand(i: 1).getReg());
935 return MaybeDef;
936}
937
938bool getVacantFunctionName(Module &M, std::string &Name) {
939 // It's a bit of paranoia, but still we don't want to have even a chance that
940 // the loop will work for too long.
941 constexpr unsigned MaxIters = 1024;
942 for (unsigned I = 0; I < MaxIters; ++I) {
943 std::string OrdName = Name + Twine(I).str();
944 if (!M.getFunction(Name: OrdName)) {
945 Name = std::move(OrdName);
946 return true;
947 }
948 }
949 return false;
950}
951
952// Assign SPIR-V type to the register. If the register has no valid assigned
953// class, set register LLT type and class according to the SPIR-V type.
954void setRegClassType(Register Reg, SPIRVTypeInst SpvType,
955 SPIRVGlobalRegistry *GR, MachineRegisterInfo *MRI,
956 const MachineFunction &MF, bool Force) {
957 GR->assignSPIRVTypeToVReg(Type: SpvType, VReg: Reg, MF);
958 if (!MRI->getRegClassOrNull(Reg) || Force) {
959 MRI->setRegClass(Reg, RC: GR->getRegClass(SpvType));
960 LLT RegType = GR->getRegType(SpvType);
961 if (Force || !MRI->getType(Reg).isValid())
962 MRI->setType(VReg: Reg, Ty: RegType);
963 }
964}
965
966// Create a SPIR-V type, assign SPIR-V type to the register. If the register has
967// no valid assigned class, set register LLT type and class according to the
968// SPIR-V type.
969void setRegClassType(Register Reg, const Type *Ty, SPIRVGlobalRegistry *GR,
970 MachineIRBuilder &MIRBuilder,
971 SPIRV::AccessQualifier::AccessQualifier AccessQual,
972 bool EmitIR, bool Force) {
973 setRegClassType(Reg,
974 SpvType: GR->getOrCreateSPIRVType(Type: Ty, MIRBuilder, AQ: AccessQual, EmitIR),
975 GR, MRI: MIRBuilder.getMRI(), MF: MIRBuilder.getMF(), Force);
976}
977
978// Create a virtual register and assign SPIR-V type to the register. Set
979// register LLT type and class according to the SPIR-V type.
980Register createVirtualRegister(SPIRVTypeInst SpvType, SPIRVGlobalRegistry *GR,
981 MachineRegisterInfo *MRI,
982 const MachineFunction &MF) {
983 Register Reg = MRI->createVirtualRegister(RegClass: GR->getRegClass(SpvType));
984 MRI->setType(VReg: Reg, Ty: GR->getRegType(SpvType));
985 GR->assignSPIRVTypeToVReg(Type: SpvType, VReg: Reg, MF);
986 return Reg;
987}
988
989// Create a virtual register and assign SPIR-V type to the register. Set
990// register LLT type and class according to the SPIR-V type.
991Register createVirtualRegister(SPIRVTypeInst SpvType, SPIRVGlobalRegistry *GR,
992 MachineIRBuilder &MIRBuilder) {
993 return createVirtualRegister(SpvType, GR, MRI: MIRBuilder.getMRI(),
994 MF: MIRBuilder.getMF());
995}
996
997// Create a SPIR-V type, virtual register and assign SPIR-V type to the
998// register. Set register LLT type and class according to the SPIR-V type.
999Register createVirtualRegister(
1000 const Type *Ty, SPIRVGlobalRegistry *GR, MachineIRBuilder &MIRBuilder,
1001 SPIRV::AccessQualifier::AccessQualifier AccessQual, bool EmitIR) {
1002 return createVirtualRegister(
1003 SpvType: GR->getOrCreateSPIRVType(Type: Ty, MIRBuilder, AQ: AccessQual, EmitIR), GR,
1004 MIRBuilder);
1005}
1006
1007CallInst *buildIntrWithMD(Intrinsic::ID IntrID, ArrayRef<Type *> Types,
1008 Value *Arg, Value *Arg2, ArrayRef<Constant *> Imms,
1009 IRBuilder<> &B) {
1010 SmallVector<Value *, 4> Args;
1011 Args.push_back(Elt: Arg2);
1012 Args.push_back(Elt: buildMD(Arg));
1013 llvm::append_range(C&: Args, R&: Imms);
1014 return B.CreateIntrinsicWithoutFolding(ID: IntrID, OverloadTypes: {Types}, Args);
1015}
1016
1017// Return true if there is an opaque pointer type nested in the argument.
1018bool isNestedPointer(const Type *Ty) {
1019 if (Ty->isPtrOrPtrVectorTy())
1020 return true;
1021 if (const FunctionType *RefTy = dyn_cast<FunctionType>(Val: Ty)) {
1022 if (isNestedPointer(Ty: RefTy->getReturnType()))
1023 return true;
1024 for (const Type *ArgTy : RefTy->params())
1025 if (isNestedPointer(Ty: ArgTy))
1026 return true;
1027 return false;
1028 }
1029 if (const ArrayType *RefTy = dyn_cast<ArrayType>(Val: Ty))
1030 return isNestedPointer(Ty: RefTy->getElementType());
1031 return false;
1032}
1033
1034bool isSpvIntrinsic(const Value *Arg) {
1035 if (const auto *II = dyn_cast<IntrinsicInst>(Val: Arg))
1036 if (Function *F = II->getCalledFunction())
1037 if (F->getName().starts_with(Prefix: "llvm.spv."))
1038 return true;
1039 return false;
1040}
1041
1042// Function to create continued instructions for SPV_INTEL_long_composites
1043// extension
1044SmallVector<MachineInstr *, 4>
1045createContinuedInstructions(MachineIRBuilder &MIRBuilder, unsigned Opcode,
1046 unsigned MinWC, unsigned ContinuedOpcode,
1047 ArrayRef<Register> Args, Register ReturnRegister,
1048 Register TypeID) {
1049
1050 SmallVector<MachineInstr *, 4> Instructions;
1051 constexpr unsigned MaxWordCount = UINT16_MAX;
1052 const size_t NumElements = Args.size();
1053 size_t MaxNumElements = MaxWordCount - MinWC;
1054 size_t SPIRVStructNumElements = NumElements;
1055
1056 if (NumElements > MaxNumElements) {
1057 // Do adjustments for continued instructions which always had only one
1058 // minumum word count.
1059 SPIRVStructNumElements = MaxNumElements;
1060 MaxNumElements = MaxWordCount - 1;
1061 }
1062
1063 auto MIB =
1064 MIRBuilder.buildInstr(Opcode).addDef(RegNo: ReturnRegister).addUse(RegNo: TypeID);
1065
1066 for (size_t I = 0; I < SPIRVStructNumElements; ++I)
1067 MIB.addUse(RegNo: Args[I]);
1068
1069 Instructions.push_back(Elt: MIB.getInstr());
1070
1071 for (size_t I = SPIRVStructNumElements; I < NumElements;
1072 I += MaxNumElements) {
1073 auto MIB = MIRBuilder.buildInstr(Opcode: ContinuedOpcode);
1074 for (size_t J = I; J < std::min(a: I + MaxNumElements, b: NumElements); ++J)
1075 MIB.addUse(RegNo: Args[J]);
1076 Instructions.push_back(Elt: MIB.getInstr());
1077 }
1078 return Instructions;
1079}
1080
1081SmallVector<unsigned, 1>
1082getSpirvLoopControlOperandsFromLoopMetadata(MDNode *LoopMD) {
1083 unsigned LC = SPIRV::LoopControl::None;
1084 // Currently used only to store PartialCount value. Later when other
1085 // LoopControls are added - this map should be sorted before making
1086 // them loop_merge operands to satisfy 3.23. Loop Control requirements.
1087 std::vector<std::pair<unsigned, unsigned>> MaskToValueMap;
1088 if (findOptionMDForLoopID(LoopID: LoopMD, Name: "llvm.loop.unroll.disable")) {
1089 LC |= SPIRV::LoopControl::DontUnroll;
1090 } else {
1091 if (findOptionMDForLoopID(LoopID: LoopMD, Name: "llvm.loop.unroll.enable") ||
1092 findOptionMDForLoopID(LoopID: LoopMD, Name: "llvm.loop.unroll.full")) {
1093 LC |= SPIRV::LoopControl::Unroll;
1094 }
1095 if (MDNode *CountMD =
1096 findOptionMDForLoopID(LoopID: LoopMD, Name: "llvm.loop.unroll.count")) {
1097 if (auto *CI =
1098 mdconst::extract_or_null<ConstantInt>(MD: CountMD->getOperand(I: 1))) {
1099 unsigned Count = CI->getZExtValue();
1100 if (Count != 1) {
1101 LC |= SPIRV::LoopControl::PartialCount;
1102 MaskToValueMap.emplace_back(
1103 args: std::make_pair(x: SPIRV::LoopControl::PartialCount, y&: Count));
1104 }
1105 }
1106 }
1107 }
1108 SmallVector<unsigned, 1> Result = {LC};
1109 for (auto &[Mask, Val] : MaskToValueMap)
1110 Result.push_back(Elt: Val);
1111 return Result;
1112}
1113
1114SmallVector<unsigned, 1> getSpirvLoopControlOperandsFromLoopMetadata(Loop *L) {
1115 return getSpirvLoopControlOperandsFromLoopMetadata(LoopMD: L->getLoopID());
1116}
1117
1118const std::set<unsigned> &getTypeFoldingSupportedOpcodes() {
1119 // clang-format off
1120 static const std::set<unsigned> TypeFoldingSupportingOpcs = {
1121 TargetOpcode::G_ADD,
1122 TargetOpcode::G_FADD,
1123 TargetOpcode::G_STRICT_FADD,
1124 TargetOpcode::G_SUB,
1125 TargetOpcode::G_FSUB,
1126 TargetOpcode::G_STRICT_FSUB,
1127 TargetOpcode::G_MUL,
1128 TargetOpcode::G_FMUL,
1129 TargetOpcode::G_STRICT_FMUL,
1130 TargetOpcode::G_SDIV,
1131 TargetOpcode::G_UDIV,
1132 TargetOpcode::G_FDIV,
1133 TargetOpcode::G_STRICT_FDIV,
1134 TargetOpcode::G_SREM,
1135 TargetOpcode::G_UREM,
1136 TargetOpcode::G_FREM,
1137 TargetOpcode::G_STRICT_FREM,
1138 TargetOpcode::G_FNEG,
1139 TargetOpcode::G_CONSTANT,
1140 TargetOpcode::G_FCONSTANT,
1141 TargetOpcode::G_AND,
1142 TargetOpcode::G_OR,
1143 TargetOpcode::G_XOR,
1144 TargetOpcode::G_SHL,
1145 TargetOpcode::G_ASHR,
1146 TargetOpcode::G_LSHR,
1147 TargetOpcode::G_SELECT,
1148 TargetOpcode::G_EXTRACT_VECTOR_ELT,
1149 };
1150 // clang-format on
1151 return TypeFoldingSupportingOpcs;
1152}
1153
1154bool isTypeFoldingSupported(unsigned Opcode) {
1155 return getTypeFoldingSupportedOpcodes().count(x: Opcode) > 0;
1156}
1157
1158// Traversing [g]MIR accounting for pseudo-instructions.
1159MachineInstr *passCopy(MachineInstr *Def, const MachineRegisterInfo *MRI) {
1160 return (Def->getOpcode() == SPIRV::ASSIGN_TYPE ||
1161 Def->getOpcode() == TargetOpcode::COPY)
1162 ? MRI->getVRegDef(Reg: Def->getOperand(i: 1).getReg())
1163 : Def;
1164}
1165
1166MachineInstr *getDef(const MachineOperand &MO, const MachineRegisterInfo *MRI) {
1167 if (MachineInstr *Def = MRI->getVRegDef(Reg: MO.getReg()))
1168 return passCopy(Def, MRI);
1169 return nullptr;
1170}
1171
1172MachineInstr *getImm(const MachineOperand &MO, const MachineRegisterInfo *MRI) {
1173 if (MachineInstr *Def = getDef(MO, MRI)) {
1174 if (Def->getOpcode() == TargetOpcode::G_CONSTANT ||
1175 Def->getOpcode() == SPIRV::OpConstantI)
1176 return Def;
1177 }
1178 return nullptr;
1179}
1180
1181int64_t foldImm(const MachineOperand &MO, const MachineRegisterInfo *MRI) {
1182 if (MachineInstr *Def = getImm(MO, MRI)) {
1183 if (Def->getOpcode() == SPIRV::OpConstantI)
1184 return Def->getOperand(i: 2).getImm();
1185 if (Def->getOpcode() == TargetOpcode::G_CONSTANT)
1186 return Def->getOperand(i: 1).getCImm()->getZExtValue();
1187 }
1188 llvm_unreachable("Unexpected integer constant pattern");
1189}
1190
1191unsigned getArrayComponentCount(const MachineRegisterInfo *MRI,
1192 const MachineInstr *ResType) {
1193 return foldImm(MO: ResType->getOperand(i: 2), MRI);
1194}
1195
1196bool matchPeeledArrayPattern(const StructType *Ty, Type *&OriginalElementType,
1197 uint64_t &TotalSize) {
1198 // An array of N padded structs is represented as {[N-1 x <{T, pad}>], T}.
1199 if (Ty->getStructNumElements() != 2)
1200 return false;
1201
1202 Type *FirstElement = Ty->getStructElementType(N: 0);
1203 Type *SecondElement = Ty->getStructElementType(N: 1);
1204
1205 if (!FirstElement->isArrayTy())
1206 return false;
1207
1208 Type *ArrayElementType = FirstElement->getArrayElementType();
1209 if (!ArrayElementType->isStructTy() ||
1210 ArrayElementType->getStructNumElements() != 2)
1211 return false;
1212
1213 Type *T_in_struct = ArrayElementType->getStructElementType(N: 0);
1214 if (T_in_struct != SecondElement)
1215 return false;
1216
1217 auto *Padding_in_struct =
1218 dyn_cast<TargetExtType>(Val: ArrayElementType->getStructElementType(N: 1));
1219 if (!Padding_in_struct || Padding_in_struct->getName() != "spirv.Padding")
1220 return false;
1221
1222 const uint64_t ArraySize = FirstElement->getArrayNumElements();
1223 TotalSize = ArraySize + 1;
1224 OriginalElementType = ArrayElementType;
1225 return true;
1226}
1227
1228Type *reconstitutePeeledArrayType(Type *Ty) {
1229 if (!Ty->isStructTy())
1230 return Ty;
1231
1232 auto *STy = cast<StructType>(Val: Ty);
1233 Type *OriginalElementType = nullptr;
1234 uint64_t TotalSize = 0;
1235 if (matchPeeledArrayPattern(Ty: STy, OriginalElementType, TotalSize)) {
1236 Type *ResultTy = ArrayType::get(
1237 ElementType: reconstitutePeeledArrayType(Ty: OriginalElementType), NumElements: TotalSize);
1238 return ResultTy;
1239 }
1240
1241 SmallVector<Type *, 4> NewElementTypes;
1242 bool Changed = false;
1243 for (Type *ElementTy : STy->elements()) {
1244 Type *NewElementTy = reconstitutePeeledArrayType(Ty: ElementTy);
1245 if (NewElementTy != ElementTy)
1246 Changed = true;
1247 NewElementTypes.push_back(Elt: NewElementTy);
1248 }
1249
1250 if (!Changed)
1251 return Ty;
1252
1253 Type *ResultTy;
1254 if (STy->isLiteral())
1255 ResultTy =
1256 StructType::get(Context&: STy->getContext(), Elements: NewElementTypes, isPacked: STy->isPacked());
1257 else {
1258 auto *NewTy = StructType::create(Context&: STy->getContext(), Name: STy->getName());
1259 NewTy->setBody(Elements: NewElementTypes, isPacked: STy->isPacked());
1260 ResultTy = NewTy;
1261 }
1262 return ResultTy;
1263}
1264
1265std::optional<SPIRV::LinkageType::LinkageType>
1266getSpirvLinkageTypeFor(const SPIRVSubtarget &ST, const GlobalValue &GV) {
1267 if (GV.hasLocalLinkage())
1268 return std::nullopt;
1269
1270 if (GV.isDeclarationForLinker()) {
1271 // Interface variables must not get Import linkage.
1272 if (const auto *GVar = dyn_cast<GlobalVariable>(Val: &GV)) {
1273 auto SC = addressSpaceToStorageClass(AddrSpace: GVar->getAddressSpace(), STI: ST);
1274 if (SC == SPIRV::StorageClass::Input ||
1275 SC == SPIRV::StorageClass::Output ||
1276 SC == SPIRV::StorageClass::PushConstant)
1277 return std::nullopt;
1278 }
1279 return SPIRV::LinkageType::Import;
1280 }
1281
1282 if (GV.hasHiddenVisibility())
1283 return std::nullopt;
1284
1285 if (GV.hasLinkOnceODRLinkage() &&
1286 ST.canUseExtension(E: SPIRV::Extension::SPV_KHR_linkonce_odr))
1287 return SPIRV::LinkageType::LinkOnceODR;
1288
1289 if (GV.hasWeakLinkage() &&
1290 ST.canUseExtension(E: SPIRV::Extension::SPV_AMD_weak_linkage))
1291 return SPIRV::LinkageType::WeakAMD;
1292
1293 return SPIRV::LinkageType::Export;
1294}
1295
1296Function *getOrCreateBackendServiceFunction(Module &M) {
1297 std::string ServiceFunName = SPIRV_BACKEND_SERVICE_FUN_NAME;
1298 if (!getVacantFunctionName(M, Name&: ServiceFunName))
1299 report_fatal_error(
1300 reason: "cannot allocate a name for the internal service function");
1301 if (Function *SF = M.getFunction(Name: ServiceFunName)) {
1302 if (SF->getInstructionCount() > 0)
1303 report_fatal_error(
1304 reason: "Unexpected combination of global variables and function pointers");
1305 return SF;
1306 }
1307 Function *SF = Function::Create(
1308 Ty: FunctionType::get(Result: Type::getVoidTy(C&: M.getContext()), Params: {}, isVarArg: false),
1309 Linkage: GlobalValue::PrivateLinkage, N: ServiceFunName, M);
1310 SF->addFnAttr(SPIRV_BACKEND_SERVICE_FUN_NAME, Val: "");
1311 return SF;
1312}
1313
1314} // namespace llvm
1315