1//===- SPIRVModuleAnalysis.cpp - analysis of global instrs & regs - C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// The analysis collects instructions that should be output at the module level
10// and performs the global register numbering.
11//
12// The results of this analysis are used in AsmPrinter to rename registers
13// globally and to output required instructions at the module level.
14//
15//===----------------------------------------------------------------------===//
16
17// TODO: uses or report_fatal_error (which is also deprecated) /
18// ReportFatalUsageError in this file should be refactored, as per LLVM
19// best practices, to rely on the Diagnostic infrastructure.
20
21#include "SPIRVModuleAnalysis.h"
22#include "MCTargetDesc/SPIRVBaseInfo.h"
23#include "MCTargetDesc/SPIRVMCTargetDesc.h"
24#include "SPIRV.h"
25#include "SPIRVSubtarget.h"
26#include "SPIRVTargetMachine.h"
27#include "SPIRVUtils.h"
28#include "llvm/ADT/STLExtras.h"
29#include "llvm/CodeGen/MachineModuleInfo.h"
30#include "llvm/CodeGen/TargetPassConfig.h"
31
32using namespace llvm;
33
34#define DEBUG_TYPE "spirv-module-analysis"
35
36static cl::opt<bool>
37 SPVDumpDeps("spv-dump-deps",
38 cl::desc("Dump MIR with SPIR-V dependencies info"),
39 cl::Optional, cl::init(Val: false));
40
41static cl::list<SPIRV::Capability::Capability>
42 AvoidCapabilities("avoid-spirv-capabilities",
43 cl::desc("SPIR-V capabilities to avoid if there are "
44 "other options enabling a feature"),
45 cl::Hidden,
46 cl::values(clEnumValN(SPIRV::Capability::Shader, "Shader",
47 "SPIR-V Shader capability")));
48// Use sets instead of cl::list to check "if contains" condition
49struct AvoidCapabilitiesSet {
50 SmallSet<SPIRV::Capability::Capability, 4> S;
51 AvoidCapabilitiesSet() { S.insert_range(R&: AvoidCapabilities); }
52};
53
54char llvm::SPIRVModuleAnalysis::ID = 0;
55
56INITIALIZE_PASS(SPIRVModuleAnalysis, DEBUG_TYPE, "SPIRV module analysis", true,
57 true)
58
59// Retrieve an unsigned from an MDNode with a list of them as operands.
60static unsigned getMetadataUInt(MDNode *MdNode, unsigned OpIndex,
61 unsigned DefaultVal = 0) {
62 if (MdNode && OpIndex < MdNode->getNumOperands()) {
63 const auto &Op = MdNode->getOperand(I: OpIndex);
64 return mdconst::extract<ConstantInt>(MD: Op)->getZExtValue();
65 }
66 return DefaultVal;
67}
68
69static SPIRV::Requirements
70getSymbolicOperandRequirements(SPIRV::OperandCategory::OperandCategory Category,
71 unsigned i, const SPIRVSubtarget &ST,
72 SPIRV::RequirementHandler &Reqs) {
73 // A set of capabilities to avoid if there is another option.
74 AvoidCapabilitiesSet AvoidCaps;
75 if (!ST.isShader())
76 AvoidCaps.S.insert(V: SPIRV::Capability::Shader);
77 else
78 AvoidCaps.S.insert(V: SPIRV::Capability::Kernel);
79
80 VersionTuple ReqMinVer = getSymbolicOperandMinVersion(Category, Value: i);
81 VersionTuple ReqMaxVer = getSymbolicOperandMaxVersion(Category, Value: i);
82 VersionTuple SPIRVVersion = ST.getSPIRVVersion();
83 bool MinVerOK = SPIRVVersion.empty() || SPIRVVersion >= ReqMinVer;
84 bool MaxVerOK =
85 ReqMaxVer.empty() || SPIRVVersion.empty() || SPIRVVersion <= ReqMaxVer;
86 CapabilityList ReqCaps = getSymbolicOperandCapabilities(Category, Value: i);
87 ExtensionList ReqExts = getSymbolicOperandExtensions(Category, Value: i);
88 if (ReqCaps.empty()) {
89 if (ReqExts.empty()) {
90 if (MinVerOK && MaxVerOK)
91 return {true, {}, {}, ReqMinVer, ReqMaxVer};
92 return {false, {}, {}, VersionTuple(), VersionTuple()};
93 }
94 } else if (MinVerOK && MaxVerOK) {
95 if (ReqCaps.size() == 1) {
96 auto Cap = ReqCaps[0];
97 if (Reqs.isCapabilityAvailable(Cap)) {
98 ReqExts.append(RHS: getSymbolicOperandExtensions(
99 Category: SPIRV::OperandCategory::CapabilityOperand, Value: Cap));
100 return {true, {Cap}, std::move(ReqExts), ReqMinVer, ReqMaxVer};
101 }
102 } else {
103 // By SPIR-V specification: "If an instruction, enumerant, or other
104 // feature specifies multiple enabling capabilities, only one such
105 // capability needs to be declared to use the feature." However, one
106 // capability may be preferred over another. We use command line
107 // argument(s) and AvoidCapabilities to avoid selection of certain
108 // capabilities if there are other options.
109 CapabilityList UseCaps;
110 for (auto Cap : ReqCaps)
111 if (Reqs.isCapabilityAvailable(Cap))
112 UseCaps.push_back(Elt: Cap);
113 for (size_t i = 0, Sz = UseCaps.size(); i < Sz; ++i) {
114 auto Cap = UseCaps[i];
115 if (i == Sz - 1 || !AvoidCaps.S.contains(V: Cap)) {
116 ReqExts.append(RHS: getSymbolicOperandExtensions(
117 Category: SPIRV::OperandCategory::CapabilityOperand, Value: Cap));
118 return {true, {Cap}, std::move(ReqExts), ReqMinVer, ReqMaxVer};
119 }
120 }
121 }
122 }
123 // If there are no capabilities, or we can't satisfy the version or
124 // capability requirements, use the list of extensions (if the subtarget
125 // can handle them all).
126 if (llvm::all_of(Range&: ReqExts, P: [&ST](const SPIRV::Extension::Extension &Ext) {
127 return ST.canUseExtension(E: Ext);
128 })) {
129 return {true,
130 {},
131 std::move(ReqExts),
132 VersionTuple(),
133 VersionTuple()}; // TODO: add versions to extensions.
134 }
135 return {false, {}, {}, VersionTuple(), VersionTuple()};
136}
137
138void SPIRVModuleAnalysis::setBaseInfo(const Module &M) {
139 MAI.MaxID = 0;
140 for (int i = 0; i < SPIRV::NUM_MODULE_SECTIONS; i++)
141 MAI.MS[i].clear();
142 MAI.RegisterAliasTable.clear();
143 MAI.InstrsToDelete.clear();
144 MAI.GlobalObjMap.clear();
145 MAI.GlobalVarList.clear();
146 MAI.ExtInstSetMap.clear();
147 MAI.Reqs.clear();
148 MAI.Reqs.initAvailableCapabilities(ST: *ST);
149
150 // TODO: determine memory model and source language from the configuratoin.
151 if (auto MemModel = M.getNamedMetadata(Name: "spirv.MemoryModel")) {
152 auto MemMD = MemModel->getOperand(i: 0);
153 MAI.Addr = static_cast<SPIRV::AddressingModel::AddressingModel>(
154 getMetadataUInt(MdNode: MemMD, OpIndex: 0));
155 MAI.Mem =
156 static_cast<SPIRV::MemoryModel::MemoryModel>(getMetadataUInt(MdNode: MemMD, OpIndex: 1));
157 } else {
158 // TODO: Add support for VulkanMemoryModel.
159 MAI.Mem = ST->isShader() ? SPIRV::MemoryModel::GLSL450
160 : SPIRV::MemoryModel::OpenCL;
161 if (MAI.Mem == SPIRV::MemoryModel::OpenCL) {
162 unsigned PtrSize = ST->getPointerSize();
163 MAI.Addr = PtrSize == 32 ? SPIRV::AddressingModel::Physical32
164 : PtrSize == 64 ? SPIRV::AddressingModel::Physical64
165 : SPIRV::AddressingModel::Logical;
166 } else {
167 // TODO: Add support for PhysicalStorageBufferAddress.
168 MAI.Addr = SPIRV::AddressingModel::Logical;
169 }
170 }
171 // Get the OpenCL version number from metadata.
172 // TODO: support other source languages.
173 if (auto VerNode = M.getNamedMetadata(Name: "opencl.ocl.version")) {
174 MAI.SrcLang = SPIRV::SourceLanguage::OpenCL_C;
175 // Construct version literal in accordance with SPIRV-LLVM-Translator.
176 // TODO: support multiple OCL version metadata.
177 assert(VerNode->getNumOperands() > 0 && "Invalid SPIR");
178 auto VersionMD = VerNode->getOperand(i: 0);
179 unsigned MajorNum = getMetadataUInt(MdNode: VersionMD, OpIndex: 0, DefaultVal: 2);
180 unsigned MinorNum = getMetadataUInt(MdNode: VersionMD, OpIndex: 1);
181 unsigned RevNum = getMetadataUInt(MdNode: VersionMD, OpIndex: 2);
182 // Prevent Major part of OpenCL version to be 0
183 MAI.SrcLangVersion =
184 (std::max(a: 1U, b: MajorNum) * 100 + MinorNum) * 1000 + RevNum;
185 // When opencl.cxx.version is also present, validate compatibility
186 // and use C++ for OpenCL as source language with the C++ version.
187 if (auto *CxxVerNode = M.getNamedMetadata(Name: "opencl.cxx.version")) {
188 assert(CxxVerNode->getNumOperands() > 0 && "Invalid SPIR");
189 auto *CxxMD = CxxVerNode->getOperand(i: 0);
190 unsigned CxxVer =
191 (getMetadataUInt(MdNode: CxxMD, OpIndex: 0) * 100 + getMetadataUInt(MdNode: CxxMD, OpIndex: 1)) * 1000 +
192 getMetadataUInt(MdNode: CxxMD, OpIndex: 2);
193 if ((MAI.SrcLangVersion == 200000 && CxxVer == 100000) ||
194 (MAI.SrcLangVersion == 300000 && CxxVer == 202100000)) {
195 MAI.SrcLang = SPIRV::SourceLanguage::CPP_for_OpenCL;
196 MAI.SrcLangVersion = CxxVer;
197 } else {
198 report_fatal_error(
199 reason: "opencl cxx version is not compatible with opencl c version!");
200 }
201 }
202 } else {
203 // If there is no information about OpenCL version we are forced to generate
204 // OpenCL 1.0 by default for the OpenCL environment to avoid puzzling
205 // run-times with Unknown/0.0 version output. For a reference, LLVM-SPIRV
206 // Translator avoids potential issues with run-times in a similar manner.
207 if (!ST->isShader()) {
208 MAI.SrcLang = SPIRV::SourceLanguage::OpenCL_CPP;
209 MAI.SrcLangVersion = 100000;
210 } else {
211 MAI.SrcLang = SPIRV::SourceLanguage::Unknown;
212 MAI.SrcLangVersion = 0;
213 }
214 }
215
216 if (auto ExtNode = M.getNamedMetadata(Name: "opencl.used.extensions")) {
217 for (unsigned I = 0, E = ExtNode->getNumOperands(); I != E; ++I) {
218 MDNode *MD = ExtNode->getOperand(i: I);
219 if (!MD || MD->getNumOperands() == 0)
220 continue;
221 for (unsigned J = 0, N = MD->getNumOperands(); J != N; ++J)
222 MAI.SrcExt.insert(key: cast<MDString>(Val: MD->getOperand(I: J))->getString());
223 }
224 }
225
226 // Update required capabilities for this memory model, addressing model and
227 // source language.
228 MAI.Reqs.getAndAddRequirements(Category: SPIRV::OperandCategory::MemoryModelOperand,
229 i: MAI.Mem, ST: *ST);
230 MAI.Reqs.getAndAddRequirements(Category: SPIRV::OperandCategory::SourceLanguageOperand,
231 i: MAI.SrcLang, ST: *ST);
232 MAI.Reqs.getAndAddRequirements(Category: SPIRV::OperandCategory::AddressingModelOperand,
233 i: MAI.Addr, ST: *ST);
234
235 if (MAI.Mem == SPIRV::MemoryModel::VulkanKHR)
236 MAI.Reqs.addExtension(ToAdd: SPIRV::Extension::SPV_KHR_vulkan_memory_model);
237
238 if (!ST->isShader()) {
239 // TODO: check if it's required by default.
240 MAI.ExtInstSetMap[static_cast<unsigned>(
241 SPIRV::InstructionSet::OpenCL_std)] = MAI.getNextIDRegister();
242 }
243}
244
245// Appends the signature of the decoration instructions that decorate R to
246// Signature.
247static void appendDecorationsForReg(const MachineRegisterInfo &MRI, Register R,
248 InstrSignature &Signature) {
249 for (MachineInstr &UseMI : MRI.use_instructions(Reg: R)) {
250 // We don't handle OpDecorateId because getting the register alias for the
251 // ID can cause problems, and we do not need it for now.
252 if (UseMI.getOpcode() != SPIRV::OpDecorate &&
253 UseMI.getOpcode() != SPIRV::OpMemberDecorate)
254 continue;
255
256 for (unsigned I = 0; I < UseMI.getNumOperands(); ++I) {
257 const MachineOperand &MO = UseMI.getOperand(i: I);
258 if (MO.isReg())
259 continue;
260 Signature.push_back(Elt: hash_value(MO));
261 }
262 }
263}
264
265// Returns a representation of an instruction as a vector of MachineOperand
266// hash values, see llvm::hash_value(const MachineOperand &MO) for details.
267// This creates a signature of the instruction with the same content
268// that MachineOperand::isIdenticalTo uses for comparison.
269static InstrSignature instrToSignature(const MachineInstr &MI,
270 SPIRV::ModuleAnalysisInfo &MAI,
271 bool UseDefReg) {
272 Register DefReg;
273 InstrSignature Signature{MI.getOpcode()};
274 for (unsigned i = 0; i < MI.getNumOperands(); ++i) {
275 // The only decorations that can be applied more than once to a given <id>
276 // or structure member are FuncParamAttr (38), UserSemantic (5635),
277 // CacheControlLoadINTEL (6442), and CacheControlStoreINTEL (6443). For all
278 // the rest of decorations, we will only add to the signature the Opcode,
279 // the id to which it applies, and the decoration id, disregarding any
280 // decoration flags. This will ensure that any subsequent decoration with
281 // the same id will be deemed as a duplicate. Then, at the call site, we
282 // will be able to handle duplicates in the best way.
283 unsigned Opcode = MI.getOpcode();
284 if ((Opcode == SPIRV::OpDecorate) && i >= 2) {
285 unsigned DecorationID = MI.getOperand(i: 1).getImm();
286 if (DecorationID != SPIRV::Decoration::FuncParamAttr &&
287 DecorationID != SPIRV::Decoration::UserSemantic &&
288 DecorationID != SPIRV::Decoration::CacheControlLoadINTEL &&
289 DecorationID != SPIRV::Decoration::CacheControlStoreINTEL)
290 continue;
291 }
292 const MachineOperand &MO = MI.getOperand(i);
293 size_t h;
294 if (MO.isReg()) {
295 if (!UseDefReg && MO.isDef()) {
296 assert(!DefReg.isValid() && "Multiple def registers.");
297 DefReg = MO.getReg();
298 continue;
299 }
300 Register RegAlias = MAI.getRegisterAlias(MF: MI.getMF(), Reg: MO.getReg());
301 if (!RegAlias.isValid()) {
302 LLVM_DEBUG({
303 dbgs() << "Unexpectedly, no global id found for the operand ";
304 MO.print(dbgs());
305 dbgs() << "\nInstruction: ";
306 MI.print(dbgs());
307 dbgs() << "\n";
308 });
309 report_fatal_error(reason: "All v-regs must have been mapped to global id's");
310 }
311 // mimic llvm::hash_value(const MachineOperand &MO)
312 h = hash_combine(args: MO.getType(), args: (unsigned)RegAlias, args: MO.getSubReg(),
313 args: MO.isDef());
314 } else {
315 h = hash_value(MO);
316 }
317 Signature.push_back(Elt: h);
318 }
319
320 if (DefReg.isValid()) {
321 // Decorations change the semantics of the current instruction. So two
322 // identical instruction with different decorations cannot be merged. That
323 // is why we add the decorations to the signature.
324 appendDecorationsForReg(MRI: MI.getMF()->getRegInfo(), R: DefReg, Signature);
325 }
326 return Signature;
327}
328
329bool SPIRVModuleAnalysis::isDeclSection(const MachineRegisterInfo &MRI,
330 const MachineInstr &MI) {
331 unsigned Opcode = MI.getOpcode();
332 switch (Opcode) {
333 case SPIRV::OpTypeForwardPointer:
334 // omit now, collect later
335 return false;
336 case SPIRV::OpVariable:
337 return static_cast<SPIRV::StorageClass::StorageClass>(
338 MI.getOperand(i: 2).getImm()) != SPIRV::StorageClass::Function;
339 case SPIRV::OpFunction:
340 case SPIRV::OpFunctionParameter:
341 return true;
342 }
343 if (GR->hasConstFunPtr() && Opcode == SPIRV::OpUndef) {
344 Register DefReg = MI.getOperand(i: 0).getReg();
345 for (MachineInstr &UseMI : MRI.use_instructions(Reg: DefReg)) {
346 if (UseMI.getOpcode() != SPIRV::OpConstantFunctionPointerINTEL)
347 continue;
348 // it's a dummy definition, FP constant refers to a function,
349 // and this is resolved in another way; let's skip this definition
350 assert(UseMI.getOperand(2).isReg() &&
351 UseMI.getOperand(2).getReg() == DefReg);
352 MAI.setSkipEmission(&MI);
353 return false;
354 }
355 }
356 return TII->isTypeDeclInstr(MI) || TII->isConstantInstr(MI) ||
357 TII->isInlineAsmDefInstr(MI);
358}
359
360// This is a special case of a function pointer refering to a possibly
361// forward function declaration. The operand is a dummy OpUndef that
362// requires a special treatment.
363void SPIRVModuleAnalysis::visitFunPtrUse(
364 Register OpReg, InstrGRegsMap &SignatureToGReg,
365 std::map<const Value *, unsigned> &GlobalToGReg, const MachineFunction *MF,
366 const MachineInstr &MI) {
367 const MachineOperand *OpFunDef =
368 GR->getFunctionDefinitionByUse(Use: &MI.getOperand(i: 2));
369 assert(OpFunDef && OpFunDef->isReg());
370 // find the actual function definition and number it globally in advance
371 const MachineInstr *OpDefMI = OpFunDef->getParent();
372 assert(OpDefMI && OpDefMI->getOpcode() == SPIRV::OpFunction);
373 const MachineFunction *FunDefMF = OpDefMI->getParent()->getParent();
374 const MachineRegisterInfo &FunDefMRI = FunDefMF->getRegInfo();
375 do {
376 visitDecl(MRI: FunDefMRI, SignatureToGReg, GlobalToGReg, MF: FunDefMF, MI: *OpDefMI);
377 OpDefMI = OpDefMI->getNextNode();
378 } while (OpDefMI && (OpDefMI->getOpcode() == SPIRV::OpFunction ||
379 OpDefMI->getOpcode() == SPIRV::OpFunctionParameter));
380 // associate the function pointer with the newly assigned global number
381 MCRegister GlobalFunDefReg =
382 MAI.getRegisterAlias(MF: FunDefMF, Reg: OpFunDef->getReg());
383 assert(GlobalFunDefReg.isValid() &&
384 "Function definition must refer to a global register");
385 MAI.setRegisterAlias(MF, Reg: OpReg, AliasReg: GlobalFunDefReg);
386}
387
388// Depth first recursive traversal of dependencies. Repeated visits are guarded
389// by MAI.hasRegisterAlias().
390void SPIRVModuleAnalysis::visitDecl(
391 const MachineRegisterInfo &MRI, InstrGRegsMap &SignatureToGReg,
392 std::map<const Value *, unsigned> &GlobalToGReg, const MachineFunction *MF,
393 const MachineInstr &MI) {
394 unsigned Opcode = MI.getOpcode();
395
396 // Process each operand of the instruction to resolve dependencies
397 for (const MachineOperand &MO : MI.operands()) {
398 if (!MO.isReg() || MO.isDef())
399 continue;
400 Register OpReg = MO.getReg();
401 // Handle function pointers special case
402 if (Opcode == SPIRV::OpConstantFunctionPointerINTEL &&
403 MRI.getRegClass(Reg: OpReg) == &SPIRV::pIDRegClass) {
404 visitFunPtrUse(OpReg, SignatureToGReg, GlobalToGReg, MF, MI);
405 continue;
406 }
407 // Skip already processed instructions
408 if (MAI.hasRegisterAlias(MF, Reg: MO.getReg()))
409 continue;
410 // Recursively visit dependencies
411 if (const MachineInstr *OpDefMI = MRI.getUniqueVRegDef(Reg: OpReg)) {
412 if (isDeclSection(MRI, MI: *OpDefMI))
413 visitDecl(MRI, SignatureToGReg, GlobalToGReg, MF, MI: *OpDefMI);
414 continue;
415 }
416 // Handle the unexpected case of no unique definition for the SPIR-V
417 // instruction
418 LLVM_DEBUG({
419 dbgs() << "Unexpectedly, no unique definition for the operand ";
420 MO.print(dbgs());
421 dbgs() << "\nInstruction: ";
422 MI.print(dbgs());
423 dbgs() << "\n";
424 });
425 report_fatal_error(
426 reason: "No unique definition is found for the virtual register");
427 }
428
429 MCRegister GReg;
430 bool IsFunDef = false;
431 if (TII->isSpecConstantInstr(MI)) {
432 GReg = MAI.getNextIDRegister();
433 MAI.MS[SPIRV::MB_TypeConstVars].push_back(Elt: &MI);
434 } else if (Opcode == SPIRV::OpFunction ||
435 Opcode == SPIRV::OpFunctionParameter) {
436 GReg = handleFunctionOrParameter(MF, MI, GlobalToGReg, IsFunDef);
437 } else if (Opcode == SPIRV::OpTypeStruct ||
438 Opcode == SPIRV::OpConstantComposite) {
439 GReg = handleTypeDeclOrConstant(MI, SignatureToGReg);
440 const MachineInstr *NextInstr = MI.getNextNode();
441 while (NextInstr &&
442 ((Opcode == SPIRV::OpTypeStruct &&
443 NextInstr->getOpcode() == SPIRV::OpTypeStructContinuedINTEL) ||
444 (Opcode == SPIRV::OpConstantComposite &&
445 NextInstr->getOpcode() ==
446 SPIRV::OpConstantCompositeContinuedINTEL))) {
447 MCRegister Tmp = handleTypeDeclOrConstant(MI: *NextInstr, SignatureToGReg);
448 MAI.setRegisterAlias(MF, Reg: NextInstr->getOperand(i: 0).getReg(), AliasReg: Tmp);
449 MAI.setSkipEmission(NextInstr);
450 NextInstr = NextInstr->getNextNode();
451 }
452 } else if (TII->isTypeDeclInstr(MI) || TII->isConstantInstr(MI) ||
453 TII->isInlineAsmDefInstr(MI)) {
454 GReg = handleTypeDeclOrConstant(MI, SignatureToGReg);
455 } else if (Opcode == SPIRV::OpVariable) {
456 GReg = handleVariable(MF, MI, GlobalToGReg);
457 } else {
458 LLVM_DEBUG({
459 dbgs() << "\nInstruction: ";
460 MI.print(dbgs());
461 dbgs() << "\n";
462 });
463 llvm_unreachable("Unexpected instruction is visited");
464 }
465 MAI.setRegisterAlias(MF, Reg: MI.getOperand(i: 0).getReg(), AliasReg: GReg);
466 if (!IsFunDef)
467 MAI.setSkipEmission(&MI);
468}
469
470MCRegister SPIRVModuleAnalysis::handleFunctionOrParameter(
471 const MachineFunction *MF, const MachineInstr &MI,
472 std::map<const Value *, unsigned> &GlobalToGReg, bool &IsFunDef) {
473 const Value *GObj = GR->getGlobalObject(MF, R: MI.getOperand(i: 0).getReg());
474 assert(GObj && "Unregistered global definition");
475 const Function *F = dyn_cast<Function>(Val: GObj);
476 if (!F)
477 F = dyn_cast<Argument>(Val: GObj)->getParent();
478 assert(F && "Expected a reference to a function or an argument");
479 IsFunDef = !F->isDeclaration();
480 auto [It, Inserted] = GlobalToGReg.try_emplace(k: GObj);
481 if (!Inserted)
482 return It->second;
483 MCRegister GReg = MAI.getNextIDRegister();
484 It->second = GReg;
485 if (!IsFunDef)
486 MAI.MS[SPIRV::MB_ExtFuncDecls].push_back(Elt: &MI);
487 return GReg;
488}
489
490MCRegister
491SPIRVModuleAnalysis::handleTypeDeclOrConstant(const MachineInstr &MI,
492 InstrGRegsMap &SignatureToGReg) {
493 InstrSignature MISign = instrToSignature(MI, MAI, UseDefReg: false);
494 auto [It, Inserted] = SignatureToGReg.try_emplace(k: MISign);
495 if (!Inserted)
496 return It->second;
497 MCRegister GReg = MAI.getNextIDRegister();
498 It->second = GReg;
499 MAI.MS[SPIRV::MB_TypeConstVars].push_back(Elt: &MI);
500 return GReg;
501}
502
503MCRegister SPIRVModuleAnalysis::handleVariable(
504 const MachineFunction *MF, const MachineInstr &MI,
505 std::map<const Value *, unsigned> &GlobalToGReg) {
506 MAI.GlobalVarList.push_back(Elt: &MI);
507 const Value *GObj = GR->getGlobalObject(MF, R: MI.getOperand(i: 0).getReg());
508 assert(GObj && "Unregistered global definition");
509 auto [It, Inserted] = GlobalToGReg.try_emplace(k: GObj);
510 if (!Inserted)
511 return It->second;
512 MCRegister GReg = MAI.getNextIDRegister();
513 It->second = GReg;
514 MAI.MS[SPIRV::MB_TypeConstVars].push_back(Elt: &MI);
515 if (const auto *GV = dyn_cast<GlobalVariable>(Val: GObj))
516 MAI.GlobalObjMap[GV] = GReg;
517 return GReg;
518}
519
520void SPIRVModuleAnalysis::collectDeclarations(const Module &M) {
521 InstrGRegsMap SignatureToGReg;
522 std::map<const Value *, unsigned> GlobalToGReg;
523 for (const Function &F : M) {
524 MachineFunction *MF = MMI->getMachineFunction(F);
525 if (!MF)
526 continue;
527 const MachineRegisterInfo &MRI = MF->getRegInfo();
528 unsigned PastHeader = 0;
529 for (MachineBasicBlock &MBB : *MF) {
530 for (MachineInstr &MI : MBB) {
531 if (MI.getNumOperands() == 0)
532 continue;
533 unsigned Opcode = MI.getOpcode();
534 if (Opcode == SPIRV::OpFunction) {
535 if (PastHeader == 0) {
536 PastHeader = 1;
537 continue;
538 }
539 } else if (Opcode == SPIRV::OpFunctionParameter) {
540 if (PastHeader < 2)
541 continue;
542 } else if (PastHeader > 0) {
543 PastHeader = 2;
544 }
545
546 const MachineOperand &DefMO = MI.getOperand(i: 0);
547 switch (Opcode) {
548 case SPIRV::OpExtension:
549 MAI.Reqs.addExtension(ToAdd: SPIRV::Extension::Extension(DefMO.getImm()));
550 MAI.setSkipEmission(&MI);
551 break;
552 case SPIRV::OpCapability:
553 MAI.Reqs.addCapability(ToAdd: SPIRV::Capability::Capability(DefMO.getImm()));
554 MAI.setSkipEmission(&MI);
555 if (PastHeader > 0)
556 PastHeader = 2;
557 break;
558 default:
559 if (DefMO.isReg() && isDeclSection(MRI, MI) &&
560 !MAI.hasRegisterAlias(MF, Reg: DefMO.getReg()))
561 visitDecl(MRI, SignatureToGReg, GlobalToGReg, MF, MI);
562 }
563 }
564 }
565 }
566}
567
568// Look for IDs declared with Import linkage, and map the corresponding function
569// to the register defining that variable (which will usually be the result of
570// an OpFunction). This lets us call externally imported functions using
571// the correct ID registers.
572void SPIRVModuleAnalysis::collectFuncNames(MachineInstr &MI,
573 const Function *F) {
574 if (MI.getOpcode() == SPIRV::OpDecorate) {
575 // If it's got Import linkage.
576 auto Dec = MI.getOperand(i: 1).getImm();
577 if (Dec == SPIRV::Decoration::LinkageAttributes) {
578 auto Lnk = MI.getOperand(i: MI.getNumOperands() - 1).getImm();
579 if (Lnk == SPIRV::LinkageType::Import) {
580 // Map imported function name to function ID register.
581 const Function *ImportedFunc =
582 F->getParent()->getFunction(Name: getStringImm(MI, StartIndex: 2));
583 Register Target = MI.getOperand(i: 0).getReg();
584 MAI.GlobalObjMap[ImportedFunc] =
585 MAI.getRegisterAlias(MF: MI.getMF(), Reg: Target);
586 }
587 }
588 } else if (MI.getOpcode() == SPIRV::OpFunction) {
589 // Record all internal OpFunction declarations.
590 Register Reg = MI.defs().begin()->getReg();
591 MCRegister GlobalReg = MAI.getRegisterAlias(MF: MI.getMF(), Reg);
592 assert(GlobalReg.isValid());
593 MAI.GlobalObjMap[F] = GlobalReg;
594 }
595}
596
597// Collect the given instruction in the specified MS. We assume global register
598// numbering has already occurred by this point. We can directly compare reg
599// arguments when detecting duplicates.
600static void collectOtherInstr(MachineInstr &MI, SPIRV::ModuleAnalysisInfo &MAI,
601 SPIRV::ModuleSectionType MSType, InstrTraces &IS,
602 bool Append = true) {
603 MAI.setSkipEmission(&MI);
604 InstrSignature MISign = instrToSignature(MI, MAI, UseDefReg: true);
605 auto FoundMI = IS.insert(x: std::move(MISign));
606 if (!FoundMI.second) {
607 if (MI.getOpcode() == SPIRV::OpDecorate) {
608 assert(MI.getNumOperands() >= 2 &&
609 "Decoration instructions must have at least 2 operands");
610 assert(MSType == SPIRV::MB_Annotations &&
611 "Only OpDecorate instructions can be duplicates");
612 // For FPFastMathMode decoration, we need to merge the flags of the
613 // duplicate decoration with the original one, so we need to find the
614 // original instruction that has the same signature. For the rest of
615 // instructions, we will simply skip the duplicate.
616 if (MI.getOperand(i: 1).getImm() != SPIRV::Decoration::FPFastMathMode)
617 return; // Skip duplicates of other decorations.
618
619 const SPIRV::InstrList &Decorations = MAI.MS[MSType];
620 for (const MachineInstr *OrigMI : Decorations) {
621 if (instrToSignature(MI: *OrigMI, MAI, UseDefReg: true) == MISign) {
622 assert(OrigMI->getNumOperands() == MI.getNumOperands() &&
623 "Original instruction must have the same number of operands");
624 assert(
625 OrigMI->getNumOperands() == 3 &&
626 "FPFastMathMode decoration must have 3 operands for OpDecorate");
627 unsigned OrigFlags = OrigMI->getOperand(i: 2).getImm();
628 unsigned NewFlags = MI.getOperand(i: 2).getImm();
629 if (OrigFlags == NewFlags)
630 return; // No need to merge, the flags are the same.
631
632 // Emit warning about possible conflict between flags.
633 unsigned FinalFlags = OrigFlags | NewFlags;
634 llvm::errs()
635 << "Warning: Conflicting FPFastMathMode decoration flags "
636 "in instruction: "
637 << *OrigMI << "Original flags: " << OrigFlags
638 << ", new flags: " << NewFlags
639 << ". They will be merged on a best effort basis, but not "
640 "validated. Final flags: "
641 << FinalFlags << "\n";
642 MachineInstr *OrigMINonConst = const_cast<MachineInstr *>(OrigMI);
643 MachineOperand &OrigFlagsOp = OrigMINonConst->getOperand(i: 2);
644 OrigFlagsOp = MachineOperand::CreateImm(Val: FinalFlags);
645 return; // Merge done, so we found a duplicate; don't add it to MAI.MS
646 }
647 }
648 assert(false && "No original instruction found for the duplicate "
649 "OpDecorate, but we found one in IS.");
650 }
651 return; // insert failed, so we found a duplicate; don't add it to MAI.MS
652 }
653 // No duplicates, so add it.
654 if (Append)
655 MAI.MS[MSType].push_back(Elt: &MI);
656 else
657 MAI.MS[MSType].insert(I: MAI.MS[MSType].begin(), Elt: &MI);
658}
659
660// Some global instructions make reference to function-local ID regs, so cannot
661// be correctly collected until these registers are globally numbered.
662void SPIRVModuleAnalysis::processOtherInstrs(const Module &M) {
663 InstrTraces IS;
664 for (const Function &F : M) {
665 if (F.isDeclaration())
666 continue;
667 MachineFunction *MF = MMI->getMachineFunction(F);
668 assert(MF);
669
670 for (MachineBasicBlock &MBB : *MF)
671 for (MachineInstr &MI : MBB) {
672 if (MAI.getSkipEmission(MI: &MI))
673 continue;
674 const unsigned OpCode = MI.getOpcode();
675 if (OpCode == SPIRV::OpString) {
676 collectOtherInstr(MI, MAI, MSType: SPIRV::MB_DebugStrings, IS);
677 } else if (OpCode == SPIRV::OpExtInst && MI.getOperand(i: 2).isImm() &&
678 MI.getOperand(i: 2).getImm() ==
679 SPIRV::InstructionSet::
680 NonSemantic_Shader_DebugInfo_100) {
681 MachineOperand Ins = MI.getOperand(i: 3);
682 namespace NS = SPIRV::NonSemanticExtInst;
683 static constexpr int64_t GlobalNonSemanticDITy[] = {
684 NS::DebugSource, NS::DebugCompilationUnit, NS::DebugInfoNone,
685 NS::DebugTypeBasic, NS::DebugTypePointer};
686 bool IsGlobalDI = false;
687 for (unsigned Idx = 0; Idx < std::size(GlobalNonSemanticDITy); ++Idx)
688 IsGlobalDI |= Ins.getImm() == GlobalNonSemanticDITy[Idx];
689 if (IsGlobalDI)
690 collectOtherInstr(MI, MAI, MSType: SPIRV::MB_NonSemanticGlobalDI, IS);
691 } else if (OpCode == SPIRV::OpName || OpCode == SPIRV::OpMemberName) {
692 collectOtherInstr(MI, MAI, MSType: SPIRV::MB_DebugNames, IS);
693 } else if (OpCode == SPIRV::OpEntryPoint) {
694 collectOtherInstr(MI, MAI, MSType: SPIRV::MB_EntryPoints, IS);
695 } else if (TII->isAliasingInstr(MI)) {
696 collectOtherInstr(MI, MAI, MSType: SPIRV::MB_AliasingInsts, IS);
697 } else if (TII->isDecorationInstr(MI)) {
698 collectOtherInstr(MI, MAI, MSType: SPIRV::MB_Annotations, IS);
699 collectFuncNames(MI, F: &F);
700 } else if (TII->isConstantInstr(MI)) {
701 // Now OpSpecConstant*s are not in DT,
702 // but they need to be collected anyway.
703 collectOtherInstr(MI, MAI, MSType: SPIRV::MB_TypeConstVars, IS);
704 } else if (OpCode == SPIRV::OpFunction) {
705 collectFuncNames(MI, F: &F);
706 } else if (OpCode == SPIRV::OpTypeForwardPointer) {
707 collectOtherInstr(MI, MAI, MSType: SPIRV::MB_TypeConstVars, IS, Append: false);
708 }
709 }
710 }
711}
712
713// Number registers in all functions globally from 0 onwards and store
714// the result in global register alias table. Some registers are already
715// numbered.
716void SPIRVModuleAnalysis::numberRegistersGlobally(const Module &M) {
717 for (const Function &F : M) {
718 if (F.isDeclaration())
719 continue;
720 MachineFunction *MF = MMI->getMachineFunction(F);
721 assert(MF);
722 for (MachineBasicBlock &MBB : *MF) {
723 for (MachineInstr &MI : MBB) {
724 for (MachineOperand &Op : MI.operands()) {
725 if (!Op.isReg())
726 continue;
727 Register Reg = Op.getReg();
728 if (MAI.hasRegisterAlias(MF, Reg))
729 continue;
730 MCRegister NewReg = MAI.getNextIDRegister();
731 MAI.setRegisterAlias(MF, Reg, AliasReg: NewReg);
732 }
733 if (MI.getOpcode() != SPIRV::OpExtInst)
734 continue;
735 auto Set = MI.getOperand(i: 2).getImm();
736 auto [It, Inserted] = MAI.ExtInstSetMap.try_emplace(Key: Set);
737 if (Inserted)
738 It->second = MAI.getNextIDRegister();
739 }
740 }
741 }
742}
743
744// RequirementHandler implementations.
745void SPIRV::RequirementHandler::getAndAddRequirements(
746 SPIRV::OperandCategory::OperandCategory Category, uint32_t i,
747 const SPIRVSubtarget &ST) {
748 addRequirements(Req: getSymbolicOperandRequirements(Category, i, ST, Reqs&: *this));
749}
750
751void SPIRV::RequirementHandler::recursiveAddCapabilities(
752 const CapabilityList &ToPrune) {
753 for (const auto &Cap : ToPrune) {
754 AllCaps.insert(V: Cap);
755 CapabilityList ImplicitDecls =
756 getSymbolicOperandCapabilities(Category: OperandCategory::CapabilityOperand, Value: Cap);
757 recursiveAddCapabilities(ToPrune: ImplicitDecls);
758 }
759}
760
761void SPIRV::RequirementHandler::addCapabilities(const CapabilityList &ToAdd) {
762 for (const auto &Cap : ToAdd) {
763 bool IsNewlyInserted = AllCaps.insert(V: Cap).second;
764 if (!IsNewlyInserted) // Don't re-add if it's already been declared.
765 continue;
766 CapabilityList ImplicitDecls =
767 getSymbolicOperandCapabilities(Category: OperandCategory::CapabilityOperand, Value: Cap);
768 recursiveAddCapabilities(ToPrune: ImplicitDecls);
769 MinimalCaps.push_back(Elt: Cap);
770 }
771}
772
773void SPIRV::RequirementHandler::addRequirements(
774 const SPIRV::Requirements &Req) {
775 if (!Req.IsSatisfiable)
776 report_fatal_error(reason: "Adding SPIR-V requirements this target can't satisfy.");
777
778 if (Req.Cap.has_value())
779 addCapabilities(ToAdd: {Req.Cap.value()});
780
781 addExtensions(ToAdd: Req.Exts);
782
783 if (!Req.MinVer.empty()) {
784 if (!MaxVersion.empty() && Req.MinVer > MaxVersion) {
785 LLVM_DEBUG(dbgs() << "Conflicting version requirements: >= " << Req.MinVer
786 << " and <= " << MaxVersion << "\n");
787 report_fatal_error(reason: "Adding SPIR-V requirements that can't be satisfied.");
788 }
789
790 if (MinVersion.empty() || Req.MinVer > MinVersion)
791 MinVersion = Req.MinVer;
792 }
793
794 if (!Req.MaxVer.empty()) {
795 if (!MinVersion.empty() && Req.MaxVer < MinVersion) {
796 LLVM_DEBUG(dbgs() << "Conflicting version requirements: <= " << Req.MaxVer
797 << " and >= " << MinVersion << "\n");
798 report_fatal_error(reason: "Adding SPIR-V requirements that can't be satisfied.");
799 }
800
801 if (MaxVersion.empty() || Req.MaxVer < MaxVersion)
802 MaxVersion = Req.MaxVer;
803 }
804}
805
806void SPIRV::RequirementHandler::checkSatisfiable(
807 const SPIRVSubtarget &ST) const {
808 // Report as many errors as possible before aborting the compilation.
809 bool IsSatisfiable = true;
810 auto TargetVer = ST.getSPIRVVersion();
811
812 if (!MaxVersion.empty() && !TargetVer.empty() && MaxVersion < TargetVer) {
813 LLVM_DEBUG(
814 dbgs() << "Target SPIR-V version too high for required features\n"
815 << "Required max version: " << MaxVersion << " target version "
816 << TargetVer << "\n");
817 IsSatisfiable = false;
818 }
819
820 if (!MinVersion.empty() && !TargetVer.empty() && MinVersion > TargetVer) {
821 LLVM_DEBUG(dbgs() << "Target SPIR-V version too low for required features\n"
822 << "Required min version: " << MinVersion
823 << " target version " << TargetVer << "\n");
824 IsSatisfiable = false;
825 }
826
827 if (!MinVersion.empty() && !MaxVersion.empty() && MinVersion > MaxVersion) {
828 LLVM_DEBUG(
829 dbgs()
830 << "Version is too low for some features and too high for others.\n"
831 << "Required SPIR-V min version: " << MinVersion
832 << " required SPIR-V max version " << MaxVersion << "\n");
833 IsSatisfiable = false;
834 }
835
836 AvoidCapabilitiesSet AvoidCaps;
837 if (!ST.isShader())
838 AvoidCaps.S.insert(V: SPIRV::Capability::Shader);
839 else
840 AvoidCaps.S.insert(V: SPIRV::Capability::Kernel);
841
842 for (auto Cap : MinimalCaps) {
843 if (AvailableCaps.contains(V: Cap) && !AvoidCaps.S.contains(V: Cap))
844 continue;
845 LLVM_DEBUG(dbgs() << "Capability not supported: "
846 << getSymbolicOperandMnemonic(
847 OperandCategory::CapabilityOperand, Cap)
848 << "\n");
849 IsSatisfiable = false;
850 }
851
852 for (auto Ext : AllExtensions) {
853 if (ST.canUseExtension(E: Ext))
854 continue;
855 LLVM_DEBUG(dbgs() << "Extension not supported: "
856 << getSymbolicOperandMnemonic(
857 OperandCategory::ExtensionOperand, Ext)
858 << "\n");
859 IsSatisfiable = false;
860 }
861
862 if (!IsSatisfiable)
863 report_fatal_error(reason: "Unable to meet SPIR-V requirements for this target.");
864}
865
866// Add the given capabilities and all their implicitly defined capabilities too.
867void SPIRV::RequirementHandler::addAvailableCaps(const CapabilityList &ToAdd) {
868 for (const auto Cap : ToAdd)
869 if (AvailableCaps.insert(V: Cap).second)
870 addAvailableCaps(ToAdd: getSymbolicOperandCapabilities(
871 Category: SPIRV::OperandCategory::CapabilityOperand, Value: Cap));
872}
873
874void SPIRV::RequirementHandler::removeCapabilityIf(
875 const Capability::Capability ToRemove,
876 const Capability::Capability IfPresent) {
877 if (AllCaps.contains(V: IfPresent))
878 AllCaps.erase(V: ToRemove);
879}
880
881namespace llvm {
882namespace SPIRV {
883void RequirementHandler::initAvailableCapabilities(const SPIRVSubtarget &ST) {
884 // Provided by both all supported Vulkan versions and OpenCl.
885 addAvailableCaps(ToAdd: {Capability::Shader, Capability::Linkage, Capability::Int8,
886 Capability::Int16});
887
888 if (ST.isAtLeastSPIRVVer(VerToCompareTo: VersionTuple(1, 3)))
889 addAvailableCaps(ToAdd: {Capability::GroupNonUniform,
890 Capability::GroupNonUniformVote,
891 Capability::GroupNonUniformArithmetic,
892 Capability::GroupNonUniformBallot,
893 Capability::GroupNonUniformClustered,
894 Capability::GroupNonUniformShuffle,
895 Capability::GroupNonUniformShuffleRelative,
896 Capability::GroupNonUniformQuad});
897
898 if (ST.isAtLeastSPIRVVer(VerToCompareTo: VersionTuple(1, 6)))
899 addAvailableCaps(ToAdd: {Capability::DotProduct, Capability::DotProductInputAll,
900 Capability::DotProductInput4x8Bit,
901 Capability::DotProductInput4x8BitPacked,
902 Capability::DemoteToHelperInvocation});
903
904 // Add capabilities enabled by extensions.
905 for (auto Extension : ST.getAllAvailableExtensions()) {
906 CapabilityList EnabledCapabilities =
907 getCapabilitiesEnabledByExtension(Extension);
908 addAvailableCaps(ToAdd: EnabledCapabilities);
909 }
910
911 if (!ST.isShader()) {
912 initAvailableCapabilitiesForOpenCL(ST);
913 return;
914 }
915
916 if (ST.isShader()) {
917 initAvailableCapabilitiesForVulkan(ST);
918 return;
919 }
920
921 report_fatal_error(reason: "Unimplemented environment for SPIR-V generation.");
922}
923
924void RequirementHandler::initAvailableCapabilitiesForOpenCL(
925 const SPIRVSubtarget &ST) {
926 // Add the min requirements for different OpenCL and SPIR-V versions.
927 addAvailableCaps(ToAdd: {Capability::Addresses, Capability::Float16Buffer,
928 Capability::Kernel, Capability::Vector16,
929 Capability::Groups, Capability::GenericPointer,
930 Capability::StorageImageWriteWithoutFormat,
931 Capability::StorageImageReadWithoutFormat});
932 if (ST.hasOpenCLFullProfile())
933 addAvailableCaps(ToAdd: {Capability::Int64, Capability::Int64Atomics});
934 if (ST.hasOpenCLImageSupport()) {
935 addAvailableCaps(ToAdd: {Capability::ImageBasic, Capability::LiteralSampler,
936 Capability::Image1D, Capability::SampledBuffer,
937 Capability::ImageBuffer});
938 if (ST.isAtLeastOpenCLVer(VerToCompareTo: VersionTuple(2, 0)))
939 addAvailableCaps(ToAdd: {Capability::ImageReadWrite});
940 }
941 if (ST.isAtLeastSPIRVVer(VerToCompareTo: VersionTuple(1, 1)) &&
942 ST.isAtLeastOpenCLVer(VerToCompareTo: VersionTuple(2, 2)))
943 addAvailableCaps(ToAdd: {Capability::SubgroupDispatch, Capability::PipeStorage});
944 if (ST.isAtLeastSPIRVVer(VerToCompareTo: VersionTuple(1, 4)))
945 addAvailableCaps(ToAdd: {Capability::DenormPreserve, Capability::DenormFlushToZero,
946 Capability::SignedZeroInfNanPreserve,
947 Capability::RoundingModeRTE,
948 Capability::RoundingModeRTZ});
949 // TODO: verify if this needs some checks.
950 addAvailableCaps(ToAdd: {Capability::Float16, Capability::Float64});
951
952 // TODO: add OpenCL extensions.
953}
954
955void RequirementHandler::initAvailableCapabilitiesForVulkan(
956 const SPIRVSubtarget &ST) {
957
958 // Core in Vulkan 1.1 and earlier.
959 addAvailableCaps(ToAdd: {Capability::Int64, Capability::Float16, Capability::Float64,
960 Capability::GroupNonUniform, Capability::Image1D,
961 Capability::SampledBuffer, Capability::ImageBuffer,
962 Capability::UniformBufferArrayDynamicIndexing,
963 Capability::SampledImageArrayDynamicIndexing,
964 Capability::StorageBufferArrayDynamicIndexing,
965 Capability::StorageImageArrayDynamicIndexing,
966 Capability::DerivativeControl, Capability::MinLod,
967 Capability::ImageQuery, Capability::ImageGatherExtended,
968 Capability::Addresses, Capability::VulkanMemoryModelKHR});
969
970 // Became core in Vulkan 1.2
971 if (ST.isAtLeastSPIRVVer(VerToCompareTo: VersionTuple(1, 5))) {
972 addAvailableCaps(
973 ToAdd: {Capability::ShaderNonUniformEXT, Capability::RuntimeDescriptorArrayEXT,
974 Capability::InputAttachmentArrayDynamicIndexingEXT,
975 Capability::UniformTexelBufferArrayDynamicIndexingEXT,
976 Capability::StorageTexelBufferArrayDynamicIndexingEXT,
977 Capability::UniformBufferArrayNonUniformIndexingEXT,
978 Capability::SampledImageArrayNonUniformIndexingEXT,
979 Capability::StorageBufferArrayNonUniformIndexingEXT,
980 Capability::StorageImageArrayNonUniformIndexingEXT,
981 Capability::InputAttachmentArrayNonUniformIndexingEXT,
982 Capability::UniformTexelBufferArrayNonUniformIndexingEXT,
983 Capability::StorageTexelBufferArrayNonUniformIndexingEXT});
984 }
985
986 // Became core in Vulkan 1.3
987 if (ST.isAtLeastSPIRVVer(VerToCompareTo: VersionTuple(1, 6)))
988 addAvailableCaps(ToAdd: {Capability::StorageImageWriteWithoutFormat,
989 Capability::StorageImageReadWithoutFormat});
990}
991
992} // namespace SPIRV
993} // namespace llvm
994
995// Add the required capabilities from a decoration instruction (including
996// BuiltIns).
997static void addOpDecorateReqs(const MachineInstr &MI, unsigned DecIndex,
998 SPIRV::RequirementHandler &Reqs,
999 const SPIRVSubtarget &ST) {
1000 int64_t DecOp = MI.getOperand(i: DecIndex).getImm();
1001 auto Dec = static_cast<SPIRV::Decoration::Decoration>(DecOp);
1002 Reqs.addRequirements(Req: getSymbolicOperandRequirements(
1003 Category: SPIRV::OperandCategory::DecorationOperand, i: Dec, ST, Reqs));
1004
1005 if (Dec == SPIRV::Decoration::BuiltIn) {
1006 int64_t BuiltInOp = MI.getOperand(i: DecIndex + 1).getImm();
1007 auto BuiltIn = static_cast<SPIRV::BuiltIn::BuiltIn>(BuiltInOp);
1008 Reqs.addRequirements(Req: getSymbolicOperandRequirements(
1009 Category: SPIRV::OperandCategory::BuiltInOperand, i: BuiltIn, ST, Reqs));
1010 } else if (Dec == SPIRV::Decoration::LinkageAttributes) {
1011 int64_t LinkageOp = MI.getOperand(i: MI.getNumOperands() - 1).getImm();
1012 SPIRV::LinkageType::LinkageType LnkType =
1013 static_cast<SPIRV::LinkageType::LinkageType>(LinkageOp);
1014 if (LnkType == SPIRV::LinkageType::LinkOnceODR)
1015 Reqs.addExtension(ToAdd: SPIRV::Extension::SPV_KHR_linkonce_odr);
1016 } else if (Dec == SPIRV::Decoration::CacheControlLoadINTEL ||
1017 Dec == SPIRV::Decoration::CacheControlStoreINTEL) {
1018 Reqs.addExtension(ToAdd: SPIRV::Extension::SPV_INTEL_cache_controls);
1019 } else if (Dec == SPIRV::Decoration::HostAccessINTEL) {
1020 Reqs.addExtension(ToAdd: SPIRV::Extension::SPV_INTEL_global_variable_host_access);
1021 } else if (Dec == SPIRV::Decoration::InitModeINTEL ||
1022 Dec == SPIRV::Decoration::ImplementInRegisterMapINTEL) {
1023 Reqs.addExtension(
1024 ToAdd: SPIRV::Extension::SPV_INTEL_global_variable_fpga_decorations);
1025 } else if (Dec == SPIRV::Decoration::NonUniformEXT) {
1026 Reqs.addRequirements(Req: SPIRV::Capability::ShaderNonUniformEXT);
1027 } else if (Dec == SPIRV::Decoration::FPMaxErrorDecorationINTEL) {
1028 Reqs.addRequirements(Req: SPIRV::Capability::FPMaxErrorINTEL);
1029 Reqs.addExtension(ToAdd: SPIRV::Extension::SPV_INTEL_fp_max_error);
1030 } else if (Dec == SPIRV::Decoration::FPFastMathMode) {
1031 if (ST.canUseExtension(E: SPIRV::Extension::SPV_KHR_float_controls2)) {
1032 Reqs.addRequirements(Req: SPIRV::Capability::FloatControls2);
1033 Reqs.addExtension(ToAdd: SPIRV::Extension::SPV_KHR_float_controls2);
1034 }
1035 }
1036}
1037
1038// Add requirements for image handling.
1039static void addOpTypeImageReqs(const MachineInstr &MI,
1040 SPIRV::RequirementHandler &Reqs,
1041 const SPIRVSubtarget &ST) {
1042 assert(MI.getNumOperands() >= 8 && "Insufficient operands for OpTypeImage");
1043 // The operand indices used here are based on the OpTypeImage layout, which
1044 // the MachineInstr follows as well.
1045 int64_t ImgFormatOp = MI.getOperand(i: 7).getImm();
1046 auto ImgFormat = static_cast<SPIRV::ImageFormat::ImageFormat>(ImgFormatOp);
1047 Reqs.getAndAddRequirements(Category: SPIRV::OperandCategory::ImageFormatOperand,
1048 i: ImgFormat, ST);
1049
1050 bool IsArrayed = MI.getOperand(i: 4).getImm() == 1;
1051 bool IsMultisampled = MI.getOperand(i: 5).getImm() == 1;
1052 bool NoSampler = MI.getOperand(i: 6).getImm() == 2;
1053 // Add dimension requirements.
1054 assert(MI.getOperand(2).isImm());
1055 switch (MI.getOperand(i: 2).getImm()) {
1056 case SPIRV::Dim::DIM_1D:
1057 Reqs.addRequirements(Req: NoSampler ? SPIRV::Capability::Image1D
1058 : SPIRV::Capability::Sampled1D);
1059 break;
1060 case SPIRV::Dim::DIM_2D:
1061 if (IsMultisampled && NoSampler)
1062 Reqs.addRequirements(Req: SPIRV::Capability::ImageMSArray);
1063 break;
1064 case SPIRV::Dim::DIM_Cube:
1065 Reqs.addRequirements(Req: SPIRV::Capability::Shader);
1066 if (IsArrayed)
1067 Reqs.addRequirements(Req: NoSampler ? SPIRV::Capability::ImageCubeArray
1068 : SPIRV::Capability::SampledCubeArray);
1069 break;
1070 case SPIRV::Dim::DIM_Rect:
1071 Reqs.addRequirements(Req: NoSampler ? SPIRV::Capability::ImageRect
1072 : SPIRV::Capability::SampledRect);
1073 break;
1074 case SPIRV::Dim::DIM_Buffer:
1075 Reqs.addRequirements(Req: NoSampler ? SPIRV::Capability::ImageBuffer
1076 : SPIRV::Capability::SampledBuffer);
1077 break;
1078 case SPIRV::Dim::DIM_SubpassData:
1079 Reqs.addRequirements(Req: SPIRV::Capability::InputAttachment);
1080 break;
1081 }
1082
1083 // Has optional access qualifier.
1084 if (!ST.isShader()) {
1085 if (MI.getNumOperands() > 8 &&
1086 MI.getOperand(i: 8).getImm() == SPIRV::AccessQualifier::ReadWrite)
1087 Reqs.addRequirements(Req: SPIRV::Capability::ImageReadWrite);
1088 else
1089 Reqs.addRequirements(Req: SPIRV::Capability::ImageBasic);
1090 }
1091}
1092
1093static bool isBFloat16Type(SPIRVTypeInst TypeDef) {
1094 return TypeDef && TypeDef->getNumOperands() == 3 &&
1095 TypeDef->getOpcode() == SPIRV::OpTypeFloat &&
1096 TypeDef->getOperand(i: 1).getImm() == 16 &&
1097 TypeDef->getOperand(i: 2).getImm() == SPIRV::FPEncoding::BFloat16KHR;
1098}
1099
1100// Add requirements for handling atomic float instructions
1101#define ATOM_FLT_REQ_EXT_MSG(ExtName) \
1102 "The atomic float instruction requires the following SPIR-V " \
1103 "extension: SPV_EXT_shader_atomic_float" ExtName
1104static void AddAtomicVectorFloatRequirements(const MachineInstr &MI,
1105 SPIRV::RequirementHandler &Reqs,
1106 const SPIRVSubtarget &ST) {
1107 SPIRVTypeInst VecTypeDef =
1108 MI.getMF()->getRegInfo().getVRegDef(Reg: MI.getOperand(i: 1).getReg());
1109
1110 const unsigned Rank = VecTypeDef->getOperand(i: 2).getImm();
1111 if (Rank != 2 && Rank != 4)
1112 reportFatalUsageError(reason: "Result type of an atomic vector float instruction "
1113 "must be a 2-component or 4 component vector");
1114
1115 SPIRVTypeInst EltTypeDef =
1116 MI.getMF()->getRegInfo().getVRegDef(Reg: VecTypeDef->getOperand(i: 1).getReg());
1117
1118 if (EltTypeDef->getOpcode() != SPIRV::OpTypeFloat ||
1119 EltTypeDef->getOperand(i: 1).getImm() != 16)
1120 reportFatalUsageError(
1121 reason: "The element type for the result type of an atomic vector float "
1122 "instruction must be a 16-bit floating-point scalar");
1123
1124 if (isBFloat16Type(TypeDef: EltTypeDef))
1125 reportFatalUsageError(
1126 reason: "The element type for the result type of an atomic vector float "
1127 "instruction cannot be a bfloat16 scalar");
1128 if (!ST.canUseExtension(E: SPIRV::Extension::SPV_NV_shader_atomic_fp16_vector))
1129 reportFatalUsageError(
1130 reason: "The atomic float16 vector instruction requires the following SPIR-V "
1131 "extension: SPV_NV_shader_atomic_fp16_vector");
1132
1133 Reqs.addExtension(ToAdd: SPIRV::Extension::SPV_NV_shader_atomic_fp16_vector);
1134 Reqs.addCapability(ToAdd: SPIRV::Capability::AtomicFloat16VectorNV);
1135}
1136
1137static void AddAtomicFloatRequirements(const MachineInstr &MI,
1138 SPIRV::RequirementHandler &Reqs,
1139 const SPIRVSubtarget &ST) {
1140 assert(MI.getOperand(1).isReg() &&
1141 "Expect register operand in atomic float instruction");
1142 Register TypeReg = MI.getOperand(i: 1).getReg();
1143 SPIRVTypeInst TypeDef = MI.getMF()->getRegInfo().getVRegDef(Reg: TypeReg);
1144
1145 if (TypeDef->getOpcode() == SPIRV::OpTypeVector)
1146 return AddAtomicVectorFloatRequirements(MI, Reqs, ST);
1147
1148 if (TypeDef->getOpcode() != SPIRV::OpTypeFloat)
1149 report_fatal_error(reason: "Result type of an atomic float instruction must be a "
1150 "floating-point type scalar");
1151
1152 unsigned BitWidth = TypeDef->getOperand(i: 1).getImm();
1153 unsigned Op = MI.getOpcode();
1154 if (Op == SPIRV::OpAtomicFAddEXT) {
1155 if (!ST.canUseExtension(E: SPIRV::Extension::SPV_EXT_shader_atomic_float_add))
1156 report_fatal_error(ATOM_FLT_REQ_EXT_MSG("_add"), gen_crash_diag: false);
1157 Reqs.addExtension(ToAdd: SPIRV::Extension::SPV_EXT_shader_atomic_float_add);
1158 switch (BitWidth) {
1159 case 16:
1160 if (isBFloat16Type(TypeDef)) {
1161 if (!ST.canUseExtension(E: SPIRV::Extension::SPV_INTEL_16bit_atomics))
1162 report_fatal_error(
1163 reason: "The atomic bfloat16 instruction requires the following SPIR-V "
1164 "extension: SPV_INTEL_16bit_atomics",
1165 gen_crash_diag: false);
1166 Reqs.addExtension(ToAdd: SPIRV::Extension::SPV_INTEL_16bit_atomics);
1167 Reqs.addCapability(ToAdd: SPIRV::Capability::AtomicBFloat16AddINTEL);
1168 } else {
1169 if (!ST.canUseExtension(
1170 E: SPIRV::Extension::SPV_EXT_shader_atomic_float16_add))
1171 report_fatal_error(ATOM_FLT_REQ_EXT_MSG("16_add"), gen_crash_diag: false);
1172 Reqs.addExtension(ToAdd: SPIRV::Extension::SPV_EXT_shader_atomic_float16_add);
1173 Reqs.addCapability(ToAdd: SPIRV::Capability::AtomicFloat16AddEXT);
1174 }
1175 break;
1176 case 32:
1177 Reqs.addCapability(ToAdd: SPIRV::Capability::AtomicFloat32AddEXT);
1178 break;
1179 case 64:
1180 Reqs.addCapability(ToAdd: SPIRV::Capability::AtomicFloat64AddEXT);
1181 break;
1182 default:
1183 report_fatal_error(
1184 reason: "Unexpected floating-point type width in atomic float instruction");
1185 }
1186 } else {
1187 if (!ST.canUseExtension(
1188 E: SPIRV::Extension::SPV_EXT_shader_atomic_float_min_max))
1189 report_fatal_error(ATOM_FLT_REQ_EXT_MSG("_min_max"), gen_crash_diag: false);
1190 Reqs.addExtension(ToAdd: SPIRV::Extension::SPV_EXT_shader_atomic_float_min_max);
1191 switch (BitWidth) {
1192 case 16:
1193 if (isBFloat16Type(TypeDef)) {
1194 if (!ST.canUseExtension(E: SPIRV::Extension::SPV_INTEL_16bit_atomics))
1195 report_fatal_error(
1196 reason: "The atomic bfloat16 instruction requires the following SPIR-V "
1197 "extension: SPV_INTEL_16bit_atomics",
1198 gen_crash_diag: false);
1199 Reqs.addExtension(ToAdd: SPIRV::Extension::SPV_INTEL_16bit_atomics);
1200 Reqs.addCapability(ToAdd: SPIRV::Capability::AtomicBFloat16MinMaxINTEL);
1201 } else {
1202 Reqs.addCapability(ToAdd: SPIRV::Capability::AtomicFloat16MinMaxEXT);
1203 }
1204 break;
1205 case 32:
1206 Reqs.addCapability(ToAdd: SPIRV::Capability::AtomicFloat32MinMaxEXT);
1207 break;
1208 case 64:
1209 Reqs.addCapability(ToAdd: SPIRV::Capability::AtomicFloat64MinMaxEXT);
1210 break;
1211 default:
1212 report_fatal_error(
1213 reason: "Unexpected floating-point type width in atomic float instruction");
1214 }
1215 }
1216}
1217
1218bool isUniformTexelBuffer(MachineInstr *ImageInst) {
1219 if (ImageInst->getOpcode() != SPIRV::OpTypeImage)
1220 return false;
1221 uint32_t Dim = ImageInst->getOperand(i: 2).getImm();
1222 uint32_t Sampled = ImageInst->getOperand(i: 6).getImm();
1223 return Dim == SPIRV::Dim::DIM_Buffer && Sampled == 1;
1224}
1225
1226bool isStorageTexelBuffer(MachineInstr *ImageInst) {
1227 if (ImageInst->getOpcode() != SPIRV::OpTypeImage)
1228 return false;
1229 uint32_t Dim = ImageInst->getOperand(i: 2).getImm();
1230 uint32_t Sampled = ImageInst->getOperand(i: 6).getImm();
1231 return Dim == SPIRV::Dim::DIM_Buffer && Sampled == 2;
1232}
1233
1234bool isSampledImage(MachineInstr *ImageInst) {
1235 if (ImageInst->getOpcode() != SPIRV::OpTypeImage)
1236 return false;
1237 uint32_t Dim = ImageInst->getOperand(i: 2).getImm();
1238 uint32_t Sampled = ImageInst->getOperand(i: 6).getImm();
1239 return Dim != SPIRV::Dim::DIM_Buffer && Sampled == 1;
1240}
1241
1242bool isInputAttachment(MachineInstr *ImageInst) {
1243 if (ImageInst->getOpcode() != SPIRV::OpTypeImage)
1244 return false;
1245 uint32_t Dim = ImageInst->getOperand(i: 2).getImm();
1246 uint32_t Sampled = ImageInst->getOperand(i: 6).getImm();
1247 return Dim == SPIRV::Dim::DIM_SubpassData && Sampled == 2;
1248}
1249
1250bool isStorageImage(MachineInstr *ImageInst) {
1251 if (ImageInst->getOpcode() != SPIRV::OpTypeImage)
1252 return false;
1253 uint32_t Dim = ImageInst->getOperand(i: 2).getImm();
1254 uint32_t Sampled = ImageInst->getOperand(i: 6).getImm();
1255 return Dim != SPIRV::Dim::DIM_Buffer && Sampled == 2;
1256}
1257
1258bool isCombinedImageSampler(MachineInstr *SampledImageInst) {
1259 if (SampledImageInst->getOpcode() != SPIRV::OpTypeSampledImage)
1260 return false;
1261
1262 const MachineRegisterInfo &MRI = SampledImageInst->getMF()->getRegInfo();
1263 Register ImageReg = SampledImageInst->getOperand(i: 1).getReg();
1264 auto *ImageInst = MRI.getUniqueVRegDef(Reg: ImageReg);
1265 return isSampledImage(ImageInst);
1266}
1267
1268bool hasNonUniformDecoration(Register Reg, const MachineRegisterInfo &MRI) {
1269 for (const auto &MI : MRI.reg_instructions(Reg)) {
1270 if (MI.getOpcode() != SPIRV::OpDecorate)
1271 continue;
1272
1273 uint32_t Dec = MI.getOperand(i: 1).getImm();
1274 if (Dec == SPIRV::Decoration::NonUniformEXT)
1275 return true;
1276 }
1277 return false;
1278}
1279
1280void addOpAccessChainReqs(const MachineInstr &Instr,
1281 SPIRV::RequirementHandler &Handler,
1282 const SPIRVSubtarget &Subtarget) {
1283 const MachineRegisterInfo &MRI = Instr.getMF()->getRegInfo();
1284 // Get the result type. If it is an image type, then the shader uses
1285 // descriptor indexing. The appropriate capabilities will be added based
1286 // on the specifics of the image.
1287 Register ResTypeReg = Instr.getOperand(i: 1).getReg();
1288 MachineInstr *ResTypeInst = MRI.getUniqueVRegDef(Reg: ResTypeReg);
1289
1290 assert(ResTypeInst->getOpcode() == SPIRV::OpTypePointer);
1291 uint32_t StorageClass = ResTypeInst->getOperand(i: 1).getImm();
1292 if (StorageClass != SPIRV::StorageClass::StorageClass::UniformConstant &&
1293 StorageClass != SPIRV::StorageClass::StorageClass::Uniform &&
1294 StorageClass != SPIRV::StorageClass::StorageClass::StorageBuffer) {
1295 return;
1296 }
1297
1298 bool IsNonUniform =
1299 hasNonUniformDecoration(Reg: Instr.getOperand(i: 0).getReg(), MRI);
1300
1301 auto FirstIndexReg = Instr.getOperand(i: 3).getReg();
1302 bool FirstIndexIsConstant =
1303 Subtarget.getInstrInfo()->isConstantInstr(MI: *MRI.getVRegDef(Reg: FirstIndexReg));
1304
1305 if (StorageClass == SPIRV::StorageClass::StorageClass::StorageBuffer) {
1306 if (IsNonUniform)
1307 Handler.addRequirements(
1308 Req: SPIRV::Capability::StorageBufferArrayNonUniformIndexingEXT);
1309 else if (!FirstIndexIsConstant)
1310 Handler.addRequirements(
1311 Req: SPIRV::Capability::StorageBufferArrayDynamicIndexing);
1312 return;
1313 }
1314
1315 Register PointeeTypeReg = ResTypeInst->getOperand(i: 2).getReg();
1316 MachineInstr *PointeeType = MRI.getUniqueVRegDef(Reg: PointeeTypeReg);
1317 if (PointeeType->getOpcode() != SPIRV::OpTypeImage &&
1318 PointeeType->getOpcode() != SPIRV::OpTypeSampledImage &&
1319 PointeeType->getOpcode() != SPIRV::OpTypeSampler) {
1320 return;
1321 }
1322
1323 if (isUniformTexelBuffer(ImageInst: PointeeType)) {
1324 if (IsNonUniform)
1325 Handler.addRequirements(
1326 Req: SPIRV::Capability::UniformTexelBufferArrayNonUniformIndexingEXT);
1327 else if (!FirstIndexIsConstant)
1328 Handler.addRequirements(
1329 Req: SPIRV::Capability::UniformTexelBufferArrayDynamicIndexingEXT);
1330 } else if (isInputAttachment(ImageInst: PointeeType)) {
1331 if (IsNonUniform)
1332 Handler.addRequirements(
1333 Req: SPIRV::Capability::InputAttachmentArrayNonUniformIndexingEXT);
1334 else if (!FirstIndexIsConstant)
1335 Handler.addRequirements(
1336 Req: SPIRV::Capability::InputAttachmentArrayDynamicIndexingEXT);
1337 } else if (isStorageTexelBuffer(ImageInst: PointeeType)) {
1338 if (IsNonUniform)
1339 Handler.addRequirements(
1340 Req: SPIRV::Capability::StorageTexelBufferArrayNonUniformIndexingEXT);
1341 else if (!FirstIndexIsConstant)
1342 Handler.addRequirements(
1343 Req: SPIRV::Capability::StorageTexelBufferArrayDynamicIndexingEXT);
1344 } else if (isSampledImage(ImageInst: PointeeType) ||
1345 isCombinedImageSampler(SampledImageInst: PointeeType) ||
1346 PointeeType->getOpcode() == SPIRV::OpTypeSampler) {
1347 if (IsNonUniform)
1348 Handler.addRequirements(
1349 Req: SPIRV::Capability::SampledImageArrayNonUniformIndexingEXT);
1350 else if (!FirstIndexIsConstant)
1351 Handler.addRequirements(
1352 Req: SPIRV::Capability::SampledImageArrayDynamicIndexing);
1353 } else if (isStorageImage(ImageInst: PointeeType)) {
1354 if (IsNonUniform)
1355 Handler.addRequirements(
1356 Req: SPIRV::Capability::StorageImageArrayNonUniformIndexingEXT);
1357 else if (!FirstIndexIsConstant)
1358 Handler.addRequirements(
1359 Req: SPIRV::Capability::StorageImageArrayDynamicIndexing);
1360 }
1361}
1362
1363static bool isImageTypeWithUnknownFormat(SPIRVTypeInst TypeInst) {
1364 if (TypeInst->getOpcode() != SPIRV::OpTypeImage)
1365 return false;
1366 assert(TypeInst->getOperand(7).isImm() && "The image format must be an imm.");
1367 return TypeInst->getOperand(i: 7).getImm() == 0;
1368}
1369
1370static void AddDotProductRequirements(const MachineInstr &MI,
1371 SPIRV::RequirementHandler &Reqs,
1372 const SPIRVSubtarget &ST) {
1373 if (ST.canUseExtension(E: SPIRV::Extension::SPV_KHR_integer_dot_product))
1374 Reqs.addExtension(ToAdd: SPIRV::Extension::SPV_KHR_integer_dot_product);
1375 Reqs.addCapability(ToAdd: SPIRV::Capability::DotProduct);
1376
1377 const MachineRegisterInfo &MRI = MI.getMF()->getRegInfo();
1378 assert(MI.getOperand(2).isReg() && "Unexpected operand in dot");
1379 // We do not consider what the previous instruction is. This is just used
1380 // to get the input register and to check the type.
1381 const MachineInstr *Input = MRI.getVRegDef(Reg: MI.getOperand(i: 2).getReg());
1382 assert(Input->getOperand(1).isReg() && "Unexpected operand in dot input");
1383 Register InputReg = Input->getOperand(i: 1).getReg();
1384
1385 SPIRVTypeInst TypeDef = MRI.getVRegDef(Reg: InputReg);
1386 if (TypeDef->getOpcode() == SPIRV::OpTypeInt) {
1387 assert(TypeDef->getOperand(1).getImm() == 32);
1388 Reqs.addCapability(ToAdd: SPIRV::Capability::DotProductInput4x8BitPacked);
1389 } else if (TypeDef->getOpcode() == SPIRV::OpTypeVector) {
1390 SPIRVTypeInst ScalarTypeDef =
1391 MRI.getVRegDef(Reg: TypeDef->getOperand(i: 1).getReg());
1392 assert(ScalarTypeDef->getOpcode() == SPIRV::OpTypeInt);
1393 if (ScalarTypeDef->getOperand(i: 1).getImm() == 8) {
1394 assert(TypeDef->getOperand(2).getImm() == 4 &&
1395 "Dot operand of 8-bit integer type requires 4 components");
1396 Reqs.addCapability(ToAdd: SPIRV::Capability::DotProductInput4x8Bit);
1397 } else {
1398 Reqs.addCapability(ToAdd: SPIRV::Capability::DotProductInputAll);
1399 }
1400 }
1401}
1402
1403void addPrintfRequirements(const MachineInstr &MI,
1404 SPIRV::RequirementHandler &Reqs,
1405 const SPIRVSubtarget &ST) {
1406 SPIRVGlobalRegistry *GR = ST.getSPIRVGlobalRegistry();
1407 SPIRVTypeInst PtrType = GR->getSPIRVTypeForVReg(VReg: MI.getOperand(i: 4).getReg());
1408 if (PtrType) {
1409 MachineOperand ASOp = PtrType->getOperand(i: 1);
1410 if (ASOp.isImm()) {
1411 unsigned AddrSpace = ASOp.getImm();
1412 if (AddrSpace != SPIRV::StorageClass::UniformConstant) {
1413 if (!ST.canUseExtension(
1414 E: SPIRV::Extension::
1415 SPV_EXT_relaxed_printf_string_address_space)) {
1416 report_fatal_error(reason: "SPV_EXT_relaxed_printf_string_address_space is "
1417 "required because printf uses a format string not "
1418 "in constant address space.",
1419 gen_crash_diag: false);
1420 }
1421 Reqs.addExtension(
1422 ToAdd: SPIRV::Extension::SPV_EXT_relaxed_printf_string_address_space);
1423 }
1424 }
1425 }
1426}
1427
1428static void addImageOperandReqs(const MachineInstr &MI,
1429 SPIRV::RequirementHandler &Reqs,
1430 const SPIRVSubtarget &ST, unsigned OpIdx) {
1431 if (MI.getNumOperands() <= OpIdx)
1432 return;
1433 uint32_t Mask = MI.getOperand(i: OpIdx).getImm();
1434 for (uint32_t I = 0; I < 32; ++I)
1435 if (Mask & (1U << I))
1436 Reqs.getAndAddRequirements(Category: SPIRV::OperandCategory::ImageOperandOperand,
1437 i: 1U << I, ST);
1438}
1439
1440void addInstrRequirements(const MachineInstr &MI,
1441 SPIRV::ModuleAnalysisInfo &MAI,
1442 const SPIRVSubtarget &ST) {
1443 SPIRV::RequirementHandler &Reqs = MAI.Reqs;
1444 unsigned Op = MI.getOpcode();
1445 switch (Op) {
1446 case SPIRV::OpMemoryModel: {
1447 int64_t Addr = MI.getOperand(i: 0).getImm();
1448 Reqs.getAndAddRequirements(Category: SPIRV::OperandCategory::AddressingModelOperand,
1449 i: Addr, ST);
1450 int64_t Mem = MI.getOperand(i: 1).getImm();
1451 Reqs.getAndAddRequirements(Category: SPIRV::OperandCategory::MemoryModelOperand, i: Mem,
1452 ST);
1453 break;
1454 }
1455 case SPIRV::OpEntryPoint: {
1456 int64_t Exe = MI.getOperand(i: 0).getImm();
1457 Reqs.getAndAddRequirements(Category: SPIRV::OperandCategory::ExecutionModelOperand,
1458 i: Exe, ST);
1459 break;
1460 }
1461 case SPIRV::OpExecutionMode:
1462 case SPIRV::OpExecutionModeId: {
1463 int64_t Exe = MI.getOperand(i: 1).getImm();
1464 Reqs.getAndAddRequirements(Category: SPIRV::OperandCategory::ExecutionModeOperand,
1465 i: Exe, ST);
1466 break;
1467 }
1468 case SPIRV::OpTypeMatrix:
1469 Reqs.addCapability(ToAdd: SPIRV::Capability::Matrix);
1470 break;
1471 case SPIRV::OpTypeInt: {
1472 unsigned BitWidth = MI.getOperand(i: 1).getImm();
1473 if (BitWidth == 64)
1474 Reqs.addCapability(ToAdd: SPIRV::Capability::Int64);
1475 else if (BitWidth == 16)
1476 Reqs.addCapability(ToAdd: SPIRV::Capability::Int16);
1477 else if (BitWidth == 8)
1478 Reqs.addCapability(ToAdd: SPIRV::Capability::Int8);
1479 else if (BitWidth == 4 &&
1480 ST.canUseExtension(E: SPIRV::Extension::SPV_INTEL_int4)) {
1481 Reqs.addExtension(ToAdd: SPIRV::Extension::SPV_INTEL_int4);
1482 Reqs.addCapability(ToAdd: SPIRV::Capability::Int4TypeINTEL);
1483 } else if (BitWidth != 32) {
1484 if (!ST.canUseExtension(
1485 E: SPIRV::Extension::SPV_ALTERA_arbitrary_precision_integers))
1486 reportFatalUsageError(
1487 reason: "OpTypeInt type with a width other than 8, 16, 32 or 64 bits "
1488 "requires the following SPIR-V extension: "
1489 "SPV_ALTERA_arbitrary_precision_integers");
1490 Reqs.addExtension(
1491 ToAdd: SPIRV::Extension::SPV_ALTERA_arbitrary_precision_integers);
1492 Reqs.addCapability(ToAdd: SPIRV::Capability::ArbitraryPrecisionIntegersALTERA);
1493 }
1494 break;
1495 }
1496 case SPIRV::OpDot: {
1497 const MachineRegisterInfo &MRI = MI.getMF()->getRegInfo();
1498 SPIRVTypeInst TypeDef = MRI.getVRegDef(Reg: MI.getOperand(i: 1).getReg());
1499 if (isBFloat16Type(TypeDef))
1500 Reqs.addCapability(ToAdd: SPIRV::Capability::BFloat16DotProductKHR);
1501 break;
1502 }
1503 case SPIRV::OpTypeFloat: {
1504 unsigned BitWidth = MI.getOperand(i: 1).getImm();
1505 if (BitWidth == 64)
1506 Reqs.addCapability(ToAdd: SPIRV::Capability::Float64);
1507 else if (BitWidth == 16) {
1508 if (isBFloat16Type(TypeDef: &MI)) {
1509 if (!ST.canUseExtension(E: SPIRV::Extension::SPV_KHR_bfloat16))
1510 report_fatal_error(reason: "OpTypeFloat type with bfloat requires the "
1511 "following SPIR-V extension: SPV_KHR_bfloat16",
1512 gen_crash_diag: false);
1513 Reqs.addExtension(ToAdd: SPIRV::Extension::SPV_KHR_bfloat16);
1514 Reqs.addCapability(ToAdd: SPIRV::Capability::BFloat16TypeKHR);
1515 } else {
1516 Reqs.addCapability(ToAdd: SPIRV::Capability::Float16);
1517 }
1518 }
1519 break;
1520 }
1521 case SPIRV::OpTypeVector: {
1522 unsigned NumComponents = MI.getOperand(i: 2).getImm();
1523 if (NumComponents == 8 || NumComponents == 16)
1524 Reqs.addCapability(ToAdd: SPIRV::Capability::Vector16);
1525
1526 assert(MI.getOperand(1).isReg());
1527 const MachineRegisterInfo &MRI = MI.getMF()->getRegInfo();
1528 SPIRVTypeInst ElemTypeDef = MRI.getVRegDef(Reg: MI.getOperand(i: 1).getReg());
1529 if (ElemTypeDef->getOpcode() == SPIRV::OpTypePointer &&
1530 ST.canUseExtension(E: SPIRV::Extension::SPV_INTEL_masked_gather_scatter)) {
1531 Reqs.addExtension(ToAdd: SPIRV::Extension::SPV_INTEL_masked_gather_scatter);
1532 Reqs.addCapability(ToAdd: SPIRV::Capability::MaskedGatherScatterINTEL);
1533 }
1534 break;
1535 }
1536 case SPIRV::OpTypePointer: {
1537 auto SC = MI.getOperand(i: 1).getImm();
1538 Reqs.getAndAddRequirements(Category: SPIRV::OperandCategory::StorageClassOperand, i: SC,
1539 ST);
1540 // If it's a type of pointer to float16 targeting OpenCL, add Float16Buffer
1541 // capability.
1542 if (ST.isShader())
1543 break;
1544 assert(MI.getOperand(2).isReg());
1545 const MachineRegisterInfo &MRI = MI.getMF()->getRegInfo();
1546 SPIRVTypeInst TypeDef = MRI.getVRegDef(Reg: MI.getOperand(i: 2).getReg());
1547 if ((TypeDef->getNumOperands() == 2) &&
1548 (TypeDef->getOpcode() == SPIRV::OpTypeFloat) &&
1549 (TypeDef->getOperand(i: 1).getImm() == 16))
1550 Reqs.addCapability(ToAdd: SPIRV::Capability::Float16Buffer);
1551 break;
1552 }
1553 case SPIRV::OpExtInst: {
1554 if (MI.getOperand(i: 2).getImm() ==
1555 static_cast<int64_t>(
1556 SPIRV::InstructionSet::NonSemantic_Shader_DebugInfo_100)) {
1557 Reqs.addExtension(ToAdd: SPIRV::Extension::SPV_KHR_non_semantic_info);
1558 break;
1559 }
1560 if (MI.getOperand(i: 3).getImm() ==
1561 static_cast<int64_t>(SPIRV::OpenCLExtInst::printf)) {
1562 addPrintfRequirements(MI, Reqs, ST);
1563 break;
1564 }
1565 // TODO: handle bfloat16 extended instructions when
1566 // SPV_INTEL_bfloat16_arithmetic is enabled.
1567 break;
1568 }
1569 case SPIRV::OpAliasDomainDeclINTEL:
1570 case SPIRV::OpAliasScopeDeclINTEL:
1571 case SPIRV::OpAliasScopeListDeclINTEL: {
1572 Reqs.addExtension(ToAdd: SPIRV::Extension::SPV_INTEL_memory_access_aliasing);
1573 Reqs.addCapability(ToAdd: SPIRV::Capability::MemoryAccessAliasingINTEL);
1574 break;
1575 }
1576 case SPIRV::OpBitReverse:
1577 case SPIRV::OpBitFieldInsert:
1578 case SPIRV::OpBitFieldSExtract:
1579 case SPIRV::OpBitFieldUExtract:
1580 if (!ST.canUseExtension(E: SPIRV::Extension::SPV_KHR_bit_instructions)) {
1581 Reqs.addCapability(ToAdd: SPIRV::Capability::Shader);
1582 break;
1583 }
1584 Reqs.addExtension(ToAdd: SPIRV::Extension::SPV_KHR_bit_instructions);
1585 Reqs.addCapability(ToAdd: SPIRV::Capability::BitInstructions);
1586 break;
1587 case SPIRV::OpTypeRuntimeArray:
1588 Reqs.addCapability(ToAdd: SPIRV::Capability::Shader);
1589 break;
1590 case SPIRV::OpTypeOpaque:
1591 case SPIRV::OpTypeEvent:
1592 Reqs.addCapability(ToAdd: SPIRV::Capability::Kernel);
1593 break;
1594 case SPIRV::OpTypePipe:
1595 case SPIRV::OpTypeReserveId:
1596 Reqs.addCapability(ToAdd: SPIRV::Capability::Pipes);
1597 break;
1598 case SPIRV::OpTypeDeviceEvent:
1599 case SPIRV::OpTypeQueue:
1600 case SPIRV::OpBuildNDRange:
1601 Reqs.addCapability(ToAdd: SPIRV::Capability::DeviceEnqueue);
1602 break;
1603 case SPIRV::OpDecorate:
1604 case SPIRV::OpDecorateId:
1605 case SPIRV::OpDecorateString:
1606 addOpDecorateReqs(MI, DecIndex: 1, Reqs, ST);
1607 break;
1608 case SPIRV::OpMemberDecorate:
1609 case SPIRV::OpMemberDecorateString:
1610 addOpDecorateReqs(MI, DecIndex: 2, Reqs, ST);
1611 break;
1612 case SPIRV::OpInBoundsPtrAccessChain:
1613 Reqs.addCapability(ToAdd: SPIRV::Capability::Addresses);
1614 break;
1615 case SPIRV::OpConstantSampler:
1616 Reqs.addCapability(ToAdd: SPIRV::Capability::LiteralSampler);
1617 break;
1618 case SPIRV::OpInBoundsAccessChain:
1619 case SPIRV::OpAccessChain:
1620 addOpAccessChainReqs(Instr: MI, Handler&: Reqs, Subtarget: ST);
1621 break;
1622 case SPIRV::OpTypeImage:
1623 addOpTypeImageReqs(MI, Reqs, ST);
1624 break;
1625 case SPIRV::OpTypeSampler:
1626 if (!ST.isShader()) {
1627 Reqs.addCapability(ToAdd: SPIRV::Capability::ImageBasic);
1628 }
1629 break;
1630 case SPIRV::OpTypeForwardPointer:
1631 // TODO: check if it's OpenCL's kernel.
1632 Reqs.addCapability(ToAdd: SPIRV::Capability::Addresses);
1633 break;
1634 case SPIRV::OpAtomicFlagTestAndSet:
1635 case SPIRV::OpAtomicLoad:
1636 case SPIRV::OpAtomicStore:
1637 case SPIRV::OpAtomicExchange:
1638 case SPIRV::OpAtomicCompareExchange:
1639 case SPIRV::OpAtomicIIncrement:
1640 case SPIRV::OpAtomicIDecrement:
1641 case SPIRV::OpAtomicIAdd:
1642 case SPIRV::OpAtomicISub:
1643 case SPIRV::OpAtomicUMin:
1644 case SPIRV::OpAtomicUMax:
1645 case SPIRV::OpAtomicSMin:
1646 case SPIRV::OpAtomicSMax:
1647 case SPIRV::OpAtomicAnd:
1648 case SPIRV::OpAtomicOr:
1649 case SPIRV::OpAtomicXor: {
1650 const MachineRegisterInfo &MRI = MI.getMF()->getRegInfo();
1651 const MachineInstr *InstrPtr = &MI;
1652 if (Op == SPIRV::OpAtomicStore) {
1653 assert(MI.getOperand(3).isReg());
1654 InstrPtr = MRI.getVRegDef(Reg: MI.getOperand(i: 3).getReg());
1655 assert(InstrPtr && "Unexpected type instruction for OpAtomicStore");
1656 }
1657 assert(InstrPtr->getOperand(1).isReg() && "Unexpected operand in atomic");
1658 Register TypeReg = InstrPtr->getOperand(i: 1).getReg();
1659 SPIRVTypeInst TypeDef = MRI.getVRegDef(Reg: TypeReg);
1660
1661 if (TypeDef->getOpcode() == SPIRV::OpTypeInt) {
1662 unsigned BitWidth = TypeDef->getOperand(i: 1).getImm();
1663 if (BitWidth == 64)
1664 Reqs.addCapability(ToAdd: SPIRV::Capability::Int64Atomics);
1665 else if (BitWidth == 16) {
1666 if (!ST.canUseExtension(E: SPIRV::Extension::SPV_INTEL_16bit_atomics))
1667 report_fatal_error(
1668 reason: "16-bit integer atomic operations require the following SPIR-V "
1669 "extension: SPV_INTEL_16bit_atomics",
1670 gen_crash_diag: false);
1671 Reqs.addExtension(ToAdd: SPIRV::Extension::SPV_INTEL_16bit_atomics);
1672 switch (Op) {
1673 case SPIRV::OpAtomicLoad:
1674 case SPIRV::OpAtomicStore:
1675 case SPIRV::OpAtomicExchange:
1676 case SPIRV::OpAtomicCompareExchange:
1677 case SPIRV::OpAtomicCompareExchangeWeak:
1678 Reqs.addCapability(
1679 ToAdd: SPIRV::Capability::AtomicInt16CompareExchangeINTEL);
1680 break;
1681 default:
1682 Reqs.addCapability(ToAdd: SPIRV::Capability::Int16AtomicsINTEL);
1683 break;
1684 }
1685 }
1686 } else if (isBFloat16Type(TypeDef)) {
1687 if (is_contained(Set: {SPIRV::OpAtomicLoad, SPIRV::OpAtomicStore,
1688 SPIRV::OpAtomicExchange},
1689 Element: Op)) {
1690 if (!ST.canUseExtension(E: SPIRV::Extension::SPV_INTEL_16bit_atomics))
1691 report_fatal_error(
1692 reason: "The atomic bfloat16 instruction requires the following SPIR-V "
1693 "extension: SPV_INTEL_16bit_atomics",
1694 gen_crash_diag: false);
1695 Reqs.addExtension(ToAdd: SPIRV::Extension::SPV_INTEL_16bit_atomics);
1696 Reqs.addCapability(ToAdd: SPIRV::Capability::AtomicBFloat16LoadStoreINTEL);
1697 }
1698 }
1699 break;
1700 }
1701 case SPIRV::OpGroupNonUniformIAdd:
1702 case SPIRV::OpGroupNonUniformFAdd:
1703 case SPIRV::OpGroupNonUniformIMul:
1704 case SPIRV::OpGroupNonUniformFMul:
1705 case SPIRV::OpGroupNonUniformSMin:
1706 case SPIRV::OpGroupNonUniformUMin:
1707 case SPIRV::OpGroupNonUniformFMin:
1708 case SPIRV::OpGroupNonUniformSMax:
1709 case SPIRV::OpGroupNonUniformUMax:
1710 case SPIRV::OpGroupNonUniformFMax:
1711 case SPIRV::OpGroupNonUniformBitwiseAnd:
1712 case SPIRV::OpGroupNonUniformBitwiseOr:
1713 case SPIRV::OpGroupNonUniformBitwiseXor:
1714 case SPIRV::OpGroupNonUniformLogicalAnd:
1715 case SPIRV::OpGroupNonUniformLogicalOr:
1716 case SPIRV::OpGroupNonUniformLogicalXor: {
1717 assert(MI.getOperand(3).isImm());
1718 int64_t GroupOp = MI.getOperand(i: 3).getImm();
1719 switch (GroupOp) {
1720 case SPIRV::GroupOperation::Reduce:
1721 case SPIRV::GroupOperation::InclusiveScan:
1722 case SPIRV::GroupOperation::ExclusiveScan:
1723 Reqs.addCapability(ToAdd: SPIRV::Capability::GroupNonUniformArithmetic);
1724 break;
1725 case SPIRV::GroupOperation::ClusteredReduce:
1726 Reqs.addCapability(ToAdd: SPIRV::Capability::GroupNonUniformClustered);
1727 break;
1728 case SPIRV::GroupOperation::PartitionedReduceNV:
1729 case SPIRV::GroupOperation::PartitionedInclusiveScanNV:
1730 case SPIRV::GroupOperation::PartitionedExclusiveScanNV:
1731 Reqs.addCapability(ToAdd: SPIRV::Capability::GroupNonUniformPartitionedNV);
1732 break;
1733 }
1734 break;
1735 }
1736 case SPIRV::OpGroupNonUniformQuadSwap:
1737 Reqs.addCapability(ToAdd: SPIRV::Capability::GroupNonUniformQuad);
1738 break;
1739 case SPIRV::OpImageQueryLod:
1740 Reqs.addCapability(ToAdd: SPIRV::Capability::ImageQuery);
1741 break;
1742 case SPIRV::OpImageQuerySize:
1743 case SPIRV::OpImageQuerySizeLod:
1744 case SPIRV::OpImageQueryLevels:
1745 case SPIRV::OpImageQuerySamples:
1746 if (ST.isShader())
1747 Reqs.addCapability(ToAdd: SPIRV::Capability::ImageQuery);
1748 break;
1749 case SPIRV::OpImageQueryFormat: {
1750 Register ResultReg = MI.getOperand(i: 0).getReg();
1751 const MachineRegisterInfo &MRI = MI.getMF()->getRegInfo();
1752 static const unsigned CompareOps[] = {
1753 SPIRV::OpIEqual, SPIRV::OpINotEqual,
1754 SPIRV::OpUGreaterThan, SPIRV::OpUGreaterThanEqual,
1755 SPIRV::OpULessThan, SPIRV::OpULessThanEqual,
1756 SPIRV::OpSGreaterThan, SPIRV::OpSGreaterThanEqual,
1757 SPIRV::OpSLessThan, SPIRV::OpSLessThanEqual};
1758
1759 auto CheckAndAddExtension = [&](int64_t ImmVal) {
1760 if (ImmVal == 4323 || ImmVal == 4324) {
1761 if (ST.canUseExtension(E: SPIRV::Extension::SPV_EXT_image_raw10_raw12))
1762 Reqs.addExtension(ToAdd: SPIRV::Extension::SPV_EXT_image_raw10_raw12);
1763 else
1764 report_fatal_error(reason: "This requires the "
1765 "SPV_EXT_image_raw10_raw12 extension");
1766 }
1767 };
1768
1769 for (MachineInstr &UseInst : MRI.use_instructions(Reg: ResultReg)) {
1770 unsigned Opc = UseInst.getOpcode();
1771
1772 if (Opc == SPIRV::OpSwitch) {
1773 for (const MachineOperand &Op : UseInst.operands())
1774 if (Op.isImm())
1775 CheckAndAddExtension(Op.getImm());
1776 } else if (llvm::is_contained(Range: CompareOps, Element: Opc)) {
1777 for (unsigned i = 1; i < UseInst.getNumOperands(); ++i) {
1778 Register UseReg = UseInst.getOperand(i).getReg();
1779 MachineInstr *ConstInst = MRI.getVRegDef(Reg: UseReg);
1780 if (ConstInst && ConstInst->getOpcode() == SPIRV::OpConstantI) {
1781 int64_t ImmVal = ConstInst->getOperand(i: 2).getImm();
1782 if (ImmVal)
1783 CheckAndAddExtension(ImmVal);
1784 }
1785 }
1786 }
1787 }
1788 break;
1789 }
1790
1791 case SPIRV::OpGroupNonUniformShuffle:
1792 case SPIRV::OpGroupNonUniformShuffleXor:
1793 Reqs.addCapability(ToAdd: SPIRV::Capability::GroupNonUniformShuffle);
1794 break;
1795 case SPIRV::OpGroupNonUniformShuffleUp:
1796 case SPIRV::OpGroupNonUniformShuffleDown:
1797 Reqs.addCapability(ToAdd: SPIRV::Capability::GroupNonUniformShuffleRelative);
1798 break;
1799 case SPIRV::OpGroupAll:
1800 case SPIRV::OpGroupAny:
1801 case SPIRV::OpGroupBroadcast:
1802 case SPIRV::OpGroupIAdd:
1803 case SPIRV::OpGroupFAdd:
1804 case SPIRV::OpGroupFMin:
1805 case SPIRV::OpGroupUMin:
1806 case SPIRV::OpGroupSMin:
1807 case SPIRV::OpGroupFMax:
1808 case SPIRV::OpGroupUMax:
1809 case SPIRV::OpGroupSMax:
1810 Reqs.addCapability(ToAdd: SPIRV::Capability::Groups);
1811 break;
1812 case SPIRV::OpGroupNonUniformElect:
1813 Reqs.addCapability(ToAdd: SPIRV::Capability::GroupNonUniform);
1814 break;
1815 case SPIRV::OpGroupNonUniformAll:
1816 case SPIRV::OpGroupNonUniformAny:
1817 case SPIRV::OpGroupNonUniformAllEqual:
1818 Reqs.addCapability(ToAdd: SPIRV::Capability::GroupNonUniformVote);
1819 break;
1820 case SPIRV::OpGroupNonUniformBroadcast:
1821 case SPIRV::OpGroupNonUniformBroadcastFirst:
1822 case SPIRV::OpGroupNonUniformBallot:
1823 case SPIRV::OpGroupNonUniformInverseBallot:
1824 case SPIRV::OpGroupNonUniformBallotBitExtract:
1825 case SPIRV::OpGroupNonUniformBallotBitCount:
1826 case SPIRV::OpGroupNonUniformBallotFindLSB:
1827 case SPIRV::OpGroupNonUniformBallotFindMSB:
1828 Reqs.addCapability(ToAdd: SPIRV::Capability::GroupNonUniformBallot);
1829 break;
1830 case SPIRV::OpSubgroupShuffleINTEL:
1831 case SPIRV::OpSubgroupShuffleDownINTEL:
1832 case SPIRV::OpSubgroupShuffleUpINTEL:
1833 case SPIRV::OpSubgroupShuffleXorINTEL:
1834 if (ST.canUseExtension(E: SPIRV::Extension::SPV_INTEL_subgroups)) {
1835 Reqs.addExtension(ToAdd: SPIRV::Extension::SPV_INTEL_subgroups);
1836 Reqs.addCapability(ToAdd: SPIRV::Capability::SubgroupShuffleINTEL);
1837 }
1838 break;
1839 case SPIRV::OpSubgroupBlockReadINTEL:
1840 case SPIRV::OpSubgroupBlockWriteINTEL:
1841 if (ST.canUseExtension(E: SPIRV::Extension::SPV_INTEL_subgroups)) {
1842 Reqs.addExtension(ToAdd: SPIRV::Extension::SPV_INTEL_subgroups);
1843 Reqs.addCapability(ToAdd: SPIRV::Capability::SubgroupBufferBlockIOINTEL);
1844 }
1845 break;
1846 case SPIRV::OpSubgroupImageBlockReadINTEL:
1847 case SPIRV::OpSubgroupImageBlockWriteINTEL:
1848 if (ST.canUseExtension(E: SPIRV::Extension::SPV_INTEL_subgroups)) {
1849 Reqs.addExtension(ToAdd: SPIRV::Extension::SPV_INTEL_subgroups);
1850 Reqs.addCapability(ToAdd: SPIRV::Capability::SubgroupImageBlockIOINTEL);
1851 }
1852 break;
1853 case SPIRV::OpSubgroupImageMediaBlockReadINTEL:
1854 case SPIRV::OpSubgroupImageMediaBlockWriteINTEL:
1855 if (ST.canUseExtension(E: SPIRV::Extension::SPV_INTEL_media_block_io)) {
1856 Reqs.addExtension(ToAdd: SPIRV::Extension::SPV_INTEL_media_block_io);
1857 Reqs.addCapability(ToAdd: SPIRV::Capability::SubgroupImageMediaBlockIOINTEL);
1858 }
1859 break;
1860 case SPIRV::OpAssumeTrueKHR:
1861 case SPIRV::OpExpectKHR:
1862 if (ST.canUseExtension(E: SPIRV::Extension::SPV_KHR_expect_assume)) {
1863 Reqs.addExtension(ToAdd: SPIRV::Extension::SPV_KHR_expect_assume);
1864 Reqs.addCapability(ToAdd: SPIRV::Capability::ExpectAssumeKHR);
1865 }
1866 break;
1867 case SPIRV::OpFmaKHR:
1868 if (ST.canUseExtension(E: SPIRV::Extension::SPV_KHR_fma)) {
1869 Reqs.addExtension(ToAdd: SPIRV::Extension::SPV_KHR_fma);
1870 Reqs.addCapability(ToAdd: SPIRV::Capability::FmaKHR);
1871 }
1872 break;
1873 case SPIRV::OpPtrCastToCrossWorkgroupINTEL:
1874 case SPIRV::OpCrossWorkgroupCastToPtrINTEL:
1875 if (ST.canUseExtension(E: SPIRV::Extension::SPV_INTEL_usm_storage_classes)) {
1876 Reqs.addExtension(ToAdd: SPIRV::Extension::SPV_INTEL_usm_storage_classes);
1877 Reqs.addCapability(ToAdd: SPIRV::Capability::USMStorageClassesINTEL);
1878 }
1879 break;
1880 case SPIRV::OpConstantFunctionPointerINTEL:
1881 if (ST.canUseExtension(E: SPIRV::Extension::SPV_INTEL_function_pointers)) {
1882 Reqs.addExtension(ToAdd: SPIRV::Extension::SPV_INTEL_function_pointers);
1883 Reqs.addCapability(ToAdd: SPIRV::Capability::FunctionPointersINTEL);
1884 }
1885 break;
1886 case SPIRV::OpGroupNonUniformRotateKHR:
1887 if (!ST.canUseExtension(E: SPIRV::Extension::SPV_KHR_subgroup_rotate))
1888 report_fatal_error(reason: "OpGroupNonUniformRotateKHR instruction requires the "
1889 "following SPIR-V extension: SPV_KHR_subgroup_rotate",
1890 gen_crash_diag: false);
1891 Reqs.addExtension(ToAdd: SPIRV::Extension::SPV_KHR_subgroup_rotate);
1892 Reqs.addCapability(ToAdd: SPIRV::Capability::GroupNonUniformRotateKHR);
1893 Reqs.addCapability(ToAdd: SPIRV::Capability::GroupNonUniform);
1894 break;
1895 case SPIRV::OpFixedCosALTERA:
1896 case SPIRV::OpFixedSinALTERA:
1897 case SPIRV::OpFixedCosPiALTERA:
1898 case SPIRV::OpFixedSinPiALTERA:
1899 case SPIRV::OpFixedExpALTERA:
1900 case SPIRV::OpFixedLogALTERA:
1901 case SPIRV::OpFixedRecipALTERA:
1902 case SPIRV::OpFixedSqrtALTERA:
1903 case SPIRV::OpFixedSinCosALTERA:
1904 case SPIRV::OpFixedSinCosPiALTERA:
1905 case SPIRV::OpFixedRsqrtALTERA:
1906 if (!ST.canUseExtension(
1907 E: SPIRV::Extension::SPV_ALTERA_arbitrary_precision_fixed_point))
1908 report_fatal_error(reason: "This instruction requires the "
1909 "following SPIR-V extension: "
1910 "SPV_ALTERA_arbitrary_precision_fixed_point",
1911 gen_crash_diag: false);
1912 Reqs.addExtension(
1913 ToAdd: SPIRV::Extension::SPV_ALTERA_arbitrary_precision_fixed_point);
1914 Reqs.addCapability(ToAdd: SPIRV::Capability::ArbitraryPrecisionFixedPointALTERA);
1915 break;
1916 case SPIRV::OpGroupIMulKHR:
1917 case SPIRV::OpGroupFMulKHR:
1918 case SPIRV::OpGroupBitwiseAndKHR:
1919 case SPIRV::OpGroupBitwiseOrKHR:
1920 case SPIRV::OpGroupBitwiseXorKHR:
1921 case SPIRV::OpGroupLogicalAndKHR:
1922 case SPIRV::OpGroupLogicalOrKHR:
1923 case SPIRV::OpGroupLogicalXorKHR:
1924 if (ST.canUseExtension(
1925 E: SPIRV::Extension::SPV_KHR_uniform_group_instructions)) {
1926 Reqs.addExtension(ToAdd: SPIRV::Extension::SPV_KHR_uniform_group_instructions);
1927 Reqs.addCapability(ToAdd: SPIRV::Capability::GroupUniformArithmeticKHR);
1928 }
1929 break;
1930 case SPIRV::OpReadClockKHR:
1931 if (!ST.canUseExtension(E: SPIRV::Extension::SPV_KHR_shader_clock))
1932 report_fatal_error(reason: "OpReadClockKHR instruction requires the "
1933 "following SPIR-V extension: SPV_KHR_shader_clock",
1934 gen_crash_diag: false);
1935 Reqs.addExtension(ToAdd: SPIRV::Extension::SPV_KHR_shader_clock);
1936 Reqs.addCapability(ToAdd: SPIRV::Capability::ShaderClockKHR);
1937 break;
1938 case SPIRV::OpFunctionPointerCallINTEL:
1939 if (ST.canUseExtension(E: SPIRV::Extension::SPV_INTEL_function_pointers)) {
1940 Reqs.addExtension(ToAdd: SPIRV::Extension::SPV_INTEL_function_pointers);
1941 Reqs.addCapability(ToAdd: SPIRV::Capability::FunctionPointersINTEL);
1942 }
1943 break;
1944 case SPIRV::OpAtomicFAddEXT:
1945 case SPIRV::OpAtomicFMinEXT:
1946 case SPIRV::OpAtomicFMaxEXT:
1947 AddAtomicFloatRequirements(MI, Reqs, ST);
1948 break;
1949 case SPIRV::OpConvertBF16ToFINTEL:
1950 case SPIRV::OpConvertFToBF16INTEL:
1951 if (ST.canUseExtension(E: SPIRV::Extension::SPV_INTEL_bfloat16_conversion)) {
1952 Reqs.addExtension(ToAdd: SPIRV::Extension::SPV_INTEL_bfloat16_conversion);
1953 Reqs.addCapability(ToAdd: SPIRV::Capability::BFloat16ConversionINTEL);
1954 }
1955 break;
1956 case SPIRV::OpRoundFToTF32INTEL:
1957 if (ST.canUseExtension(
1958 E: SPIRV::Extension::SPV_INTEL_tensor_float32_conversion)) {
1959 Reqs.addExtension(ToAdd: SPIRV::Extension::SPV_INTEL_tensor_float32_conversion);
1960 Reqs.addCapability(ToAdd: SPIRV::Capability::TensorFloat32RoundingINTEL);
1961 }
1962 break;
1963 case SPIRV::OpVariableLengthArrayINTEL:
1964 case SPIRV::OpSaveMemoryINTEL:
1965 case SPIRV::OpRestoreMemoryINTEL:
1966 if (ST.canUseExtension(E: SPIRV::Extension::SPV_INTEL_variable_length_array)) {
1967 Reqs.addExtension(ToAdd: SPIRV::Extension::SPV_INTEL_variable_length_array);
1968 Reqs.addCapability(ToAdd: SPIRV::Capability::VariableLengthArrayINTEL);
1969 }
1970 break;
1971 case SPIRV::OpAsmTargetINTEL:
1972 case SPIRV::OpAsmINTEL:
1973 case SPIRV::OpAsmCallINTEL:
1974 if (ST.canUseExtension(E: SPIRV::Extension::SPV_INTEL_inline_assembly)) {
1975 Reqs.addExtension(ToAdd: SPIRV::Extension::SPV_INTEL_inline_assembly);
1976 Reqs.addCapability(ToAdd: SPIRV::Capability::AsmINTEL);
1977 }
1978 break;
1979 case SPIRV::OpTypeCooperativeMatrixKHR: {
1980 if (!ST.canUseExtension(E: SPIRV::Extension::SPV_KHR_cooperative_matrix))
1981 report_fatal_error(
1982 reason: "OpTypeCooperativeMatrixKHR type requires the "
1983 "following SPIR-V extension: SPV_KHR_cooperative_matrix",
1984 gen_crash_diag: false);
1985 Reqs.addExtension(ToAdd: SPIRV::Extension::SPV_KHR_cooperative_matrix);
1986 Reqs.addCapability(ToAdd: SPIRV::Capability::CooperativeMatrixKHR);
1987 const MachineRegisterInfo &MRI = MI.getMF()->getRegInfo();
1988 SPIRVTypeInst TypeDef = MRI.getVRegDef(Reg: MI.getOperand(i: 1).getReg());
1989 if (isBFloat16Type(TypeDef))
1990 Reqs.addCapability(ToAdd: SPIRV::Capability::BFloat16CooperativeMatrixKHR);
1991 break;
1992 }
1993 case SPIRV::OpArithmeticFenceEXT:
1994 if (!ST.canUseExtension(E: SPIRV::Extension::SPV_EXT_arithmetic_fence))
1995 report_fatal_error(reason: "OpArithmeticFenceEXT requires the "
1996 "following SPIR-V extension: SPV_EXT_arithmetic_fence",
1997 gen_crash_diag: false);
1998 Reqs.addExtension(ToAdd: SPIRV::Extension::SPV_EXT_arithmetic_fence);
1999 Reqs.addCapability(ToAdd: SPIRV::Capability::ArithmeticFenceEXT);
2000 break;
2001 case SPIRV::OpControlBarrierArriveINTEL:
2002 case SPIRV::OpControlBarrierWaitINTEL:
2003 if (ST.canUseExtension(E: SPIRV::Extension::SPV_INTEL_split_barrier)) {
2004 Reqs.addExtension(ToAdd: SPIRV::Extension::SPV_INTEL_split_barrier);
2005 Reqs.addCapability(ToAdd: SPIRV::Capability::SplitBarrierINTEL);
2006 }
2007 break;
2008 case SPIRV::OpCooperativeMatrixMulAddKHR: {
2009 if (!ST.canUseExtension(E: SPIRV::Extension::SPV_KHR_cooperative_matrix))
2010 report_fatal_error(reason: "Cooperative matrix instructions require the "
2011 "following SPIR-V extension: "
2012 "SPV_KHR_cooperative_matrix",
2013 gen_crash_diag: false);
2014 Reqs.addExtension(ToAdd: SPIRV::Extension::SPV_KHR_cooperative_matrix);
2015 Reqs.addCapability(ToAdd: SPIRV::Capability::CooperativeMatrixKHR);
2016 constexpr unsigned MulAddMaxSize = 6;
2017 if (MI.getNumOperands() != MulAddMaxSize)
2018 break;
2019 const int64_t CoopOperands = MI.getOperand(i: MulAddMaxSize - 1).getImm();
2020 if (CoopOperands &
2021 SPIRV::CooperativeMatrixOperands::MatrixAAndBTF32ComponentsINTEL) {
2022 if (!ST.canUseExtension(E: SPIRV::Extension::SPV_INTEL_joint_matrix))
2023 report_fatal_error(reason: "MatrixAAndBTF32ComponentsINTEL type interpretation "
2024 "require the following SPIR-V extension: "
2025 "SPV_INTEL_joint_matrix",
2026 gen_crash_diag: false);
2027 Reqs.addExtension(ToAdd: SPIRV::Extension::SPV_INTEL_joint_matrix);
2028 Reqs.addCapability(
2029 ToAdd: SPIRV::Capability::CooperativeMatrixTF32ComponentTypeINTEL);
2030 }
2031 if (CoopOperands & SPIRV::CooperativeMatrixOperands::
2032 MatrixAAndBBFloat16ComponentsINTEL ||
2033 CoopOperands &
2034 SPIRV::CooperativeMatrixOperands::MatrixCBFloat16ComponentsINTEL ||
2035 CoopOperands & SPIRV::CooperativeMatrixOperands::
2036 MatrixResultBFloat16ComponentsINTEL) {
2037 if (!ST.canUseExtension(E: SPIRV::Extension::SPV_INTEL_joint_matrix))
2038 report_fatal_error(reason: "***BF16ComponentsINTEL type interpretations "
2039 "require the following SPIR-V extension: "
2040 "SPV_INTEL_joint_matrix",
2041 gen_crash_diag: false);
2042 Reqs.addExtension(ToAdd: SPIRV::Extension::SPV_INTEL_joint_matrix);
2043 Reqs.addCapability(
2044 ToAdd: SPIRV::Capability::CooperativeMatrixBFloat16ComponentTypeINTEL);
2045 }
2046 break;
2047 }
2048 case SPIRV::OpCooperativeMatrixLoadKHR:
2049 case SPIRV::OpCooperativeMatrixStoreKHR:
2050 case SPIRV::OpCooperativeMatrixLoadCheckedINTEL:
2051 case SPIRV::OpCooperativeMatrixStoreCheckedINTEL:
2052 case SPIRV::OpCooperativeMatrixPrefetchINTEL: {
2053 if (!ST.canUseExtension(E: SPIRV::Extension::SPV_KHR_cooperative_matrix))
2054 report_fatal_error(reason: "Cooperative matrix instructions require the "
2055 "following SPIR-V extension: "
2056 "SPV_KHR_cooperative_matrix",
2057 gen_crash_diag: false);
2058 Reqs.addExtension(ToAdd: SPIRV::Extension::SPV_KHR_cooperative_matrix);
2059 Reqs.addCapability(ToAdd: SPIRV::Capability::CooperativeMatrixKHR);
2060
2061 // Check Layout operand in case if it's not a standard one and add the
2062 // appropriate capability.
2063 std::unordered_map<unsigned, unsigned> LayoutToInstMap = {
2064 {SPIRV::OpCooperativeMatrixLoadKHR, 3},
2065 {SPIRV::OpCooperativeMatrixStoreKHR, 2},
2066 {SPIRV::OpCooperativeMatrixLoadCheckedINTEL, 5},
2067 {SPIRV::OpCooperativeMatrixStoreCheckedINTEL, 4},
2068 {SPIRV::OpCooperativeMatrixPrefetchINTEL, 4}};
2069
2070 const unsigned LayoutNum = LayoutToInstMap[Op];
2071 Register RegLayout = MI.getOperand(i: LayoutNum).getReg();
2072 const MachineRegisterInfo &MRI = MI.getMF()->getRegInfo();
2073 MachineInstr *MILayout = MRI.getUniqueVRegDef(Reg: RegLayout);
2074 if (MILayout->getOpcode() == SPIRV::OpConstantI) {
2075 const unsigned LayoutVal = MILayout->getOperand(i: 2).getImm();
2076 if (LayoutVal ==
2077 static_cast<unsigned>(SPIRV::CooperativeMatrixLayout::PackedINTEL)) {
2078 if (!ST.canUseExtension(E: SPIRV::Extension::SPV_INTEL_joint_matrix))
2079 report_fatal_error(reason: "PackedINTEL layout require the following SPIR-V "
2080 "extension: SPV_INTEL_joint_matrix",
2081 gen_crash_diag: false);
2082 Reqs.addExtension(ToAdd: SPIRV::Extension::SPV_INTEL_joint_matrix);
2083 Reqs.addCapability(ToAdd: SPIRV::Capability::PackedCooperativeMatrixINTEL);
2084 }
2085 }
2086
2087 // Nothing to do.
2088 if (Op == SPIRV::OpCooperativeMatrixLoadKHR ||
2089 Op == SPIRV::OpCooperativeMatrixStoreKHR)
2090 break;
2091
2092 std::string InstName;
2093 switch (Op) {
2094 case SPIRV::OpCooperativeMatrixPrefetchINTEL:
2095 InstName = "OpCooperativeMatrixPrefetchINTEL";
2096 break;
2097 case SPIRV::OpCooperativeMatrixLoadCheckedINTEL:
2098 InstName = "OpCooperativeMatrixLoadCheckedINTEL";
2099 break;
2100 case SPIRV::OpCooperativeMatrixStoreCheckedINTEL:
2101 InstName = "OpCooperativeMatrixStoreCheckedINTEL";
2102 break;
2103 }
2104
2105 if (!ST.canUseExtension(E: SPIRV::Extension::SPV_INTEL_joint_matrix)) {
2106 const std::string ErrorMsg =
2107 InstName + " instruction requires the "
2108 "following SPIR-V extension: SPV_INTEL_joint_matrix";
2109 report_fatal_error(reason: ErrorMsg.c_str(), gen_crash_diag: false);
2110 }
2111 Reqs.addExtension(ToAdd: SPIRV::Extension::SPV_INTEL_joint_matrix);
2112 if (Op == SPIRV::OpCooperativeMatrixPrefetchINTEL) {
2113 Reqs.addCapability(ToAdd: SPIRV::Capability::CooperativeMatrixPrefetchINTEL);
2114 break;
2115 }
2116 Reqs.addCapability(
2117 ToAdd: SPIRV::Capability::CooperativeMatrixCheckedInstructionsINTEL);
2118 break;
2119 }
2120 case SPIRV::OpCooperativeMatrixConstructCheckedINTEL:
2121 if (!ST.canUseExtension(E: SPIRV::Extension::SPV_INTEL_joint_matrix))
2122 report_fatal_error(reason: "OpCooperativeMatrixConstructCheckedINTEL "
2123 "instructions require the following SPIR-V extension: "
2124 "SPV_INTEL_joint_matrix",
2125 gen_crash_diag: false);
2126 Reqs.addExtension(ToAdd: SPIRV::Extension::SPV_INTEL_joint_matrix);
2127 Reqs.addCapability(
2128 ToAdd: SPIRV::Capability::CooperativeMatrixCheckedInstructionsINTEL);
2129 break;
2130 case SPIRV::OpReadPipeBlockingALTERA:
2131 case SPIRV::OpWritePipeBlockingALTERA:
2132 if (ST.canUseExtension(E: SPIRV::Extension::SPV_ALTERA_blocking_pipes)) {
2133 Reqs.addExtension(ToAdd: SPIRV::Extension::SPV_ALTERA_blocking_pipes);
2134 Reqs.addCapability(ToAdd: SPIRV::Capability::BlockingPipesALTERA);
2135 }
2136 break;
2137 case SPIRV::OpCooperativeMatrixGetElementCoordINTEL:
2138 if (!ST.canUseExtension(E: SPIRV::Extension::SPV_INTEL_joint_matrix))
2139 report_fatal_error(reason: "OpCooperativeMatrixGetElementCoordINTEL requires the "
2140 "following SPIR-V extension: SPV_INTEL_joint_matrix",
2141 gen_crash_diag: false);
2142 Reqs.addExtension(ToAdd: SPIRV::Extension::SPV_INTEL_joint_matrix);
2143 Reqs.addCapability(
2144 ToAdd: SPIRV::Capability::CooperativeMatrixInvocationInstructionsINTEL);
2145 break;
2146 case SPIRV::OpConvertHandleToImageINTEL:
2147 case SPIRV::OpConvertHandleToSamplerINTEL:
2148 case SPIRV::OpConvertHandleToSampledImageINTEL: {
2149 if (!ST.canUseExtension(E: SPIRV::Extension::SPV_INTEL_bindless_images))
2150 report_fatal_error(reason: "OpConvertHandleTo[Image/Sampler/SampledImage]INTEL "
2151 "instructions require the following SPIR-V extension: "
2152 "SPV_INTEL_bindless_images",
2153 gen_crash_diag: false);
2154 SPIRVGlobalRegistry *GR = ST.getSPIRVGlobalRegistry();
2155 SPIRV::AddressingModel::AddressingModel AddrModel = MAI.Addr;
2156 SPIRVTypeInst TyDef = GR->getSPIRVTypeForVReg(VReg: MI.getOperand(i: 1).getReg());
2157 if (Op == SPIRV::OpConvertHandleToImageINTEL &&
2158 TyDef->getOpcode() != SPIRV::OpTypeImage) {
2159 report_fatal_error(reason: "Incorrect return type for the instruction "
2160 "OpConvertHandleToImageINTEL",
2161 gen_crash_diag: false);
2162 } else if (Op == SPIRV::OpConvertHandleToSamplerINTEL &&
2163 TyDef->getOpcode() != SPIRV::OpTypeSampler) {
2164 report_fatal_error(reason: "Incorrect return type for the instruction "
2165 "OpConvertHandleToSamplerINTEL",
2166 gen_crash_diag: false);
2167 } else if (Op == SPIRV::OpConvertHandleToSampledImageINTEL &&
2168 TyDef->getOpcode() != SPIRV::OpTypeSampledImage) {
2169 report_fatal_error(reason: "Incorrect return type for the instruction "
2170 "OpConvertHandleToSampledImageINTEL",
2171 gen_crash_diag: false);
2172 }
2173 SPIRVTypeInst SpvTy = GR->getSPIRVTypeForVReg(VReg: MI.getOperand(i: 2).getReg());
2174 unsigned Bitwidth = GR->getScalarOrVectorBitWidth(Type: SpvTy);
2175 if (!(Bitwidth == 32 && AddrModel == SPIRV::AddressingModel::Physical32) &&
2176 !(Bitwidth == 64 && AddrModel == SPIRV::AddressingModel::Physical64)) {
2177 report_fatal_error(
2178 reason: "Parameter value must be a 32-bit scalar in case of "
2179 "Physical32 addressing model or a 64-bit scalar in case of "
2180 "Physical64 addressing model",
2181 gen_crash_diag: false);
2182 }
2183 Reqs.addExtension(ToAdd: SPIRV::Extension::SPV_INTEL_bindless_images);
2184 Reqs.addCapability(ToAdd: SPIRV::Capability::BindlessImagesINTEL);
2185 break;
2186 }
2187 case SPIRV::OpSubgroup2DBlockLoadINTEL:
2188 case SPIRV::OpSubgroup2DBlockLoadTransposeINTEL:
2189 case SPIRV::OpSubgroup2DBlockLoadTransformINTEL:
2190 case SPIRV::OpSubgroup2DBlockPrefetchINTEL:
2191 case SPIRV::OpSubgroup2DBlockStoreINTEL: {
2192 if (!ST.canUseExtension(E: SPIRV::Extension::SPV_INTEL_2d_block_io))
2193 report_fatal_error(reason: "OpSubgroup2DBlock[Load/LoadTranspose/LoadTransform/"
2194 "Prefetch/Store]INTEL instructions require the "
2195 "following SPIR-V extension: SPV_INTEL_2d_block_io",
2196 gen_crash_diag: false);
2197 Reqs.addExtension(ToAdd: SPIRV::Extension::SPV_INTEL_2d_block_io);
2198 Reqs.addCapability(ToAdd: SPIRV::Capability::Subgroup2DBlockIOINTEL);
2199
2200 if (Op == SPIRV::OpSubgroup2DBlockLoadTransposeINTEL) {
2201 Reqs.addCapability(ToAdd: SPIRV::Capability::Subgroup2DBlockTransposeINTEL);
2202 break;
2203 }
2204 if (Op == SPIRV::OpSubgroup2DBlockLoadTransformINTEL) {
2205 Reqs.addCapability(ToAdd: SPIRV::Capability::Subgroup2DBlockTransformINTEL);
2206 break;
2207 }
2208 break;
2209 }
2210 case SPIRV::OpKill: {
2211 Reqs.addCapability(ToAdd: SPIRV::Capability::Shader);
2212 } break;
2213 case SPIRV::OpDemoteToHelperInvocation:
2214 Reqs.addCapability(ToAdd: SPIRV::Capability::DemoteToHelperInvocation);
2215
2216 if (ST.canUseExtension(
2217 E: SPIRV::Extension::SPV_EXT_demote_to_helper_invocation)) {
2218 if (!ST.isAtLeastSPIRVVer(VerToCompareTo: llvm::VersionTuple(1, 6)))
2219 Reqs.addExtension(
2220 ToAdd: SPIRV::Extension::SPV_EXT_demote_to_helper_invocation);
2221 }
2222 break;
2223 case SPIRV::OpSDot:
2224 case SPIRV::OpUDot:
2225 case SPIRV::OpSUDot:
2226 case SPIRV::OpSDotAccSat:
2227 case SPIRV::OpUDotAccSat:
2228 case SPIRV::OpSUDotAccSat:
2229 AddDotProductRequirements(MI, Reqs, ST);
2230 break;
2231 case SPIRV::OpImageSampleImplicitLod:
2232 Reqs.addCapability(ToAdd: SPIRV::Capability::Shader);
2233 addImageOperandReqs(MI, Reqs, ST, OpIdx: 4);
2234 break;
2235 case SPIRV::OpImageSampleExplicitLod:
2236 addImageOperandReqs(MI, Reqs, ST, OpIdx: 4);
2237 break;
2238 case SPIRV::OpImageSampleDrefImplicitLod:
2239 Reqs.addCapability(ToAdd: SPIRV::Capability::Shader);
2240 addImageOperandReqs(MI, Reqs, ST, OpIdx: 5);
2241 break;
2242 case SPIRV::OpImageSampleDrefExplicitLod:
2243 Reqs.addCapability(ToAdd: SPIRV::Capability::Shader);
2244 addImageOperandReqs(MI, Reqs, ST, OpIdx: 5);
2245 break;
2246 case SPIRV::OpImageFetch:
2247 Reqs.addCapability(ToAdd: SPIRV::Capability::Shader);
2248 addImageOperandReqs(MI, Reqs, ST, OpIdx: 4);
2249 break;
2250 case SPIRV::OpImageDrefGather:
2251 case SPIRV::OpImageGather:
2252 Reqs.addCapability(ToAdd: SPIRV::Capability::Shader);
2253 addImageOperandReqs(MI, Reqs, ST, OpIdx: 5);
2254 break;
2255 case SPIRV::OpImageRead: {
2256 Register ImageReg = MI.getOperand(i: 2).getReg();
2257 SPIRVTypeInst TypeDef = ST.getSPIRVGlobalRegistry()->getResultType(
2258 VReg: ImageReg, MF: const_cast<MachineFunction *>(MI.getMF()));
2259 // OpImageRead and OpImageWrite can use Unknown Image Formats
2260 // when the Kernel capability is declared. In the OpenCL environment we are
2261 // not allowed to produce
2262 // StorageImageReadWithoutFormat/StorageImageWriteWithoutFormat, see
2263 // https://github.com/KhronosGroup/SPIRV-Headers/issues/487
2264
2265 if (isImageTypeWithUnknownFormat(TypeInst: TypeDef) && ST.isShader())
2266 Reqs.addCapability(ToAdd: SPIRV::Capability::StorageImageReadWithoutFormat);
2267 break;
2268 }
2269 case SPIRV::OpImageWrite: {
2270 Register ImageReg = MI.getOperand(i: 0).getReg();
2271 SPIRVTypeInst TypeDef = ST.getSPIRVGlobalRegistry()->getResultType(
2272 VReg: ImageReg, MF: const_cast<MachineFunction *>(MI.getMF()));
2273 // OpImageRead and OpImageWrite can use Unknown Image Formats
2274 // when the Kernel capability is declared. In the OpenCL environment we are
2275 // not allowed to produce
2276 // StorageImageReadWithoutFormat/StorageImageWriteWithoutFormat, see
2277 // https://github.com/KhronosGroup/SPIRV-Headers/issues/487
2278
2279 if (isImageTypeWithUnknownFormat(TypeInst: TypeDef) && ST.isShader())
2280 Reqs.addCapability(ToAdd: SPIRV::Capability::StorageImageWriteWithoutFormat);
2281 break;
2282 }
2283 case SPIRV::OpTypeStructContinuedINTEL:
2284 case SPIRV::OpConstantCompositeContinuedINTEL:
2285 case SPIRV::OpSpecConstantCompositeContinuedINTEL:
2286 case SPIRV::OpCompositeConstructContinuedINTEL: {
2287 if (!ST.canUseExtension(E: SPIRV::Extension::SPV_INTEL_long_composites))
2288 report_fatal_error(
2289 reason: "Continued instructions require the "
2290 "following SPIR-V extension: SPV_INTEL_long_composites",
2291 gen_crash_diag: false);
2292 Reqs.addExtension(ToAdd: SPIRV::Extension::SPV_INTEL_long_composites);
2293 Reqs.addCapability(ToAdd: SPIRV::Capability::LongCompositesINTEL);
2294 break;
2295 }
2296 case SPIRV::OpArbitraryFloatEQALTERA:
2297 case SPIRV::OpArbitraryFloatGEALTERA:
2298 case SPIRV::OpArbitraryFloatGTALTERA:
2299 case SPIRV::OpArbitraryFloatLEALTERA:
2300 case SPIRV::OpArbitraryFloatLTALTERA:
2301 case SPIRV::OpArbitraryFloatCbrtALTERA:
2302 case SPIRV::OpArbitraryFloatCosALTERA:
2303 case SPIRV::OpArbitraryFloatCosPiALTERA:
2304 case SPIRV::OpArbitraryFloatExp10ALTERA:
2305 case SPIRV::OpArbitraryFloatExp2ALTERA:
2306 case SPIRV::OpArbitraryFloatExpALTERA:
2307 case SPIRV::OpArbitraryFloatExpm1ALTERA:
2308 case SPIRV::OpArbitraryFloatHypotALTERA:
2309 case SPIRV::OpArbitraryFloatLog10ALTERA:
2310 case SPIRV::OpArbitraryFloatLog1pALTERA:
2311 case SPIRV::OpArbitraryFloatLog2ALTERA:
2312 case SPIRV::OpArbitraryFloatLogALTERA:
2313 case SPIRV::OpArbitraryFloatRecipALTERA:
2314 case SPIRV::OpArbitraryFloatSinCosALTERA:
2315 case SPIRV::OpArbitraryFloatSinCosPiALTERA:
2316 case SPIRV::OpArbitraryFloatSinALTERA:
2317 case SPIRV::OpArbitraryFloatSinPiALTERA:
2318 case SPIRV::OpArbitraryFloatSqrtALTERA:
2319 case SPIRV::OpArbitraryFloatACosALTERA:
2320 case SPIRV::OpArbitraryFloatACosPiALTERA:
2321 case SPIRV::OpArbitraryFloatAddALTERA:
2322 case SPIRV::OpArbitraryFloatASinALTERA:
2323 case SPIRV::OpArbitraryFloatASinPiALTERA:
2324 case SPIRV::OpArbitraryFloatATan2ALTERA:
2325 case SPIRV::OpArbitraryFloatATanALTERA:
2326 case SPIRV::OpArbitraryFloatATanPiALTERA:
2327 case SPIRV::OpArbitraryFloatCastFromIntALTERA:
2328 case SPIRV::OpArbitraryFloatCastALTERA:
2329 case SPIRV::OpArbitraryFloatCastToIntALTERA:
2330 case SPIRV::OpArbitraryFloatDivALTERA:
2331 case SPIRV::OpArbitraryFloatMulALTERA:
2332 case SPIRV::OpArbitraryFloatPowALTERA:
2333 case SPIRV::OpArbitraryFloatPowNALTERA:
2334 case SPIRV::OpArbitraryFloatPowRALTERA:
2335 case SPIRV::OpArbitraryFloatRSqrtALTERA:
2336 case SPIRV::OpArbitraryFloatSubALTERA: {
2337 if (!ST.canUseExtension(
2338 E: SPIRV::Extension::SPV_ALTERA_arbitrary_precision_floating_point))
2339 report_fatal_error(
2340 reason: "Floating point instructions can't be translated correctly without "
2341 "enabled SPV_ALTERA_arbitrary_precision_floating_point extension!",
2342 gen_crash_diag: false);
2343 Reqs.addExtension(
2344 ToAdd: SPIRV::Extension::SPV_ALTERA_arbitrary_precision_floating_point);
2345 Reqs.addCapability(
2346 ToAdd: SPIRV::Capability::ArbitraryPrecisionFloatingPointALTERA);
2347 break;
2348 }
2349 case SPIRV::OpSubgroupMatrixMultiplyAccumulateINTEL: {
2350 if (!ST.canUseExtension(
2351 E: SPIRV::Extension::SPV_INTEL_subgroup_matrix_multiply_accumulate))
2352 report_fatal_error(
2353 reason: "OpSubgroupMatrixMultiplyAccumulateINTEL instruction requires the "
2354 "following SPIR-V "
2355 "extension: SPV_INTEL_subgroup_matrix_multiply_accumulate",
2356 gen_crash_diag: false);
2357 Reqs.addExtension(
2358 ToAdd: SPIRV::Extension::SPV_INTEL_subgroup_matrix_multiply_accumulate);
2359 Reqs.addCapability(
2360 ToAdd: SPIRV::Capability::SubgroupMatrixMultiplyAccumulateINTEL);
2361 break;
2362 }
2363 case SPIRV::OpBitwiseFunctionINTEL: {
2364 if (!ST.canUseExtension(
2365 E: SPIRV::Extension::SPV_INTEL_ternary_bitwise_function))
2366 report_fatal_error(
2367 reason: "OpBitwiseFunctionINTEL instruction requires the following SPIR-V "
2368 "extension: SPV_INTEL_ternary_bitwise_function",
2369 gen_crash_diag: false);
2370 Reqs.addExtension(ToAdd: SPIRV::Extension::SPV_INTEL_ternary_bitwise_function);
2371 Reqs.addCapability(ToAdd: SPIRV::Capability::TernaryBitwiseFunctionINTEL);
2372 break;
2373 }
2374 case SPIRV::OpCopyMemorySized: {
2375 Reqs.addCapability(ToAdd: SPIRV::Capability::Addresses);
2376 // TODO: Add UntypedPointersKHR when implemented.
2377 break;
2378 }
2379 case SPIRV::OpPredicatedLoadINTEL:
2380 case SPIRV::OpPredicatedStoreINTEL: {
2381 if (!ST.canUseExtension(E: SPIRV::Extension::SPV_INTEL_predicated_io))
2382 report_fatal_error(
2383 reason: "OpPredicated[Load/Store]INTEL instructions require "
2384 "the following SPIR-V extension: SPV_INTEL_predicated_io",
2385 gen_crash_diag: false);
2386 Reqs.addExtension(ToAdd: SPIRV::Extension::SPV_INTEL_predicated_io);
2387 Reqs.addCapability(ToAdd: SPIRV::Capability::PredicatedIOINTEL);
2388 break;
2389 }
2390 case SPIRV::OpFAddS:
2391 case SPIRV::OpFSubS:
2392 case SPIRV::OpFMulS:
2393 case SPIRV::OpFDivS:
2394 case SPIRV::OpFRemS:
2395 case SPIRV::OpFMod:
2396 case SPIRV::OpFNegate:
2397 case SPIRV::OpFAddV:
2398 case SPIRV::OpFSubV:
2399 case SPIRV::OpFMulV:
2400 case SPIRV::OpFDivV:
2401 case SPIRV::OpFRemV:
2402 case SPIRV::OpFNegateV: {
2403 const MachineRegisterInfo &MRI = MI.getMF()->getRegInfo();
2404 SPIRVTypeInst TypeDef = MRI.getVRegDef(Reg: MI.getOperand(i: 1).getReg());
2405 if (TypeDef->getOpcode() == SPIRV::OpTypeVector)
2406 TypeDef = MRI.getVRegDef(Reg: TypeDef->getOperand(i: 1).getReg());
2407 if (isBFloat16Type(TypeDef)) {
2408 if (!ST.canUseExtension(E: SPIRV::Extension::SPV_INTEL_bfloat16_arithmetic))
2409 report_fatal_error(
2410 reason: "Arithmetic instructions with bfloat16 arguments require the "
2411 "following SPIR-V extension: SPV_INTEL_bfloat16_arithmetic",
2412 gen_crash_diag: false);
2413 Reqs.addExtension(ToAdd: SPIRV::Extension::SPV_INTEL_bfloat16_arithmetic);
2414 Reqs.addCapability(ToAdd: SPIRV::Capability::BFloat16ArithmeticINTEL);
2415 }
2416 break;
2417 }
2418 case SPIRV::OpOrdered:
2419 case SPIRV::OpUnordered:
2420 case SPIRV::OpFOrdEqual:
2421 case SPIRV::OpFOrdNotEqual:
2422 case SPIRV::OpFOrdLessThan:
2423 case SPIRV::OpFOrdLessThanEqual:
2424 case SPIRV::OpFOrdGreaterThan:
2425 case SPIRV::OpFOrdGreaterThanEqual:
2426 case SPIRV::OpFUnordEqual:
2427 case SPIRV::OpFUnordNotEqual:
2428 case SPIRV::OpFUnordLessThan:
2429 case SPIRV::OpFUnordLessThanEqual:
2430 case SPIRV::OpFUnordGreaterThan:
2431 case SPIRV::OpFUnordGreaterThanEqual: {
2432 const MachineRegisterInfo &MRI = MI.getMF()->getRegInfo();
2433 MachineInstr *OperandDef = MRI.getVRegDef(Reg: MI.getOperand(i: 2).getReg());
2434 SPIRVTypeInst TypeDef = MRI.getVRegDef(Reg: OperandDef->getOperand(i: 1).getReg());
2435 if (TypeDef->getOpcode() == SPIRV::OpTypeVector)
2436 TypeDef = MRI.getVRegDef(Reg: TypeDef->getOperand(i: 1).getReg());
2437 if (isBFloat16Type(TypeDef)) {
2438 if (!ST.canUseExtension(E: SPIRV::Extension::SPV_INTEL_bfloat16_arithmetic))
2439 report_fatal_error(
2440 reason: "Relational instructions with bfloat16 arguments require the "
2441 "following SPIR-V extension: SPV_INTEL_bfloat16_arithmetic",
2442 gen_crash_diag: false);
2443 Reqs.addExtension(ToAdd: SPIRV::Extension::SPV_INTEL_bfloat16_arithmetic);
2444 Reqs.addCapability(ToAdd: SPIRV::Capability::BFloat16ArithmeticINTEL);
2445 }
2446 break;
2447 }
2448 case SPIRV::OpDPdxCoarse:
2449 case SPIRV::OpDPdyCoarse:
2450 case SPIRV::OpDPdxFine:
2451 case SPIRV::OpDPdyFine: {
2452 Reqs.addCapability(ToAdd: SPIRV::Capability::DerivativeControl);
2453 break;
2454 }
2455 case SPIRV::OpLoopControlINTEL: {
2456 Reqs.addExtension(ToAdd: SPIRV::Extension::SPV_INTEL_unstructured_loop_controls);
2457 Reqs.addCapability(ToAdd: SPIRV::Capability::UnstructuredLoopControlsINTEL);
2458 break;
2459 }
2460
2461 default:
2462 break;
2463 }
2464
2465 // If we require capability Shader, then we can remove the requirement for
2466 // the BitInstructions capability, since Shader is a superset capability
2467 // of BitInstructions.
2468 Reqs.removeCapabilityIf(ToRemove: SPIRV::Capability::BitInstructions,
2469 IfPresent: SPIRV::Capability::Shader);
2470}
2471
2472static void collectReqs(const Module &M, SPIRV::ModuleAnalysisInfo &MAI,
2473 MachineModuleInfo *MMI, const SPIRVSubtarget &ST) {
2474 // Collect requirements for existing instructions.
2475 for (const Function &F : M) {
2476 MachineFunction *MF = MMI->getMachineFunction(F);
2477 if (!MF)
2478 continue;
2479 for (const MachineBasicBlock &MBB : *MF)
2480 for (const MachineInstr &MI : MBB)
2481 addInstrRequirements(MI, MAI, ST);
2482 }
2483 // Collect requirements for OpExecutionMode instructions.
2484 auto Node = M.getNamedMetadata(Name: "spirv.ExecutionMode");
2485 if (Node) {
2486 bool RequireFloatControls = false, RequireIntelFloatControls2 = false,
2487 RequireKHRFloatControls2 = false,
2488 VerLower14 = !ST.isAtLeastSPIRVVer(VerToCompareTo: VersionTuple(1, 4));
2489 bool HasIntelFloatControls2 =
2490 ST.canUseExtension(E: SPIRV::Extension::SPV_INTEL_float_controls2);
2491 bool HasKHRFloatControls2 =
2492 ST.canUseExtension(E: SPIRV::Extension::SPV_KHR_float_controls2);
2493 for (unsigned i = 0; i < Node->getNumOperands(); i++) {
2494 MDNode *MDN = cast<MDNode>(Val: Node->getOperand(i));
2495 const MDOperand &MDOp = MDN->getOperand(I: 1);
2496 if (auto *CMeta = dyn_cast<ConstantAsMetadata>(Val: MDOp)) {
2497 Constant *C = CMeta->getValue();
2498 if (ConstantInt *Const = dyn_cast<ConstantInt>(Val: C)) {
2499 auto EM = Const->getZExtValue();
2500 // SPV_KHR_float_controls is not available until v1.4:
2501 // add SPV_KHR_float_controls if the version is too low
2502 switch (EM) {
2503 case SPIRV::ExecutionMode::DenormPreserve:
2504 case SPIRV::ExecutionMode::DenormFlushToZero:
2505 case SPIRV::ExecutionMode::RoundingModeRTE:
2506 case SPIRV::ExecutionMode::RoundingModeRTZ:
2507 RequireFloatControls = VerLower14;
2508 MAI.Reqs.getAndAddRequirements(
2509 Category: SPIRV::OperandCategory::ExecutionModeOperand, i: EM, ST);
2510 break;
2511 case SPIRV::ExecutionMode::RoundingModeRTPINTEL:
2512 case SPIRV::ExecutionMode::RoundingModeRTNINTEL:
2513 case SPIRV::ExecutionMode::FloatingPointModeALTINTEL:
2514 case SPIRV::ExecutionMode::FloatingPointModeIEEEINTEL:
2515 if (HasIntelFloatControls2) {
2516 RequireIntelFloatControls2 = true;
2517 MAI.Reqs.getAndAddRequirements(
2518 Category: SPIRV::OperandCategory::ExecutionModeOperand, i: EM, ST);
2519 }
2520 break;
2521 case SPIRV::ExecutionMode::FPFastMathDefault: {
2522 if (HasKHRFloatControls2) {
2523 RequireKHRFloatControls2 = true;
2524 MAI.Reqs.getAndAddRequirements(
2525 Category: SPIRV::OperandCategory::ExecutionModeOperand, i: EM, ST);
2526 }
2527 break;
2528 }
2529 case SPIRV::ExecutionMode::ContractionOff:
2530 case SPIRV::ExecutionMode::SignedZeroInfNanPreserve:
2531 if (HasKHRFloatControls2) {
2532 RequireKHRFloatControls2 = true;
2533 MAI.Reqs.getAndAddRequirements(
2534 Category: SPIRV::OperandCategory::ExecutionModeOperand,
2535 i: SPIRV::ExecutionMode::FPFastMathDefault, ST);
2536 } else {
2537 MAI.Reqs.getAndAddRequirements(
2538 Category: SPIRV::OperandCategory::ExecutionModeOperand, i: EM, ST);
2539 }
2540 break;
2541 default:
2542 MAI.Reqs.getAndAddRequirements(
2543 Category: SPIRV::OperandCategory::ExecutionModeOperand, i: EM, ST);
2544 }
2545 }
2546 }
2547 }
2548 if (RequireFloatControls &&
2549 ST.canUseExtension(E: SPIRV::Extension::SPV_KHR_float_controls))
2550 MAI.Reqs.addExtension(ToAdd: SPIRV::Extension::SPV_KHR_float_controls);
2551 if (RequireIntelFloatControls2)
2552 MAI.Reqs.addExtension(ToAdd: SPIRV::Extension::SPV_INTEL_float_controls2);
2553 if (RequireKHRFloatControls2)
2554 MAI.Reqs.addExtension(ToAdd: SPIRV::Extension::SPV_KHR_float_controls2);
2555 }
2556 for (const Function &F : M) {
2557 if (F.isDeclaration())
2558 continue;
2559 if (F.getMetadata(Kind: "reqd_work_group_size"))
2560 MAI.Reqs.getAndAddRequirements(
2561 Category: SPIRV::OperandCategory::ExecutionModeOperand,
2562 i: SPIRV::ExecutionMode::LocalSize, ST);
2563 if (F.getFnAttribute(Kind: "hlsl.numthreads").isValid()) {
2564 MAI.Reqs.getAndAddRequirements(
2565 Category: SPIRV::OperandCategory::ExecutionModeOperand,
2566 i: SPIRV::ExecutionMode::LocalSize, ST);
2567 }
2568 if (F.getFnAttribute(Kind: "enable-maximal-reconvergence").getValueAsBool()) {
2569 MAI.Reqs.addExtension(ToAdd: SPIRV::Extension::SPV_KHR_maximal_reconvergence);
2570 }
2571 if (F.getMetadata(Kind: "work_group_size_hint"))
2572 MAI.Reqs.getAndAddRequirements(
2573 Category: SPIRV::OperandCategory::ExecutionModeOperand,
2574 i: SPIRV::ExecutionMode::LocalSizeHint, ST);
2575 if (F.getMetadata(Kind: "intel_reqd_sub_group_size"))
2576 MAI.Reqs.getAndAddRequirements(
2577 Category: SPIRV::OperandCategory::ExecutionModeOperand,
2578 i: SPIRV::ExecutionMode::SubgroupSize, ST);
2579 if (F.getMetadata(Kind: "max_work_group_size"))
2580 MAI.Reqs.getAndAddRequirements(
2581 Category: SPIRV::OperandCategory::ExecutionModeOperand,
2582 i: SPIRV::ExecutionMode::MaxWorkgroupSizeINTEL, ST);
2583 if (F.getMetadata(Kind: "vec_type_hint"))
2584 MAI.Reqs.getAndAddRequirements(
2585 Category: SPIRV::OperandCategory::ExecutionModeOperand,
2586 i: SPIRV::ExecutionMode::VecTypeHint, ST);
2587
2588 if (F.hasOptNone()) {
2589 if (ST.canUseExtension(E: SPIRV::Extension::SPV_INTEL_optnone)) {
2590 MAI.Reqs.addExtension(ToAdd: SPIRV::Extension::SPV_INTEL_optnone);
2591 MAI.Reqs.addCapability(ToAdd: SPIRV::Capability::OptNoneINTEL);
2592 } else if (ST.canUseExtension(E: SPIRV::Extension::SPV_EXT_optnone)) {
2593 MAI.Reqs.addExtension(ToAdd: SPIRV::Extension::SPV_EXT_optnone);
2594 MAI.Reqs.addCapability(ToAdd: SPIRV::Capability::OptNoneEXT);
2595 }
2596 }
2597 }
2598}
2599
2600static unsigned getFastMathFlags(const MachineInstr &I,
2601 const SPIRVSubtarget &ST) {
2602 unsigned Flags = SPIRV::FPFastMathMode::None;
2603 bool CanUseKHRFloatControls2 =
2604 ST.canUseExtension(E: SPIRV::Extension::SPV_KHR_float_controls2);
2605 if (I.getFlag(Flag: MachineInstr::MIFlag::FmNoNans))
2606 Flags |= SPIRV::FPFastMathMode::NotNaN;
2607 if (I.getFlag(Flag: MachineInstr::MIFlag::FmNoInfs))
2608 Flags |= SPIRV::FPFastMathMode::NotInf;
2609 if (I.getFlag(Flag: MachineInstr::MIFlag::FmNsz))
2610 Flags |= SPIRV::FPFastMathMode::NSZ;
2611 if (I.getFlag(Flag: MachineInstr::MIFlag::FmArcp))
2612 Flags |= SPIRV::FPFastMathMode::AllowRecip;
2613 if (I.getFlag(Flag: MachineInstr::MIFlag::FmContract) && CanUseKHRFloatControls2)
2614 Flags |= SPIRV::FPFastMathMode::AllowContract;
2615 if (I.getFlag(Flag: MachineInstr::MIFlag::FmReassoc)) {
2616 if (CanUseKHRFloatControls2)
2617 // LLVM reassoc maps to SPIRV transform, see
2618 // https://github.com/KhronosGroup/SPIRV-Registry/issues/326 for details.
2619 // Because we are enabling AllowTransform, we must enable AllowReassoc and
2620 // AllowContract too, as required by SPIRV spec. Also, we used to map
2621 // MIFlag::FmReassoc to FPFastMathMode::Fast, which now should instead by
2622 // replaced by turning all the other bits instead. Therefore, we're
2623 // enabling every bit here except None and Fast.
2624 Flags |= SPIRV::FPFastMathMode::NotNaN | SPIRV::FPFastMathMode::NotInf |
2625 SPIRV::FPFastMathMode::NSZ | SPIRV::FPFastMathMode::AllowRecip |
2626 SPIRV::FPFastMathMode::AllowTransform |
2627 SPIRV::FPFastMathMode::AllowReassoc |
2628 SPIRV::FPFastMathMode::AllowContract;
2629 else
2630 Flags |= SPIRV::FPFastMathMode::Fast;
2631 }
2632
2633 if (CanUseKHRFloatControls2) {
2634 // Error out if SPIRV::FPFastMathMode::Fast is enabled.
2635 assert(!(Flags & SPIRV::FPFastMathMode::Fast) &&
2636 "SPIRV::FPFastMathMode::Fast is deprecated and should not be used "
2637 "anymore.");
2638
2639 // Error out if AllowTransform is enabled without AllowReassoc and
2640 // AllowContract.
2641 assert((!(Flags & SPIRV::FPFastMathMode::AllowTransform) ||
2642 ((Flags & SPIRV::FPFastMathMode::AllowReassoc &&
2643 Flags & SPIRV::FPFastMathMode::AllowContract))) &&
2644 "SPIRV::FPFastMathMode::AllowTransform requires AllowReassoc and "
2645 "AllowContract flags to be enabled as well.");
2646 }
2647
2648 return Flags;
2649}
2650
2651static bool isFastMathModeAvailable(const SPIRVSubtarget &ST) {
2652 if (ST.isKernel())
2653 return true;
2654 if (ST.getSPIRVVersion() < VersionTuple(1, 2))
2655 return false;
2656 return ST.canUseExtension(E: SPIRV::Extension::SPV_KHR_float_controls2);
2657}
2658
2659static void handleMIFlagDecoration(
2660 MachineInstr &I, const SPIRVSubtarget &ST, const SPIRVInstrInfo &TII,
2661 SPIRV::RequirementHandler &Reqs, const SPIRVGlobalRegistry *GR,
2662 SPIRV::FPFastMathDefaultInfoVector &FPFastMathDefaultInfoVec) {
2663 if (I.getFlag(Flag: MachineInstr::MIFlag::NoSWrap) && TII.canUseNSW(MI: I) &&
2664 getSymbolicOperandRequirements(Category: SPIRV::OperandCategory::DecorationOperand,
2665 i: SPIRV::Decoration::NoSignedWrap, ST, Reqs)
2666 .IsSatisfiable) {
2667 buildOpDecorate(Reg: I.getOperand(i: 0).getReg(), I, TII,
2668 Dec: SPIRV::Decoration::NoSignedWrap, DecArgs: {});
2669 }
2670 if (I.getFlag(Flag: MachineInstr::MIFlag::NoUWrap) && TII.canUseNUW(MI: I) &&
2671 getSymbolicOperandRequirements(Category: SPIRV::OperandCategory::DecorationOperand,
2672 i: SPIRV::Decoration::NoUnsignedWrap, ST,
2673 Reqs)
2674 .IsSatisfiable) {
2675 buildOpDecorate(Reg: I.getOperand(i: 0).getReg(), I, TII,
2676 Dec: SPIRV::Decoration::NoUnsignedWrap, DecArgs: {});
2677 }
2678 // In Kernel environments, FPFastMathMode on OpExtInst is valid per core
2679 // spec. For other instruction types, SPV_KHR_float_controls2 is required.
2680 bool CanUseFM =
2681 TII.canUseFastMathFlags(
2682 MI: I, KHRFloatControls2: ST.canUseExtension(E: SPIRV::Extension::SPV_KHR_float_controls2)) ||
2683 (ST.isKernel() && I.getOpcode() == SPIRV::OpExtInst);
2684 if (!CanUseFM)
2685 return;
2686
2687 unsigned FMFlags = getFastMathFlags(I, ST);
2688 if (FMFlags == SPIRV::FPFastMathMode::None) {
2689 // We also need to check if any FPFastMathDefault info was set for the
2690 // types used in this instruction.
2691 if (FPFastMathDefaultInfoVec.empty())
2692 return;
2693
2694 // There are three types of instructions that can use fast math flags:
2695 // 1. Arithmetic instructions (FAdd, FMul, FSub, FDiv, FRem, etc.)
2696 // 2. Relational instructions (FCmp, FOrd, FUnord, etc.)
2697 // 3. Extended instructions (ExtInst)
2698 // For arithmetic instructions, the floating point type can be in the
2699 // result type or in the operands, but they all must be the same.
2700 // For the relational and logical instructions, the floating point type
2701 // can only be in the operands 1 and 2, not the result type. Also, the
2702 // operands must have the same type. For the extended instructions, the
2703 // floating point type can be in the result type or in the operands. It's
2704 // unclear if the operands and the result type must be the same. Let's
2705 // assume they must be. Therefore, for 1. and 2., we can check the first
2706 // operand type, and for 3. we can check the result type.
2707 assert(I.getNumOperands() >= 3 && "Expected at least 3 operands");
2708 Register ResReg = I.getOpcode() == SPIRV::OpExtInst
2709 ? I.getOperand(i: 1).getReg()
2710 : I.getOperand(i: 2).getReg();
2711 SPIRVTypeInst ResType = GR->getSPIRVTypeForVReg(VReg: ResReg, MF: I.getMF());
2712 const Type *Ty = GR->getTypeForSPIRVType(Ty: ResType);
2713 Ty = Ty->isVectorTy() ? cast<VectorType>(Val: Ty)->getElementType() : Ty;
2714
2715 // Match instruction type with the FPFastMathDefaultInfoVec.
2716 bool Emit = false;
2717 for (SPIRV::FPFastMathDefaultInfo &Elem : FPFastMathDefaultInfoVec) {
2718 if (Ty == Elem.Ty) {
2719 FMFlags = Elem.FastMathFlags;
2720 Emit = Elem.ContractionOff || Elem.SignedZeroInfNanPreserve ||
2721 Elem.FPFastMathDefault;
2722 break;
2723 }
2724 }
2725
2726 if (FMFlags == SPIRV::FPFastMathMode::None && !Emit)
2727 return;
2728 }
2729 if (isFastMathModeAvailable(ST)) {
2730 Register DstReg = I.getOperand(i: 0).getReg();
2731 buildOpDecorate(Reg: DstReg, I, TII, Dec: SPIRV::Decoration::FPFastMathMode,
2732 DecArgs: {FMFlags});
2733 }
2734}
2735
2736// Walk all functions and add decorations related to MI flags.
2737static void addDecorations(const Module &M, const SPIRVInstrInfo &TII,
2738 MachineModuleInfo *MMI, const SPIRVSubtarget &ST,
2739 SPIRV::ModuleAnalysisInfo &MAI,
2740 const SPIRVGlobalRegistry *GR) {
2741 for (const Function &F : M) {
2742 MachineFunction *MF = MMI->getMachineFunction(F);
2743 if (!MF)
2744 continue;
2745
2746 for (auto &MBB : *MF)
2747 for (auto &MI : MBB)
2748 handleMIFlagDecoration(I&: MI, ST, TII, Reqs&: MAI.Reqs, GR,
2749 FPFastMathDefaultInfoVec&: MAI.FPFastMathDefaultInfoMap[&F]);
2750 }
2751}
2752
2753static void addMBBNames(const Module &M, const SPIRVInstrInfo &TII,
2754 MachineModuleInfo *MMI, const SPIRVSubtarget &ST,
2755 SPIRV::ModuleAnalysisInfo &MAI) {
2756 for (const Function &F : M) {
2757 MachineFunction *MF = MMI->getMachineFunction(F);
2758 if (!MF)
2759 continue;
2760 if (MF->getFunction()
2761 .getFnAttribute(SPIRV_BACKEND_SERVICE_FUN_NAME)
2762 .isValid())
2763 continue;
2764 MachineRegisterInfo &MRI = MF->getRegInfo();
2765 for (auto &MBB : *MF) {
2766 if (!MBB.hasName() || MBB.empty())
2767 continue;
2768 // Emit basic block names.
2769 Register Reg = MRI.createGenericVirtualRegister(Ty: LLT::scalar(SizeInBits: 64));
2770 MRI.setRegClass(Reg, RC: &SPIRV::IDRegClass);
2771 buildOpName(Target: Reg, Name: MBB.getName(), I&: *std::prev(x: MBB.end()), TII);
2772 MCRegister GlobalReg = MAI.getOrCreateMBBRegister(MBB);
2773 MAI.setRegisterAlias(MF, Reg, AliasReg: GlobalReg);
2774 }
2775 }
2776}
2777
2778// patching Instruction::PHI to SPIRV::OpPhi
2779static void patchPhis(const Module &M, SPIRVGlobalRegistry *GR,
2780 const SPIRVInstrInfo &TII, MachineModuleInfo *MMI) {
2781 for (const Function &F : M) {
2782 MachineFunction *MF = MMI->getMachineFunction(F);
2783 if (!MF)
2784 continue;
2785 for (auto &MBB : *MF) {
2786 for (MachineInstr &MI : MBB.phis()) {
2787 MI.setDesc(TII.get(Opcode: SPIRV::OpPhi));
2788 Register ResTypeReg = GR->getSPIRVTypeID(
2789 SpirvType: GR->getSPIRVTypeForVReg(VReg: MI.getOperand(i: 0).getReg(), MF));
2790 MI.insert(InsertBefore: MI.operands_begin() + 1,
2791 Ops: {MachineOperand::CreateReg(Reg: ResTypeReg, isDef: false)});
2792 }
2793 }
2794
2795 MF->getProperties().setNoPHIs();
2796 }
2797}
2798
2799static SPIRV::FPFastMathDefaultInfoVector &getOrCreateFPFastMathDefaultInfoVec(
2800 const Module &M, SPIRV::ModuleAnalysisInfo &MAI, const Function *F) {
2801 auto it = MAI.FPFastMathDefaultInfoMap.find(Val: F);
2802 if (it != MAI.FPFastMathDefaultInfoMap.end())
2803 return it->second;
2804
2805 // If the map does not contain the entry, create a new one. Initialize it to
2806 // contain all 3 elements sorted by bit width of target type: {half, float,
2807 // double}.
2808 SPIRV::FPFastMathDefaultInfoVector FPFastMathDefaultInfoVec;
2809 FPFastMathDefaultInfoVec.emplace_back(Args: Type::getHalfTy(C&: M.getContext()),
2810 Args: SPIRV::FPFastMathMode::None);
2811 FPFastMathDefaultInfoVec.emplace_back(Args: Type::getFloatTy(C&: M.getContext()),
2812 Args: SPIRV::FPFastMathMode::None);
2813 FPFastMathDefaultInfoVec.emplace_back(Args: Type::getDoubleTy(C&: M.getContext()),
2814 Args: SPIRV::FPFastMathMode::None);
2815 return MAI.FPFastMathDefaultInfoMap[F] = std::move(FPFastMathDefaultInfoVec);
2816}
2817
2818static SPIRV::FPFastMathDefaultInfo &getFPFastMathDefaultInfo(
2819 SPIRV::FPFastMathDefaultInfoVector &FPFastMathDefaultInfoVec,
2820 const Type *Ty) {
2821 size_t BitWidth = Ty->getScalarSizeInBits();
2822 int Index =
2823 SPIRV::FPFastMathDefaultInfoVector::computeFPFastMathDefaultInfoVecIndex(
2824 BitWidth);
2825 assert(Index >= 0 && Index < 3 &&
2826 "Expected FPFastMathDefaultInfo for half, float, or double");
2827 assert(FPFastMathDefaultInfoVec.size() == 3 &&
2828 "Expected FPFastMathDefaultInfoVec to have exactly 3 elements");
2829 return FPFastMathDefaultInfoVec[Index];
2830}
2831
2832static void collectFPFastMathDefaults(const Module &M,
2833 SPIRV::ModuleAnalysisInfo &MAI,
2834 const SPIRVSubtarget &ST) {
2835 if (!ST.canUseExtension(E: SPIRV::Extension::SPV_KHR_float_controls2))
2836 return;
2837
2838 // Store the FPFastMathDefaultInfo in the FPFastMathDefaultInfoMap.
2839 // We need the entry point (function) as the key, and the target
2840 // type and flags as the value.
2841 // We also need to check ContractionOff and SignedZeroInfNanPreserve
2842 // execution modes, as they are now deprecated and must be replaced
2843 // with FPFastMathDefaultInfo.
2844 auto Node = M.getNamedMetadata(Name: "spirv.ExecutionMode");
2845 if (!Node)
2846 return;
2847
2848 for (unsigned i = 0; i < Node->getNumOperands(); i++) {
2849 MDNode *MDN = cast<MDNode>(Val: Node->getOperand(i));
2850 assert(MDN->getNumOperands() >= 2 && "Expected at least 2 operands");
2851 const Function *F = cast<Function>(
2852 Val: cast<ConstantAsMetadata>(Val: MDN->getOperand(I: 0))->getValue());
2853 const auto EM =
2854 cast<ConstantInt>(
2855 Val: cast<ConstantAsMetadata>(Val: MDN->getOperand(I: 1))->getValue())
2856 ->getZExtValue();
2857 if (EM == SPIRV::ExecutionMode::FPFastMathDefault) {
2858 assert(MDN->getNumOperands() == 4 &&
2859 "Expected 4 operands for FPFastMathDefault");
2860
2861 const Type *T = cast<ValueAsMetadata>(Val: MDN->getOperand(I: 2))->getType();
2862 unsigned Flags =
2863 cast<ConstantInt>(
2864 Val: cast<ConstantAsMetadata>(Val: MDN->getOperand(I: 3))->getValue())
2865 ->getZExtValue();
2866 SPIRV::FPFastMathDefaultInfoVector &FPFastMathDefaultInfoVec =
2867 getOrCreateFPFastMathDefaultInfoVec(M, MAI, F);
2868 SPIRV::FPFastMathDefaultInfo &Info =
2869 getFPFastMathDefaultInfo(FPFastMathDefaultInfoVec, Ty: T);
2870 Info.FastMathFlags = Flags;
2871 Info.FPFastMathDefault = true;
2872 } else if (EM == SPIRV::ExecutionMode::ContractionOff) {
2873 assert(MDN->getNumOperands() == 2 &&
2874 "Expected no operands for ContractionOff");
2875
2876 // We need to save this info for every possible FP type, i.e. {half,
2877 // float, double, fp128}.
2878 SPIRV::FPFastMathDefaultInfoVector &FPFastMathDefaultInfoVec =
2879 getOrCreateFPFastMathDefaultInfoVec(M, MAI, F);
2880 for (SPIRV::FPFastMathDefaultInfo &Info : FPFastMathDefaultInfoVec) {
2881 Info.ContractionOff = true;
2882 }
2883 } else if (EM == SPIRV::ExecutionMode::SignedZeroInfNanPreserve) {
2884 assert(MDN->getNumOperands() == 3 &&
2885 "Expected 1 operand for SignedZeroInfNanPreserve");
2886 unsigned TargetWidth =
2887 cast<ConstantInt>(
2888 Val: cast<ConstantAsMetadata>(Val: MDN->getOperand(I: 2))->getValue())
2889 ->getZExtValue();
2890 // We need to save this info only for the FP type with TargetWidth.
2891 SPIRV::FPFastMathDefaultInfoVector &FPFastMathDefaultInfoVec =
2892 getOrCreateFPFastMathDefaultInfoVec(M, MAI, F);
2893 int Index = SPIRV::FPFastMathDefaultInfoVector::
2894 computeFPFastMathDefaultInfoVecIndex(BitWidth: TargetWidth);
2895 assert(Index >= 0 && Index < 3 &&
2896 "Expected FPFastMathDefaultInfo for half, float, or double");
2897 assert(FPFastMathDefaultInfoVec.size() == 3 &&
2898 "Expected FPFastMathDefaultInfoVec to have exactly 3 elements");
2899 FPFastMathDefaultInfoVec[Index].SignedZeroInfNanPreserve = true;
2900 }
2901 }
2902}
2903
2904void SPIRVModuleAnalysis::getAnalysisUsage(AnalysisUsage &AU) const {
2905 AU.addRequired<TargetPassConfig>();
2906 AU.addRequired<MachineModuleInfoWrapperPass>();
2907}
2908
2909bool SPIRVModuleAnalysis::runOnModule(Module &M) {
2910 SPIRVTargetMachine &TM =
2911 getAnalysis<TargetPassConfig>().getTM<SPIRVTargetMachine>();
2912 ST = TM.getSubtargetImpl();
2913 GR = ST->getSPIRVGlobalRegistry();
2914 TII = ST->getInstrInfo();
2915
2916 MMI = &getAnalysis<MachineModuleInfoWrapperPass>().getMMI();
2917
2918 setBaseInfo(M);
2919
2920 patchPhis(M, GR, TII: *TII, MMI);
2921
2922 addMBBNames(M, TII: *TII, MMI, ST: *ST, MAI);
2923 collectFPFastMathDefaults(M, MAI, ST: *ST);
2924 addDecorations(M, TII: *TII, MMI, ST: *ST, MAI, GR);
2925
2926 collectReqs(M, MAI, MMI, ST: *ST);
2927
2928 // Process type/const/global var/func decl instructions, number their
2929 // destination registers from 0 to N, collect Extensions and Capabilities.
2930 collectReqs(M, MAI, MMI, ST: *ST);
2931 collectDeclarations(M);
2932
2933 // Number rest of registers from N+1 onwards.
2934 numberRegistersGlobally(M);
2935
2936 // Collect OpName, OpEntryPoint, OpDecorate etc, process other instructions.
2937 processOtherInstrs(M);
2938
2939 // If there are no entry points, we need the Linkage capability.
2940 if (MAI.MS[SPIRV::MB_EntryPoints].empty())
2941 MAI.Reqs.addCapability(ToAdd: SPIRV::Capability::Linkage);
2942
2943 // Set maximum ID used.
2944 GR->setBound(MAI.MaxID);
2945
2946 return false;
2947}
2948