1//===- SPIRVModuleAnalysis.cpp - analysis of global instrs & regs - C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// The analysis collects instructions that should be output at the module level
10// and performs the global register numbering.
11//
12// The results of this analysis are used in AsmPrinter to rename registers
13// globally and to output required instructions at the module level.
14//
15//===----------------------------------------------------------------------===//
16
17// TODO: uses or report_fatal_error (which is also deprecated) /
18// ReportFatalUsageError in this file should be refactored, as per LLVM
19// best practices, to rely on the Diagnostic infrastructure.
20
21#include "SPIRVModuleAnalysis.h"
22#include "MCTargetDesc/SPIRVBaseInfo.h"
23#include "MCTargetDesc/SPIRVMCTargetDesc.h"
24#include "SPIRV.h"
25#include "SPIRVSubtarget.h"
26#include "SPIRVTargetMachine.h"
27#include "SPIRVUtils.h"
28#include "llvm/ADT/STLExtras.h"
29#include "llvm/CodeGen/MachineModuleInfo.h"
30#include "llvm/CodeGen/TargetPassConfig.h"
31
32using namespace llvm;
33
34#define DEBUG_TYPE "spirv-module-analysis"
35
36static cl::opt<bool>
37 SPVDumpDeps("spv-dump-deps",
38 cl::desc("Dump MIR with SPIR-V dependencies info"),
39 cl::Optional, cl::init(Val: false));
40
41static cl::list<SPIRV::Capability::Capability>
42 AvoidCapabilities("avoid-spirv-capabilities",
43 cl::desc("SPIR-V capabilities to avoid if there are "
44 "other options enabling a feature"),
45 cl::Hidden,
46 cl::values(clEnumValN(SPIRV::Capability::Shader, "Shader",
47 "SPIR-V Shader capability")));
48// Use sets instead of cl::list to check "if contains" condition
49struct AvoidCapabilitiesSet {
50 SmallSet<SPIRV::Capability::Capability, 4> S;
51 AvoidCapabilitiesSet() { S.insert_range(R&: AvoidCapabilities); }
52};
53
54char llvm::SPIRVModuleAnalysis::ID = 0;
55
56INITIALIZE_PASS(SPIRVModuleAnalysis, DEBUG_TYPE, "SPIRV module analysis", true,
57 true)
58
59// Retrieve an unsigned from an MDNode with a list of them as operands.
60static unsigned getMetadataUInt(MDNode *MdNode, unsigned OpIndex,
61 unsigned DefaultVal = 0) {
62 if (MdNode && OpIndex < MdNode->getNumOperands()) {
63 const auto &Op = MdNode->getOperand(I: OpIndex);
64 return mdconst::extract<ConstantInt>(MD: Op)->getZExtValue();
65 }
66 return DefaultVal;
67}
68
69static SPIRV::Requirements
70getSymbolicOperandRequirements(SPIRV::OperandCategory::OperandCategory Category,
71 unsigned i, const SPIRVSubtarget &ST,
72 SPIRV::RequirementHandler &Reqs) {
73 // A set of capabilities to avoid if there is another option.
74 AvoidCapabilitiesSet AvoidCaps;
75 if (!ST.isShader())
76 AvoidCaps.S.insert(V: SPIRV::Capability::Shader);
77 else
78 AvoidCaps.S.insert(V: SPIRV::Capability::Kernel);
79
80 VersionTuple ReqMinVer = getSymbolicOperandMinVersion(Category, Value: i);
81 VersionTuple ReqMaxVer = getSymbolicOperandMaxVersion(Category, Value: i);
82 VersionTuple SPIRVVersion = ST.getSPIRVVersion();
83 bool MinVerOK = SPIRVVersion.empty() || SPIRVVersion >= ReqMinVer;
84 bool MaxVerOK =
85 ReqMaxVer.empty() || SPIRVVersion.empty() || SPIRVVersion <= ReqMaxVer;
86 CapabilityList ReqCaps = getSymbolicOperandCapabilities(Category, Value: i);
87 ExtensionList ReqExts = getSymbolicOperandExtensions(Category, Value: i);
88 if (ReqCaps.empty()) {
89 if (ReqExts.empty()) {
90 if (MinVerOK && MaxVerOK)
91 return {true, {}, {}, ReqMinVer, ReqMaxVer};
92 return {false, {}, {}, VersionTuple(), VersionTuple()};
93 }
94 } else if (MinVerOK && MaxVerOK) {
95 if (ReqCaps.size() == 1) {
96 auto Cap = ReqCaps[0];
97 if (Reqs.isCapabilityAvailable(Cap)) {
98 ReqExts.append(RHS: getSymbolicOperandExtensions(
99 Category: SPIRV::OperandCategory::CapabilityOperand, Value: Cap));
100 return {true, {Cap}, std::move(ReqExts), ReqMinVer, ReqMaxVer};
101 }
102 } else {
103 // By SPIR-V specification: "If an instruction, enumerant, or other
104 // feature specifies multiple enabling capabilities, only one such
105 // capability needs to be declared to use the feature." However, one
106 // capability may be preferred over another. We use command line
107 // argument(s) and AvoidCapabilities to avoid selection of certain
108 // capabilities if there are other options.
109 CapabilityList UseCaps;
110 for (auto Cap : ReqCaps)
111 if (Reqs.isCapabilityAvailable(Cap))
112 UseCaps.push_back(Elt: Cap);
113 for (size_t i = 0, Sz = UseCaps.size(); i < Sz; ++i) {
114 auto Cap = UseCaps[i];
115 if (i == Sz - 1 || !AvoidCaps.S.contains(V: Cap)) {
116 ReqExts.append(RHS: getSymbolicOperandExtensions(
117 Category: SPIRV::OperandCategory::CapabilityOperand, Value: Cap));
118 return {true, {Cap}, std::move(ReqExts), ReqMinVer, ReqMaxVer};
119 }
120 }
121 }
122 }
123 // If there are no capabilities, or we can't satisfy the version or
124 // capability requirements, use the list of extensions (if the subtarget
125 // can handle them all).
126 if (llvm::all_of(Range&: ReqExts, P: [&ST](const SPIRV::Extension::Extension &Ext) {
127 return ST.canUseExtension(E: Ext);
128 })) {
129 return {true,
130 {},
131 std::move(ReqExts),
132 VersionTuple(),
133 VersionTuple()}; // TODO: add versions to extensions.
134 }
135 return {false, {}, {}, VersionTuple(), VersionTuple()};
136}
137
138void SPIRVModuleAnalysis::setBaseInfo(const Module &M) {
139 MAI.MaxID = 0;
140 for (int i = 0; i < SPIRV::NUM_MODULE_SECTIONS; i++)
141 MAI.MS[i].clear();
142 MAI.RegisterAliasTable.clear();
143 MAI.InstrsToDelete.clear();
144 MAI.GlobalObjMap.clear();
145 MAI.GlobalVarList.clear();
146 MAI.ExtInstSetMap.clear();
147 MAI.Reqs.clear();
148 MAI.Reqs.initAvailableCapabilities(ST: *ST);
149
150 // TODO: determine memory model and source language from the configuratoin.
151 if (auto MemModel = M.getNamedMetadata(Name: "spirv.MemoryModel")) {
152 auto MemMD = MemModel->getOperand(i: 0);
153 MAI.Addr = static_cast<SPIRV::AddressingModel::AddressingModel>(
154 getMetadataUInt(MdNode: MemMD, OpIndex: 0));
155 MAI.Mem =
156 static_cast<SPIRV::MemoryModel::MemoryModel>(getMetadataUInt(MdNode: MemMD, OpIndex: 1));
157 } else {
158 // TODO: Add support for VulkanMemoryModel.
159 MAI.Mem = ST->isShader() ? SPIRV::MemoryModel::GLSL450
160 : SPIRV::MemoryModel::OpenCL;
161 if (MAI.Mem == SPIRV::MemoryModel::OpenCL) {
162 unsigned PtrSize = ST->getPointerSize();
163 MAI.Addr = PtrSize == 32 ? SPIRV::AddressingModel::Physical32
164 : PtrSize == 64 ? SPIRV::AddressingModel::Physical64
165 : SPIRV::AddressingModel::Logical;
166 } else {
167 // TODO: Add support for PhysicalStorageBufferAddress.
168 MAI.Addr = SPIRV::AddressingModel::Logical;
169 }
170 }
171 // Get the OpenCL version number from metadata.
172 // TODO: support other source languages.
173 if (auto VerNode = M.getNamedMetadata(Name: "opencl.ocl.version")) {
174 MAI.SrcLang = SPIRV::SourceLanguage::OpenCL_C;
175 // Construct version literal in accordance with SPIRV-LLVM-Translator.
176 // TODO: support multiple OCL version metadata.
177 assert(VerNode->getNumOperands() > 0 && "Invalid SPIR");
178 auto VersionMD = VerNode->getOperand(i: 0);
179 unsigned MajorNum = getMetadataUInt(MdNode: VersionMD, OpIndex: 0, DefaultVal: 2);
180 unsigned MinorNum = getMetadataUInt(MdNode: VersionMD, OpIndex: 1);
181 unsigned RevNum = getMetadataUInt(MdNode: VersionMD, OpIndex: 2);
182 // Prevent Major part of OpenCL version to be 0
183 MAI.SrcLangVersion =
184 (std::max(a: 1U, b: MajorNum) * 100 + MinorNum) * 1000 + RevNum;
185 // When opencl.cxx.version is also present, validate compatibility
186 // and use C++ for OpenCL as source language with the C++ version.
187 if (auto *CxxVerNode = M.getNamedMetadata(Name: "opencl.cxx.version")) {
188 assert(CxxVerNode->getNumOperands() > 0 && "Invalid SPIR");
189 auto *CxxMD = CxxVerNode->getOperand(i: 0);
190 unsigned CxxVer =
191 (getMetadataUInt(MdNode: CxxMD, OpIndex: 0) * 100 + getMetadataUInt(MdNode: CxxMD, OpIndex: 1)) * 1000 +
192 getMetadataUInt(MdNode: CxxMD, OpIndex: 2);
193 if ((MAI.SrcLangVersion == 200000 && CxxVer == 100000) ||
194 (MAI.SrcLangVersion == 300000 && CxxVer == 202100000)) {
195 MAI.SrcLang = SPIRV::SourceLanguage::CPP_for_OpenCL;
196 MAI.SrcLangVersion = CxxVer;
197 } else {
198 report_fatal_error(
199 reason: "opencl cxx version is not compatible with opencl c version!");
200 }
201 }
202 } else {
203 // If there is no information about OpenCL version we are forced to generate
204 // OpenCL 1.0 by default for the OpenCL environment to avoid puzzling
205 // run-times with Unknown/0.0 version output. For a reference, LLVM-SPIRV
206 // Translator avoids potential issues with run-times in a similar manner.
207 if (!ST->isShader()) {
208 MAI.SrcLang = SPIRV::SourceLanguage::OpenCL_CPP;
209 MAI.SrcLangVersion = 100000;
210 } else {
211 MAI.SrcLang = SPIRV::SourceLanguage::Unknown;
212 MAI.SrcLangVersion = 0;
213 }
214 }
215
216 if (auto ExtNode = M.getNamedMetadata(Name: "opencl.used.extensions")) {
217 for (unsigned I = 0, E = ExtNode->getNumOperands(); I != E; ++I) {
218 MDNode *MD = ExtNode->getOperand(i: I);
219 if (!MD || MD->getNumOperands() == 0)
220 continue;
221 for (unsigned J = 0, N = MD->getNumOperands(); J != N; ++J)
222 MAI.SrcExt.insert(key: cast<MDString>(Val: MD->getOperand(I: J))->getString());
223 }
224 }
225
226 // Update required capabilities for this memory model, addressing model and
227 // source language.
228 MAI.Reqs.getAndAddRequirements(Category: SPIRV::OperandCategory::MemoryModelOperand,
229 i: MAI.Mem, ST: *ST);
230 MAI.Reqs.getAndAddRequirements(Category: SPIRV::OperandCategory::SourceLanguageOperand,
231 i: MAI.SrcLang, ST: *ST);
232 MAI.Reqs.getAndAddRequirements(Category: SPIRV::OperandCategory::AddressingModelOperand,
233 i: MAI.Addr, ST: *ST);
234
235 if (MAI.Mem == SPIRV::MemoryModel::VulkanKHR)
236 MAI.Reqs.addExtension(ToAdd: SPIRV::Extension::SPV_KHR_vulkan_memory_model);
237
238 if (!ST->isShader()) {
239 // TODO: check if it's required by default.
240 MAI.ExtInstSetMap[static_cast<unsigned>(
241 SPIRV::InstructionSet::OpenCL_std)] = MAI.getNextIDRegister();
242 }
243}
244
245// Appends the signature of the decoration instructions that decorate R to
246// Signature.
247static void appendDecorationsForReg(const MachineRegisterInfo &MRI, Register R,
248 InstrSignature &Signature) {
249 for (MachineInstr &UseMI : MRI.use_instructions(Reg: R)) {
250 // We don't handle OpDecorateId because getting the register alias for the
251 // ID can cause problems, and we do not need it for now.
252 if (UseMI.getOpcode() != SPIRV::OpDecorate &&
253 UseMI.getOpcode() != SPIRV::OpMemberDecorate)
254 continue;
255
256 for (unsigned I = 0; I < UseMI.getNumOperands(); ++I) {
257 const MachineOperand &MO = UseMI.getOperand(i: I);
258 if (MO.isReg())
259 continue;
260 Signature.push_back(Elt: hash_value(MO));
261 }
262 }
263}
264
265// Returns a representation of an instruction as a vector of MachineOperand
266// hash values, see llvm::hash_value(const MachineOperand &MO) for details.
267// This creates a signature of the instruction with the same content
268// that MachineOperand::isIdenticalTo uses for comparison.
269static InstrSignature instrToSignature(const MachineInstr &MI,
270 SPIRV::ModuleAnalysisInfo &MAI,
271 bool UseDefReg) {
272 Register DefReg;
273 InstrSignature Signature{MI.getOpcode()};
274 for (unsigned i = 0; i < MI.getNumOperands(); ++i) {
275 // The only decorations that can be applied more than once to a given <id>
276 // or structure member are FuncParamAttr (38), UserSemantic (5635),
277 // CacheControlLoadINTEL (6442), and CacheControlStoreINTEL (6443). For all
278 // the rest of decorations, we will only add to the signature the Opcode,
279 // the id to which it applies, and the decoration id, disregarding any
280 // decoration flags. This will ensure that any subsequent decoration with
281 // the same id will be deemed as a duplicate. Then, at the call site, we
282 // will be able to handle duplicates in the best way.
283 unsigned Opcode = MI.getOpcode();
284 if ((Opcode == SPIRV::OpDecorate) && i >= 2) {
285 unsigned DecorationID = MI.getOperand(i: 1).getImm();
286 if (DecorationID != SPIRV::Decoration::FuncParamAttr &&
287 DecorationID != SPIRV::Decoration::UserSemantic &&
288 DecorationID != SPIRV::Decoration::CacheControlLoadINTEL &&
289 DecorationID != SPIRV::Decoration::CacheControlStoreINTEL)
290 continue;
291 }
292 const MachineOperand &MO = MI.getOperand(i);
293 size_t h;
294 if (MO.isReg()) {
295 if (!UseDefReg && MO.isDef()) {
296 assert(!DefReg.isValid() && "Multiple def registers.");
297 DefReg = MO.getReg();
298 continue;
299 }
300 Register RegAlias = MAI.getRegisterAlias(MF: MI.getMF(), Reg: MO.getReg());
301 if (!RegAlias.isValid()) {
302 LLVM_DEBUG({
303 dbgs() << "Unexpectedly, no global id found for the operand ";
304 MO.print(dbgs());
305 dbgs() << "\nInstruction: ";
306 MI.print(dbgs());
307 dbgs() << "\n";
308 });
309 report_fatal_error(reason: "All v-regs must have been mapped to global id's");
310 }
311 // mimic llvm::hash_value(const MachineOperand &MO)
312 h = hash_combine(args: MO.getType(), args: (unsigned)RegAlias, args: MO.getSubReg(),
313 args: MO.isDef());
314 } else {
315 h = hash_value(MO);
316 }
317 Signature.push_back(Elt: h);
318 }
319
320 if (DefReg.isValid()) {
321 // Decorations change the semantics of the current instruction. So two
322 // identical instruction with different decorations cannot be merged. That
323 // is why we add the decorations to the signature.
324 appendDecorationsForReg(MRI: MI.getMF()->getRegInfo(), R: DefReg, Signature);
325 }
326 return Signature;
327}
328
329bool SPIRVModuleAnalysis::isDeclSection(const MachineRegisterInfo &MRI,
330 const MachineInstr &MI) {
331 unsigned Opcode = MI.getOpcode();
332 switch (Opcode) {
333 case SPIRV::OpTypeForwardPointer:
334 // omit now, collect later
335 return false;
336 case SPIRV::OpVariable:
337 return static_cast<SPIRV::StorageClass::StorageClass>(
338 MI.getOperand(i: 2).getImm()) != SPIRV::StorageClass::Function;
339 case SPIRV::OpFunction:
340 case SPIRV::OpFunctionParameter:
341 return true;
342 }
343 if (GR->hasConstFunPtr() && Opcode == SPIRV::OpUndef) {
344 // The OpUndef may be a placeholder for a function reference recorded by
345 // selectGlobalValue. Skip emitting it if any user consumes it as a
346 // function-pointer-like operand (OpConstantFunctionPointerINTEL operand 2,
347 // or OpEnqueueKernel's Invoke operand at index 8). The rewrite happens
348 // in visitFunPtrUse, which aliases the OpUndef's vreg to the function's
349 // global <id>.
350 Register DefReg = MI.getOperand(i: 0).getReg();
351 if (GR->getFunctionDefinitionByUse(Use: &MI.getOperand(i: 0))) {
352 for (MachineInstr &UseMI : MRI.use_instructions(Reg: DefReg)) {
353 unsigned UseOp = UseMI.getOpcode();
354 if (UseOp == SPIRV::OpConstantFunctionPointerINTEL ||
355 UseOp == SPIRV::OpEnqueueKernel) {
356 MAI.setSkipEmission(&MI);
357 return false;
358 }
359 }
360 }
361 for (MachineInstr &UseMI : MRI.use_instructions(Reg: DefReg)) {
362 if (UseMI.getOpcode() != SPIRV::OpConstantFunctionPointerINTEL)
363 continue;
364 // it's a dummy definition, FP constant refers to a function,
365 // and this is resolved in another way; let's skip this definition
366 assert(UseMI.getOperand(2).isReg() &&
367 UseMI.getOperand(2).getReg() == DefReg);
368 MAI.setSkipEmission(&MI);
369 return false;
370 }
371 }
372 return TII->isTypeDeclInstr(MI) || TII->isConstantInstr(MI) ||
373 TII->isInlineAsmDefInstr(MI);
374}
375
376// This is a special case of a function pointer referring to a possibly
377// forward function declaration. The operand is a dummy OpUndef that
378// requires a special treatment.
379// FunPtrOp is the MachineOperand previously recorded via
380// SPIRVGlobalRegistry::recordFunctionPointer, identifying which Function
381// this placeholder refers to.
382void SPIRVModuleAnalysis::visitFunPtrUse(
383 Register OpReg, const MachineOperand *FunPtrOp,
384 InstrGRegsMap &SignatureToGReg,
385 std::map<const Value *, unsigned> &GlobalToGReg,
386 const MachineFunction *MF) {
387 const MachineOperand *OpFunDef = GR->getFunctionDefinitionByUse(Use: FunPtrOp);
388 assert(OpFunDef && OpFunDef->isReg());
389 // find the actual function definition and number it globally in advance
390 const MachineInstr *OpDefMI = OpFunDef->getParent();
391 assert(OpDefMI && OpDefMI->getOpcode() == SPIRV::OpFunction);
392 const MachineFunction *FunDefMF = OpDefMI->getParent()->getParent();
393 const MachineRegisterInfo &FunDefMRI = FunDefMF->getRegInfo();
394 do {
395 visitDecl(MRI: FunDefMRI, SignatureToGReg, GlobalToGReg, MF: FunDefMF, MI: *OpDefMI);
396 OpDefMI = OpDefMI->getNextNode();
397 } while (OpDefMI && (OpDefMI->getOpcode() == SPIRV::OpFunction ||
398 OpDefMI->getOpcode() == SPIRV::OpFunctionParameter));
399 // associate the function pointer with the newly assigned global number
400 MCRegister GlobalFunDefReg =
401 MAI.getRegisterAlias(MF: FunDefMF, Reg: OpFunDef->getReg());
402 assert(GlobalFunDefReg.isValid() &&
403 "Function definition must refer to a global register");
404 MAI.setRegisterAlias(MF, Reg: OpReg, AliasReg: GlobalFunDefReg);
405}
406
407// Depth first recursive traversal of dependencies. Repeated visits are guarded
408// by MAI.hasRegisterAlias().
409void SPIRVModuleAnalysis::visitDecl(
410 const MachineRegisterInfo &MRI, InstrGRegsMap &SignatureToGReg,
411 std::map<const Value *, unsigned> &GlobalToGReg, const MachineFunction *MF,
412 const MachineInstr &MI) {
413 unsigned Opcode = MI.getOpcode();
414
415 // Process each operand of the instruction to resolve dependencies
416 for (const MachineOperand &MO : MI.operands()) {
417 if (!MO.isReg() || MO.isDef())
418 continue;
419 Register OpReg = MO.getReg();
420 // Handle function pointers special case
421 if (Opcode == SPIRV::OpConstantFunctionPointerINTEL &&
422 MRI.getRegClass(Reg: OpReg) == &SPIRV::pIDRegClass) {
423 visitFunPtrUse(OpReg, FunPtrOp: &MI.getOperand(i: 2), SignatureToGReg, GlobalToGReg,
424 MF);
425 continue;
426 }
427 // Skip already processed instructions
428 if (MAI.hasRegisterAlias(MF, Reg: MO.getReg()))
429 continue;
430 // Recursively visit dependencies
431 if (const MachineInstr *OpDefMI = MRI.getUniqueVRegDef(Reg: OpReg)) {
432 if (isDeclSection(MRI, MI: *OpDefMI))
433 visitDecl(MRI, SignatureToGReg, GlobalToGReg, MF, MI: *OpDefMI);
434 continue;
435 }
436 // Handle the unexpected case of no unique definition for the SPIR-V
437 // instruction
438 LLVM_DEBUG({
439 dbgs() << "Unexpectedly, no unique definition for the operand ";
440 MO.print(dbgs());
441 dbgs() << "\nInstruction: ";
442 MI.print(dbgs());
443 dbgs() << "\n";
444 });
445 report_fatal_error(
446 reason: "No unique definition is found for the virtual register");
447 }
448
449 MCRegister GReg;
450 bool IsFunDef = false;
451 if (TII->isSpecConstantInstr(MI)) {
452 GReg = MAI.getNextIDRegister();
453 MAI.MS[SPIRV::MB_TypeConstVars].push_back(Elt: &MI);
454 } else if (Opcode == SPIRV::OpFunction ||
455 Opcode == SPIRV::OpFunctionParameter) {
456 GReg = handleFunctionOrParameter(MF, MI, GlobalToGReg, IsFunDef);
457 } else if (Opcode == SPIRV::OpTypeStruct ||
458 Opcode == SPIRV::OpConstantComposite) {
459 GReg = handleTypeDeclOrConstant(MI, SignatureToGReg);
460 const MachineInstr *NextInstr = MI.getNextNode();
461 while (NextInstr &&
462 ((Opcode == SPIRV::OpTypeStruct &&
463 NextInstr->getOpcode() == SPIRV::OpTypeStructContinuedINTEL) ||
464 (Opcode == SPIRV::OpConstantComposite &&
465 NextInstr->getOpcode() ==
466 SPIRV::OpConstantCompositeContinuedINTEL))) {
467 MCRegister Tmp = handleTypeDeclOrConstant(MI: *NextInstr, SignatureToGReg);
468 MAI.setRegisterAlias(MF, Reg: NextInstr->getOperand(i: 0).getReg(), AliasReg: Tmp);
469 MAI.setSkipEmission(NextInstr);
470 NextInstr = NextInstr->getNextNode();
471 }
472 } else if (TII->isTypeDeclInstr(MI) || TII->isConstantInstr(MI) ||
473 TII->isInlineAsmDefInstr(MI)) {
474 GReg = handleTypeDeclOrConstant(MI, SignatureToGReg);
475 } else if (Opcode == SPIRV::OpVariable) {
476 GReg = handleVariable(MF, MI, GlobalToGReg);
477 } else {
478 LLVM_DEBUG({
479 dbgs() << "\nInstruction: ";
480 MI.print(dbgs());
481 dbgs() << "\n";
482 });
483 llvm_unreachable("Unexpected instruction is visited");
484 }
485 MAI.setRegisterAlias(MF, Reg: MI.getOperand(i: 0).getReg(), AliasReg: GReg);
486 if (!IsFunDef)
487 MAI.setSkipEmission(&MI);
488}
489
490MCRegister SPIRVModuleAnalysis::handleFunctionOrParameter(
491 const MachineFunction *MF, const MachineInstr &MI,
492 std::map<const Value *, unsigned> &GlobalToGReg, bool &IsFunDef) {
493 const Value *GObj = GR->getGlobalObject(MF, R: MI.getOperand(i: 0).getReg());
494 assert(GObj && "Unregistered global definition");
495 const Function *F = dyn_cast<Function>(Val: GObj);
496 if (!F)
497 F = dyn_cast<Argument>(Val: GObj)->getParent();
498 assert(F && "Expected a reference to a function or an argument");
499 IsFunDef = !F->isDeclaration();
500 auto [It, Inserted] = GlobalToGReg.try_emplace(k: GObj);
501 if (!Inserted)
502 return It->second;
503 MCRegister GReg = MAI.getNextIDRegister();
504 It->second = GReg;
505 if (!IsFunDef)
506 MAI.MS[SPIRV::MB_ExtFuncDecls].push_back(Elt: &MI);
507 return GReg;
508}
509
510MCRegister
511SPIRVModuleAnalysis::handleTypeDeclOrConstant(const MachineInstr &MI,
512 InstrGRegsMap &SignatureToGReg) {
513 InstrSignature MISign = instrToSignature(MI, MAI, UseDefReg: false);
514 auto [It, Inserted] = SignatureToGReg.try_emplace(k: MISign);
515 if (!Inserted)
516 return It->second;
517 MCRegister GReg = MAI.getNextIDRegister();
518 It->second = GReg;
519 MAI.MS[SPIRV::MB_TypeConstVars].push_back(Elt: &MI);
520 return GReg;
521}
522
523MCRegister SPIRVModuleAnalysis::handleVariable(
524 const MachineFunction *MF, const MachineInstr &MI,
525 std::map<const Value *, unsigned> &GlobalToGReg) {
526 MAI.GlobalVarList.push_back(Elt: &MI);
527 const Value *GObj = GR->getGlobalObject(MF, R: MI.getOperand(i: 0).getReg());
528 assert(GObj && "Unregistered global definition");
529 auto [It, Inserted] = GlobalToGReg.try_emplace(k: GObj);
530 if (!Inserted)
531 return It->second;
532 MCRegister GReg = MAI.getNextIDRegister();
533 It->second = GReg;
534 MAI.MS[SPIRV::MB_TypeConstVars].push_back(Elt: &MI);
535 if (const auto *GV = dyn_cast<GlobalVariable>(Val: GObj))
536 MAI.GlobalObjMap[GV] = GReg;
537 return GReg;
538}
539
540void SPIRVModuleAnalysis::collectDeclarations(const Module &M) {
541 InstrGRegsMap SignatureToGReg;
542 std::map<const Value *, unsigned> GlobalToGReg;
543 for (const Function &F : M) {
544 MachineFunction *MF = MMI->getMachineFunction(F);
545 if (!MF)
546 continue;
547 const MachineRegisterInfo &MRI = MF->getRegInfo();
548 unsigned PastHeader = 0;
549 for (MachineBasicBlock &MBB : *MF) {
550 for (MachineInstr &MI : MBB) {
551 if (MI.getNumOperands() == 0)
552 continue;
553 unsigned Opcode = MI.getOpcode();
554 if (Opcode == SPIRV::OpFunction) {
555 if (PastHeader == 0) {
556 PastHeader = 1;
557 continue;
558 }
559 } else if (Opcode == SPIRV::OpFunctionParameter) {
560 if (PastHeader < 2)
561 continue;
562 } else if (PastHeader > 0) {
563 PastHeader = 2;
564 }
565
566 const MachineOperand &DefMO = MI.getOperand(i: 0);
567 switch (Opcode) {
568 case SPIRV::OpExtension:
569 MAI.Reqs.addExtension(ToAdd: SPIRV::Extension::Extension(DefMO.getImm()));
570 MAI.setSkipEmission(&MI);
571 break;
572 case SPIRV::OpCapability:
573 MAI.Reqs.addCapability(ToAdd: SPIRV::Capability::Capability(DefMO.getImm()));
574 MAI.setSkipEmission(&MI);
575 if (PastHeader > 0)
576 PastHeader = 2;
577 break;
578 default:
579 if (DefMO.isReg() && isDeclSection(MRI, MI) &&
580 !MAI.hasRegisterAlias(MF, Reg: DefMO.getReg()))
581 visitDecl(MRI, SignatureToGReg, GlobalToGReg, MF, MI);
582 // OpEnqueueKernel is not a decl, but its Invoke operand may be a
583 // function-pointer placeholder OpUndef recorded by selectGlobalValue.
584 // Resolve it to the OpFunction's global <id> via visitFunPtrUse.
585 if (Opcode == SPIRV::OpEnqueueKernel && MI.getNumOperands() > 8) {
586 const MachineOperand &InvokeMO = MI.getOperand(i: 8);
587 if (InvokeMO.isReg()) {
588 Register InvokeReg = InvokeMO.getReg();
589 if (!MAI.hasRegisterAlias(MF, Reg: InvokeReg)) {
590 if (const MachineInstr *DefMI =
591 MRI.getUniqueVRegDef(Reg: InvokeReg)) {
592 if (DefMI->getOpcode() == SPIRV::OpUndef) {
593 const MachineOperand *FunPtrOp = &DefMI->getOperand(i: 0);
594 if (GR->getFunctionDefinitionByUse(Use: FunPtrOp))
595 visitFunPtrUse(OpReg: InvokeReg, FunPtrOp, SignatureToGReg,
596 GlobalToGReg, MF);
597 }
598 }
599 }
600 }
601 }
602 }
603 }
604 }
605 }
606}
607
608// Look for IDs declared with Import linkage, and map the corresponding function
609// to the register defining that variable (which will usually be the result of
610// an OpFunction). This lets us call externally imported functions using
611// the correct ID registers.
612void SPIRVModuleAnalysis::collectFuncNames(MachineInstr &MI,
613 const Function *F) {
614 if (MI.getOpcode() == SPIRV::OpDecorate) {
615 // If it's got Import linkage.
616 auto Dec = MI.getOperand(i: 1).getImm();
617 if (Dec == SPIRV::Decoration::LinkageAttributes) {
618 auto Lnk = MI.getOperand(i: MI.getNumOperands() - 1).getImm();
619 if (Lnk == SPIRV::LinkageType::Import) {
620 // Map imported function name to function ID register.
621 const Function *ImportedFunc =
622 F->getParent()->getFunction(Name: getStringImm(MI, StartIndex: 2));
623 Register Target = MI.getOperand(i: 0).getReg();
624 MAI.GlobalObjMap[ImportedFunc] =
625 MAI.getRegisterAlias(MF: MI.getMF(), Reg: Target);
626 }
627 }
628 } else if (MI.getOpcode() == SPIRV::OpFunction) {
629 // Record all internal OpFunction declarations.
630 Register Reg = MI.defs().begin()->getReg();
631 MCRegister GlobalReg = MAI.getRegisterAlias(MF: MI.getMF(), Reg);
632 assert(GlobalReg.isValid());
633 MAI.GlobalObjMap[F] = GlobalReg;
634 }
635}
636
637// Collect the given instruction in the specified MS. We assume global register
638// numbering has already occurred by this point. We can directly compare reg
639// arguments when detecting duplicates.
640static void collectOtherInstr(MachineInstr &MI, SPIRV::ModuleAnalysisInfo &MAI,
641 SPIRV::ModuleSectionType MSType, InstrTraces &IS,
642 bool Append = true) {
643 MAI.setSkipEmission(&MI);
644 InstrSignature MISign = instrToSignature(MI, MAI, UseDefReg: true);
645 auto FoundMI = IS.insert(x: std::move(MISign));
646 if (!FoundMI.second) {
647 if (MI.getOpcode() == SPIRV::OpDecorate) {
648 assert(MI.getNumOperands() >= 2 &&
649 "Decoration instructions must have at least 2 operands");
650 assert(MSType == SPIRV::MB_Annotations &&
651 "Only OpDecorate instructions can be duplicates");
652 // For FPFastMathMode decoration, we need to merge the flags of the
653 // duplicate decoration with the original one, so we need to find the
654 // original instruction that has the same signature. For the rest of
655 // instructions, we will simply skip the duplicate.
656 if (MI.getOperand(i: 1).getImm() != SPIRV::Decoration::FPFastMathMode)
657 return; // Skip duplicates of other decorations.
658
659 const SPIRV::InstrList &Decorations = MAI.MS[MSType];
660 for (const MachineInstr *OrigMI : Decorations) {
661 if (instrToSignature(MI: *OrigMI, MAI, UseDefReg: true) == MISign) {
662 assert(OrigMI->getNumOperands() == MI.getNumOperands() &&
663 "Original instruction must have the same number of operands");
664 assert(
665 OrigMI->getNumOperands() == 3 &&
666 "FPFastMathMode decoration must have 3 operands for OpDecorate");
667 unsigned OrigFlags = OrigMI->getOperand(i: 2).getImm();
668 unsigned NewFlags = MI.getOperand(i: 2).getImm();
669 if (OrigFlags == NewFlags)
670 return; // No need to merge, the flags are the same.
671
672 // Emit warning about possible conflict between flags.
673 unsigned FinalFlags = OrigFlags | NewFlags;
674 llvm::errs()
675 << "Warning: Conflicting FPFastMathMode decoration flags "
676 "in instruction: "
677 << *OrigMI << "Original flags: " << OrigFlags
678 << ", new flags: " << NewFlags
679 << ". They will be merged on a best effort basis, but not "
680 "validated. Final flags: "
681 << FinalFlags << "\n";
682 MachineInstr *OrigMINonConst = const_cast<MachineInstr *>(OrigMI);
683 MachineOperand &OrigFlagsOp = OrigMINonConst->getOperand(i: 2);
684 OrigFlagsOp = MachineOperand::CreateImm(Val: FinalFlags);
685 return; // Merge done, so we found a duplicate; don't add it to MAI.MS
686 }
687 }
688 assert(false && "No original instruction found for the duplicate "
689 "OpDecorate, but we found one in IS.");
690 }
691 return; // insert failed, so we found a duplicate; don't add it to MAI.MS
692 }
693 // No duplicates, so add it.
694 if (Append)
695 MAI.MS[MSType].push_back(Elt: &MI);
696 else
697 MAI.MS[MSType].insert(I: MAI.MS[MSType].begin(), Elt: &MI);
698}
699
700// Some global instructions make reference to function-local ID regs, so cannot
701// be correctly collected until these registers are globally numbered.
702void SPIRVModuleAnalysis::processOtherInstrs(const Module &M) {
703 InstrTraces IS;
704 for (const Function &F : M) {
705 if (F.isDeclaration())
706 continue;
707 MachineFunction *MF = MMI->getMachineFunction(F);
708 assert(MF);
709
710 for (MachineBasicBlock &MBB : *MF)
711 for (MachineInstr &MI : MBB) {
712 if (MAI.getSkipEmission(MI: &MI))
713 continue;
714 const unsigned OpCode = MI.getOpcode();
715 if (OpCode == SPIRV::OpString) {
716 collectOtherInstr(MI, MAI, MSType: SPIRV::MB_DebugStrings, IS);
717 } else if (OpCode == SPIRV::OpExtInst && MI.getOperand(i: 2).isImm() &&
718 MI.getOperand(i: 2).getImm() ==
719 SPIRV::InstructionSet::
720 NonSemantic_Shader_DebugInfo_100) {
721 // TODO: This branch is dead. SPIRVNonSemanticDebugHandler emits NSDI
722 // instructions directly as MCInsts at print time; no
723 // MachineInstructions with the NSDI ext set are created anymore.
724 // Remove this block and
725 // MB_NonSemanticGlobalDI once per-function NSDI emission is confirmed
726 // not to need MIR routing.
727 MachineOperand Ins = MI.getOperand(i: 3);
728 namespace NS = SPIRV::NonSemanticExtInst;
729 static constexpr int64_t GlobalNonSemanticDITy[] = {
730 NS::DebugSource, NS::DebugCompilationUnit, NS::DebugInfoNone,
731 NS::DebugTypeBasic, NS::DebugTypePointer};
732 bool IsGlobalDI = false;
733 for (unsigned Idx = 0; Idx < std::size(GlobalNonSemanticDITy); ++Idx)
734 IsGlobalDI |= Ins.getImm() == GlobalNonSemanticDITy[Idx];
735 if (IsGlobalDI)
736 collectOtherInstr(MI, MAI, MSType: SPIRV::MB_NonSemanticGlobalDI, IS);
737 } else if (OpCode == SPIRV::OpName || OpCode == SPIRV::OpMemberName) {
738 collectOtherInstr(MI, MAI, MSType: SPIRV::MB_DebugNames, IS);
739 } else if (OpCode == SPIRV::OpEntryPoint) {
740 collectOtherInstr(MI, MAI, MSType: SPIRV::MB_EntryPoints, IS);
741 } else if (TII->isAliasingInstr(MI)) {
742 collectOtherInstr(MI, MAI, MSType: SPIRV::MB_AliasingInsts, IS);
743 } else if (TII->isDecorationInstr(MI)) {
744 collectOtherInstr(MI, MAI, MSType: SPIRV::MB_Annotations, IS);
745 collectFuncNames(MI, F: &F);
746 } else if (TII->isConstantInstr(MI)) {
747 // Now OpSpecConstant*s are not in DT,
748 // but they need to be collected anyway.
749 collectOtherInstr(MI, MAI, MSType: SPIRV::MB_TypeConstVars, IS);
750 } else if (OpCode == SPIRV::OpFunction) {
751 collectFuncNames(MI, F: &F);
752 } else if (OpCode == SPIRV::OpTypeForwardPointer) {
753 collectOtherInstr(MI, MAI, MSType: SPIRV::MB_TypeConstVars, IS, Append: false);
754 }
755 }
756 }
757 // Selection order can place a scope/list ahead of a domain/scope it
758 // references. The dependency meanwhile is domain -> scope -> list, so sort
759 // the def before its uses.
760 auto AliasingTier = [](const MachineInstr *MI) {
761 switch (MI->getOpcode()) {
762 case SPIRV::OpAliasDomainDeclINTEL:
763 return 0;
764 case SPIRV::OpAliasScopeDeclINTEL:
765 return 1;
766 case SPIRV::OpAliasScopeListDeclINTEL:
767 return 2;
768 default:
769 llvm_unreachable("unexpected aliasing instruction");
770 }
771 };
772 stable_sort(Range&: MAI.MS[SPIRV::MB_AliasingInsts],
773 C: [&](const MachineInstr *LHS, const MachineInstr *RHS) {
774 return AliasingTier(LHS) < AliasingTier(RHS);
775 });
776}
777
778// Number registers in all functions globally from 0 onwards and store
779// the result in global register alias table. Some registers are already
780// numbered.
781void SPIRVModuleAnalysis::numberRegistersGlobally(const Module &M) {
782 for (const Function &F : M) {
783 if (F.isDeclaration())
784 continue;
785 MachineFunction *MF = MMI->getMachineFunction(F);
786 assert(MF);
787 for (MachineBasicBlock &MBB : *MF) {
788 for (MachineInstr &MI : MBB) {
789 for (MachineOperand &Op : MI.operands()) {
790 if (!Op.isReg())
791 continue;
792 Register Reg = Op.getReg();
793 if (MAI.hasRegisterAlias(MF, Reg))
794 continue;
795 MCRegister NewReg = MAI.getNextIDRegister();
796 MAI.setRegisterAlias(MF, Reg, AliasReg: NewReg);
797 }
798 if (MI.getOpcode() != SPIRV::OpExtInst)
799 continue;
800 auto Set = MI.getOperand(i: 2).getImm();
801 auto [It, Inserted] = MAI.ExtInstSetMap.try_emplace(Key: Set);
802 if (Inserted)
803 It->second = MAI.getNextIDRegister();
804 }
805 }
806 }
807}
808
809// RequirementHandler implementations.
810void SPIRV::RequirementHandler::getAndAddRequirements(
811 SPIRV::OperandCategory::OperandCategory Category, uint32_t i,
812 const SPIRVSubtarget &ST) {
813 addRequirements(Req: getSymbolicOperandRequirements(Category, i, ST, Reqs&: *this));
814}
815
816void SPIRV::RequirementHandler::recursiveAddCapabilities(
817 const CapabilityList &ToPrune) {
818 for (const auto &Cap : ToPrune) {
819 AllCaps.insert(V: Cap);
820 CapabilityList ImplicitDecls =
821 getSymbolicOperandCapabilities(Category: OperandCategory::CapabilityOperand, Value: Cap);
822 recursiveAddCapabilities(ToPrune: ImplicitDecls);
823 }
824}
825
826void SPIRV::RequirementHandler::addCapabilities(const CapabilityList &ToAdd) {
827 for (const auto &Cap : ToAdd) {
828 bool IsNewlyInserted = AllCaps.insert(V: Cap).second;
829 if (!IsNewlyInserted) // Don't re-add if it's already been declared.
830 continue;
831 CapabilityList ImplicitDecls =
832 getSymbolicOperandCapabilities(Category: OperandCategory::CapabilityOperand, Value: Cap);
833 recursiveAddCapabilities(ToPrune: ImplicitDecls);
834 MinimalCaps.push_back(Elt: Cap);
835 }
836}
837
838void SPIRV::RequirementHandler::addRequirements(
839 const SPIRV::Requirements &Req) {
840 if (!Req.IsSatisfiable)
841 report_fatal_error(reason: "Adding SPIR-V requirements this target can't satisfy.");
842
843 if (Req.Cap.has_value())
844 addCapabilities(ToAdd: {Req.Cap.value()});
845
846 addExtensions(ToAdd: Req.Exts);
847
848 if (!Req.MinVer.empty()) {
849 if (!MaxVersion.empty() && Req.MinVer > MaxVersion) {
850 LLVM_DEBUG(dbgs() << "Conflicting version requirements: >= " << Req.MinVer
851 << " and <= " << MaxVersion << "\n");
852 report_fatal_error(reason: "Adding SPIR-V requirements that can't be satisfied.");
853 }
854
855 if (MinVersion.empty() || Req.MinVer > MinVersion)
856 MinVersion = Req.MinVer;
857 }
858
859 if (!Req.MaxVer.empty()) {
860 if (!MinVersion.empty() && Req.MaxVer < MinVersion) {
861 LLVM_DEBUG(dbgs() << "Conflicting version requirements: <= " << Req.MaxVer
862 << " and >= " << MinVersion << "\n");
863 report_fatal_error(reason: "Adding SPIR-V requirements that can't be satisfied.");
864 }
865
866 if (MaxVersion.empty() || Req.MaxVer < MaxVersion)
867 MaxVersion = Req.MaxVer;
868 }
869}
870
871void SPIRV::RequirementHandler::checkSatisfiable(
872 const SPIRVSubtarget &ST) const {
873 // Report as many errors as possible before aborting the compilation.
874 bool IsSatisfiable = true;
875 auto TargetVer = ST.getSPIRVVersion();
876
877 if (!MaxVersion.empty() && !TargetVer.empty() && MaxVersion < TargetVer) {
878 LLVM_DEBUG(
879 dbgs() << "Target SPIR-V version too high for required features\n"
880 << "Required max version: " << MaxVersion << " target version "
881 << TargetVer << "\n");
882 IsSatisfiable = false;
883 }
884
885 if (!MinVersion.empty() && !TargetVer.empty() && MinVersion > TargetVer) {
886 LLVM_DEBUG(dbgs() << "Target SPIR-V version too low for required features\n"
887 << "Required min version: " << MinVersion
888 << " target version " << TargetVer << "\n");
889 IsSatisfiable = false;
890 }
891
892 if (!MinVersion.empty() && !MaxVersion.empty() && MinVersion > MaxVersion) {
893 LLVM_DEBUG(
894 dbgs()
895 << "Version is too low for some features and too high for others.\n"
896 << "Required SPIR-V min version: " << MinVersion
897 << " required SPIR-V max version " << MaxVersion << "\n");
898 IsSatisfiable = false;
899 }
900
901 AvoidCapabilitiesSet AvoidCaps;
902 if (!ST.isShader())
903 AvoidCaps.S.insert(V: SPIRV::Capability::Shader);
904 else
905 AvoidCaps.S.insert(V: SPIRV::Capability::Kernel);
906
907 for (auto Cap : MinimalCaps) {
908 if (AvailableCaps.contains(V: Cap) && !AvoidCaps.S.contains(V: Cap))
909 continue;
910 LLVM_DEBUG(dbgs() << "Capability not supported: "
911 << getSymbolicOperandMnemonic(
912 OperandCategory::CapabilityOperand, Cap)
913 << "\n");
914 IsSatisfiable = false;
915 }
916
917 for (auto Ext : AllExtensions) {
918 if (ST.canUseExtension(E: Ext))
919 continue;
920 LLVM_DEBUG(dbgs() << "Extension not supported: "
921 << getSymbolicOperandMnemonic(
922 OperandCategory::ExtensionOperand, Ext)
923 << "\n");
924 IsSatisfiable = false;
925 }
926
927 if (!IsSatisfiable)
928 report_fatal_error(reason: "Unable to meet SPIR-V requirements for this target.");
929}
930
931// Add the given capabilities and all their implicitly defined capabilities too.
932void SPIRV::RequirementHandler::addAvailableCaps(const CapabilityList &ToAdd) {
933 for (const auto Cap : ToAdd)
934 if (AvailableCaps.insert(V: Cap).second)
935 addAvailableCaps(ToAdd: getSymbolicOperandCapabilities(
936 Category: SPIRV::OperandCategory::CapabilityOperand, Value: Cap));
937}
938
939void SPIRV::RequirementHandler::removeCapabilityIf(
940 const Capability::Capability ToRemove,
941 const Capability::Capability IfPresent) {
942 if (AllCaps.contains(V: IfPresent))
943 AllCaps.erase(V: ToRemove);
944}
945
946namespace llvm {
947namespace SPIRV {
948void RequirementHandler::initAvailableCapabilities(const SPIRVSubtarget &ST) {
949 // Provided by both all supported Vulkan versions and OpenCl.
950 addAvailableCaps(ToAdd: {Capability::Shader, Capability::Linkage, Capability::Int8,
951 Capability::Int16});
952
953 if (ST.isAtLeastSPIRVVer(VerToCompareTo: VersionTuple(1, 3)))
954 addAvailableCaps(ToAdd: {Capability::GroupNonUniform,
955 Capability::GroupNonUniformVote,
956 Capability::GroupNonUniformArithmetic,
957 Capability::GroupNonUniformBallot,
958 Capability::GroupNonUniformClustered,
959 Capability::GroupNonUniformShuffle,
960 Capability::GroupNonUniformShuffleRelative,
961 Capability::GroupNonUniformQuad});
962
963 if (ST.isAtLeastSPIRVVer(VerToCompareTo: VersionTuple(1, 6)))
964 addAvailableCaps(ToAdd: {Capability::DotProduct, Capability::DotProductInputAll,
965 Capability::DotProductInput4x8Bit,
966 Capability::DotProductInput4x8BitPacked,
967 Capability::DemoteToHelperInvocation});
968
969 // Add capabilities enabled by extensions.
970 for (auto Extension : ST.getAllAvailableExtensions()) {
971 CapabilityList EnabledCapabilities =
972 getCapabilitiesEnabledByExtension(Extension);
973 addAvailableCaps(ToAdd: EnabledCapabilities);
974 }
975
976 if (!ST.isShader()) {
977 initAvailableCapabilitiesForOpenCL(ST);
978 return;
979 }
980
981 if (ST.isShader()) {
982 initAvailableCapabilitiesForVulkan(ST);
983 return;
984 }
985
986 report_fatal_error(reason: "Unimplemented environment for SPIR-V generation.");
987}
988
989void RequirementHandler::initAvailableCapabilitiesForOpenCL(
990 const SPIRVSubtarget &ST) {
991 // Add the min requirements for different OpenCL and SPIR-V versions.
992 addAvailableCaps(ToAdd: {Capability::Addresses, Capability::Float16Buffer,
993 Capability::Kernel, Capability::Vector16,
994 Capability::Groups, Capability::GenericPointer,
995 Capability::StorageImageWriteWithoutFormat,
996 Capability::StorageImageReadWithoutFormat});
997 if (ST.hasOpenCLFullProfile())
998 addAvailableCaps(ToAdd: {Capability::Int64, Capability::Int64Atomics});
999 if (ST.hasOpenCLImageSupport()) {
1000 addAvailableCaps(ToAdd: {Capability::ImageBasic, Capability::LiteralSampler,
1001 Capability::Image1D, Capability::SampledBuffer,
1002 Capability::ImageBuffer});
1003 if (ST.isAtLeastOpenCLVer(VerToCompareTo: VersionTuple(2, 0)))
1004 addAvailableCaps(ToAdd: {Capability::ImageReadWrite});
1005 }
1006 if (ST.isAtLeastSPIRVVer(VerToCompareTo: VersionTuple(1, 1)) &&
1007 ST.isAtLeastOpenCLVer(VerToCompareTo: VersionTuple(2, 2)))
1008 addAvailableCaps(ToAdd: {Capability::SubgroupDispatch, Capability::PipeStorage});
1009 if (ST.isAtLeastSPIRVVer(VerToCompareTo: VersionTuple(1, 4)))
1010 addAvailableCaps(ToAdd: {Capability::DenormPreserve, Capability::DenormFlushToZero,
1011 Capability::SignedZeroInfNanPreserve,
1012 Capability::RoundingModeRTE,
1013 Capability::RoundingModeRTZ});
1014 // TODO: verify if this needs some checks.
1015 addAvailableCaps(ToAdd: {Capability::Float16, Capability::Float64});
1016
1017 // TODO: add OpenCL extensions.
1018}
1019
1020void RequirementHandler::initAvailableCapabilitiesForVulkan(
1021 const SPIRVSubtarget &ST) {
1022
1023 // Core in Vulkan 1.1 and earlier.
1024 addAvailableCaps(ToAdd: {Capability::Int64,
1025 Capability::Float16,
1026 Capability::Float64,
1027 Capability::GroupNonUniform,
1028 Capability::Image1D,
1029 Capability::SampledBuffer,
1030 Capability::ImageBuffer,
1031 Capability::UniformBufferArrayDynamicIndexing,
1032 Capability::SampledImageArrayDynamicIndexing,
1033 Capability::StorageBufferArrayDynamicIndexing,
1034 Capability::StorageImageArrayDynamicIndexing,
1035 Capability::DerivativeControl,
1036 Capability::MinLod,
1037 Capability::ImageQuery,
1038 Capability::ImageGatherExtended,
1039 Capability::Addresses,
1040 Capability::VulkanMemoryModelKHR,
1041 Capability::StorageImageExtendedFormats,
1042 Capability::StorageImageMultisample,
1043 Capability::ImageMSArray});
1044
1045 // Became core in Vulkan 1.2
1046 if (ST.isAtLeastSPIRVVer(VerToCompareTo: VersionTuple(1, 5))) {
1047 addAvailableCaps(
1048 ToAdd: {Capability::Int64Atomics, Capability::ShaderNonUniformEXT,
1049 Capability::RuntimeDescriptorArrayEXT,
1050 Capability::InputAttachmentArrayDynamicIndexingEXT,
1051 Capability::UniformTexelBufferArrayDynamicIndexingEXT,
1052 Capability::StorageTexelBufferArrayDynamicIndexingEXT,
1053 Capability::UniformBufferArrayNonUniformIndexingEXT,
1054 Capability::SampledImageArrayNonUniformIndexingEXT,
1055 Capability::StorageBufferArrayNonUniformIndexingEXT,
1056 Capability::StorageImageArrayNonUniformIndexingEXT,
1057 Capability::InputAttachmentArrayNonUniformIndexingEXT,
1058 Capability::UniformTexelBufferArrayNonUniformIndexingEXT,
1059 Capability::StorageTexelBufferArrayNonUniformIndexingEXT});
1060 }
1061
1062 // Became core in Vulkan 1.3
1063 if (ST.isAtLeastSPIRVVer(VerToCompareTo: VersionTuple(1, 6)))
1064 addAvailableCaps(ToAdd: {Capability::StorageImageWriteWithoutFormat,
1065 Capability::StorageImageReadWithoutFormat});
1066}
1067
1068} // namespace SPIRV
1069} // namespace llvm
1070
1071// Add the required capabilities from a decoration instruction (including
1072// BuiltIns).
1073static void addOpDecorateReqs(const MachineInstr &MI, unsigned DecIndex,
1074 SPIRV::RequirementHandler &Reqs,
1075 const SPIRVSubtarget &ST) {
1076 int64_t DecOp = MI.getOperand(i: DecIndex).getImm();
1077 auto Dec = static_cast<SPIRV::Decoration::Decoration>(DecOp);
1078 Reqs.addRequirements(Req: getSymbolicOperandRequirements(
1079 Category: SPIRV::OperandCategory::DecorationOperand, i: Dec, ST, Reqs));
1080
1081 if (Dec == SPIRV::Decoration::BuiltIn) {
1082 int64_t BuiltInOp = MI.getOperand(i: DecIndex + 1).getImm();
1083 auto BuiltIn = static_cast<SPIRV::BuiltIn::BuiltIn>(BuiltInOp);
1084 Reqs.addRequirements(Req: getSymbolicOperandRequirements(
1085 Category: SPIRV::OperandCategory::BuiltInOperand, i: BuiltIn, ST, Reqs));
1086 } else if (Dec == SPIRV::Decoration::LinkageAttributes) {
1087 int64_t LinkageOp = MI.getOperand(i: MI.getNumOperands() - 1).getImm();
1088 SPIRV::LinkageType::LinkageType LnkType =
1089 static_cast<SPIRV::LinkageType::LinkageType>(LinkageOp);
1090 if (LnkType == SPIRV::LinkageType::LinkOnceODR)
1091 Reqs.addExtension(ToAdd: SPIRV::Extension::SPV_KHR_linkonce_odr);
1092 else if (LnkType == SPIRV::LinkageType::WeakAMD) {
1093 Reqs.addExtension(ToAdd: SPIRV::Extension::SPV_AMD_weak_linkage);
1094 Reqs.addCapability(ToAdd: SPIRV::Capability::WeakLinkageAMD);
1095 }
1096 } else if (Dec == SPIRV::Decoration::CacheControlLoadINTEL ||
1097 Dec == SPIRV::Decoration::CacheControlStoreINTEL) {
1098 Reqs.addExtension(ToAdd: SPIRV::Extension::SPV_INTEL_cache_controls);
1099 } else if (Dec == SPIRV::Decoration::HostAccessINTEL) {
1100 Reqs.addExtension(ToAdd: SPIRV::Extension::SPV_INTEL_global_variable_host_access);
1101 } else if (Dec == SPIRV::Decoration::InitModeINTEL ||
1102 Dec == SPIRV::Decoration::ImplementInRegisterMapINTEL) {
1103 Reqs.addExtension(
1104 ToAdd: SPIRV::Extension::SPV_INTEL_global_variable_fpga_decorations);
1105 } else if (Dec == SPIRV::Decoration::NonUniformEXT) {
1106 Reqs.addRequirements(Req: SPIRV::Capability::ShaderNonUniformEXT);
1107 } else if (Dec == SPIRV::Decoration::FPMaxErrorDecorationINTEL) {
1108 Reqs.addRequirements(Req: SPIRV::Capability::FPMaxErrorINTEL);
1109 Reqs.addExtension(ToAdd: SPIRV::Extension::SPV_INTEL_fp_max_error);
1110 } else if (Dec == SPIRV::Decoration::FPFastMathMode) {
1111 if (ST.canUseExtension(E: SPIRV::Extension::SPV_KHR_float_controls2)) {
1112 Reqs.addRequirements(Req: SPIRV::Capability::FloatControls2);
1113 Reqs.addExtension(ToAdd: SPIRV::Extension::SPV_KHR_float_controls2);
1114 }
1115 }
1116}
1117
1118// Add requirements for image handling.
1119static void addOpTypeImageReqs(const MachineInstr &MI,
1120 SPIRV::RequirementHandler &Reqs,
1121 const SPIRVSubtarget &ST) {
1122 assert(MI.getNumOperands() >= 8 && "Insufficient operands for OpTypeImage");
1123 // The operand indices used here are based on the OpTypeImage layout, which
1124 // the MachineInstr follows as well.
1125 int64_t ImgFormatOp = MI.getOperand(i: 7).getImm();
1126 auto ImgFormat = static_cast<SPIRV::ImageFormat::ImageFormat>(ImgFormatOp);
1127 Reqs.getAndAddRequirements(Category: SPIRV::OperandCategory::ImageFormatOperand,
1128 i: ImgFormat, ST);
1129
1130 bool IsArrayed = MI.getOperand(i: 4).getImm() == 1;
1131 bool IsMultisampled = MI.getOperand(i: 5).getImm() == 1;
1132 bool NoSampler = MI.getOperand(i: 6).getImm() == 2;
1133 // Add dimension requirements.
1134 assert(MI.getOperand(2).isImm());
1135 switch (MI.getOperand(i: 2).getImm()) {
1136 case SPIRV::Dim::DIM_1D:
1137 Reqs.addRequirements(Req: NoSampler ? SPIRV::Capability::Image1D
1138 : SPIRV::Capability::Sampled1D);
1139 break;
1140 case SPIRV::Dim::DIM_2D:
1141 if (IsMultisampled && NoSampler)
1142 Reqs.addRequirements(Req: SPIRV::Capability::StorageImageMultisample);
1143 if (IsMultisampled && IsArrayed)
1144 Reqs.addRequirements(Req: SPIRV::Capability::ImageMSArray);
1145 break;
1146 case SPIRV::Dim::DIM_3D:
1147 break;
1148 case SPIRV::Dim::DIM_Cube:
1149 Reqs.addRequirements(Req: SPIRV::Capability::Shader);
1150 if (IsArrayed)
1151 Reqs.addRequirements(Req: NoSampler ? SPIRV::Capability::ImageCubeArray
1152 : SPIRV::Capability::SampledCubeArray);
1153 break;
1154 case SPIRV::Dim::DIM_Rect:
1155 Reqs.addRequirements(Req: NoSampler ? SPIRV::Capability::ImageRect
1156 : SPIRV::Capability::SampledRect);
1157 break;
1158 case SPIRV::Dim::DIM_Buffer:
1159 Reqs.addRequirements(Req: NoSampler ? SPIRV::Capability::ImageBuffer
1160 : SPIRV::Capability::SampledBuffer);
1161 break;
1162 case SPIRV::Dim::DIM_SubpassData:
1163 Reqs.addRequirements(Req: SPIRV::Capability::InputAttachment);
1164 break;
1165 }
1166
1167 // Check if the sampled type is a 64-bit integer, which requires
1168 // Int64ImageEXT capability.
1169 assert(MI.getOperand(1).isReg());
1170 const MachineRegisterInfo &MRI = MI.getMF()->getRegInfo();
1171 SPIRVTypeInst SampledTypeDef = MRI.getVRegDef(Reg: MI.getOperand(i: 1).getReg());
1172 if (SampledTypeDef.isTypeIntN(N: 64)) {
1173 Reqs.addCapability(ToAdd: SPIRV::Capability::Int64ImageEXT);
1174 Reqs.addExtension(ToAdd: SPIRV::Extension::SPV_EXT_shader_image_int64);
1175 }
1176
1177 // Has optional access qualifier.
1178 if (!ST.isShader()) {
1179 if (MI.getNumOperands() > 8 &&
1180 MI.getOperand(i: 8).getImm() == SPIRV::AccessQualifier::ReadWrite)
1181 Reqs.addRequirements(Req: SPIRV::Capability::ImageReadWrite);
1182 else
1183 Reqs.addRequirements(Req: SPIRV::Capability::ImageBasic);
1184 }
1185}
1186
1187static bool isBFloat16Type(SPIRVTypeInst TypeDef) {
1188 return TypeDef && TypeDef->getNumOperands() == 3 &&
1189 TypeDef->getOpcode() == SPIRV::OpTypeFloat &&
1190 TypeDef->getOperand(i: 1).getImm() == 16 &&
1191 TypeDef->getOperand(i: 2).getImm() == SPIRV::FPEncoding::BFloat16KHR;
1192}
1193
1194// Add requirements for handling atomic float instructions
1195#define ATOM_FLT_REQ_EXT_MSG(ExtName) \
1196 "The atomic float instruction requires the following SPIR-V " \
1197 "extension: SPV_EXT_shader_atomic_float" ExtName
1198static void AddAtomicVectorFloatRequirements(const MachineInstr &MI,
1199 SPIRV::RequirementHandler &Reqs,
1200 const SPIRVSubtarget &ST) {
1201 SPIRVTypeInst VecTypeDef =
1202 MI.getMF()->getRegInfo().getVRegDef(Reg: MI.getOperand(i: 1).getReg());
1203
1204 const unsigned Rank = VecTypeDef->getOperand(i: 2).getImm();
1205 if (Rank != 2 && Rank != 4)
1206 reportFatalUsageError(reason: "Result type of an atomic vector float instruction "
1207 "must be a 2-component or 4 component vector");
1208
1209 SPIRVTypeInst EltTypeDef =
1210 MI.getMF()->getRegInfo().getVRegDef(Reg: VecTypeDef->getOperand(i: 1).getReg());
1211
1212 if (EltTypeDef->getOpcode() != SPIRV::OpTypeFloat ||
1213 EltTypeDef->getOperand(i: 1).getImm() != 16)
1214 reportFatalUsageError(
1215 reason: "The element type for the result type of an atomic vector float "
1216 "instruction must be a 16-bit floating-point scalar");
1217
1218 // The extension is defined for fp16, but the AMD target lets a bf16 vector
1219 // use the same instruction so it can lower to a packed bf16 atomic.
1220 if (isBFloat16Type(TypeDef: EltTypeDef) &&
1221 ST.getTargetTriple().getVendor() != Triple::AMD)
1222 reportFatalUsageError(
1223 reason: "The element type for the result type of an atomic vector float "
1224 "instruction cannot be a bfloat16 scalar");
1225 if (!ST.canUseExtension(E: SPIRV::Extension::SPV_NV_shader_atomic_fp16_vector))
1226 reportFatalUsageError(
1227 reason: "The atomic float16 vector instruction requires the following SPIR-V "
1228 "extension: SPV_NV_shader_atomic_fp16_vector");
1229
1230 Reqs.addExtension(ToAdd: SPIRV::Extension::SPV_NV_shader_atomic_fp16_vector);
1231 Reqs.addCapability(ToAdd: SPIRV::Capability::AtomicFloat16VectorNV);
1232}
1233
1234static void AddAtomicFloatRequirements(const MachineInstr &MI,
1235 SPIRV::RequirementHandler &Reqs,
1236 const SPIRVSubtarget &ST) {
1237 assert(MI.getOperand(1).isReg() &&
1238 "Expect register operand in atomic float instruction");
1239 Register TypeReg = MI.getOperand(i: 1).getReg();
1240 SPIRVTypeInst TypeDef = MI.getMF()->getRegInfo().getVRegDef(Reg: TypeReg);
1241
1242 if (TypeDef->getOpcode() == SPIRV::OpTypeVector)
1243 return AddAtomicVectorFloatRequirements(MI, Reqs, ST);
1244
1245 if (TypeDef->getOpcode() != SPIRV::OpTypeFloat)
1246 report_fatal_error(reason: "Result type of an atomic float instruction must be a "
1247 "floating-point type scalar");
1248
1249 unsigned BitWidth = TypeDef->getOperand(i: 1).getImm();
1250 unsigned Op = MI.getOpcode();
1251 if (Op == SPIRV::OpAtomicFAddEXT) {
1252 if (!ST.canUseExtension(E: SPIRV::Extension::SPV_EXT_shader_atomic_float_add))
1253 report_fatal_error(ATOM_FLT_REQ_EXT_MSG("_add"), gen_crash_diag: false);
1254 Reqs.addExtension(ToAdd: SPIRV::Extension::SPV_EXT_shader_atomic_float_add);
1255 switch (BitWidth) {
1256 case 16:
1257 if (isBFloat16Type(TypeDef)) {
1258 if (!ST.canUseExtension(E: SPIRV::Extension::SPV_INTEL_16bit_atomics))
1259 report_fatal_error(
1260 reason: "The atomic bfloat16 instruction requires the following SPIR-V "
1261 "extension: SPV_INTEL_16bit_atomics",
1262 gen_crash_diag: false);
1263 Reqs.addExtension(ToAdd: SPIRV::Extension::SPV_INTEL_16bit_atomics);
1264 Reqs.addCapability(ToAdd: SPIRV::Capability::AtomicBFloat16AddINTEL);
1265 } else {
1266 if (!ST.canUseExtension(
1267 E: SPIRV::Extension::SPV_EXT_shader_atomic_float16_add))
1268 report_fatal_error(ATOM_FLT_REQ_EXT_MSG("16_add"), gen_crash_diag: false);
1269 Reqs.addExtension(ToAdd: SPIRV::Extension::SPV_EXT_shader_atomic_float16_add);
1270 Reqs.addCapability(ToAdd: SPIRV::Capability::AtomicFloat16AddEXT);
1271 }
1272 break;
1273 case 32:
1274 Reqs.addCapability(ToAdd: SPIRV::Capability::AtomicFloat32AddEXT);
1275 break;
1276 case 64:
1277 Reqs.addCapability(ToAdd: SPIRV::Capability::AtomicFloat64AddEXT);
1278 break;
1279 default:
1280 report_fatal_error(
1281 reason: "Unexpected floating-point type width in atomic float instruction");
1282 }
1283 } else {
1284 if (!ST.canUseExtension(
1285 E: SPIRV::Extension::SPV_EXT_shader_atomic_float_min_max))
1286 report_fatal_error(ATOM_FLT_REQ_EXT_MSG("_min_max"), gen_crash_diag: false);
1287 Reqs.addExtension(ToAdd: SPIRV::Extension::SPV_EXT_shader_atomic_float_min_max);
1288 switch (BitWidth) {
1289 case 16:
1290 if (isBFloat16Type(TypeDef)) {
1291 if (!ST.canUseExtension(E: SPIRV::Extension::SPV_INTEL_16bit_atomics))
1292 report_fatal_error(
1293 reason: "The atomic bfloat16 instruction requires the following SPIR-V "
1294 "extension: SPV_INTEL_16bit_atomics",
1295 gen_crash_diag: false);
1296 Reqs.addExtension(ToAdd: SPIRV::Extension::SPV_INTEL_16bit_atomics);
1297 Reqs.addCapability(ToAdd: SPIRV::Capability::AtomicBFloat16MinMaxINTEL);
1298 } else {
1299 Reqs.addCapability(ToAdd: SPIRV::Capability::AtomicFloat16MinMaxEXT);
1300 }
1301 break;
1302 case 32:
1303 Reqs.addCapability(ToAdd: SPIRV::Capability::AtomicFloat32MinMaxEXT);
1304 break;
1305 case 64:
1306 Reqs.addCapability(ToAdd: SPIRV::Capability::AtomicFloat64MinMaxEXT);
1307 break;
1308 default:
1309 report_fatal_error(
1310 reason: "Unexpected floating-point type width in atomic float instruction");
1311 }
1312 }
1313}
1314
1315bool isUniformTexelBuffer(MachineInstr *ImageInst) {
1316 if (ImageInst->getOpcode() != SPIRV::OpTypeImage)
1317 return false;
1318 uint32_t Dim = ImageInst->getOperand(i: 2).getImm();
1319 uint32_t Sampled = ImageInst->getOperand(i: 6).getImm();
1320 return Dim == SPIRV::Dim::DIM_Buffer && Sampled == 1;
1321}
1322
1323bool isStorageTexelBuffer(MachineInstr *ImageInst) {
1324 if (ImageInst->getOpcode() != SPIRV::OpTypeImage)
1325 return false;
1326 uint32_t Dim = ImageInst->getOperand(i: 2).getImm();
1327 uint32_t Sampled = ImageInst->getOperand(i: 6).getImm();
1328 return Dim == SPIRV::Dim::DIM_Buffer && Sampled == 2;
1329}
1330
1331bool isSampledImage(MachineInstr *ImageInst) {
1332 if (ImageInst->getOpcode() != SPIRV::OpTypeImage)
1333 return false;
1334 uint32_t Dim = ImageInst->getOperand(i: 2).getImm();
1335 uint32_t Sampled = ImageInst->getOperand(i: 6).getImm();
1336 return Dim != SPIRV::Dim::DIM_Buffer && Sampled == 1;
1337}
1338
1339bool isInputAttachment(MachineInstr *ImageInst) {
1340 if (ImageInst->getOpcode() != SPIRV::OpTypeImage)
1341 return false;
1342 uint32_t Dim = ImageInst->getOperand(i: 2).getImm();
1343 uint32_t Sampled = ImageInst->getOperand(i: 6).getImm();
1344 return Dim == SPIRV::Dim::DIM_SubpassData && Sampled == 2;
1345}
1346
1347bool isStorageImage(MachineInstr *ImageInst) {
1348 if (ImageInst->getOpcode() != SPIRV::OpTypeImage)
1349 return false;
1350 uint32_t Dim = ImageInst->getOperand(i: 2).getImm();
1351 uint32_t Sampled = ImageInst->getOperand(i: 6).getImm();
1352 return Dim != SPIRV::Dim::DIM_Buffer && Sampled == 2;
1353}
1354
1355bool isCombinedImageSampler(MachineInstr *SampledImageInst) {
1356 if (SampledImageInst->getOpcode() != SPIRV::OpTypeSampledImage)
1357 return false;
1358
1359 const MachineRegisterInfo &MRI = SampledImageInst->getMF()->getRegInfo();
1360 Register ImageReg = SampledImageInst->getOperand(i: 1).getReg();
1361 auto *ImageInst = MRI.getUniqueVRegDef(Reg: ImageReg);
1362 return isSampledImage(ImageInst);
1363}
1364
1365bool hasNonUniformDecoration(Register Reg, const MachineRegisterInfo &MRI) {
1366 for (const auto &MI : MRI.reg_instructions(Reg)) {
1367 if (MI.getOpcode() != SPIRV::OpDecorate)
1368 continue;
1369
1370 uint32_t Dec = MI.getOperand(i: 1).getImm();
1371 if (Dec == SPIRV::Decoration::NonUniformEXT)
1372 return true;
1373 }
1374 return false;
1375}
1376
1377void addOpAccessChainReqs(const MachineInstr &Instr,
1378 SPIRV::RequirementHandler &Handler,
1379 const SPIRVSubtarget &Subtarget) {
1380 const MachineRegisterInfo &MRI = Instr.getMF()->getRegInfo();
1381 // Get the result type. If it is an image type, then the shader uses
1382 // descriptor indexing. The appropriate capabilities will be added based
1383 // on the specifics of the image.
1384 Register ResTypeReg = Instr.getOperand(i: 1).getReg();
1385 MachineInstr *ResTypeInst = MRI.getUniqueVRegDef(Reg: ResTypeReg);
1386
1387 assert(ResTypeInst->getOpcode() == SPIRV::OpTypePointer);
1388 uint32_t StorageClass = ResTypeInst->getOperand(i: 1).getImm();
1389 if (StorageClass != SPIRV::StorageClass::StorageClass::UniformConstant &&
1390 StorageClass != SPIRV::StorageClass::StorageClass::Uniform &&
1391 StorageClass != SPIRV::StorageClass::StorageClass::StorageBuffer) {
1392 return;
1393 }
1394
1395 bool IsNonUniform =
1396 hasNonUniformDecoration(Reg: Instr.getOperand(i: 0).getReg(), MRI);
1397
1398 auto FirstIndexReg = Instr.getOperand(i: 3).getReg();
1399 bool FirstIndexIsConstant =
1400 Subtarget.getInstrInfo()->isConstantInstr(MI: *MRI.getVRegDef(Reg: FirstIndexReg));
1401
1402 if (StorageClass == SPIRV::StorageClass::StorageClass::StorageBuffer) {
1403 if (IsNonUniform)
1404 Handler.addRequirements(
1405 Req: SPIRV::Capability::StorageBufferArrayNonUniformIndexingEXT);
1406 else if (!FirstIndexIsConstant)
1407 Handler.addRequirements(
1408 Req: SPIRV::Capability::StorageBufferArrayDynamicIndexing);
1409 return;
1410 }
1411
1412 Register PointeeTypeReg = ResTypeInst->getOperand(i: 2).getReg();
1413 MachineInstr *PointeeType = MRI.getUniqueVRegDef(Reg: PointeeTypeReg);
1414 if (PointeeType->getOpcode() != SPIRV::OpTypeImage &&
1415 PointeeType->getOpcode() != SPIRV::OpTypeSampledImage &&
1416 PointeeType->getOpcode() != SPIRV::OpTypeSampler) {
1417 return;
1418 }
1419
1420 if (isUniformTexelBuffer(ImageInst: PointeeType)) {
1421 if (IsNonUniform)
1422 Handler.addRequirements(
1423 Req: SPIRV::Capability::UniformTexelBufferArrayNonUniformIndexingEXT);
1424 else if (!FirstIndexIsConstant)
1425 Handler.addRequirements(
1426 Req: SPIRV::Capability::UniformTexelBufferArrayDynamicIndexingEXT);
1427 } else if (isInputAttachment(ImageInst: PointeeType)) {
1428 if (IsNonUniform)
1429 Handler.addRequirements(
1430 Req: SPIRV::Capability::InputAttachmentArrayNonUniformIndexingEXT);
1431 else if (!FirstIndexIsConstant)
1432 Handler.addRequirements(
1433 Req: SPIRV::Capability::InputAttachmentArrayDynamicIndexingEXT);
1434 } else if (isStorageTexelBuffer(ImageInst: PointeeType)) {
1435 if (IsNonUniform)
1436 Handler.addRequirements(
1437 Req: SPIRV::Capability::StorageTexelBufferArrayNonUniformIndexingEXT);
1438 else if (!FirstIndexIsConstant)
1439 Handler.addRequirements(
1440 Req: SPIRV::Capability::StorageTexelBufferArrayDynamicIndexingEXT);
1441 } else if (isSampledImage(ImageInst: PointeeType) ||
1442 isCombinedImageSampler(SampledImageInst: PointeeType) ||
1443 PointeeType->getOpcode() == SPIRV::OpTypeSampler) {
1444 if (IsNonUniform)
1445 Handler.addRequirements(
1446 Req: SPIRV::Capability::SampledImageArrayNonUniformIndexingEXT);
1447 else if (!FirstIndexIsConstant)
1448 Handler.addRequirements(
1449 Req: SPIRV::Capability::SampledImageArrayDynamicIndexing);
1450 } else if (isStorageImage(ImageInst: PointeeType)) {
1451 if (IsNonUniform)
1452 Handler.addRequirements(
1453 Req: SPIRV::Capability::StorageImageArrayNonUniformIndexingEXT);
1454 else if (!FirstIndexIsConstant)
1455 Handler.addRequirements(
1456 Req: SPIRV::Capability::StorageImageArrayDynamicIndexing);
1457 }
1458}
1459
1460static bool isImageTypeWithUnknownFormat(SPIRVTypeInst TypeInst) {
1461 if (TypeInst->getOpcode() != SPIRV::OpTypeImage)
1462 return false;
1463 assert(TypeInst->getOperand(7).isImm() && "The image format must be an imm.");
1464 return TypeInst->getOperand(i: 7).getImm() == 0;
1465}
1466
1467static void AddDotProductRequirements(const MachineInstr &MI,
1468 SPIRV::RequirementHandler &Reqs,
1469 const SPIRVSubtarget &ST) {
1470 if (ST.canUseExtension(E: SPIRV::Extension::SPV_KHR_integer_dot_product))
1471 Reqs.addExtension(ToAdd: SPIRV::Extension::SPV_KHR_integer_dot_product);
1472 Reqs.addCapability(ToAdd: SPIRV::Capability::DotProduct);
1473
1474 const MachineRegisterInfo &MRI = MI.getMF()->getRegInfo();
1475 assert(MI.getOperand(2).isReg() && "Unexpected operand in dot");
1476 // We do not consider what the previous instruction is. This is just used
1477 // to get the input register and to check the type.
1478 const MachineInstr *Input = MRI.getVRegDef(Reg: MI.getOperand(i: 2).getReg());
1479 assert(Input->getOperand(1).isReg() && "Unexpected operand in dot input");
1480 Register InputReg = Input->getOperand(i: 1).getReg();
1481
1482 SPIRVTypeInst TypeDef = MRI.getVRegDef(Reg: InputReg);
1483 if (TypeDef->getOpcode() == SPIRV::OpTypeInt) {
1484 assert(TypeDef->getOperand(1).getImm() == 32);
1485 Reqs.addCapability(ToAdd: SPIRV::Capability::DotProductInput4x8BitPacked);
1486 } else if (TypeDef->getOpcode() == SPIRV::OpTypeVector) {
1487 SPIRVTypeInst ScalarTypeDef =
1488 MRI.getVRegDef(Reg: TypeDef->getOperand(i: 1).getReg());
1489 assert(ScalarTypeDef->getOpcode() == SPIRV::OpTypeInt);
1490 if (ScalarTypeDef->getOperand(i: 1).getImm() == 8) {
1491 assert(TypeDef->getOperand(2).getImm() == 4 &&
1492 "Dot operand of 8-bit integer type requires 4 components");
1493 Reqs.addCapability(ToAdd: SPIRV::Capability::DotProductInput4x8Bit);
1494 } else {
1495 Reqs.addCapability(ToAdd: SPIRV::Capability::DotProductInputAll);
1496 }
1497 }
1498}
1499
1500void addPrintfRequirements(const MachineInstr &MI,
1501 SPIRV::RequirementHandler &Reqs,
1502 const SPIRVSubtarget &ST) {
1503 SPIRVGlobalRegistry *GR = ST.getSPIRVGlobalRegistry();
1504 SPIRVTypeInst PtrType =
1505 GR->getSPIRVTypeForVReg(VReg: MI.getOperand(i: 4).getReg(), MF: MI.getMF());
1506 if (PtrType) {
1507 MachineOperand ASOp = PtrType->getOperand(i: 1);
1508 if (ASOp.isImm()) {
1509 unsigned AddrSpace = ASOp.getImm();
1510 if (AddrSpace != SPIRV::StorageClass::UniformConstant) {
1511 if (!ST.canUseExtension(
1512 E: SPIRV::Extension::
1513 SPV_EXT_relaxed_printf_string_address_space)) {
1514 report_fatal_error(reason: "SPV_EXT_relaxed_printf_string_address_space is "
1515 "required because printf uses a format string not "
1516 "in constant address space.",
1517 gen_crash_diag: false);
1518 }
1519 Reqs.addExtension(
1520 ToAdd: SPIRV::Extension::SPV_EXT_relaxed_printf_string_address_space);
1521 }
1522 }
1523 }
1524}
1525
1526static void addImageOperandReqs(const MachineInstr &MI,
1527 SPIRV::RequirementHandler &Reqs,
1528 const SPIRVSubtarget &ST, unsigned OpIdx) {
1529 if (MI.getNumOperands() <= OpIdx)
1530 return;
1531 uint32_t Mask = MI.getOperand(i: OpIdx).getImm();
1532 for (uint32_t I = 0; I < 32; ++I)
1533 if (Mask & (1U << I))
1534 Reqs.getAndAddRequirements(Category: SPIRV::OperandCategory::ImageOperandOperand,
1535 i: 1U << I, ST);
1536}
1537
1538void addInstrRequirements(const MachineInstr &MI,
1539 SPIRV::ModuleAnalysisInfo &MAI,
1540 const SPIRVSubtarget &ST) {
1541 SPIRV::RequirementHandler &Reqs = MAI.Reqs;
1542 unsigned Op = MI.getOpcode();
1543 switch (Op) {
1544 case SPIRV::OpMemoryModel: {
1545 int64_t Addr = MI.getOperand(i: 0).getImm();
1546 Reqs.getAndAddRequirements(Category: SPIRV::OperandCategory::AddressingModelOperand,
1547 i: Addr, ST);
1548 int64_t Mem = MI.getOperand(i: 1).getImm();
1549 Reqs.getAndAddRequirements(Category: SPIRV::OperandCategory::MemoryModelOperand, i: Mem,
1550 ST);
1551 break;
1552 }
1553 case SPIRV::OpEntryPoint: {
1554 int64_t Exe = MI.getOperand(i: 0).getImm();
1555 Reqs.getAndAddRequirements(Category: SPIRV::OperandCategory::ExecutionModelOperand,
1556 i: Exe, ST);
1557 break;
1558 }
1559 case SPIRV::OpExecutionMode:
1560 case SPIRV::OpExecutionModeId: {
1561 int64_t Exe = MI.getOperand(i: 1).getImm();
1562 Reqs.getAndAddRequirements(Category: SPIRV::OperandCategory::ExecutionModeOperand,
1563 i: Exe, ST);
1564 break;
1565 }
1566 case SPIRV::OpTypeMatrix:
1567 Reqs.addCapability(ToAdd: SPIRV::Capability::Matrix);
1568 break;
1569 case SPIRV::OpTypeInt: {
1570 unsigned BitWidth = MI.getOperand(i: 1).getImm();
1571 if (BitWidth == 64)
1572 Reqs.addCapability(ToAdd: SPIRV::Capability::Int64);
1573 else if (BitWidth == 16)
1574 Reqs.addCapability(ToAdd: SPIRV::Capability::Int16);
1575 else if (BitWidth == 8)
1576 Reqs.addCapability(ToAdd: SPIRV::Capability::Int8);
1577 else if (BitWidth == 4 &&
1578 ST.canUseExtension(E: SPIRV::Extension::SPV_INTEL_int4)) {
1579 Reqs.addExtension(ToAdd: SPIRV::Extension::SPV_INTEL_int4);
1580 Reqs.addCapability(ToAdd: SPIRV::Capability::Int4TypeINTEL);
1581 } else if (BitWidth != 32) {
1582 if (!ST.canUseExtension(
1583 E: SPIRV::Extension::SPV_ALTERA_arbitrary_precision_integers))
1584 reportFatalUsageError(
1585 reason: "OpTypeInt type with a width other than 8, 16, 32 or 64 bits "
1586 "requires the following SPIR-V extension: "
1587 "SPV_ALTERA_arbitrary_precision_integers");
1588 Reqs.addExtension(
1589 ToAdd: SPIRV::Extension::SPV_ALTERA_arbitrary_precision_integers);
1590 Reqs.addCapability(ToAdd: SPIRV::Capability::ArbitraryPrecisionIntegersALTERA);
1591 }
1592 break;
1593 }
1594 case SPIRV::OpDot: {
1595 const MachineRegisterInfo &MRI = MI.getMF()->getRegInfo();
1596 SPIRVTypeInst TypeDef = MRI.getVRegDef(Reg: MI.getOperand(i: 1).getReg());
1597 if (isBFloat16Type(TypeDef))
1598 Reqs.addCapability(ToAdd: SPIRV::Capability::BFloat16DotProductKHR);
1599 break;
1600 }
1601 case SPIRV::OpTypeFloat: {
1602 unsigned BitWidth = MI.getOperand(i: 1).getImm();
1603 if (BitWidth == 64)
1604 Reqs.addCapability(ToAdd: SPIRV::Capability::Float64);
1605 else if (BitWidth == 16) {
1606 if (isBFloat16Type(TypeDef: &MI)) {
1607 if (!ST.canUseExtension(E: SPIRV::Extension::SPV_KHR_bfloat16))
1608 report_fatal_error(reason: "OpTypeFloat type with bfloat requires the "
1609 "following SPIR-V extension: SPV_KHR_bfloat16",
1610 gen_crash_diag: false);
1611 Reqs.addExtension(ToAdd: SPIRV::Extension::SPV_KHR_bfloat16);
1612 Reqs.addCapability(ToAdd: SPIRV::Capability::BFloat16TypeKHR);
1613 } else {
1614 Reqs.addCapability(ToAdd: SPIRV::Capability::Float16);
1615 }
1616 }
1617 break;
1618 }
1619 case SPIRV::OpTypeVector: {
1620 unsigned NumComponents = MI.getOperand(i: 2).getImm();
1621 if (NumComponents == 8 || NumComponents == 16)
1622 Reqs.addCapability(ToAdd: SPIRV::Capability::Vector16);
1623
1624 assert(MI.getOperand(1).isReg());
1625 const MachineRegisterInfo &MRI = MI.getMF()->getRegInfo();
1626 SPIRVTypeInst ElemTypeDef = MRI.getVRegDef(Reg: MI.getOperand(i: 1).getReg());
1627 if (ElemTypeDef->getOpcode() == SPIRV::OpTypePointer &&
1628 ST.canUseExtension(E: SPIRV::Extension::SPV_INTEL_masked_gather_scatter)) {
1629 Reqs.addExtension(ToAdd: SPIRV::Extension::SPV_INTEL_masked_gather_scatter);
1630 Reqs.addCapability(ToAdd: SPIRV::Capability::MaskedGatherScatterINTEL);
1631 }
1632 break;
1633 }
1634 case SPIRV::OpTypePointer: {
1635 auto SC = MI.getOperand(i: 1).getImm();
1636 Reqs.getAndAddRequirements(Category: SPIRV::OperandCategory::StorageClassOperand, i: SC,
1637 ST);
1638 // If it's a type of pointer to float16 targeting OpenCL, add Float16Buffer
1639 // capability.
1640 if (ST.isShader())
1641 break;
1642 assert(MI.getOperand(2).isReg());
1643 const MachineRegisterInfo &MRI = MI.getMF()->getRegInfo();
1644 SPIRVTypeInst TypeDef = MRI.getVRegDef(Reg: MI.getOperand(i: 2).getReg());
1645 if ((TypeDef->getNumOperands() == 2) &&
1646 (TypeDef->getOpcode() == SPIRV::OpTypeFloat) &&
1647 (TypeDef->getOperand(i: 1).getImm() == 16))
1648 Reqs.addCapability(ToAdd: SPIRV::Capability::Float16Buffer);
1649 break;
1650 }
1651 case SPIRV::OpExtInst: {
1652 if (MI.getOperand(i: 2).getImm() ==
1653 static_cast<int64_t>(
1654 SPIRV::InstructionSet::NonSemantic_Shader_DebugInfo_100)) {
1655 Reqs.addExtension(ToAdd: SPIRV::Extension::SPV_KHR_non_semantic_info);
1656 break;
1657 }
1658 if (MI.getOperand(i: 3).getImm() ==
1659 static_cast<int64_t>(SPIRV::OpenCLExtInst::printf)) {
1660 addPrintfRequirements(MI, Reqs, ST);
1661 break;
1662 }
1663 // TODO: handle bfloat16 extended instructions when
1664 // SPV_INTEL_bfloat16_arithmetic is enabled.
1665 break;
1666 }
1667 case SPIRV::OpAliasDomainDeclINTEL:
1668 case SPIRV::OpAliasScopeDeclINTEL:
1669 case SPIRV::OpAliasScopeListDeclINTEL: {
1670 Reqs.addExtension(ToAdd: SPIRV::Extension::SPV_INTEL_memory_access_aliasing);
1671 Reqs.addCapability(ToAdd: SPIRV::Capability::MemoryAccessAliasingINTEL);
1672 break;
1673 }
1674 case SPIRV::OpBitReverse:
1675 case SPIRV::OpBitFieldInsert:
1676 case SPIRV::OpBitFieldSExtract:
1677 case SPIRV::OpBitFieldUExtract:
1678 if (!ST.canUseExtension(E: SPIRV::Extension::SPV_KHR_bit_instructions)) {
1679 Reqs.addCapability(ToAdd: SPIRV::Capability::Shader);
1680 break;
1681 }
1682 Reqs.addExtension(ToAdd: SPIRV::Extension::SPV_KHR_bit_instructions);
1683 Reqs.addCapability(ToAdd: SPIRV::Capability::BitInstructions);
1684 break;
1685 case SPIRV::OpTypeRuntimeArray:
1686 Reqs.addCapability(ToAdd: SPIRV::Capability::Shader);
1687 break;
1688 case SPIRV::OpTypeOpaque:
1689 case SPIRV::OpTypeEvent:
1690 Reqs.addCapability(ToAdd: SPIRV::Capability::Kernel);
1691 break;
1692 case SPIRV::OpTypePipe:
1693 case SPIRV::OpTypeReserveId:
1694 Reqs.addCapability(ToAdd: SPIRV::Capability::Pipes);
1695 break;
1696 case SPIRV::OpTypeDeviceEvent:
1697 case SPIRV::OpTypeQueue:
1698 case SPIRV::OpBuildNDRange:
1699 case SPIRV::OpEnqueueKernel:
1700 Reqs.addCapability(ToAdd: SPIRV::Capability::DeviceEnqueue);
1701 break;
1702 case SPIRV::OpDecorate:
1703 case SPIRV::OpDecorateId:
1704 case SPIRV::OpDecorateString:
1705 addOpDecorateReqs(MI, DecIndex: 1, Reqs, ST);
1706 break;
1707 case SPIRV::OpMemberDecorate:
1708 case SPIRV::OpMemberDecorateString:
1709 addOpDecorateReqs(MI, DecIndex: 2, Reqs, ST);
1710 break;
1711 case SPIRV::OpInBoundsPtrAccessChain:
1712 Reqs.addCapability(ToAdd: SPIRV::Capability::Addresses);
1713 break;
1714 case SPIRV::OpConstantSampler:
1715 Reqs.addCapability(ToAdd: SPIRV::Capability::LiteralSampler);
1716 break;
1717 case SPIRV::OpInBoundsAccessChain:
1718 case SPIRV::OpAccessChain:
1719 addOpAccessChainReqs(Instr: MI, Handler&: Reqs, Subtarget: ST);
1720 break;
1721 case SPIRV::OpTypeImage:
1722 addOpTypeImageReqs(MI, Reqs, ST);
1723 break;
1724 case SPIRV::OpTypeSampler:
1725 if (!ST.isShader()) {
1726 Reqs.addCapability(ToAdd: SPIRV::Capability::ImageBasic);
1727 }
1728 break;
1729 case SPIRV::OpTypeForwardPointer:
1730 // TODO: check if it's OpenCL's kernel.
1731 Reqs.addCapability(ToAdd: SPIRV::Capability::Addresses);
1732 break;
1733 case SPIRV::OpAtomicFlagTestAndSet:
1734 case SPIRV::OpAtomicLoad:
1735 case SPIRV::OpAtomicStore:
1736 case SPIRV::OpAtomicExchange:
1737 case SPIRV::OpAtomicCompareExchange:
1738 case SPIRV::OpAtomicIIncrement:
1739 case SPIRV::OpAtomicIDecrement:
1740 case SPIRV::OpAtomicIAdd:
1741 case SPIRV::OpAtomicISub:
1742 case SPIRV::OpAtomicUMin:
1743 case SPIRV::OpAtomicUMax:
1744 case SPIRV::OpAtomicSMin:
1745 case SPIRV::OpAtomicSMax:
1746 case SPIRV::OpAtomicAnd:
1747 case SPIRV::OpAtomicOr:
1748 case SPIRV::OpAtomicXor: {
1749 const MachineRegisterInfo &MRI = MI.getMF()->getRegInfo();
1750 const MachineInstr *InstrPtr = &MI;
1751 if (Op == SPIRV::OpAtomicStore) {
1752 assert(MI.getOperand(3).isReg());
1753 InstrPtr = MRI.getVRegDef(Reg: MI.getOperand(i: 3).getReg());
1754 assert(InstrPtr && "Unexpected type instruction for OpAtomicStore");
1755 }
1756 assert(InstrPtr->getOperand(1).isReg() && "Unexpected operand in atomic");
1757 Register TypeReg = InstrPtr->getOperand(i: 1).getReg();
1758 SPIRVTypeInst TypeDef = MRI.getVRegDef(Reg: TypeReg);
1759
1760 if (TypeDef->getOpcode() == SPIRV::OpTypeInt) {
1761 unsigned BitWidth = TypeDef->getOperand(i: 1).getImm();
1762 if (BitWidth == 64)
1763 Reqs.addCapability(ToAdd: SPIRV::Capability::Int64Atomics);
1764 else if (BitWidth == 16) {
1765 if (!ST.canUseExtension(E: SPIRV::Extension::SPV_INTEL_16bit_atomics))
1766 report_fatal_error(
1767 reason: "16-bit integer atomic operations require the following SPIR-V "
1768 "extension: SPV_INTEL_16bit_atomics",
1769 gen_crash_diag: false);
1770 Reqs.addExtension(ToAdd: SPIRV::Extension::SPV_INTEL_16bit_atomics);
1771 switch (Op) {
1772 case SPIRV::OpAtomicLoad:
1773 case SPIRV::OpAtomicStore:
1774 case SPIRV::OpAtomicExchange:
1775 case SPIRV::OpAtomicCompareExchange:
1776 case SPIRV::OpAtomicCompareExchangeWeak:
1777 Reqs.addCapability(
1778 ToAdd: SPIRV::Capability::AtomicInt16CompareExchangeINTEL);
1779 break;
1780 default:
1781 Reqs.addCapability(ToAdd: SPIRV::Capability::Int16AtomicsINTEL);
1782 break;
1783 }
1784 }
1785 } else if (isBFloat16Type(TypeDef)) {
1786 if (is_contained(Set: {SPIRV::OpAtomicLoad, SPIRV::OpAtomicStore,
1787 SPIRV::OpAtomicExchange},
1788 Element: Op)) {
1789 if (!ST.canUseExtension(E: SPIRV::Extension::SPV_INTEL_16bit_atomics))
1790 report_fatal_error(
1791 reason: "The atomic bfloat16 instruction requires the following SPIR-V "
1792 "extension: SPV_INTEL_16bit_atomics",
1793 gen_crash_diag: false);
1794 Reqs.addExtension(ToAdd: SPIRV::Extension::SPV_INTEL_16bit_atomics);
1795 Reqs.addCapability(ToAdd: SPIRV::Capability::AtomicBFloat16LoadStoreINTEL);
1796 }
1797 }
1798 break;
1799 }
1800 case SPIRV::OpGroupNonUniformIAdd:
1801 case SPIRV::OpGroupNonUniformFAdd:
1802 case SPIRV::OpGroupNonUniformIMul:
1803 case SPIRV::OpGroupNonUniformFMul:
1804 case SPIRV::OpGroupNonUniformSMin:
1805 case SPIRV::OpGroupNonUniformUMin:
1806 case SPIRV::OpGroupNonUniformFMin:
1807 case SPIRV::OpGroupNonUniformSMax:
1808 case SPIRV::OpGroupNonUniformUMax:
1809 case SPIRV::OpGroupNonUniformFMax:
1810 case SPIRV::OpGroupNonUniformBitwiseAnd:
1811 case SPIRV::OpGroupNonUniformBitwiseOr:
1812 case SPIRV::OpGroupNonUniformBitwiseXor:
1813 case SPIRV::OpGroupNonUniformLogicalAnd:
1814 case SPIRV::OpGroupNonUniformLogicalOr:
1815 case SPIRV::OpGroupNonUniformLogicalXor: {
1816 assert(MI.getOperand(3).isImm());
1817 int64_t GroupOp = MI.getOperand(i: 3).getImm();
1818 switch (GroupOp) {
1819 case SPIRV::GroupOperation::Reduce:
1820 case SPIRV::GroupOperation::InclusiveScan:
1821 case SPIRV::GroupOperation::ExclusiveScan:
1822 Reqs.addCapability(ToAdd: SPIRV::Capability::GroupNonUniformArithmetic);
1823 break;
1824 case SPIRV::GroupOperation::ClusteredReduce:
1825 Reqs.addCapability(ToAdd: SPIRV::Capability::GroupNonUniformClustered);
1826 break;
1827 case SPIRV::GroupOperation::PartitionedReduceNV:
1828 case SPIRV::GroupOperation::PartitionedInclusiveScanNV:
1829 case SPIRV::GroupOperation::PartitionedExclusiveScanNV:
1830 Reqs.addCapability(ToAdd: SPIRV::Capability::GroupNonUniformPartitionedNV);
1831 break;
1832 }
1833 break;
1834 }
1835 case SPIRV::OpGroupNonUniformQuadSwap:
1836 Reqs.addCapability(ToAdd: SPIRV::Capability::GroupNonUniformQuad);
1837 break;
1838 case SPIRV::OpImageQueryLod:
1839 Reqs.addCapability(ToAdd: SPIRV::Capability::ImageQuery);
1840 break;
1841 case SPIRV::OpImageQuerySize:
1842 case SPIRV::OpImageQuerySizeLod:
1843 case SPIRV::OpImageQueryLevels:
1844 case SPIRV::OpImageQuerySamples:
1845 if (ST.isShader())
1846 Reqs.addCapability(ToAdd: SPIRV::Capability::ImageQuery);
1847 break;
1848 case SPIRV::OpImageQueryFormat: {
1849 Register ResultReg = MI.getOperand(i: 0).getReg();
1850 const MachineRegisterInfo &MRI = MI.getMF()->getRegInfo();
1851 static const unsigned CompareOps[] = {
1852 SPIRV::OpIEqual, SPIRV::OpINotEqual,
1853 SPIRV::OpUGreaterThan, SPIRV::OpUGreaterThanEqual,
1854 SPIRV::OpULessThan, SPIRV::OpULessThanEqual,
1855 SPIRV::OpSGreaterThan, SPIRV::OpSGreaterThanEqual,
1856 SPIRV::OpSLessThan, SPIRV::OpSLessThanEqual};
1857
1858 auto CheckAndAddExtension = [&](int64_t ImmVal) {
1859 if (ImmVal == 4323 || ImmVal == 4324) {
1860 if (ST.canUseExtension(E: SPIRV::Extension::SPV_EXT_image_raw10_raw12))
1861 Reqs.addExtension(ToAdd: SPIRV::Extension::SPV_EXT_image_raw10_raw12);
1862 else
1863 report_fatal_error(reason: "This requires the "
1864 "SPV_EXT_image_raw10_raw12 extension");
1865 }
1866 };
1867
1868 for (MachineInstr &UseInst : MRI.use_instructions(Reg: ResultReg)) {
1869 unsigned Opc = UseInst.getOpcode();
1870
1871 if (Opc == SPIRV::OpSwitch) {
1872 for (const MachineOperand &Op : UseInst.operands())
1873 if (Op.isImm())
1874 CheckAndAddExtension(Op.getImm());
1875 } else if (llvm::is_contained(Range: CompareOps, Element: Opc)) {
1876 for (unsigned i = 1; i < UseInst.getNumOperands(); ++i) {
1877 Register UseReg = UseInst.getOperand(i).getReg();
1878 MachineInstr *ConstInst = MRI.getVRegDef(Reg: UseReg);
1879 if (ConstInst && ConstInst->getOpcode() == SPIRV::OpConstantI) {
1880 int64_t ImmVal = ConstInst->getOperand(i: 2).getImm();
1881 if (ImmVal)
1882 CheckAndAddExtension(ImmVal);
1883 }
1884 }
1885 }
1886 }
1887 break;
1888 }
1889
1890 case SPIRV::OpGroupNonUniformShuffle:
1891 case SPIRV::OpGroupNonUniformShuffleXor:
1892 Reqs.addCapability(ToAdd: SPIRV::Capability::GroupNonUniformShuffle);
1893 break;
1894 case SPIRV::OpGroupNonUniformShuffleUp:
1895 case SPIRV::OpGroupNonUniformShuffleDown:
1896 Reqs.addCapability(ToAdd: SPIRV::Capability::GroupNonUniformShuffleRelative);
1897 break;
1898 case SPIRV::OpGroupAll:
1899 case SPIRV::OpGroupAny:
1900 case SPIRV::OpGroupBroadcast:
1901 case SPIRV::OpGroupIAdd:
1902 case SPIRV::OpGroupFAdd:
1903 case SPIRV::OpGroupFMin:
1904 case SPIRV::OpGroupUMin:
1905 case SPIRV::OpGroupSMin:
1906 case SPIRV::OpGroupFMax:
1907 case SPIRV::OpGroupUMax:
1908 case SPIRV::OpGroupSMax:
1909 Reqs.addCapability(ToAdd: SPIRV::Capability::Groups);
1910 break;
1911 case SPIRV::OpGroupNonUniformElect:
1912 Reqs.addCapability(ToAdd: SPIRV::Capability::GroupNonUniform);
1913 break;
1914 case SPIRV::OpGroupNonUniformAll:
1915 case SPIRV::OpGroupNonUniformAny:
1916 case SPIRV::OpGroupNonUniformAllEqual:
1917 Reqs.addCapability(ToAdd: SPIRV::Capability::GroupNonUniformVote);
1918 break;
1919 case SPIRV::OpGroupNonUniformBroadcast:
1920 case SPIRV::OpGroupNonUniformBroadcastFirst:
1921 case SPIRV::OpGroupNonUniformBallot:
1922 case SPIRV::OpGroupNonUniformInverseBallot:
1923 case SPIRV::OpGroupNonUniformBallotBitExtract:
1924 case SPIRV::OpGroupNonUniformBallotBitCount:
1925 case SPIRV::OpGroupNonUniformBallotFindLSB:
1926 case SPIRV::OpGroupNonUniformBallotFindMSB:
1927 Reqs.addCapability(ToAdd: SPIRV::Capability::GroupNonUniformBallot);
1928 break;
1929 case SPIRV::OpSubgroupShuffleINTEL:
1930 case SPIRV::OpSubgroupShuffleDownINTEL:
1931 case SPIRV::OpSubgroupShuffleUpINTEL:
1932 case SPIRV::OpSubgroupShuffleXorINTEL:
1933 if (ST.canUseExtension(E: SPIRV::Extension::SPV_INTEL_subgroups)) {
1934 Reqs.addExtension(ToAdd: SPIRV::Extension::SPV_INTEL_subgroups);
1935 Reqs.addCapability(ToAdd: SPIRV::Capability::SubgroupShuffleINTEL);
1936 }
1937 break;
1938 case SPIRV::OpSubgroupBlockReadINTEL:
1939 case SPIRV::OpSubgroupBlockWriteINTEL:
1940 if (ST.canUseExtension(E: SPIRV::Extension::SPV_INTEL_subgroups)) {
1941 Reqs.addExtension(ToAdd: SPIRV::Extension::SPV_INTEL_subgroups);
1942 Reqs.addCapability(ToAdd: SPIRV::Capability::SubgroupBufferBlockIOINTEL);
1943 }
1944 break;
1945 case SPIRV::OpSubgroupImageBlockReadINTEL:
1946 case SPIRV::OpSubgroupImageBlockWriteINTEL:
1947 if (ST.canUseExtension(E: SPIRV::Extension::SPV_INTEL_subgroups)) {
1948 Reqs.addExtension(ToAdd: SPIRV::Extension::SPV_INTEL_subgroups);
1949 Reqs.addCapability(ToAdd: SPIRV::Capability::SubgroupImageBlockIOINTEL);
1950 }
1951 break;
1952 case SPIRV::OpSubgroupImageMediaBlockReadINTEL:
1953 case SPIRV::OpSubgroupImageMediaBlockWriteINTEL:
1954 if (ST.canUseExtension(E: SPIRV::Extension::SPV_INTEL_media_block_io)) {
1955 Reqs.addExtension(ToAdd: SPIRV::Extension::SPV_INTEL_media_block_io);
1956 Reqs.addCapability(ToAdd: SPIRV::Capability::SubgroupImageMediaBlockIOINTEL);
1957 }
1958 break;
1959 case SPIRV::OpAssumeTrueKHR:
1960 case SPIRV::OpExpectKHR:
1961 if (ST.canUseExtension(E: SPIRV::Extension::SPV_KHR_expect_assume)) {
1962 Reqs.addExtension(ToAdd: SPIRV::Extension::SPV_KHR_expect_assume);
1963 Reqs.addCapability(ToAdd: SPIRV::Capability::ExpectAssumeKHR);
1964 }
1965 break;
1966 case SPIRV::OpFmaKHR:
1967 if (ST.canUseExtension(E: SPIRV::Extension::SPV_KHR_fma)) {
1968 Reqs.addExtension(ToAdd: SPIRV::Extension::SPV_KHR_fma);
1969 Reqs.addCapability(ToAdd: SPIRV::Capability::FmaKHR);
1970 }
1971 break;
1972 case SPIRV::OpPtrCastToCrossWorkgroupINTEL:
1973 case SPIRV::OpCrossWorkgroupCastToPtrINTEL:
1974 if (ST.canUseExtension(E: SPIRV::Extension::SPV_INTEL_usm_storage_classes)) {
1975 Reqs.addExtension(ToAdd: SPIRV::Extension::SPV_INTEL_usm_storage_classes);
1976 Reqs.addCapability(ToAdd: SPIRV::Capability::USMStorageClassesINTEL);
1977 }
1978 break;
1979 case SPIRV::OpConstantFunctionPointerINTEL:
1980 case SPIRV::OpFunctionPointerCallINTEL:
1981 if (ST.canUseExtension(E: SPIRV::Extension::SPV_INTEL_function_pointers)) {
1982 Reqs.addExtension(ToAdd: SPIRV::Extension::SPV_INTEL_function_pointers);
1983 Reqs.addCapability(ToAdd: SPIRV::Capability::FunctionPointersINTEL);
1984 }
1985 break;
1986 case SPIRV::OpGroupNonUniformRotateKHR:
1987 if (!ST.canUseExtension(E: SPIRV::Extension::SPV_KHR_subgroup_rotate))
1988 report_fatal_error(reason: "OpGroupNonUniformRotateKHR instruction requires the "
1989 "following SPIR-V extension: SPV_KHR_subgroup_rotate",
1990 gen_crash_diag: false);
1991 Reqs.addExtension(ToAdd: SPIRV::Extension::SPV_KHR_subgroup_rotate);
1992 Reqs.addCapability(ToAdd: SPIRV::Capability::GroupNonUniformRotateKHR);
1993 Reqs.addCapability(ToAdd: SPIRV::Capability::GroupNonUniform);
1994 break;
1995 case SPIRV::OpFixedCosALTERA:
1996 case SPIRV::OpFixedSinALTERA:
1997 case SPIRV::OpFixedCosPiALTERA:
1998 case SPIRV::OpFixedSinPiALTERA:
1999 case SPIRV::OpFixedExpALTERA:
2000 case SPIRV::OpFixedLogALTERA:
2001 case SPIRV::OpFixedRecipALTERA:
2002 case SPIRV::OpFixedSqrtALTERA:
2003 case SPIRV::OpFixedSinCosALTERA:
2004 case SPIRV::OpFixedSinCosPiALTERA:
2005 case SPIRV::OpFixedRsqrtALTERA:
2006 if (!ST.canUseExtension(
2007 E: SPIRV::Extension::SPV_ALTERA_arbitrary_precision_fixed_point))
2008 report_fatal_error(reason: "This instruction requires the "
2009 "following SPIR-V extension: "
2010 "SPV_ALTERA_arbitrary_precision_fixed_point",
2011 gen_crash_diag: false);
2012 Reqs.addExtension(
2013 ToAdd: SPIRV::Extension::SPV_ALTERA_arbitrary_precision_fixed_point);
2014 Reqs.addCapability(ToAdd: SPIRV::Capability::ArbitraryPrecisionFixedPointALTERA);
2015 break;
2016 case SPIRV::OpGroupIMulKHR:
2017 case SPIRV::OpGroupFMulKHR:
2018 case SPIRV::OpGroupBitwiseAndKHR:
2019 case SPIRV::OpGroupBitwiseOrKHR:
2020 case SPIRV::OpGroupBitwiseXorKHR:
2021 case SPIRV::OpGroupLogicalAndKHR:
2022 case SPIRV::OpGroupLogicalOrKHR:
2023 case SPIRV::OpGroupLogicalXorKHR:
2024 if (ST.canUseExtension(
2025 E: SPIRV::Extension::SPV_KHR_uniform_group_instructions)) {
2026 Reqs.addExtension(ToAdd: SPIRV::Extension::SPV_KHR_uniform_group_instructions);
2027 Reqs.addCapability(ToAdd: SPIRV::Capability::GroupUniformArithmeticKHR);
2028 }
2029 break;
2030 case SPIRV::OpReadClockKHR:
2031 if (!ST.canUseExtension(E: SPIRV::Extension::SPV_KHR_shader_clock))
2032 report_fatal_error(reason: "OpReadClockKHR instruction requires the "
2033 "following SPIR-V extension: SPV_KHR_shader_clock",
2034 gen_crash_diag: false);
2035 Reqs.addExtension(ToAdd: SPIRV::Extension::SPV_KHR_shader_clock);
2036 Reqs.addCapability(ToAdd: SPIRV::Capability::ShaderClockKHR);
2037 break;
2038 case SPIRV::OpAbortKHR:
2039 if (!ST.canUseExtension(E: SPIRV::Extension::SPV_KHR_abort))
2040 report_fatal_error(reason: "OpAbortKHR instruction requires the "
2041 "following SPIR-V extension: SPV_KHR_abort",
2042 gen_crash_diag: false);
2043 Reqs.addExtension(ToAdd: SPIRV::Extension::SPV_KHR_abort);
2044 Reqs.addCapability(ToAdd: SPIRV::Capability::AbortKHR);
2045 break;
2046 case SPIRV::OpPoisonKHR:
2047 case SPIRV::OpFreezeKHR:
2048 if (!ST.canUseExtension(E: SPIRV::Extension::SPV_KHR_poison_freeze))
2049 report_fatal_error(reason: "OpPoisonKHR/OpFreezeKHR instruction requires the "
2050 "following SPIR-V extension: SPV_KHR_poison_freeze",
2051 gen_crash_diag: false);
2052 Reqs.addExtension(ToAdd: SPIRV::Extension::SPV_KHR_poison_freeze);
2053 Reqs.addCapability(ToAdd: SPIRV::Capability::PoisonFreezeKHR);
2054 break;
2055 case SPIRV::OpAtomicFAddEXT:
2056 case SPIRV::OpAtomicFMinEXT:
2057 case SPIRV::OpAtomicFMaxEXT:
2058 AddAtomicFloatRequirements(MI, Reqs, ST);
2059 break;
2060 case SPIRV::OpConvertBF16ToFINTEL:
2061 case SPIRV::OpConvertFToBF16INTEL:
2062 if (ST.canUseExtension(E: SPIRV::Extension::SPV_INTEL_bfloat16_conversion)) {
2063 Reqs.addExtension(ToAdd: SPIRV::Extension::SPV_INTEL_bfloat16_conversion);
2064 Reqs.addCapability(ToAdd: SPIRV::Capability::BFloat16ConversionINTEL);
2065 }
2066 break;
2067 case SPIRV::OpRoundFToTF32INTEL:
2068 if (ST.canUseExtension(
2069 E: SPIRV::Extension::SPV_INTEL_tensor_float32_conversion)) {
2070 Reqs.addExtension(ToAdd: SPIRV::Extension::SPV_INTEL_tensor_float32_conversion);
2071 Reqs.addCapability(ToAdd: SPIRV::Capability::TensorFloat32RoundingINTEL);
2072 }
2073 break;
2074 case SPIRV::OpVariableLengthArrayINTEL:
2075 case SPIRV::OpSaveMemoryINTEL:
2076 case SPIRV::OpRestoreMemoryINTEL:
2077 if (ST.canUseExtension(E: SPIRV::Extension::SPV_INTEL_variable_length_array)) {
2078 Reqs.addExtension(ToAdd: SPIRV::Extension::SPV_INTEL_variable_length_array);
2079 Reqs.addCapability(ToAdd: SPIRV::Capability::VariableLengthArrayINTEL);
2080 }
2081 break;
2082 case SPIRV::OpAsmTargetINTEL:
2083 case SPIRV::OpAsmINTEL:
2084 case SPIRV::OpAsmCallINTEL:
2085 if (ST.canUseExtension(E: SPIRV::Extension::SPV_INTEL_inline_assembly)) {
2086 Reqs.addExtension(ToAdd: SPIRV::Extension::SPV_INTEL_inline_assembly);
2087 Reqs.addCapability(ToAdd: SPIRV::Capability::AsmINTEL);
2088 }
2089 break;
2090 case SPIRV::OpTypeCooperativeMatrixKHR: {
2091 if (!ST.canUseExtension(E: SPIRV::Extension::SPV_KHR_cooperative_matrix))
2092 report_fatal_error(
2093 reason: "OpTypeCooperativeMatrixKHR type requires the "
2094 "following SPIR-V extension: SPV_KHR_cooperative_matrix",
2095 gen_crash_diag: false);
2096 Reqs.addExtension(ToAdd: SPIRV::Extension::SPV_KHR_cooperative_matrix);
2097 Reqs.addCapability(ToAdd: SPIRV::Capability::CooperativeMatrixKHR);
2098 const MachineRegisterInfo &MRI = MI.getMF()->getRegInfo();
2099 SPIRVTypeInst TypeDef = MRI.getVRegDef(Reg: MI.getOperand(i: 1).getReg());
2100 if (isBFloat16Type(TypeDef))
2101 Reqs.addCapability(ToAdd: SPIRV::Capability::BFloat16CooperativeMatrixKHR);
2102 break;
2103 }
2104 case SPIRV::OpArithmeticFenceEXT:
2105 if (!ST.canUseExtension(E: SPIRV::Extension::SPV_EXT_arithmetic_fence))
2106 report_fatal_error(reason: "OpArithmeticFenceEXT requires the "
2107 "following SPIR-V extension: SPV_EXT_arithmetic_fence",
2108 gen_crash_diag: false);
2109 Reqs.addExtension(ToAdd: SPIRV::Extension::SPV_EXT_arithmetic_fence);
2110 Reqs.addCapability(ToAdd: SPIRV::Capability::ArithmeticFenceEXT);
2111 break;
2112 case SPIRV::OpControlBarrierArriveINTEL:
2113 case SPIRV::OpControlBarrierWaitINTEL:
2114 if (ST.canUseExtension(E: SPIRV::Extension::SPV_INTEL_split_barrier)) {
2115 Reqs.addExtension(ToAdd: SPIRV::Extension::SPV_INTEL_split_barrier);
2116 Reqs.addCapability(ToAdd: SPIRV::Capability::SplitBarrierINTEL);
2117 }
2118 break;
2119 case SPIRV::OpCooperativeMatrixMulAddKHR: {
2120 if (!ST.canUseExtension(E: SPIRV::Extension::SPV_KHR_cooperative_matrix))
2121 report_fatal_error(reason: "Cooperative matrix instructions require the "
2122 "following SPIR-V extension: "
2123 "SPV_KHR_cooperative_matrix",
2124 gen_crash_diag: false);
2125 Reqs.addExtension(ToAdd: SPIRV::Extension::SPV_KHR_cooperative_matrix);
2126 Reqs.addCapability(ToAdd: SPIRV::Capability::CooperativeMatrixKHR);
2127 constexpr unsigned MulAddMaxSize = 6;
2128 if (MI.getNumOperands() != MulAddMaxSize)
2129 break;
2130 const int64_t CoopOperands = MI.getOperand(i: MulAddMaxSize - 1).getImm();
2131 if (CoopOperands &
2132 SPIRV::CooperativeMatrixOperands::MatrixAAndBTF32ComponentsINTEL) {
2133 if (!ST.canUseExtension(E: SPIRV::Extension::SPV_INTEL_joint_matrix))
2134 report_fatal_error(reason: "MatrixAAndBTF32ComponentsINTEL type interpretation "
2135 "require the following SPIR-V extension: "
2136 "SPV_INTEL_joint_matrix",
2137 gen_crash_diag: false);
2138 Reqs.addExtension(ToAdd: SPIRV::Extension::SPV_INTEL_joint_matrix);
2139 Reqs.addCapability(
2140 ToAdd: SPIRV::Capability::CooperativeMatrixTF32ComponentTypeINTEL);
2141 }
2142 if (CoopOperands & SPIRV::CooperativeMatrixOperands::
2143 MatrixAAndBBFloat16ComponentsINTEL ||
2144 CoopOperands &
2145 SPIRV::CooperativeMatrixOperands::MatrixCBFloat16ComponentsINTEL ||
2146 CoopOperands & SPIRV::CooperativeMatrixOperands::
2147 MatrixResultBFloat16ComponentsINTEL) {
2148 if (!ST.canUseExtension(E: SPIRV::Extension::SPV_INTEL_joint_matrix))
2149 report_fatal_error(reason: "***BF16ComponentsINTEL type interpretations "
2150 "require the following SPIR-V extension: "
2151 "SPV_INTEL_joint_matrix",
2152 gen_crash_diag: false);
2153 Reqs.addExtension(ToAdd: SPIRV::Extension::SPV_INTEL_joint_matrix);
2154 Reqs.addCapability(
2155 ToAdd: SPIRV::Capability::CooperativeMatrixBFloat16ComponentTypeINTEL);
2156 }
2157 break;
2158 }
2159 case SPIRV::OpCooperativeMatrixLoadKHR:
2160 case SPIRV::OpCooperativeMatrixStoreKHR:
2161 case SPIRV::OpCooperativeMatrixLoadCheckedINTEL:
2162 case SPIRV::OpCooperativeMatrixStoreCheckedINTEL:
2163 case SPIRV::OpCooperativeMatrixPrefetchINTEL: {
2164 if (!ST.canUseExtension(E: SPIRV::Extension::SPV_KHR_cooperative_matrix))
2165 report_fatal_error(reason: "Cooperative matrix instructions require the "
2166 "following SPIR-V extension: "
2167 "SPV_KHR_cooperative_matrix",
2168 gen_crash_diag: false);
2169 Reqs.addExtension(ToAdd: SPIRV::Extension::SPV_KHR_cooperative_matrix);
2170 Reqs.addCapability(ToAdd: SPIRV::Capability::CooperativeMatrixKHR);
2171
2172 // Check Layout operand in case if it's not a standard one and add the
2173 // appropriate capability.
2174 unsigned LayoutNum;
2175 switch (Op) {
2176 case SPIRV::OpCooperativeMatrixLoadKHR:
2177 LayoutNum = 3;
2178 break;
2179 case SPIRV::OpCooperativeMatrixStoreKHR:
2180 LayoutNum = 2;
2181 break;
2182 case SPIRV::OpCooperativeMatrixLoadCheckedINTEL:
2183 LayoutNum = 5;
2184 break;
2185 case SPIRV::OpCooperativeMatrixStoreCheckedINTEL:
2186 case SPIRV::OpCooperativeMatrixPrefetchINTEL:
2187 LayoutNum = 4;
2188 break;
2189 default:
2190 llvm_unreachable("unexpected cooperative matrix opcode");
2191 }
2192 Register RegLayout = MI.getOperand(i: LayoutNum).getReg();
2193 const MachineRegisterInfo &MRI = MI.getMF()->getRegInfo();
2194 MachineInstr *MILayout = MRI.getUniqueVRegDef(Reg: RegLayout);
2195 if (MILayout->getOpcode() == SPIRV::OpConstantI) {
2196 const unsigned LayoutVal = MILayout->getOperand(i: 2).getImm();
2197 if (LayoutVal ==
2198 static_cast<unsigned>(SPIRV::CooperativeMatrixLayout::PackedINTEL)) {
2199 if (!ST.canUseExtension(E: SPIRV::Extension::SPV_INTEL_joint_matrix))
2200 report_fatal_error(reason: "PackedINTEL layout require the following SPIR-V "
2201 "extension: SPV_INTEL_joint_matrix",
2202 gen_crash_diag: false);
2203 Reqs.addExtension(ToAdd: SPIRV::Extension::SPV_INTEL_joint_matrix);
2204 Reqs.addCapability(ToAdd: SPIRV::Capability::PackedCooperativeMatrixINTEL);
2205 }
2206 }
2207
2208 // Nothing to do.
2209 if (Op == SPIRV::OpCooperativeMatrixLoadKHR ||
2210 Op == SPIRV::OpCooperativeMatrixStoreKHR)
2211 break;
2212
2213 std::string InstName;
2214 switch (Op) {
2215 case SPIRV::OpCooperativeMatrixPrefetchINTEL:
2216 InstName = "OpCooperativeMatrixPrefetchINTEL";
2217 break;
2218 case SPIRV::OpCooperativeMatrixLoadCheckedINTEL:
2219 InstName = "OpCooperativeMatrixLoadCheckedINTEL";
2220 break;
2221 case SPIRV::OpCooperativeMatrixStoreCheckedINTEL:
2222 InstName = "OpCooperativeMatrixStoreCheckedINTEL";
2223 break;
2224 }
2225
2226 if (!ST.canUseExtension(E: SPIRV::Extension::SPV_INTEL_joint_matrix)) {
2227 const std::string ErrorMsg =
2228 InstName + " instruction requires the "
2229 "following SPIR-V extension: SPV_INTEL_joint_matrix";
2230 report_fatal_error(reason: ErrorMsg.c_str(), gen_crash_diag: false);
2231 }
2232 Reqs.addExtension(ToAdd: SPIRV::Extension::SPV_INTEL_joint_matrix);
2233 if (Op == SPIRV::OpCooperativeMatrixPrefetchINTEL) {
2234 Reqs.addCapability(ToAdd: SPIRV::Capability::CooperativeMatrixPrefetchINTEL);
2235 break;
2236 }
2237 Reqs.addCapability(
2238 ToAdd: SPIRV::Capability::CooperativeMatrixCheckedInstructionsINTEL);
2239 break;
2240 }
2241 case SPIRV::OpCooperativeMatrixConstructCheckedINTEL:
2242 if (!ST.canUseExtension(E: SPIRV::Extension::SPV_INTEL_joint_matrix))
2243 report_fatal_error(reason: "OpCooperativeMatrixConstructCheckedINTEL "
2244 "instructions require the following SPIR-V extension: "
2245 "SPV_INTEL_joint_matrix",
2246 gen_crash_diag: false);
2247 Reqs.addExtension(ToAdd: SPIRV::Extension::SPV_INTEL_joint_matrix);
2248 Reqs.addCapability(
2249 ToAdd: SPIRV::Capability::CooperativeMatrixCheckedInstructionsINTEL);
2250 break;
2251 case SPIRV::OpReadPipeBlockingALTERA:
2252 case SPIRV::OpWritePipeBlockingALTERA:
2253 if (ST.canUseExtension(E: SPIRV::Extension::SPV_ALTERA_blocking_pipes)) {
2254 Reqs.addExtension(ToAdd: SPIRV::Extension::SPV_ALTERA_blocking_pipes);
2255 Reqs.addCapability(ToAdd: SPIRV::Capability::BlockingPipesALTERA);
2256 }
2257 break;
2258 case SPIRV::OpCooperativeMatrixGetElementCoordINTEL:
2259 if (!ST.canUseExtension(E: SPIRV::Extension::SPV_INTEL_joint_matrix))
2260 report_fatal_error(reason: "OpCooperativeMatrixGetElementCoordINTEL requires the "
2261 "following SPIR-V extension: SPV_INTEL_joint_matrix",
2262 gen_crash_diag: false);
2263 Reqs.addExtension(ToAdd: SPIRV::Extension::SPV_INTEL_joint_matrix);
2264 Reqs.addCapability(
2265 ToAdd: SPIRV::Capability::CooperativeMatrixInvocationInstructionsINTEL);
2266 break;
2267 case SPIRV::OpConvertHandleToImageINTEL:
2268 case SPIRV::OpConvertHandleToSamplerINTEL:
2269 case SPIRV::OpConvertHandleToSampledImageINTEL: {
2270 if (!ST.canUseExtension(E: SPIRV::Extension::SPV_INTEL_bindless_images))
2271 report_fatal_error(reason: "OpConvertHandleTo[Image/Sampler/SampledImage]INTEL "
2272 "instructions require the following SPIR-V extension: "
2273 "SPV_INTEL_bindless_images",
2274 gen_crash_diag: false);
2275 SPIRVGlobalRegistry *GR = ST.getSPIRVGlobalRegistry();
2276 SPIRV::AddressingModel::AddressingModel AddrModel = MAI.Addr;
2277 SPIRVTypeInst TyDef = GR->getSPIRVTypeForVReg(VReg: MI.getOperand(i: 1).getReg());
2278 if (Op == SPIRV::OpConvertHandleToImageINTEL &&
2279 TyDef->getOpcode() != SPIRV::OpTypeImage) {
2280 report_fatal_error(reason: "Incorrect return type for the instruction "
2281 "OpConvertHandleToImageINTEL",
2282 gen_crash_diag: false);
2283 } else if (Op == SPIRV::OpConvertHandleToSamplerINTEL &&
2284 TyDef->getOpcode() != SPIRV::OpTypeSampler) {
2285 report_fatal_error(reason: "Incorrect return type for the instruction "
2286 "OpConvertHandleToSamplerINTEL",
2287 gen_crash_diag: false);
2288 } else if (Op == SPIRV::OpConvertHandleToSampledImageINTEL &&
2289 TyDef->getOpcode() != SPIRV::OpTypeSampledImage) {
2290 report_fatal_error(reason: "Incorrect return type for the instruction "
2291 "OpConvertHandleToSampledImageINTEL",
2292 gen_crash_diag: false);
2293 }
2294 SPIRVTypeInst SpvTy = GR->getSPIRVTypeForVReg(VReg: MI.getOperand(i: 2).getReg());
2295 unsigned Bitwidth = GR->getScalarOrVectorBitWidth(Type: SpvTy);
2296 if (!(Bitwidth == 32 && AddrModel == SPIRV::AddressingModel::Physical32) &&
2297 !(Bitwidth == 64 && AddrModel == SPIRV::AddressingModel::Physical64)) {
2298 report_fatal_error(
2299 reason: "Parameter value must be a 32-bit scalar in case of "
2300 "Physical32 addressing model or a 64-bit scalar in case of "
2301 "Physical64 addressing model",
2302 gen_crash_diag: false);
2303 }
2304 Reqs.addExtension(ToAdd: SPIRV::Extension::SPV_INTEL_bindless_images);
2305 Reqs.addCapability(ToAdd: SPIRV::Capability::BindlessImagesINTEL);
2306 break;
2307 }
2308 case SPIRV::OpSubgroup2DBlockLoadINTEL:
2309 case SPIRV::OpSubgroup2DBlockLoadTransposeINTEL:
2310 case SPIRV::OpSubgroup2DBlockLoadTransformINTEL:
2311 case SPIRV::OpSubgroup2DBlockPrefetchINTEL:
2312 case SPIRV::OpSubgroup2DBlockStoreINTEL: {
2313 if (!ST.canUseExtension(E: SPIRV::Extension::SPV_INTEL_2d_block_io))
2314 report_fatal_error(reason: "OpSubgroup2DBlock[Load/LoadTranspose/LoadTransform/"
2315 "Prefetch/Store]INTEL instructions require the "
2316 "following SPIR-V extension: SPV_INTEL_2d_block_io",
2317 gen_crash_diag: false);
2318 Reqs.addExtension(ToAdd: SPIRV::Extension::SPV_INTEL_2d_block_io);
2319 Reqs.addCapability(ToAdd: SPIRV::Capability::Subgroup2DBlockIOINTEL);
2320
2321 if (Op == SPIRV::OpSubgroup2DBlockLoadTransposeINTEL) {
2322 Reqs.addCapability(ToAdd: SPIRV::Capability::Subgroup2DBlockTransposeINTEL);
2323 break;
2324 }
2325 if (Op == SPIRV::OpSubgroup2DBlockLoadTransformINTEL) {
2326 Reqs.addCapability(ToAdd: SPIRV::Capability::Subgroup2DBlockTransformINTEL);
2327 break;
2328 }
2329 break;
2330 }
2331 case SPIRV::OpKill: {
2332 Reqs.addCapability(ToAdd: SPIRV::Capability::Shader);
2333 } break;
2334 case SPIRV::OpDemoteToHelperInvocation:
2335 Reqs.addCapability(ToAdd: SPIRV::Capability::DemoteToHelperInvocation);
2336
2337 if (ST.canUseExtension(
2338 E: SPIRV::Extension::SPV_EXT_demote_to_helper_invocation)) {
2339 if (!ST.isAtLeastSPIRVVer(VerToCompareTo: llvm::VersionTuple(1, 6)))
2340 Reqs.addExtension(
2341 ToAdd: SPIRV::Extension::SPV_EXT_demote_to_helper_invocation);
2342 }
2343 break;
2344 case SPIRV::OpSDot:
2345 case SPIRV::OpUDot:
2346 case SPIRV::OpSUDot:
2347 case SPIRV::OpSDotAccSat:
2348 case SPIRV::OpUDotAccSat:
2349 case SPIRV::OpSUDotAccSat:
2350 AddDotProductRequirements(MI, Reqs, ST);
2351 break;
2352 case SPIRV::OpImageSampleImplicitLod:
2353 case SPIRV::OpImageFetch:
2354 Reqs.addCapability(ToAdd: SPIRV::Capability::Shader);
2355 addImageOperandReqs(MI, Reqs, ST, OpIdx: 4);
2356 break;
2357 case SPIRV::OpImageSampleExplicitLod:
2358 addImageOperandReqs(MI, Reqs, ST, OpIdx: 4);
2359 break;
2360 case SPIRV::OpImageSampleDrefImplicitLod:
2361 case SPIRV::OpImageSampleDrefExplicitLod:
2362 case SPIRV::OpImageDrefGather:
2363 case SPIRV::OpImageGather:
2364 Reqs.addCapability(ToAdd: SPIRV::Capability::Shader);
2365 addImageOperandReqs(MI, Reqs, ST, OpIdx: 5);
2366 break;
2367 case SPIRV::OpImageRead: {
2368 Register ImageReg = MI.getOperand(i: 2).getReg();
2369 SPIRVTypeInst TypeDef = ST.getSPIRVGlobalRegistry()->getResultType(
2370 VReg: ImageReg, MF: const_cast<MachineFunction *>(MI.getMF()));
2371 // OpImageRead and OpImageWrite can use Unknown Image Formats
2372 // when the Kernel capability is declared. In the OpenCL environment we are
2373 // not allowed to produce
2374 // StorageImageReadWithoutFormat/StorageImageWriteWithoutFormat, see
2375 // https://github.com/KhronosGroup/SPIRV-Headers/issues/487
2376
2377 if (isImageTypeWithUnknownFormat(TypeInst: TypeDef) && ST.isShader())
2378 Reqs.addCapability(ToAdd: SPIRV::Capability::StorageImageReadWithoutFormat);
2379 break;
2380 }
2381 case SPIRV::OpImageWrite: {
2382 Register ImageReg = MI.getOperand(i: 0).getReg();
2383 SPIRVTypeInst TypeDef = ST.getSPIRVGlobalRegistry()->getResultType(
2384 VReg: ImageReg, MF: const_cast<MachineFunction *>(MI.getMF()));
2385 // OpImageRead and OpImageWrite can use Unknown Image Formats
2386 // when the Kernel capability is declared. In the OpenCL environment we are
2387 // not allowed to produce
2388 // StorageImageReadWithoutFormat/StorageImageWriteWithoutFormat, see
2389 // https://github.com/KhronosGroup/SPIRV-Headers/issues/487
2390
2391 if (isImageTypeWithUnknownFormat(TypeInst: TypeDef) && ST.isShader())
2392 Reqs.addCapability(ToAdd: SPIRV::Capability::StorageImageWriteWithoutFormat);
2393 break;
2394 }
2395 case SPIRV::OpTypeStructContinuedINTEL:
2396 case SPIRV::OpConstantCompositeContinuedINTEL:
2397 case SPIRV::OpSpecConstantCompositeContinuedINTEL:
2398 case SPIRV::OpCompositeConstructContinuedINTEL: {
2399 if (!ST.canUseExtension(E: SPIRV::Extension::SPV_INTEL_long_composites))
2400 report_fatal_error(
2401 reason: "Continued instructions require the "
2402 "following SPIR-V extension: SPV_INTEL_long_composites",
2403 gen_crash_diag: false);
2404 Reqs.addExtension(ToAdd: SPIRV::Extension::SPV_INTEL_long_composites);
2405 Reqs.addCapability(ToAdd: SPIRV::Capability::LongCompositesINTEL);
2406 break;
2407 }
2408 case SPIRV::OpArbitraryFloatEQALTERA:
2409 case SPIRV::OpArbitraryFloatGEALTERA:
2410 case SPIRV::OpArbitraryFloatGTALTERA:
2411 case SPIRV::OpArbitraryFloatLEALTERA:
2412 case SPIRV::OpArbitraryFloatLTALTERA:
2413 case SPIRV::OpArbitraryFloatCbrtALTERA:
2414 case SPIRV::OpArbitraryFloatCosALTERA:
2415 case SPIRV::OpArbitraryFloatCosPiALTERA:
2416 case SPIRV::OpArbitraryFloatExp10ALTERA:
2417 case SPIRV::OpArbitraryFloatExp2ALTERA:
2418 case SPIRV::OpArbitraryFloatExpALTERA:
2419 case SPIRV::OpArbitraryFloatExpm1ALTERA:
2420 case SPIRV::OpArbitraryFloatHypotALTERA:
2421 case SPIRV::OpArbitraryFloatLog10ALTERA:
2422 case SPIRV::OpArbitraryFloatLog1pALTERA:
2423 case SPIRV::OpArbitraryFloatLog2ALTERA:
2424 case SPIRV::OpArbitraryFloatLogALTERA:
2425 case SPIRV::OpArbitraryFloatRecipALTERA:
2426 case SPIRV::OpArbitraryFloatSinCosALTERA:
2427 case SPIRV::OpArbitraryFloatSinCosPiALTERA:
2428 case SPIRV::OpArbitraryFloatSinALTERA:
2429 case SPIRV::OpArbitraryFloatSinPiALTERA:
2430 case SPIRV::OpArbitraryFloatSqrtALTERA:
2431 case SPIRV::OpArbitraryFloatACosALTERA:
2432 case SPIRV::OpArbitraryFloatACosPiALTERA:
2433 case SPIRV::OpArbitraryFloatAddALTERA:
2434 case SPIRV::OpArbitraryFloatASinALTERA:
2435 case SPIRV::OpArbitraryFloatASinPiALTERA:
2436 case SPIRV::OpArbitraryFloatATan2ALTERA:
2437 case SPIRV::OpArbitraryFloatATanALTERA:
2438 case SPIRV::OpArbitraryFloatATanPiALTERA:
2439 case SPIRV::OpArbitraryFloatCastFromIntALTERA:
2440 case SPIRV::OpArbitraryFloatCastALTERA:
2441 case SPIRV::OpArbitraryFloatCastToIntALTERA:
2442 case SPIRV::OpArbitraryFloatDivALTERA:
2443 case SPIRV::OpArbitraryFloatMulALTERA:
2444 case SPIRV::OpArbitraryFloatPowALTERA:
2445 case SPIRV::OpArbitraryFloatPowNALTERA:
2446 case SPIRV::OpArbitraryFloatPowRALTERA:
2447 case SPIRV::OpArbitraryFloatRSqrtALTERA:
2448 case SPIRV::OpArbitraryFloatSubALTERA: {
2449 if (!ST.canUseExtension(
2450 E: SPIRV::Extension::SPV_ALTERA_arbitrary_precision_floating_point))
2451 report_fatal_error(
2452 reason: "Floating point instructions can't be translated correctly without "
2453 "enabled SPV_ALTERA_arbitrary_precision_floating_point extension!",
2454 gen_crash_diag: false);
2455 Reqs.addExtension(
2456 ToAdd: SPIRV::Extension::SPV_ALTERA_arbitrary_precision_floating_point);
2457 Reqs.addCapability(
2458 ToAdd: SPIRV::Capability::ArbitraryPrecisionFloatingPointALTERA);
2459 break;
2460 }
2461 case SPIRV::OpSubgroupMatrixMultiplyAccumulateINTEL: {
2462 if (!ST.canUseExtension(
2463 E: SPIRV::Extension::SPV_INTEL_subgroup_matrix_multiply_accumulate))
2464 report_fatal_error(
2465 reason: "OpSubgroupMatrixMultiplyAccumulateINTEL instruction requires the "
2466 "following SPIR-V "
2467 "extension: SPV_INTEL_subgroup_matrix_multiply_accumulate",
2468 gen_crash_diag: false);
2469 Reqs.addExtension(
2470 ToAdd: SPIRV::Extension::SPV_INTEL_subgroup_matrix_multiply_accumulate);
2471 Reqs.addCapability(
2472 ToAdd: SPIRV::Capability::SubgroupMatrixMultiplyAccumulateINTEL);
2473 break;
2474 }
2475 case SPIRV::OpBitwiseFunctionINTEL: {
2476 if (!ST.canUseExtension(
2477 E: SPIRV::Extension::SPV_INTEL_ternary_bitwise_function))
2478 report_fatal_error(
2479 reason: "OpBitwiseFunctionINTEL instruction requires the following SPIR-V "
2480 "extension: SPV_INTEL_ternary_bitwise_function",
2481 gen_crash_diag: false);
2482 Reqs.addExtension(ToAdd: SPIRV::Extension::SPV_INTEL_ternary_bitwise_function);
2483 Reqs.addCapability(ToAdd: SPIRV::Capability::TernaryBitwiseFunctionINTEL);
2484 break;
2485 }
2486 case SPIRV::OpCopyMemorySized: {
2487 Reqs.addCapability(ToAdd: SPIRV::Capability::Addresses);
2488 // TODO: Add UntypedPointersKHR when implemented.
2489 break;
2490 }
2491 case SPIRV::OpPredicatedLoadINTEL:
2492 case SPIRV::OpPredicatedStoreINTEL: {
2493 if (!ST.canUseExtension(E: SPIRV::Extension::SPV_INTEL_predicated_io))
2494 report_fatal_error(
2495 reason: "OpPredicated[Load/Store]INTEL instructions require "
2496 "the following SPIR-V extension: SPV_INTEL_predicated_io",
2497 gen_crash_diag: false);
2498 Reqs.addExtension(ToAdd: SPIRV::Extension::SPV_INTEL_predicated_io);
2499 Reqs.addCapability(ToAdd: SPIRV::Capability::PredicatedIOINTEL);
2500 break;
2501 }
2502 case SPIRV::OpFAddS:
2503 case SPIRV::OpFSubS:
2504 case SPIRV::OpFMulS:
2505 case SPIRV::OpFDivS:
2506 case SPIRV::OpFRemS:
2507 case SPIRV::OpFMod:
2508 case SPIRV::OpFNegate:
2509 case SPIRV::OpFAddV:
2510 case SPIRV::OpFSubV:
2511 case SPIRV::OpFMulV:
2512 case SPIRV::OpFDivV:
2513 case SPIRV::OpFRemV:
2514 case SPIRV::OpFNegateV: {
2515 const MachineRegisterInfo &MRI = MI.getMF()->getRegInfo();
2516 SPIRVTypeInst TypeDef = MRI.getVRegDef(Reg: MI.getOperand(i: 1).getReg());
2517 if (TypeDef->getOpcode() == SPIRV::OpTypeVector)
2518 TypeDef = MRI.getVRegDef(Reg: TypeDef->getOperand(i: 1).getReg());
2519 if (isBFloat16Type(TypeDef)) {
2520 if (!ST.canUseExtension(E: SPIRV::Extension::SPV_INTEL_bfloat16_arithmetic))
2521 report_fatal_error(
2522 reason: "Arithmetic instructions with bfloat16 arguments require the "
2523 "following SPIR-V extension: SPV_INTEL_bfloat16_arithmetic",
2524 gen_crash_diag: false);
2525 Reqs.addExtension(ToAdd: SPIRV::Extension::SPV_INTEL_bfloat16_arithmetic);
2526 Reqs.addCapability(ToAdd: SPIRV::Capability::BFloat16ArithmeticINTEL);
2527 }
2528 break;
2529 }
2530 case SPIRV::OpOrdered:
2531 case SPIRV::OpUnordered:
2532 case SPIRV::OpFOrdEqual:
2533 case SPIRV::OpFOrdNotEqual:
2534 case SPIRV::OpFOrdLessThan:
2535 case SPIRV::OpFOrdLessThanEqual:
2536 case SPIRV::OpFOrdGreaterThan:
2537 case SPIRV::OpFOrdGreaterThanEqual:
2538 case SPIRV::OpFUnordEqual:
2539 case SPIRV::OpFUnordNotEqual:
2540 case SPIRV::OpFUnordLessThan:
2541 case SPIRV::OpFUnordLessThanEqual:
2542 case SPIRV::OpFUnordGreaterThan:
2543 case SPIRV::OpFUnordGreaterThanEqual: {
2544 const MachineRegisterInfo &MRI = MI.getMF()->getRegInfo();
2545 MachineInstr *OperandDef = MRI.getVRegDef(Reg: MI.getOperand(i: 2).getReg());
2546 SPIRVTypeInst TypeDef = MRI.getVRegDef(Reg: OperandDef->getOperand(i: 1).getReg());
2547 if (TypeDef->getOpcode() == SPIRV::OpTypeVector)
2548 TypeDef = MRI.getVRegDef(Reg: TypeDef->getOperand(i: 1).getReg());
2549 if (isBFloat16Type(TypeDef)) {
2550 if (!ST.canUseExtension(E: SPIRV::Extension::SPV_INTEL_bfloat16_arithmetic))
2551 report_fatal_error(
2552 reason: "Relational instructions with bfloat16 arguments require the "
2553 "following SPIR-V extension: SPV_INTEL_bfloat16_arithmetic",
2554 gen_crash_diag: false);
2555 Reqs.addExtension(ToAdd: SPIRV::Extension::SPV_INTEL_bfloat16_arithmetic);
2556 Reqs.addCapability(ToAdd: SPIRV::Capability::BFloat16ArithmeticINTEL);
2557 }
2558 break;
2559 }
2560 case SPIRV::OpDPdxCoarse:
2561 case SPIRV::OpDPdyCoarse:
2562 case SPIRV::OpDPdxFine:
2563 case SPIRV::OpDPdyFine: {
2564 Reqs.addCapability(ToAdd: SPIRV::Capability::DerivativeControl);
2565 break;
2566 }
2567 case SPIRV::OpLoopControlINTEL: {
2568 Reqs.addExtension(ToAdd: SPIRV::Extension::SPV_INTEL_unstructured_loop_controls);
2569 Reqs.addCapability(ToAdd: SPIRV::Capability::UnstructuredLoopControlsINTEL);
2570 break;
2571 }
2572
2573 default:
2574 break;
2575 }
2576
2577 // If we require capability Shader, then we can remove the requirement for
2578 // the BitInstructions capability, since Shader is a superset capability
2579 // of BitInstructions.
2580 Reqs.removeCapabilityIf(ToRemove: SPIRV::Capability::BitInstructions,
2581 IfPresent: SPIRV::Capability::Shader);
2582}
2583
2584static void collectReqs(const Module &M, SPIRV::ModuleAnalysisInfo &MAI,
2585 MachineModuleInfo *MMI, const SPIRVSubtarget &ST) {
2586 // Collect requirements for existing instructions.
2587 for (const Function &F : M) {
2588 MachineFunction *MF = MMI->getMachineFunction(F);
2589 if (!MF)
2590 continue;
2591 for (const MachineBasicBlock &MBB : *MF)
2592 for (const MachineInstr &MI : MBB)
2593 addInstrRequirements(MI, MAI, ST);
2594 }
2595 // Collect requirements for OpExecutionMode instructions.
2596 auto Node = M.getNamedMetadata(Name: "spirv.ExecutionMode");
2597 if (Node) {
2598 bool RequireFloatControls = false, RequireIntelFloatControls2 = false,
2599 RequireKHRFloatControls2 = false,
2600 VerLower14 = !ST.isAtLeastSPIRVVer(VerToCompareTo: VersionTuple(1, 4));
2601 bool HasIntelFloatControls2 =
2602 ST.canUseExtension(E: SPIRV::Extension::SPV_INTEL_float_controls2);
2603 bool HasKHRFloatControls2 =
2604 ST.canUseExtension(E: SPIRV::Extension::SPV_KHR_float_controls2);
2605 for (unsigned i = 0; i < Node->getNumOperands(); i++) {
2606 MDNode *MDN = cast<MDNode>(Val: Node->getOperand(i));
2607 const MDOperand &MDOp = MDN->getOperand(I: 1);
2608 if (auto *CMeta = dyn_cast<ConstantAsMetadata>(Val: MDOp)) {
2609 Constant *C = CMeta->getValue();
2610 if (ConstantInt *Const = dyn_cast<ConstantInt>(Val: C)) {
2611 auto EM = Const->getZExtValue();
2612 // SPV_KHR_float_controls is not available until v1.4:
2613 // add SPV_KHR_float_controls if the version is too low
2614 switch (EM) {
2615 case SPIRV::ExecutionMode::DenormPreserve:
2616 case SPIRV::ExecutionMode::DenormFlushToZero:
2617 case SPIRV::ExecutionMode::RoundingModeRTE:
2618 case SPIRV::ExecutionMode::RoundingModeRTZ:
2619 RequireFloatControls = VerLower14;
2620 MAI.Reqs.getAndAddRequirements(
2621 Category: SPIRV::OperandCategory::ExecutionModeOperand, i: EM, ST);
2622 break;
2623 case SPIRV::ExecutionMode::RoundingModeRTPINTEL:
2624 case SPIRV::ExecutionMode::RoundingModeRTNINTEL:
2625 case SPIRV::ExecutionMode::FloatingPointModeALTINTEL:
2626 case SPIRV::ExecutionMode::FloatingPointModeIEEEINTEL:
2627 if (HasIntelFloatControls2) {
2628 RequireIntelFloatControls2 = true;
2629 MAI.Reqs.getAndAddRequirements(
2630 Category: SPIRV::OperandCategory::ExecutionModeOperand, i: EM, ST);
2631 }
2632 break;
2633 case SPIRV::ExecutionMode::FPFastMathDefault: {
2634 if (HasKHRFloatControls2) {
2635 RequireKHRFloatControls2 = true;
2636 MAI.Reqs.getAndAddRequirements(
2637 Category: SPIRV::OperandCategory::ExecutionModeOperand, i: EM, ST);
2638 }
2639 break;
2640 }
2641 case SPIRV::ExecutionMode::ContractionOff:
2642 case SPIRV::ExecutionMode::SignedZeroInfNanPreserve:
2643 if (HasKHRFloatControls2) {
2644 RequireKHRFloatControls2 = true;
2645 MAI.Reqs.getAndAddRequirements(
2646 Category: SPIRV::OperandCategory::ExecutionModeOperand,
2647 i: SPIRV::ExecutionMode::FPFastMathDefault, ST);
2648 } else {
2649 MAI.Reqs.getAndAddRequirements(
2650 Category: SPIRV::OperandCategory::ExecutionModeOperand, i: EM, ST);
2651 }
2652 break;
2653 default:
2654 MAI.Reqs.getAndAddRequirements(
2655 Category: SPIRV::OperandCategory::ExecutionModeOperand, i: EM, ST);
2656 }
2657 }
2658 }
2659 }
2660 if (RequireFloatControls &&
2661 ST.canUseExtension(E: SPIRV::Extension::SPV_KHR_float_controls))
2662 MAI.Reqs.addExtension(ToAdd: SPIRV::Extension::SPV_KHR_float_controls);
2663 if (RequireIntelFloatControls2)
2664 MAI.Reqs.addExtension(ToAdd: SPIRV::Extension::SPV_INTEL_float_controls2);
2665 if (RequireKHRFloatControls2)
2666 MAI.Reqs.addExtension(ToAdd: SPIRV::Extension::SPV_KHR_float_controls2);
2667 }
2668 for (const Function &F : M) {
2669 if (F.isDeclaration())
2670 continue;
2671 if (F.getMetadata(Kind: "reqd_work_group_size"))
2672 MAI.Reqs.getAndAddRequirements(
2673 Category: SPIRV::OperandCategory::ExecutionModeOperand,
2674 i: SPIRV::ExecutionMode::LocalSize, ST);
2675 if (F.getFnAttribute(Kind: "hlsl.numthreads").isValid()) {
2676 MAI.Reqs.getAndAddRequirements(
2677 Category: SPIRV::OperandCategory::ExecutionModeOperand,
2678 i: SPIRV::ExecutionMode::LocalSize, ST);
2679 }
2680 if (F.getFnAttribute(Kind: "enable-maximal-reconvergence").getValueAsBool()) {
2681 MAI.Reqs.addExtension(ToAdd: SPIRV::Extension::SPV_KHR_maximal_reconvergence);
2682 }
2683 if (F.getMetadata(Kind: "work_group_size_hint"))
2684 MAI.Reqs.getAndAddRequirements(
2685 Category: SPIRV::OperandCategory::ExecutionModeOperand,
2686 i: SPIRV::ExecutionMode::LocalSizeHint, ST);
2687 if (F.getMetadata(Kind: "intel_reqd_sub_group_size") ||
2688 F.getMetadata(Kind: "reqd_sub_group_size"))
2689 MAI.Reqs.getAndAddRequirements(
2690 Category: SPIRV::OperandCategory::ExecutionModeOperand,
2691 i: SPIRV::ExecutionMode::SubgroupSize, ST);
2692 if (F.getMetadata(Kind: "max_work_group_size"))
2693 MAI.Reqs.getAndAddRequirements(
2694 Category: SPIRV::OperandCategory::ExecutionModeOperand,
2695 i: SPIRV::ExecutionMode::MaxWorkgroupSizeINTEL, ST);
2696 if (F.getMetadata(Kind: "vec_type_hint"))
2697 MAI.Reqs.getAndAddRequirements(
2698 Category: SPIRV::OperandCategory::ExecutionModeOperand,
2699 i: SPIRV::ExecutionMode::VecTypeHint, ST);
2700
2701 if (F.hasOptNone()) {
2702 if (ST.canUseExtension(E: SPIRV::Extension::SPV_INTEL_optnone)) {
2703 MAI.Reqs.addExtension(ToAdd: SPIRV::Extension::SPV_INTEL_optnone);
2704 MAI.Reqs.addCapability(ToAdd: SPIRV::Capability::OptNoneINTEL);
2705 } else if (ST.canUseExtension(E: SPIRV::Extension::SPV_EXT_optnone)) {
2706 MAI.Reqs.addExtension(ToAdd: SPIRV::Extension::SPV_EXT_optnone);
2707 MAI.Reqs.addCapability(ToAdd: SPIRV::Capability::OptNoneEXT);
2708 }
2709 }
2710 }
2711}
2712
2713static unsigned getFastMathFlags(const MachineInstr &I,
2714 const SPIRVSubtarget &ST) {
2715 unsigned Flags = SPIRV::FPFastMathMode::None;
2716 bool CanUseKHRFloatControls2 =
2717 ST.canUseExtension(E: SPIRV::Extension::SPV_KHR_float_controls2);
2718 if (I.getFlag(Flag: MachineInstr::MIFlag::FmNoNans))
2719 Flags |= SPIRV::FPFastMathMode::NotNaN;
2720 if (I.getFlag(Flag: MachineInstr::MIFlag::FmNoInfs))
2721 Flags |= SPIRV::FPFastMathMode::NotInf;
2722 if (I.getFlag(Flag: MachineInstr::MIFlag::FmNsz))
2723 Flags |= SPIRV::FPFastMathMode::NSZ;
2724 if (I.getFlag(Flag: MachineInstr::MIFlag::FmArcp))
2725 Flags |= SPIRV::FPFastMathMode::AllowRecip;
2726 if (I.getFlag(Flag: MachineInstr::MIFlag::FmContract) && CanUseKHRFloatControls2)
2727 Flags |= SPIRV::FPFastMathMode::AllowContract;
2728 if (I.getFlag(Flag: MachineInstr::MIFlag::FmReassoc)) {
2729 if (CanUseKHRFloatControls2)
2730 // LLVM reassoc maps to SPIRV transform, see
2731 // https://github.com/KhronosGroup/SPIRV-Registry/issues/326 for details.
2732 // Because we are enabling AllowTransform, we must enable AllowReassoc and
2733 // AllowContract too, as required by SPIRV spec. Also, we used to map
2734 // MIFlag::FmReassoc to FPFastMathMode::Fast, which now should instead by
2735 // replaced by turning all the other bits instead. Therefore, we're
2736 // enabling every bit here except None and Fast.
2737 Flags |= SPIRV::FPFastMathMode::NotNaN | SPIRV::FPFastMathMode::NotInf |
2738 SPIRV::FPFastMathMode::NSZ | SPIRV::FPFastMathMode::AllowRecip |
2739 SPIRV::FPFastMathMode::AllowTransform |
2740 SPIRV::FPFastMathMode::AllowReassoc |
2741 SPIRV::FPFastMathMode::AllowContract;
2742 else
2743 Flags |= SPIRV::FPFastMathMode::Fast;
2744 }
2745
2746 if (CanUseKHRFloatControls2) {
2747 // Error out if SPIRV::FPFastMathMode::Fast is enabled.
2748 assert(!(Flags & SPIRV::FPFastMathMode::Fast) &&
2749 "SPIRV::FPFastMathMode::Fast is deprecated and should not be used "
2750 "anymore.");
2751
2752 // Error out if AllowTransform is enabled without AllowReassoc and
2753 // AllowContract.
2754 assert((!(Flags & SPIRV::FPFastMathMode::AllowTransform) ||
2755 ((Flags & SPIRV::FPFastMathMode::AllowReassoc &&
2756 Flags & SPIRV::FPFastMathMode::AllowContract))) &&
2757 "SPIRV::FPFastMathMode::AllowTransform requires AllowReassoc and "
2758 "AllowContract flags to be enabled as well.");
2759 }
2760
2761 return Flags;
2762}
2763
2764static bool isFastMathModeAvailable(const SPIRVSubtarget &ST) {
2765 if (ST.isKernel())
2766 return true;
2767 if (ST.getSPIRVVersion() < VersionTuple(1, 2))
2768 return false;
2769 return ST.canUseExtension(E: SPIRV::Extension::SPV_KHR_float_controls2);
2770}
2771
2772static void handleMIFlagDecoration(
2773 MachineInstr &I, const SPIRVSubtarget &ST, const SPIRVInstrInfo &TII,
2774 SPIRV::RequirementHandler &Reqs, const SPIRVGlobalRegistry *GR,
2775 SPIRV::FPFastMathDefaultInfoVector &FPFastMathDefaultInfoVec) {
2776 if (TII.canUseIntegerWrapDecoration(MI: I)) {
2777 if (I.getFlag(Flag: MachineInstr::MIFlag::NoSWrap) &&
2778 getSymbolicOperandRequirements(
2779 Category: SPIRV::OperandCategory::DecorationOperand,
2780 i: SPIRV::Decoration::NoSignedWrap, ST, Reqs)
2781 .IsSatisfiable)
2782 buildOpDecorate(Reg: I.getOperand(i: 0).getReg(), I, TII,
2783 Dec: SPIRV::Decoration::NoSignedWrap, DecArgs: {});
2784 if (I.getFlag(Flag: MachineInstr::MIFlag::NoUWrap) &&
2785 getSymbolicOperandRequirements(
2786 Category: SPIRV::OperandCategory::DecorationOperand,
2787 i: SPIRV::Decoration::NoUnsignedWrap, ST, Reqs)
2788 .IsSatisfiable)
2789 buildOpDecorate(Reg: I.getOperand(i: 0).getReg(), I, TII,
2790 Dec: SPIRV::Decoration::NoUnsignedWrap, DecArgs: {});
2791 }
2792 // In Kernel environments, FPFastMathMode on OpExtInst is valid per core
2793 // spec. For other instruction types, SPV_KHR_float_controls2 is required.
2794 bool CanUseFM =
2795 TII.canUseFastMathFlags(
2796 MI: I, KHRFloatControls2: ST.canUseExtension(E: SPIRV::Extension::SPV_KHR_float_controls2)) ||
2797 (ST.isKernel() && I.getOpcode() == SPIRV::OpExtInst);
2798 if (!CanUseFM)
2799 return;
2800
2801 unsigned FMFlags = getFastMathFlags(I, ST);
2802 if (FMFlags == SPIRV::FPFastMathMode::None) {
2803 // We also need to check if any FPFastMathDefault info was set for the
2804 // types used in this instruction.
2805 if (FPFastMathDefaultInfoVec.empty())
2806 return;
2807
2808 // There are three types of instructions that can use fast math flags:
2809 // 1. Arithmetic instructions (FAdd, FMul, FSub, FDiv, FRem, etc.)
2810 // 2. Relational instructions (FCmp, FOrd, FUnord, etc.)
2811 // 3. Extended instructions (ExtInst)
2812 // For arithmetic instructions, the floating point type can be in the
2813 // result type or in the operands, but they all must be the same.
2814 // For the relational and logical instructions, the floating point type
2815 // can only be in the operands 1 and 2, not the result type. Also, the
2816 // operands must have the same type. For the extended instructions, the
2817 // floating point type can be in the result type or in the operands. It's
2818 // unclear if the operands and the result type must be the same. Let's
2819 // assume they must be. Therefore, for 1. and 2., we can check the first
2820 // operand type, and for 3. we can check the result type.
2821 assert(I.getNumOperands() >= 3 && "Expected at least 3 operands");
2822 Register ResReg = I.getOpcode() == SPIRV::OpExtInst
2823 ? I.getOperand(i: 1).getReg()
2824 : I.getOperand(i: 2).getReg();
2825 SPIRVTypeInst ResType = GR->getSPIRVTypeForVReg(VReg: ResReg, MF: I.getMF());
2826 const Type *Ty = GR->getTypeForSPIRVType(Ty: ResType);
2827 Ty = Ty->isVectorTy() ? cast<VectorType>(Val: Ty)->getElementType() : Ty;
2828
2829 // Match instruction type with the FPFastMathDefaultInfoVec.
2830 bool Emit = false;
2831 for (SPIRV::FPFastMathDefaultInfo &Elem : FPFastMathDefaultInfoVec) {
2832 if (Ty == Elem.Ty) {
2833 FMFlags = Elem.FastMathFlags;
2834 Emit = Elem.ContractionOff || Elem.SignedZeroInfNanPreserve ||
2835 Elem.FPFastMathDefault;
2836 break;
2837 }
2838 }
2839
2840 if (FMFlags == SPIRV::FPFastMathMode::None && !Emit)
2841 return;
2842 }
2843 if (isFastMathModeAvailable(ST)) {
2844 Register DstReg = I.getOperand(i: 0).getReg();
2845 buildOpDecorate(Reg: DstReg, I, TII, Dec: SPIRV::Decoration::FPFastMathMode,
2846 DecArgs: {FMFlags});
2847 }
2848}
2849
2850// Walk all functions and add decorations related to MI flags.
2851static void addDecorations(const Module &M, const SPIRVInstrInfo &TII,
2852 MachineModuleInfo *MMI, const SPIRVSubtarget &ST,
2853 SPIRV::ModuleAnalysisInfo &MAI,
2854 const SPIRVGlobalRegistry *GR) {
2855 for (const Function &F : M) {
2856 MachineFunction *MF = MMI->getMachineFunction(F);
2857 if (!MF)
2858 continue;
2859
2860 for (auto &MBB : *MF)
2861 for (auto &MI : MBB)
2862 handleMIFlagDecoration(I&: MI, ST, TII, Reqs&: MAI.Reqs, GR,
2863 FPFastMathDefaultInfoVec&: MAI.FPFastMathDefaultInfoMap[&F]);
2864 }
2865}
2866
2867static void addMBBNames(const Module &M, const SPIRVInstrInfo &TII,
2868 MachineModuleInfo *MMI, const SPIRVSubtarget &ST,
2869 SPIRV::ModuleAnalysisInfo &MAI) {
2870 for (const Function &F : M) {
2871 MachineFunction *MF = MMI->getMachineFunction(F);
2872 if (!MF)
2873 continue;
2874 if (MF->getFunction()
2875 .getFnAttribute(SPIRV_BACKEND_SERVICE_FUN_NAME)
2876 .isValid())
2877 continue;
2878 MachineRegisterInfo &MRI = MF->getRegInfo();
2879 for (auto &MBB : *MF) {
2880 if (!MBB.hasName() || MBB.empty())
2881 continue;
2882 // Emit basic block names.
2883 Register Reg = MRI.createGenericVirtualRegister(Ty: LLT::scalar(SizeInBits: 64));
2884 MRI.setRegClass(Reg, RC: &SPIRV::IDRegClass);
2885 buildOpName(Target: Reg, Name: MBB.getName(), I&: *std::prev(x: MBB.end()), TII);
2886 MCRegister GlobalReg = MAI.getOrCreateMBBRegister(MBB);
2887 MAI.setRegisterAlias(MF, Reg, AliasReg: GlobalReg);
2888 }
2889 }
2890}
2891
2892// patching Instruction::PHI to SPIRV::OpPhi
2893static void patchPhis(const Module &M, SPIRVGlobalRegistry *GR,
2894 const SPIRVInstrInfo &TII, MachineModuleInfo *MMI) {
2895 for (const Function &F : M) {
2896 MachineFunction *MF = MMI->getMachineFunction(F);
2897 if (!MF)
2898 continue;
2899 for (auto &MBB : *MF) {
2900 for (MachineInstr &MI : MBB.phis()) {
2901 MI.setDesc(TII.get(Opcode: SPIRV::OpPhi));
2902 Register ResTypeReg = GR->getSPIRVTypeID(
2903 SpirvType: GR->getSPIRVTypeForVReg(VReg: MI.getOperand(i: 0).getReg(), MF));
2904 MI.insert(InsertBefore: MI.operands_begin() + 1,
2905 Ops: {MachineOperand::CreateReg(Reg: ResTypeReg, isDef: false)});
2906 }
2907 }
2908
2909 MF->getProperties().setNoPHIs();
2910 }
2911}
2912
2913static SPIRV::FPFastMathDefaultInfoVector &getOrCreateFPFastMathDefaultInfoVec(
2914 const Module &M, SPIRV::ModuleAnalysisInfo &MAI, const Function *F) {
2915 auto it = MAI.FPFastMathDefaultInfoMap.find(Val: F);
2916 if (it != MAI.FPFastMathDefaultInfoMap.end())
2917 return it->second;
2918
2919 // If the map does not contain the entry, create a new one. Initialize it to
2920 // contain all 3 elements sorted by bit width of target type: {half, float,
2921 // double}.
2922 SPIRV::FPFastMathDefaultInfoVector FPFastMathDefaultInfoVec;
2923 FPFastMathDefaultInfoVec.emplace_back(Args: Type::getHalfTy(C&: M.getContext()),
2924 Args: SPIRV::FPFastMathMode::None);
2925 FPFastMathDefaultInfoVec.emplace_back(Args: Type::getFloatTy(C&: M.getContext()),
2926 Args: SPIRV::FPFastMathMode::None);
2927 FPFastMathDefaultInfoVec.emplace_back(Args: Type::getDoubleTy(C&: M.getContext()),
2928 Args: SPIRV::FPFastMathMode::None);
2929 return MAI.FPFastMathDefaultInfoMap[F] = std::move(FPFastMathDefaultInfoVec);
2930}
2931
2932static SPIRV::FPFastMathDefaultInfo &getFPFastMathDefaultInfo(
2933 SPIRV::FPFastMathDefaultInfoVector &FPFastMathDefaultInfoVec,
2934 const Type *Ty) {
2935 size_t BitWidth = Ty->getScalarSizeInBits();
2936 int Index =
2937 SPIRV::FPFastMathDefaultInfoVector::computeFPFastMathDefaultInfoVecIndex(
2938 BitWidth);
2939 assert(Index >= 0 && Index < 3 &&
2940 "Expected FPFastMathDefaultInfo for half, float, or double");
2941 assert(FPFastMathDefaultInfoVec.size() == 3 &&
2942 "Expected FPFastMathDefaultInfoVec to have exactly 3 elements");
2943 return FPFastMathDefaultInfoVec[Index];
2944}
2945
2946static void collectFPFastMathDefaults(const Module &M,
2947 SPIRV::ModuleAnalysisInfo &MAI,
2948 const SPIRVSubtarget &ST) {
2949 if (!ST.canUseExtension(E: SPIRV::Extension::SPV_KHR_float_controls2))
2950 return;
2951
2952 // Store the FPFastMathDefaultInfo in the FPFastMathDefaultInfoMap.
2953 // We need the entry point (function) as the key, and the target
2954 // type and flags as the value.
2955 // We also need to check ContractionOff and SignedZeroInfNanPreserve
2956 // execution modes, as they are now deprecated and must be replaced
2957 // with FPFastMathDefaultInfo.
2958 auto Node = M.getNamedMetadata(Name: "spirv.ExecutionMode");
2959 if (!Node)
2960 return;
2961
2962 for (unsigned i = 0; i < Node->getNumOperands(); i++) {
2963 MDNode *MDN = cast<MDNode>(Val: Node->getOperand(i));
2964 assert(MDN->getNumOperands() >= 2 && "Expected at least 2 operands");
2965 const Function *F = cast<Function>(
2966 Val: cast<ConstantAsMetadata>(Val: MDN->getOperand(I: 0))->getValue());
2967 const auto EM =
2968 cast<ConstantInt>(
2969 Val: cast<ConstantAsMetadata>(Val: MDN->getOperand(I: 1))->getValue())
2970 ->getZExtValue();
2971 if (EM == SPIRV::ExecutionMode::FPFastMathDefault) {
2972 assert(MDN->getNumOperands() == 4 &&
2973 "Expected 4 operands for FPFastMathDefault");
2974
2975 const Type *T = cast<ValueAsMetadata>(Val: MDN->getOperand(I: 2))->getType();
2976 unsigned Flags =
2977 cast<ConstantInt>(
2978 Val: cast<ConstantAsMetadata>(Val: MDN->getOperand(I: 3))->getValue())
2979 ->getZExtValue();
2980 SPIRV::FPFastMathDefaultInfoVector &FPFastMathDefaultInfoVec =
2981 getOrCreateFPFastMathDefaultInfoVec(M, MAI, F);
2982 SPIRV::FPFastMathDefaultInfo &Info =
2983 getFPFastMathDefaultInfo(FPFastMathDefaultInfoVec, Ty: T);
2984 Info.FastMathFlags = Flags;
2985 Info.FPFastMathDefault = true;
2986 } else if (EM == SPIRV::ExecutionMode::ContractionOff) {
2987 assert(MDN->getNumOperands() == 2 &&
2988 "Expected no operands for ContractionOff");
2989
2990 // We need to save this info for every possible FP type, i.e. {half,
2991 // float, double, fp128}.
2992 SPIRV::FPFastMathDefaultInfoVector &FPFastMathDefaultInfoVec =
2993 getOrCreateFPFastMathDefaultInfoVec(M, MAI, F);
2994 for (SPIRV::FPFastMathDefaultInfo &Info : FPFastMathDefaultInfoVec) {
2995 Info.ContractionOff = true;
2996 }
2997 } else if (EM == SPIRV::ExecutionMode::SignedZeroInfNanPreserve) {
2998 assert(MDN->getNumOperands() == 3 &&
2999 "Expected 1 operand for SignedZeroInfNanPreserve");
3000 unsigned TargetWidth =
3001 cast<ConstantInt>(
3002 Val: cast<ConstantAsMetadata>(Val: MDN->getOperand(I: 2))->getValue())
3003 ->getZExtValue();
3004 // We need to save this info only for the FP type with TargetWidth.
3005 SPIRV::FPFastMathDefaultInfoVector &FPFastMathDefaultInfoVec =
3006 getOrCreateFPFastMathDefaultInfoVec(M, MAI, F);
3007 int Index = SPIRV::FPFastMathDefaultInfoVector::
3008 computeFPFastMathDefaultInfoVecIndex(BitWidth: TargetWidth);
3009 assert(Index >= 0 && Index < 3 &&
3010 "Expected FPFastMathDefaultInfo for half, float, or double");
3011 assert(FPFastMathDefaultInfoVec.size() == 3 &&
3012 "Expected FPFastMathDefaultInfoVec to have exactly 3 elements");
3013 FPFastMathDefaultInfoVec[Index].SignedZeroInfNanPreserve = true;
3014 }
3015 }
3016}
3017
3018void SPIRVModuleAnalysis::getAnalysisUsage(AnalysisUsage &AU) const {
3019 AU.addRequired<TargetPassConfig>();
3020 AU.addRequired<MachineModuleInfoWrapperPass>();
3021}
3022
3023bool SPIRVModuleAnalysis::runOnModule(Module &M) {
3024 SPIRVTargetMachine &TM =
3025 getAnalysis<TargetPassConfig>().getTM<SPIRVTargetMachine>();
3026 ST = TM.getSubtargetImpl();
3027 GR = ST->getSPIRVGlobalRegistry();
3028 TII = ST->getInstrInfo();
3029
3030 MMI = &getAnalysis<MachineModuleInfoWrapperPass>().getMMI();
3031
3032 setBaseInfo(M);
3033
3034 patchPhis(M, GR, TII: *TII, MMI);
3035
3036 addMBBNames(M, TII: *TII, MMI, ST: *ST, MAI);
3037 collectFPFastMathDefaults(M, MAI, ST: *ST);
3038 addDecorations(M, TII: *TII, MMI, ST: *ST, MAI, GR);
3039
3040 collectReqs(M, MAI, MMI, ST: *ST);
3041
3042 // Process type/const/global var/func decl instructions, number their
3043 // destination registers from 0 to N, collect Extensions and Capabilities.
3044 collectDeclarations(M);
3045
3046 // Number rest of registers from N+1 onwards.
3047 numberRegistersGlobally(M);
3048
3049 // Collect OpName, OpEntryPoint, OpDecorate etc, process other instructions.
3050 processOtherInstrs(M);
3051
3052 // If there are no entry points, we need the Linkage capability.
3053 if (MAI.MS[SPIRV::MB_EntryPoints].empty())
3054 MAI.Reqs.addCapability(ToAdd: SPIRV::Capability::Linkage);
3055
3056 // Set maximum ID used.
3057 GR->setBound(MAI.MaxID);
3058
3059 return false;
3060}
3061