1//===- SPIRVModuleAnalysis.cpp - analysis of global instrs & regs - C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// The analysis collects instructions that should be output at the module level
10// and performs the global register numbering.
11//
12// The results of this analysis are used in AsmPrinter to rename registers
13// globally and to output required instructions at the module level.
14//
15//===----------------------------------------------------------------------===//
16
17// TODO: uses or report_fatal_error (which is also deprecated) /
18// ReportFatalUsageError in this file should be refactored, as per LLVM
19// best practices, to rely on the Diagnostic infrastructure.
20
21#include "SPIRVModuleAnalysis.h"
22#include "MCTargetDesc/SPIRVBaseInfo.h"
23#include "MCTargetDesc/SPIRVMCTargetDesc.h"
24#include "SPIRV.h"
25#include "SPIRVSubtarget.h"
26#include "SPIRVTargetMachine.h"
27#include "SPIRVUtils.h"
28#include "llvm/ADT/STLExtras.h"
29#include "llvm/CodeGen/MachineModuleInfo.h"
30#include "llvm/CodeGen/TargetPassConfig.h"
31
32using namespace llvm;
33
34#define DEBUG_TYPE "spirv-module-analysis"
35
36static cl::opt<bool>
37 SPVDumpDeps("spv-dump-deps",
38 cl::desc("Dump MIR with SPIR-V dependencies info"),
39 cl::Optional, cl::init(Val: false));
40
41static cl::list<SPIRV::Capability::Capability>
42 AvoidCapabilities("avoid-spirv-capabilities",
43 cl::desc("SPIR-V capabilities to avoid if there are "
44 "other options enabling a feature"),
45 cl::Hidden,
46 cl::values(clEnumValN(SPIRV::Capability::Shader, "Shader",
47 "SPIR-V Shader capability")));
48// Use sets instead of cl::list to check "if contains" condition
49struct AvoidCapabilitiesSet {
50 SmallSet<SPIRV::Capability::Capability, 4> S;
51 AvoidCapabilitiesSet() { S.insert_range(R&: AvoidCapabilities); }
52};
53
54char llvm::SPIRVModuleAnalysis::ID = 0;
55
56INITIALIZE_PASS(SPIRVModuleAnalysis, DEBUG_TYPE, "SPIRV module analysis", true,
57 true)
58
59// Retrieve an unsigned from an MDNode with a list of them as operands.
60static unsigned getMetadataUInt(MDNode *MdNode, unsigned OpIndex,
61 unsigned DefaultVal = 0) {
62 if (MdNode && OpIndex < MdNode->getNumOperands()) {
63 const auto &Op = MdNode->getOperand(I: OpIndex);
64 return mdconst::extract<ConstantInt>(MD: Op)->getZExtValue();
65 }
66 return DefaultVal;
67}
68
69static SPIRV::Requirements
70getSymbolicOperandRequirements(SPIRV::OperandCategory::OperandCategory Category,
71 unsigned i, const SPIRVSubtarget &ST,
72 SPIRV::RequirementHandler &Reqs) {
73 // A set of capabilities to avoid if there is another option.
74 AvoidCapabilitiesSet AvoidCaps;
75 if (!ST.isShader())
76 AvoidCaps.S.insert(V: SPIRV::Capability::Shader);
77 else
78 AvoidCaps.S.insert(V: SPIRV::Capability::Kernel);
79
80 VersionTuple ReqMinVer = getSymbolicOperandMinVersion(Category, Value: i);
81 VersionTuple ReqMaxVer = getSymbolicOperandMaxVersion(Category, Value: i);
82 VersionTuple SPIRVVersion = ST.getSPIRVVersion();
83 bool MinVerOK = SPIRVVersion.empty() || SPIRVVersion >= ReqMinVer;
84 bool MaxVerOK =
85 ReqMaxVer.empty() || SPIRVVersion.empty() || SPIRVVersion <= ReqMaxVer;
86 CapabilityList ReqCaps = getSymbolicOperandCapabilities(Category, Value: i);
87 ExtensionList ReqExts = getSymbolicOperandExtensions(Category, Value: i);
88 if (ReqCaps.empty()) {
89 if (ReqExts.empty()) {
90 if (MinVerOK && MaxVerOK)
91 return {true, {}, {}, ReqMinVer, ReqMaxVer};
92 return {false, {}, {}, VersionTuple(), VersionTuple()};
93 }
94 } else if (MinVerOK && MaxVerOK) {
95 if (ReqCaps.size() == 1) {
96 auto Cap = ReqCaps[0];
97 if (Reqs.isCapabilityAvailable(Cap)) {
98 ReqExts.append(RHS: getSymbolicOperandExtensions(
99 Category: SPIRV::OperandCategory::CapabilityOperand, Value: Cap));
100 return {true, {Cap}, std::move(ReqExts), ReqMinVer, ReqMaxVer};
101 }
102 } else {
103 // By SPIR-V specification: "If an instruction, enumerant, or other
104 // feature specifies multiple enabling capabilities, only one such
105 // capability needs to be declared to use the feature." However, one
106 // capability may be preferred over another. We use command line
107 // argument(s) and AvoidCapabilities to avoid selection of certain
108 // capabilities if there are other options.
109 CapabilityList UseCaps;
110 for (auto Cap : ReqCaps)
111 if (Reqs.isCapabilityAvailable(Cap))
112 UseCaps.push_back(Elt: Cap);
113 for (size_t i = 0, Sz = UseCaps.size(); i < Sz; ++i) {
114 auto Cap = UseCaps[i];
115 if (i == Sz - 1 || !AvoidCaps.S.contains(V: Cap)) {
116 ReqExts.append(RHS: getSymbolicOperandExtensions(
117 Category: SPIRV::OperandCategory::CapabilityOperand, Value: Cap));
118 return {true, {Cap}, std::move(ReqExts), ReqMinVer, ReqMaxVer};
119 }
120 }
121 }
122 }
123 // If there are no capabilities, or we can't satisfy the version or
124 // capability requirements, use the list of extensions (if the subtarget
125 // can handle them all).
126 if (llvm::all_of(Range&: ReqExts, P: [&ST](const SPIRV::Extension::Extension &Ext) {
127 return ST.canUseExtension(E: Ext);
128 })) {
129 return {true,
130 {},
131 std::move(ReqExts),
132 VersionTuple(),
133 VersionTuple()}; // TODO: add versions to extensions.
134 }
135 return {false, {}, {}, VersionTuple(), VersionTuple()};
136}
137
138void SPIRVModuleAnalysis::setBaseInfo(const Module &M) {
139 MAI.MaxID = 0;
140 for (int i = 0; i < SPIRV::NUM_MODULE_SECTIONS; i++)
141 MAI.MS[i].clear();
142 MAI.RegisterAliasTable.clear();
143 MAI.InstrsToDelete.clear();
144 MAI.FuncMap.clear();
145 MAI.GlobalVarList.clear();
146 MAI.ExtInstSetMap.clear();
147 MAI.Reqs.clear();
148 MAI.Reqs.initAvailableCapabilities(ST: *ST);
149
150 // TODO: determine memory model and source language from the configuratoin.
151 if (auto MemModel = M.getNamedMetadata(Name: "spirv.MemoryModel")) {
152 auto MemMD = MemModel->getOperand(i: 0);
153 MAI.Addr = static_cast<SPIRV::AddressingModel::AddressingModel>(
154 getMetadataUInt(MdNode: MemMD, OpIndex: 0));
155 MAI.Mem =
156 static_cast<SPIRV::MemoryModel::MemoryModel>(getMetadataUInt(MdNode: MemMD, OpIndex: 1));
157 } else {
158 // TODO: Add support for VulkanMemoryModel.
159 MAI.Mem = ST->isShader() ? SPIRV::MemoryModel::GLSL450
160 : SPIRV::MemoryModel::OpenCL;
161 if (MAI.Mem == SPIRV::MemoryModel::OpenCL) {
162 unsigned PtrSize = ST->getPointerSize();
163 MAI.Addr = PtrSize == 32 ? SPIRV::AddressingModel::Physical32
164 : PtrSize == 64 ? SPIRV::AddressingModel::Physical64
165 : SPIRV::AddressingModel::Logical;
166 } else {
167 // TODO: Add support for PhysicalStorageBufferAddress.
168 MAI.Addr = SPIRV::AddressingModel::Logical;
169 }
170 }
171 // Get the OpenCL version number from metadata.
172 // TODO: support other source languages.
173 if (auto VerNode = M.getNamedMetadata(Name: "opencl.ocl.version")) {
174 MAI.SrcLang = SPIRV::SourceLanguage::OpenCL_C;
175 // Construct version literal in accordance with SPIRV-LLVM-Translator.
176 // TODO: support multiple OCL version metadata.
177 assert(VerNode->getNumOperands() > 0 && "Invalid SPIR");
178 auto VersionMD = VerNode->getOperand(i: 0);
179 unsigned MajorNum = getMetadataUInt(MdNode: VersionMD, OpIndex: 0, DefaultVal: 2);
180 unsigned MinorNum = getMetadataUInt(MdNode: VersionMD, OpIndex: 1);
181 unsigned RevNum = getMetadataUInt(MdNode: VersionMD, OpIndex: 2);
182 // Prevent Major part of OpenCL version to be 0
183 MAI.SrcLangVersion =
184 (std::max(a: 1U, b: MajorNum) * 100 + MinorNum) * 1000 + RevNum;
185 } else {
186 // If there is no information about OpenCL version we are forced to generate
187 // OpenCL 1.0 by default for the OpenCL environment to avoid puzzling
188 // run-times with Unknown/0.0 version output. For a reference, LLVM-SPIRV
189 // Translator avoids potential issues with run-times in a similar manner.
190 if (!ST->isShader()) {
191 MAI.SrcLang = SPIRV::SourceLanguage::OpenCL_CPP;
192 MAI.SrcLangVersion = 100000;
193 } else {
194 MAI.SrcLang = SPIRV::SourceLanguage::Unknown;
195 MAI.SrcLangVersion = 0;
196 }
197 }
198
199 if (auto ExtNode = M.getNamedMetadata(Name: "opencl.used.extensions")) {
200 for (unsigned I = 0, E = ExtNode->getNumOperands(); I != E; ++I) {
201 MDNode *MD = ExtNode->getOperand(i: I);
202 if (!MD || MD->getNumOperands() == 0)
203 continue;
204 for (unsigned J = 0, N = MD->getNumOperands(); J != N; ++J)
205 MAI.SrcExt.insert(key: cast<MDString>(Val: MD->getOperand(I: J))->getString());
206 }
207 }
208
209 // Update required capabilities for this memory model, addressing model and
210 // source language.
211 MAI.Reqs.getAndAddRequirements(Category: SPIRV::OperandCategory::MemoryModelOperand,
212 i: MAI.Mem, ST: *ST);
213 MAI.Reqs.getAndAddRequirements(Category: SPIRV::OperandCategory::SourceLanguageOperand,
214 i: MAI.SrcLang, ST: *ST);
215 MAI.Reqs.getAndAddRequirements(Category: SPIRV::OperandCategory::AddressingModelOperand,
216 i: MAI.Addr, ST: *ST);
217
218 if (MAI.Mem == SPIRV::MemoryModel::VulkanKHR)
219 MAI.Reqs.addExtension(ToAdd: SPIRV::Extension::SPV_KHR_vulkan_memory_model);
220
221 if (!ST->isShader()) {
222 // TODO: check if it's required by default.
223 MAI.ExtInstSetMap[static_cast<unsigned>(
224 SPIRV::InstructionSet::OpenCL_std)] = MAI.getNextIDRegister();
225 }
226}
227
228// Appends the signature of the decoration instructions that decorate R to
229// Signature.
230static void appendDecorationsForReg(const MachineRegisterInfo &MRI, Register R,
231 InstrSignature &Signature) {
232 for (MachineInstr &UseMI : MRI.use_instructions(Reg: R)) {
233 // We don't handle OpDecorateId because getting the register alias for the
234 // ID can cause problems, and we do not need it for now.
235 if (UseMI.getOpcode() != SPIRV::OpDecorate &&
236 UseMI.getOpcode() != SPIRV::OpMemberDecorate)
237 continue;
238
239 for (unsigned I = 0; I < UseMI.getNumOperands(); ++I) {
240 const MachineOperand &MO = UseMI.getOperand(i: I);
241 if (MO.isReg())
242 continue;
243 Signature.push_back(Elt: hash_value(MO));
244 }
245 }
246}
247
248// Returns a representation of an instruction as a vector of MachineOperand
249// hash values, see llvm::hash_value(const MachineOperand &MO) for details.
250// This creates a signature of the instruction with the same content
251// that MachineOperand::isIdenticalTo uses for comparison.
252static InstrSignature instrToSignature(const MachineInstr &MI,
253 SPIRV::ModuleAnalysisInfo &MAI,
254 bool UseDefReg) {
255 Register DefReg;
256 InstrSignature Signature{MI.getOpcode()};
257 for (unsigned i = 0; i < MI.getNumOperands(); ++i) {
258 // The only decorations that can be applied more than once to a given <id>
259 // or structure member are FuncParamAttr (38), UserSemantic (5635),
260 // CacheControlLoadINTEL (6442), and CacheControlStoreINTEL (6443). For all
261 // the rest of decorations, we will only add to the signature the Opcode,
262 // the id to which it applies, and the decoration id, disregarding any
263 // decoration flags. This will ensure that any subsequent decoration with
264 // the same id will be deemed as a duplicate. Then, at the call site, we
265 // will be able to handle duplicates in the best way.
266 unsigned Opcode = MI.getOpcode();
267 if ((Opcode == SPIRV::OpDecorate) && i >= 2) {
268 unsigned DecorationID = MI.getOperand(i: 1).getImm();
269 if (DecorationID != SPIRV::Decoration::FuncParamAttr &&
270 DecorationID != SPIRV::Decoration::UserSemantic &&
271 DecorationID != SPIRV::Decoration::CacheControlLoadINTEL &&
272 DecorationID != SPIRV::Decoration::CacheControlStoreINTEL)
273 continue;
274 }
275 const MachineOperand &MO = MI.getOperand(i);
276 size_t h;
277 if (MO.isReg()) {
278 if (!UseDefReg && MO.isDef()) {
279 assert(!DefReg.isValid() && "Multiple def registers.");
280 DefReg = MO.getReg();
281 continue;
282 }
283 Register RegAlias = MAI.getRegisterAlias(MF: MI.getMF(), Reg: MO.getReg());
284 if (!RegAlias.isValid()) {
285 LLVM_DEBUG({
286 dbgs() << "Unexpectedly, no global id found for the operand ";
287 MO.print(dbgs());
288 dbgs() << "\nInstruction: ";
289 MI.print(dbgs());
290 dbgs() << "\n";
291 });
292 report_fatal_error(reason: "All v-regs must have been mapped to global id's");
293 }
294 // mimic llvm::hash_value(const MachineOperand &MO)
295 h = hash_combine(args: MO.getType(), args: (unsigned)RegAlias, args: MO.getSubReg(),
296 args: MO.isDef());
297 } else {
298 h = hash_value(MO);
299 }
300 Signature.push_back(Elt: h);
301 }
302
303 if (DefReg.isValid()) {
304 // Decorations change the semantics of the current instruction. So two
305 // identical instruction with different decorations cannot be merged. That
306 // is why we add the decorations to the signature.
307 appendDecorationsForReg(MRI: MI.getMF()->getRegInfo(), R: DefReg, Signature);
308 }
309 return Signature;
310}
311
312bool SPIRVModuleAnalysis::isDeclSection(const MachineRegisterInfo &MRI,
313 const MachineInstr &MI) {
314 unsigned Opcode = MI.getOpcode();
315 switch (Opcode) {
316 case SPIRV::OpTypeForwardPointer:
317 // omit now, collect later
318 return false;
319 case SPIRV::OpVariable:
320 return static_cast<SPIRV::StorageClass::StorageClass>(
321 MI.getOperand(i: 2).getImm()) != SPIRV::StorageClass::Function;
322 case SPIRV::OpFunction:
323 case SPIRV::OpFunctionParameter:
324 return true;
325 }
326 if (GR->hasConstFunPtr() && Opcode == SPIRV::OpUndef) {
327 Register DefReg = MI.getOperand(i: 0).getReg();
328 for (MachineInstr &UseMI : MRI.use_instructions(Reg: DefReg)) {
329 if (UseMI.getOpcode() != SPIRV::OpConstantFunctionPointerINTEL)
330 continue;
331 // it's a dummy definition, FP constant refers to a function,
332 // and this is resolved in another way; let's skip this definition
333 assert(UseMI.getOperand(2).isReg() &&
334 UseMI.getOperand(2).getReg() == DefReg);
335 MAI.setSkipEmission(&MI);
336 return false;
337 }
338 }
339 return TII->isTypeDeclInstr(MI) || TII->isConstantInstr(MI) ||
340 TII->isInlineAsmDefInstr(MI);
341}
342
343// This is a special case of a function pointer refering to a possibly
344// forward function declaration. The operand is a dummy OpUndef that
345// requires a special treatment.
346void SPIRVModuleAnalysis::visitFunPtrUse(
347 Register OpReg, InstrGRegsMap &SignatureToGReg,
348 std::map<const Value *, unsigned> &GlobalToGReg, const MachineFunction *MF,
349 const MachineInstr &MI) {
350 const MachineOperand *OpFunDef =
351 GR->getFunctionDefinitionByUse(Use: &MI.getOperand(i: 2));
352 assert(OpFunDef && OpFunDef->isReg());
353 // find the actual function definition and number it globally in advance
354 const MachineInstr *OpDefMI = OpFunDef->getParent();
355 assert(OpDefMI && OpDefMI->getOpcode() == SPIRV::OpFunction);
356 const MachineFunction *FunDefMF = OpDefMI->getParent()->getParent();
357 const MachineRegisterInfo &FunDefMRI = FunDefMF->getRegInfo();
358 do {
359 visitDecl(MRI: FunDefMRI, SignatureToGReg, GlobalToGReg, MF: FunDefMF, MI: *OpDefMI);
360 OpDefMI = OpDefMI->getNextNode();
361 } while (OpDefMI && (OpDefMI->getOpcode() == SPIRV::OpFunction ||
362 OpDefMI->getOpcode() == SPIRV::OpFunctionParameter));
363 // associate the function pointer with the newly assigned global number
364 MCRegister GlobalFunDefReg =
365 MAI.getRegisterAlias(MF: FunDefMF, Reg: OpFunDef->getReg());
366 assert(GlobalFunDefReg.isValid() &&
367 "Function definition must refer to a global register");
368 MAI.setRegisterAlias(MF, Reg: OpReg, AliasReg: GlobalFunDefReg);
369}
370
371// Depth first recursive traversal of dependencies. Repeated visits are guarded
372// by MAI.hasRegisterAlias().
373void SPIRVModuleAnalysis::visitDecl(
374 const MachineRegisterInfo &MRI, InstrGRegsMap &SignatureToGReg,
375 std::map<const Value *, unsigned> &GlobalToGReg, const MachineFunction *MF,
376 const MachineInstr &MI) {
377 unsigned Opcode = MI.getOpcode();
378
379 // Process each operand of the instruction to resolve dependencies
380 for (const MachineOperand &MO : MI.operands()) {
381 if (!MO.isReg() || MO.isDef())
382 continue;
383 Register OpReg = MO.getReg();
384 // Handle function pointers special case
385 if (Opcode == SPIRV::OpConstantFunctionPointerINTEL &&
386 MRI.getRegClass(Reg: OpReg) == &SPIRV::pIDRegClass) {
387 visitFunPtrUse(OpReg, SignatureToGReg, GlobalToGReg, MF, MI);
388 continue;
389 }
390 // Skip already processed instructions
391 if (MAI.hasRegisterAlias(MF, Reg: MO.getReg()))
392 continue;
393 // Recursively visit dependencies
394 if (const MachineInstr *OpDefMI = MRI.getUniqueVRegDef(Reg: OpReg)) {
395 if (isDeclSection(MRI, MI: *OpDefMI))
396 visitDecl(MRI, SignatureToGReg, GlobalToGReg, MF, MI: *OpDefMI);
397 continue;
398 }
399 // Handle the unexpected case of no unique definition for the SPIR-V
400 // instruction
401 LLVM_DEBUG({
402 dbgs() << "Unexpectedly, no unique definition for the operand ";
403 MO.print(dbgs());
404 dbgs() << "\nInstruction: ";
405 MI.print(dbgs());
406 dbgs() << "\n";
407 });
408 report_fatal_error(
409 reason: "No unique definition is found for the virtual register");
410 }
411
412 MCRegister GReg;
413 bool IsFunDef = false;
414 if (TII->isSpecConstantInstr(MI)) {
415 GReg = MAI.getNextIDRegister();
416 MAI.MS[SPIRV::MB_TypeConstVars].push_back(Elt: &MI);
417 } else if (Opcode == SPIRV::OpFunction ||
418 Opcode == SPIRV::OpFunctionParameter) {
419 GReg = handleFunctionOrParameter(MF, MI, GlobalToGReg, IsFunDef);
420 } else if (Opcode == SPIRV::OpTypeStruct ||
421 Opcode == SPIRV::OpConstantComposite) {
422 GReg = handleTypeDeclOrConstant(MI, SignatureToGReg);
423 const MachineInstr *NextInstr = MI.getNextNode();
424 while (NextInstr &&
425 ((Opcode == SPIRV::OpTypeStruct &&
426 NextInstr->getOpcode() == SPIRV::OpTypeStructContinuedINTEL) ||
427 (Opcode == SPIRV::OpConstantComposite &&
428 NextInstr->getOpcode() ==
429 SPIRV::OpConstantCompositeContinuedINTEL))) {
430 MCRegister Tmp = handleTypeDeclOrConstant(MI: *NextInstr, SignatureToGReg);
431 MAI.setRegisterAlias(MF, Reg: NextInstr->getOperand(i: 0).getReg(), AliasReg: Tmp);
432 MAI.setSkipEmission(NextInstr);
433 NextInstr = NextInstr->getNextNode();
434 }
435 } else if (TII->isTypeDeclInstr(MI) || TII->isConstantInstr(MI) ||
436 TII->isInlineAsmDefInstr(MI)) {
437 GReg = handleTypeDeclOrConstant(MI, SignatureToGReg);
438 } else if (Opcode == SPIRV::OpVariable) {
439 GReg = handleVariable(MF, MI, GlobalToGReg);
440 } else {
441 LLVM_DEBUG({
442 dbgs() << "\nInstruction: ";
443 MI.print(dbgs());
444 dbgs() << "\n";
445 });
446 llvm_unreachable("Unexpected instruction is visited");
447 }
448 MAI.setRegisterAlias(MF, Reg: MI.getOperand(i: 0).getReg(), AliasReg: GReg);
449 if (!IsFunDef)
450 MAI.setSkipEmission(&MI);
451}
452
453MCRegister SPIRVModuleAnalysis::handleFunctionOrParameter(
454 const MachineFunction *MF, const MachineInstr &MI,
455 std::map<const Value *, unsigned> &GlobalToGReg, bool &IsFunDef) {
456 const Value *GObj = GR->getGlobalObject(MF, R: MI.getOperand(i: 0).getReg());
457 assert(GObj && "Unregistered global definition");
458 const Function *F = dyn_cast<Function>(Val: GObj);
459 if (!F)
460 F = dyn_cast<Argument>(Val: GObj)->getParent();
461 assert(F && "Expected a reference to a function or an argument");
462 IsFunDef = !F->isDeclaration();
463 auto [It, Inserted] = GlobalToGReg.try_emplace(k: GObj);
464 if (!Inserted)
465 return It->second;
466 MCRegister GReg = MAI.getNextIDRegister();
467 It->second = GReg;
468 if (!IsFunDef)
469 MAI.MS[SPIRV::MB_ExtFuncDecls].push_back(Elt: &MI);
470 return GReg;
471}
472
473MCRegister
474SPIRVModuleAnalysis::handleTypeDeclOrConstant(const MachineInstr &MI,
475 InstrGRegsMap &SignatureToGReg) {
476 InstrSignature MISign = instrToSignature(MI, MAI, UseDefReg: false);
477 auto [It, Inserted] = SignatureToGReg.try_emplace(k: MISign);
478 if (!Inserted)
479 return It->second;
480 MCRegister GReg = MAI.getNextIDRegister();
481 It->second = GReg;
482 MAI.MS[SPIRV::MB_TypeConstVars].push_back(Elt: &MI);
483 return GReg;
484}
485
486MCRegister SPIRVModuleAnalysis::handleVariable(
487 const MachineFunction *MF, const MachineInstr &MI,
488 std::map<const Value *, unsigned> &GlobalToGReg) {
489 MAI.GlobalVarList.push_back(Elt: &MI);
490 const Value *GObj = GR->getGlobalObject(MF, R: MI.getOperand(i: 0).getReg());
491 assert(GObj && "Unregistered global definition");
492 auto [It, Inserted] = GlobalToGReg.try_emplace(k: GObj);
493 if (!Inserted)
494 return It->second;
495 MCRegister GReg = MAI.getNextIDRegister();
496 It->second = GReg;
497 MAI.MS[SPIRV::MB_TypeConstVars].push_back(Elt: &MI);
498 return GReg;
499}
500
501void SPIRVModuleAnalysis::collectDeclarations(const Module &M) {
502 InstrGRegsMap SignatureToGReg;
503 std::map<const Value *, unsigned> GlobalToGReg;
504 for (const Function &F : M) {
505 MachineFunction *MF = MMI->getMachineFunction(F);
506 if (!MF)
507 continue;
508 const MachineRegisterInfo &MRI = MF->getRegInfo();
509 unsigned PastHeader = 0;
510 for (MachineBasicBlock &MBB : *MF) {
511 for (MachineInstr &MI : MBB) {
512 if (MI.getNumOperands() == 0)
513 continue;
514 unsigned Opcode = MI.getOpcode();
515 if (Opcode == SPIRV::OpFunction) {
516 if (PastHeader == 0) {
517 PastHeader = 1;
518 continue;
519 }
520 } else if (Opcode == SPIRV::OpFunctionParameter) {
521 if (PastHeader < 2)
522 continue;
523 } else if (PastHeader > 0) {
524 PastHeader = 2;
525 }
526
527 const MachineOperand &DefMO = MI.getOperand(i: 0);
528 switch (Opcode) {
529 case SPIRV::OpExtension:
530 MAI.Reqs.addExtension(ToAdd: SPIRV::Extension::Extension(DefMO.getImm()));
531 MAI.setSkipEmission(&MI);
532 break;
533 case SPIRV::OpCapability:
534 MAI.Reqs.addCapability(ToAdd: SPIRV::Capability::Capability(DefMO.getImm()));
535 MAI.setSkipEmission(&MI);
536 if (PastHeader > 0)
537 PastHeader = 2;
538 break;
539 default:
540 if (DefMO.isReg() && isDeclSection(MRI, MI) &&
541 !MAI.hasRegisterAlias(MF, Reg: DefMO.getReg()))
542 visitDecl(MRI, SignatureToGReg, GlobalToGReg, MF, MI);
543 }
544 }
545 }
546 }
547}
548
549// Look for IDs declared with Import linkage, and map the corresponding function
550// to the register defining that variable (which will usually be the result of
551// an OpFunction). This lets us call externally imported functions using
552// the correct ID registers.
553void SPIRVModuleAnalysis::collectFuncNames(MachineInstr &MI,
554 const Function *F) {
555 if (MI.getOpcode() == SPIRV::OpDecorate) {
556 // If it's got Import linkage.
557 auto Dec = MI.getOperand(i: 1).getImm();
558 if (Dec == SPIRV::Decoration::LinkageAttributes) {
559 auto Lnk = MI.getOperand(i: MI.getNumOperands() - 1).getImm();
560 if (Lnk == SPIRV::LinkageType::Import) {
561 // Map imported function name to function ID register.
562 const Function *ImportedFunc =
563 F->getParent()->getFunction(Name: getStringImm(MI, StartIndex: 2));
564 Register Target = MI.getOperand(i: 0).getReg();
565 MAI.FuncMap[ImportedFunc] = MAI.getRegisterAlias(MF: MI.getMF(), Reg: Target);
566 }
567 }
568 } else if (MI.getOpcode() == SPIRV::OpFunction) {
569 // Record all internal OpFunction declarations.
570 Register Reg = MI.defs().begin()->getReg();
571 MCRegister GlobalReg = MAI.getRegisterAlias(MF: MI.getMF(), Reg);
572 assert(GlobalReg.isValid());
573 MAI.FuncMap[F] = GlobalReg;
574 }
575}
576
577// Collect the given instruction in the specified MS. We assume global register
578// numbering has already occurred by this point. We can directly compare reg
579// arguments when detecting duplicates.
580static void collectOtherInstr(MachineInstr &MI, SPIRV::ModuleAnalysisInfo &MAI,
581 SPIRV::ModuleSectionType MSType, InstrTraces &IS,
582 bool Append = true) {
583 MAI.setSkipEmission(&MI);
584 InstrSignature MISign = instrToSignature(MI, MAI, UseDefReg: true);
585 auto FoundMI = IS.insert(x: std::move(MISign));
586 if (!FoundMI.second) {
587 if (MI.getOpcode() == SPIRV::OpDecorate) {
588 assert(MI.getNumOperands() >= 2 &&
589 "Decoration instructions must have at least 2 operands");
590 assert(MSType == SPIRV::MB_Annotations &&
591 "Only OpDecorate instructions can be duplicates");
592 // For FPFastMathMode decoration, we need to merge the flags of the
593 // duplicate decoration with the original one, so we need to find the
594 // original instruction that has the same signature. For the rest of
595 // instructions, we will simply skip the duplicate.
596 if (MI.getOperand(i: 1).getImm() != SPIRV::Decoration::FPFastMathMode)
597 return; // Skip duplicates of other decorations.
598
599 const SPIRV::InstrList &Decorations = MAI.MS[MSType];
600 for (const MachineInstr *OrigMI : Decorations) {
601 if (instrToSignature(MI: *OrigMI, MAI, UseDefReg: true) == MISign) {
602 assert(OrigMI->getNumOperands() == MI.getNumOperands() &&
603 "Original instruction must have the same number of operands");
604 assert(
605 OrigMI->getNumOperands() == 3 &&
606 "FPFastMathMode decoration must have 3 operands for OpDecorate");
607 unsigned OrigFlags = OrigMI->getOperand(i: 2).getImm();
608 unsigned NewFlags = MI.getOperand(i: 2).getImm();
609 if (OrigFlags == NewFlags)
610 return; // No need to merge, the flags are the same.
611
612 // Emit warning about possible conflict between flags.
613 unsigned FinalFlags = OrigFlags | NewFlags;
614 llvm::errs()
615 << "Warning: Conflicting FPFastMathMode decoration flags "
616 "in instruction: "
617 << *OrigMI << "Original flags: " << OrigFlags
618 << ", new flags: " << NewFlags
619 << ". They will be merged on a best effort basis, but not "
620 "validated. Final flags: "
621 << FinalFlags << "\n";
622 MachineInstr *OrigMINonConst = const_cast<MachineInstr *>(OrigMI);
623 MachineOperand &OrigFlagsOp = OrigMINonConst->getOperand(i: 2);
624 OrigFlagsOp = MachineOperand::CreateImm(Val: FinalFlags);
625 return; // Merge done, so we found a duplicate; don't add it to MAI.MS
626 }
627 }
628 assert(false && "No original instruction found for the duplicate "
629 "OpDecorate, but we found one in IS.");
630 }
631 return; // insert failed, so we found a duplicate; don't add it to MAI.MS
632 }
633 // No duplicates, so add it.
634 if (Append)
635 MAI.MS[MSType].push_back(Elt: &MI);
636 else
637 MAI.MS[MSType].insert(I: MAI.MS[MSType].begin(), Elt: &MI);
638}
639
640// Some global instructions make reference to function-local ID regs, so cannot
641// be correctly collected until these registers are globally numbered.
642void SPIRVModuleAnalysis::processOtherInstrs(const Module &M) {
643 InstrTraces IS;
644 for (const Function &F : M) {
645 if (F.isDeclaration())
646 continue;
647 MachineFunction *MF = MMI->getMachineFunction(F);
648 assert(MF);
649
650 for (MachineBasicBlock &MBB : *MF)
651 for (MachineInstr &MI : MBB) {
652 if (MAI.getSkipEmission(MI: &MI))
653 continue;
654 const unsigned OpCode = MI.getOpcode();
655 if (OpCode == SPIRV::OpString) {
656 collectOtherInstr(MI, MAI, MSType: SPIRV::MB_DebugStrings, IS);
657 } else if (OpCode == SPIRV::OpExtInst && MI.getOperand(i: 2).isImm() &&
658 MI.getOperand(i: 2).getImm() ==
659 SPIRV::InstructionSet::
660 NonSemantic_Shader_DebugInfo_100) {
661 MachineOperand Ins = MI.getOperand(i: 3);
662 namespace NS = SPIRV::NonSemanticExtInst;
663 static constexpr int64_t GlobalNonSemanticDITy[] = {
664 NS::DebugSource, NS::DebugCompilationUnit, NS::DebugInfoNone,
665 NS::DebugTypeBasic, NS::DebugTypePointer};
666 bool IsGlobalDI = false;
667 for (unsigned Idx = 0; Idx < std::size(GlobalNonSemanticDITy); ++Idx)
668 IsGlobalDI |= Ins.getImm() == GlobalNonSemanticDITy[Idx];
669 if (IsGlobalDI)
670 collectOtherInstr(MI, MAI, MSType: SPIRV::MB_NonSemanticGlobalDI, IS);
671 } else if (OpCode == SPIRV::OpName || OpCode == SPIRV::OpMemberName) {
672 collectOtherInstr(MI, MAI, MSType: SPIRV::MB_DebugNames, IS);
673 } else if (OpCode == SPIRV::OpEntryPoint) {
674 collectOtherInstr(MI, MAI, MSType: SPIRV::MB_EntryPoints, IS);
675 } else if (TII->isAliasingInstr(MI)) {
676 collectOtherInstr(MI, MAI, MSType: SPIRV::MB_AliasingInsts, IS);
677 } else if (TII->isDecorationInstr(MI)) {
678 collectOtherInstr(MI, MAI, MSType: SPIRV::MB_Annotations, IS);
679 collectFuncNames(MI, F: &F);
680 } else if (TII->isConstantInstr(MI)) {
681 // Now OpSpecConstant*s are not in DT,
682 // but they need to be collected anyway.
683 collectOtherInstr(MI, MAI, MSType: SPIRV::MB_TypeConstVars, IS);
684 } else if (OpCode == SPIRV::OpFunction) {
685 collectFuncNames(MI, F: &F);
686 } else if (OpCode == SPIRV::OpTypeForwardPointer) {
687 collectOtherInstr(MI, MAI, MSType: SPIRV::MB_TypeConstVars, IS, Append: false);
688 }
689 }
690 }
691}
692
693// Number registers in all functions globally from 0 onwards and store
694// the result in global register alias table. Some registers are already
695// numbered.
696void SPIRVModuleAnalysis::numberRegistersGlobally(const Module &M) {
697 for (const Function &F : M) {
698 if (F.isDeclaration())
699 continue;
700 MachineFunction *MF = MMI->getMachineFunction(F);
701 assert(MF);
702 for (MachineBasicBlock &MBB : *MF) {
703 for (MachineInstr &MI : MBB) {
704 for (MachineOperand &Op : MI.operands()) {
705 if (!Op.isReg())
706 continue;
707 Register Reg = Op.getReg();
708 if (MAI.hasRegisterAlias(MF, Reg))
709 continue;
710 MCRegister NewReg = MAI.getNextIDRegister();
711 MAI.setRegisterAlias(MF, Reg, AliasReg: NewReg);
712 }
713 if (MI.getOpcode() != SPIRV::OpExtInst)
714 continue;
715 auto Set = MI.getOperand(i: 2).getImm();
716 auto [It, Inserted] = MAI.ExtInstSetMap.try_emplace(Key: Set);
717 if (Inserted)
718 It->second = MAI.getNextIDRegister();
719 }
720 }
721 }
722}
723
724// RequirementHandler implementations.
725void SPIRV::RequirementHandler::getAndAddRequirements(
726 SPIRV::OperandCategory::OperandCategory Category, uint32_t i,
727 const SPIRVSubtarget &ST) {
728 addRequirements(Req: getSymbolicOperandRequirements(Category, i, ST, Reqs&: *this));
729}
730
731void SPIRV::RequirementHandler::recursiveAddCapabilities(
732 const CapabilityList &ToPrune) {
733 for (const auto &Cap : ToPrune) {
734 AllCaps.insert(V: Cap);
735 CapabilityList ImplicitDecls =
736 getSymbolicOperandCapabilities(Category: OperandCategory::CapabilityOperand, Value: Cap);
737 recursiveAddCapabilities(ToPrune: ImplicitDecls);
738 }
739}
740
741void SPIRV::RequirementHandler::addCapabilities(const CapabilityList &ToAdd) {
742 for (const auto &Cap : ToAdd) {
743 bool IsNewlyInserted = AllCaps.insert(V: Cap).second;
744 if (!IsNewlyInserted) // Don't re-add if it's already been declared.
745 continue;
746 CapabilityList ImplicitDecls =
747 getSymbolicOperandCapabilities(Category: OperandCategory::CapabilityOperand, Value: Cap);
748 recursiveAddCapabilities(ToPrune: ImplicitDecls);
749 MinimalCaps.push_back(Elt: Cap);
750 }
751}
752
753void SPIRV::RequirementHandler::addRequirements(
754 const SPIRV::Requirements &Req) {
755 if (!Req.IsSatisfiable)
756 report_fatal_error(reason: "Adding SPIR-V requirements this target can't satisfy.");
757
758 if (Req.Cap.has_value())
759 addCapabilities(ToAdd: {Req.Cap.value()});
760
761 addExtensions(ToAdd: Req.Exts);
762
763 if (!Req.MinVer.empty()) {
764 if (!MaxVersion.empty() && Req.MinVer > MaxVersion) {
765 LLVM_DEBUG(dbgs() << "Conflicting version requirements: >= " << Req.MinVer
766 << " and <= " << MaxVersion << "\n");
767 report_fatal_error(reason: "Adding SPIR-V requirements that can't be satisfied.");
768 }
769
770 if (MinVersion.empty() || Req.MinVer > MinVersion)
771 MinVersion = Req.MinVer;
772 }
773
774 if (!Req.MaxVer.empty()) {
775 if (!MinVersion.empty() && Req.MaxVer < MinVersion) {
776 LLVM_DEBUG(dbgs() << "Conflicting version requirements: <= " << Req.MaxVer
777 << " and >= " << MinVersion << "\n");
778 report_fatal_error(reason: "Adding SPIR-V requirements that can't be satisfied.");
779 }
780
781 if (MaxVersion.empty() || Req.MaxVer < MaxVersion)
782 MaxVersion = Req.MaxVer;
783 }
784}
785
786void SPIRV::RequirementHandler::checkSatisfiable(
787 const SPIRVSubtarget &ST) const {
788 // Report as many errors as possible before aborting the compilation.
789 bool IsSatisfiable = true;
790 auto TargetVer = ST.getSPIRVVersion();
791
792 if (!MaxVersion.empty() && !TargetVer.empty() && MaxVersion < TargetVer) {
793 LLVM_DEBUG(
794 dbgs() << "Target SPIR-V version too high for required features\n"
795 << "Required max version: " << MaxVersion << " target version "
796 << TargetVer << "\n");
797 IsSatisfiable = false;
798 }
799
800 if (!MinVersion.empty() && !TargetVer.empty() && MinVersion > TargetVer) {
801 LLVM_DEBUG(dbgs() << "Target SPIR-V version too low for required features\n"
802 << "Required min version: " << MinVersion
803 << " target version " << TargetVer << "\n");
804 IsSatisfiable = false;
805 }
806
807 if (!MinVersion.empty() && !MaxVersion.empty() && MinVersion > MaxVersion) {
808 LLVM_DEBUG(
809 dbgs()
810 << "Version is too low for some features and too high for others.\n"
811 << "Required SPIR-V min version: " << MinVersion
812 << " required SPIR-V max version " << MaxVersion << "\n");
813 IsSatisfiable = false;
814 }
815
816 AvoidCapabilitiesSet AvoidCaps;
817 if (!ST.isShader())
818 AvoidCaps.S.insert(V: SPIRV::Capability::Shader);
819 else
820 AvoidCaps.S.insert(V: SPIRV::Capability::Kernel);
821
822 for (auto Cap : MinimalCaps) {
823 if (AvailableCaps.contains(V: Cap) && !AvoidCaps.S.contains(V: Cap))
824 continue;
825 LLVM_DEBUG(dbgs() << "Capability not supported: "
826 << getSymbolicOperandMnemonic(
827 OperandCategory::CapabilityOperand, Cap)
828 << "\n");
829 IsSatisfiable = false;
830 }
831
832 for (auto Ext : AllExtensions) {
833 if (ST.canUseExtension(E: Ext))
834 continue;
835 LLVM_DEBUG(dbgs() << "Extension not supported: "
836 << getSymbolicOperandMnemonic(
837 OperandCategory::ExtensionOperand, Ext)
838 << "\n");
839 IsSatisfiable = false;
840 }
841
842 if (!IsSatisfiable)
843 report_fatal_error(reason: "Unable to meet SPIR-V requirements for this target.");
844}
845
846// Add the given capabilities and all their implicitly defined capabilities too.
847void SPIRV::RequirementHandler::addAvailableCaps(const CapabilityList &ToAdd) {
848 for (const auto Cap : ToAdd)
849 if (AvailableCaps.insert(V: Cap).second)
850 addAvailableCaps(ToAdd: getSymbolicOperandCapabilities(
851 Category: SPIRV::OperandCategory::CapabilityOperand, Value: Cap));
852}
853
854void SPIRV::RequirementHandler::removeCapabilityIf(
855 const Capability::Capability ToRemove,
856 const Capability::Capability IfPresent) {
857 if (AllCaps.contains(V: IfPresent))
858 AllCaps.erase(V: ToRemove);
859}
860
861namespace llvm {
862namespace SPIRV {
863void RequirementHandler::initAvailableCapabilities(const SPIRVSubtarget &ST) {
864 // Provided by both all supported Vulkan versions and OpenCl.
865 addAvailableCaps(ToAdd: {Capability::Shader, Capability::Linkage, Capability::Int8,
866 Capability::Int16});
867
868 if (ST.isAtLeastSPIRVVer(VerToCompareTo: VersionTuple(1, 3)))
869 addAvailableCaps(ToAdd: {Capability::GroupNonUniform,
870 Capability::GroupNonUniformVote,
871 Capability::GroupNonUniformArithmetic,
872 Capability::GroupNonUniformBallot,
873 Capability::GroupNonUniformClustered,
874 Capability::GroupNonUniformShuffle,
875 Capability::GroupNonUniformShuffleRelative});
876
877 if (ST.isAtLeastSPIRVVer(VerToCompareTo: VersionTuple(1, 6)))
878 addAvailableCaps(ToAdd: {Capability::DotProduct, Capability::DotProductInputAll,
879 Capability::DotProductInput4x8Bit,
880 Capability::DotProductInput4x8BitPacked,
881 Capability::DemoteToHelperInvocation});
882
883 // Add capabilities enabled by extensions.
884 for (auto Extension : ST.getAllAvailableExtensions()) {
885 CapabilityList EnabledCapabilities =
886 getCapabilitiesEnabledByExtension(Extension);
887 addAvailableCaps(ToAdd: EnabledCapabilities);
888 }
889
890 if (!ST.isShader()) {
891 initAvailableCapabilitiesForOpenCL(ST);
892 return;
893 }
894
895 if (ST.isShader()) {
896 initAvailableCapabilitiesForVulkan(ST);
897 return;
898 }
899
900 report_fatal_error(reason: "Unimplemented environment for SPIR-V generation.");
901}
902
903void RequirementHandler::initAvailableCapabilitiesForOpenCL(
904 const SPIRVSubtarget &ST) {
905 // Add the min requirements for different OpenCL and SPIR-V versions.
906 addAvailableCaps(ToAdd: {Capability::Addresses, Capability::Float16Buffer,
907 Capability::Kernel, Capability::Vector16,
908 Capability::Groups, Capability::GenericPointer,
909 Capability::StorageImageWriteWithoutFormat,
910 Capability::StorageImageReadWithoutFormat});
911 if (ST.hasOpenCLFullProfile())
912 addAvailableCaps(ToAdd: {Capability::Int64, Capability::Int64Atomics});
913 if (ST.hasOpenCLImageSupport()) {
914 addAvailableCaps(ToAdd: {Capability::ImageBasic, Capability::LiteralSampler,
915 Capability::Image1D, Capability::SampledBuffer,
916 Capability::ImageBuffer});
917 if (ST.isAtLeastOpenCLVer(VerToCompareTo: VersionTuple(2, 0)))
918 addAvailableCaps(ToAdd: {Capability::ImageReadWrite});
919 }
920 if (ST.isAtLeastSPIRVVer(VerToCompareTo: VersionTuple(1, 1)) &&
921 ST.isAtLeastOpenCLVer(VerToCompareTo: VersionTuple(2, 2)))
922 addAvailableCaps(ToAdd: {Capability::SubgroupDispatch, Capability::PipeStorage});
923 if (ST.isAtLeastSPIRVVer(VerToCompareTo: VersionTuple(1, 4)))
924 addAvailableCaps(ToAdd: {Capability::DenormPreserve, Capability::DenormFlushToZero,
925 Capability::SignedZeroInfNanPreserve,
926 Capability::RoundingModeRTE,
927 Capability::RoundingModeRTZ});
928 // TODO: verify if this needs some checks.
929 addAvailableCaps(ToAdd: {Capability::Float16, Capability::Float64});
930
931 // TODO: add OpenCL extensions.
932}
933
934void RequirementHandler::initAvailableCapabilitiesForVulkan(
935 const SPIRVSubtarget &ST) {
936
937 // Core in Vulkan 1.1 and earlier.
938 addAvailableCaps(ToAdd: {Capability::Int64, Capability::Float16, Capability::Float64,
939 Capability::GroupNonUniform, Capability::Image1D,
940 Capability::SampledBuffer, Capability::ImageBuffer,
941 Capability::UniformBufferArrayDynamicIndexing,
942 Capability::SampledImageArrayDynamicIndexing,
943 Capability::StorageBufferArrayDynamicIndexing,
944 Capability::StorageImageArrayDynamicIndexing,
945 Capability::DerivativeControl, Capability::MinLod,
946 Capability::ImageGatherExtended, Capability::Addresses,
947 Capability::VulkanMemoryModelKHR});
948
949 // Became core in Vulkan 1.2
950 if (ST.isAtLeastSPIRVVer(VerToCompareTo: VersionTuple(1, 5))) {
951 addAvailableCaps(
952 ToAdd: {Capability::ShaderNonUniformEXT, Capability::RuntimeDescriptorArrayEXT,
953 Capability::InputAttachmentArrayDynamicIndexingEXT,
954 Capability::UniformTexelBufferArrayDynamicIndexingEXT,
955 Capability::StorageTexelBufferArrayDynamicIndexingEXT,
956 Capability::UniformBufferArrayNonUniformIndexingEXT,
957 Capability::SampledImageArrayNonUniformIndexingEXT,
958 Capability::StorageBufferArrayNonUniformIndexingEXT,
959 Capability::StorageImageArrayNonUniformIndexingEXT,
960 Capability::InputAttachmentArrayNonUniformIndexingEXT,
961 Capability::UniformTexelBufferArrayNonUniformIndexingEXT,
962 Capability::StorageTexelBufferArrayNonUniformIndexingEXT});
963 }
964
965 // Became core in Vulkan 1.3
966 if (ST.isAtLeastSPIRVVer(VerToCompareTo: VersionTuple(1, 6)))
967 addAvailableCaps(ToAdd: {Capability::StorageImageWriteWithoutFormat,
968 Capability::StorageImageReadWithoutFormat});
969}
970
971} // namespace SPIRV
972} // namespace llvm
973
974// Add the required capabilities from a decoration instruction (including
975// BuiltIns).
976static void addOpDecorateReqs(const MachineInstr &MI, unsigned DecIndex,
977 SPIRV::RequirementHandler &Reqs,
978 const SPIRVSubtarget &ST) {
979 int64_t DecOp = MI.getOperand(i: DecIndex).getImm();
980 auto Dec = static_cast<SPIRV::Decoration::Decoration>(DecOp);
981 Reqs.addRequirements(Req: getSymbolicOperandRequirements(
982 Category: SPIRV::OperandCategory::DecorationOperand, i: Dec, ST, Reqs));
983
984 if (Dec == SPIRV::Decoration::BuiltIn) {
985 int64_t BuiltInOp = MI.getOperand(i: DecIndex + 1).getImm();
986 auto BuiltIn = static_cast<SPIRV::BuiltIn::BuiltIn>(BuiltInOp);
987 Reqs.addRequirements(Req: getSymbolicOperandRequirements(
988 Category: SPIRV::OperandCategory::BuiltInOperand, i: BuiltIn, ST, Reqs));
989 } else if (Dec == SPIRV::Decoration::LinkageAttributes) {
990 int64_t LinkageOp = MI.getOperand(i: MI.getNumOperands() - 1).getImm();
991 SPIRV::LinkageType::LinkageType LnkType =
992 static_cast<SPIRV::LinkageType::LinkageType>(LinkageOp);
993 if (LnkType == SPIRV::LinkageType::LinkOnceODR)
994 Reqs.addExtension(ToAdd: SPIRV::Extension::SPV_KHR_linkonce_odr);
995 } else if (Dec == SPIRV::Decoration::CacheControlLoadINTEL ||
996 Dec == SPIRV::Decoration::CacheControlStoreINTEL) {
997 Reqs.addExtension(ToAdd: SPIRV::Extension::SPV_INTEL_cache_controls);
998 } else if (Dec == SPIRV::Decoration::HostAccessINTEL) {
999 Reqs.addExtension(ToAdd: SPIRV::Extension::SPV_INTEL_global_variable_host_access);
1000 } else if (Dec == SPIRV::Decoration::InitModeINTEL ||
1001 Dec == SPIRV::Decoration::ImplementInRegisterMapINTEL) {
1002 Reqs.addExtension(
1003 ToAdd: SPIRV::Extension::SPV_INTEL_global_variable_fpga_decorations);
1004 } else if (Dec == SPIRV::Decoration::NonUniformEXT) {
1005 Reqs.addRequirements(Req: SPIRV::Capability::ShaderNonUniformEXT);
1006 } else if (Dec == SPIRV::Decoration::FPMaxErrorDecorationINTEL) {
1007 Reqs.addRequirements(Req: SPIRV::Capability::FPMaxErrorINTEL);
1008 Reqs.addExtension(ToAdd: SPIRV::Extension::SPV_INTEL_fp_max_error);
1009 } else if (Dec == SPIRV::Decoration::FPFastMathMode) {
1010 if (ST.canUseExtension(E: SPIRV::Extension::SPV_KHR_float_controls2)) {
1011 Reqs.addRequirements(Req: SPIRV::Capability::FloatControls2);
1012 Reqs.addExtension(ToAdd: SPIRV::Extension::SPV_KHR_float_controls2);
1013 }
1014 }
1015}
1016
1017// Add requirements for image handling.
1018static void addOpTypeImageReqs(const MachineInstr &MI,
1019 SPIRV::RequirementHandler &Reqs,
1020 const SPIRVSubtarget &ST) {
1021 assert(MI.getNumOperands() >= 8 && "Insufficient operands for OpTypeImage");
1022 // The operand indices used here are based on the OpTypeImage layout, which
1023 // the MachineInstr follows as well.
1024 int64_t ImgFormatOp = MI.getOperand(i: 7).getImm();
1025 auto ImgFormat = static_cast<SPIRV::ImageFormat::ImageFormat>(ImgFormatOp);
1026 Reqs.getAndAddRequirements(Category: SPIRV::OperandCategory::ImageFormatOperand,
1027 i: ImgFormat, ST);
1028
1029 bool IsArrayed = MI.getOperand(i: 4).getImm() == 1;
1030 bool IsMultisampled = MI.getOperand(i: 5).getImm() == 1;
1031 bool NoSampler = MI.getOperand(i: 6).getImm() == 2;
1032 // Add dimension requirements.
1033 assert(MI.getOperand(2).isImm());
1034 switch (MI.getOperand(i: 2).getImm()) {
1035 case SPIRV::Dim::DIM_1D:
1036 Reqs.addRequirements(Req: NoSampler ? SPIRV::Capability::Image1D
1037 : SPIRV::Capability::Sampled1D);
1038 break;
1039 case SPIRV::Dim::DIM_2D:
1040 if (IsMultisampled && NoSampler)
1041 Reqs.addRequirements(Req: SPIRV::Capability::ImageMSArray);
1042 break;
1043 case SPIRV::Dim::DIM_Cube:
1044 Reqs.addRequirements(Req: SPIRV::Capability::Shader);
1045 if (IsArrayed)
1046 Reqs.addRequirements(Req: NoSampler ? SPIRV::Capability::ImageCubeArray
1047 : SPIRV::Capability::SampledCubeArray);
1048 break;
1049 case SPIRV::Dim::DIM_Rect:
1050 Reqs.addRequirements(Req: NoSampler ? SPIRV::Capability::ImageRect
1051 : SPIRV::Capability::SampledRect);
1052 break;
1053 case SPIRV::Dim::DIM_Buffer:
1054 Reqs.addRequirements(Req: NoSampler ? SPIRV::Capability::ImageBuffer
1055 : SPIRV::Capability::SampledBuffer);
1056 break;
1057 case SPIRV::Dim::DIM_SubpassData:
1058 Reqs.addRequirements(Req: SPIRV::Capability::InputAttachment);
1059 break;
1060 }
1061
1062 // Has optional access qualifier.
1063 if (!ST.isShader()) {
1064 if (MI.getNumOperands() > 8 &&
1065 MI.getOperand(i: 8).getImm() == SPIRV::AccessQualifier::ReadWrite)
1066 Reqs.addRequirements(Req: SPIRV::Capability::ImageReadWrite);
1067 else
1068 Reqs.addRequirements(Req: SPIRV::Capability::ImageBasic);
1069 }
1070}
1071
1072static bool isBFloat16Type(SPIRVTypeInst TypeDef) {
1073 return TypeDef && TypeDef->getNumOperands() == 3 &&
1074 TypeDef->getOpcode() == SPIRV::OpTypeFloat &&
1075 TypeDef->getOperand(i: 1).getImm() == 16 &&
1076 TypeDef->getOperand(i: 2).getImm() == SPIRV::FPEncoding::BFloat16KHR;
1077}
1078
1079// Add requirements for handling atomic float instructions
1080#define ATOM_FLT_REQ_EXT_MSG(ExtName) \
1081 "The atomic float instruction requires the following SPIR-V " \
1082 "extension: SPV_EXT_shader_atomic_float" ExtName
1083static void AddAtomicVectorFloatRequirements(const MachineInstr &MI,
1084 SPIRV::RequirementHandler &Reqs,
1085 const SPIRVSubtarget &ST) {
1086 SPIRVTypeInst VecTypeDef =
1087 MI.getMF()->getRegInfo().getVRegDef(Reg: MI.getOperand(i: 1).getReg());
1088
1089 const unsigned Rank = VecTypeDef->getOperand(i: 2).getImm();
1090 if (Rank != 2 && Rank != 4)
1091 reportFatalUsageError(reason: "Result type of an atomic vector float instruction "
1092 "must be a 2-component or 4 component vector");
1093
1094 SPIRVTypeInst EltTypeDef =
1095 MI.getMF()->getRegInfo().getVRegDef(Reg: VecTypeDef->getOperand(i: 1).getReg());
1096
1097 if (EltTypeDef->getOpcode() != SPIRV::OpTypeFloat ||
1098 EltTypeDef->getOperand(i: 1).getImm() != 16)
1099 reportFatalUsageError(
1100 reason: "The element type for the result type of an atomic vector float "
1101 "instruction must be a 16-bit floating-point scalar");
1102
1103 if (isBFloat16Type(TypeDef: EltTypeDef))
1104 reportFatalUsageError(
1105 reason: "The element type for the result type of an atomic vector float "
1106 "instruction cannot be a bfloat16 scalar");
1107 if (!ST.canUseExtension(E: SPIRV::Extension::SPV_NV_shader_atomic_fp16_vector))
1108 reportFatalUsageError(
1109 reason: "The atomic float16 vector instruction requires the following SPIR-V "
1110 "extension: SPV_NV_shader_atomic_fp16_vector");
1111
1112 Reqs.addExtension(ToAdd: SPIRV::Extension::SPV_NV_shader_atomic_fp16_vector);
1113 Reqs.addCapability(ToAdd: SPIRV::Capability::AtomicFloat16VectorNV);
1114}
1115
1116static void AddAtomicFloatRequirements(const MachineInstr &MI,
1117 SPIRV::RequirementHandler &Reqs,
1118 const SPIRVSubtarget &ST) {
1119 assert(MI.getOperand(1).isReg() &&
1120 "Expect register operand in atomic float instruction");
1121 Register TypeReg = MI.getOperand(i: 1).getReg();
1122 SPIRVTypeInst TypeDef = MI.getMF()->getRegInfo().getVRegDef(Reg: TypeReg);
1123
1124 if (TypeDef->getOpcode() == SPIRV::OpTypeVector)
1125 return AddAtomicVectorFloatRequirements(MI, Reqs, ST);
1126
1127 if (TypeDef->getOpcode() != SPIRV::OpTypeFloat)
1128 report_fatal_error(reason: "Result type of an atomic float instruction must be a "
1129 "floating-point type scalar");
1130
1131 unsigned BitWidth = TypeDef->getOperand(i: 1).getImm();
1132 unsigned Op = MI.getOpcode();
1133 if (Op == SPIRV::OpAtomicFAddEXT) {
1134 if (!ST.canUseExtension(E: SPIRV::Extension::SPV_EXT_shader_atomic_float_add))
1135 report_fatal_error(ATOM_FLT_REQ_EXT_MSG("_add"), gen_crash_diag: false);
1136 Reqs.addExtension(ToAdd: SPIRV::Extension::SPV_EXT_shader_atomic_float_add);
1137 switch (BitWidth) {
1138 case 16:
1139 if (isBFloat16Type(TypeDef)) {
1140 if (!ST.canUseExtension(E: SPIRV::Extension::SPV_INTEL_16bit_atomics))
1141 report_fatal_error(
1142 reason: "The atomic bfloat16 instruction requires the following SPIR-V "
1143 "extension: SPV_INTEL_16bit_atomics",
1144 gen_crash_diag: false);
1145 Reqs.addExtension(ToAdd: SPIRV::Extension::SPV_INTEL_16bit_atomics);
1146 Reqs.addCapability(ToAdd: SPIRV::Capability::AtomicBFloat16AddINTEL);
1147 } else {
1148 if (!ST.canUseExtension(
1149 E: SPIRV::Extension::SPV_EXT_shader_atomic_float16_add))
1150 report_fatal_error(ATOM_FLT_REQ_EXT_MSG("16_add"), gen_crash_diag: false);
1151 Reqs.addExtension(ToAdd: SPIRV::Extension::SPV_EXT_shader_atomic_float16_add);
1152 Reqs.addCapability(ToAdd: SPIRV::Capability::AtomicFloat16AddEXT);
1153 }
1154 break;
1155 case 32:
1156 Reqs.addCapability(ToAdd: SPIRV::Capability::AtomicFloat32AddEXT);
1157 break;
1158 case 64:
1159 Reqs.addCapability(ToAdd: SPIRV::Capability::AtomicFloat64AddEXT);
1160 break;
1161 default:
1162 report_fatal_error(
1163 reason: "Unexpected floating-point type width in atomic float instruction");
1164 }
1165 } else {
1166 if (!ST.canUseExtension(
1167 E: SPIRV::Extension::SPV_EXT_shader_atomic_float_min_max))
1168 report_fatal_error(ATOM_FLT_REQ_EXT_MSG("_min_max"), gen_crash_diag: false);
1169 Reqs.addExtension(ToAdd: SPIRV::Extension::SPV_EXT_shader_atomic_float_min_max);
1170 switch (BitWidth) {
1171 case 16:
1172 if (isBFloat16Type(TypeDef)) {
1173 if (!ST.canUseExtension(E: SPIRV::Extension::SPV_INTEL_16bit_atomics))
1174 report_fatal_error(
1175 reason: "The atomic bfloat16 instruction requires the following SPIR-V "
1176 "extension: SPV_INTEL_16bit_atomics",
1177 gen_crash_diag: false);
1178 Reqs.addExtension(ToAdd: SPIRV::Extension::SPV_INTEL_16bit_atomics);
1179 Reqs.addCapability(ToAdd: SPIRV::Capability::AtomicBFloat16MinMaxINTEL);
1180 } else {
1181 Reqs.addCapability(ToAdd: SPIRV::Capability::AtomicFloat16MinMaxEXT);
1182 }
1183 break;
1184 case 32:
1185 Reqs.addCapability(ToAdd: SPIRV::Capability::AtomicFloat32MinMaxEXT);
1186 break;
1187 case 64:
1188 Reqs.addCapability(ToAdd: SPIRV::Capability::AtomicFloat64MinMaxEXT);
1189 break;
1190 default:
1191 report_fatal_error(
1192 reason: "Unexpected floating-point type width in atomic float instruction");
1193 }
1194 }
1195}
1196
1197bool isUniformTexelBuffer(MachineInstr *ImageInst) {
1198 if (ImageInst->getOpcode() != SPIRV::OpTypeImage)
1199 return false;
1200 uint32_t Dim = ImageInst->getOperand(i: 2).getImm();
1201 uint32_t Sampled = ImageInst->getOperand(i: 6).getImm();
1202 return Dim == SPIRV::Dim::DIM_Buffer && Sampled == 1;
1203}
1204
1205bool isStorageTexelBuffer(MachineInstr *ImageInst) {
1206 if (ImageInst->getOpcode() != SPIRV::OpTypeImage)
1207 return false;
1208 uint32_t Dim = ImageInst->getOperand(i: 2).getImm();
1209 uint32_t Sampled = ImageInst->getOperand(i: 6).getImm();
1210 return Dim == SPIRV::Dim::DIM_Buffer && Sampled == 2;
1211}
1212
1213bool isSampledImage(MachineInstr *ImageInst) {
1214 if (ImageInst->getOpcode() != SPIRV::OpTypeImage)
1215 return false;
1216 uint32_t Dim = ImageInst->getOperand(i: 2).getImm();
1217 uint32_t Sampled = ImageInst->getOperand(i: 6).getImm();
1218 return Dim != SPIRV::Dim::DIM_Buffer && Sampled == 1;
1219}
1220
1221bool isInputAttachment(MachineInstr *ImageInst) {
1222 if (ImageInst->getOpcode() != SPIRV::OpTypeImage)
1223 return false;
1224 uint32_t Dim = ImageInst->getOperand(i: 2).getImm();
1225 uint32_t Sampled = ImageInst->getOperand(i: 6).getImm();
1226 return Dim == SPIRV::Dim::DIM_SubpassData && Sampled == 2;
1227}
1228
1229bool isStorageImage(MachineInstr *ImageInst) {
1230 if (ImageInst->getOpcode() != SPIRV::OpTypeImage)
1231 return false;
1232 uint32_t Dim = ImageInst->getOperand(i: 2).getImm();
1233 uint32_t Sampled = ImageInst->getOperand(i: 6).getImm();
1234 return Dim != SPIRV::Dim::DIM_Buffer && Sampled == 2;
1235}
1236
1237bool isCombinedImageSampler(MachineInstr *SampledImageInst) {
1238 if (SampledImageInst->getOpcode() != SPIRV::OpTypeSampledImage)
1239 return false;
1240
1241 const MachineRegisterInfo &MRI = SampledImageInst->getMF()->getRegInfo();
1242 Register ImageReg = SampledImageInst->getOperand(i: 1).getReg();
1243 auto *ImageInst = MRI.getUniqueVRegDef(Reg: ImageReg);
1244 return isSampledImage(ImageInst);
1245}
1246
1247bool hasNonUniformDecoration(Register Reg, const MachineRegisterInfo &MRI) {
1248 for (const auto &MI : MRI.reg_instructions(Reg)) {
1249 if (MI.getOpcode() != SPIRV::OpDecorate)
1250 continue;
1251
1252 uint32_t Dec = MI.getOperand(i: 1).getImm();
1253 if (Dec == SPIRV::Decoration::NonUniformEXT)
1254 return true;
1255 }
1256 return false;
1257}
1258
1259void addOpAccessChainReqs(const MachineInstr &Instr,
1260 SPIRV::RequirementHandler &Handler,
1261 const SPIRVSubtarget &Subtarget) {
1262 const MachineRegisterInfo &MRI = Instr.getMF()->getRegInfo();
1263 // Get the result type. If it is an image type, then the shader uses
1264 // descriptor indexing. The appropriate capabilities will be added based
1265 // on the specifics of the image.
1266 Register ResTypeReg = Instr.getOperand(i: 1).getReg();
1267 MachineInstr *ResTypeInst = MRI.getUniqueVRegDef(Reg: ResTypeReg);
1268
1269 assert(ResTypeInst->getOpcode() == SPIRV::OpTypePointer);
1270 uint32_t StorageClass = ResTypeInst->getOperand(i: 1).getImm();
1271 if (StorageClass != SPIRV::StorageClass::StorageClass::UniformConstant &&
1272 StorageClass != SPIRV::StorageClass::StorageClass::Uniform &&
1273 StorageClass != SPIRV::StorageClass::StorageClass::StorageBuffer) {
1274 return;
1275 }
1276
1277 bool IsNonUniform =
1278 hasNonUniformDecoration(Reg: Instr.getOperand(i: 0).getReg(), MRI);
1279
1280 auto FirstIndexReg = Instr.getOperand(i: 3).getReg();
1281 bool FirstIndexIsConstant =
1282 Subtarget.getInstrInfo()->isConstantInstr(MI: *MRI.getVRegDef(Reg: FirstIndexReg));
1283
1284 if (StorageClass == SPIRV::StorageClass::StorageClass::StorageBuffer) {
1285 if (IsNonUniform)
1286 Handler.addRequirements(
1287 Req: SPIRV::Capability::StorageBufferArrayNonUniformIndexingEXT);
1288 else if (!FirstIndexIsConstant)
1289 Handler.addRequirements(
1290 Req: SPIRV::Capability::StorageBufferArrayDynamicIndexing);
1291 return;
1292 }
1293
1294 Register PointeeTypeReg = ResTypeInst->getOperand(i: 2).getReg();
1295 MachineInstr *PointeeType = MRI.getUniqueVRegDef(Reg: PointeeTypeReg);
1296 if (PointeeType->getOpcode() != SPIRV::OpTypeImage &&
1297 PointeeType->getOpcode() != SPIRV::OpTypeSampledImage &&
1298 PointeeType->getOpcode() != SPIRV::OpTypeSampler) {
1299 return;
1300 }
1301
1302 if (isUniformTexelBuffer(ImageInst: PointeeType)) {
1303 if (IsNonUniform)
1304 Handler.addRequirements(
1305 Req: SPIRV::Capability::UniformTexelBufferArrayNonUniformIndexingEXT);
1306 else if (!FirstIndexIsConstant)
1307 Handler.addRequirements(
1308 Req: SPIRV::Capability::UniformTexelBufferArrayDynamicIndexingEXT);
1309 } else if (isInputAttachment(ImageInst: PointeeType)) {
1310 if (IsNonUniform)
1311 Handler.addRequirements(
1312 Req: SPIRV::Capability::InputAttachmentArrayNonUniformIndexingEXT);
1313 else if (!FirstIndexIsConstant)
1314 Handler.addRequirements(
1315 Req: SPIRV::Capability::InputAttachmentArrayDynamicIndexingEXT);
1316 } else if (isStorageTexelBuffer(ImageInst: PointeeType)) {
1317 if (IsNonUniform)
1318 Handler.addRequirements(
1319 Req: SPIRV::Capability::StorageTexelBufferArrayNonUniformIndexingEXT);
1320 else if (!FirstIndexIsConstant)
1321 Handler.addRequirements(
1322 Req: SPIRV::Capability::StorageTexelBufferArrayDynamicIndexingEXT);
1323 } else if (isSampledImage(ImageInst: PointeeType) ||
1324 isCombinedImageSampler(SampledImageInst: PointeeType) ||
1325 PointeeType->getOpcode() == SPIRV::OpTypeSampler) {
1326 if (IsNonUniform)
1327 Handler.addRequirements(
1328 Req: SPIRV::Capability::SampledImageArrayNonUniformIndexingEXT);
1329 else if (!FirstIndexIsConstant)
1330 Handler.addRequirements(
1331 Req: SPIRV::Capability::SampledImageArrayDynamicIndexing);
1332 } else if (isStorageImage(ImageInst: PointeeType)) {
1333 if (IsNonUniform)
1334 Handler.addRequirements(
1335 Req: SPIRV::Capability::StorageImageArrayNonUniformIndexingEXT);
1336 else if (!FirstIndexIsConstant)
1337 Handler.addRequirements(
1338 Req: SPIRV::Capability::StorageImageArrayDynamicIndexing);
1339 }
1340}
1341
1342static bool isImageTypeWithUnknownFormat(SPIRVTypeInst TypeInst) {
1343 if (TypeInst->getOpcode() != SPIRV::OpTypeImage)
1344 return false;
1345 assert(TypeInst->getOperand(7).isImm() && "The image format must be an imm.");
1346 return TypeInst->getOperand(i: 7).getImm() == 0;
1347}
1348
1349static void AddDotProductRequirements(const MachineInstr &MI,
1350 SPIRV::RequirementHandler &Reqs,
1351 const SPIRVSubtarget &ST) {
1352 if (ST.canUseExtension(E: SPIRV::Extension::SPV_KHR_integer_dot_product))
1353 Reqs.addExtension(ToAdd: SPIRV::Extension::SPV_KHR_integer_dot_product);
1354 Reqs.addCapability(ToAdd: SPIRV::Capability::DotProduct);
1355
1356 const MachineRegisterInfo &MRI = MI.getMF()->getRegInfo();
1357 assert(MI.getOperand(2).isReg() && "Unexpected operand in dot");
1358 // We do not consider what the previous instruction is. This is just used
1359 // to get the input register and to check the type.
1360 const MachineInstr *Input = MRI.getVRegDef(Reg: MI.getOperand(i: 2).getReg());
1361 assert(Input->getOperand(1).isReg() && "Unexpected operand in dot input");
1362 Register InputReg = Input->getOperand(i: 1).getReg();
1363
1364 SPIRVTypeInst TypeDef = MRI.getVRegDef(Reg: InputReg);
1365 if (TypeDef->getOpcode() == SPIRV::OpTypeInt) {
1366 assert(TypeDef->getOperand(1).getImm() == 32);
1367 Reqs.addCapability(ToAdd: SPIRV::Capability::DotProductInput4x8BitPacked);
1368 } else if (TypeDef->getOpcode() == SPIRV::OpTypeVector) {
1369 SPIRVTypeInst ScalarTypeDef =
1370 MRI.getVRegDef(Reg: TypeDef->getOperand(i: 1).getReg());
1371 assert(ScalarTypeDef->getOpcode() == SPIRV::OpTypeInt);
1372 if (ScalarTypeDef->getOperand(i: 1).getImm() == 8) {
1373 assert(TypeDef->getOperand(2).getImm() == 4 &&
1374 "Dot operand of 8-bit integer type requires 4 components");
1375 Reqs.addCapability(ToAdd: SPIRV::Capability::DotProductInput4x8Bit);
1376 } else {
1377 Reqs.addCapability(ToAdd: SPIRV::Capability::DotProductInputAll);
1378 }
1379 }
1380}
1381
1382void addPrintfRequirements(const MachineInstr &MI,
1383 SPIRV::RequirementHandler &Reqs,
1384 const SPIRVSubtarget &ST) {
1385 SPIRVGlobalRegistry *GR = ST.getSPIRVGlobalRegistry();
1386 SPIRVTypeInst PtrType = GR->getSPIRVTypeForVReg(VReg: MI.getOperand(i: 4).getReg());
1387 if (PtrType) {
1388 MachineOperand ASOp = PtrType->getOperand(i: 1);
1389 if (ASOp.isImm()) {
1390 unsigned AddrSpace = ASOp.getImm();
1391 if (AddrSpace != SPIRV::StorageClass::UniformConstant) {
1392 if (!ST.canUseExtension(
1393 E: SPIRV::Extension::
1394 SPV_EXT_relaxed_printf_string_address_space)) {
1395 report_fatal_error(reason: "SPV_EXT_relaxed_printf_string_address_space is "
1396 "required because printf uses a format string not "
1397 "in constant address space.",
1398 gen_crash_diag: false);
1399 }
1400 Reqs.addExtension(
1401 ToAdd: SPIRV::Extension::SPV_EXT_relaxed_printf_string_address_space);
1402 }
1403 }
1404 }
1405}
1406
1407static void addImageOperandReqs(const MachineInstr &MI,
1408 SPIRV::RequirementHandler &Reqs,
1409 const SPIRVSubtarget &ST, unsigned OpIdx) {
1410 if (MI.getNumOperands() <= OpIdx)
1411 return;
1412 uint32_t Mask = MI.getOperand(i: OpIdx).getImm();
1413 for (uint32_t I = 0; I < 32; ++I)
1414 if (Mask & (1U << I))
1415 Reqs.getAndAddRequirements(Category: SPIRV::OperandCategory::ImageOperandOperand,
1416 i: 1U << I, ST);
1417}
1418
1419void addInstrRequirements(const MachineInstr &MI,
1420 SPIRV::ModuleAnalysisInfo &MAI,
1421 const SPIRVSubtarget &ST) {
1422 SPIRV::RequirementHandler &Reqs = MAI.Reqs;
1423 switch (MI.getOpcode()) {
1424 case SPIRV::OpMemoryModel: {
1425 int64_t Addr = MI.getOperand(i: 0).getImm();
1426 Reqs.getAndAddRequirements(Category: SPIRV::OperandCategory::AddressingModelOperand,
1427 i: Addr, ST);
1428 int64_t Mem = MI.getOperand(i: 1).getImm();
1429 Reqs.getAndAddRequirements(Category: SPIRV::OperandCategory::MemoryModelOperand, i: Mem,
1430 ST);
1431 break;
1432 }
1433 case SPIRV::OpEntryPoint: {
1434 int64_t Exe = MI.getOperand(i: 0).getImm();
1435 Reqs.getAndAddRequirements(Category: SPIRV::OperandCategory::ExecutionModelOperand,
1436 i: Exe, ST);
1437 break;
1438 }
1439 case SPIRV::OpExecutionMode:
1440 case SPIRV::OpExecutionModeId: {
1441 int64_t Exe = MI.getOperand(i: 1).getImm();
1442 Reqs.getAndAddRequirements(Category: SPIRV::OperandCategory::ExecutionModeOperand,
1443 i: Exe, ST);
1444 break;
1445 }
1446 case SPIRV::OpTypeMatrix:
1447 Reqs.addCapability(ToAdd: SPIRV::Capability::Matrix);
1448 break;
1449 case SPIRV::OpTypeInt: {
1450 unsigned BitWidth = MI.getOperand(i: 1).getImm();
1451 if (BitWidth == 64)
1452 Reqs.addCapability(ToAdd: SPIRV::Capability::Int64);
1453 else if (BitWidth == 16)
1454 Reqs.addCapability(ToAdd: SPIRV::Capability::Int16);
1455 else if (BitWidth == 8)
1456 Reqs.addCapability(ToAdd: SPIRV::Capability::Int8);
1457 else if (BitWidth == 4 &&
1458 ST.canUseExtension(E: SPIRV::Extension::SPV_INTEL_int4)) {
1459 Reqs.addExtension(ToAdd: SPIRV::Extension::SPV_INTEL_int4);
1460 Reqs.addCapability(ToAdd: SPIRV::Capability::Int4TypeINTEL);
1461 } else if (BitWidth != 32) {
1462 if (!ST.canUseExtension(
1463 E: SPIRV::Extension::SPV_ALTERA_arbitrary_precision_integers))
1464 reportFatalUsageError(
1465 reason: "OpTypeInt type with a width other than 8, 16, 32 or 64 bits "
1466 "requires the following SPIR-V extension: "
1467 "SPV_ALTERA_arbitrary_precision_integers");
1468 Reqs.addExtension(
1469 ToAdd: SPIRV::Extension::SPV_ALTERA_arbitrary_precision_integers);
1470 Reqs.addCapability(ToAdd: SPIRV::Capability::ArbitraryPrecisionIntegersALTERA);
1471 }
1472 break;
1473 }
1474 case SPIRV::OpDot: {
1475 const MachineRegisterInfo &MRI = MI.getMF()->getRegInfo();
1476 SPIRVTypeInst TypeDef = MRI.getVRegDef(Reg: MI.getOperand(i: 1).getReg());
1477 if (isBFloat16Type(TypeDef))
1478 Reqs.addCapability(ToAdd: SPIRV::Capability::BFloat16DotProductKHR);
1479 break;
1480 }
1481 case SPIRV::OpTypeFloat: {
1482 unsigned BitWidth = MI.getOperand(i: 1).getImm();
1483 if (BitWidth == 64)
1484 Reqs.addCapability(ToAdd: SPIRV::Capability::Float64);
1485 else if (BitWidth == 16) {
1486 if (isBFloat16Type(TypeDef: &MI)) {
1487 if (!ST.canUseExtension(E: SPIRV::Extension::SPV_KHR_bfloat16))
1488 report_fatal_error(reason: "OpTypeFloat type with bfloat requires the "
1489 "following SPIR-V extension: SPV_KHR_bfloat16",
1490 gen_crash_diag: false);
1491 Reqs.addExtension(ToAdd: SPIRV::Extension::SPV_KHR_bfloat16);
1492 Reqs.addCapability(ToAdd: SPIRV::Capability::BFloat16TypeKHR);
1493 } else {
1494 Reqs.addCapability(ToAdd: SPIRV::Capability::Float16);
1495 }
1496 }
1497 break;
1498 }
1499 case SPIRV::OpTypeVector: {
1500 unsigned NumComponents = MI.getOperand(i: 2).getImm();
1501 if (NumComponents == 8 || NumComponents == 16)
1502 Reqs.addCapability(ToAdd: SPIRV::Capability::Vector16);
1503
1504 assert(MI.getOperand(1).isReg());
1505 const MachineRegisterInfo &MRI = MI.getMF()->getRegInfo();
1506 SPIRVTypeInst ElemTypeDef = MRI.getVRegDef(Reg: MI.getOperand(i: 1).getReg());
1507 if (ElemTypeDef->getOpcode() == SPIRV::OpTypePointer &&
1508 ST.canUseExtension(E: SPIRV::Extension::SPV_INTEL_masked_gather_scatter)) {
1509 Reqs.addExtension(ToAdd: SPIRV::Extension::SPV_INTEL_masked_gather_scatter);
1510 Reqs.addCapability(ToAdd: SPIRV::Capability::MaskedGatherScatterINTEL);
1511 }
1512 break;
1513 }
1514 case SPIRV::OpTypePointer: {
1515 auto SC = MI.getOperand(i: 1).getImm();
1516 Reqs.getAndAddRequirements(Category: SPIRV::OperandCategory::StorageClassOperand, i: SC,
1517 ST);
1518 // If it's a type of pointer to float16 targeting OpenCL, add Float16Buffer
1519 // capability.
1520 if (ST.isShader())
1521 break;
1522 assert(MI.getOperand(2).isReg());
1523 const MachineRegisterInfo &MRI = MI.getMF()->getRegInfo();
1524 SPIRVTypeInst TypeDef = MRI.getVRegDef(Reg: MI.getOperand(i: 2).getReg());
1525 if ((TypeDef->getNumOperands() == 2) &&
1526 (TypeDef->getOpcode() == SPIRV::OpTypeFloat) &&
1527 (TypeDef->getOperand(i: 1).getImm() == 16))
1528 Reqs.addCapability(ToAdd: SPIRV::Capability::Float16Buffer);
1529 break;
1530 }
1531 case SPIRV::OpExtInst: {
1532 if (MI.getOperand(i: 2).getImm() ==
1533 static_cast<int64_t>(
1534 SPIRV::InstructionSet::NonSemantic_Shader_DebugInfo_100)) {
1535 Reqs.addExtension(ToAdd: SPIRV::Extension::SPV_KHR_non_semantic_info);
1536 break;
1537 }
1538 if (MI.getOperand(i: 3).getImm() ==
1539 static_cast<int64_t>(SPIRV::OpenCLExtInst::printf)) {
1540 addPrintfRequirements(MI, Reqs, ST);
1541 break;
1542 }
1543 // TODO: handle bfloat16 extended instructions when
1544 // SPV_INTEL_bfloat16_arithmetic is enabled.
1545 break;
1546 }
1547 case SPIRV::OpAliasDomainDeclINTEL:
1548 case SPIRV::OpAliasScopeDeclINTEL:
1549 case SPIRV::OpAliasScopeListDeclINTEL: {
1550 Reqs.addExtension(ToAdd: SPIRV::Extension::SPV_INTEL_memory_access_aliasing);
1551 Reqs.addCapability(ToAdd: SPIRV::Capability::MemoryAccessAliasingINTEL);
1552 break;
1553 }
1554 case SPIRV::OpBitReverse:
1555 case SPIRV::OpBitFieldInsert:
1556 case SPIRV::OpBitFieldSExtract:
1557 case SPIRV::OpBitFieldUExtract:
1558 if (!ST.canUseExtension(E: SPIRV::Extension::SPV_KHR_bit_instructions)) {
1559 Reqs.addCapability(ToAdd: SPIRV::Capability::Shader);
1560 break;
1561 }
1562 Reqs.addExtension(ToAdd: SPIRV::Extension::SPV_KHR_bit_instructions);
1563 Reqs.addCapability(ToAdd: SPIRV::Capability::BitInstructions);
1564 break;
1565 case SPIRV::OpTypeRuntimeArray:
1566 Reqs.addCapability(ToAdd: SPIRV::Capability::Shader);
1567 break;
1568 case SPIRV::OpTypeOpaque:
1569 case SPIRV::OpTypeEvent:
1570 Reqs.addCapability(ToAdd: SPIRV::Capability::Kernel);
1571 break;
1572 case SPIRV::OpTypePipe:
1573 case SPIRV::OpTypeReserveId:
1574 Reqs.addCapability(ToAdd: SPIRV::Capability::Pipes);
1575 break;
1576 case SPIRV::OpTypeDeviceEvent:
1577 case SPIRV::OpTypeQueue:
1578 case SPIRV::OpBuildNDRange:
1579 Reqs.addCapability(ToAdd: SPIRV::Capability::DeviceEnqueue);
1580 break;
1581 case SPIRV::OpDecorate:
1582 case SPIRV::OpDecorateId:
1583 case SPIRV::OpDecorateString:
1584 addOpDecorateReqs(MI, DecIndex: 1, Reqs, ST);
1585 break;
1586 case SPIRV::OpMemberDecorate:
1587 case SPIRV::OpMemberDecorateString:
1588 addOpDecorateReqs(MI, DecIndex: 2, Reqs, ST);
1589 break;
1590 case SPIRV::OpInBoundsPtrAccessChain:
1591 Reqs.addCapability(ToAdd: SPIRV::Capability::Addresses);
1592 break;
1593 case SPIRV::OpConstantSampler:
1594 Reqs.addCapability(ToAdd: SPIRV::Capability::LiteralSampler);
1595 break;
1596 case SPIRV::OpInBoundsAccessChain:
1597 case SPIRV::OpAccessChain:
1598 addOpAccessChainReqs(Instr: MI, Handler&: Reqs, Subtarget: ST);
1599 break;
1600 case SPIRV::OpTypeImage:
1601 addOpTypeImageReqs(MI, Reqs, ST);
1602 break;
1603 case SPIRV::OpTypeSampler:
1604 if (!ST.isShader()) {
1605 Reqs.addCapability(ToAdd: SPIRV::Capability::ImageBasic);
1606 }
1607 break;
1608 case SPIRV::OpTypeForwardPointer:
1609 // TODO: check if it's OpenCL's kernel.
1610 Reqs.addCapability(ToAdd: SPIRV::Capability::Addresses);
1611 break;
1612 case SPIRV::OpAtomicFlagTestAndSet:
1613 case SPIRV::OpAtomicLoad:
1614 case SPIRV::OpAtomicStore:
1615 case SPIRV::OpAtomicExchange:
1616 case SPIRV::OpAtomicCompareExchange:
1617 case SPIRV::OpAtomicIIncrement:
1618 case SPIRV::OpAtomicIDecrement:
1619 case SPIRV::OpAtomicIAdd:
1620 case SPIRV::OpAtomicISub:
1621 case SPIRV::OpAtomicUMin:
1622 case SPIRV::OpAtomicUMax:
1623 case SPIRV::OpAtomicSMin:
1624 case SPIRV::OpAtomicSMax:
1625 case SPIRV::OpAtomicAnd:
1626 case SPIRV::OpAtomicOr:
1627 case SPIRV::OpAtomicXor: {
1628 const MachineRegisterInfo &MRI = MI.getMF()->getRegInfo();
1629 const MachineInstr *InstrPtr = &MI;
1630 if (MI.getOpcode() == SPIRV::OpAtomicStore) {
1631 assert(MI.getOperand(3).isReg());
1632 InstrPtr = MRI.getVRegDef(Reg: MI.getOperand(i: 3).getReg());
1633 assert(InstrPtr && "Unexpected type instruction for OpAtomicStore");
1634 }
1635 assert(InstrPtr->getOperand(1).isReg() && "Unexpected operand in atomic");
1636 Register TypeReg = InstrPtr->getOperand(i: 1).getReg();
1637 SPIRVTypeInst TypeDef = MRI.getVRegDef(Reg: TypeReg);
1638 if (TypeDef->getOpcode() == SPIRV::OpTypeInt) {
1639 unsigned BitWidth = TypeDef->getOperand(i: 1).getImm();
1640 if (BitWidth == 64)
1641 Reqs.addCapability(ToAdd: SPIRV::Capability::Int64Atomics);
1642 }
1643 break;
1644 }
1645 case SPIRV::OpGroupNonUniformIAdd:
1646 case SPIRV::OpGroupNonUniformFAdd:
1647 case SPIRV::OpGroupNonUniformIMul:
1648 case SPIRV::OpGroupNonUniformFMul:
1649 case SPIRV::OpGroupNonUniformSMin:
1650 case SPIRV::OpGroupNonUniformUMin:
1651 case SPIRV::OpGroupNonUniformFMin:
1652 case SPIRV::OpGroupNonUniformSMax:
1653 case SPIRV::OpGroupNonUniformUMax:
1654 case SPIRV::OpGroupNonUniformFMax:
1655 case SPIRV::OpGroupNonUniformBitwiseAnd:
1656 case SPIRV::OpGroupNonUniformBitwiseOr:
1657 case SPIRV::OpGroupNonUniformBitwiseXor:
1658 case SPIRV::OpGroupNonUniformLogicalAnd:
1659 case SPIRV::OpGroupNonUniformLogicalOr:
1660 case SPIRV::OpGroupNonUniformLogicalXor: {
1661 assert(MI.getOperand(3).isImm());
1662 int64_t GroupOp = MI.getOperand(i: 3).getImm();
1663 switch (GroupOp) {
1664 case SPIRV::GroupOperation::Reduce:
1665 case SPIRV::GroupOperation::InclusiveScan:
1666 case SPIRV::GroupOperation::ExclusiveScan:
1667 Reqs.addCapability(ToAdd: SPIRV::Capability::GroupNonUniformArithmetic);
1668 break;
1669 case SPIRV::GroupOperation::ClusteredReduce:
1670 Reqs.addCapability(ToAdd: SPIRV::Capability::GroupNonUniformClustered);
1671 break;
1672 case SPIRV::GroupOperation::PartitionedReduceNV:
1673 case SPIRV::GroupOperation::PartitionedInclusiveScanNV:
1674 case SPIRV::GroupOperation::PartitionedExclusiveScanNV:
1675 Reqs.addCapability(ToAdd: SPIRV::Capability::GroupNonUniformPartitionedNV);
1676 break;
1677 }
1678 break;
1679 }
1680 case SPIRV::OpImageQueryFormat: {
1681 Register ResultReg = MI.getOperand(i: 0).getReg();
1682 const MachineRegisterInfo &MRI = MI.getMF()->getRegInfo();
1683 static const unsigned CompareOps[] = {
1684 SPIRV::OpIEqual, SPIRV::OpINotEqual,
1685 SPIRV::OpUGreaterThan, SPIRV::OpUGreaterThanEqual,
1686 SPIRV::OpULessThan, SPIRV::OpULessThanEqual,
1687 SPIRV::OpSGreaterThan, SPIRV::OpSGreaterThanEqual,
1688 SPIRV::OpSLessThan, SPIRV::OpSLessThanEqual};
1689
1690 auto CheckAndAddExtension = [&](int64_t ImmVal) {
1691 if (ImmVal == 4323 || ImmVal == 4324) {
1692 if (ST.canUseExtension(E: SPIRV::Extension::SPV_EXT_image_raw10_raw12))
1693 Reqs.addExtension(ToAdd: SPIRV::Extension::SPV_EXT_image_raw10_raw12);
1694 else
1695 report_fatal_error(reason: "This requires the "
1696 "SPV_EXT_image_raw10_raw12 extension");
1697 }
1698 };
1699
1700 for (MachineInstr &UseInst : MRI.use_instructions(Reg: ResultReg)) {
1701 unsigned Opc = UseInst.getOpcode();
1702
1703 if (Opc == SPIRV::OpSwitch) {
1704 for (const MachineOperand &Op : UseInst.operands())
1705 if (Op.isImm())
1706 CheckAndAddExtension(Op.getImm());
1707 } else if (llvm::is_contained(Range: CompareOps, Element: Opc)) {
1708 for (unsigned i = 1; i < UseInst.getNumOperands(); ++i) {
1709 Register UseReg = UseInst.getOperand(i).getReg();
1710 MachineInstr *ConstInst = MRI.getVRegDef(Reg: UseReg);
1711 if (ConstInst && ConstInst->getOpcode() == SPIRV::OpConstantI) {
1712 int64_t ImmVal = ConstInst->getOperand(i: 2).getImm();
1713 if (ImmVal)
1714 CheckAndAddExtension(ImmVal);
1715 }
1716 }
1717 }
1718 }
1719 break;
1720 }
1721
1722 case SPIRV::OpGroupNonUniformShuffle:
1723 case SPIRV::OpGroupNonUniformShuffleXor:
1724 Reqs.addCapability(ToAdd: SPIRV::Capability::GroupNonUniformShuffle);
1725 break;
1726 case SPIRV::OpGroupNonUniformShuffleUp:
1727 case SPIRV::OpGroupNonUniformShuffleDown:
1728 Reqs.addCapability(ToAdd: SPIRV::Capability::GroupNonUniformShuffleRelative);
1729 break;
1730 case SPIRV::OpGroupAll:
1731 case SPIRV::OpGroupAny:
1732 case SPIRV::OpGroupBroadcast:
1733 case SPIRV::OpGroupIAdd:
1734 case SPIRV::OpGroupFAdd:
1735 case SPIRV::OpGroupFMin:
1736 case SPIRV::OpGroupUMin:
1737 case SPIRV::OpGroupSMin:
1738 case SPIRV::OpGroupFMax:
1739 case SPIRV::OpGroupUMax:
1740 case SPIRV::OpGroupSMax:
1741 Reqs.addCapability(ToAdd: SPIRV::Capability::Groups);
1742 break;
1743 case SPIRV::OpGroupNonUniformElect:
1744 Reqs.addCapability(ToAdd: SPIRV::Capability::GroupNonUniform);
1745 break;
1746 case SPIRV::OpGroupNonUniformAll:
1747 case SPIRV::OpGroupNonUniformAny:
1748 case SPIRV::OpGroupNonUniformAllEqual:
1749 Reqs.addCapability(ToAdd: SPIRV::Capability::GroupNonUniformVote);
1750 break;
1751 case SPIRV::OpGroupNonUniformBroadcast:
1752 case SPIRV::OpGroupNonUniformBroadcastFirst:
1753 case SPIRV::OpGroupNonUniformBallot:
1754 case SPIRV::OpGroupNonUniformInverseBallot:
1755 case SPIRV::OpGroupNonUniformBallotBitExtract:
1756 case SPIRV::OpGroupNonUniformBallotBitCount:
1757 case SPIRV::OpGroupNonUniformBallotFindLSB:
1758 case SPIRV::OpGroupNonUniformBallotFindMSB:
1759 Reqs.addCapability(ToAdd: SPIRV::Capability::GroupNonUniformBallot);
1760 break;
1761 case SPIRV::OpSubgroupShuffleINTEL:
1762 case SPIRV::OpSubgroupShuffleDownINTEL:
1763 case SPIRV::OpSubgroupShuffleUpINTEL:
1764 case SPIRV::OpSubgroupShuffleXorINTEL:
1765 if (ST.canUseExtension(E: SPIRV::Extension::SPV_INTEL_subgroups)) {
1766 Reqs.addExtension(ToAdd: SPIRV::Extension::SPV_INTEL_subgroups);
1767 Reqs.addCapability(ToAdd: SPIRV::Capability::SubgroupShuffleINTEL);
1768 }
1769 break;
1770 case SPIRV::OpSubgroupBlockReadINTEL:
1771 case SPIRV::OpSubgroupBlockWriteINTEL:
1772 if (ST.canUseExtension(E: SPIRV::Extension::SPV_INTEL_subgroups)) {
1773 Reqs.addExtension(ToAdd: SPIRV::Extension::SPV_INTEL_subgroups);
1774 Reqs.addCapability(ToAdd: SPIRV::Capability::SubgroupBufferBlockIOINTEL);
1775 }
1776 break;
1777 case SPIRV::OpSubgroupImageBlockReadINTEL:
1778 case SPIRV::OpSubgroupImageBlockWriteINTEL:
1779 if (ST.canUseExtension(E: SPIRV::Extension::SPV_INTEL_subgroups)) {
1780 Reqs.addExtension(ToAdd: SPIRV::Extension::SPV_INTEL_subgroups);
1781 Reqs.addCapability(ToAdd: SPIRV::Capability::SubgroupImageBlockIOINTEL);
1782 }
1783 break;
1784 case SPIRV::OpSubgroupImageMediaBlockReadINTEL:
1785 case SPIRV::OpSubgroupImageMediaBlockWriteINTEL:
1786 if (ST.canUseExtension(E: SPIRV::Extension::SPV_INTEL_media_block_io)) {
1787 Reqs.addExtension(ToAdd: SPIRV::Extension::SPV_INTEL_media_block_io);
1788 Reqs.addCapability(ToAdd: SPIRV::Capability::SubgroupImageMediaBlockIOINTEL);
1789 }
1790 break;
1791 case SPIRV::OpAssumeTrueKHR:
1792 case SPIRV::OpExpectKHR:
1793 if (ST.canUseExtension(E: SPIRV::Extension::SPV_KHR_expect_assume)) {
1794 Reqs.addExtension(ToAdd: SPIRV::Extension::SPV_KHR_expect_assume);
1795 Reqs.addCapability(ToAdd: SPIRV::Capability::ExpectAssumeKHR);
1796 }
1797 break;
1798 case SPIRV::OpFmaKHR:
1799 if (ST.canUseExtension(E: SPIRV::Extension::SPV_KHR_fma)) {
1800 Reqs.addExtension(ToAdd: SPIRV::Extension::SPV_KHR_fma);
1801 Reqs.addCapability(ToAdd: SPIRV::Capability::FmaKHR);
1802 }
1803 break;
1804 case SPIRV::OpPtrCastToCrossWorkgroupINTEL:
1805 case SPIRV::OpCrossWorkgroupCastToPtrINTEL:
1806 if (ST.canUseExtension(E: SPIRV::Extension::SPV_INTEL_usm_storage_classes)) {
1807 Reqs.addExtension(ToAdd: SPIRV::Extension::SPV_INTEL_usm_storage_classes);
1808 Reqs.addCapability(ToAdd: SPIRV::Capability::USMStorageClassesINTEL);
1809 }
1810 break;
1811 case SPIRV::OpConstantFunctionPointerINTEL:
1812 if (ST.canUseExtension(E: SPIRV::Extension::SPV_INTEL_function_pointers)) {
1813 Reqs.addExtension(ToAdd: SPIRV::Extension::SPV_INTEL_function_pointers);
1814 Reqs.addCapability(ToAdd: SPIRV::Capability::FunctionPointersINTEL);
1815 }
1816 break;
1817 case SPIRV::OpGroupNonUniformRotateKHR:
1818 if (!ST.canUseExtension(E: SPIRV::Extension::SPV_KHR_subgroup_rotate))
1819 report_fatal_error(reason: "OpGroupNonUniformRotateKHR instruction requires the "
1820 "following SPIR-V extension: SPV_KHR_subgroup_rotate",
1821 gen_crash_diag: false);
1822 Reqs.addExtension(ToAdd: SPIRV::Extension::SPV_KHR_subgroup_rotate);
1823 Reqs.addCapability(ToAdd: SPIRV::Capability::GroupNonUniformRotateKHR);
1824 Reqs.addCapability(ToAdd: SPIRV::Capability::GroupNonUniform);
1825 break;
1826 case SPIRV::OpFixedCosALTERA:
1827 case SPIRV::OpFixedSinALTERA:
1828 case SPIRV::OpFixedCosPiALTERA:
1829 case SPIRV::OpFixedSinPiALTERA:
1830 case SPIRV::OpFixedExpALTERA:
1831 case SPIRV::OpFixedLogALTERA:
1832 case SPIRV::OpFixedRecipALTERA:
1833 case SPIRV::OpFixedSqrtALTERA:
1834 case SPIRV::OpFixedSinCosALTERA:
1835 case SPIRV::OpFixedSinCosPiALTERA:
1836 case SPIRV::OpFixedRsqrtALTERA:
1837 if (!ST.canUseExtension(
1838 E: SPIRV::Extension::SPV_ALTERA_arbitrary_precision_fixed_point))
1839 report_fatal_error(reason: "This instruction requires the "
1840 "following SPIR-V extension: "
1841 "SPV_ALTERA_arbitrary_precision_fixed_point",
1842 gen_crash_diag: false);
1843 Reqs.addExtension(
1844 ToAdd: SPIRV::Extension::SPV_ALTERA_arbitrary_precision_fixed_point);
1845 Reqs.addCapability(ToAdd: SPIRV::Capability::ArbitraryPrecisionFixedPointALTERA);
1846 break;
1847 case SPIRV::OpGroupIMulKHR:
1848 case SPIRV::OpGroupFMulKHR:
1849 case SPIRV::OpGroupBitwiseAndKHR:
1850 case SPIRV::OpGroupBitwiseOrKHR:
1851 case SPIRV::OpGroupBitwiseXorKHR:
1852 case SPIRV::OpGroupLogicalAndKHR:
1853 case SPIRV::OpGroupLogicalOrKHR:
1854 case SPIRV::OpGroupLogicalXorKHR:
1855 if (ST.canUseExtension(
1856 E: SPIRV::Extension::SPV_KHR_uniform_group_instructions)) {
1857 Reqs.addExtension(ToAdd: SPIRV::Extension::SPV_KHR_uniform_group_instructions);
1858 Reqs.addCapability(ToAdd: SPIRV::Capability::GroupUniformArithmeticKHR);
1859 }
1860 break;
1861 case SPIRV::OpReadClockKHR:
1862 if (!ST.canUseExtension(E: SPIRV::Extension::SPV_KHR_shader_clock))
1863 report_fatal_error(reason: "OpReadClockKHR instruction requires the "
1864 "following SPIR-V extension: SPV_KHR_shader_clock",
1865 gen_crash_diag: false);
1866 Reqs.addExtension(ToAdd: SPIRV::Extension::SPV_KHR_shader_clock);
1867 Reqs.addCapability(ToAdd: SPIRV::Capability::ShaderClockKHR);
1868 break;
1869 case SPIRV::OpFunctionPointerCallINTEL:
1870 if (ST.canUseExtension(E: SPIRV::Extension::SPV_INTEL_function_pointers)) {
1871 Reqs.addExtension(ToAdd: SPIRV::Extension::SPV_INTEL_function_pointers);
1872 Reqs.addCapability(ToAdd: SPIRV::Capability::FunctionPointersINTEL);
1873 }
1874 break;
1875 case SPIRV::OpAtomicFAddEXT:
1876 case SPIRV::OpAtomicFMinEXT:
1877 case SPIRV::OpAtomicFMaxEXT:
1878 AddAtomicFloatRequirements(MI, Reqs, ST);
1879 break;
1880 case SPIRV::OpConvertBF16ToFINTEL:
1881 case SPIRV::OpConvertFToBF16INTEL:
1882 if (ST.canUseExtension(E: SPIRV::Extension::SPV_INTEL_bfloat16_conversion)) {
1883 Reqs.addExtension(ToAdd: SPIRV::Extension::SPV_INTEL_bfloat16_conversion);
1884 Reqs.addCapability(ToAdd: SPIRV::Capability::BFloat16ConversionINTEL);
1885 }
1886 break;
1887 case SPIRV::OpRoundFToTF32INTEL:
1888 if (ST.canUseExtension(
1889 E: SPIRV::Extension::SPV_INTEL_tensor_float32_conversion)) {
1890 Reqs.addExtension(ToAdd: SPIRV::Extension::SPV_INTEL_tensor_float32_conversion);
1891 Reqs.addCapability(ToAdd: SPIRV::Capability::TensorFloat32RoundingINTEL);
1892 }
1893 break;
1894 case SPIRV::OpVariableLengthArrayINTEL:
1895 case SPIRV::OpSaveMemoryINTEL:
1896 case SPIRV::OpRestoreMemoryINTEL:
1897 if (ST.canUseExtension(E: SPIRV::Extension::SPV_INTEL_variable_length_array)) {
1898 Reqs.addExtension(ToAdd: SPIRV::Extension::SPV_INTEL_variable_length_array);
1899 Reqs.addCapability(ToAdd: SPIRV::Capability::VariableLengthArrayINTEL);
1900 }
1901 break;
1902 case SPIRV::OpAsmTargetINTEL:
1903 case SPIRV::OpAsmINTEL:
1904 case SPIRV::OpAsmCallINTEL:
1905 if (ST.canUseExtension(E: SPIRV::Extension::SPV_INTEL_inline_assembly)) {
1906 Reqs.addExtension(ToAdd: SPIRV::Extension::SPV_INTEL_inline_assembly);
1907 Reqs.addCapability(ToAdd: SPIRV::Capability::AsmINTEL);
1908 }
1909 break;
1910 case SPIRV::OpTypeCooperativeMatrixKHR: {
1911 if (!ST.canUseExtension(E: SPIRV::Extension::SPV_KHR_cooperative_matrix))
1912 report_fatal_error(
1913 reason: "OpTypeCooperativeMatrixKHR type requires the "
1914 "following SPIR-V extension: SPV_KHR_cooperative_matrix",
1915 gen_crash_diag: false);
1916 Reqs.addExtension(ToAdd: SPIRV::Extension::SPV_KHR_cooperative_matrix);
1917 Reqs.addCapability(ToAdd: SPIRV::Capability::CooperativeMatrixKHR);
1918 const MachineRegisterInfo &MRI = MI.getMF()->getRegInfo();
1919 SPIRVTypeInst TypeDef = MRI.getVRegDef(Reg: MI.getOperand(i: 1).getReg());
1920 if (isBFloat16Type(TypeDef))
1921 Reqs.addCapability(ToAdd: SPIRV::Capability::BFloat16CooperativeMatrixKHR);
1922 break;
1923 }
1924 case SPIRV::OpArithmeticFenceEXT:
1925 if (!ST.canUseExtension(E: SPIRV::Extension::SPV_EXT_arithmetic_fence))
1926 report_fatal_error(reason: "OpArithmeticFenceEXT requires the "
1927 "following SPIR-V extension: SPV_EXT_arithmetic_fence",
1928 gen_crash_diag: false);
1929 Reqs.addExtension(ToAdd: SPIRV::Extension::SPV_EXT_arithmetic_fence);
1930 Reqs.addCapability(ToAdd: SPIRV::Capability::ArithmeticFenceEXT);
1931 break;
1932 case SPIRV::OpControlBarrierArriveINTEL:
1933 case SPIRV::OpControlBarrierWaitINTEL:
1934 if (ST.canUseExtension(E: SPIRV::Extension::SPV_INTEL_split_barrier)) {
1935 Reqs.addExtension(ToAdd: SPIRV::Extension::SPV_INTEL_split_barrier);
1936 Reqs.addCapability(ToAdd: SPIRV::Capability::SplitBarrierINTEL);
1937 }
1938 break;
1939 case SPIRV::OpCooperativeMatrixMulAddKHR: {
1940 if (!ST.canUseExtension(E: SPIRV::Extension::SPV_KHR_cooperative_matrix))
1941 report_fatal_error(reason: "Cooperative matrix instructions require the "
1942 "following SPIR-V extension: "
1943 "SPV_KHR_cooperative_matrix",
1944 gen_crash_diag: false);
1945 Reqs.addExtension(ToAdd: SPIRV::Extension::SPV_KHR_cooperative_matrix);
1946 Reqs.addCapability(ToAdd: SPIRV::Capability::CooperativeMatrixKHR);
1947 constexpr unsigned MulAddMaxSize = 6;
1948 if (MI.getNumOperands() != MulAddMaxSize)
1949 break;
1950 const int64_t CoopOperands = MI.getOperand(i: MulAddMaxSize - 1).getImm();
1951 if (CoopOperands &
1952 SPIRV::CooperativeMatrixOperands::MatrixAAndBTF32ComponentsINTEL) {
1953 if (!ST.canUseExtension(E: SPIRV::Extension::SPV_INTEL_joint_matrix))
1954 report_fatal_error(reason: "MatrixAAndBTF32ComponentsINTEL type interpretation "
1955 "require the following SPIR-V extension: "
1956 "SPV_INTEL_joint_matrix",
1957 gen_crash_diag: false);
1958 Reqs.addExtension(ToAdd: SPIRV::Extension::SPV_INTEL_joint_matrix);
1959 Reqs.addCapability(
1960 ToAdd: SPIRV::Capability::CooperativeMatrixTF32ComponentTypeINTEL);
1961 }
1962 if (CoopOperands & SPIRV::CooperativeMatrixOperands::
1963 MatrixAAndBBFloat16ComponentsINTEL ||
1964 CoopOperands &
1965 SPIRV::CooperativeMatrixOperands::MatrixCBFloat16ComponentsINTEL ||
1966 CoopOperands & SPIRV::CooperativeMatrixOperands::
1967 MatrixResultBFloat16ComponentsINTEL) {
1968 if (!ST.canUseExtension(E: SPIRV::Extension::SPV_INTEL_joint_matrix))
1969 report_fatal_error(reason: "***BF16ComponentsINTEL type interpretations "
1970 "require the following SPIR-V extension: "
1971 "SPV_INTEL_joint_matrix",
1972 gen_crash_diag: false);
1973 Reqs.addExtension(ToAdd: SPIRV::Extension::SPV_INTEL_joint_matrix);
1974 Reqs.addCapability(
1975 ToAdd: SPIRV::Capability::CooperativeMatrixBFloat16ComponentTypeINTEL);
1976 }
1977 break;
1978 }
1979 case SPIRV::OpCooperativeMatrixLoadKHR:
1980 case SPIRV::OpCooperativeMatrixStoreKHR:
1981 case SPIRV::OpCooperativeMatrixLoadCheckedINTEL:
1982 case SPIRV::OpCooperativeMatrixStoreCheckedINTEL:
1983 case SPIRV::OpCooperativeMatrixPrefetchINTEL: {
1984 if (!ST.canUseExtension(E: SPIRV::Extension::SPV_KHR_cooperative_matrix))
1985 report_fatal_error(reason: "Cooperative matrix instructions require the "
1986 "following SPIR-V extension: "
1987 "SPV_KHR_cooperative_matrix",
1988 gen_crash_diag: false);
1989 Reqs.addExtension(ToAdd: SPIRV::Extension::SPV_KHR_cooperative_matrix);
1990 Reqs.addCapability(ToAdd: SPIRV::Capability::CooperativeMatrixKHR);
1991
1992 // Check Layout operand in case if it's not a standard one and add the
1993 // appropriate capability.
1994 std::unordered_map<unsigned, unsigned> LayoutToInstMap = {
1995 {SPIRV::OpCooperativeMatrixLoadKHR, 3},
1996 {SPIRV::OpCooperativeMatrixStoreKHR, 2},
1997 {SPIRV::OpCooperativeMatrixLoadCheckedINTEL, 5},
1998 {SPIRV::OpCooperativeMatrixStoreCheckedINTEL, 4},
1999 {SPIRV::OpCooperativeMatrixPrefetchINTEL, 4}};
2000
2001 const auto OpCode = MI.getOpcode();
2002 const unsigned LayoutNum = LayoutToInstMap[OpCode];
2003 Register RegLayout = MI.getOperand(i: LayoutNum).getReg();
2004 const MachineRegisterInfo &MRI = MI.getMF()->getRegInfo();
2005 MachineInstr *MILayout = MRI.getUniqueVRegDef(Reg: RegLayout);
2006 if (MILayout->getOpcode() == SPIRV::OpConstantI) {
2007 const unsigned LayoutVal = MILayout->getOperand(i: 2).getImm();
2008 if (LayoutVal ==
2009 static_cast<unsigned>(SPIRV::CooperativeMatrixLayout::PackedINTEL)) {
2010 if (!ST.canUseExtension(E: SPIRV::Extension::SPV_INTEL_joint_matrix))
2011 report_fatal_error(reason: "PackedINTEL layout require the following SPIR-V "
2012 "extension: SPV_INTEL_joint_matrix",
2013 gen_crash_diag: false);
2014 Reqs.addExtension(ToAdd: SPIRV::Extension::SPV_INTEL_joint_matrix);
2015 Reqs.addCapability(ToAdd: SPIRV::Capability::PackedCooperativeMatrixINTEL);
2016 }
2017 }
2018
2019 // Nothing to do.
2020 if (OpCode == SPIRV::OpCooperativeMatrixLoadKHR ||
2021 OpCode == SPIRV::OpCooperativeMatrixStoreKHR)
2022 break;
2023
2024 std::string InstName;
2025 switch (OpCode) {
2026 case SPIRV::OpCooperativeMatrixPrefetchINTEL:
2027 InstName = "OpCooperativeMatrixPrefetchINTEL";
2028 break;
2029 case SPIRV::OpCooperativeMatrixLoadCheckedINTEL:
2030 InstName = "OpCooperativeMatrixLoadCheckedINTEL";
2031 break;
2032 case SPIRV::OpCooperativeMatrixStoreCheckedINTEL:
2033 InstName = "OpCooperativeMatrixStoreCheckedINTEL";
2034 break;
2035 }
2036
2037 if (!ST.canUseExtension(E: SPIRV::Extension::SPV_INTEL_joint_matrix)) {
2038 const std::string ErrorMsg =
2039 InstName + " instruction requires the "
2040 "following SPIR-V extension: SPV_INTEL_joint_matrix";
2041 report_fatal_error(reason: ErrorMsg.c_str(), gen_crash_diag: false);
2042 }
2043 Reqs.addExtension(ToAdd: SPIRV::Extension::SPV_INTEL_joint_matrix);
2044 if (OpCode == SPIRV::OpCooperativeMatrixPrefetchINTEL) {
2045 Reqs.addCapability(ToAdd: SPIRV::Capability::CooperativeMatrixPrefetchINTEL);
2046 break;
2047 }
2048 Reqs.addCapability(
2049 ToAdd: SPIRV::Capability::CooperativeMatrixCheckedInstructionsINTEL);
2050 break;
2051 }
2052 case SPIRV::OpCooperativeMatrixConstructCheckedINTEL:
2053 if (!ST.canUseExtension(E: SPIRV::Extension::SPV_INTEL_joint_matrix))
2054 report_fatal_error(reason: "OpCooperativeMatrixConstructCheckedINTEL "
2055 "instructions require the following SPIR-V extension: "
2056 "SPV_INTEL_joint_matrix",
2057 gen_crash_diag: false);
2058 Reqs.addExtension(ToAdd: SPIRV::Extension::SPV_INTEL_joint_matrix);
2059 Reqs.addCapability(
2060 ToAdd: SPIRV::Capability::CooperativeMatrixCheckedInstructionsINTEL);
2061 break;
2062 case SPIRV::OpReadPipeBlockingALTERA:
2063 case SPIRV::OpWritePipeBlockingALTERA:
2064 if (ST.canUseExtension(E: SPIRV::Extension::SPV_ALTERA_blocking_pipes)) {
2065 Reqs.addExtension(ToAdd: SPIRV::Extension::SPV_ALTERA_blocking_pipes);
2066 Reqs.addCapability(ToAdd: SPIRV::Capability::BlockingPipesALTERA);
2067 }
2068 break;
2069 case SPIRV::OpCooperativeMatrixGetElementCoordINTEL:
2070 if (!ST.canUseExtension(E: SPIRV::Extension::SPV_INTEL_joint_matrix))
2071 report_fatal_error(reason: "OpCooperativeMatrixGetElementCoordINTEL requires the "
2072 "following SPIR-V extension: SPV_INTEL_joint_matrix",
2073 gen_crash_diag: false);
2074 Reqs.addExtension(ToAdd: SPIRV::Extension::SPV_INTEL_joint_matrix);
2075 Reqs.addCapability(
2076 ToAdd: SPIRV::Capability::CooperativeMatrixInvocationInstructionsINTEL);
2077 break;
2078 case SPIRV::OpConvertHandleToImageINTEL:
2079 case SPIRV::OpConvertHandleToSamplerINTEL:
2080 case SPIRV::OpConvertHandleToSampledImageINTEL: {
2081 if (!ST.canUseExtension(E: SPIRV::Extension::SPV_INTEL_bindless_images))
2082 report_fatal_error(reason: "OpConvertHandleTo[Image/Sampler/SampledImage]INTEL "
2083 "instructions require the following SPIR-V extension: "
2084 "SPV_INTEL_bindless_images",
2085 gen_crash_diag: false);
2086 SPIRVGlobalRegistry *GR = ST.getSPIRVGlobalRegistry();
2087 SPIRV::AddressingModel::AddressingModel AddrModel = MAI.Addr;
2088 SPIRVTypeInst TyDef = GR->getSPIRVTypeForVReg(VReg: MI.getOperand(i: 1).getReg());
2089 if (MI.getOpcode() == SPIRV::OpConvertHandleToImageINTEL &&
2090 TyDef->getOpcode() != SPIRV::OpTypeImage) {
2091 report_fatal_error(reason: "Incorrect return type for the instruction "
2092 "OpConvertHandleToImageINTEL",
2093 gen_crash_diag: false);
2094 } else if (MI.getOpcode() == SPIRV::OpConvertHandleToSamplerINTEL &&
2095 TyDef->getOpcode() != SPIRV::OpTypeSampler) {
2096 report_fatal_error(reason: "Incorrect return type for the instruction "
2097 "OpConvertHandleToSamplerINTEL",
2098 gen_crash_diag: false);
2099 } else if (MI.getOpcode() == SPIRV::OpConvertHandleToSampledImageINTEL &&
2100 TyDef->getOpcode() != SPIRV::OpTypeSampledImage) {
2101 report_fatal_error(reason: "Incorrect return type for the instruction "
2102 "OpConvertHandleToSampledImageINTEL",
2103 gen_crash_diag: false);
2104 }
2105 SPIRVTypeInst SpvTy = GR->getSPIRVTypeForVReg(VReg: MI.getOperand(i: 2).getReg());
2106 unsigned Bitwidth = GR->getScalarOrVectorBitWidth(Type: SpvTy);
2107 if (!(Bitwidth == 32 && AddrModel == SPIRV::AddressingModel::Physical32) &&
2108 !(Bitwidth == 64 && AddrModel == SPIRV::AddressingModel::Physical64)) {
2109 report_fatal_error(
2110 reason: "Parameter value must be a 32-bit scalar in case of "
2111 "Physical32 addressing model or a 64-bit scalar in case of "
2112 "Physical64 addressing model",
2113 gen_crash_diag: false);
2114 }
2115 Reqs.addExtension(ToAdd: SPIRV::Extension::SPV_INTEL_bindless_images);
2116 Reqs.addCapability(ToAdd: SPIRV::Capability::BindlessImagesINTEL);
2117 break;
2118 }
2119 case SPIRV::OpSubgroup2DBlockLoadINTEL:
2120 case SPIRV::OpSubgroup2DBlockLoadTransposeINTEL:
2121 case SPIRV::OpSubgroup2DBlockLoadTransformINTEL:
2122 case SPIRV::OpSubgroup2DBlockPrefetchINTEL:
2123 case SPIRV::OpSubgroup2DBlockStoreINTEL: {
2124 if (!ST.canUseExtension(E: SPIRV::Extension::SPV_INTEL_2d_block_io))
2125 report_fatal_error(reason: "OpSubgroup2DBlock[Load/LoadTranspose/LoadTransform/"
2126 "Prefetch/Store]INTEL instructions require the "
2127 "following SPIR-V extension: SPV_INTEL_2d_block_io",
2128 gen_crash_diag: false);
2129 Reqs.addExtension(ToAdd: SPIRV::Extension::SPV_INTEL_2d_block_io);
2130 Reqs.addCapability(ToAdd: SPIRV::Capability::Subgroup2DBlockIOINTEL);
2131
2132 const auto OpCode = MI.getOpcode();
2133 if (OpCode == SPIRV::OpSubgroup2DBlockLoadTransposeINTEL) {
2134 Reqs.addCapability(ToAdd: SPIRV::Capability::Subgroup2DBlockTransposeINTEL);
2135 break;
2136 }
2137 if (OpCode == SPIRV::OpSubgroup2DBlockLoadTransformINTEL) {
2138 Reqs.addCapability(ToAdd: SPIRV::Capability::Subgroup2DBlockTransformINTEL);
2139 break;
2140 }
2141 break;
2142 }
2143 case SPIRV::OpKill: {
2144 Reqs.addCapability(ToAdd: SPIRV::Capability::Shader);
2145 } break;
2146 case SPIRV::OpDemoteToHelperInvocation:
2147 Reqs.addCapability(ToAdd: SPIRV::Capability::DemoteToHelperInvocation);
2148
2149 if (ST.canUseExtension(
2150 E: SPIRV::Extension::SPV_EXT_demote_to_helper_invocation)) {
2151 if (!ST.isAtLeastSPIRVVer(VerToCompareTo: llvm::VersionTuple(1, 6)))
2152 Reqs.addExtension(
2153 ToAdd: SPIRV::Extension::SPV_EXT_demote_to_helper_invocation);
2154 }
2155 break;
2156 case SPIRV::OpSDot:
2157 case SPIRV::OpUDot:
2158 case SPIRV::OpSUDot:
2159 case SPIRV::OpSDotAccSat:
2160 case SPIRV::OpUDotAccSat:
2161 case SPIRV::OpSUDotAccSat:
2162 AddDotProductRequirements(MI, Reqs, ST);
2163 break;
2164 case SPIRV::OpImageSampleImplicitLod:
2165 Reqs.addCapability(ToAdd: SPIRV::Capability::Shader);
2166 addImageOperandReqs(MI, Reqs, ST, OpIdx: 4);
2167 break;
2168 case SPIRV::OpImageSampleExplicitLod:
2169 addImageOperandReqs(MI, Reqs, ST, OpIdx: 4);
2170 break;
2171 case SPIRV::OpImageSampleDrefImplicitLod:
2172 Reqs.addCapability(ToAdd: SPIRV::Capability::Shader);
2173 addImageOperandReqs(MI, Reqs, ST, OpIdx: 5);
2174 break;
2175 case SPIRV::OpImageSampleDrefExplicitLod:
2176 Reqs.addCapability(ToAdd: SPIRV::Capability::Shader);
2177 addImageOperandReqs(MI, Reqs, ST, OpIdx: 5);
2178 break;
2179 case SPIRV::OpImageFetch:
2180 Reqs.addCapability(ToAdd: SPIRV::Capability::Shader);
2181 addImageOperandReqs(MI, Reqs, ST, OpIdx: 4);
2182 break;
2183 case SPIRV::OpImageDrefGather:
2184 case SPIRV::OpImageGather:
2185 Reqs.addCapability(ToAdd: SPIRV::Capability::Shader);
2186 addImageOperandReqs(MI, Reqs, ST, OpIdx: 5);
2187 break;
2188 case SPIRV::OpImageRead: {
2189 Register ImageReg = MI.getOperand(i: 2).getReg();
2190 SPIRVTypeInst TypeDef = ST.getSPIRVGlobalRegistry()->getResultType(
2191 VReg: ImageReg, MF: const_cast<MachineFunction *>(MI.getMF()));
2192 // OpImageRead and OpImageWrite can use Unknown Image Formats
2193 // when the Kernel capability is declared. In the OpenCL environment we are
2194 // not allowed to produce
2195 // StorageImageReadWithoutFormat/StorageImageWriteWithoutFormat, see
2196 // https://github.com/KhronosGroup/SPIRV-Headers/issues/487
2197
2198 if (isImageTypeWithUnknownFormat(TypeInst: TypeDef) && ST.isShader())
2199 Reqs.addCapability(ToAdd: SPIRV::Capability::StorageImageReadWithoutFormat);
2200 break;
2201 }
2202 case SPIRV::OpImageWrite: {
2203 Register ImageReg = MI.getOperand(i: 0).getReg();
2204 SPIRVTypeInst TypeDef = ST.getSPIRVGlobalRegistry()->getResultType(
2205 VReg: ImageReg, MF: const_cast<MachineFunction *>(MI.getMF()));
2206 // OpImageRead and OpImageWrite can use Unknown Image Formats
2207 // when the Kernel capability is declared. In the OpenCL environment we are
2208 // not allowed to produce
2209 // StorageImageReadWithoutFormat/StorageImageWriteWithoutFormat, see
2210 // https://github.com/KhronosGroup/SPIRV-Headers/issues/487
2211
2212 if (isImageTypeWithUnknownFormat(TypeInst: TypeDef) && ST.isShader())
2213 Reqs.addCapability(ToAdd: SPIRV::Capability::StorageImageWriteWithoutFormat);
2214 break;
2215 }
2216 case SPIRV::OpTypeStructContinuedINTEL:
2217 case SPIRV::OpConstantCompositeContinuedINTEL:
2218 case SPIRV::OpSpecConstantCompositeContinuedINTEL:
2219 case SPIRV::OpCompositeConstructContinuedINTEL: {
2220 if (!ST.canUseExtension(E: SPIRV::Extension::SPV_INTEL_long_composites))
2221 report_fatal_error(
2222 reason: "Continued instructions require the "
2223 "following SPIR-V extension: SPV_INTEL_long_composites",
2224 gen_crash_diag: false);
2225 Reqs.addExtension(ToAdd: SPIRV::Extension::SPV_INTEL_long_composites);
2226 Reqs.addCapability(ToAdd: SPIRV::Capability::LongCompositesINTEL);
2227 break;
2228 }
2229 case SPIRV::OpArbitraryFloatEQALTERA:
2230 case SPIRV::OpArbitraryFloatGEALTERA:
2231 case SPIRV::OpArbitraryFloatGTALTERA:
2232 case SPIRV::OpArbitraryFloatLEALTERA:
2233 case SPIRV::OpArbitraryFloatLTALTERA:
2234 case SPIRV::OpArbitraryFloatCbrtALTERA:
2235 case SPIRV::OpArbitraryFloatCosALTERA:
2236 case SPIRV::OpArbitraryFloatCosPiALTERA:
2237 case SPIRV::OpArbitraryFloatExp10ALTERA:
2238 case SPIRV::OpArbitraryFloatExp2ALTERA:
2239 case SPIRV::OpArbitraryFloatExpALTERA:
2240 case SPIRV::OpArbitraryFloatExpm1ALTERA:
2241 case SPIRV::OpArbitraryFloatHypotALTERA:
2242 case SPIRV::OpArbitraryFloatLog10ALTERA:
2243 case SPIRV::OpArbitraryFloatLog1pALTERA:
2244 case SPIRV::OpArbitraryFloatLog2ALTERA:
2245 case SPIRV::OpArbitraryFloatLogALTERA:
2246 case SPIRV::OpArbitraryFloatRecipALTERA:
2247 case SPIRV::OpArbitraryFloatSinCosALTERA:
2248 case SPIRV::OpArbitraryFloatSinCosPiALTERA:
2249 case SPIRV::OpArbitraryFloatSinALTERA:
2250 case SPIRV::OpArbitraryFloatSinPiALTERA:
2251 case SPIRV::OpArbitraryFloatSqrtALTERA:
2252 case SPIRV::OpArbitraryFloatACosALTERA:
2253 case SPIRV::OpArbitraryFloatACosPiALTERA:
2254 case SPIRV::OpArbitraryFloatAddALTERA:
2255 case SPIRV::OpArbitraryFloatASinALTERA:
2256 case SPIRV::OpArbitraryFloatASinPiALTERA:
2257 case SPIRV::OpArbitraryFloatATan2ALTERA:
2258 case SPIRV::OpArbitraryFloatATanALTERA:
2259 case SPIRV::OpArbitraryFloatATanPiALTERA:
2260 case SPIRV::OpArbitraryFloatCastFromIntALTERA:
2261 case SPIRV::OpArbitraryFloatCastALTERA:
2262 case SPIRV::OpArbitraryFloatCastToIntALTERA:
2263 case SPIRV::OpArbitraryFloatDivALTERA:
2264 case SPIRV::OpArbitraryFloatMulALTERA:
2265 case SPIRV::OpArbitraryFloatPowALTERA:
2266 case SPIRV::OpArbitraryFloatPowNALTERA:
2267 case SPIRV::OpArbitraryFloatPowRALTERA:
2268 case SPIRV::OpArbitraryFloatRSqrtALTERA:
2269 case SPIRV::OpArbitraryFloatSubALTERA: {
2270 if (!ST.canUseExtension(
2271 E: SPIRV::Extension::SPV_ALTERA_arbitrary_precision_floating_point))
2272 report_fatal_error(
2273 reason: "Floating point instructions can't be translated correctly without "
2274 "enabled SPV_ALTERA_arbitrary_precision_floating_point extension!",
2275 gen_crash_diag: false);
2276 Reqs.addExtension(
2277 ToAdd: SPIRV::Extension::SPV_ALTERA_arbitrary_precision_floating_point);
2278 Reqs.addCapability(
2279 ToAdd: SPIRV::Capability::ArbitraryPrecisionFloatingPointALTERA);
2280 break;
2281 }
2282 case SPIRV::OpSubgroupMatrixMultiplyAccumulateINTEL: {
2283 if (!ST.canUseExtension(
2284 E: SPIRV::Extension::SPV_INTEL_subgroup_matrix_multiply_accumulate))
2285 report_fatal_error(
2286 reason: "OpSubgroupMatrixMultiplyAccumulateINTEL instruction requires the "
2287 "following SPIR-V "
2288 "extension: SPV_INTEL_subgroup_matrix_multiply_accumulate",
2289 gen_crash_diag: false);
2290 Reqs.addExtension(
2291 ToAdd: SPIRV::Extension::SPV_INTEL_subgroup_matrix_multiply_accumulate);
2292 Reqs.addCapability(
2293 ToAdd: SPIRV::Capability::SubgroupMatrixMultiplyAccumulateINTEL);
2294 break;
2295 }
2296 case SPIRV::OpBitwiseFunctionINTEL: {
2297 if (!ST.canUseExtension(
2298 E: SPIRV::Extension::SPV_INTEL_ternary_bitwise_function))
2299 report_fatal_error(
2300 reason: "OpBitwiseFunctionINTEL instruction requires the following SPIR-V "
2301 "extension: SPV_INTEL_ternary_bitwise_function",
2302 gen_crash_diag: false);
2303 Reqs.addExtension(ToAdd: SPIRV::Extension::SPV_INTEL_ternary_bitwise_function);
2304 Reqs.addCapability(ToAdd: SPIRV::Capability::TernaryBitwiseFunctionINTEL);
2305 break;
2306 }
2307 case SPIRV::OpCopyMemorySized: {
2308 Reqs.addCapability(ToAdd: SPIRV::Capability::Addresses);
2309 // TODO: Add UntypedPointersKHR when implemented.
2310 break;
2311 }
2312 case SPIRV::OpPredicatedLoadINTEL:
2313 case SPIRV::OpPredicatedStoreINTEL: {
2314 if (!ST.canUseExtension(E: SPIRV::Extension::SPV_INTEL_predicated_io))
2315 report_fatal_error(
2316 reason: "OpPredicated[Load/Store]INTEL instructions require "
2317 "the following SPIR-V extension: SPV_INTEL_predicated_io",
2318 gen_crash_diag: false);
2319 Reqs.addExtension(ToAdd: SPIRV::Extension::SPV_INTEL_predicated_io);
2320 Reqs.addCapability(ToAdd: SPIRV::Capability::PredicatedIOINTEL);
2321 break;
2322 }
2323 case SPIRV::OpFAddS:
2324 case SPIRV::OpFSubS:
2325 case SPIRV::OpFMulS:
2326 case SPIRV::OpFDivS:
2327 case SPIRV::OpFRemS:
2328 case SPIRV::OpFMod:
2329 case SPIRV::OpFNegate:
2330 case SPIRV::OpFAddV:
2331 case SPIRV::OpFSubV:
2332 case SPIRV::OpFMulV:
2333 case SPIRV::OpFDivV:
2334 case SPIRV::OpFRemV:
2335 case SPIRV::OpFNegateV: {
2336 const MachineRegisterInfo &MRI = MI.getMF()->getRegInfo();
2337 SPIRVTypeInst TypeDef = MRI.getVRegDef(Reg: MI.getOperand(i: 1).getReg());
2338 if (TypeDef->getOpcode() == SPIRV::OpTypeVector)
2339 TypeDef = MRI.getVRegDef(Reg: TypeDef->getOperand(i: 1).getReg());
2340 if (isBFloat16Type(TypeDef)) {
2341 if (!ST.canUseExtension(E: SPIRV::Extension::SPV_INTEL_bfloat16_arithmetic))
2342 report_fatal_error(
2343 reason: "Arithmetic instructions with bfloat16 arguments require the "
2344 "following SPIR-V extension: SPV_INTEL_bfloat16_arithmetic",
2345 gen_crash_diag: false);
2346 Reqs.addExtension(ToAdd: SPIRV::Extension::SPV_INTEL_bfloat16_arithmetic);
2347 Reqs.addCapability(ToAdd: SPIRV::Capability::BFloat16ArithmeticINTEL);
2348 }
2349 break;
2350 }
2351 case SPIRV::OpOrdered:
2352 case SPIRV::OpUnordered:
2353 case SPIRV::OpFOrdEqual:
2354 case SPIRV::OpFOrdNotEqual:
2355 case SPIRV::OpFOrdLessThan:
2356 case SPIRV::OpFOrdLessThanEqual:
2357 case SPIRV::OpFOrdGreaterThan:
2358 case SPIRV::OpFOrdGreaterThanEqual:
2359 case SPIRV::OpFUnordEqual:
2360 case SPIRV::OpFUnordNotEqual:
2361 case SPIRV::OpFUnordLessThan:
2362 case SPIRV::OpFUnordLessThanEqual:
2363 case SPIRV::OpFUnordGreaterThan:
2364 case SPIRV::OpFUnordGreaterThanEqual: {
2365 const MachineRegisterInfo &MRI = MI.getMF()->getRegInfo();
2366 MachineInstr *OperandDef = MRI.getVRegDef(Reg: MI.getOperand(i: 2).getReg());
2367 SPIRVTypeInst TypeDef = MRI.getVRegDef(Reg: OperandDef->getOperand(i: 1).getReg());
2368 if (TypeDef->getOpcode() == SPIRV::OpTypeVector)
2369 TypeDef = MRI.getVRegDef(Reg: TypeDef->getOperand(i: 1).getReg());
2370 if (isBFloat16Type(TypeDef)) {
2371 if (!ST.canUseExtension(E: SPIRV::Extension::SPV_INTEL_bfloat16_arithmetic))
2372 report_fatal_error(
2373 reason: "Relational instructions with bfloat16 arguments require the "
2374 "following SPIR-V extension: SPV_INTEL_bfloat16_arithmetic",
2375 gen_crash_diag: false);
2376 Reqs.addExtension(ToAdd: SPIRV::Extension::SPV_INTEL_bfloat16_arithmetic);
2377 Reqs.addCapability(ToAdd: SPIRV::Capability::BFloat16ArithmeticINTEL);
2378 }
2379 break;
2380 }
2381 case SPIRV::OpDPdxCoarse:
2382 case SPIRV::OpDPdyCoarse:
2383 case SPIRV::OpDPdxFine:
2384 case SPIRV::OpDPdyFine: {
2385 Reqs.addCapability(ToAdd: SPIRV::Capability::DerivativeControl);
2386 break;
2387 }
2388 case SPIRV::OpLoopControlINTEL: {
2389 Reqs.addExtension(ToAdd: SPIRV::Extension::SPV_INTEL_unstructured_loop_controls);
2390 Reqs.addCapability(ToAdd: SPIRV::Capability::UnstructuredLoopControlsINTEL);
2391 break;
2392 }
2393
2394 default:
2395 break;
2396 }
2397
2398 // If we require capability Shader, then we can remove the requirement for
2399 // the BitInstructions capability, since Shader is a superset capability
2400 // of BitInstructions.
2401 Reqs.removeCapabilityIf(ToRemove: SPIRV::Capability::BitInstructions,
2402 IfPresent: SPIRV::Capability::Shader);
2403}
2404
2405static void collectReqs(const Module &M, SPIRV::ModuleAnalysisInfo &MAI,
2406 MachineModuleInfo *MMI, const SPIRVSubtarget &ST) {
2407 // Collect requirements for existing instructions.
2408 for (const Function &F : M) {
2409 MachineFunction *MF = MMI->getMachineFunction(F);
2410 if (!MF)
2411 continue;
2412 for (const MachineBasicBlock &MBB : *MF)
2413 for (const MachineInstr &MI : MBB)
2414 addInstrRequirements(MI, MAI, ST);
2415 }
2416 // Collect requirements for OpExecutionMode instructions.
2417 auto Node = M.getNamedMetadata(Name: "spirv.ExecutionMode");
2418 if (Node) {
2419 bool RequireFloatControls = false, RequireIntelFloatControls2 = false,
2420 RequireKHRFloatControls2 = false,
2421 VerLower14 = !ST.isAtLeastSPIRVVer(VerToCompareTo: VersionTuple(1, 4));
2422 bool HasIntelFloatControls2 =
2423 ST.canUseExtension(E: SPIRV::Extension::SPV_INTEL_float_controls2);
2424 bool HasKHRFloatControls2 =
2425 ST.canUseExtension(E: SPIRV::Extension::SPV_KHR_float_controls2);
2426 for (unsigned i = 0; i < Node->getNumOperands(); i++) {
2427 MDNode *MDN = cast<MDNode>(Val: Node->getOperand(i));
2428 const MDOperand &MDOp = MDN->getOperand(I: 1);
2429 if (auto *CMeta = dyn_cast<ConstantAsMetadata>(Val: MDOp)) {
2430 Constant *C = CMeta->getValue();
2431 if (ConstantInt *Const = dyn_cast<ConstantInt>(Val: C)) {
2432 auto EM = Const->getZExtValue();
2433 // SPV_KHR_float_controls is not available until v1.4:
2434 // add SPV_KHR_float_controls if the version is too low
2435 switch (EM) {
2436 case SPIRV::ExecutionMode::DenormPreserve:
2437 case SPIRV::ExecutionMode::DenormFlushToZero:
2438 case SPIRV::ExecutionMode::RoundingModeRTE:
2439 case SPIRV::ExecutionMode::RoundingModeRTZ:
2440 RequireFloatControls = VerLower14;
2441 MAI.Reqs.getAndAddRequirements(
2442 Category: SPIRV::OperandCategory::ExecutionModeOperand, i: EM, ST);
2443 break;
2444 case SPIRV::ExecutionMode::RoundingModeRTPINTEL:
2445 case SPIRV::ExecutionMode::RoundingModeRTNINTEL:
2446 case SPIRV::ExecutionMode::FloatingPointModeALTINTEL:
2447 case SPIRV::ExecutionMode::FloatingPointModeIEEEINTEL:
2448 if (HasIntelFloatControls2) {
2449 RequireIntelFloatControls2 = true;
2450 MAI.Reqs.getAndAddRequirements(
2451 Category: SPIRV::OperandCategory::ExecutionModeOperand, i: EM, ST);
2452 }
2453 break;
2454 case SPIRV::ExecutionMode::FPFastMathDefault: {
2455 if (HasKHRFloatControls2) {
2456 RequireKHRFloatControls2 = true;
2457 MAI.Reqs.getAndAddRequirements(
2458 Category: SPIRV::OperandCategory::ExecutionModeOperand, i: EM, ST);
2459 }
2460 break;
2461 }
2462 case SPIRV::ExecutionMode::ContractionOff:
2463 case SPIRV::ExecutionMode::SignedZeroInfNanPreserve:
2464 if (HasKHRFloatControls2) {
2465 RequireKHRFloatControls2 = true;
2466 MAI.Reqs.getAndAddRequirements(
2467 Category: SPIRV::OperandCategory::ExecutionModeOperand,
2468 i: SPIRV::ExecutionMode::FPFastMathDefault, ST);
2469 } else {
2470 MAI.Reqs.getAndAddRequirements(
2471 Category: SPIRV::OperandCategory::ExecutionModeOperand, i: EM, ST);
2472 }
2473 break;
2474 default:
2475 MAI.Reqs.getAndAddRequirements(
2476 Category: SPIRV::OperandCategory::ExecutionModeOperand, i: EM, ST);
2477 }
2478 }
2479 }
2480 }
2481 if (RequireFloatControls &&
2482 ST.canUseExtension(E: SPIRV::Extension::SPV_KHR_float_controls))
2483 MAI.Reqs.addExtension(ToAdd: SPIRV::Extension::SPV_KHR_float_controls);
2484 if (RequireIntelFloatControls2)
2485 MAI.Reqs.addExtension(ToAdd: SPIRV::Extension::SPV_INTEL_float_controls2);
2486 if (RequireKHRFloatControls2)
2487 MAI.Reqs.addExtension(ToAdd: SPIRV::Extension::SPV_KHR_float_controls2);
2488 }
2489 for (const Function &F : M) {
2490 if (F.isDeclaration())
2491 continue;
2492 if (F.getMetadata(Kind: "reqd_work_group_size"))
2493 MAI.Reqs.getAndAddRequirements(
2494 Category: SPIRV::OperandCategory::ExecutionModeOperand,
2495 i: SPIRV::ExecutionMode::LocalSize, ST);
2496 if (F.getFnAttribute(Kind: "hlsl.numthreads").isValid()) {
2497 MAI.Reqs.getAndAddRequirements(
2498 Category: SPIRV::OperandCategory::ExecutionModeOperand,
2499 i: SPIRV::ExecutionMode::LocalSize, ST);
2500 }
2501 if (F.getFnAttribute(Kind: "enable-maximal-reconvergence").getValueAsBool()) {
2502 MAI.Reqs.addExtension(ToAdd: SPIRV::Extension::SPV_KHR_maximal_reconvergence);
2503 }
2504 if (F.getMetadata(Kind: "work_group_size_hint"))
2505 MAI.Reqs.getAndAddRequirements(
2506 Category: SPIRV::OperandCategory::ExecutionModeOperand,
2507 i: SPIRV::ExecutionMode::LocalSizeHint, ST);
2508 if (F.getMetadata(Kind: "intel_reqd_sub_group_size"))
2509 MAI.Reqs.getAndAddRequirements(
2510 Category: SPIRV::OperandCategory::ExecutionModeOperand,
2511 i: SPIRV::ExecutionMode::SubgroupSize, ST);
2512 if (F.getMetadata(Kind: "max_work_group_size"))
2513 MAI.Reqs.getAndAddRequirements(
2514 Category: SPIRV::OperandCategory::ExecutionModeOperand,
2515 i: SPIRV::ExecutionMode::MaxWorkgroupSizeINTEL, ST);
2516 if (F.getMetadata(Kind: "vec_type_hint"))
2517 MAI.Reqs.getAndAddRequirements(
2518 Category: SPIRV::OperandCategory::ExecutionModeOperand,
2519 i: SPIRV::ExecutionMode::VecTypeHint, ST);
2520
2521 if (F.hasOptNone()) {
2522 if (ST.canUseExtension(E: SPIRV::Extension::SPV_INTEL_optnone)) {
2523 MAI.Reqs.addExtension(ToAdd: SPIRV::Extension::SPV_INTEL_optnone);
2524 MAI.Reqs.addCapability(ToAdd: SPIRV::Capability::OptNoneINTEL);
2525 } else if (ST.canUseExtension(E: SPIRV::Extension::SPV_EXT_optnone)) {
2526 MAI.Reqs.addExtension(ToAdd: SPIRV::Extension::SPV_EXT_optnone);
2527 MAI.Reqs.addCapability(ToAdd: SPIRV::Capability::OptNoneEXT);
2528 }
2529 }
2530 }
2531}
2532
2533static unsigned getFastMathFlags(const MachineInstr &I,
2534 const SPIRVSubtarget &ST) {
2535 unsigned Flags = SPIRV::FPFastMathMode::None;
2536 bool CanUseKHRFloatControls2 =
2537 ST.canUseExtension(E: SPIRV::Extension::SPV_KHR_float_controls2);
2538 if (I.getFlag(Flag: MachineInstr::MIFlag::FmNoNans))
2539 Flags |= SPIRV::FPFastMathMode::NotNaN;
2540 if (I.getFlag(Flag: MachineInstr::MIFlag::FmNoInfs))
2541 Flags |= SPIRV::FPFastMathMode::NotInf;
2542 if (I.getFlag(Flag: MachineInstr::MIFlag::FmNsz))
2543 Flags |= SPIRV::FPFastMathMode::NSZ;
2544 if (I.getFlag(Flag: MachineInstr::MIFlag::FmArcp))
2545 Flags |= SPIRV::FPFastMathMode::AllowRecip;
2546 if (I.getFlag(Flag: MachineInstr::MIFlag::FmContract) && CanUseKHRFloatControls2)
2547 Flags |= SPIRV::FPFastMathMode::AllowContract;
2548 if (I.getFlag(Flag: MachineInstr::MIFlag::FmReassoc)) {
2549 if (CanUseKHRFloatControls2)
2550 // LLVM reassoc maps to SPIRV transform, see
2551 // https://github.com/KhronosGroup/SPIRV-Registry/issues/326 for details.
2552 // Because we are enabling AllowTransform, we must enable AllowReassoc and
2553 // AllowContract too, as required by SPIRV spec. Also, we used to map
2554 // MIFlag::FmReassoc to FPFastMathMode::Fast, which now should instead by
2555 // replaced by turning all the other bits instead. Therefore, we're
2556 // enabling every bit here except None and Fast.
2557 Flags |= SPIRV::FPFastMathMode::NotNaN | SPIRV::FPFastMathMode::NotInf |
2558 SPIRV::FPFastMathMode::NSZ | SPIRV::FPFastMathMode::AllowRecip |
2559 SPIRV::FPFastMathMode::AllowTransform |
2560 SPIRV::FPFastMathMode::AllowReassoc |
2561 SPIRV::FPFastMathMode::AllowContract;
2562 else
2563 Flags |= SPIRV::FPFastMathMode::Fast;
2564 }
2565
2566 if (CanUseKHRFloatControls2) {
2567 // Error out if SPIRV::FPFastMathMode::Fast is enabled.
2568 assert(!(Flags & SPIRV::FPFastMathMode::Fast) &&
2569 "SPIRV::FPFastMathMode::Fast is deprecated and should not be used "
2570 "anymore.");
2571
2572 // Error out if AllowTransform is enabled without AllowReassoc and
2573 // AllowContract.
2574 assert((!(Flags & SPIRV::FPFastMathMode::AllowTransform) ||
2575 ((Flags & SPIRV::FPFastMathMode::AllowReassoc &&
2576 Flags & SPIRV::FPFastMathMode::AllowContract))) &&
2577 "SPIRV::FPFastMathMode::AllowTransform requires AllowReassoc and "
2578 "AllowContract flags to be enabled as well.");
2579 }
2580
2581 return Flags;
2582}
2583
2584static bool isFastMathModeAvailable(const SPIRVSubtarget &ST) {
2585 if (ST.isKernel())
2586 return true;
2587 if (ST.getSPIRVVersion() < VersionTuple(1, 2))
2588 return false;
2589 return ST.canUseExtension(E: SPIRV::Extension::SPV_KHR_float_controls2);
2590}
2591
2592static void handleMIFlagDecoration(
2593 MachineInstr &I, const SPIRVSubtarget &ST, const SPIRVInstrInfo &TII,
2594 SPIRV::RequirementHandler &Reqs, const SPIRVGlobalRegistry *GR,
2595 SPIRV::FPFastMathDefaultInfoVector &FPFastMathDefaultInfoVec) {
2596 if (I.getFlag(Flag: MachineInstr::MIFlag::NoSWrap) && TII.canUseNSW(MI: I) &&
2597 getSymbolicOperandRequirements(Category: SPIRV::OperandCategory::DecorationOperand,
2598 i: SPIRV::Decoration::NoSignedWrap, ST, Reqs)
2599 .IsSatisfiable) {
2600 buildOpDecorate(Reg: I.getOperand(i: 0).getReg(), I, TII,
2601 Dec: SPIRV::Decoration::NoSignedWrap, DecArgs: {});
2602 }
2603 if (I.getFlag(Flag: MachineInstr::MIFlag::NoUWrap) && TII.canUseNUW(MI: I) &&
2604 getSymbolicOperandRequirements(Category: SPIRV::OperandCategory::DecorationOperand,
2605 i: SPIRV::Decoration::NoUnsignedWrap, ST,
2606 Reqs)
2607 .IsSatisfiable) {
2608 buildOpDecorate(Reg: I.getOperand(i: 0).getReg(), I, TII,
2609 Dec: SPIRV::Decoration::NoUnsignedWrap, DecArgs: {});
2610 }
2611 if (!TII.canUseFastMathFlags(
2612 MI: I, KHRFloatControls2: ST.canUseExtension(E: SPIRV::Extension::SPV_KHR_float_controls2)))
2613 return;
2614
2615 unsigned FMFlags = getFastMathFlags(I, ST);
2616 if (FMFlags == SPIRV::FPFastMathMode::None) {
2617 // We also need to check if any FPFastMathDefault info was set for the
2618 // types used in this instruction.
2619 if (FPFastMathDefaultInfoVec.empty())
2620 return;
2621
2622 // There are three types of instructions that can use fast math flags:
2623 // 1. Arithmetic instructions (FAdd, FMul, FSub, FDiv, FRem, etc.)
2624 // 2. Relational instructions (FCmp, FOrd, FUnord, etc.)
2625 // 3. Extended instructions (ExtInst)
2626 // For arithmetic instructions, the floating point type can be in the
2627 // result type or in the operands, but they all must be the same.
2628 // For the relational and logical instructions, the floating point type
2629 // can only be in the operands 1 and 2, not the result type. Also, the
2630 // operands must have the same type. For the extended instructions, the
2631 // floating point type can be in the result type or in the operands. It's
2632 // unclear if the operands and the result type must be the same. Let's
2633 // assume they must be. Therefore, for 1. and 2., we can check the first
2634 // operand type, and for 3. we can check the result type.
2635 assert(I.getNumOperands() >= 3 && "Expected at least 3 operands");
2636 Register ResReg = I.getOpcode() == SPIRV::OpExtInst
2637 ? I.getOperand(i: 1).getReg()
2638 : I.getOperand(i: 2).getReg();
2639 SPIRVTypeInst ResType = GR->getSPIRVTypeForVReg(VReg: ResReg, MF: I.getMF());
2640 const Type *Ty = GR->getTypeForSPIRVType(Ty: ResType);
2641 Ty = Ty->isVectorTy() ? cast<VectorType>(Val: Ty)->getElementType() : Ty;
2642
2643 // Match instruction type with the FPFastMathDefaultInfoVec.
2644 bool Emit = false;
2645 for (SPIRV::FPFastMathDefaultInfo &Elem : FPFastMathDefaultInfoVec) {
2646 if (Ty == Elem.Ty) {
2647 FMFlags = Elem.FastMathFlags;
2648 Emit = Elem.ContractionOff || Elem.SignedZeroInfNanPreserve ||
2649 Elem.FPFastMathDefault;
2650 break;
2651 }
2652 }
2653
2654 if (FMFlags == SPIRV::FPFastMathMode::None && !Emit)
2655 return;
2656 }
2657 if (isFastMathModeAvailable(ST)) {
2658 Register DstReg = I.getOperand(i: 0).getReg();
2659 buildOpDecorate(Reg: DstReg, I, TII, Dec: SPIRV::Decoration::FPFastMathMode,
2660 DecArgs: {FMFlags});
2661 }
2662}
2663
2664// Walk all functions and add decorations related to MI flags.
2665static void addDecorations(const Module &M, const SPIRVInstrInfo &TII,
2666 MachineModuleInfo *MMI, const SPIRVSubtarget &ST,
2667 SPIRV::ModuleAnalysisInfo &MAI,
2668 const SPIRVGlobalRegistry *GR) {
2669 for (const Function &F : M) {
2670 MachineFunction *MF = MMI->getMachineFunction(F);
2671 if (!MF)
2672 continue;
2673
2674 for (auto &MBB : *MF)
2675 for (auto &MI : MBB)
2676 handleMIFlagDecoration(I&: MI, ST, TII, Reqs&: MAI.Reqs, GR,
2677 FPFastMathDefaultInfoVec&: MAI.FPFastMathDefaultInfoMap[&F]);
2678 }
2679}
2680
2681static void addMBBNames(const Module &M, const SPIRVInstrInfo &TII,
2682 MachineModuleInfo *MMI, const SPIRVSubtarget &ST,
2683 SPIRV::ModuleAnalysisInfo &MAI) {
2684 for (const Function &F : M) {
2685 MachineFunction *MF = MMI->getMachineFunction(F);
2686 if (!MF)
2687 continue;
2688 if (MF->getFunction()
2689 .getFnAttribute(SPIRV_BACKEND_SERVICE_FUN_NAME)
2690 .isValid())
2691 continue;
2692 MachineRegisterInfo &MRI = MF->getRegInfo();
2693 for (auto &MBB : *MF) {
2694 if (!MBB.hasName() || MBB.empty())
2695 continue;
2696 // Emit basic block names.
2697 Register Reg = MRI.createGenericVirtualRegister(Ty: LLT::scalar(SizeInBits: 64));
2698 MRI.setRegClass(Reg, RC: &SPIRV::IDRegClass);
2699 buildOpName(Target: Reg, Name: MBB.getName(), I&: *std::prev(x: MBB.end()), TII);
2700 MCRegister GlobalReg = MAI.getOrCreateMBBRegister(MBB);
2701 MAI.setRegisterAlias(MF, Reg, AliasReg: GlobalReg);
2702 }
2703 }
2704}
2705
2706// patching Instruction::PHI to SPIRV::OpPhi
2707static void patchPhis(const Module &M, SPIRVGlobalRegistry *GR,
2708 const SPIRVInstrInfo &TII, MachineModuleInfo *MMI) {
2709 for (const Function &F : M) {
2710 MachineFunction *MF = MMI->getMachineFunction(F);
2711 if (!MF)
2712 continue;
2713 for (auto &MBB : *MF) {
2714 for (MachineInstr &MI : MBB.phis()) {
2715 MI.setDesc(TII.get(Opcode: SPIRV::OpPhi));
2716 Register ResTypeReg = GR->getSPIRVTypeID(
2717 SpirvType: GR->getSPIRVTypeForVReg(VReg: MI.getOperand(i: 0).getReg(), MF));
2718 MI.insert(InsertBefore: MI.operands_begin() + 1,
2719 Ops: {MachineOperand::CreateReg(Reg: ResTypeReg, isDef: false)});
2720 }
2721 }
2722
2723 MF->getProperties().setNoPHIs();
2724 }
2725}
2726
2727static SPIRV::FPFastMathDefaultInfoVector &getOrCreateFPFastMathDefaultInfoVec(
2728 const Module &M, SPIRV::ModuleAnalysisInfo &MAI, const Function *F) {
2729 auto it = MAI.FPFastMathDefaultInfoMap.find(Val: F);
2730 if (it != MAI.FPFastMathDefaultInfoMap.end())
2731 return it->second;
2732
2733 // If the map does not contain the entry, create a new one. Initialize it to
2734 // contain all 3 elements sorted by bit width of target type: {half, float,
2735 // double}.
2736 SPIRV::FPFastMathDefaultInfoVector FPFastMathDefaultInfoVec;
2737 FPFastMathDefaultInfoVec.emplace_back(Args: Type::getHalfTy(C&: M.getContext()),
2738 Args: SPIRV::FPFastMathMode::None);
2739 FPFastMathDefaultInfoVec.emplace_back(Args: Type::getFloatTy(C&: M.getContext()),
2740 Args: SPIRV::FPFastMathMode::None);
2741 FPFastMathDefaultInfoVec.emplace_back(Args: Type::getDoubleTy(C&: M.getContext()),
2742 Args: SPIRV::FPFastMathMode::None);
2743 return MAI.FPFastMathDefaultInfoMap[F] = std::move(FPFastMathDefaultInfoVec);
2744}
2745
2746static SPIRV::FPFastMathDefaultInfo &getFPFastMathDefaultInfo(
2747 SPIRV::FPFastMathDefaultInfoVector &FPFastMathDefaultInfoVec,
2748 const Type *Ty) {
2749 size_t BitWidth = Ty->getScalarSizeInBits();
2750 int Index =
2751 SPIRV::FPFastMathDefaultInfoVector::computeFPFastMathDefaultInfoVecIndex(
2752 BitWidth);
2753 assert(Index >= 0 && Index < 3 &&
2754 "Expected FPFastMathDefaultInfo for half, float, or double");
2755 assert(FPFastMathDefaultInfoVec.size() == 3 &&
2756 "Expected FPFastMathDefaultInfoVec to have exactly 3 elements");
2757 return FPFastMathDefaultInfoVec[Index];
2758}
2759
2760static void collectFPFastMathDefaults(const Module &M,
2761 SPIRV::ModuleAnalysisInfo &MAI,
2762 const SPIRVSubtarget &ST) {
2763 if (!ST.canUseExtension(E: SPIRV::Extension::SPV_KHR_float_controls2))
2764 return;
2765
2766 // Store the FPFastMathDefaultInfo in the FPFastMathDefaultInfoMap.
2767 // We need the entry point (function) as the key, and the target
2768 // type and flags as the value.
2769 // We also need to check ContractionOff and SignedZeroInfNanPreserve
2770 // execution modes, as they are now deprecated and must be replaced
2771 // with FPFastMathDefaultInfo.
2772 auto Node = M.getNamedMetadata(Name: "spirv.ExecutionMode");
2773 if (!Node)
2774 return;
2775
2776 for (unsigned i = 0; i < Node->getNumOperands(); i++) {
2777 MDNode *MDN = cast<MDNode>(Val: Node->getOperand(i));
2778 assert(MDN->getNumOperands() >= 2 && "Expected at least 2 operands");
2779 const Function *F = cast<Function>(
2780 Val: cast<ConstantAsMetadata>(Val: MDN->getOperand(I: 0))->getValue());
2781 const auto EM =
2782 cast<ConstantInt>(
2783 Val: cast<ConstantAsMetadata>(Val: MDN->getOperand(I: 1))->getValue())
2784 ->getZExtValue();
2785 if (EM == SPIRV::ExecutionMode::FPFastMathDefault) {
2786 assert(MDN->getNumOperands() == 4 &&
2787 "Expected 4 operands for FPFastMathDefault");
2788
2789 const Type *T = cast<ValueAsMetadata>(Val: MDN->getOperand(I: 2))->getType();
2790 unsigned Flags =
2791 cast<ConstantInt>(
2792 Val: cast<ConstantAsMetadata>(Val: MDN->getOperand(I: 3))->getValue())
2793 ->getZExtValue();
2794 SPIRV::FPFastMathDefaultInfoVector &FPFastMathDefaultInfoVec =
2795 getOrCreateFPFastMathDefaultInfoVec(M, MAI, F);
2796 SPIRV::FPFastMathDefaultInfo &Info =
2797 getFPFastMathDefaultInfo(FPFastMathDefaultInfoVec, Ty: T);
2798 Info.FastMathFlags = Flags;
2799 Info.FPFastMathDefault = true;
2800 } else if (EM == SPIRV::ExecutionMode::ContractionOff) {
2801 assert(MDN->getNumOperands() == 2 &&
2802 "Expected no operands for ContractionOff");
2803
2804 // We need to save this info for every possible FP type, i.e. {half,
2805 // float, double, fp128}.
2806 SPIRV::FPFastMathDefaultInfoVector &FPFastMathDefaultInfoVec =
2807 getOrCreateFPFastMathDefaultInfoVec(M, MAI, F);
2808 for (SPIRV::FPFastMathDefaultInfo &Info : FPFastMathDefaultInfoVec) {
2809 Info.ContractionOff = true;
2810 }
2811 } else if (EM == SPIRV::ExecutionMode::SignedZeroInfNanPreserve) {
2812 assert(MDN->getNumOperands() == 3 &&
2813 "Expected 1 operand for SignedZeroInfNanPreserve");
2814 unsigned TargetWidth =
2815 cast<ConstantInt>(
2816 Val: cast<ConstantAsMetadata>(Val: MDN->getOperand(I: 2))->getValue())
2817 ->getZExtValue();
2818 // We need to save this info only for the FP type with TargetWidth.
2819 SPIRV::FPFastMathDefaultInfoVector &FPFastMathDefaultInfoVec =
2820 getOrCreateFPFastMathDefaultInfoVec(M, MAI, F);
2821 int Index = SPIRV::FPFastMathDefaultInfoVector::
2822 computeFPFastMathDefaultInfoVecIndex(BitWidth: TargetWidth);
2823 assert(Index >= 0 && Index < 3 &&
2824 "Expected FPFastMathDefaultInfo for half, float, or double");
2825 assert(FPFastMathDefaultInfoVec.size() == 3 &&
2826 "Expected FPFastMathDefaultInfoVec to have exactly 3 elements");
2827 FPFastMathDefaultInfoVec[Index].SignedZeroInfNanPreserve = true;
2828 }
2829 }
2830}
2831
2832void SPIRVModuleAnalysis::getAnalysisUsage(AnalysisUsage &AU) const {
2833 AU.addRequired<TargetPassConfig>();
2834 AU.addRequired<MachineModuleInfoWrapperPass>();
2835}
2836
2837bool SPIRVModuleAnalysis::runOnModule(Module &M) {
2838 SPIRVTargetMachine &TM =
2839 getAnalysis<TargetPassConfig>().getTM<SPIRVTargetMachine>();
2840 ST = TM.getSubtargetImpl();
2841 GR = ST->getSPIRVGlobalRegistry();
2842 TII = ST->getInstrInfo();
2843
2844 MMI = &getAnalysis<MachineModuleInfoWrapperPass>().getMMI();
2845
2846 setBaseInfo(M);
2847
2848 patchPhis(M, GR, TII: *TII, MMI);
2849
2850 addMBBNames(M, TII: *TII, MMI, ST: *ST, MAI);
2851 collectFPFastMathDefaults(M, MAI, ST: *ST);
2852 addDecorations(M, TII: *TII, MMI, ST: *ST, MAI, GR);
2853
2854 collectReqs(M, MAI, MMI, ST: *ST);
2855
2856 // Process type/const/global var/func decl instructions, number their
2857 // destination registers from 0 to N, collect Extensions and Capabilities.
2858 collectReqs(M, MAI, MMI, ST: *ST);
2859 collectDeclarations(M);
2860
2861 // Number rest of registers from N+1 onwards.
2862 numberRegistersGlobally(M);
2863
2864 // Collect OpName, OpEntryPoint, OpDecorate etc, process other instructions.
2865 processOtherInstrs(M);
2866
2867 // If there are no entry points, we need the Linkage capability.
2868 if (MAI.MS[SPIRV::MB_EntryPoints].empty())
2869 MAI.Reqs.addCapability(ToAdd: SPIRV::Capability::Linkage);
2870
2871 // Set maximum ID used.
2872 GR->setBound(MAI.MaxID);
2873
2874 return false;
2875}
2876