1//===-- SPIRVAsmPrinter.cpp - SPIR-V LLVM assembly writer ------*- C++ -*--===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file contains a printer that converts from our internal representation
10// of machine-dependent LLVM code to the SPIR-V assembly language.
11//
12//===----------------------------------------------------------------------===//
13
14#include "MCTargetDesc/SPIRVInstPrinter.h"
15#include "SPIRV.h"
16#include "SPIRVAuxDataHandler.h"
17#include "SPIRVInstrInfo.h"
18#include "SPIRVMCInstLower.h"
19#include "SPIRVModuleAnalysis.h"
20#include "SPIRVNonSemanticDebugHandler.h"
21#include "SPIRVSubtarget.h"
22#include "SPIRVTargetMachine.h"
23#include "SPIRVUtils.h"
24#include "TargetInfo/SPIRVTargetInfo.h"
25#include "llvm/ADT/DenseMap.h"
26#include "llvm/Analysis/ValueTracking.h"
27#include "llvm/CodeGen/AsmPrinter.h"
28#include "llvm/CodeGen/MachineConstantPool.h"
29#include "llvm/CodeGen/MachineInstr.h"
30#include "llvm/CodeGen/MachineModuleInfo.h"
31#include "llvm/CodeGen/TargetLoweringObjectFileImpl.h"
32#include "llvm/MC/MCAsmInfo.h"
33#include "llvm/MC/MCAssembler.h"
34#include "llvm/MC/MCInst.h"
35#include "llvm/MC/MCObjectStreamer.h"
36#include "llvm/MC/MCSPIRVObjectWriter.h"
37#include "llvm/MC/MCStreamer.h"
38#include "llvm/MC/MCSymbol.h"
39#include "llvm/MC/TargetRegistry.h"
40#include "llvm/Support/CommandLine.h"
41#include "llvm/Support/Compiler.h"
42#include "llvm/Support/raw_ostream.h"
43
44using namespace llvm;
45
46#define DEBUG_TYPE "asm-printer"
47
48namespace {
49enum class SPIRVFPContractMode { On, Off, Fast };
50
51static cl::opt<SPIRVFPContractMode> SPIRVFPContract(
52 "spirv-fp-contract",
53 cl::desc("Override FP contraction policy for SPIR-V kernel entry points"),
54 cl::values(
55 clEnumValN(SPIRVFPContractMode::On, "on",
56 "Follow IR metadata (default)"),
57 clEnumValN(SPIRVFPContractMode::Off, "off",
58 "Force ContractionOff on all kernel entry points"),
59 clEnumValN(SPIRVFPContractMode::Fast, "fast",
60 "Suppress ContractionOff on all kernel entry points")),
61 cl::init(Val: SPIRVFPContractMode::On));
62
63class SPIRVAsmPrinter : public AsmPrinter {
64 unsigned NLabels = 0;
65 SmallPtrSet<const MachineBasicBlock *, 8> LabeledMBB;
66
67public:
68 explicit SPIRVAsmPrinter(TargetMachine &TM,
69 std::unique_ptr<MCStreamer> Streamer)
70 : AsmPrinter(TM, std::move(Streamer), ID), ModuleSectionsEmitted(false),
71 ST(nullptr), TII(nullptr), MAI(nullptr) {}
72 static char ID;
73 bool ModuleSectionsEmitted;
74 const SPIRVSubtarget *ST;
75 const SPIRVInstrInfo *TII;
76
77 StringRef getPassName() const override { return "SPIRV Assembly Printer"; }
78 void printOperand(const MachineInstr *MI, int OpNum, raw_ostream &O);
79 bool PrintAsmOperand(const MachineInstr *MI, unsigned OpNo,
80 const char *ExtraCode, raw_ostream &O) override;
81
82 void outputMCInst(MCInst &Inst);
83 void outputInstruction(const MachineInstr *MI);
84 void outputModuleSection(SPIRV::ModuleSectionType MSType);
85 void outputGlobalRequirements();
86 void outputEntryPoints();
87 void outputDebugSourceAndStrings(const Module &M);
88 void outputOpExtInstImports(const Module &M);
89 void outputOpMemoryModel();
90 void outputOpFunctionEnd();
91 void outputExtFuncDecls();
92 void outputExecutionModeFromMDNode(MCRegister Reg, MDNode *Node,
93 SPIRV::ExecutionMode::ExecutionMode EM,
94 unsigned ExpectMDOps, int64_t DefVal);
95 void outputExecutionModeFromNumthreadsAttribute(
96 const MCRegister &Reg, const Attribute &Attr,
97 SPIRV::ExecutionMode::ExecutionMode EM);
98 void outputExecutionModeFromEnableMaximalReconvergenceAttr(
99 const MCRegister &Reg, const SPIRVSubtarget &ST);
100 void outputExecutionMode(const Module &M);
101 void outputAnnotations(const Module &M);
102 void outputModuleSections();
103 void outputFPFastMathDefaultInfo();
104 bool isHidden() {
105 return MF->getFunction()
106 .getFnAttribute(SPIRV_BACKEND_SERVICE_FUN_NAME)
107 .isValid();
108 }
109
110 void emitInstruction(const MachineInstr *MI) override;
111 void emitFunctionEntryLabel() override {}
112 void emitFunctionHeader() override;
113 void emitFunctionBodyStart() override {}
114 void emitFunctionBodyEnd() override;
115 void emitBasicBlockStart(const MachineBasicBlock &MBB) override;
116 void emitBasicBlockEnd(const MachineBasicBlock &MBB) override {}
117 void emitGlobalVariable(const GlobalVariable *GV) override {}
118 void emitOpLabel(const MachineBasicBlock &MBB);
119 void emitEndOfAsmFile(Module &M) override;
120 bool doInitialization(Module &M) override;
121
122 void getAnalysisUsage(AnalysisUsage &AU) const override;
123 SPIRV::ModuleAnalysisInfo *MAI;
124
125 // Non-owning pointer to the NSDI handler registered via addAsmPrinterHandler.
126 // The handler's lifetime is managed by AsmPrinter (the base class of this
127 // object), so this pointer cannot dangle.
128 SPIRVNonSemanticDebugHandler *NSDebugHandler = nullptr;
129
130 std::unique_ptr<SPIRVAuxDataHandler> AuxDataHandler;
131
132protected:
133 void cleanUp(Module &M);
134};
135} // namespace
136
137void SPIRVAsmPrinter::getAnalysisUsage(AnalysisUsage &AU) const {
138 AU.addRequired<SPIRVModuleAnalysis>();
139 AU.addPreserved<SPIRVModuleAnalysis>();
140 AsmPrinter::getAnalysisUsage(AU);
141}
142
143// If the module has no functions, we need output global info anyway.
144void SPIRVAsmPrinter::emitEndOfAsmFile(Module &M) {
145 if (!ModuleSectionsEmitted) {
146 outputModuleSections();
147 ModuleSectionsEmitted = true;
148 }
149
150 ST = static_cast<const SPIRVTargetMachine &>(TM).getSubtargetImpl();
151 // SPIRVModuleAnalysis sets GR->Bound = MAI->MaxID before printing. Any IDs
152 // allocated by AsmPrinter handlers (e.g. SPIRVNonSemanticDebugHandler) during
153 // outputModuleSections() are not counted. Refresh the bound here so the
154 // formula below sees the final allocation count.
155 if (MAI)
156 ST->getSPIRVGlobalRegistry()->setBound(MAI->MaxID);
157 VersionTuple SPIRVVersion = ST->getSPIRVVersion();
158 uint32_t Major = SPIRVVersion.getMajor();
159 uint32_t Minor = SPIRVVersion.getMinor().value_or(u: 0);
160 // Bound is an approximation that accounts for the maximum used register
161 // number and number of generated OpLabels
162 unsigned Bound = 2 * (ST->getBound() + 1) + NLabels;
163 if (MCAssembler *Asm = OutStreamer->getAssemblerPtr())
164 static_cast<SPIRVObjectWriter &>(Asm->getWriter())
165 .setBuildVersion(Major, Minor, Bound);
166
167 cleanUp(M);
168}
169
170// Any cleanup actions with the Module after we don't care about its content
171// anymore.
172void SPIRVAsmPrinter::cleanUp(Module &M) {
173 // Verifier disallows uses of intrinsic global variables.
174 for (StringRef GVName :
175 {"llvm.global_ctors", "llvm.global_dtors", "llvm.used"}) {
176 if (GlobalVariable *GV = M.getNamedGlobal(Name: GVName))
177 GV->setName("");
178 }
179}
180
181void SPIRVAsmPrinter::emitFunctionHeader() {
182 if (!ModuleSectionsEmitted) {
183 outputModuleSections();
184 ModuleSectionsEmitted = true;
185 }
186 // Get the subtarget from the current MachineFunction.
187 ST = &MF->getSubtarget<SPIRVSubtarget>();
188 TII = ST->getInstrInfo();
189 const Function &F = MF->getFunction();
190
191 if (isVerbose() && !isHidden()) {
192 OutStreamer->getCommentOS()
193 << "-- Begin function "
194 << GlobalValue::dropLLVMManglingEscape(Name: F.getName()) << '\n';
195 }
196
197 auto Section = getObjFileLowering().SectionForGlobal(GO: &F, TM);
198 MF->setSection(Section);
199
200 // SPIRVAsmPrinter::emitFunctionHeader() does not call the base class,
201 // so handlers never receive beginFunction() from the normal path. Drive the
202 // per-function lifecycle here, matching what AsmPrinter::emitFunctionHeader()
203 // does for other targets.
204 for (auto &Handler : Handlers) {
205 Handler->beginFunction(MF);
206 Handler->beginBasicBlockSection(MBB: MF->front());
207 }
208}
209
210void SPIRVAsmPrinter::outputOpFunctionEnd() {
211 MCInst FunctionEndInst;
212 FunctionEndInst.setOpcode(SPIRV::OpFunctionEnd);
213 outputMCInst(Inst&: FunctionEndInst);
214}
215
216void SPIRVAsmPrinter::emitFunctionBodyEnd() {
217 if (!isHidden())
218 outputOpFunctionEnd();
219}
220
221void SPIRVAsmPrinter::emitOpLabel(const MachineBasicBlock &MBB) {
222 // Do not emit anything if it's an internal service function.
223 if (isHidden())
224 return;
225
226 MCInst LabelInst;
227 LabelInst.setOpcode(SPIRV::OpLabel);
228 LabelInst.addOperand(Op: MCOperand::createReg(Reg: MAI->getOrCreateMBBRegister(MBB)));
229 outputMCInst(Inst&: LabelInst);
230 ++NLabels;
231 LabeledMBB.insert(Ptr: &MBB);
232}
233
234void SPIRVAsmPrinter::emitBasicBlockStart(const MachineBasicBlock &MBB) {
235 // Do not emit anything if it's an internal service function.
236 if (MBB.empty() || isHidden())
237 return;
238
239 // If it's the first MBB in MF, it has OpFunction and OpFunctionParameter, so
240 // OpLabel should be output after them.
241 if (MBB.getNumber() == MF->front().getNumber()) {
242 for (const MachineInstr &MI : MBB)
243 if (MI.getOpcode() == SPIRV::OpFunction)
244 return;
245 // TODO: this case should be checked by the verifier.
246 report_fatal_error(reason: "OpFunction is expected in the front MBB of MF");
247 }
248 emitOpLabel(MBB);
249}
250
251void SPIRVAsmPrinter::printOperand(const MachineInstr *MI, int OpNum,
252 raw_ostream &O) {
253 const MachineOperand &MO = MI->getOperand(i: OpNum);
254
255 switch (MO.getType()) {
256 case MachineOperand::MO_Register:
257 O << SPIRVInstPrinter::getRegisterName(Reg: MO.getReg());
258 break;
259
260 case MachineOperand::MO_Immediate:
261 O << MO.getImm();
262 break;
263
264 case MachineOperand::MO_FPImmediate:
265 O << MO.getFPImm();
266 break;
267
268 case MachineOperand::MO_MachineBasicBlock:
269 O << *MO.getMBB()->getSymbol();
270 break;
271
272 case MachineOperand::MO_GlobalAddress:
273 O << *getSymbol(GV: MO.getGlobal());
274 break;
275
276 case MachineOperand::MO_BlockAddress: {
277 MCSymbol *BA = GetBlockAddressSymbol(BA: MO.getBlockAddress());
278 O << BA->getName();
279 break;
280 }
281
282 case MachineOperand::MO_ExternalSymbol:
283 O << *GetExternalSymbolSymbol(Sym: MO.getSymbolName());
284 break;
285
286 case MachineOperand::MO_JumpTableIndex:
287 case MachineOperand::MO_ConstantPoolIndex:
288 default:
289 llvm_unreachable("<unknown operand type>");
290 }
291}
292
293bool SPIRVAsmPrinter::PrintAsmOperand(const MachineInstr *MI, unsigned OpNo,
294 const char *ExtraCode, raw_ostream &O) {
295 if (ExtraCode && ExtraCode[0])
296 return true; // Invalid instruction - SPIR-V does not have special modifiers
297
298 printOperand(MI, OpNum: OpNo, O);
299 return false;
300}
301
302static bool isFuncOrHeaderInstr(const MachineInstr *MI,
303 const SPIRVInstrInfo *TII) {
304 return TII->isHeaderInstr(MI: *MI) || MI->getOpcode() == SPIRV::OpFunction ||
305 MI->getOpcode() == SPIRV::OpFunctionParameter;
306}
307
308void SPIRVAsmPrinter::outputMCInst(MCInst &Inst) {
309 OutStreamer->emitInstruction(Inst, STI: *OutContext.getSubtargetInfo());
310}
311
312void SPIRVAsmPrinter::outputInstruction(const MachineInstr *MI) {
313 SPIRVMCInstLower MCInstLowering;
314 MCInst TmpInst;
315 MCInstLowering.lower(MI, OutMI&: TmpInst, MAI);
316 outputMCInst(Inst&: TmpInst);
317}
318
319void SPIRVAsmPrinter::emitInstruction(const MachineInstr *MI) {
320 SPIRV_MC::verifyInstructionPredicates(Opcode: MI->getOpcode(),
321 Features: getSubtargetInfo().getFeatureBits());
322
323 if (!MAI->getSkipEmission(MI))
324 outputInstruction(MI);
325
326 // Output OpLabel after OpFunction and OpFunctionParameter in the first MBB.
327 const MachineInstr *NextMI = MI->getNextNode();
328 if (!LabeledMBB.contains(Ptr: MI->getParent()) && isFuncOrHeaderInstr(MI, TII) &&
329 (!NextMI || !isFuncOrHeaderInstr(MI: NextMI, TII))) {
330 assert(MI->getParent()->getNumber() == MF->front().getNumber() &&
331 "OpFunction is not in the front MBB of MF");
332 emitOpLabel(MBB: *MI->getParent());
333 }
334}
335
336void SPIRVAsmPrinter::outputModuleSection(SPIRV::ModuleSectionType MSType) {
337 for (const MachineInstr *MI : MAI->getMSInstrs(MSType))
338 outputInstruction(MI);
339}
340
341void SPIRVAsmPrinter::outputDebugSourceAndStrings(const Module &M) {
342 // Output OpSourceExtensions.
343 for (auto &Str : MAI->SrcExt) {
344 MCInst Inst;
345 Inst.setOpcode(SPIRV::OpSourceExtension);
346 addStringImm(Str: Str.first(), Inst);
347 outputMCInst(Inst);
348 }
349 // Output OpString.
350 outputModuleSection(MSType: SPIRV::MB_DebugStrings);
351 // Output OpSource.
352 MCInst Inst;
353 Inst.setOpcode(SPIRV::OpSource);
354 Inst.addOperand(Op: MCOperand::createImm(Val: static_cast<unsigned>(MAI->SrcLang)));
355 Inst.addOperand(
356 Op: MCOperand::createImm(Val: static_cast<unsigned>(MAI->SrcLangVersion)));
357 outputMCInst(Inst);
358 // Emit OpString instructions for NSDI file paths and type names here, in
359 // section 7. OpString must precede type/constant declarations per the SPIR-V
360 // module layout (section 2.4). The OpExtInst instructions that reference
361 // these strings are emitted later at section 10 by
362 // emitNonSemanticGlobalDebugInfo().
363 if (NSDebugHandler)
364 NSDebugHandler->emitNonSemanticDebugStrings(MAI&: *MAI);
365 if (AuxDataHandler)
366 AuxDataHandler->emitAuxDataStrings(MAI&: *MAI);
367}
368
369void SPIRVAsmPrinter::outputOpExtInstImports(const Module &M) {
370 for (auto &CU : MAI->ExtInstSetMap) {
371 unsigned Set = CU.first;
372 MCRegister Reg = CU.second;
373 MCInst Inst;
374 Inst.setOpcode(SPIRV::OpExtInstImport);
375 Inst.addOperand(Op: MCOperand::createReg(Reg));
376 addStringImm(Str: getExtInstSetName(
377 Set: static_cast<SPIRV::InstructionSet::InstructionSet>(Set)),
378 Inst);
379 outputMCInst(Inst);
380 }
381}
382
383void SPIRVAsmPrinter::outputOpMemoryModel() {
384 MCInst Inst;
385 Inst.setOpcode(SPIRV::OpMemoryModel);
386 Inst.addOperand(Op: MCOperand::createImm(Val: static_cast<unsigned>(MAI->Addr)));
387 Inst.addOperand(Op: MCOperand::createImm(Val: static_cast<unsigned>(MAI->Mem)));
388 outputMCInst(Inst);
389}
390
391// Before the OpEntryPoints' output, we need to add the entry point's
392// interfaces. The interface is a list of IDs of global OpVariable instructions.
393// These declare the set of global variables from a module that form
394// the interface of this entry point.
395void SPIRVAsmPrinter::outputEntryPoints() {
396 // Find all OpVariable IDs with required StorageClass.
397 DenseSet<MCRegister> InterfaceIDs;
398 for (const MachineInstr *MI : MAI->GlobalVarList) {
399 assert(MI->getOpcode() == SPIRV::OpVariable);
400 auto SC = static_cast<SPIRV::StorageClass::StorageClass>(
401 MI->getOperand(i: 2).getImm());
402 // Before version 1.4, the interface's storage classes are limited to
403 // the Input and Output storage classes. Starting with version 1.4,
404 // the interface's storage classes are all storage classes used in
405 // declaring all global variables referenced by the entry point call tree.
406 if (ST->isAtLeastSPIRVVer(VerToCompareTo: VersionTuple(1, 4)) ||
407 SC == SPIRV::StorageClass::Input || SC == SPIRV::StorageClass::Output) {
408 const MachineFunction *MF = MI->getMF();
409 MCRegister Reg = MAI->getRegisterAlias(MF, Reg: MI->getOperand(i: 0).getReg());
410 InterfaceIDs.insert(V: Reg);
411 }
412 }
413
414 // Output OpEntryPoints adding interface args to all of them.
415 for (const MachineInstr *MI : MAI->getMSInstrs(MSType: SPIRV::MB_EntryPoints)) {
416 SPIRVMCInstLower MCInstLowering;
417 MCInst TmpInst;
418 MCInstLowering.lower(MI, OutMI&: TmpInst, MAI);
419 for (MCRegister Reg : InterfaceIDs) {
420 assert(Reg.isValid());
421 TmpInst.addOperand(Op: MCOperand::createReg(Reg));
422 }
423 outputMCInst(Inst&: TmpInst);
424 }
425}
426
427// Create global OpCapability instructions for the required capabilities.
428void SPIRVAsmPrinter::outputGlobalRequirements() {
429 // Abort here if not all requirements can be satisfied.
430 MAI->Reqs.checkSatisfiable(ST: *ST);
431
432 for (const auto &Cap : MAI->Reqs.getMinimalCapabilities()) {
433 MCInst Inst;
434 Inst.setOpcode(SPIRV::OpCapability);
435 Inst.addOperand(Op: MCOperand::createImm(Val: Cap));
436 outputMCInst(Inst);
437 }
438
439 // Generate the final OpExtensions with strings instead of enums.
440 for (const auto &Ext : MAI->Reqs.getExtensions()) {
441 MCInst Inst;
442 Inst.setOpcode(SPIRV::OpExtension);
443 addStringImm(Str: getSymbolicOperandMnemonic(
444 Category: SPIRV::OperandCategory::ExtensionOperand, Value: Ext),
445 Inst);
446 outputMCInst(Inst);
447 }
448 // TODO add a pseudo instr for version number.
449}
450
451void SPIRVAsmPrinter::outputExtFuncDecls() {
452 // Insert OpFunctionEnd after each declaration.
453 auto I = MAI->getMSInstrs(MSType: SPIRV::MB_ExtFuncDecls).begin(),
454 E = MAI->getMSInstrs(MSType: SPIRV::MB_ExtFuncDecls).end();
455 for (; I != E; ++I) {
456 outputInstruction(MI: *I);
457 if ((I + 1) == E || (*(I + 1))->getOpcode() == SPIRV::OpFunction)
458 outputOpFunctionEnd();
459 }
460}
461
462// Encode LLVM type by SPIR-V execution mode VecTypeHint.
463static unsigned encodeVecTypeHint(Type *Ty) {
464 if (Ty->isHalfTy())
465 return 4;
466 if (Ty->isFloatTy())
467 return 5;
468 if (Ty->isDoubleTy())
469 return 6;
470 if (IntegerType *IntTy = dyn_cast<IntegerType>(Val: Ty)) {
471 switch (IntTy->getIntegerBitWidth()) {
472 case 8:
473 return 0;
474 case 16:
475 return 1;
476 case 32:
477 return 2;
478 case 64:
479 return 3;
480 default:
481 llvm_unreachable("invalid integer type");
482 }
483 }
484 if (FixedVectorType *VecTy = dyn_cast<FixedVectorType>(Val: Ty)) {
485 Type *EleTy = VecTy->getElementType();
486 unsigned Size = VecTy->getNumElements();
487 return Size << 16 | encodeVecTypeHint(Ty: EleTy);
488 }
489 llvm_unreachable("invalid type");
490}
491
492static void addOpsFromMDNode(MDNode *MDN, MCInst &Inst,
493 SPIRV::ModuleAnalysisInfo *MAI) {
494 for (const MDOperand &MDOp : MDN->operands()) {
495 if (auto *CMeta = dyn_cast<ConstantAsMetadata>(Val: MDOp)) {
496 Constant *C = CMeta->getValue();
497 if (ConstantInt *Const = dyn_cast<ConstantInt>(Val: C)) {
498 Inst.addOperand(Op: MCOperand::createImm(Val: Const->getZExtValue()));
499 } else if (auto *CE = dyn_cast<Function>(Val: C)) {
500 MCRegister FuncReg = MAI->getGlobalObjReg(GO: CE);
501 assert(FuncReg.isValid());
502 Inst.addOperand(Op: MCOperand::createReg(Reg: FuncReg));
503 }
504 }
505 }
506}
507
508void SPIRVAsmPrinter::outputExecutionModeFromMDNode(
509 MCRegister Reg, MDNode *Node, SPIRV::ExecutionMode::ExecutionMode EM,
510 unsigned ExpectMDOps, int64_t DefVal) {
511 MCInst Inst;
512 Inst.setOpcode(SPIRV::OpExecutionMode);
513 Inst.addOperand(Op: MCOperand::createReg(Reg));
514 Inst.addOperand(Op: MCOperand::createImm(Val: static_cast<unsigned>(EM)));
515 addOpsFromMDNode(MDN: Node, Inst, MAI);
516 // reqd_work_group_size and work_group_size_hint require 3 operands,
517 // if metadata contains less operands, just add a default value
518 unsigned NodeSz = Node->getNumOperands();
519 if (ExpectMDOps > 0 && NodeSz < ExpectMDOps)
520 for (unsigned i = NodeSz; i < ExpectMDOps; ++i)
521 Inst.addOperand(Op: MCOperand::createImm(Val: DefVal));
522 outputMCInst(Inst);
523}
524
525void SPIRVAsmPrinter::outputExecutionModeFromNumthreadsAttribute(
526 const MCRegister &Reg, const Attribute &Attr,
527 SPIRV::ExecutionMode::ExecutionMode EM) {
528 assert(Attr.isValid() && "Function called with an invalid attribute.");
529
530 MCInst Inst;
531 Inst.setOpcode(SPIRV::OpExecutionMode);
532 Inst.addOperand(Op: MCOperand::createReg(Reg));
533 Inst.addOperand(Op: MCOperand::createImm(Val: static_cast<unsigned>(EM)));
534
535 SmallVector<StringRef> NumThreads;
536 Attr.getValueAsString().split(A&: NumThreads, Separator: ',');
537 assert(NumThreads.size() == 3 && "invalid numthreads");
538 for (uint32_t i = 0; i < 3; ++i) {
539 uint32_t V;
540 [[maybe_unused]] bool Result = NumThreads[i].getAsInteger(Radix: 10, Result&: V);
541 assert(!Result && "Failed to parse numthreads");
542 Inst.addOperand(Op: MCOperand::createImm(Val: V));
543 }
544
545 outputMCInst(Inst);
546}
547
548void SPIRVAsmPrinter::outputExecutionModeFromEnableMaximalReconvergenceAttr(
549 const MCRegister &Reg, const SPIRVSubtarget &ST) {
550 assert(ST.canUseExtension(SPIRV::Extension::SPV_KHR_maximal_reconvergence) &&
551 "Function called when SPV_KHR_maximal_reconvergence is not enabled.");
552
553 MCInst Inst;
554 Inst.setOpcode(SPIRV::OpExecutionMode);
555 Inst.addOperand(Op: MCOperand::createReg(Reg));
556 unsigned EM =
557 static_cast<unsigned>(SPIRV::ExecutionMode::MaximallyReconvergesKHR);
558 Inst.addOperand(Op: MCOperand::createImm(Val: EM));
559 outputMCInst(Inst);
560}
561
562void SPIRVAsmPrinter::outputExecutionMode(const Module &M) {
563 NamedMDNode *Node = M.getNamedMetadata(Name: "spirv.ExecutionMode");
564 if (Node) {
565 for (unsigned i = 0; i < Node->getNumOperands(); i++) {
566 const auto EM =
567 cast<ConstantInt>(
568 Val: cast<ConstantAsMetadata>(Val: (Node->getOperand(i))->getOperand(I: 1))
569 ->getValue())
570 ->getZExtValue();
571 // Skip ArithmeticPoisonKHR to avoid a duplicate.
572 if (EM == SPIRV::ExecutionMode::ArithmeticPoisonKHR)
573 continue;
574 // If SPV_KHR_float_controls2 is enabled and we find any of
575 // FPFastMathDefault, ContractionOff or SignedZeroInfNanPreserve execution
576 // modes, skip it, it'll be done somewhere else.
577 if (ST->canUseExtension(E: SPIRV::Extension::SPV_KHR_float_controls2)) {
578 if (EM == SPIRV::ExecutionMode::FPFastMathDefault ||
579 EM == SPIRV::ExecutionMode::ContractionOff ||
580 EM == SPIRV::ExecutionMode::SignedZeroInfNanPreserve)
581 continue;
582 }
583
584 MCInst Inst;
585 Inst.setOpcode(SPIRV::OpExecutionMode);
586 addOpsFromMDNode(MDN: cast<MDNode>(Val: Node->getOperand(i)), Inst, MAI);
587 outputMCInst(Inst);
588 }
589 outputFPFastMathDefaultInfo();
590 }
591 for (auto FI = M.begin(), E = M.end(); FI != E; ++FI) {
592 const Function &F = *FI;
593 // Only operands of OpEntryPoint instructions are allowed to be
594 // <Entry Point> operands of OpExecutionMode
595 if (F.isDeclaration() || !isEntryPoint(F))
596 continue;
597 MCRegister FReg = MAI->getGlobalObjReg(GO: &F);
598 assert(FReg.isValid());
599
600 if (Attribute Attr = F.getFnAttribute(Kind: "hlsl.shader"); Attr.isValid()) {
601 // SPIR-V common validation: Fragment requires OriginUpperLeft or
602 // OriginLowerLeft.
603 // VUID-StandaloneSpirv-OriginLowerLeft-04653: Fragment must declare
604 // OriginUpperLeft.
605 if (Attr.getValueAsString() == "pixel") {
606 MCInst Inst;
607 Inst.setOpcode(SPIRV::OpExecutionMode);
608 Inst.addOperand(Op: MCOperand::createReg(Reg: FReg));
609 unsigned EM =
610 static_cast<unsigned>(SPIRV::ExecutionMode::OriginUpperLeft);
611 Inst.addOperand(Op: MCOperand::createImm(Val: EM));
612 outputMCInst(Inst);
613 }
614 }
615 if (MDNode *Node = F.getMetadata(Kind: "reqd_work_group_size"))
616 outputExecutionModeFromMDNode(Reg: FReg, Node, EM: SPIRV::ExecutionMode::LocalSize,
617 ExpectMDOps: 3, DefVal: 1);
618 if (Attribute Attr = F.getFnAttribute(Kind: "hlsl.numthreads"); Attr.isValid())
619 outputExecutionModeFromNumthreadsAttribute(
620 Reg: FReg, Attr, EM: SPIRV::ExecutionMode::LocalSize);
621 if (Attribute Attr = F.getFnAttribute(Kind: "enable-maximal-reconvergence");
622 Attr.getValueAsBool()) {
623 outputExecutionModeFromEnableMaximalReconvergenceAttr(Reg: FReg, ST: *ST);
624 }
625 if (MDNode *Node = F.getMetadata(Kind: "work_group_size_hint"))
626 outputExecutionModeFromMDNode(Reg: FReg, Node,
627 EM: SPIRV::ExecutionMode::LocalSizeHint, ExpectMDOps: 3, DefVal: 1);
628 if (MDNode *Node = F.getMetadata(Kind: "reqd_sub_group_size"))
629 outputExecutionModeFromMDNode(Reg: FReg, Node,
630 EM: SPIRV::ExecutionMode::SubgroupSize, ExpectMDOps: 0, DefVal: 0);
631 if (MDNode *Node = F.getMetadata(Kind: "intel_reqd_sub_group_size"))
632 outputExecutionModeFromMDNode(Reg: FReg, Node,
633 EM: SPIRV::ExecutionMode::SubgroupSize, ExpectMDOps: 0, DefVal: 0);
634 if (MDNode *Node = F.getMetadata(Kind: "max_work_group_size")) {
635 if (ST->canUseExtension(E: SPIRV::Extension::SPV_INTEL_kernel_attributes))
636 outputExecutionModeFromMDNode(
637 Reg: FReg, Node, EM: SPIRV::ExecutionMode::MaxWorkgroupSizeINTEL, ExpectMDOps: 3, DefVal: 1);
638 }
639 if (MDNode *Node = F.getMetadata(Kind: "vec_type_hint")) {
640 MCInst Inst;
641 Inst.setOpcode(SPIRV::OpExecutionMode);
642 Inst.addOperand(Op: MCOperand::createReg(Reg: FReg));
643 unsigned EM = static_cast<unsigned>(SPIRV::ExecutionMode::VecTypeHint);
644 Inst.addOperand(Op: MCOperand::createImm(Val: EM));
645 unsigned TypeCode = encodeVecTypeHint(Ty: getMDOperandAsType(N: Node, I: 0));
646 Inst.addOperand(Op: MCOperand::createImm(Val: TypeCode));
647 outputMCInst(Inst);
648 }
649 // Per SPV_KHR_poison_freeze description of PoisonFreezeKHR "If declared,
650 // all entry points must use the ArithmeticPoisonKHR execution mode".
651 if (llvm::is_contained(Range: MAI->Reqs.getMinimalCapabilities(),
652 Element: SPIRV::Capability::PoisonFreezeKHR)) {
653 MCInst Inst;
654 Inst.setOpcode(SPIRV::OpExecutionMode);
655 Inst.addOperand(Op: MCOperand::createReg(Reg: FReg));
656 unsigned EM =
657 static_cast<unsigned>(SPIRV::ExecutionMode::ArithmeticPoisonKHR);
658 Inst.addOperand(Op: MCOperand::createImm(Val: EM));
659 outputMCInst(Inst);
660 }
661 // --spirv-fp-contract=off forces to emit ContractionOff for this kernel
662 // entry point, --spirv-fp-contract=fast suppresses it.
663 bool EmitContractionOff =
664 ST->isKernel() && !M.getNamedMetadata(Name: "spirv.ExecutionMode") &&
665 SPIRVFPContract != SPIRVFPContractMode::Fast &&
666 (SPIRVFPContract == SPIRVFPContractMode::Off ||
667 !M.getNamedMetadata(Name: "opencl.enable.FP_CONTRACT"));
668 if (EmitContractionOff) {
669 if (ST->canUseExtension(E: SPIRV::Extension::SPV_KHR_float_controls2)) {
670 // When SPV_KHR_float_controls2 is enabled, ContractionOff is
671 // deprecated. We need to use FPFastMathDefault with the appropriate
672 // flags instead. Since FPFastMathDefault takes a target type, we need
673 // to emit it for each floating-point type that exists in the module
674 // to match the effect of ContractionOff. As of now, there are 3 FP
675 // types: fp16, fp32 and fp64.
676
677 // We only end up here because there is no "spirv.ExecutionMode"
678 // metadata, so that means no FPFastMathDefault. Therefore, we only
679 // need to make sure AllowContract is set to 0, as the rest of flags.
680 // We still need to emit the OpExecutionMode instruction, otherwise
681 // it's up to the client API to define the flags. Therefore, we need
682 // to find the constant with 0 value.
683
684 // Collect the SPIRVTypes for fp16, fp32, and fp64 and the constant of
685 // type int32 with 0 value to represent the FP Fast Math Mode.
686 std::vector<const MachineInstr *> SPIRVFloatTypes;
687 const MachineInstr *ConstZeroInt32 = nullptr;
688 for (const MachineInstr *MI :
689 MAI->getMSInstrs(MSType: SPIRV::MB_TypeConstVars)) {
690 unsigned OpCode = MI->getOpcode();
691
692 // Collect the SPIRV type if it's a float.
693 if (OpCode == SPIRV::OpTypeFloat) {
694 // Skip if the target type is not fp16, fp32, fp64.
695 const unsigned OpTypeFloatSize = MI->getOperand(i: 1).getImm();
696 if (OpTypeFloatSize != 16 && OpTypeFloatSize != 32 &&
697 OpTypeFloatSize != 64) {
698 continue;
699 }
700 SPIRVFloatTypes.push_back(x: MI);
701 continue;
702 }
703
704 if (OpCode == SPIRV::OpConstantNull) {
705 // Check if the constant is int32, if not skip it.
706 const MachineRegisterInfo &MRI = MI->getMF()->getRegInfo();
707 MachineInstr *TypeMI = MRI.getVRegDef(Reg: MI->getOperand(i: 1).getReg());
708 bool IsInt32Ty = TypeMI &&
709 TypeMI->getOpcode() == SPIRV::OpTypeInt &&
710 TypeMI->getOperand(i: 1).getImm() == 32;
711 if (IsInt32Ty)
712 ConstZeroInt32 = MI;
713 }
714 }
715
716 // When SPV_KHR_float_controls2 is enabled, ContractionOff is
717 // deprecated. We need to use FPFastMathDefault with the appropriate
718 // flags instead. Since FPFastMathDefault takes a target type, we need
719 // to emit it for each floating-point type that exists in the module
720 // to match the effect of ContractionOff. As of now, there are 3 FP
721 // types: fp16, fp32 and fp64.
722 for (const MachineInstr *MI : SPIRVFloatTypes) {
723 MCInst Inst;
724 Inst.setOpcode(SPIRV::OpExecutionModeId);
725 Inst.addOperand(Op: MCOperand::createReg(Reg: FReg));
726 unsigned EM =
727 static_cast<unsigned>(SPIRV::ExecutionMode::FPFastMathDefault);
728 Inst.addOperand(Op: MCOperand::createImm(Val: EM));
729 const MachineFunction *MF = MI->getMF();
730 MCRegister TypeReg =
731 MAI->getRegisterAlias(MF, Reg: MI->getOperand(i: 0).getReg());
732 Inst.addOperand(Op: MCOperand::createReg(Reg: TypeReg));
733 assert(ConstZeroInt32 && "There should be a constant zero.");
734 MCRegister ConstReg = MAI->getRegisterAlias(
735 MF: ConstZeroInt32->getMF(), Reg: ConstZeroInt32->getOperand(i: 0).getReg());
736 Inst.addOperand(Op: MCOperand::createReg(Reg: ConstReg));
737 outputMCInst(Inst);
738 }
739 } else {
740 MCInst Inst;
741 Inst.setOpcode(SPIRV::OpExecutionMode);
742 Inst.addOperand(Op: MCOperand::createReg(Reg: FReg));
743 unsigned EM =
744 static_cast<unsigned>(SPIRV::ExecutionMode::ContractionOff);
745 Inst.addOperand(Op: MCOperand::createImm(Val: EM));
746 outputMCInst(Inst);
747 }
748 }
749 }
750}
751
752void SPIRVAsmPrinter::outputAnnotations(const Module &M) {
753 outputModuleSection(MSType: SPIRV::MB_Annotations);
754 // Process llvm.global.annotations special global variable.
755 for (auto F = M.global_begin(), E = M.global_end(); F != E; ++F) {
756 if ((*F).getName() != "llvm.global.annotations")
757 continue;
758 const GlobalVariable *V = &(*F);
759 const ConstantArray *CA = cast<ConstantArray>(Val: V->getOperand(i_nocapture: 0));
760 for (Value *Op : CA->operands()) {
761 ConstantStruct *CS = cast<ConstantStruct>(Val: Op);
762 // The first field of the struct contains a pointer to
763 // the annotated variable.
764 Value *AnnotatedVar = CS->getOperand(i_nocapture: 0)->stripPointerCasts();
765 auto *GO = dyn_cast<GlobalObject>(Val: AnnotatedVar);
766 MCRegister Reg = GO ? MAI->getGlobalObjReg(GO) : MCRegister();
767 if (!Reg.isValid()) {
768 std::string DiagMsg;
769 raw_string_ostream OS(DiagMsg);
770 AnnotatedVar->print(O&: OS);
771 DiagMsg = "Unsupported value in llvm.global.annotations: " + DiagMsg;
772 report_fatal_error(reason: DiagMsg.c_str());
773 }
774
775 // The second field contains a pointer to a global annotation string.
776 GlobalVariable *GV =
777 cast<GlobalVariable>(Val: CS->getOperand(i_nocapture: 1)->stripPointerCasts());
778
779 StringRef AnnotationString;
780 [[maybe_unused]] bool Success =
781 getConstantStringInfo(V: GV, Str&: AnnotationString);
782 assert(Success && "Failed to get annotation string");
783 MCInst Inst;
784 Inst.setOpcode(SPIRV::OpDecorate);
785 Inst.addOperand(Op: MCOperand::createReg(Reg));
786 unsigned Dec = static_cast<unsigned>(SPIRV::Decoration::UserSemantic);
787 Inst.addOperand(Op: MCOperand::createImm(Val: Dec));
788 addStringImm(Str: AnnotationString, Inst);
789 outputMCInst(Inst);
790 }
791 }
792}
793
794void SPIRVAsmPrinter::outputFPFastMathDefaultInfo() {
795 // Collect the SPIRVTypes that are OpTypeFloat and the constants of type
796 // int32, that might be used as FP Fast Math Mode.
797 std::vector<const MachineInstr *> SPIRVFloatTypes;
798 // Hashtable to associate immediate values with the constant holding them.
799 DenseMap<int, const MachineInstr *> ConstMap;
800 for (const MachineInstr *MI : MAI->getMSInstrs(MSType: SPIRV::MB_TypeConstVars)) {
801 // Skip if the instruction is not OpTypeFloat or OpConstant.
802 unsigned OpCode = MI->getOpcode();
803 if (OpCode != SPIRV::OpTypeFloat && OpCode != SPIRV::OpConstantI &&
804 OpCode != SPIRV::OpConstantNull)
805 continue;
806
807 // Collect the SPIRV type if it's a float.
808 if (OpCode == SPIRV::OpTypeFloat) {
809 SPIRVFloatTypes.push_back(x: MI);
810 } else {
811 // Check if the constant is int32, if not skip it.
812 const MachineRegisterInfo &MRI = MI->getMF()->getRegInfo();
813 MachineInstr *TypeMI = MRI.getVRegDef(Reg: MI->getOperand(i: 1).getReg());
814 if (!TypeMI || TypeMI->getOpcode() != SPIRV::OpTypeInt ||
815 TypeMI->getOperand(i: 1).getImm() != 32)
816 continue;
817
818 if (OpCode == SPIRV::OpConstantI)
819 ConstMap[MI->getOperand(i: 2).getImm()] = MI;
820 else
821 ConstMap[0] = MI;
822 }
823 }
824
825 for (const auto &[Func, FPFastMathDefaultInfoVec] :
826 MAI->FPFastMathDefaultInfoMap) {
827 if (FPFastMathDefaultInfoVec.empty())
828 continue;
829
830 for (const MachineInstr *MI : SPIRVFloatTypes) {
831 unsigned OpTypeFloatSize = MI->getOperand(i: 1).getImm();
832 unsigned Index = SPIRV::FPFastMathDefaultInfoVector::
833 computeFPFastMathDefaultInfoVecIndex(BitWidth: OpTypeFloatSize);
834 assert(Index < FPFastMathDefaultInfoVec.size() &&
835 "Index out of bounds for FPFastMathDefaultInfoVec");
836 const auto &FPFastMathDefaultInfo = FPFastMathDefaultInfoVec[Index];
837 assert(FPFastMathDefaultInfo.Ty &&
838 "Expected target type for FPFastMathDefaultInfo");
839 assert(FPFastMathDefaultInfo.Ty->getScalarSizeInBits() ==
840 OpTypeFloatSize &&
841 "Mismatched float type size");
842 MCInst Inst;
843 Inst.setOpcode(SPIRV::OpExecutionModeId);
844 MCRegister FuncReg = MAI->getGlobalObjReg(GO: Func);
845 assert(FuncReg.isValid());
846 Inst.addOperand(Op: MCOperand::createReg(Reg: FuncReg));
847 Inst.addOperand(
848 Op: MCOperand::createImm(Val: SPIRV::ExecutionMode::FPFastMathDefault));
849 MCRegister TypeReg =
850 MAI->getRegisterAlias(MF: MI->getMF(), Reg: MI->getOperand(i: 0).getReg());
851 Inst.addOperand(Op: MCOperand::createReg(Reg: TypeReg));
852 unsigned Flags = FPFastMathDefaultInfo.FastMathFlags;
853 if (FPFastMathDefaultInfo.ContractionOff &&
854 (Flags & SPIRV::FPFastMathMode::AllowContract))
855 report_fatal_error(
856 reason: "Conflicting FPFastMathFlags: ContractionOff and AllowContract");
857
858 if (FPFastMathDefaultInfo.SignedZeroInfNanPreserve &&
859 !(Flags &
860 (SPIRV::FPFastMathMode::NotNaN | SPIRV::FPFastMathMode::NotInf |
861 SPIRV::FPFastMathMode::NSZ))) {
862 if (FPFastMathDefaultInfo.FPFastMathDefault)
863 report_fatal_error(reason: "Conflicting FPFastMathFlags: "
864 "SignedZeroInfNanPreserve but at least one of "
865 "NotNaN/NotInf/NSZ is enabled.");
866 }
867
868 // Don't emit if none of the execution modes was used.
869 if (Flags == SPIRV::FPFastMathMode::None &&
870 !FPFastMathDefaultInfo.ContractionOff &&
871 !FPFastMathDefaultInfo.SignedZeroInfNanPreserve &&
872 !FPFastMathDefaultInfo.FPFastMathDefault)
873 continue;
874
875 // Retrieve the constant instruction for the immediate value.
876 auto It = ConstMap.find(Val: Flags);
877 if (It == ConstMap.end())
878 report_fatal_error(reason: "Expected constant instruction for FP Fast Math "
879 "Mode operand of FPFastMathDefault execution mode.");
880 const MachineInstr *ConstMI = It->second;
881 MCRegister ConstReg = MAI->getRegisterAlias(
882 MF: ConstMI->getMF(), Reg: ConstMI->getOperand(i: 0).getReg());
883 Inst.addOperand(Op: MCOperand::createReg(Reg: ConstReg));
884 outputMCInst(Inst);
885 }
886 }
887}
888
889void SPIRVAsmPrinter::outputModuleSections() {
890 const Module *M = MMI->getModule();
891 // Get the global subtarget to output module-level info.
892 ST = static_cast<const SPIRVTargetMachine &>(TM).getSubtargetImpl();
893 TII = ST->getInstrInfo();
894 MAI = &getAnalysis<SPIRVModuleAnalysis>().MAI;
895 assert(ST && TII && MAI && M && "Module analysis is required");
896
897 if (!AuxDataHandler) {
898 auto Handler = std::make_unique<SPIRVAuxDataHandler>(args&: *this, args: *M);
899 if (Handler->hasWork())
900 AuxDataHandler = std::move(Handler);
901 }
902
903 // Let the NSDI handler add its extension and ext inst import entry to MAI
904 // before the module header sections are emitted.
905 if (NSDebugHandler)
906 NSDebugHandler->prepareModuleOutput(ST: *ST, MAI&: *MAI);
907 if (AuxDataHandler)
908 AuxDataHandler->prepareModuleOutput(ST: *ST, MAI&: *MAI);
909
910 // Output instructions according to the Logical Layout of a Module:
911 // 1,2. All OpCapability instructions, then optional OpExtension
912 // instructions.
913 outputGlobalRequirements();
914 // 3. Optional OpExtInstImport instructions.
915 outputOpExtInstImports(M: *M);
916 // 4. The single required OpMemoryModel instruction.
917 outputOpMemoryModel();
918 // 5. All entry point declarations, using OpEntryPoint.
919 outputEntryPoints();
920 // 6. Execution-mode declarations, using OpExecutionMode or
921 // OpExecutionModeId.
922 outputExecutionMode(M: *M);
923 // 7a. Debug: all OpString, OpSourceExtension, OpSource, and
924 // OpSourceContinued, without forward references.
925 outputDebugSourceAndStrings(M: *M);
926 // 7b. Debug: all OpName and all OpMemberName.
927 outputModuleSection(MSType: SPIRV::MB_DebugNames);
928 // 7c. Debug: all OpModuleProcessed instructions.
929 outputModuleSection(MSType: SPIRV::MB_DebugModuleProcessed);
930 // xxx. SPV_INTEL_memory_access_aliasing instructions go before 8.
931 // "All annotation instructions"
932 outputModuleSection(MSType: SPIRV::MB_AliasingInsts);
933 // 8. All annotation instructions (all decorations).
934 outputAnnotations(M: *M);
935 // 9. All type declarations (OpTypeXXX instructions), all constant
936 // instructions, and all global variable declarations. This section is
937 // the first section to allow use of: OpLine and OpNoLine debug information;
938 // non-semantic instructions with OpExtInst.
939 outputModuleSection(MSType: SPIRV::MB_TypeConstVars);
940 // 10. All global NonSemantic.Shader.DebugInfo.100 instructions. The
941 // SPIRVNonSemanticDebugHandler emits these directly as MCInsts; the
942 // MB_NonSemanticGlobalDI section in MAI is intentionally left empty.
943 if (NSDebugHandler)
944 NSDebugHandler->emitNonSemanticGlobalDebugInfo(MAI&: *MAI);
945 if (AuxDataHandler)
946 AuxDataHandler->emitAuxData(MAI&: *MAI);
947 // 11. All function declarations (functions without a body).
948 outputExtFuncDecls();
949 // 12. All function definitions (functions with a body).
950 // This is done in regular function output.
951}
952
953bool SPIRVAsmPrinter::doInitialization(Module &M) {
954 ModuleSectionsEmitted = false;
955 if (!M.getModuleInlineAsm().empty()) {
956 M.getContext().emitError(
957 ErrorStr: "SPIR-V does not support module-level inline assembly");
958 M.setModuleInlineAsm("");
959 }
960
961 // Register the NSDI handler before calling the base class so that
962 // AsmPrinter::doInitialization() calls Handler->beginModule(M) for it.
963 if (M.getNamedMetadata(Name: "llvm.dbg.cu")) {
964 auto Handler = std::make_unique<SPIRVNonSemanticDebugHandler>(args&: *this);
965 NSDebugHandler = Handler.get();
966 addAsmPrinterHandler(Handler: std::move(Handler));
967 }
968 // We need to call the parent's one explicitly.
969 return AsmPrinter::doInitialization(M);
970}
971
972char SPIRVAsmPrinter::ID = 0;
973
974INITIALIZE_PASS(SPIRVAsmPrinter, "spirv-asm-printer", "SPIRV Assembly Printer",
975 false, false)
976
977// Force static initialization.
978extern "C" LLVM_ABI LLVM_EXTERNAL_VISIBILITY void
979LLVMInitializeSPIRVAsmPrinter() {
980 RegisterAsmPrinter<SPIRVAsmPrinter> X(getTheSPIRV32Target());
981 RegisterAsmPrinter<SPIRVAsmPrinter> Y(getTheSPIRV64Target());
982 RegisterAsmPrinter<SPIRVAsmPrinter> Z(getTheSPIRVLogicalTarget());
983}
984