1//===- SPIRVModuleAnalysis.cpp - analysis of global instrs & regs - C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// The analysis collects instructions that should be output at the module level
10// and performs the global register numbering.
11//
12// The results of this analysis are used in AsmPrinter to rename registers
13// globally and to output required instructions at the module level.
14//
15//===----------------------------------------------------------------------===//
16
17// TODO: uses or report_fatal_error (which is also deprecated) /
18// ReportFatalUsageError in this file should be refactored, as per LLVM
19// best practices, to rely on the Diagnostic infrastructure.
20
21#include "SPIRVModuleAnalysis.h"
22#include "MCTargetDesc/SPIRVBaseInfo.h"
23#include "MCTargetDesc/SPIRVMCTargetDesc.h"
24#include "SPIRV.h"
25#include "SPIRVSubtarget.h"
26#include "SPIRVTargetMachine.h"
27#include "SPIRVUtils.h"
28#include "llvm/ADT/STLExtras.h"
29#include "llvm/CodeGen/MachineModuleInfo.h"
30#include "llvm/CodeGen/TargetPassConfig.h"
31
32using namespace llvm;
33
34#define DEBUG_TYPE "spirv-module-analysis"
35
36static cl::opt<bool>
37 SPVDumpDeps("spv-dump-deps",
38 cl::desc("Dump MIR with SPIR-V dependencies info"),
39 cl::Optional, cl::init(Val: false));
40
41static cl::list<SPIRV::Capability::Capability>
42 AvoidCapabilities("avoid-spirv-capabilities",
43 cl::desc("SPIR-V capabilities to avoid if there are "
44 "other options enabling a feature"),
45 cl::ZeroOrMore, cl::Hidden,
46 cl::values(clEnumValN(SPIRV::Capability::Shader, "Shader",
47 "SPIR-V Shader capability")));
48// Use sets instead of cl::list to check "if contains" condition
49struct AvoidCapabilitiesSet {
50 SmallSet<SPIRV::Capability::Capability, 4> S;
51 AvoidCapabilitiesSet() { S.insert_range(R&: AvoidCapabilities); }
52};
53
54char llvm::SPIRVModuleAnalysis::ID = 0;
55
56INITIALIZE_PASS(SPIRVModuleAnalysis, DEBUG_TYPE, "SPIRV module analysis", true,
57 true)
58
59// Retrieve an unsigned from an MDNode with a list of them as operands.
60static unsigned getMetadataUInt(MDNode *MdNode, unsigned OpIndex,
61 unsigned DefaultVal = 0) {
62 if (MdNode && OpIndex < MdNode->getNumOperands()) {
63 const auto &Op = MdNode->getOperand(I: OpIndex);
64 return mdconst::extract<ConstantInt>(MD: Op)->getZExtValue();
65 }
66 return DefaultVal;
67}
68
69static SPIRV::Requirements
70getSymbolicOperandRequirements(SPIRV::OperandCategory::OperandCategory Category,
71 unsigned i, const SPIRVSubtarget &ST,
72 SPIRV::RequirementHandler &Reqs) {
73 // A set of capabilities to avoid if there is another option.
74 AvoidCapabilitiesSet AvoidCaps;
75 if (!ST.isShader())
76 AvoidCaps.S.insert(V: SPIRV::Capability::Shader);
77 else
78 AvoidCaps.S.insert(V: SPIRV::Capability::Kernel);
79
80 VersionTuple ReqMinVer = getSymbolicOperandMinVersion(Category, Value: i);
81 VersionTuple ReqMaxVer = getSymbolicOperandMaxVersion(Category, Value: i);
82 VersionTuple SPIRVVersion = ST.getSPIRVVersion();
83 bool MinVerOK = SPIRVVersion.empty() || SPIRVVersion >= ReqMinVer;
84 bool MaxVerOK =
85 ReqMaxVer.empty() || SPIRVVersion.empty() || SPIRVVersion <= ReqMaxVer;
86 CapabilityList ReqCaps = getSymbolicOperandCapabilities(Category, Value: i);
87 ExtensionList ReqExts = getSymbolicOperandExtensions(Category, Value: i);
88 if (ReqCaps.empty()) {
89 if (ReqExts.empty()) {
90 if (MinVerOK && MaxVerOK)
91 return {true, {}, {}, ReqMinVer, ReqMaxVer};
92 return {false, {}, {}, VersionTuple(), VersionTuple()};
93 }
94 } else if (MinVerOK && MaxVerOK) {
95 if (ReqCaps.size() == 1) {
96 auto Cap = ReqCaps[0];
97 if (Reqs.isCapabilityAvailable(Cap)) {
98 ReqExts.append(RHS: getSymbolicOperandExtensions(
99 Category: SPIRV::OperandCategory::CapabilityOperand, Value: Cap));
100 return {true, {Cap}, std::move(ReqExts), ReqMinVer, ReqMaxVer};
101 }
102 } else {
103 // By SPIR-V specification: "If an instruction, enumerant, or other
104 // feature specifies multiple enabling capabilities, only one such
105 // capability needs to be declared to use the feature." However, one
106 // capability may be preferred over another. We use command line
107 // argument(s) and AvoidCapabilities to avoid selection of certain
108 // capabilities if there are other options.
109 CapabilityList UseCaps;
110 for (auto Cap : ReqCaps)
111 if (Reqs.isCapabilityAvailable(Cap))
112 UseCaps.push_back(Elt: Cap);
113 for (size_t i = 0, Sz = UseCaps.size(); i < Sz; ++i) {
114 auto Cap = UseCaps[i];
115 if (i == Sz - 1 || !AvoidCaps.S.contains(V: Cap)) {
116 ReqExts.append(RHS: getSymbolicOperandExtensions(
117 Category: SPIRV::OperandCategory::CapabilityOperand, Value: Cap));
118 return {true, {Cap}, std::move(ReqExts), ReqMinVer, ReqMaxVer};
119 }
120 }
121 }
122 }
123 // If there are no capabilities, or we can't satisfy the version or
124 // capability requirements, use the list of extensions (if the subtarget
125 // can handle them all).
126 if (llvm::all_of(Range&: ReqExts, P: [&ST](const SPIRV::Extension::Extension &Ext) {
127 return ST.canUseExtension(E: Ext);
128 })) {
129 return {true,
130 {},
131 std::move(ReqExts),
132 VersionTuple(),
133 VersionTuple()}; // TODO: add versions to extensions.
134 }
135 return {false, {}, {}, VersionTuple(), VersionTuple()};
136}
137
138void SPIRVModuleAnalysis::setBaseInfo(const Module &M) {
139 MAI.MaxID = 0;
140 for (int i = 0; i < SPIRV::NUM_MODULE_SECTIONS; i++)
141 MAI.MS[i].clear();
142 MAI.RegisterAliasTable.clear();
143 MAI.InstrsToDelete.clear();
144 MAI.FuncMap.clear();
145 MAI.GlobalVarList.clear();
146 MAI.ExtInstSetMap.clear();
147 MAI.Reqs.clear();
148 MAI.Reqs.initAvailableCapabilities(ST: *ST);
149
150 // TODO: determine memory model and source language from the configuratoin.
151 if (auto MemModel = M.getNamedMetadata(Name: "spirv.MemoryModel")) {
152 auto MemMD = MemModel->getOperand(i: 0);
153 MAI.Addr = static_cast<SPIRV::AddressingModel::AddressingModel>(
154 getMetadataUInt(MdNode: MemMD, OpIndex: 0));
155 MAI.Mem =
156 static_cast<SPIRV::MemoryModel::MemoryModel>(getMetadataUInt(MdNode: MemMD, OpIndex: 1));
157 } else {
158 // TODO: Add support for VulkanMemoryModel.
159 MAI.Mem = ST->isShader() ? SPIRV::MemoryModel::GLSL450
160 : SPIRV::MemoryModel::OpenCL;
161 if (MAI.Mem == SPIRV::MemoryModel::OpenCL) {
162 unsigned PtrSize = ST->getPointerSize();
163 MAI.Addr = PtrSize == 32 ? SPIRV::AddressingModel::Physical32
164 : PtrSize == 64 ? SPIRV::AddressingModel::Physical64
165 : SPIRV::AddressingModel::Logical;
166 } else {
167 // TODO: Add support for PhysicalStorageBufferAddress.
168 MAI.Addr = SPIRV::AddressingModel::Logical;
169 }
170 }
171 // Get the OpenCL version number from metadata.
172 // TODO: support other source languages.
173 if (auto VerNode = M.getNamedMetadata(Name: "opencl.ocl.version")) {
174 MAI.SrcLang = SPIRV::SourceLanguage::OpenCL_C;
175 // Construct version literal in accordance with SPIRV-LLVM-Translator.
176 // TODO: support multiple OCL version metadata.
177 assert(VerNode->getNumOperands() > 0 && "Invalid SPIR");
178 auto VersionMD = VerNode->getOperand(i: 0);
179 unsigned MajorNum = getMetadataUInt(MdNode: VersionMD, OpIndex: 0, DefaultVal: 2);
180 unsigned MinorNum = getMetadataUInt(MdNode: VersionMD, OpIndex: 1);
181 unsigned RevNum = getMetadataUInt(MdNode: VersionMD, OpIndex: 2);
182 // Prevent Major part of OpenCL version to be 0
183 MAI.SrcLangVersion =
184 (std::max(a: 1U, b: MajorNum) * 100 + MinorNum) * 1000 + RevNum;
185 } else {
186 // If there is no information about OpenCL version we are forced to generate
187 // OpenCL 1.0 by default for the OpenCL environment to avoid puzzling
188 // run-times with Unknown/0.0 version output. For a reference, LLVM-SPIRV
189 // Translator avoids potential issues with run-times in a similar manner.
190 if (!ST->isShader()) {
191 MAI.SrcLang = SPIRV::SourceLanguage::OpenCL_CPP;
192 MAI.SrcLangVersion = 100000;
193 } else {
194 MAI.SrcLang = SPIRV::SourceLanguage::Unknown;
195 MAI.SrcLangVersion = 0;
196 }
197 }
198
199 if (auto ExtNode = M.getNamedMetadata(Name: "opencl.used.extensions")) {
200 for (unsigned I = 0, E = ExtNode->getNumOperands(); I != E; ++I) {
201 MDNode *MD = ExtNode->getOperand(i: I);
202 if (!MD || MD->getNumOperands() == 0)
203 continue;
204 for (unsigned J = 0, N = MD->getNumOperands(); J != N; ++J)
205 MAI.SrcExt.insert(key: cast<MDString>(Val: MD->getOperand(I: J))->getString());
206 }
207 }
208
209 // Update required capabilities for this memory model, addressing model and
210 // source language.
211 MAI.Reqs.getAndAddRequirements(Category: SPIRV::OperandCategory::MemoryModelOperand,
212 i: MAI.Mem, ST: *ST);
213 MAI.Reqs.getAndAddRequirements(Category: SPIRV::OperandCategory::SourceLanguageOperand,
214 i: MAI.SrcLang, ST: *ST);
215 MAI.Reqs.getAndAddRequirements(Category: SPIRV::OperandCategory::AddressingModelOperand,
216 i: MAI.Addr, ST: *ST);
217
218 if (!ST->isShader()) {
219 // TODO: check if it's required by default.
220 MAI.ExtInstSetMap[static_cast<unsigned>(
221 SPIRV::InstructionSet::OpenCL_std)] = MAI.getNextIDRegister();
222 }
223}
224
225// Appends the signature of the decoration instructions that decorate R to
226// Signature.
227static void appendDecorationsForReg(const MachineRegisterInfo &MRI, Register R,
228 InstrSignature &Signature) {
229 for (MachineInstr &UseMI : MRI.use_instructions(Reg: R)) {
230 // We don't handle OpDecorateId because getting the register alias for the
231 // ID can cause problems, and we do not need it for now.
232 if (UseMI.getOpcode() != SPIRV::OpDecorate &&
233 UseMI.getOpcode() != SPIRV::OpMemberDecorate)
234 continue;
235
236 for (unsigned I = 0; I < UseMI.getNumOperands(); ++I) {
237 const MachineOperand &MO = UseMI.getOperand(i: I);
238 if (MO.isReg())
239 continue;
240 Signature.push_back(Elt: hash_value(MO));
241 }
242 }
243}
244
245// Returns a representation of an instruction as a vector of MachineOperand
246// hash values, see llvm::hash_value(const MachineOperand &MO) for details.
247// This creates a signature of the instruction with the same content
248// that MachineOperand::isIdenticalTo uses for comparison.
249static InstrSignature instrToSignature(const MachineInstr &MI,
250 SPIRV::ModuleAnalysisInfo &MAI,
251 bool UseDefReg) {
252 Register DefReg;
253 InstrSignature Signature{MI.getOpcode()};
254 for (unsigned i = 0; i < MI.getNumOperands(); ++i) {
255 // The only decorations that can be applied more than once to a given <id>
256 // or structure member are FuncParamAttr (38), UserSemantic (5635),
257 // CacheControlLoadINTEL (6442), and CacheControlStoreINTEL (6443). For all
258 // the rest of decorations, we will only add to the signature the Opcode,
259 // the id to which it applies, and the decoration id, disregarding any
260 // decoration flags. This will ensure that any subsequent decoration with
261 // the same id will be deemed as a duplicate. Then, at the call site, we
262 // will be able to handle duplicates in the best way.
263 unsigned Opcode = MI.getOpcode();
264 if ((Opcode == SPIRV::OpDecorate) && i >= 2) {
265 unsigned DecorationID = MI.getOperand(i: 1).getImm();
266 if (DecorationID != SPIRV::Decoration::FuncParamAttr &&
267 DecorationID != SPIRV::Decoration::UserSemantic &&
268 DecorationID != SPIRV::Decoration::CacheControlLoadINTEL &&
269 DecorationID != SPIRV::Decoration::CacheControlStoreINTEL)
270 continue;
271 }
272 const MachineOperand &MO = MI.getOperand(i);
273 size_t h;
274 if (MO.isReg()) {
275 if (!UseDefReg && MO.isDef()) {
276 assert(!DefReg.isValid() && "Multiple def registers.");
277 DefReg = MO.getReg();
278 continue;
279 }
280 Register RegAlias = MAI.getRegisterAlias(MF: MI.getMF(), Reg: MO.getReg());
281 if (!RegAlias.isValid()) {
282 LLVM_DEBUG({
283 dbgs() << "Unexpectedly, no global id found for the operand ";
284 MO.print(dbgs());
285 dbgs() << "\nInstruction: ";
286 MI.print(dbgs());
287 dbgs() << "\n";
288 });
289 report_fatal_error(reason: "All v-regs must have been mapped to global id's");
290 }
291 // mimic llvm::hash_value(const MachineOperand &MO)
292 h = hash_combine(args: MO.getType(), args: (unsigned)RegAlias, args: MO.getSubReg(),
293 args: MO.isDef());
294 } else {
295 h = hash_value(MO);
296 }
297 Signature.push_back(Elt: h);
298 }
299
300 if (DefReg.isValid()) {
301 // Decorations change the semantics of the current instruction. So two
302 // identical instruction with different decorations cannot be merged. That
303 // is why we add the decorations to the signature.
304 appendDecorationsForReg(MRI: MI.getMF()->getRegInfo(), R: DefReg, Signature);
305 }
306 return Signature;
307}
308
309bool SPIRVModuleAnalysis::isDeclSection(const MachineRegisterInfo &MRI,
310 const MachineInstr &MI) {
311 unsigned Opcode = MI.getOpcode();
312 switch (Opcode) {
313 case SPIRV::OpTypeForwardPointer:
314 // omit now, collect later
315 return false;
316 case SPIRV::OpVariable:
317 return static_cast<SPIRV::StorageClass::StorageClass>(
318 MI.getOperand(i: 2).getImm()) != SPIRV::StorageClass::Function;
319 case SPIRV::OpFunction:
320 case SPIRV::OpFunctionParameter:
321 return true;
322 }
323 if (GR->hasConstFunPtr() && Opcode == SPIRV::OpUndef) {
324 Register DefReg = MI.getOperand(i: 0).getReg();
325 for (MachineInstr &UseMI : MRI.use_instructions(Reg: DefReg)) {
326 if (UseMI.getOpcode() != SPIRV::OpConstantFunctionPointerINTEL)
327 continue;
328 // it's a dummy definition, FP constant refers to a function,
329 // and this is resolved in another way; let's skip this definition
330 assert(UseMI.getOperand(2).isReg() &&
331 UseMI.getOperand(2).getReg() == DefReg);
332 MAI.setSkipEmission(&MI);
333 return false;
334 }
335 }
336 return TII->isTypeDeclInstr(MI) || TII->isConstantInstr(MI) ||
337 TII->isInlineAsmDefInstr(MI);
338}
339
340// This is a special case of a function pointer refering to a possibly
341// forward function declaration. The operand is a dummy OpUndef that
342// requires a special treatment.
343void SPIRVModuleAnalysis::visitFunPtrUse(
344 Register OpReg, InstrGRegsMap &SignatureToGReg,
345 std::map<const Value *, unsigned> &GlobalToGReg, const MachineFunction *MF,
346 const MachineInstr &MI) {
347 const MachineOperand *OpFunDef =
348 GR->getFunctionDefinitionByUse(Use: &MI.getOperand(i: 2));
349 assert(OpFunDef && OpFunDef->isReg());
350 // find the actual function definition and number it globally in advance
351 const MachineInstr *OpDefMI = OpFunDef->getParent();
352 assert(OpDefMI && OpDefMI->getOpcode() == SPIRV::OpFunction);
353 const MachineFunction *FunDefMF = OpDefMI->getParent()->getParent();
354 const MachineRegisterInfo &FunDefMRI = FunDefMF->getRegInfo();
355 do {
356 visitDecl(MRI: FunDefMRI, SignatureToGReg, GlobalToGReg, MF: FunDefMF, MI: *OpDefMI);
357 OpDefMI = OpDefMI->getNextNode();
358 } while (OpDefMI && (OpDefMI->getOpcode() == SPIRV::OpFunction ||
359 OpDefMI->getOpcode() == SPIRV::OpFunctionParameter));
360 // associate the function pointer with the newly assigned global number
361 MCRegister GlobalFunDefReg =
362 MAI.getRegisterAlias(MF: FunDefMF, Reg: OpFunDef->getReg());
363 assert(GlobalFunDefReg.isValid() &&
364 "Function definition must refer to a global register");
365 MAI.setRegisterAlias(MF, Reg: OpReg, AliasReg: GlobalFunDefReg);
366}
367
368// Depth first recursive traversal of dependencies. Repeated visits are guarded
369// by MAI.hasRegisterAlias().
370void SPIRVModuleAnalysis::visitDecl(
371 const MachineRegisterInfo &MRI, InstrGRegsMap &SignatureToGReg,
372 std::map<const Value *, unsigned> &GlobalToGReg, const MachineFunction *MF,
373 const MachineInstr &MI) {
374 unsigned Opcode = MI.getOpcode();
375
376 // Process each operand of the instruction to resolve dependencies
377 for (const MachineOperand &MO : MI.operands()) {
378 if (!MO.isReg() || MO.isDef())
379 continue;
380 Register OpReg = MO.getReg();
381 // Handle function pointers special case
382 if (Opcode == SPIRV::OpConstantFunctionPointerINTEL &&
383 MRI.getRegClass(Reg: OpReg) == &SPIRV::pIDRegClass) {
384 visitFunPtrUse(OpReg, SignatureToGReg, GlobalToGReg, MF, MI);
385 continue;
386 }
387 // Skip already processed instructions
388 if (MAI.hasRegisterAlias(MF, Reg: MO.getReg()))
389 continue;
390 // Recursively visit dependencies
391 if (const MachineInstr *OpDefMI = MRI.getUniqueVRegDef(Reg: OpReg)) {
392 if (isDeclSection(MRI, MI: *OpDefMI))
393 visitDecl(MRI, SignatureToGReg, GlobalToGReg, MF, MI: *OpDefMI);
394 continue;
395 }
396 // Handle the unexpected case of no unique definition for the SPIR-V
397 // instruction
398 LLVM_DEBUG({
399 dbgs() << "Unexpectedly, no unique definition for the operand ";
400 MO.print(dbgs());
401 dbgs() << "\nInstruction: ";
402 MI.print(dbgs());
403 dbgs() << "\n";
404 });
405 report_fatal_error(
406 reason: "No unique definition is found for the virtual register");
407 }
408
409 MCRegister GReg;
410 bool IsFunDef = false;
411 if (TII->isSpecConstantInstr(MI)) {
412 GReg = MAI.getNextIDRegister();
413 MAI.MS[SPIRV::MB_TypeConstVars].push_back(Elt: &MI);
414 } else if (Opcode == SPIRV::OpFunction ||
415 Opcode == SPIRV::OpFunctionParameter) {
416 GReg = handleFunctionOrParameter(MF, MI, GlobalToGReg, IsFunDef);
417 } else if (Opcode == SPIRV::OpTypeStruct ||
418 Opcode == SPIRV::OpConstantComposite) {
419 GReg = handleTypeDeclOrConstant(MI, SignatureToGReg);
420 const MachineInstr *NextInstr = MI.getNextNode();
421 while (NextInstr &&
422 ((Opcode == SPIRV::OpTypeStruct &&
423 NextInstr->getOpcode() == SPIRV::OpTypeStructContinuedINTEL) ||
424 (Opcode == SPIRV::OpConstantComposite &&
425 NextInstr->getOpcode() ==
426 SPIRV::OpConstantCompositeContinuedINTEL))) {
427 MCRegister Tmp = handleTypeDeclOrConstant(MI: *NextInstr, SignatureToGReg);
428 MAI.setRegisterAlias(MF, Reg: NextInstr->getOperand(i: 0).getReg(), AliasReg: Tmp);
429 MAI.setSkipEmission(NextInstr);
430 NextInstr = NextInstr->getNextNode();
431 }
432 } else if (TII->isTypeDeclInstr(MI) || TII->isConstantInstr(MI) ||
433 TII->isInlineAsmDefInstr(MI)) {
434 GReg = handleTypeDeclOrConstant(MI, SignatureToGReg);
435 } else if (Opcode == SPIRV::OpVariable) {
436 GReg = handleVariable(MF, MI, GlobalToGReg);
437 } else {
438 LLVM_DEBUG({
439 dbgs() << "\nInstruction: ";
440 MI.print(dbgs());
441 dbgs() << "\n";
442 });
443 llvm_unreachable("Unexpected instruction is visited");
444 }
445 MAI.setRegisterAlias(MF, Reg: MI.getOperand(i: 0).getReg(), AliasReg: GReg);
446 if (!IsFunDef)
447 MAI.setSkipEmission(&MI);
448}
449
450MCRegister SPIRVModuleAnalysis::handleFunctionOrParameter(
451 const MachineFunction *MF, const MachineInstr &MI,
452 std::map<const Value *, unsigned> &GlobalToGReg, bool &IsFunDef) {
453 const Value *GObj = GR->getGlobalObject(MF, R: MI.getOperand(i: 0).getReg());
454 assert(GObj && "Unregistered global definition");
455 const Function *F = dyn_cast<Function>(Val: GObj);
456 if (!F)
457 F = dyn_cast<Argument>(Val: GObj)->getParent();
458 assert(F && "Expected a reference to a function or an argument");
459 IsFunDef = !F->isDeclaration();
460 auto [It, Inserted] = GlobalToGReg.try_emplace(k: GObj);
461 if (!Inserted)
462 return It->second;
463 MCRegister GReg = MAI.getNextIDRegister();
464 It->second = GReg;
465 if (!IsFunDef)
466 MAI.MS[SPIRV::MB_ExtFuncDecls].push_back(Elt: &MI);
467 return GReg;
468}
469
470MCRegister
471SPIRVModuleAnalysis::handleTypeDeclOrConstant(const MachineInstr &MI,
472 InstrGRegsMap &SignatureToGReg) {
473 InstrSignature MISign = instrToSignature(MI, MAI, UseDefReg: false);
474 auto [It, Inserted] = SignatureToGReg.try_emplace(k: MISign);
475 if (!Inserted)
476 return It->second;
477 MCRegister GReg = MAI.getNextIDRegister();
478 It->second = GReg;
479 MAI.MS[SPIRV::MB_TypeConstVars].push_back(Elt: &MI);
480 return GReg;
481}
482
483MCRegister SPIRVModuleAnalysis::handleVariable(
484 const MachineFunction *MF, const MachineInstr &MI,
485 std::map<const Value *, unsigned> &GlobalToGReg) {
486 MAI.GlobalVarList.push_back(Elt: &MI);
487 const Value *GObj = GR->getGlobalObject(MF, R: MI.getOperand(i: 0).getReg());
488 assert(GObj && "Unregistered global definition");
489 auto [It, Inserted] = GlobalToGReg.try_emplace(k: GObj);
490 if (!Inserted)
491 return It->second;
492 MCRegister GReg = MAI.getNextIDRegister();
493 It->second = GReg;
494 MAI.MS[SPIRV::MB_TypeConstVars].push_back(Elt: &MI);
495 return GReg;
496}
497
498void SPIRVModuleAnalysis::collectDeclarations(const Module &M) {
499 InstrGRegsMap SignatureToGReg;
500 std::map<const Value *, unsigned> GlobalToGReg;
501 for (const Function &F : M) {
502 MachineFunction *MF = MMI->getMachineFunction(F);
503 if (!MF)
504 continue;
505 const MachineRegisterInfo &MRI = MF->getRegInfo();
506 unsigned PastHeader = 0;
507 for (MachineBasicBlock &MBB : *MF) {
508 for (MachineInstr &MI : MBB) {
509 if (MI.getNumOperands() == 0)
510 continue;
511 unsigned Opcode = MI.getOpcode();
512 if (Opcode == SPIRV::OpFunction) {
513 if (PastHeader == 0) {
514 PastHeader = 1;
515 continue;
516 }
517 } else if (Opcode == SPIRV::OpFunctionParameter) {
518 if (PastHeader < 2)
519 continue;
520 } else if (PastHeader > 0) {
521 PastHeader = 2;
522 }
523
524 const MachineOperand &DefMO = MI.getOperand(i: 0);
525 switch (Opcode) {
526 case SPIRV::OpExtension:
527 MAI.Reqs.addExtension(ToAdd: SPIRV::Extension::Extension(DefMO.getImm()));
528 MAI.setSkipEmission(&MI);
529 break;
530 case SPIRV::OpCapability:
531 MAI.Reqs.addCapability(ToAdd: SPIRV::Capability::Capability(DefMO.getImm()));
532 MAI.setSkipEmission(&MI);
533 if (PastHeader > 0)
534 PastHeader = 2;
535 break;
536 default:
537 if (DefMO.isReg() && isDeclSection(MRI, MI) &&
538 !MAI.hasRegisterAlias(MF, Reg: DefMO.getReg()))
539 visitDecl(MRI, SignatureToGReg, GlobalToGReg, MF, MI);
540 }
541 }
542 }
543 }
544}
545
546// Look for IDs declared with Import linkage, and map the corresponding function
547// to the register defining that variable (which will usually be the result of
548// an OpFunction). This lets us call externally imported functions using
549// the correct ID registers.
550void SPIRVModuleAnalysis::collectFuncNames(MachineInstr &MI,
551 const Function *F) {
552 if (MI.getOpcode() == SPIRV::OpDecorate) {
553 // If it's got Import linkage.
554 auto Dec = MI.getOperand(i: 1).getImm();
555 if (Dec == SPIRV::Decoration::LinkageAttributes) {
556 auto Lnk = MI.getOperand(i: MI.getNumOperands() - 1).getImm();
557 if (Lnk == SPIRV::LinkageType::Import) {
558 // Map imported function name to function ID register.
559 const Function *ImportedFunc =
560 F->getParent()->getFunction(Name: getStringImm(MI, StartIndex: 2));
561 Register Target = MI.getOperand(i: 0).getReg();
562 MAI.FuncMap[ImportedFunc] = MAI.getRegisterAlias(MF: MI.getMF(), Reg: Target);
563 }
564 }
565 } else if (MI.getOpcode() == SPIRV::OpFunction) {
566 // Record all internal OpFunction declarations.
567 Register Reg = MI.defs().begin()->getReg();
568 MCRegister GlobalReg = MAI.getRegisterAlias(MF: MI.getMF(), Reg);
569 assert(GlobalReg.isValid());
570 MAI.FuncMap[F] = GlobalReg;
571 }
572}
573
574// Collect the given instruction in the specified MS. We assume global register
575// numbering has already occurred by this point. We can directly compare reg
576// arguments when detecting duplicates.
577static void collectOtherInstr(MachineInstr &MI, SPIRV::ModuleAnalysisInfo &MAI,
578 SPIRV::ModuleSectionType MSType, InstrTraces &IS,
579 bool Append = true) {
580 MAI.setSkipEmission(&MI);
581 InstrSignature MISign = instrToSignature(MI, MAI, UseDefReg: true);
582 auto FoundMI = IS.insert(x: std::move(MISign));
583 if (!FoundMI.second) {
584 if (MI.getOpcode() == SPIRV::OpDecorate) {
585 assert(MI.getNumOperands() >= 2 &&
586 "Decoration instructions must have at least 2 operands");
587 assert(MSType == SPIRV::MB_Annotations &&
588 "Only OpDecorate instructions can be duplicates");
589 // For FPFastMathMode decoration, we need to merge the flags of the
590 // duplicate decoration with the original one, so we need to find the
591 // original instruction that has the same signature. For the rest of
592 // instructions, we will simply skip the duplicate.
593 if (MI.getOperand(i: 1).getImm() != SPIRV::Decoration::FPFastMathMode)
594 return; // Skip duplicates of other decorations.
595
596 const SPIRV::InstrList &Decorations = MAI.MS[MSType];
597 for (const MachineInstr *OrigMI : Decorations) {
598 if (instrToSignature(MI: *OrigMI, MAI, UseDefReg: true) == MISign) {
599 assert(OrigMI->getNumOperands() == MI.getNumOperands() &&
600 "Original instruction must have the same number of operands");
601 assert(
602 OrigMI->getNumOperands() == 3 &&
603 "FPFastMathMode decoration must have 3 operands for OpDecorate");
604 unsigned OrigFlags = OrigMI->getOperand(i: 2).getImm();
605 unsigned NewFlags = MI.getOperand(i: 2).getImm();
606 if (OrigFlags == NewFlags)
607 return; // No need to merge, the flags are the same.
608
609 // Emit warning about possible conflict between flags.
610 unsigned FinalFlags = OrigFlags | NewFlags;
611 llvm::errs()
612 << "Warning: Conflicting FPFastMathMode decoration flags "
613 "in instruction: "
614 << *OrigMI << "Original flags: " << OrigFlags
615 << ", new flags: " << NewFlags
616 << ". They will be merged on a best effort basis, but not "
617 "validated. Final flags: "
618 << FinalFlags << "\n";
619 MachineInstr *OrigMINonConst = const_cast<MachineInstr *>(OrigMI);
620 MachineOperand &OrigFlagsOp = OrigMINonConst->getOperand(i: 2);
621 OrigFlagsOp = MachineOperand::CreateImm(Val: FinalFlags);
622 return; // Merge done, so we found a duplicate; don't add it to MAI.MS
623 }
624 }
625 assert(false && "No original instruction found for the duplicate "
626 "OpDecorate, but we found one in IS.");
627 }
628 return; // insert failed, so we found a duplicate; don't add it to MAI.MS
629 }
630 // No duplicates, so add it.
631 if (Append)
632 MAI.MS[MSType].push_back(Elt: &MI);
633 else
634 MAI.MS[MSType].insert(I: MAI.MS[MSType].begin(), Elt: &MI);
635}
636
637// Some global instructions make reference to function-local ID regs, so cannot
638// be correctly collected until these registers are globally numbered.
639void SPIRVModuleAnalysis::processOtherInstrs(const Module &M) {
640 InstrTraces IS;
641 for (const Function &F : M) {
642 if (F.isDeclaration())
643 continue;
644 MachineFunction *MF = MMI->getMachineFunction(F);
645 assert(MF);
646
647 for (MachineBasicBlock &MBB : *MF)
648 for (MachineInstr &MI : MBB) {
649 if (MAI.getSkipEmission(MI: &MI))
650 continue;
651 const unsigned OpCode = MI.getOpcode();
652 if (OpCode == SPIRV::OpString) {
653 collectOtherInstr(MI, MAI, MSType: SPIRV::MB_DebugStrings, IS);
654 } else if (OpCode == SPIRV::OpExtInst && MI.getOperand(i: 2).isImm() &&
655 MI.getOperand(i: 2).getImm() ==
656 SPIRV::InstructionSet::
657 NonSemantic_Shader_DebugInfo_100) {
658 MachineOperand Ins = MI.getOperand(i: 3);
659 namespace NS = SPIRV::NonSemanticExtInst;
660 static constexpr int64_t GlobalNonSemanticDITy[] = {
661 NS::DebugSource, NS::DebugCompilationUnit, NS::DebugInfoNone,
662 NS::DebugTypeBasic, NS::DebugTypePointer};
663 bool IsGlobalDI = false;
664 for (unsigned Idx = 0; Idx < std::size(GlobalNonSemanticDITy); ++Idx)
665 IsGlobalDI |= Ins.getImm() == GlobalNonSemanticDITy[Idx];
666 if (IsGlobalDI)
667 collectOtherInstr(MI, MAI, MSType: SPIRV::MB_NonSemanticGlobalDI, IS);
668 } else if (OpCode == SPIRV::OpName || OpCode == SPIRV::OpMemberName) {
669 collectOtherInstr(MI, MAI, MSType: SPIRV::MB_DebugNames, IS);
670 } else if (OpCode == SPIRV::OpEntryPoint) {
671 collectOtherInstr(MI, MAI, MSType: SPIRV::MB_EntryPoints, IS);
672 } else if (TII->isAliasingInstr(MI)) {
673 collectOtherInstr(MI, MAI, MSType: SPIRV::MB_AliasingInsts, IS);
674 } else if (TII->isDecorationInstr(MI)) {
675 collectOtherInstr(MI, MAI, MSType: SPIRV::MB_Annotations, IS);
676 collectFuncNames(MI, F: &F);
677 } else if (TII->isConstantInstr(MI)) {
678 // Now OpSpecConstant*s are not in DT,
679 // but they need to be collected anyway.
680 collectOtherInstr(MI, MAI, MSType: SPIRV::MB_TypeConstVars, IS);
681 } else if (OpCode == SPIRV::OpFunction) {
682 collectFuncNames(MI, F: &F);
683 } else if (OpCode == SPIRV::OpTypeForwardPointer) {
684 collectOtherInstr(MI, MAI, MSType: SPIRV::MB_TypeConstVars, IS, Append: false);
685 }
686 }
687 }
688}
689
690// Number registers in all functions globally from 0 onwards and store
691// the result in global register alias table. Some registers are already
692// numbered.
693void SPIRVModuleAnalysis::numberRegistersGlobally(const Module &M) {
694 for (const Function &F : M) {
695 if (F.isDeclaration())
696 continue;
697 MachineFunction *MF = MMI->getMachineFunction(F);
698 assert(MF);
699 for (MachineBasicBlock &MBB : *MF) {
700 for (MachineInstr &MI : MBB) {
701 for (MachineOperand &Op : MI.operands()) {
702 if (!Op.isReg())
703 continue;
704 Register Reg = Op.getReg();
705 if (MAI.hasRegisterAlias(MF, Reg))
706 continue;
707 MCRegister NewReg = MAI.getNextIDRegister();
708 MAI.setRegisterAlias(MF, Reg, AliasReg: NewReg);
709 }
710 if (MI.getOpcode() != SPIRV::OpExtInst)
711 continue;
712 auto Set = MI.getOperand(i: 2).getImm();
713 auto [It, Inserted] = MAI.ExtInstSetMap.try_emplace(Key: Set);
714 if (Inserted)
715 It->second = MAI.getNextIDRegister();
716 }
717 }
718 }
719}
720
721// RequirementHandler implementations.
722void SPIRV::RequirementHandler::getAndAddRequirements(
723 SPIRV::OperandCategory::OperandCategory Category, uint32_t i,
724 const SPIRVSubtarget &ST) {
725 addRequirements(Req: getSymbolicOperandRequirements(Category, i, ST, Reqs&: *this));
726}
727
728void SPIRV::RequirementHandler::recursiveAddCapabilities(
729 const CapabilityList &ToPrune) {
730 for (const auto &Cap : ToPrune) {
731 AllCaps.insert(V: Cap);
732 CapabilityList ImplicitDecls =
733 getSymbolicOperandCapabilities(Category: OperandCategory::CapabilityOperand, Value: Cap);
734 recursiveAddCapabilities(ToPrune: ImplicitDecls);
735 }
736}
737
738void SPIRV::RequirementHandler::addCapabilities(const CapabilityList &ToAdd) {
739 for (const auto &Cap : ToAdd) {
740 bool IsNewlyInserted = AllCaps.insert(V: Cap).second;
741 if (!IsNewlyInserted) // Don't re-add if it's already been declared.
742 continue;
743 CapabilityList ImplicitDecls =
744 getSymbolicOperandCapabilities(Category: OperandCategory::CapabilityOperand, Value: Cap);
745 recursiveAddCapabilities(ToPrune: ImplicitDecls);
746 MinimalCaps.push_back(Elt: Cap);
747 }
748}
749
750void SPIRV::RequirementHandler::addRequirements(
751 const SPIRV::Requirements &Req) {
752 if (!Req.IsSatisfiable)
753 report_fatal_error(reason: "Adding SPIR-V requirements this target can't satisfy.");
754
755 if (Req.Cap.has_value())
756 addCapabilities(ToAdd: {Req.Cap.value()});
757
758 addExtensions(ToAdd: Req.Exts);
759
760 if (!Req.MinVer.empty()) {
761 if (!MaxVersion.empty() && Req.MinVer > MaxVersion) {
762 LLVM_DEBUG(dbgs() << "Conflicting version requirements: >= " << Req.MinVer
763 << " and <= " << MaxVersion << "\n");
764 report_fatal_error(reason: "Adding SPIR-V requirements that can't be satisfied.");
765 }
766
767 if (MinVersion.empty() || Req.MinVer > MinVersion)
768 MinVersion = Req.MinVer;
769 }
770
771 if (!Req.MaxVer.empty()) {
772 if (!MinVersion.empty() && Req.MaxVer < MinVersion) {
773 LLVM_DEBUG(dbgs() << "Conflicting version requirements: <= " << Req.MaxVer
774 << " and >= " << MinVersion << "\n");
775 report_fatal_error(reason: "Adding SPIR-V requirements that can't be satisfied.");
776 }
777
778 if (MaxVersion.empty() || Req.MaxVer < MaxVersion)
779 MaxVersion = Req.MaxVer;
780 }
781}
782
783void SPIRV::RequirementHandler::checkSatisfiable(
784 const SPIRVSubtarget &ST) const {
785 // Report as many errors as possible before aborting the compilation.
786 bool IsSatisfiable = true;
787 auto TargetVer = ST.getSPIRVVersion();
788
789 if (!MaxVersion.empty() && !TargetVer.empty() && MaxVersion < TargetVer) {
790 LLVM_DEBUG(
791 dbgs() << "Target SPIR-V version too high for required features\n"
792 << "Required max version: " << MaxVersion << " target version "
793 << TargetVer << "\n");
794 IsSatisfiable = false;
795 }
796
797 if (!MinVersion.empty() && !TargetVer.empty() && MinVersion > TargetVer) {
798 LLVM_DEBUG(dbgs() << "Target SPIR-V version too low for required features\n"
799 << "Required min version: " << MinVersion
800 << " target version " << TargetVer << "\n");
801 IsSatisfiable = false;
802 }
803
804 if (!MinVersion.empty() && !MaxVersion.empty() && MinVersion > MaxVersion) {
805 LLVM_DEBUG(
806 dbgs()
807 << "Version is too low for some features and too high for others.\n"
808 << "Required SPIR-V min version: " << MinVersion
809 << " required SPIR-V max version " << MaxVersion << "\n");
810 IsSatisfiable = false;
811 }
812
813 AvoidCapabilitiesSet AvoidCaps;
814 if (!ST.isShader())
815 AvoidCaps.S.insert(V: SPIRV::Capability::Shader);
816 else
817 AvoidCaps.S.insert(V: SPIRV::Capability::Kernel);
818
819 for (auto Cap : MinimalCaps) {
820 if (AvailableCaps.contains(V: Cap) && !AvoidCaps.S.contains(V: Cap))
821 continue;
822 LLVM_DEBUG(dbgs() << "Capability not supported: "
823 << getSymbolicOperandMnemonic(
824 OperandCategory::CapabilityOperand, Cap)
825 << "\n");
826 IsSatisfiable = false;
827 }
828
829 for (auto Ext : AllExtensions) {
830 if (ST.canUseExtension(E: Ext))
831 continue;
832 LLVM_DEBUG(dbgs() << "Extension not supported: "
833 << getSymbolicOperandMnemonic(
834 OperandCategory::ExtensionOperand, Ext)
835 << "\n");
836 IsSatisfiable = false;
837 }
838
839 if (!IsSatisfiable)
840 report_fatal_error(reason: "Unable to meet SPIR-V requirements for this target.");
841}
842
843// Add the given capabilities and all their implicitly defined capabilities too.
844void SPIRV::RequirementHandler::addAvailableCaps(const CapabilityList &ToAdd) {
845 for (const auto Cap : ToAdd)
846 if (AvailableCaps.insert(V: Cap).second)
847 addAvailableCaps(ToAdd: getSymbolicOperandCapabilities(
848 Category: SPIRV::OperandCategory::CapabilityOperand, Value: Cap));
849}
850
851void SPIRV::RequirementHandler::removeCapabilityIf(
852 const Capability::Capability ToRemove,
853 const Capability::Capability IfPresent) {
854 if (AllCaps.contains(V: IfPresent))
855 AllCaps.erase(V: ToRemove);
856}
857
858namespace llvm {
859namespace SPIRV {
860void RequirementHandler::initAvailableCapabilities(const SPIRVSubtarget &ST) {
861 // Provided by both all supported Vulkan versions and OpenCl.
862 addAvailableCaps(ToAdd: {Capability::Shader, Capability::Linkage, Capability::Int8,
863 Capability::Int16});
864
865 if (ST.isAtLeastSPIRVVer(VerToCompareTo: VersionTuple(1, 3)))
866 addAvailableCaps(ToAdd: {Capability::GroupNonUniform,
867 Capability::GroupNonUniformVote,
868 Capability::GroupNonUniformArithmetic,
869 Capability::GroupNonUniformBallot,
870 Capability::GroupNonUniformClustered,
871 Capability::GroupNonUniformShuffle,
872 Capability::GroupNonUniformShuffleRelative});
873
874 if (ST.isAtLeastSPIRVVer(VerToCompareTo: VersionTuple(1, 6)))
875 addAvailableCaps(ToAdd: {Capability::DotProduct, Capability::DotProductInputAll,
876 Capability::DotProductInput4x8Bit,
877 Capability::DotProductInput4x8BitPacked,
878 Capability::DemoteToHelperInvocation});
879
880 // Add capabilities enabled by extensions.
881 for (auto Extension : ST.getAllAvailableExtensions()) {
882 CapabilityList EnabledCapabilities =
883 getCapabilitiesEnabledByExtension(Extension);
884 addAvailableCaps(ToAdd: EnabledCapabilities);
885 }
886
887 if (!ST.isShader()) {
888 initAvailableCapabilitiesForOpenCL(ST);
889 return;
890 }
891
892 if (ST.isShader()) {
893 initAvailableCapabilitiesForVulkan(ST);
894 return;
895 }
896
897 report_fatal_error(reason: "Unimplemented environment for SPIR-V generation.");
898}
899
900void RequirementHandler::initAvailableCapabilitiesForOpenCL(
901 const SPIRVSubtarget &ST) {
902 // Add the min requirements for different OpenCL and SPIR-V versions.
903 addAvailableCaps(ToAdd: {Capability::Addresses, Capability::Float16Buffer,
904 Capability::Kernel, Capability::Vector16,
905 Capability::Groups, Capability::GenericPointer,
906 Capability::StorageImageWriteWithoutFormat,
907 Capability::StorageImageReadWithoutFormat});
908 if (ST.hasOpenCLFullProfile())
909 addAvailableCaps(ToAdd: {Capability::Int64, Capability::Int64Atomics});
910 if (ST.hasOpenCLImageSupport()) {
911 addAvailableCaps(ToAdd: {Capability::ImageBasic, Capability::LiteralSampler,
912 Capability::Image1D, Capability::SampledBuffer,
913 Capability::ImageBuffer});
914 if (ST.isAtLeastOpenCLVer(VerToCompareTo: VersionTuple(2, 0)))
915 addAvailableCaps(ToAdd: {Capability::ImageReadWrite});
916 }
917 if (ST.isAtLeastSPIRVVer(VerToCompareTo: VersionTuple(1, 1)) &&
918 ST.isAtLeastOpenCLVer(VerToCompareTo: VersionTuple(2, 2)))
919 addAvailableCaps(ToAdd: {Capability::SubgroupDispatch, Capability::PipeStorage});
920 if (ST.isAtLeastSPIRVVer(VerToCompareTo: VersionTuple(1, 4)))
921 addAvailableCaps(ToAdd: {Capability::DenormPreserve, Capability::DenormFlushToZero,
922 Capability::SignedZeroInfNanPreserve,
923 Capability::RoundingModeRTE,
924 Capability::RoundingModeRTZ});
925 // TODO: verify if this needs some checks.
926 addAvailableCaps(ToAdd: {Capability::Float16, Capability::Float64});
927
928 // TODO: add OpenCL extensions.
929}
930
931void RequirementHandler::initAvailableCapabilitiesForVulkan(
932 const SPIRVSubtarget &ST) {
933
934 // Core in Vulkan 1.1 and earlier.
935 addAvailableCaps(ToAdd: {Capability::Int64, Capability::Float16, Capability::Float64,
936 Capability::GroupNonUniform, Capability::Image1D,
937 Capability::SampledBuffer, Capability::ImageBuffer,
938 Capability::UniformBufferArrayDynamicIndexing,
939 Capability::SampledImageArrayDynamicIndexing,
940 Capability::StorageBufferArrayDynamicIndexing,
941 Capability::StorageImageArrayDynamicIndexing,
942 Capability::DerivativeControl});
943
944 // Became core in Vulkan 1.2
945 if (ST.isAtLeastSPIRVVer(VerToCompareTo: VersionTuple(1, 5))) {
946 addAvailableCaps(
947 ToAdd: {Capability::ShaderNonUniformEXT, Capability::RuntimeDescriptorArrayEXT,
948 Capability::InputAttachmentArrayDynamicIndexingEXT,
949 Capability::UniformTexelBufferArrayDynamicIndexingEXT,
950 Capability::StorageTexelBufferArrayDynamicIndexingEXT,
951 Capability::UniformBufferArrayNonUniformIndexingEXT,
952 Capability::SampledImageArrayNonUniformIndexingEXT,
953 Capability::StorageBufferArrayNonUniformIndexingEXT,
954 Capability::StorageImageArrayNonUniformIndexingEXT,
955 Capability::InputAttachmentArrayNonUniformIndexingEXT,
956 Capability::UniformTexelBufferArrayNonUniformIndexingEXT,
957 Capability::StorageTexelBufferArrayNonUniformIndexingEXT});
958 }
959
960 // Became core in Vulkan 1.3
961 if (ST.isAtLeastSPIRVVer(VerToCompareTo: VersionTuple(1, 6)))
962 addAvailableCaps(ToAdd: {Capability::StorageImageWriteWithoutFormat,
963 Capability::StorageImageReadWithoutFormat});
964}
965
966} // namespace SPIRV
967} // namespace llvm
968
969// Add the required capabilities from a decoration instruction (including
970// BuiltIns).
971static void addOpDecorateReqs(const MachineInstr &MI, unsigned DecIndex,
972 SPIRV::RequirementHandler &Reqs,
973 const SPIRVSubtarget &ST) {
974 int64_t DecOp = MI.getOperand(i: DecIndex).getImm();
975 auto Dec = static_cast<SPIRV::Decoration::Decoration>(DecOp);
976 Reqs.addRequirements(Req: getSymbolicOperandRequirements(
977 Category: SPIRV::OperandCategory::DecorationOperand, i: Dec, ST, Reqs));
978
979 if (Dec == SPIRV::Decoration::BuiltIn) {
980 int64_t BuiltInOp = MI.getOperand(i: DecIndex + 1).getImm();
981 auto BuiltIn = static_cast<SPIRV::BuiltIn::BuiltIn>(BuiltInOp);
982 Reqs.addRequirements(Req: getSymbolicOperandRequirements(
983 Category: SPIRV::OperandCategory::BuiltInOperand, i: BuiltIn, ST, Reqs));
984 } else if (Dec == SPIRV::Decoration::LinkageAttributes) {
985 int64_t LinkageOp = MI.getOperand(i: MI.getNumOperands() - 1).getImm();
986 SPIRV::LinkageType::LinkageType LnkType =
987 static_cast<SPIRV::LinkageType::LinkageType>(LinkageOp);
988 if (LnkType == SPIRV::LinkageType::LinkOnceODR)
989 Reqs.addExtension(ToAdd: SPIRV::Extension::SPV_KHR_linkonce_odr);
990 } else if (Dec == SPIRV::Decoration::CacheControlLoadINTEL ||
991 Dec == SPIRV::Decoration::CacheControlStoreINTEL) {
992 Reqs.addExtension(ToAdd: SPIRV::Extension::SPV_INTEL_cache_controls);
993 } else if (Dec == SPIRV::Decoration::HostAccessINTEL) {
994 Reqs.addExtension(ToAdd: SPIRV::Extension::SPV_INTEL_global_variable_host_access);
995 } else if (Dec == SPIRV::Decoration::InitModeINTEL ||
996 Dec == SPIRV::Decoration::ImplementInRegisterMapINTEL) {
997 Reqs.addExtension(
998 ToAdd: SPIRV::Extension::SPV_INTEL_global_variable_fpga_decorations);
999 } else if (Dec == SPIRV::Decoration::NonUniformEXT) {
1000 Reqs.addRequirements(Req: SPIRV::Capability::ShaderNonUniformEXT);
1001 } else if (Dec == SPIRV::Decoration::FPMaxErrorDecorationINTEL) {
1002 Reqs.addRequirements(Req: SPIRV::Capability::FPMaxErrorINTEL);
1003 Reqs.addExtension(ToAdd: SPIRV::Extension::SPV_INTEL_fp_max_error);
1004 } else if (Dec == SPIRV::Decoration::FPFastMathMode) {
1005 if (ST.canUseExtension(E: SPIRV::Extension::SPV_KHR_float_controls2)) {
1006 Reqs.addRequirements(Req: SPIRV::Capability::FloatControls2);
1007 Reqs.addExtension(ToAdd: SPIRV::Extension::SPV_KHR_float_controls2);
1008 }
1009 }
1010}
1011
1012// Add requirements for image handling.
1013static void addOpTypeImageReqs(const MachineInstr &MI,
1014 SPIRV::RequirementHandler &Reqs,
1015 const SPIRVSubtarget &ST) {
1016 assert(MI.getNumOperands() >= 8 && "Insufficient operands for OpTypeImage");
1017 // The operand indices used here are based on the OpTypeImage layout, which
1018 // the MachineInstr follows as well.
1019 int64_t ImgFormatOp = MI.getOperand(i: 7).getImm();
1020 auto ImgFormat = static_cast<SPIRV::ImageFormat::ImageFormat>(ImgFormatOp);
1021 Reqs.getAndAddRequirements(Category: SPIRV::OperandCategory::ImageFormatOperand,
1022 i: ImgFormat, ST);
1023
1024 bool IsArrayed = MI.getOperand(i: 4).getImm() == 1;
1025 bool IsMultisampled = MI.getOperand(i: 5).getImm() == 1;
1026 bool NoSampler = MI.getOperand(i: 6).getImm() == 2;
1027 // Add dimension requirements.
1028 assert(MI.getOperand(2).isImm());
1029 switch (MI.getOperand(i: 2).getImm()) {
1030 case SPIRV::Dim::DIM_1D:
1031 Reqs.addRequirements(Req: NoSampler ? SPIRV::Capability::Image1D
1032 : SPIRV::Capability::Sampled1D);
1033 break;
1034 case SPIRV::Dim::DIM_2D:
1035 if (IsMultisampled && NoSampler)
1036 Reqs.addRequirements(Req: SPIRV::Capability::ImageMSArray);
1037 break;
1038 case SPIRV::Dim::DIM_Cube:
1039 Reqs.addRequirements(Req: SPIRV::Capability::Shader);
1040 if (IsArrayed)
1041 Reqs.addRequirements(Req: NoSampler ? SPIRV::Capability::ImageCubeArray
1042 : SPIRV::Capability::SampledCubeArray);
1043 break;
1044 case SPIRV::Dim::DIM_Rect:
1045 Reqs.addRequirements(Req: NoSampler ? SPIRV::Capability::ImageRect
1046 : SPIRV::Capability::SampledRect);
1047 break;
1048 case SPIRV::Dim::DIM_Buffer:
1049 Reqs.addRequirements(Req: NoSampler ? SPIRV::Capability::ImageBuffer
1050 : SPIRV::Capability::SampledBuffer);
1051 break;
1052 case SPIRV::Dim::DIM_SubpassData:
1053 Reqs.addRequirements(Req: SPIRV::Capability::InputAttachment);
1054 break;
1055 }
1056
1057 // Has optional access qualifier.
1058 if (!ST.isShader()) {
1059 if (MI.getNumOperands() > 8 &&
1060 MI.getOperand(i: 8).getImm() == SPIRV::AccessQualifier::ReadWrite)
1061 Reqs.addRequirements(Req: SPIRV::Capability::ImageReadWrite);
1062 else
1063 Reqs.addRequirements(Req: SPIRV::Capability::ImageBasic);
1064 }
1065}
1066
1067static bool isBFloat16Type(const SPIRVType *TypeDef) {
1068 return TypeDef && TypeDef->getNumOperands() == 3 &&
1069 TypeDef->getOpcode() == SPIRV::OpTypeFloat &&
1070 TypeDef->getOperand(i: 1).getImm() == 16 &&
1071 TypeDef->getOperand(i: 2).getImm() == SPIRV::FPEncoding::BFloat16KHR;
1072}
1073
1074// Add requirements for handling atomic float instructions
1075#define ATOM_FLT_REQ_EXT_MSG(ExtName) \
1076 "The atomic float instruction requires the following SPIR-V " \
1077 "extension: SPV_EXT_shader_atomic_float" ExtName
1078static void AddAtomicVectorFloatRequirements(const MachineInstr &MI,
1079 SPIRV::RequirementHandler &Reqs,
1080 const SPIRVSubtarget &ST) {
1081 SPIRVType *VecTypeDef =
1082 MI.getMF()->getRegInfo().getVRegDef(Reg: MI.getOperand(i: 1).getReg());
1083
1084 const unsigned Rank = VecTypeDef->getOperand(i: 2).getImm();
1085 if (Rank != 2 && Rank != 4)
1086 reportFatalUsageError(reason: "Result type of an atomic vector float instruction "
1087 "must be a 2-component or 4 component vector");
1088
1089 SPIRVType *EltTypeDef =
1090 MI.getMF()->getRegInfo().getVRegDef(Reg: VecTypeDef->getOperand(i: 1).getReg());
1091
1092 if (EltTypeDef->getOpcode() != SPIRV::OpTypeFloat ||
1093 EltTypeDef->getOperand(i: 1).getImm() != 16)
1094 reportFatalUsageError(
1095 reason: "The element type for the result type of an atomic vector float "
1096 "instruction must be a 16-bit floating-point scalar");
1097
1098 if (isBFloat16Type(TypeDef: EltTypeDef))
1099 reportFatalUsageError(
1100 reason: "The element type for the result type of an atomic vector float "
1101 "instruction cannot be a bfloat16 scalar");
1102 if (!ST.canUseExtension(E: SPIRV::Extension::SPV_NV_shader_atomic_fp16_vector))
1103 reportFatalUsageError(
1104 reason: "The atomic float16 vector instruction requires the following SPIR-V "
1105 "extension: SPV_NV_shader_atomic_fp16_vector");
1106
1107 Reqs.addExtension(ToAdd: SPIRV::Extension::SPV_NV_shader_atomic_fp16_vector);
1108 Reqs.addCapability(ToAdd: SPIRV::Capability::AtomicFloat16VectorNV);
1109}
1110
1111static void AddAtomicFloatRequirements(const MachineInstr &MI,
1112 SPIRV::RequirementHandler &Reqs,
1113 const SPIRVSubtarget &ST) {
1114 assert(MI.getOperand(1).isReg() &&
1115 "Expect register operand in atomic float instruction");
1116 Register TypeReg = MI.getOperand(i: 1).getReg();
1117 SPIRVType *TypeDef = MI.getMF()->getRegInfo().getVRegDef(Reg: TypeReg);
1118
1119 if (TypeDef->getOpcode() == SPIRV::OpTypeVector)
1120 return AddAtomicVectorFloatRequirements(MI, Reqs, ST);
1121
1122 if (TypeDef->getOpcode() != SPIRV::OpTypeFloat)
1123 report_fatal_error(reason: "Result type of an atomic float instruction must be a "
1124 "floating-point type scalar");
1125
1126 unsigned BitWidth = TypeDef->getOperand(i: 1).getImm();
1127 unsigned Op = MI.getOpcode();
1128 if (Op == SPIRV::OpAtomicFAddEXT) {
1129 if (!ST.canUseExtension(E: SPIRV::Extension::SPV_EXT_shader_atomic_float_add))
1130 report_fatal_error(ATOM_FLT_REQ_EXT_MSG("_add"), gen_crash_diag: false);
1131 Reqs.addExtension(ToAdd: SPIRV::Extension::SPV_EXT_shader_atomic_float_add);
1132 switch (BitWidth) {
1133 case 16:
1134 if (isBFloat16Type(TypeDef)) {
1135 if (!ST.canUseExtension(E: SPIRV::Extension::SPV_INTEL_16bit_atomics))
1136 report_fatal_error(
1137 reason: "The atomic bfloat16 instruction requires the following SPIR-V "
1138 "extension: SPV_INTEL_16bit_atomics",
1139 gen_crash_diag: false);
1140 Reqs.addExtension(ToAdd: SPIRV::Extension::SPV_INTEL_16bit_atomics);
1141 Reqs.addCapability(ToAdd: SPIRV::Capability::AtomicBFloat16AddINTEL);
1142 } else {
1143 if (!ST.canUseExtension(
1144 E: SPIRV::Extension::SPV_EXT_shader_atomic_float16_add))
1145 report_fatal_error(ATOM_FLT_REQ_EXT_MSG("16_add"), gen_crash_diag: false);
1146 Reqs.addExtension(ToAdd: SPIRV::Extension::SPV_EXT_shader_atomic_float16_add);
1147 Reqs.addCapability(ToAdd: SPIRV::Capability::AtomicFloat16AddEXT);
1148 }
1149 break;
1150 case 32:
1151 Reqs.addCapability(ToAdd: SPIRV::Capability::AtomicFloat32AddEXT);
1152 break;
1153 case 64:
1154 Reqs.addCapability(ToAdd: SPIRV::Capability::AtomicFloat64AddEXT);
1155 break;
1156 default:
1157 report_fatal_error(
1158 reason: "Unexpected floating-point type width in atomic float instruction");
1159 }
1160 } else {
1161 if (!ST.canUseExtension(
1162 E: SPIRV::Extension::SPV_EXT_shader_atomic_float_min_max))
1163 report_fatal_error(ATOM_FLT_REQ_EXT_MSG("_min_max"), gen_crash_diag: false);
1164 Reqs.addExtension(ToAdd: SPIRV::Extension::SPV_EXT_shader_atomic_float_min_max);
1165 switch (BitWidth) {
1166 case 16:
1167 if (isBFloat16Type(TypeDef)) {
1168 if (!ST.canUseExtension(E: SPIRV::Extension::SPV_INTEL_16bit_atomics))
1169 report_fatal_error(
1170 reason: "The atomic bfloat16 instruction requires the following SPIR-V "
1171 "extension: SPV_INTEL_16bit_atomics",
1172 gen_crash_diag: false);
1173 Reqs.addExtension(ToAdd: SPIRV::Extension::SPV_INTEL_16bit_atomics);
1174 Reqs.addCapability(ToAdd: SPIRV::Capability::AtomicBFloat16MinMaxINTEL);
1175 } else {
1176 Reqs.addCapability(ToAdd: SPIRV::Capability::AtomicFloat16MinMaxEXT);
1177 }
1178 break;
1179 case 32:
1180 Reqs.addCapability(ToAdd: SPIRV::Capability::AtomicFloat32MinMaxEXT);
1181 break;
1182 case 64:
1183 Reqs.addCapability(ToAdd: SPIRV::Capability::AtomicFloat64MinMaxEXT);
1184 break;
1185 default:
1186 report_fatal_error(
1187 reason: "Unexpected floating-point type width in atomic float instruction");
1188 }
1189 }
1190}
1191
1192bool isUniformTexelBuffer(MachineInstr *ImageInst) {
1193 if (ImageInst->getOpcode() != SPIRV::OpTypeImage)
1194 return false;
1195 uint32_t Dim = ImageInst->getOperand(i: 2).getImm();
1196 uint32_t Sampled = ImageInst->getOperand(i: 6).getImm();
1197 return Dim == SPIRV::Dim::DIM_Buffer && Sampled == 1;
1198}
1199
1200bool isStorageTexelBuffer(MachineInstr *ImageInst) {
1201 if (ImageInst->getOpcode() != SPIRV::OpTypeImage)
1202 return false;
1203 uint32_t Dim = ImageInst->getOperand(i: 2).getImm();
1204 uint32_t Sampled = ImageInst->getOperand(i: 6).getImm();
1205 return Dim == SPIRV::Dim::DIM_Buffer && Sampled == 2;
1206}
1207
1208bool isSampledImage(MachineInstr *ImageInst) {
1209 if (ImageInst->getOpcode() != SPIRV::OpTypeImage)
1210 return false;
1211 uint32_t Dim = ImageInst->getOperand(i: 2).getImm();
1212 uint32_t Sampled = ImageInst->getOperand(i: 6).getImm();
1213 return Dim != SPIRV::Dim::DIM_Buffer && Sampled == 1;
1214}
1215
1216bool isInputAttachment(MachineInstr *ImageInst) {
1217 if (ImageInst->getOpcode() != SPIRV::OpTypeImage)
1218 return false;
1219 uint32_t Dim = ImageInst->getOperand(i: 2).getImm();
1220 uint32_t Sampled = ImageInst->getOperand(i: 6).getImm();
1221 return Dim == SPIRV::Dim::DIM_SubpassData && Sampled == 2;
1222}
1223
1224bool isStorageImage(MachineInstr *ImageInst) {
1225 if (ImageInst->getOpcode() != SPIRV::OpTypeImage)
1226 return false;
1227 uint32_t Dim = ImageInst->getOperand(i: 2).getImm();
1228 uint32_t Sampled = ImageInst->getOperand(i: 6).getImm();
1229 return Dim != SPIRV::Dim::DIM_Buffer && Sampled == 2;
1230}
1231
1232bool isCombinedImageSampler(MachineInstr *SampledImageInst) {
1233 if (SampledImageInst->getOpcode() != SPIRV::OpTypeSampledImage)
1234 return false;
1235
1236 const MachineRegisterInfo &MRI = SampledImageInst->getMF()->getRegInfo();
1237 Register ImageReg = SampledImageInst->getOperand(i: 1).getReg();
1238 auto *ImageInst = MRI.getUniqueVRegDef(Reg: ImageReg);
1239 return isSampledImage(ImageInst);
1240}
1241
1242bool hasNonUniformDecoration(Register Reg, const MachineRegisterInfo &MRI) {
1243 for (const auto &MI : MRI.reg_instructions(Reg)) {
1244 if (MI.getOpcode() != SPIRV::OpDecorate)
1245 continue;
1246
1247 uint32_t Dec = MI.getOperand(i: 1).getImm();
1248 if (Dec == SPIRV::Decoration::NonUniformEXT)
1249 return true;
1250 }
1251 return false;
1252}
1253
1254void addOpAccessChainReqs(const MachineInstr &Instr,
1255 SPIRV::RequirementHandler &Handler,
1256 const SPIRVSubtarget &Subtarget) {
1257 const MachineRegisterInfo &MRI = Instr.getMF()->getRegInfo();
1258 // Get the result type. If it is an image type, then the shader uses
1259 // descriptor indexing. The appropriate capabilities will be added based
1260 // on the specifics of the image.
1261 Register ResTypeReg = Instr.getOperand(i: 1).getReg();
1262 MachineInstr *ResTypeInst = MRI.getUniqueVRegDef(Reg: ResTypeReg);
1263
1264 assert(ResTypeInst->getOpcode() == SPIRV::OpTypePointer);
1265 uint32_t StorageClass = ResTypeInst->getOperand(i: 1).getImm();
1266 if (StorageClass != SPIRV::StorageClass::StorageClass::UniformConstant &&
1267 StorageClass != SPIRV::StorageClass::StorageClass::Uniform &&
1268 StorageClass != SPIRV::StorageClass::StorageClass::StorageBuffer) {
1269 return;
1270 }
1271
1272 bool IsNonUniform =
1273 hasNonUniformDecoration(Reg: Instr.getOperand(i: 0).getReg(), MRI);
1274
1275 auto FirstIndexReg = Instr.getOperand(i: 3).getReg();
1276 bool FirstIndexIsConstant =
1277 Subtarget.getInstrInfo()->isConstantInstr(MI: *MRI.getVRegDef(Reg: FirstIndexReg));
1278
1279 if (StorageClass == SPIRV::StorageClass::StorageClass::StorageBuffer) {
1280 if (IsNonUniform)
1281 Handler.addRequirements(
1282 Req: SPIRV::Capability::StorageBufferArrayNonUniformIndexingEXT);
1283 else if (!FirstIndexIsConstant)
1284 Handler.addRequirements(
1285 Req: SPIRV::Capability::StorageBufferArrayDynamicIndexing);
1286 return;
1287 }
1288
1289 Register PointeeTypeReg = ResTypeInst->getOperand(i: 2).getReg();
1290 MachineInstr *PointeeType = MRI.getUniqueVRegDef(Reg: PointeeTypeReg);
1291 if (PointeeType->getOpcode() != SPIRV::OpTypeImage &&
1292 PointeeType->getOpcode() != SPIRV::OpTypeSampledImage &&
1293 PointeeType->getOpcode() != SPIRV::OpTypeSampler) {
1294 return;
1295 }
1296
1297 if (isUniformTexelBuffer(ImageInst: PointeeType)) {
1298 if (IsNonUniform)
1299 Handler.addRequirements(
1300 Req: SPIRV::Capability::UniformTexelBufferArrayNonUniformIndexingEXT);
1301 else if (!FirstIndexIsConstant)
1302 Handler.addRequirements(
1303 Req: SPIRV::Capability::UniformTexelBufferArrayDynamicIndexingEXT);
1304 } else if (isInputAttachment(ImageInst: PointeeType)) {
1305 if (IsNonUniform)
1306 Handler.addRequirements(
1307 Req: SPIRV::Capability::InputAttachmentArrayNonUniformIndexingEXT);
1308 else if (!FirstIndexIsConstant)
1309 Handler.addRequirements(
1310 Req: SPIRV::Capability::InputAttachmentArrayDynamicIndexingEXT);
1311 } else if (isStorageTexelBuffer(ImageInst: PointeeType)) {
1312 if (IsNonUniform)
1313 Handler.addRequirements(
1314 Req: SPIRV::Capability::StorageTexelBufferArrayNonUniformIndexingEXT);
1315 else if (!FirstIndexIsConstant)
1316 Handler.addRequirements(
1317 Req: SPIRV::Capability::StorageTexelBufferArrayDynamicIndexingEXT);
1318 } else if (isSampledImage(ImageInst: PointeeType) ||
1319 isCombinedImageSampler(SampledImageInst: PointeeType) ||
1320 PointeeType->getOpcode() == SPIRV::OpTypeSampler) {
1321 if (IsNonUniform)
1322 Handler.addRequirements(
1323 Req: SPIRV::Capability::SampledImageArrayNonUniformIndexingEXT);
1324 else if (!FirstIndexIsConstant)
1325 Handler.addRequirements(
1326 Req: SPIRV::Capability::SampledImageArrayDynamicIndexing);
1327 } else if (isStorageImage(ImageInst: PointeeType)) {
1328 if (IsNonUniform)
1329 Handler.addRequirements(
1330 Req: SPIRV::Capability::StorageImageArrayNonUniformIndexingEXT);
1331 else if (!FirstIndexIsConstant)
1332 Handler.addRequirements(
1333 Req: SPIRV::Capability::StorageImageArrayDynamicIndexing);
1334 }
1335}
1336
1337static bool isImageTypeWithUnknownFormat(SPIRVType *TypeInst) {
1338 if (TypeInst->getOpcode() != SPIRV::OpTypeImage)
1339 return false;
1340 assert(TypeInst->getOperand(7).isImm() && "The image format must be an imm.");
1341 return TypeInst->getOperand(i: 7).getImm() == 0;
1342}
1343
1344static void AddDotProductRequirements(const MachineInstr &MI,
1345 SPIRV::RequirementHandler &Reqs,
1346 const SPIRVSubtarget &ST) {
1347 if (ST.canUseExtension(E: SPIRV::Extension::SPV_KHR_integer_dot_product))
1348 Reqs.addExtension(ToAdd: SPIRV::Extension::SPV_KHR_integer_dot_product);
1349 Reqs.addCapability(ToAdd: SPIRV::Capability::DotProduct);
1350
1351 const MachineRegisterInfo &MRI = MI.getMF()->getRegInfo();
1352 assert(MI.getOperand(2).isReg() && "Unexpected operand in dot");
1353 // We do not consider what the previous instruction is. This is just used
1354 // to get the input register and to check the type.
1355 const MachineInstr *Input = MRI.getVRegDef(Reg: MI.getOperand(i: 2).getReg());
1356 assert(Input->getOperand(1).isReg() && "Unexpected operand in dot input");
1357 Register InputReg = Input->getOperand(i: 1).getReg();
1358
1359 SPIRVType *TypeDef = MRI.getVRegDef(Reg: InputReg);
1360 if (TypeDef->getOpcode() == SPIRV::OpTypeInt) {
1361 assert(TypeDef->getOperand(1).getImm() == 32);
1362 Reqs.addCapability(ToAdd: SPIRV::Capability::DotProductInput4x8BitPacked);
1363 } else if (TypeDef->getOpcode() == SPIRV::OpTypeVector) {
1364 SPIRVType *ScalarTypeDef = MRI.getVRegDef(Reg: TypeDef->getOperand(i: 1).getReg());
1365 assert(ScalarTypeDef->getOpcode() == SPIRV::OpTypeInt);
1366 if (ScalarTypeDef->getOperand(i: 1).getImm() == 8) {
1367 assert(TypeDef->getOperand(2).getImm() == 4 &&
1368 "Dot operand of 8-bit integer type requires 4 components");
1369 Reqs.addCapability(ToAdd: SPIRV::Capability::DotProductInput4x8Bit);
1370 } else {
1371 Reqs.addCapability(ToAdd: SPIRV::Capability::DotProductInputAll);
1372 }
1373 }
1374}
1375
1376void addPrintfRequirements(const MachineInstr &MI,
1377 SPIRV::RequirementHandler &Reqs,
1378 const SPIRVSubtarget &ST) {
1379 SPIRVGlobalRegistry *GR = ST.getSPIRVGlobalRegistry();
1380 const SPIRVType *PtrType = GR->getSPIRVTypeForVReg(VReg: MI.getOperand(i: 4).getReg());
1381 if (PtrType) {
1382 MachineOperand ASOp = PtrType->getOperand(i: 1);
1383 if (ASOp.isImm()) {
1384 unsigned AddrSpace = ASOp.getImm();
1385 if (AddrSpace != SPIRV::StorageClass::UniformConstant) {
1386 if (!ST.canUseExtension(
1387 E: SPIRV::Extension::
1388 SPV_EXT_relaxed_printf_string_address_space)) {
1389 report_fatal_error(reason: "SPV_EXT_relaxed_printf_string_address_space is "
1390 "required because printf uses a format string not "
1391 "in constant address space.",
1392 gen_crash_diag: false);
1393 }
1394 Reqs.addExtension(
1395 ToAdd: SPIRV::Extension::SPV_EXT_relaxed_printf_string_address_space);
1396 }
1397 }
1398 }
1399}
1400
1401void addInstrRequirements(const MachineInstr &MI,
1402 SPIRV::ModuleAnalysisInfo &MAI,
1403 const SPIRVSubtarget &ST) {
1404 SPIRV::RequirementHandler &Reqs = MAI.Reqs;
1405 switch (MI.getOpcode()) {
1406 case SPIRV::OpMemoryModel: {
1407 int64_t Addr = MI.getOperand(i: 0).getImm();
1408 Reqs.getAndAddRequirements(Category: SPIRV::OperandCategory::AddressingModelOperand,
1409 i: Addr, ST);
1410 int64_t Mem = MI.getOperand(i: 1).getImm();
1411 Reqs.getAndAddRequirements(Category: SPIRV::OperandCategory::MemoryModelOperand, i: Mem,
1412 ST);
1413 break;
1414 }
1415 case SPIRV::OpEntryPoint: {
1416 int64_t Exe = MI.getOperand(i: 0).getImm();
1417 Reqs.getAndAddRequirements(Category: SPIRV::OperandCategory::ExecutionModelOperand,
1418 i: Exe, ST);
1419 break;
1420 }
1421 case SPIRV::OpExecutionMode:
1422 case SPIRV::OpExecutionModeId: {
1423 int64_t Exe = MI.getOperand(i: 1).getImm();
1424 Reqs.getAndAddRequirements(Category: SPIRV::OperandCategory::ExecutionModeOperand,
1425 i: Exe, ST);
1426 break;
1427 }
1428 case SPIRV::OpTypeMatrix:
1429 Reqs.addCapability(ToAdd: SPIRV::Capability::Matrix);
1430 break;
1431 case SPIRV::OpTypeInt: {
1432 unsigned BitWidth = MI.getOperand(i: 1).getImm();
1433 if (BitWidth == 64)
1434 Reqs.addCapability(ToAdd: SPIRV::Capability::Int64);
1435 else if (BitWidth == 16)
1436 Reqs.addCapability(ToAdd: SPIRV::Capability::Int16);
1437 else if (BitWidth == 8)
1438 Reqs.addCapability(ToAdd: SPIRV::Capability::Int8);
1439 else if (BitWidth == 4 &&
1440 ST.canUseExtension(E: SPIRV::Extension::SPV_INTEL_int4)) {
1441 Reqs.addExtension(ToAdd: SPIRV::Extension::SPV_INTEL_int4);
1442 Reqs.addCapability(ToAdd: SPIRV::Capability::Int4TypeINTEL);
1443 } else if (BitWidth != 32) {
1444 if (!ST.canUseExtension(
1445 E: SPIRV::Extension::SPV_ALTERA_arbitrary_precision_integers))
1446 reportFatalUsageError(
1447 reason: "OpTypeInt type with a width other than 8, 16, 32 or 64 bits "
1448 "requires the following SPIR-V extension: "
1449 "SPV_ALTERA_arbitrary_precision_integers");
1450 Reqs.addExtension(
1451 ToAdd: SPIRV::Extension::SPV_ALTERA_arbitrary_precision_integers);
1452 Reqs.addCapability(ToAdd: SPIRV::Capability::ArbitraryPrecisionIntegersALTERA);
1453 }
1454 break;
1455 }
1456 case SPIRV::OpDot: {
1457 const MachineRegisterInfo &MRI = MI.getMF()->getRegInfo();
1458 SPIRVType *TypeDef = MRI.getVRegDef(Reg: MI.getOperand(i: 1).getReg());
1459 if (isBFloat16Type(TypeDef))
1460 Reqs.addCapability(ToAdd: SPIRV::Capability::BFloat16DotProductKHR);
1461 break;
1462 }
1463 case SPIRV::OpTypeFloat: {
1464 unsigned BitWidth = MI.getOperand(i: 1).getImm();
1465 if (BitWidth == 64)
1466 Reqs.addCapability(ToAdd: SPIRV::Capability::Float64);
1467 else if (BitWidth == 16) {
1468 if (isBFloat16Type(TypeDef: &MI)) {
1469 if (!ST.canUseExtension(E: SPIRV::Extension::SPV_KHR_bfloat16))
1470 report_fatal_error(reason: "OpTypeFloat type with bfloat requires the "
1471 "following SPIR-V extension: SPV_KHR_bfloat16",
1472 gen_crash_diag: false);
1473 Reqs.addExtension(ToAdd: SPIRV::Extension::SPV_KHR_bfloat16);
1474 Reqs.addCapability(ToAdd: SPIRV::Capability::BFloat16TypeKHR);
1475 } else {
1476 Reqs.addCapability(ToAdd: SPIRV::Capability::Float16);
1477 }
1478 }
1479 break;
1480 }
1481 case SPIRV::OpTypeVector: {
1482 unsigned NumComponents = MI.getOperand(i: 2).getImm();
1483 if (NumComponents == 8 || NumComponents == 16)
1484 Reqs.addCapability(ToAdd: SPIRV::Capability::Vector16);
1485 break;
1486 }
1487 case SPIRV::OpTypePointer: {
1488 auto SC = MI.getOperand(i: 1).getImm();
1489 Reqs.getAndAddRequirements(Category: SPIRV::OperandCategory::StorageClassOperand, i: SC,
1490 ST);
1491 // If it's a type of pointer to float16 targeting OpenCL, add Float16Buffer
1492 // capability.
1493 if (ST.isShader())
1494 break;
1495 assert(MI.getOperand(2).isReg());
1496 const MachineRegisterInfo &MRI = MI.getMF()->getRegInfo();
1497 SPIRVType *TypeDef = MRI.getVRegDef(Reg: MI.getOperand(i: 2).getReg());
1498 if ((TypeDef->getNumOperands() == 2) &&
1499 (TypeDef->getOpcode() == SPIRV::OpTypeFloat) &&
1500 (TypeDef->getOperand(i: 1).getImm() == 16))
1501 Reqs.addCapability(ToAdd: SPIRV::Capability::Float16Buffer);
1502 break;
1503 }
1504 case SPIRV::OpExtInst: {
1505 if (MI.getOperand(i: 2).getImm() ==
1506 static_cast<int64_t>(
1507 SPIRV::InstructionSet::NonSemantic_Shader_DebugInfo_100)) {
1508 Reqs.addExtension(ToAdd: SPIRV::Extension::SPV_KHR_non_semantic_info);
1509 break;
1510 }
1511 if (MI.getOperand(i: 3).getImm() ==
1512 static_cast<int64_t>(SPIRV::OpenCLExtInst::printf)) {
1513 addPrintfRequirements(MI, Reqs, ST);
1514 break;
1515 }
1516 // TODO: handle bfloat16 extended instructions when
1517 // SPV_INTEL_bfloat16_arithmetic is enabled.
1518 break;
1519 }
1520 case SPIRV::OpAliasDomainDeclINTEL:
1521 case SPIRV::OpAliasScopeDeclINTEL:
1522 case SPIRV::OpAliasScopeListDeclINTEL: {
1523 Reqs.addExtension(ToAdd: SPIRV::Extension::SPV_INTEL_memory_access_aliasing);
1524 Reqs.addCapability(ToAdd: SPIRV::Capability::MemoryAccessAliasingINTEL);
1525 break;
1526 }
1527 case SPIRV::OpBitReverse:
1528 case SPIRV::OpBitFieldInsert:
1529 case SPIRV::OpBitFieldSExtract:
1530 case SPIRV::OpBitFieldUExtract:
1531 if (!ST.canUseExtension(E: SPIRV::Extension::SPV_KHR_bit_instructions)) {
1532 Reqs.addCapability(ToAdd: SPIRV::Capability::Shader);
1533 break;
1534 }
1535 Reqs.addExtension(ToAdd: SPIRV::Extension::SPV_KHR_bit_instructions);
1536 Reqs.addCapability(ToAdd: SPIRV::Capability::BitInstructions);
1537 break;
1538 case SPIRV::OpTypeRuntimeArray:
1539 Reqs.addCapability(ToAdd: SPIRV::Capability::Shader);
1540 break;
1541 case SPIRV::OpTypeOpaque:
1542 case SPIRV::OpTypeEvent:
1543 Reqs.addCapability(ToAdd: SPIRV::Capability::Kernel);
1544 break;
1545 case SPIRV::OpTypePipe:
1546 case SPIRV::OpTypeReserveId:
1547 Reqs.addCapability(ToAdd: SPIRV::Capability::Pipes);
1548 break;
1549 case SPIRV::OpTypeDeviceEvent:
1550 case SPIRV::OpTypeQueue:
1551 case SPIRV::OpBuildNDRange:
1552 Reqs.addCapability(ToAdd: SPIRV::Capability::DeviceEnqueue);
1553 break;
1554 case SPIRV::OpDecorate:
1555 case SPIRV::OpDecorateId:
1556 case SPIRV::OpDecorateString:
1557 addOpDecorateReqs(MI, DecIndex: 1, Reqs, ST);
1558 break;
1559 case SPIRV::OpMemberDecorate:
1560 case SPIRV::OpMemberDecorateString:
1561 addOpDecorateReqs(MI, DecIndex: 2, Reqs, ST);
1562 break;
1563 case SPIRV::OpInBoundsPtrAccessChain:
1564 Reqs.addCapability(ToAdd: SPIRV::Capability::Addresses);
1565 break;
1566 case SPIRV::OpConstantSampler:
1567 Reqs.addCapability(ToAdd: SPIRV::Capability::LiteralSampler);
1568 break;
1569 case SPIRV::OpInBoundsAccessChain:
1570 case SPIRV::OpAccessChain:
1571 addOpAccessChainReqs(Instr: MI, Handler&: Reqs, Subtarget: ST);
1572 break;
1573 case SPIRV::OpTypeImage:
1574 addOpTypeImageReqs(MI, Reqs, ST);
1575 break;
1576 case SPIRV::OpTypeSampler:
1577 if (!ST.isShader()) {
1578 Reqs.addCapability(ToAdd: SPIRV::Capability::ImageBasic);
1579 }
1580 break;
1581 case SPIRV::OpTypeForwardPointer:
1582 // TODO: check if it's OpenCL's kernel.
1583 Reqs.addCapability(ToAdd: SPIRV::Capability::Addresses);
1584 break;
1585 case SPIRV::OpAtomicFlagTestAndSet:
1586 case SPIRV::OpAtomicLoad:
1587 case SPIRV::OpAtomicStore:
1588 case SPIRV::OpAtomicExchange:
1589 case SPIRV::OpAtomicCompareExchange:
1590 case SPIRV::OpAtomicIIncrement:
1591 case SPIRV::OpAtomicIDecrement:
1592 case SPIRV::OpAtomicIAdd:
1593 case SPIRV::OpAtomicISub:
1594 case SPIRV::OpAtomicUMin:
1595 case SPIRV::OpAtomicUMax:
1596 case SPIRV::OpAtomicSMin:
1597 case SPIRV::OpAtomicSMax:
1598 case SPIRV::OpAtomicAnd:
1599 case SPIRV::OpAtomicOr:
1600 case SPIRV::OpAtomicXor: {
1601 const MachineRegisterInfo &MRI = MI.getMF()->getRegInfo();
1602 const MachineInstr *InstrPtr = &MI;
1603 if (MI.getOpcode() == SPIRV::OpAtomicStore) {
1604 assert(MI.getOperand(3).isReg());
1605 InstrPtr = MRI.getVRegDef(Reg: MI.getOperand(i: 3).getReg());
1606 assert(InstrPtr && "Unexpected type instruction for OpAtomicStore");
1607 }
1608 assert(InstrPtr->getOperand(1).isReg() && "Unexpected operand in atomic");
1609 Register TypeReg = InstrPtr->getOperand(i: 1).getReg();
1610 SPIRVType *TypeDef = MRI.getVRegDef(Reg: TypeReg);
1611 if (TypeDef->getOpcode() == SPIRV::OpTypeInt) {
1612 unsigned BitWidth = TypeDef->getOperand(i: 1).getImm();
1613 if (BitWidth == 64)
1614 Reqs.addCapability(ToAdd: SPIRV::Capability::Int64Atomics);
1615 }
1616 break;
1617 }
1618 case SPIRV::OpGroupNonUniformIAdd:
1619 case SPIRV::OpGroupNonUniformFAdd:
1620 case SPIRV::OpGroupNonUniformIMul:
1621 case SPIRV::OpGroupNonUniformFMul:
1622 case SPIRV::OpGroupNonUniformSMin:
1623 case SPIRV::OpGroupNonUniformUMin:
1624 case SPIRV::OpGroupNonUniformFMin:
1625 case SPIRV::OpGroupNonUniformSMax:
1626 case SPIRV::OpGroupNonUniformUMax:
1627 case SPIRV::OpGroupNonUniformFMax:
1628 case SPIRV::OpGroupNonUniformBitwiseAnd:
1629 case SPIRV::OpGroupNonUniformBitwiseOr:
1630 case SPIRV::OpGroupNonUniformBitwiseXor:
1631 case SPIRV::OpGroupNonUniformLogicalAnd:
1632 case SPIRV::OpGroupNonUniformLogicalOr:
1633 case SPIRV::OpGroupNonUniformLogicalXor: {
1634 assert(MI.getOperand(3).isImm());
1635 int64_t GroupOp = MI.getOperand(i: 3).getImm();
1636 switch (GroupOp) {
1637 case SPIRV::GroupOperation::Reduce:
1638 case SPIRV::GroupOperation::InclusiveScan:
1639 case SPIRV::GroupOperation::ExclusiveScan:
1640 Reqs.addCapability(ToAdd: SPIRV::Capability::GroupNonUniformArithmetic);
1641 break;
1642 case SPIRV::GroupOperation::ClusteredReduce:
1643 Reqs.addCapability(ToAdd: SPIRV::Capability::GroupNonUniformClustered);
1644 break;
1645 case SPIRV::GroupOperation::PartitionedReduceNV:
1646 case SPIRV::GroupOperation::PartitionedInclusiveScanNV:
1647 case SPIRV::GroupOperation::PartitionedExclusiveScanNV:
1648 Reqs.addCapability(ToAdd: SPIRV::Capability::GroupNonUniformPartitionedNV);
1649 break;
1650 }
1651 break;
1652 }
1653 case SPIRV::OpImageQueryFormat: {
1654 Register ResultReg = MI.getOperand(i: 0).getReg();
1655 const MachineRegisterInfo &MRI = MI.getMF()->getRegInfo();
1656 static const unsigned CompareOps[] = {
1657 SPIRV::OpIEqual, SPIRV::OpINotEqual,
1658 SPIRV::OpUGreaterThan, SPIRV::OpUGreaterThanEqual,
1659 SPIRV::OpULessThan, SPIRV::OpULessThanEqual,
1660 SPIRV::OpSGreaterThan, SPIRV::OpSGreaterThanEqual,
1661 SPIRV::OpSLessThan, SPIRV::OpSLessThanEqual};
1662
1663 auto CheckAndAddExtension = [&](int64_t ImmVal) {
1664 if (ImmVal == 4323 || ImmVal == 4324) {
1665 if (ST.canUseExtension(E: SPIRV::Extension::SPV_EXT_image_raw10_raw12))
1666 Reqs.addExtension(ToAdd: SPIRV::Extension::SPV_EXT_image_raw10_raw12);
1667 else
1668 report_fatal_error(reason: "This requires the "
1669 "SPV_EXT_image_raw10_raw12 extension");
1670 }
1671 };
1672
1673 for (MachineInstr &UseInst : MRI.use_instructions(Reg: ResultReg)) {
1674 unsigned Opc = UseInst.getOpcode();
1675
1676 if (Opc == SPIRV::OpSwitch) {
1677 for (const MachineOperand &Op : UseInst.operands())
1678 if (Op.isImm())
1679 CheckAndAddExtension(Op.getImm());
1680 } else if (llvm::is_contained(Range: CompareOps, Element: Opc)) {
1681 for (unsigned i = 1; i < UseInst.getNumOperands(); ++i) {
1682 Register UseReg = UseInst.getOperand(i).getReg();
1683 MachineInstr *ConstInst = MRI.getVRegDef(Reg: UseReg);
1684 if (ConstInst && ConstInst->getOpcode() == SPIRV::OpConstantI) {
1685 int64_t ImmVal = ConstInst->getOperand(i: 2).getImm();
1686 if (ImmVal)
1687 CheckAndAddExtension(ImmVal);
1688 }
1689 }
1690 }
1691 }
1692 break;
1693 }
1694
1695 case SPIRV::OpGroupNonUniformShuffle:
1696 case SPIRV::OpGroupNonUniformShuffleXor:
1697 Reqs.addCapability(ToAdd: SPIRV::Capability::GroupNonUniformShuffle);
1698 break;
1699 case SPIRV::OpGroupNonUniformShuffleUp:
1700 case SPIRV::OpGroupNonUniformShuffleDown:
1701 Reqs.addCapability(ToAdd: SPIRV::Capability::GroupNonUniformShuffleRelative);
1702 break;
1703 case SPIRV::OpGroupAll:
1704 case SPIRV::OpGroupAny:
1705 case SPIRV::OpGroupBroadcast:
1706 case SPIRV::OpGroupIAdd:
1707 case SPIRV::OpGroupFAdd:
1708 case SPIRV::OpGroupFMin:
1709 case SPIRV::OpGroupUMin:
1710 case SPIRV::OpGroupSMin:
1711 case SPIRV::OpGroupFMax:
1712 case SPIRV::OpGroupUMax:
1713 case SPIRV::OpGroupSMax:
1714 Reqs.addCapability(ToAdd: SPIRV::Capability::Groups);
1715 break;
1716 case SPIRV::OpGroupNonUniformElect:
1717 Reqs.addCapability(ToAdd: SPIRV::Capability::GroupNonUniform);
1718 break;
1719 case SPIRV::OpGroupNonUniformAll:
1720 case SPIRV::OpGroupNonUniformAny:
1721 case SPIRV::OpGroupNonUniformAllEqual:
1722 Reqs.addCapability(ToAdd: SPIRV::Capability::GroupNonUniformVote);
1723 break;
1724 case SPIRV::OpGroupNonUniformBroadcast:
1725 case SPIRV::OpGroupNonUniformBroadcastFirst:
1726 case SPIRV::OpGroupNonUniformBallot:
1727 case SPIRV::OpGroupNonUniformInverseBallot:
1728 case SPIRV::OpGroupNonUniformBallotBitExtract:
1729 case SPIRV::OpGroupNonUniformBallotBitCount:
1730 case SPIRV::OpGroupNonUniformBallotFindLSB:
1731 case SPIRV::OpGroupNonUniformBallotFindMSB:
1732 Reqs.addCapability(ToAdd: SPIRV::Capability::GroupNonUniformBallot);
1733 break;
1734 case SPIRV::OpSubgroupShuffleINTEL:
1735 case SPIRV::OpSubgroupShuffleDownINTEL:
1736 case SPIRV::OpSubgroupShuffleUpINTEL:
1737 case SPIRV::OpSubgroupShuffleXorINTEL:
1738 if (ST.canUseExtension(E: SPIRV::Extension::SPV_INTEL_subgroups)) {
1739 Reqs.addExtension(ToAdd: SPIRV::Extension::SPV_INTEL_subgroups);
1740 Reqs.addCapability(ToAdd: SPIRV::Capability::SubgroupShuffleINTEL);
1741 }
1742 break;
1743 case SPIRV::OpSubgroupBlockReadINTEL:
1744 case SPIRV::OpSubgroupBlockWriteINTEL:
1745 if (ST.canUseExtension(E: SPIRV::Extension::SPV_INTEL_subgroups)) {
1746 Reqs.addExtension(ToAdd: SPIRV::Extension::SPV_INTEL_subgroups);
1747 Reqs.addCapability(ToAdd: SPIRV::Capability::SubgroupBufferBlockIOINTEL);
1748 }
1749 break;
1750 case SPIRV::OpSubgroupImageBlockReadINTEL:
1751 case SPIRV::OpSubgroupImageBlockWriteINTEL:
1752 if (ST.canUseExtension(E: SPIRV::Extension::SPV_INTEL_subgroups)) {
1753 Reqs.addExtension(ToAdd: SPIRV::Extension::SPV_INTEL_subgroups);
1754 Reqs.addCapability(ToAdd: SPIRV::Capability::SubgroupImageBlockIOINTEL);
1755 }
1756 break;
1757 case SPIRV::OpSubgroupImageMediaBlockReadINTEL:
1758 case SPIRV::OpSubgroupImageMediaBlockWriteINTEL:
1759 if (ST.canUseExtension(E: SPIRV::Extension::SPV_INTEL_media_block_io)) {
1760 Reqs.addExtension(ToAdd: SPIRV::Extension::SPV_INTEL_media_block_io);
1761 Reqs.addCapability(ToAdd: SPIRV::Capability::SubgroupImageMediaBlockIOINTEL);
1762 }
1763 break;
1764 case SPIRV::OpAssumeTrueKHR:
1765 case SPIRV::OpExpectKHR:
1766 if (ST.canUseExtension(E: SPIRV::Extension::SPV_KHR_expect_assume)) {
1767 Reqs.addExtension(ToAdd: SPIRV::Extension::SPV_KHR_expect_assume);
1768 Reqs.addCapability(ToAdd: SPIRV::Capability::ExpectAssumeKHR);
1769 }
1770 break;
1771 case SPIRV::OpFmaKHR:
1772 if (ST.canUseExtension(E: SPIRV::Extension::SPV_KHR_fma)) {
1773 Reqs.addExtension(ToAdd: SPIRV::Extension::SPV_KHR_fma);
1774 Reqs.addCapability(ToAdd: SPIRV::Capability::FmaKHR);
1775 }
1776 break;
1777 case SPIRV::OpPtrCastToCrossWorkgroupINTEL:
1778 case SPIRV::OpCrossWorkgroupCastToPtrINTEL:
1779 if (ST.canUseExtension(E: SPIRV::Extension::SPV_INTEL_usm_storage_classes)) {
1780 Reqs.addExtension(ToAdd: SPIRV::Extension::SPV_INTEL_usm_storage_classes);
1781 Reqs.addCapability(ToAdd: SPIRV::Capability::USMStorageClassesINTEL);
1782 }
1783 break;
1784 case SPIRV::OpConstantFunctionPointerINTEL:
1785 if (ST.canUseExtension(E: SPIRV::Extension::SPV_INTEL_function_pointers)) {
1786 Reqs.addExtension(ToAdd: SPIRV::Extension::SPV_INTEL_function_pointers);
1787 Reqs.addCapability(ToAdd: SPIRV::Capability::FunctionPointersINTEL);
1788 }
1789 break;
1790 case SPIRV::OpGroupNonUniformRotateKHR:
1791 if (!ST.canUseExtension(E: SPIRV::Extension::SPV_KHR_subgroup_rotate))
1792 report_fatal_error(reason: "OpGroupNonUniformRotateKHR instruction requires the "
1793 "following SPIR-V extension: SPV_KHR_subgroup_rotate",
1794 gen_crash_diag: false);
1795 Reqs.addExtension(ToAdd: SPIRV::Extension::SPV_KHR_subgroup_rotate);
1796 Reqs.addCapability(ToAdd: SPIRV::Capability::GroupNonUniformRotateKHR);
1797 Reqs.addCapability(ToAdd: SPIRV::Capability::GroupNonUniform);
1798 break;
1799 case SPIRV::OpFixedCosALTERA:
1800 case SPIRV::OpFixedSinALTERA:
1801 case SPIRV::OpFixedCosPiALTERA:
1802 case SPIRV::OpFixedSinPiALTERA:
1803 case SPIRV::OpFixedExpALTERA:
1804 case SPIRV::OpFixedLogALTERA:
1805 case SPIRV::OpFixedRecipALTERA:
1806 case SPIRV::OpFixedSqrtALTERA:
1807 case SPIRV::OpFixedSinCosALTERA:
1808 case SPIRV::OpFixedSinCosPiALTERA:
1809 case SPIRV::OpFixedRsqrtALTERA:
1810 if (!ST.canUseExtension(
1811 E: SPIRV::Extension::SPV_ALTERA_arbitrary_precision_fixed_point))
1812 report_fatal_error(reason: "This instruction requires the "
1813 "following SPIR-V extension: "
1814 "SPV_ALTERA_arbitrary_precision_fixed_point",
1815 gen_crash_diag: false);
1816 Reqs.addExtension(
1817 ToAdd: SPIRV::Extension::SPV_ALTERA_arbitrary_precision_fixed_point);
1818 Reqs.addCapability(ToAdd: SPIRV::Capability::ArbitraryPrecisionFixedPointALTERA);
1819 break;
1820 case SPIRV::OpGroupIMulKHR:
1821 case SPIRV::OpGroupFMulKHR:
1822 case SPIRV::OpGroupBitwiseAndKHR:
1823 case SPIRV::OpGroupBitwiseOrKHR:
1824 case SPIRV::OpGroupBitwiseXorKHR:
1825 case SPIRV::OpGroupLogicalAndKHR:
1826 case SPIRV::OpGroupLogicalOrKHR:
1827 case SPIRV::OpGroupLogicalXorKHR:
1828 if (ST.canUseExtension(
1829 E: SPIRV::Extension::SPV_KHR_uniform_group_instructions)) {
1830 Reqs.addExtension(ToAdd: SPIRV::Extension::SPV_KHR_uniform_group_instructions);
1831 Reqs.addCapability(ToAdd: SPIRV::Capability::GroupUniformArithmeticKHR);
1832 }
1833 break;
1834 case SPIRV::OpReadClockKHR:
1835 if (!ST.canUseExtension(E: SPIRV::Extension::SPV_KHR_shader_clock))
1836 report_fatal_error(reason: "OpReadClockKHR instruction requires the "
1837 "following SPIR-V extension: SPV_KHR_shader_clock",
1838 gen_crash_diag: false);
1839 Reqs.addExtension(ToAdd: SPIRV::Extension::SPV_KHR_shader_clock);
1840 Reqs.addCapability(ToAdd: SPIRV::Capability::ShaderClockKHR);
1841 break;
1842 case SPIRV::OpFunctionPointerCallINTEL:
1843 if (ST.canUseExtension(E: SPIRV::Extension::SPV_INTEL_function_pointers)) {
1844 Reqs.addExtension(ToAdd: SPIRV::Extension::SPV_INTEL_function_pointers);
1845 Reqs.addCapability(ToAdd: SPIRV::Capability::FunctionPointersINTEL);
1846 }
1847 break;
1848 case SPIRV::OpAtomicFAddEXT:
1849 case SPIRV::OpAtomicFMinEXT:
1850 case SPIRV::OpAtomicFMaxEXT:
1851 AddAtomicFloatRequirements(MI, Reqs, ST);
1852 break;
1853 case SPIRV::OpConvertBF16ToFINTEL:
1854 case SPIRV::OpConvertFToBF16INTEL:
1855 if (ST.canUseExtension(E: SPIRV::Extension::SPV_INTEL_bfloat16_conversion)) {
1856 Reqs.addExtension(ToAdd: SPIRV::Extension::SPV_INTEL_bfloat16_conversion);
1857 Reqs.addCapability(ToAdd: SPIRV::Capability::BFloat16ConversionINTEL);
1858 }
1859 break;
1860 case SPIRV::OpRoundFToTF32INTEL:
1861 if (ST.canUseExtension(
1862 E: SPIRV::Extension::SPV_INTEL_tensor_float32_conversion)) {
1863 Reqs.addExtension(ToAdd: SPIRV::Extension::SPV_INTEL_tensor_float32_conversion);
1864 Reqs.addCapability(ToAdd: SPIRV::Capability::TensorFloat32RoundingINTEL);
1865 }
1866 break;
1867 case SPIRV::OpVariableLengthArrayINTEL:
1868 case SPIRV::OpSaveMemoryINTEL:
1869 case SPIRV::OpRestoreMemoryINTEL:
1870 if (ST.canUseExtension(E: SPIRV::Extension::SPV_INTEL_variable_length_array)) {
1871 Reqs.addExtension(ToAdd: SPIRV::Extension::SPV_INTEL_variable_length_array);
1872 Reqs.addCapability(ToAdd: SPIRV::Capability::VariableLengthArrayINTEL);
1873 }
1874 break;
1875 case SPIRV::OpAsmTargetINTEL:
1876 case SPIRV::OpAsmINTEL:
1877 case SPIRV::OpAsmCallINTEL:
1878 if (ST.canUseExtension(E: SPIRV::Extension::SPV_INTEL_inline_assembly)) {
1879 Reqs.addExtension(ToAdd: SPIRV::Extension::SPV_INTEL_inline_assembly);
1880 Reqs.addCapability(ToAdd: SPIRV::Capability::AsmINTEL);
1881 }
1882 break;
1883 case SPIRV::OpTypeCooperativeMatrixKHR: {
1884 if (!ST.canUseExtension(E: SPIRV::Extension::SPV_KHR_cooperative_matrix))
1885 report_fatal_error(
1886 reason: "OpTypeCooperativeMatrixKHR type requires the "
1887 "following SPIR-V extension: SPV_KHR_cooperative_matrix",
1888 gen_crash_diag: false);
1889 Reqs.addExtension(ToAdd: SPIRV::Extension::SPV_KHR_cooperative_matrix);
1890 Reqs.addCapability(ToAdd: SPIRV::Capability::CooperativeMatrixKHR);
1891 const MachineRegisterInfo &MRI = MI.getMF()->getRegInfo();
1892 SPIRVType *TypeDef = MRI.getVRegDef(Reg: MI.getOperand(i: 1).getReg());
1893 if (isBFloat16Type(TypeDef))
1894 Reqs.addCapability(ToAdd: SPIRV::Capability::BFloat16CooperativeMatrixKHR);
1895 break;
1896 }
1897 case SPIRV::OpArithmeticFenceEXT:
1898 if (!ST.canUseExtension(E: SPIRV::Extension::SPV_EXT_arithmetic_fence))
1899 report_fatal_error(reason: "OpArithmeticFenceEXT requires the "
1900 "following SPIR-V extension: SPV_EXT_arithmetic_fence",
1901 gen_crash_diag: false);
1902 Reqs.addExtension(ToAdd: SPIRV::Extension::SPV_EXT_arithmetic_fence);
1903 Reqs.addCapability(ToAdd: SPIRV::Capability::ArithmeticFenceEXT);
1904 break;
1905 case SPIRV::OpControlBarrierArriveINTEL:
1906 case SPIRV::OpControlBarrierWaitINTEL:
1907 if (ST.canUseExtension(E: SPIRV::Extension::SPV_INTEL_split_barrier)) {
1908 Reqs.addExtension(ToAdd: SPIRV::Extension::SPV_INTEL_split_barrier);
1909 Reqs.addCapability(ToAdd: SPIRV::Capability::SplitBarrierINTEL);
1910 }
1911 break;
1912 case SPIRV::OpCooperativeMatrixMulAddKHR: {
1913 if (!ST.canUseExtension(E: SPIRV::Extension::SPV_KHR_cooperative_matrix))
1914 report_fatal_error(reason: "Cooperative matrix instructions require the "
1915 "following SPIR-V extension: "
1916 "SPV_KHR_cooperative_matrix",
1917 gen_crash_diag: false);
1918 Reqs.addExtension(ToAdd: SPIRV::Extension::SPV_KHR_cooperative_matrix);
1919 Reqs.addCapability(ToAdd: SPIRV::Capability::CooperativeMatrixKHR);
1920 constexpr unsigned MulAddMaxSize = 6;
1921 if (MI.getNumOperands() != MulAddMaxSize)
1922 break;
1923 const int64_t CoopOperands = MI.getOperand(i: MulAddMaxSize - 1).getImm();
1924 if (CoopOperands &
1925 SPIRV::CooperativeMatrixOperands::MatrixAAndBTF32ComponentsINTEL) {
1926 if (!ST.canUseExtension(E: SPIRV::Extension::SPV_INTEL_joint_matrix))
1927 report_fatal_error(reason: "MatrixAAndBTF32ComponentsINTEL type interpretation "
1928 "require the following SPIR-V extension: "
1929 "SPV_INTEL_joint_matrix",
1930 gen_crash_diag: false);
1931 Reqs.addExtension(ToAdd: SPIRV::Extension::SPV_INTEL_joint_matrix);
1932 Reqs.addCapability(
1933 ToAdd: SPIRV::Capability::CooperativeMatrixTF32ComponentTypeINTEL);
1934 }
1935 if (CoopOperands & SPIRV::CooperativeMatrixOperands::
1936 MatrixAAndBBFloat16ComponentsINTEL ||
1937 CoopOperands &
1938 SPIRV::CooperativeMatrixOperands::MatrixCBFloat16ComponentsINTEL ||
1939 CoopOperands & SPIRV::CooperativeMatrixOperands::
1940 MatrixResultBFloat16ComponentsINTEL) {
1941 if (!ST.canUseExtension(E: SPIRV::Extension::SPV_INTEL_joint_matrix))
1942 report_fatal_error(reason: "***BF16ComponentsINTEL type interpretations "
1943 "require the following SPIR-V extension: "
1944 "SPV_INTEL_joint_matrix",
1945 gen_crash_diag: false);
1946 Reqs.addExtension(ToAdd: SPIRV::Extension::SPV_INTEL_joint_matrix);
1947 Reqs.addCapability(
1948 ToAdd: SPIRV::Capability::CooperativeMatrixBFloat16ComponentTypeINTEL);
1949 }
1950 break;
1951 }
1952 case SPIRV::OpCooperativeMatrixLoadKHR:
1953 case SPIRV::OpCooperativeMatrixStoreKHR:
1954 case SPIRV::OpCooperativeMatrixLoadCheckedINTEL:
1955 case SPIRV::OpCooperativeMatrixStoreCheckedINTEL:
1956 case SPIRV::OpCooperativeMatrixPrefetchINTEL: {
1957 if (!ST.canUseExtension(E: SPIRV::Extension::SPV_KHR_cooperative_matrix))
1958 report_fatal_error(reason: "Cooperative matrix instructions require the "
1959 "following SPIR-V extension: "
1960 "SPV_KHR_cooperative_matrix",
1961 gen_crash_diag: false);
1962 Reqs.addExtension(ToAdd: SPIRV::Extension::SPV_KHR_cooperative_matrix);
1963 Reqs.addCapability(ToAdd: SPIRV::Capability::CooperativeMatrixKHR);
1964
1965 // Check Layout operand in case if it's not a standard one and add the
1966 // appropriate capability.
1967 std::unordered_map<unsigned, unsigned> LayoutToInstMap = {
1968 {SPIRV::OpCooperativeMatrixLoadKHR, 3},
1969 {SPIRV::OpCooperativeMatrixStoreKHR, 2},
1970 {SPIRV::OpCooperativeMatrixLoadCheckedINTEL, 5},
1971 {SPIRV::OpCooperativeMatrixStoreCheckedINTEL, 4},
1972 {SPIRV::OpCooperativeMatrixPrefetchINTEL, 4}};
1973
1974 const auto OpCode = MI.getOpcode();
1975 const unsigned LayoutNum = LayoutToInstMap[OpCode];
1976 Register RegLayout = MI.getOperand(i: LayoutNum).getReg();
1977 const MachineRegisterInfo &MRI = MI.getMF()->getRegInfo();
1978 MachineInstr *MILayout = MRI.getUniqueVRegDef(Reg: RegLayout);
1979 if (MILayout->getOpcode() == SPIRV::OpConstantI) {
1980 const unsigned LayoutVal = MILayout->getOperand(i: 2).getImm();
1981 if (LayoutVal ==
1982 static_cast<unsigned>(SPIRV::CooperativeMatrixLayout::PackedINTEL)) {
1983 if (!ST.canUseExtension(E: SPIRV::Extension::SPV_INTEL_joint_matrix))
1984 report_fatal_error(reason: "PackedINTEL layout require the following SPIR-V "
1985 "extension: SPV_INTEL_joint_matrix",
1986 gen_crash_diag: false);
1987 Reqs.addExtension(ToAdd: SPIRV::Extension::SPV_INTEL_joint_matrix);
1988 Reqs.addCapability(ToAdd: SPIRV::Capability::PackedCooperativeMatrixINTEL);
1989 }
1990 }
1991
1992 // Nothing to do.
1993 if (OpCode == SPIRV::OpCooperativeMatrixLoadKHR ||
1994 OpCode == SPIRV::OpCooperativeMatrixStoreKHR)
1995 break;
1996
1997 std::string InstName;
1998 switch (OpCode) {
1999 case SPIRV::OpCooperativeMatrixPrefetchINTEL:
2000 InstName = "OpCooperativeMatrixPrefetchINTEL";
2001 break;
2002 case SPIRV::OpCooperativeMatrixLoadCheckedINTEL:
2003 InstName = "OpCooperativeMatrixLoadCheckedINTEL";
2004 break;
2005 case SPIRV::OpCooperativeMatrixStoreCheckedINTEL:
2006 InstName = "OpCooperativeMatrixStoreCheckedINTEL";
2007 break;
2008 }
2009
2010 if (!ST.canUseExtension(E: SPIRV::Extension::SPV_INTEL_joint_matrix)) {
2011 const std::string ErrorMsg =
2012 InstName + " instruction requires the "
2013 "following SPIR-V extension: SPV_INTEL_joint_matrix";
2014 report_fatal_error(reason: ErrorMsg.c_str(), gen_crash_diag: false);
2015 }
2016 Reqs.addExtension(ToAdd: SPIRV::Extension::SPV_INTEL_joint_matrix);
2017 if (OpCode == SPIRV::OpCooperativeMatrixPrefetchINTEL) {
2018 Reqs.addCapability(ToAdd: SPIRV::Capability::CooperativeMatrixPrefetchINTEL);
2019 break;
2020 }
2021 Reqs.addCapability(
2022 ToAdd: SPIRV::Capability::CooperativeMatrixCheckedInstructionsINTEL);
2023 break;
2024 }
2025 case SPIRV::OpCooperativeMatrixConstructCheckedINTEL:
2026 if (!ST.canUseExtension(E: SPIRV::Extension::SPV_INTEL_joint_matrix))
2027 report_fatal_error(reason: "OpCooperativeMatrixConstructCheckedINTEL "
2028 "instructions require the following SPIR-V extension: "
2029 "SPV_INTEL_joint_matrix",
2030 gen_crash_diag: false);
2031 Reqs.addExtension(ToAdd: SPIRV::Extension::SPV_INTEL_joint_matrix);
2032 Reqs.addCapability(
2033 ToAdd: SPIRV::Capability::CooperativeMatrixCheckedInstructionsINTEL);
2034 break;
2035 case SPIRV::OpReadPipeBlockingALTERA:
2036 case SPIRV::OpWritePipeBlockingALTERA:
2037 if (ST.canUseExtension(E: SPIRV::Extension::SPV_ALTERA_blocking_pipes)) {
2038 Reqs.addExtension(ToAdd: SPIRV::Extension::SPV_ALTERA_blocking_pipes);
2039 Reqs.addCapability(ToAdd: SPIRV::Capability::BlockingPipesALTERA);
2040 }
2041 break;
2042 case SPIRV::OpCooperativeMatrixGetElementCoordINTEL:
2043 if (!ST.canUseExtension(E: SPIRV::Extension::SPV_INTEL_joint_matrix))
2044 report_fatal_error(reason: "OpCooperativeMatrixGetElementCoordINTEL requires the "
2045 "following SPIR-V extension: SPV_INTEL_joint_matrix",
2046 gen_crash_diag: false);
2047 Reqs.addExtension(ToAdd: SPIRV::Extension::SPV_INTEL_joint_matrix);
2048 Reqs.addCapability(
2049 ToAdd: SPIRV::Capability::CooperativeMatrixInvocationInstructionsINTEL);
2050 break;
2051 case SPIRV::OpConvertHandleToImageINTEL:
2052 case SPIRV::OpConvertHandleToSamplerINTEL:
2053 case SPIRV::OpConvertHandleToSampledImageINTEL: {
2054 if (!ST.canUseExtension(E: SPIRV::Extension::SPV_INTEL_bindless_images))
2055 report_fatal_error(reason: "OpConvertHandleTo[Image/Sampler/SampledImage]INTEL "
2056 "instructions require the following SPIR-V extension: "
2057 "SPV_INTEL_bindless_images",
2058 gen_crash_diag: false);
2059 SPIRVGlobalRegistry *GR = ST.getSPIRVGlobalRegistry();
2060 SPIRV::AddressingModel::AddressingModel AddrModel = MAI.Addr;
2061 SPIRVType *TyDef = GR->getSPIRVTypeForVReg(VReg: MI.getOperand(i: 1).getReg());
2062 if (MI.getOpcode() == SPIRV::OpConvertHandleToImageINTEL &&
2063 TyDef->getOpcode() != SPIRV::OpTypeImage) {
2064 report_fatal_error(reason: "Incorrect return type for the instruction "
2065 "OpConvertHandleToImageINTEL",
2066 gen_crash_diag: false);
2067 } else if (MI.getOpcode() == SPIRV::OpConvertHandleToSamplerINTEL &&
2068 TyDef->getOpcode() != SPIRV::OpTypeSampler) {
2069 report_fatal_error(reason: "Incorrect return type for the instruction "
2070 "OpConvertHandleToSamplerINTEL",
2071 gen_crash_diag: false);
2072 } else if (MI.getOpcode() == SPIRV::OpConvertHandleToSampledImageINTEL &&
2073 TyDef->getOpcode() != SPIRV::OpTypeSampledImage) {
2074 report_fatal_error(reason: "Incorrect return type for the instruction "
2075 "OpConvertHandleToSampledImageINTEL",
2076 gen_crash_diag: false);
2077 }
2078 SPIRVType *SpvTy = GR->getSPIRVTypeForVReg(VReg: MI.getOperand(i: 2).getReg());
2079 unsigned Bitwidth = GR->getScalarOrVectorBitWidth(Type: SpvTy);
2080 if (!(Bitwidth == 32 && AddrModel == SPIRV::AddressingModel::Physical32) &&
2081 !(Bitwidth == 64 && AddrModel == SPIRV::AddressingModel::Physical64)) {
2082 report_fatal_error(
2083 reason: "Parameter value must be a 32-bit scalar in case of "
2084 "Physical32 addressing model or a 64-bit scalar in case of "
2085 "Physical64 addressing model",
2086 gen_crash_diag: false);
2087 }
2088 Reqs.addExtension(ToAdd: SPIRV::Extension::SPV_INTEL_bindless_images);
2089 Reqs.addCapability(ToAdd: SPIRV::Capability::BindlessImagesINTEL);
2090 break;
2091 }
2092 case SPIRV::OpSubgroup2DBlockLoadINTEL:
2093 case SPIRV::OpSubgroup2DBlockLoadTransposeINTEL:
2094 case SPIRV::OpSubgroup2DBlockLoadTransformINTEL:
2095 case SPIRV::OpSubgroup2DBlockPrefetchINTEL:
2096 case SPIRV::OpSubgroup2DBlockStoreINTEL: {
2097 if (!ST.canUseExtension(E: SPIRV::Extension::SPV_INTEL_2d_block_io))
2098 report_fatal_error(reason: "OpSubgroup2DBlock[Load/LoadTranspose/LoadTransform/"
2099 "Prefetch/Store]INTEL instructions require the "
2100 "following SPIR-V extension: SPV_INTEL_2d_block_io",
2101 gen_crash_diag: false);
2102 Reqs.addExtension(ToAdd: SPIRV::Extension::SPV_INTEL_2d_block_io);
2103 Reqs.addCapability(ToAdd: SPIRV::Capability::Subgroup2DBlockIOINTEL);
2104
2105 const auto OpCode = MI.getOpcode();
2106 if (OpCode == SPIRV::OpSubgroup2DBlockLoadTransposeINTEL) {
2107 Reqs.addCapability(ToAdd: SPIRV::Capability::Subgroup2DBlockTransposeINTEL);
2108 break;
2109 }
2110 if (OpCode == SPIRV::OpSubgroup2DBlockLoadTransformINTEL) {
2111 Reqs.addCapability(ToAdd: SPIRV::Capability::Subgroup2DBlockTransformINTEL);
2112 break;
2113 }
2114 break;
2115 }
2116 case SPIRV::OpKill: {
2117 Reqs.addCapability(ToAdd: SPIRV::Capability::Shader);
2118 } break;
2119 case SPIRV::OpDemoteToHelperInvocation:
2120 Reqs.addCapability(ToAdd: SPIRV::Capability::DemoteToHelperInvocation);
2121
2122 if (ST.canUseExtension(
2123 E: SPIRV::Extension::SPV_EXT_demote_to_helper_invocation)) {
2124 if (!ST.isAtLeastSPIRVVer(VerToCompareTo: llvm::VersionTuple(1, 6)))
2125 Reqs.addExtension(
2126 ToAdd: SPIRV::Extension::SPV_EXT_demote_to_helper_invocation);
2127 }
2128 break;
2129 case SPIRV::OpSDot:
2130 case SPIRV::OpUDot:
2131 case SPIRV::OpSUDot:
2132 case SPIRV::OpSDotAccSat:
2133 case SPIRV::OpUDotAccSat:
2134 case SPIRV::OpSUDotAccSat:
2135 AddDotProductRequirements(MI, Reqs, ST);
2136 break;
2137 case SPIRV::OpImageSampleImplicitLod:
2138 Reqs.addCapability(ToAdd: SPIRV::Capability::Shader);
2139 break;
2140 case SPIRV::OpImageRead: {
2141 Register ImageReg = MI.getOperand(i: 2).getReg();
2142 SPIRVType *TypeDef = ST.getSPIRVGlobalRegistry()->getResultType(
2143 VReg: ImageReg, MF: const_cast<MachineFunction *>(MI.getMF()));
2144 // OpImageRead and OpImageWrite can use Unknown Image Formats
2145 // when the Kernel capability is declared. In the OpenCL environment we are
2146 // not allowed to produce
2147 // StorageImageReadWithoutFormat/StorageImageWriteWithoutFormat, see
2148 // https://github.com/KhronosGroup/SPIRV-Headers/issues/487
2149
2150 if (isImageTypeWithUnknownFormat(TypeInst: TypeDef) && ST.isShader())
2151 Reqs.addCapability(ToAdd: SPIRV::Capability::StorageImageReadWithoutFormat);
2152 break;
2153 }
2154 case SPIRV::OpImageWrite: {
2155 Register ImageReg = MI.getOperand(i: 0).getReg();
2156 SPIRVType *TypeDef = ST.getSPIRVGlobalRegistry()->getResultType(
2157 VReg: ImageReg, MF: const_cast<MachineFunction *>(MI.getMF()));
2158 // OpImageRead and OpImageWrite can use Unknown Image Formats
2159 // when the Kernel capability is declared. In the OpenCL environment we are
2160 // not allowed to produce
2161 // StorageImageReadWithoutFormat/StorageImageWriteWithoutFormat, see
2162 // https://github.com/KhronosGroup/SPIRV-Headers/issues/487
2163
2164 if (isImageTypeWithUnknownFormat(TypeInst: TypeDef) && ST.isShader())
2165 Reqs.addCapability(ToAdd: SPIRV::Capability::StorageImageWriteWithoutFormat);
2166 break;
2167 }
2168 case SPIRV::OpTypeStructContinuedINTEL:
2169 case SPIRV::OpConstantCompositeContinuedINTEL:
2170 case SPIRV::OpSpecConstantCompositeContinuedINTEL:
2171 case SPIRV::OpCompositeConstructContinuedINTEL: {
2172 if (!ST.canUseExtension(E: SPIRV::Extension::SPV_INTEL_long_composites))
2173 report_fatal_error(
2174 reason: "Continued instructions require the "
2175 "following SPIR-V extension: SPV_INTEL_long_composites",
2176 gen_crash_diag: false);
2177 Reqs.addExtension(ToAdd: SPIRV::Extension::SPV_INTEL_long_composites);
2178 Reqs.addCapability(ToAdd: SPIRV::Capability::LongCompositesINTEL);
2179 break;
2180 }
2181 case SPIRV::OpArbitraryFloatEQALTERA:
2182 case SPIRV::OpArbitraryFloatGEALTERA:
2183 case SPIRV::OpArbitraryFloatGTALTERA:
2184 case SPIRV::OpArbitraryFloatLEALTERA:
2185 case SPIRV::OpArbitraryFloatLTALTERA:
2186 case SPIRV::OpArbitraryFloatCbrtALTERA:
2187 case SPIRV::OpArbitraryFloatCosALTERA:
2188 case SPIRV::OpArbitraryFloatCosPiALTERA:
2189 case SPIRV::OpArbitraryFloatExp10ALTERA:
2190 case SPIRV::OpArbitraryFloatExp2ALTERA:
2191 case SPIRV::OpArbitraryFloatExpALTERA:
2192 case SPIRV::OpArbitraryFloatExpm1ALTERA:
2193 case SPIRV::OpArbitraryFloatHypotALTERA:
2194 case SPIRV::OpArbitraryFloatLog10ALTERA:
2195 case SPIRV::OpArbitraryFloatLog1pALTERA:
2196 case SPIRV::OpArbitraryFloatLog2ALTERA:
2197 case SPIRV::OpArbitraryFloatLogALTERA:
2198 case SPIRV::OpArbitraryFloatRecipALTERA:
2199 case SPIRV::OpArbitraryFloatSinCosALTERA:
2200 case SPIRV::OpArbitraryFloatSinCosPiALTERA:
2201 case SPIRV::OpArbitraryFloatSinALTERA:
2202 case SPIRV::OpArbitraryFloatSinPiALTERA:
2203 case SPIRV::OpArbitraryFloatSqrtALTERA:
2204 case SPIRV::OpArbitraryFloatACosALTERA:
2205 case SPIRV::OpArbitraryFloatACosPiALTERA:
2206 case SPIRV::OpArbitraryFloatAddALTERA:
2207 case SPIRV::OpArbitraryFloatASinALTERA:
2208 case SPIRV::OpArbitraryFloatASinPiALTERA:
2209 case SPIRV::OpArbitraryFloatATan2ALTERA:
2210 case SPIRV::OpArbitraryFloatATanALTERA:
2211 case SPIRV::OpArbitraryFloatATanPiALTERA:
2212 case SPIRV::OpArbitraryFloatCastFromIntALTERA:
2213 case SPIRV::OpArbitraryFloatCastALTERA:
2214 case SPIRV::OpArbitraryFloatCastToIntALTERA:
2215 case SPIRV::OpArbitraryFloatDivALTERA:
2216 case SPIRV::OpArbitraryFloatMulALTERA:
2217 case SPIRV::OpArbitraryFloatPowALTERA:
2218 case SPIRV::OpArbitraryFloatPowNALTERA:
2219 case SPIRV::OpArbitraryFloatPowRALTERA:
2220 case SPIRV::OpArbitraryFloatRSqrtALTERA:
2221 case SPIRV::OpArbitraryFloatSubALTERA: {
2222 if (!ST.canUseExtension(
2223 E: SPIRV::Extension::SPV_ALTERA_arbitrary_precision_floating_point))
2224 report_fatal_error(
2225 reason: "Floating point instructions can't be translated correctly without "
2226 "enabled SPV_ALTERA_arbitrary_precision_floating_point extension!",
2227 gen_crash_diag: false);
2228 Reqs.addExtension(
2229 ToAdd: SPIRV::Extension::SPV_ALTERA_arbitrary_precision_floating_point);
2230 Reqs.addCapability(
2231 ToAdd: SPIRV::Capability::ArbitraryPrecisionFloatingPointALTERA);
2232 break;
2233 }
2234 case SPIRV::OpSubgroupMatrixMultiplyAccumulateINTEL: {
2235 if (!ST.canUseExtension(
2236 E: SPIRV::Extension::SPV_INTEL_subgroup_matrix_multiply_accumulate))
2237 report_fatal_error(
2238 reason: "OpSubgroupMatrixMultiplyAccumulateINTEL instruction requires the "
2239 "following SPIR-V "
2240 "extension: SPV_INTEL_subgroup_matrix_multiply_accumulate",
2241 gen_crash_diag: false);
2242 Reqs.addExtension(
2243 ToAdd: SPIRV::Extension::SPV_INTEL_subgroup_matrix_multiply_accumulate);
2244 Reqs.addCapability(
2245 ToAdd: SPIRV::Capability::SubgroupMatrixMultiplyAccumulateINTEL);
2246 break;
2247 }
2248 case SPIRV::OpBitwiseFunctionINTEL: {
2249 if (!ST.canUseExtension(
2250 E: SPIRV::Extension::SPV_INTEL_ternary_bitwise_function))
2251 report_fatal_error(
2252 reason: "OpBitwiseFunctionINTEL instruction requires the following SPIR-V "
2253 "extension: SPV_INTEL_ternary_bitwise_function",
2254 gen_crash_diag: false);
2255 Reqs.addExtension(ToAdd: SPIRV::Extension::SPV_INTEL_ternary_bitwise_function);
2256 Reqs.addCapability(ToAdd: SPIRV::Capability::TernaryBitwiseFunctionINTEL);
2257 break;
2258 }
2259 case SPIRV::OpCopyMemorySized: {
2260 Reqs.addCapability(ToAdd: SPIRV::Capability::Addresses);
2261 // TODO: Add UntypedPointersKHR when implemented.
2262 break;
2263 }
2264 case SPIRV::OpPredicatedLoadINTEL:
2265 case SPIRV::OpPredicatedStoreINTEL: {
2266 if (!ST.canUseExtension(E: SPIRV::Extension::SPV_INTEL_predicated_io))
2267 report_fatal_error(
2268 reason: "OpPredicated[Load/Store]INTEL instructions require "
2269 "the following SPIR-V extension: SPV_INTEL_predicated_io",
2270 gen_crash_diag: false);
2271 Reqs.addExtension(ToAdd: SPIRV::Extension::SPV_INTEL_predicated_io);
2272 Reqs.addCapability(ToAdd: SPIRV::Capability::PredicatedIOINTEL);
2273 break;
2274 }
2275 case SPIRV::OpFAddS:
2276 case SPIRV::OpFSubS:
2277 case SPIRV::OpFMulS:
2278 case SPIRV::OpFDivS:
2279 case SPIRV::OpFRemS:
2280 case SPIRV::OpFMod:
2281 case SPIRV::OpFNegate:
2282 case SPIRV::OpFAddV:
2283 case SPIRV::OpFSubV:
2284 case SPIRV::OpFMulV:
2285 case SPIRV::OpFDivV:
2286 case SPIRV::OpFRemV:
2287 case SPIRV::OpFNegateV: {
2288 const MachineRegisterInfo &MRI = MI.getMF()->getRegInfo();
2289 SPIRVType *TypeDef = MRI.getVRegDef(Reg: MI.getOperand(i: 1).getReg());
2290 if (TypeDef->getOpcode() == SPIRV::OpTypeVector)
2291 TypeDef = MRI.getVRegDef(Reg: TypeDef->getOperand(i: 1).getReg());
2292 if (isBFloat16Type(TypeDef)) {
2293 if (!ST.canUseExtension(E: SPIRV::Extension::SPV_INTEL_bfloat16_arithmetic))
2294 report_fatal_error(
2295 reason: "Arithmetic instructions with bfloat16 arguments require the "
2296 "following SPIR-V extension: SPV_INTEL_bfloat16_arithmetic",
2297 gen_crash_diag: false);
2298 Reqs.addExtension(ToAdd: SPIRV::Extension::SPV_INTEL_bfloat16_arithmetic);
2299 Reqs.addCapability(ToAdd: SPIRV::Capability::BFloat16ArithmeticINTEL);
2300 }
2301 break;
2302 }
2303 case SPIRV::OpOrdered:
2304 case SPIRV::OpUnordered:
2305 case SPIRV::OpFOrdEqual:
2306 case SPIRV::OpFOrdNotEqual:
2307 case SPIRV::OpFOrdLessThan:
2308 case SPIRV::OpFOrdLessThanEqual:
2309 case SPIRV::OpFOrdGreaterThan:
2310 case SPIRV::OpFOrdGreaterThanEqual:
2311 case SPIRV::OpFUnordEqual:
2312 case SPIRV::OpFUnordNotEqual:
2313 case SPIRV::OpFUnordLessThan:
2314 case SPIRV::OpFUnordLessThanEqual:
2315 case SPIRV::OpFUnordGreaterThan:
2316 case SPIRV::OpFUnordGreaterThanEqual: {
2317 const MachineRegisterInfo &MRI = MI.getMF()->getRegInfo();
2318 MachineInstr *OperandDef = MRI.getVRegDef(Reg: MI.getOperand(i: 2).getReg());
2319 SPIRVType *TypeDef = MRI.getVRegDef(Reg: OperandDef->getOperand(i: 1).getReg());
2320 if (TypeDef->getOpcode() == SPIRV::OpTypeVector)
2321 TypeDef = MRI.getVRegDef(Reg: TypeDef->getOperand(i: 1).getReg());
2322 if (isBFloat16Type(TypeDef)) {
2323 if (!ST.canUseExtension(E: SPIRV::Extension::SPV_INTEL_bfloat16_arithmetic))
2324 report_fatal_error(
2325 reason: "Relational instructions with bfloat16 arguments require the "
2326 "following SPIR-V extension: SPV_INTEL_bfloat16_arithmetic",
2327 gen_crash_diag: false);
2328 Reqs.addExtension(ToAdd: SPIRV::Extension::SPV_INTEL_bfloat16_arithmetic);
2329 Reqs.addCapability(ToAdd: SPIRV::Capability::BFloat16ArithmeticINTEL);
2330 }
2331 break;
2332 }
2333 case SPIRV::OpDPdxCoarse:
2334 case SPIRV::OpDPdyCoarse:
2335 case SPIRV::OpDPdxFine:
2336 case SPIRV::OpDPdyFine: {
2337 Reqs.addCapability(ToAdd: SPIRV::Capability::DerivativeControl);
2338 break;
2339 }
2340 case SPIRV::OpLoopControlINTEL: {
2341 Reqs.addExtension(ToAdd: SPIRV::Extension::SPV_INTEL_unstructured_loop_controls);
2342 Reqs.addCapability(ToAdd: SPIRV::Capability::UnstructuredLoopControlsINTEL);
2343 break;
2344 }
2345
2346 default:
2347 break;
2348 }
2349
2350 // If we require capability Shader, then we can remove the requirement for
2351 // the BitInstructions capability, since Shader is a superset capability
2352 // of BitInstructions.
2353 Reqs.removeCapabilityIf(ToRemove: SPIRV::Capability::BitInstructions,
2354 IfPresent: SPIRV::Capability::Shader);
2355}
2356
2357static void collectReqs(const Module &M, SPIRV::ModuleAnalysisInfo &MAI,
2358 MachineModuleInfo *MMI, const SPIRVSubtarget &ST) {
2359 // Collect requirements for existing instructions.
2360 for (const Function &F : M) {
2361 MachineFunction *MF = MMI->getMachineFunction(F);
2362 if (!MF)
2363 continue;
2364 for (const MachineBasicBlock &MBB : *MF)
2365 for (const MachineInstr &MI : MBB)
2366 addInstrRequirements(MI, MAI, ST);
2367 }
2368 // Collect requirements for OpExecutionMode instructions.
2369 auto Node = M.getNamedMetadata(Name: "spirv.ExecutionMode");
2370 if (Node) {
2371 bool RequireFloatControls = false, RequireIntelFloatControls2 = false,
2372 RequireKHRFloatControls2 = false,
2373 VerLower14 = !ST.isAtLeastSPIRVVer(VerToCompareTo: VersionTuple(1, 4));
2374 bool HasIntelFloatControls2 =
2375 ST.canUseExtension(E: SPIRV::Extension::SPV_INTEL_float_controls2);
2376 bool HasKHRFloatControls2 =
2377 ST.canUseExtension(E: SPIRV::Extension::SPV_KHR_float_controls2);
2378 for (unsigned i = 0; i < Node->getNumOperands(); i++) {
2379 MDNode *MDN = cast<MDNode>(Val: Node->getOperand(i));
2380 const MDOperand &MDOp = MDN->getOperand(I: 1);
2381 if (auto *CMeta = dyn_cast<ConstantAsMetadata>(Val: MDOp)) {
2382 Constant *C = CMeta->getValue();
2383 if (ConstantInt *Const = dyn_cast<ConstantInt>(Val: C)) {
2384 auto EM = Const->getZExtValue();
2385 // SPV_KHR_float_controls is not available until v1.4:
2386 // add SPV_KHR_float_controls if the version is too low
2387 switch (EM) {
2388 case SPIRV::ExecutionMode::DenormPreserve:
2389 case SPIRV::ExecutionMode::DenormFlushToZero:
2390 case SPIRV::ExecutionMode::RoundingModeRTE:
2391 case SPIRV::ExecutionMode::RoundingModeRTZ:
2392 RequireFloatControls = VerLower14;
2393 MAI.Reqs.getAndAddRequirements(
2394 Category: SPIRV::OperandCategory::ExecutionModeOperand, i: EM, ST);
2395 break;
2396 case SPIRV::ExecutionMode::RoundingModeRTPINTEL:
2397 case SPIRV::ExecutionMode::RoundingModeRTNINTEL:
2398 case SPIRV::ExecutionMode::FloatingPointModeALTINTEL:
2399 case SPIRV::ExecutionMode::FloatingPointModeIEEEINTEL:
2400 if (HasIntelFloatControls2) {
2401 RequireIntelFloatControls2 = true;
2402 MAI.Reqs.getAndAddRequirements(
2403 Category: SPIRV::OperandCategory::ExecutionModeOperand, i: EM, ST);
2404 }
2405 break;
2406 case SPIRV::ExecutionMode::FPFastMathDefault: {
2407 if (HasKHRFloatControls2) {
2408 RequireKHRFloatControls2 = true;
2409 MAI.Reqs.getAndAddRequirements(
2410 Category: SPIRV::OperandCategory::ExecutionModeOperand, i: EM, ST);
2411 }
2412 break;
2413 }
2414 case SPIRV::ExecutionMode::ContractionOff:
2415 case SPIRV::ExecutionMode::SignedZeroInfNanPreserve:
2416 if (HasKHRFloatControls2) {
2417 RequireKHRFloatControls2 = true;
2418 MAI.Reqs.getAndAddRequirements(
2419 Category: SPIRV::OperandCategory::ExecutionModeOperand,
2420 i: SPIRV::ExecutionMode::FPFastMathDefault, ST);
2421 } else {
2422 MAI.Reqs.getAndAddRequirements(
2423 Category: SPIRV::OperandCategory::ExecutionModeOperand, i: EM, ST);
2424 }
2425 break;
2426 default:
2427 MAI.Reqs.getAndAddRequirements(
2428 Category: SPIRV::OperandCategory::ExecutionModeOperand, i: EM, ST);
2429 }
2430 }
2431 }
2432 }
2433 if (RequireFloatControls &&
2434 ST.canUseExtension(E: SPIRV::Extension::SPV_KHR_float_controls))
2435 MAI.Reqs.addExtension(ToAdd: SPIRV::Extension::SPV_KHR_float_controls);
2436 if (RequireIntelFloatControls2)
2437 MAI.Reqs.addExtension(ToAdd: SPIRV::Extension::SPV_INTEL_float_controls2);
2438 if (RequireKHRFloatControls2)
2439 MAI.Reqs.addExtension(ToAdd: SPIRV::Extension::SPV_KHR_float_controls2);
2440 }
2441 for (const Function &F : M) {
2442 if (F.isDeclaration())
2443 continue;
2444 if (F.getMetadata(Kind: "reqd_work_group_size"))
2445 MAI.Reqs.getAndAddRequirements(
2446 Category: SPIRV::OperandCategory::ExecutionModeOperand,
2447 i: SPIRV::ExecutionMode::LocalSize, ST);
2448 if (F.getFnAttribute(Kind: "hlsl.numthreads").isValid()) {
2449 MAI.Reqs.getAndAddRequirements(
2450 Category: SPIRV::OperandCategory::ExecutionModeOperand,
2451 i: SPIRV::ExecutionMode::LocalSize, ST);
2452 }
2453 if (F.getFnAttribute(Kind: "enable-maximal-reconvergence").getValueAsBool()) {
2454 MAI.Reqs.addExtension(ToAdd: SPIRV::Extension::SPV_KHR_maximal_reconvergence);
2455 }
2456 if (F.getMetadata(Kind: "work_group_size_hint"))
2457 MAI.Reqs.getAndAddRequirements(
2458 Category: SPIRV::OperandCategory::ExecutionModeOperand,
2459 i: SPIRV::ExecutionMode::LocalSizeHint, ST);
2460 if (F.getMetadata(Kind: "intel_reqd_sub_group_size"))
2461 MAI.Reqs.getAndAddRequirements(
2462 Category: SPIRV::OperandCategory::ExecutionModeOperand,
2463 i: SPIRV::ExecutionMode::SubgroupSize, ST);
2464 if (F.getMetadata(Kind: "max_work_group_size"))
2465 MAI.Reqs.getAndAddRequirements(
2466 Category: SPIRV::OperandCategory::ExecutionModeOperand,
2467 i: SPIRV::ExecutionMode::MaxWorkgroupSizeINTEL, ST);
2468 if (F.getMetadata(Kind: "vec_type_hint"))
2469 MAI.Reqs.getAndAddRequirements(
2470 Category: SPIRV::OperandCategory::ExecutionModeOperand,
2471 i: SPIRV::ExecutionMode::VecTypeHint, ST);
2472
2473 if (F.hasOptNone()) {
2474 if (ST.canUseExtension(E: SPIRV::Extension::SPV_INTEL_optnone)) {
2475 MAI.Reqs.addExtension(ToAdd: SPIRV::Extension::SPV_INTEL_optnone);
2476 MAI.Reqs.addCapability(ToAdd: SPIRV::Capability::OptNoneINTEL);
2477 } else if (ST.canUseExtension(E: SPIRV::Extension::SPV_EXT_optnone)) {
2478 MAI.Reqs.addExtension(ToAdd: SPIRV::Extension::SPV_EXT_optnone);
2479 MAI.Reqs.addCapability(ToAdd: SPIRV::Capability::OptNoneEXT);
2480 }
2481 }
2482 }
2483}
2484
2485static unsigned getFastMathFlags(const MachineInstr &I,
2486 const SPIRVSubtarget &ST) {
2487 unsigned Flags = SPIRV::FPFastMathMode::None;
2488 bool CanUseKHRFloatControls2 =
2489 ST.canUseExtension(E: SPIRV::Extension::SPV_KHR_float_controls2);
2490 if (I.getFlag(Flag: MachineInstr::MIFlag::FmNoNans))
2491 Flags |= SPIRV::FPFastMathMode::NotNaN;
2492 if (I.getFlag(Flag: MachineInstr::MIFlag::FmNoInfs))
2493 Flags |= SPIRV::FPFastMathMode::NotInf;
2494 if (I.getFlag(Flag: MachineInstr::MIFlag::FmNsz))
2495 Flags |= SPIRV::FPFastMathMode::NSZ;
2496 if (I.getFlag(Flag: MachineInstr::MIFlag::FmArcp))
2497 Flags |= SPIRV::FPFastMathMode::AllowRecip;
2498 if (I.getFlag(Flag: MachineInstr::MIFlag::FmContract) && CanUseKHRFloatControls2)
2499 Flags |= SPIRV::FPFastMathMode::AllowContract;
2500 if (I.getFlag(Flag: MachineInstr::MIFlag::FmReassoc)) {
2501 if (CanUseKHRFloatControls2)
2502 // LLVM reassoc maps to SPIRV transform, see
2503 // https://github.com/KhronosGroup/SPIRV-Registry/issues/326 for details.
2504 // Because we are enabling AllowTransform, we must enable AllowReassoc and
2505 // AllowContract too, as required by SPIRV spec. Also, we used to map
2506 // MIFlag::FmReassoc to FPFastMathMode::Fast, which now should instead by
2507 // replaced by turning all the other bits instead. Therefore, we're
2508 // enabling every bit here except None and Fast.
2509 Flags |= SPIRV::FPFastMathMode::NotNaN | SPIRV::FPFastMathMode::NotInf |
2510 SPIRV::FPFastMathMode::NSZ | SPIRV::FPFastMathMode::AllowRecip |
2511 SPIRV::FPFastMathMode::AllowTransform |
2512 SPIRV::FPFastMathMode::AllowReassoc |
2513 SPIRV::FPFastMathMode::AllowContract;
2514 else
2515 Flags |= SPIRV::FPFastMathMode::Fast;
2516 }
2517
2518 if (CanUseKHRFloatControls2) {
2519 // Error out if SPIRV::FPFastMathMode::Fast is enabled.
2520 assert(!(Flags & SPIRV::FPFastMathMode::Fast) &&
2521 "SPIRV::FPFastMathMode::Fast is deprecated and should not be used "
2522 "anymore.");
2523
2524 // Error out if AllowTransform is enabled without AllowReassoc and
2525 // AllowContract.
2526 assert((!(Flags & SPIRV::FPFastMathMode::AllowTransform) ||
2527 ((Flags & SPIRV::FPFastMathMode::AllowReassoc &&
2528 Flags & SPIRV::FPFastMathMode::AllowContract))) &&
2529 "SPIRV::FPFastMathMode::AllowTransform requires AllowReassoc and "
2530 "AllowContract flags to be enabled as well.");
2531 }
2532
2533 return Flags;
2534}
2535
2536static bool isFastMathModeAvailable(const SPIRVSubtarget &ST) {
2537 if (ST.isKernel())
2538 return true;
2539 if (ST.getSPIRVVersion() < VersionTuple(1, 2))
2540 return false;
2541 return ST.canUseExtension(E: SPIRV::Extension::SPV_KHR_float_controls2);
2542}
2543
2544static void handleMIFlagDecoration(
2545 MachineInstr &I, const SPIRVSubtarget &ST, const SPIRVInstrInfo &TII,
2546 SPIRV::RequirementHandler &Reqs, const SPIRVGlobalRegistry *GR,
2547 SPIRV::FPFastMathDefaultInfoVector &FPFastMathDefaultInfoVec) {
2548 if (I.getFlag(Flag: MachineInstr::MIFlag::NoSWrap) && TII.canUseNSW(MI: I) &&
2549 getSymbolicOperandRequirements(Category: SPIRV::OperandCategory::DecorationOperand,
2550 i: SPIRV::Decoration::NoSignedWrap, ST, Reqs)
2551 .IsSatisfiable) {
2552 buildOpDecorate(Reg: I.getOperand(i: 0).getReg(), I, TII,
2553 Dec: SPIRV::Decoration::NoSignedWrap, DecArgs: {});
2554 }
2555 if (I.getFlag(Flag: MachineInstr::MIFlag::NoUWrap) && TII.canUseNUW(MI: I) &&
2556 getSymbolicOperandRequirements(Category: SPIRV::OperandCategory::DecorationOperand,
2557 i: SPIRV::Decoration::NoUnsignedWrap, ST,
2558 Reqs)
2559 .IsSatisfiable) {
2560 buildOpDecorate(Reg: I.getOperand(i: 0).getReg(), I, TII,
2561 Dec: SPIRV::Decoration::NoUnsignedWrap, DecArgs: {});
2562 }
2563 if (!TII.canUseFastMathFlags(
2564 MI: I, KHRFloatControls2: ST.canUseExtension(E: SPIRV::Extension::SPV_KHR_float_controls2)))
2565 return;
2566
2567 unsigned FMFlags = getFastMathFlags(I, ST);
2568 if (FMFlags == SPIRV::FPFastMathMode::None) {
2569 // We also need to check if any FPFastMathDefault info was set for the
2570 // types used in this instruction.
2571 if (FPFastMathDefaultInfoVec.empty())
2572 return;
2573
2574 // There are three types of instructions that can use fast math flags:
2575 // 1. Arithmetic instructions (FAdd, FMul, FSub, FDiv, FRem, etc.)
2576 // 2. Relational instructions (FCmp, FOrd, FUnord, etc.)
2577 // 3. Extended instructions (ExtInst)
2578 // For arithmetic instructions, the floating point type can be in the
2579 // result type or in the operands, but they all must be the same.
2580 // For the relational and logical instructions, the floating point type
2581 // can only be in the operands 1 and 2, not the result type. Also, the
2582 // operands must have the same type. For the extended instructions, the
2583 // floating point type can be in the result type or in the operands. It's
2584 // unclear if the operands and the result type must be the same. Let's
2585 // assume they must be. Therefore, for 1. and 2., we can check the first
2586 // operand type, and for 3. we can check the result type.
2587 assert(I.getNumOperands() >= 3 && "Expected at least 3 operands");
2588 Register ResReg = I.getOpcode() == SPIRV::OpExtInst
2589 ? I.getOperand(i: 1).getReg()
2590 : I.getOperand(i: 2).getReg();
2591 SPIRVType *ResType = GR->getSPIRVTypeForVReg(VReg: ResReg, MF: I.getMF());
2592 const Type *Ty = GR->getTypeForSPIRVType(Ty: ResType);
2593 Ty = Ty->isVectorTy() ? cast<VectorType>(Val: Ty)->getElementType() : Ty;
2594
2595 // Match instruction type with the FPFastMathDefaultInfoVec.
2596 bool Emit = false;
2597 for (SPIRV::FPFastMathDefaultInfo &Elem : FPFastMathDefaultInfoVec) {
2598 if (Ty == Elem.Ty) {
2599 FMFlags = Elem.FastMathFlags;
2600 Emit = Elem.ContractionOff || Elem.SignedZeroInfNanPreserve ||
2601 Elem.FPFastMathDefault;
2602 break;
2603 }
2604 }
2605
2606 if (FMFlags == SPIRV::FPFastMathMode::None && !Emit)
2607 return;
2608 }
2609 if (isFastMathModeAvailable(ST)) {
2610 Register DstReg = I.getOperand(i: 0).getReg();
2611 buildOpDecorate(Reg: DstReg, I, TII, Dec: SPIRV::Decoration::FPFastMathMode,
2612 DecArgs: {FMFlags});
2613 }
2614}
2615
2616// Walk all functions and add decorations related to MI flags.
2617static void addDecorations(const Module &M, const SPIRVInstrInfo &TII,
2618 MachineModuleInfo *MMI, const SPIRVSubtarget &ST,
2619 SPIRV::ModuleAnalysisInfo &MAI,
2620 const SPIRVGlobalRegistry *GR) {
2621 for (const Function &F : M) {
2622 MachineFunction *MF = MMI->getMachineFunction(F);
2623 if (!MF)
2624 continue;
2625
2626 for (auto &MBB : *MF)
2627 for (auto &MI : MBB)
2628 handleMIFlagDecoration(I&: MI, ST, TII, Reqs&: MAI.Reqs, GR,
2629 FPFastMathDefaultInfoVec&: MAI.FPFastMathDefaultInfoMap[&F]);
2630 }
2631}
2632
2633static void addMBBNames(const Module &M, const SPIRVInstrInfo &TII,
2634 MachineModuleInfo *MMI, const SPIRVSubtarget &ST,
2635 SPIRV::ModuleAnalysisInfo &MAI) {
2636 for (const Function &F : M) {
2637 MachineFunction *MF = MMI->getMachineFunction(F);
2638 if (!MF)
2639 continue;
2640 MachineRegisterInfo &MRI = MF->getRegInfo();
2641 for (auto &MBB : *MF) {
2642 if (!MBB.hasName() || MBB.empty())
2643 continue;
2644 // Emit basic block names.
2645 Register Reg = MRI.createGenericVirtualRegister(Ty: LLT::scalar(SizeInBits: 64));
2646 MRI.setRegClass(Reg, RC: &SPIRV::IDRegClass);
2647 buildOpName(Target: Reg, Name: MBB.getName(), I&: *std::prev(x: MBB.end()), TII);
2648 MCRegister GlobalReg = MAI.getOrCreateMBBRegister(MBB);
2649 MAI.setRegisterAlias(MF, Reg, AliasReg: GlobalReg);
2650 }
2651 }
2652}
2653
2654// patching Instruction::PHI to SPIRV::OpPhi
2655static void patchPhis(const Module &M, SPIRVGlobalRegistry *GR,
2656 const SPIRVInstrInfo &TII, MachineModuleInfo *MMI) {
2657 for (const Function &F : M) {
2658 MachineFunction *MF = MMI->getMachineFunction(F);
2659 if (!MF)
2660 continue;
2661 for (auto &MBB : *MF) {
2662 for (MachineInstr &MI : MBB.phis()) {
2663 MI.setDesc(TII.get(Opcode: SPIRV::OpPhi));
2664 Register ResTypeReg = GR->getSPIRVTypeID(
2665 SpirvType: GR->getSPIRVTypeForVReg(VReg: MI.getOperand(i: 0).getReg(), MF));
2666 MI.insert(InsertBefore: MI.operands_begin() + 1,
2667 Ops: {MachineOperand::CreateReg(Reg: ResTypeReg, isDef: false)});
2668 }
2669 }
2670
2671 MF->getProperties().setNoPHIs();
2672 }
2673}
2674
2675static SPIRV::FPFastMathDefaultInfoVector &getOrCreateFPFastMathDefaultInfoVec(
2676 const Module &M, SPIRV::ModuleAnalysisInfo &MAI, const Function *F) {
2677 auto it = MAI.FPFastMathDefaultInfoMap.find(Val: F);
2678 if (it != MAI.FPFastMathDefaultInfoMap.end())
2679 return it->second;
2680
2681 // If the map does not contain the entry, create a new one. Initialize it to
2682 // contain all 3 elements sorted by bit width of target type: {half, float,
2683 // double}.
2684 SPIRV::FPFastMathDefaultInfoVector FPFastMathDefaultInfoVec;
2685 FPFastMathDefaultInfoVec.emplace_back(Args: Type::getHalfTy(C&: M.getContext()),
2686 Args: SPIRV::FPFastMathMode::None);
2687 FPFastMathDefaultInfoVec.emplace_back(Args: Type::getFloatTy(C&: M.getContext()),
2688 Args: SPIRV::FPFastMathMode::None);
2689 FPFastMathDefaultInfoVec.emplace_back(Args: Type::getDoubleTy(C&: M.getContext()),
2690 Args: SPIRV::FPFastMathMode::None);
2691 return MAI.FPFastMathDefaultInfoMap[F] = std::move(FPFastMathDefaultInfoVec);
2692}
2693
2694static SPIRV::FPFastMathDefaultInfo &getFPFastMathDefaultInfo(
2695 SPIRV::FPFastMathDefaultInfoVector &FPFastMathDefaultInfoVec,
2696 const Type *Ty) {
2697 size_t BitWidth = Ty->getScalarSizeInBits();
2698 int Index =
2699 SPIRV::FPFastMathDefaultInfoVector::computeFPFastMathDefaultInfoVecIndex(
2700 BitWidth);
2701 assert(Index >= 0 && Index < 3 &&
2702 "Expected FPFastMathDefaultInfo for half, float, or double");
2703 assert(FPFastMathDefaultInfoVec.size() == 3 &&
2704 "Expected FPFastMathDefaultInfoVec to have exactly 3 elements");
2705 return FPFastMathDefaultInfoVec[Index];
2706}
2707
2708static void collectFPFastMathDefaults(const Module &M,
2709 SPIRV::ModuleAnalysisInfo &MAI,
2710 const SPIRVSubtarget &ST) {
2711 if (!ST.canUseExtension(E: SPIRV::Extension::SPV_KHR_float_controls2))
2712 return;
2713
2714 // Store the FPFastMathDefaultInfo in the FPFastMathDefaultInfoMap.
2715 // We need the entry point (function) as the key, and the target
2716 // type and flags as the value.
2717 // We also need to check ContractionOff and SignedZeroInfNanPreserve
2718 // execution modes, as they are now deprecated and must be replaced
2719 // with FPFastMathDefaultInfo.
2720 auto Node = M.getNamedMetadata(Name: "spirv.ExecutionMode");
2721 if (!Node)
2722 return;
2723
2724 for (unsigned i = 0; i < Node->getNumOperands(); i++) {
2725 MDNode *MDN = cast<MDNode>(Val: Node->getOperand(i));
2726 assert(MDN->getNumOperands() >= 2 && "Expected at least 2 operands");
2727 const Function *F = cast<Function>(
2728 Val: cast<ConstantAsMetadata>(Val: MDN->getOperand(I: 0))->getValue());
2729 const auto EM =
2730 cast<ConstantInt>(
2731 Val: cast<ConstantAsMetadata>(Val: MDN->getOperand(I: 1))->getValue())
2732 ->getZExtValue();
2733 if (EM == SPIRV::ExecutionMode::FPFastMathDefault) {
2734 assert(MDN->getNumOperands() == 4 &&
2735 "Expected 4 operands for FPFastMathDefault");
2736
2737 const Type *T = cast<ValueAsMetadata>(Val: MDN->getOperand(I: 2))->getType();
2738 unsigned Flags =
2739 cast<ConstantInt>(
2740 Val: cast<ConstantAsMetadata>(Val: MDN->getOperand(I: 3))->getValue())
2741 ->getZExtValue();
2742 SPIRV::FPFastMathDefaultInfoVector &FPFastMathDefaultInfoVec =
2743 getOrCreateFPFastMathDefaultInfoVec(M, MAI, F);
2744 SPIRV::FPFastMathDefaultInfo &Info =
2745 getFPFastMathDefaultInfo(FPFastMathDefaultInfoVec, Ty: T);
2746 Info.FastMathFlags = Flags;
2747 Info.FPFastMathDefault = true;
2748 } else if (EM == SPIRV::ExecutionMode::ContractionOff) {
2749 assert(MDN->getNumOperands() == 2 &&
2750 "Expected no operands for ContractionOff");
2751
2752 // We need to save this info for every possible FP type, i.e. {half,
2753 // float, double, fp128}.
2754 SPIRV::FPFastMathDefaultInfoVector &FPFastMathDefaultInfoVec =
2755 getOrCreateFPFastMathDefaultInfoVec(M, MAI, F);
2756 for (SPIRV::FPFastMathDefaultInfo &Info : FPFastMathDefaultInfoVec) {
2757 Info.ContractionOff = true;
2758 }
2759 } else if (EM == SPIRV::ExecutionMode::SignedZeroInfNanPreserve) {
2760 assert(MDN->getNumOperands() == 3 &&
2761 "Expected 1 operand for SignedZeroInfNanPreserve");
2762 unsigned TargetWidth =
2763 cast<ConstantInt>(
2764 Val: cast<ConstantAsMetadata>(Val: MDN->getOperand(I: 2))->getValue())
2765 ->getZExtValue();
2766 // We need to save this info only for the FP type with TargetWidth.
2767 SPIRV::FPFastMathDefaultInfoVector &FPFastMathDefaultInfoVec =
2768 getOrCreateFPFastMathDefaultInfoVec(M, MAI, F);
2769 int Index = SPIRV::FPFastMathDefaultInfoVector::
2770 computeFPFastMathDefaultInfoVecIndex(BitWidth: TargetWidth);
2771 assert(Index >= 0 && Index < 3 &&
2772 "Expected FPFastMathDefaultInfo for half, float, or double");
2773 assert(FPFastMathDefaultInfoVec.size() == 3 &&
2774 "Expected FPFastMathDefaultInfoVec to have exactly 3 elements");
2775 FPFastMathDefaultInfoVec[Index].SignedZeroInfNanPreserve = true;
2776 }
2777 }
2778}
2779
2780struct SPIRV::ModuleAnalysisInfo SPIRVModuleAnalysis::MAI;
2781
2782void SPIRVModuleAnalysis::getAnalysisUsage(AnalysisUsage &AU) const {
2783 AU.addRequired<TargetPassConfig>();
2784 AU.addRequired<MachineModuleInfoWrapperPass>();
2785}
2786
2787bool SPIRVModuleAnalysis::runOnModule(Module &M) {
2788 SPIRVTargetMachine &TM =
2789 getAnalysis<TargetPassConfig>().getTM<SPIRVTargetMachine>();
2790 ST = TM.getSubtargetImpl();
2791 GR = ST->getSPIRVGlobalRegistry();
2792 TII = ST->getInstrInfo();
2793
2794 MMI = &getAnalysis<MachineModuleInfoWrapperPass>().getMMI();
2795
2796 setBaseInfo(M);
2797
2798 patchPhis(M, GR, TII: *TII, MMI);
2799
2800 addMBBNames(M, TII: *TII, MMI, ST: *ST, MAI);
2801 collectFPFastMathDefaults(M, MAI, ST: *ST);
2802 addDecorations(M, TII: *TII, MMI, ST: *ST, MAI, GR);
2803
2804 collectReqs(M, MAI, MMI, ST: *ST);
2805
2806 // Process type/const/global var/func decl instructions, number their
2807 // destination registers from 0 to N, collect Extensions and Capabilities.
2808 collectReqs(M, MAI, MMI, ST: *ST);
2809 collectDeclarations(M);
2810
2811 // Number rest of registers from N+1 onwards.
2812 numberRegistersGlobally(M);
2813
2814 // Collect OpName, OpEntryPoint, OpDecorate etc, process other instructions.
2815 processOtherInstrs(M);
2816
2817 // If there are no entry points, we need the Linkage capability.
2818 if (MAI.MS[SPIRV::MB_EntryPoints].empty())
2819 MAI.Reqs.addCapability(ToAdd: SPIRV::Capability::Linkage);
2820
2821 // Set maximum ID used.
2822 GR->setBound(MAI.MaxID);
2823
2824 return false;
2825}
2826