1//===-- SPIRVAsmPrinter.cpp - SPIR-V LLVM assembly writer ------*- C++ -*--===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file contains a printer that converts from our internal representation
10// of machine-dependent LLVM code to the SPIR-V assembly language.
11//
12//===----------------------------------------------------------------------===//
13
14#include "MCTargetDesc/SPIRVInstPrinter.h"
15#include "SPIRV.h"
16#include "SPIRVInstrInfo.h"
17#include "SPIRVMCInstLower.h"
18#include "SPIRVModuleAnalysis.h"
19#include "SPIRVSubtarget.h"
20#include "SPIRVTargetMachine.h"
21#include "SPIRVUtils.h"
22#include "TargetInfo/SPIRVTargetInfo.h"
23#include "llvm/ADT/DenseMap.h"
24#include "llvm/Analysis/ValueTracking.h"
25#include "llvm/CodeGen/AsmPrinter.h"
26#include "llvm/CodeGen/MachineConstantPool.h"
27#include "llvm/CodeGen/MachineInstr.h"
28#include "llvm/CodeGen/MachineModuleInfo.h"
29#include "llvm/CodeGen/TargetLoweringObjectFileImpl.h"
30#include "llvm/MC/MCAsmInfo.h"
31#include "llvm/MC/MCAssembler.h"
32#include "llvm/MC/MCInst.h"
33#include "llvm/MC/MCObjectStreamer.h"
34#include "llvm/MC/MCSPIRVObjectWriter.h"
35#include "llvm/MC/MCStreamer.h"
36#include "llvm/MC/MCSymbol.h"
37#include "llvm/MC/TargetRegistry.h"
38#include "llvm/Support/Compiler.h"
39#include "llvm/Support/raw_ostream.h"
40
41using namespace llvm;
42
43#define DEBUG_TYPE "asm-printer"
44
45namespace {
46class SPIRVAsmPrinter : public AsmPrinter {
47 unsigned NLabels = 0;
48 SmallPtrSet<const MachineBasicBlock *, 8> LabeledMBB;
49
50public:
51 explicit SPIRVAsmPrinter(TargetMachine &TM,
52 std::unique_ptr<MCStreamer> Streamer)
53 : AsmPrinter(TM, std::move(Streamer), ID), ModuleSectionsEmitted(false),
54 ST(nullptr), TII(nullptr), MAI(nullptr) {}
55 static char ID;
56 bool ModuleSectionsEmitted;
57 const SPIRVSubtarget *ST;
58 const SPIRVInstrInfo *TII;
59
60 StringRef getPassName() const override { return "SPIRV Assembly Printer"; }
61 void printOperand(const MachineInstr *MI, int OpNum, raw_ostream &O);
62 bool PrintAsmOperand(const MachineInstr *MI, unsigned OpNo,
63 const char *ExtraCode, raw_ostream &O) override;
64
65 void outputMCInst(MCInst &Inst);
66 void outputInstruction(const MachineInstr *MI);
67 void outputModuleSection(SPIRV::ModuleSectionType MSType);
68 void outputGlobalRequirements();
69 void outputEntryPoints();
70 void outputDebugSourceAndStrings(const Module &M);
71 void outputOpExtInstImports(const Module &M);
72 void outputOpMemoryModel();
73 void outputOpFunctionEnd();
74 void outputExtFuncDecls();
75 void outputExecutionModeFromMDNode(MCRegister Reg, MDNode *Node,
76 SPIRV::ExecutionMode::ExecutionMode EM,
77 unsigned ExpectMDOps, int64_t DefVal);
78 void outputExecutionModeFromNumthreadsAttribute(
79 const MCRegister &Reg, const Attribute &Attr,
80 SPIRV::ExecutionMode::ExecutionMode EM);
81 void outputExecutionModeFromEnableMaximalReconvergenceAttr(
82 const MCRegister &Reg, const SPIRVSubtarget &ST);
83 void outputExecutionMode(const Module &M);
84 void outputAnnotations(const Module &M);
85 void outputModuleSections();
86 void outputFPFastMathDefaultInfo();
87 bool isHidden() {
88 return MF->getFunction()
89 .getFnAttribute(SPIRV_BACKEND_SERVICE_FUN_NAME)
90 .isValid();
91 }
92
93 void emitInstruction(const MachineInstr *MI) override;
94 void emitFunctionEntryLabel() override {}
95 void emitFunctionHeader() override;
96 void emitFunctionBodyStart() override {}
97 void emitFunctionBodyEnd() override;
98 void emitBasicBlockStart(const MachineBasicBlock &MBB) override;
99 void emitBasicBlockEnd(const MachineBasicBlock &MBB) override {}
100 void emitGlobalVariable(const GlobalVariable *GV) override {}
101 void emitOpLabel(const MachineBasicBlock &MBB);
102 void emitEndOfAsmFile(Module &M) override;
103 bool doInitialization(Module &M) override;
104
105 void getAnalysisUsage(AnalysisUsage &AU) const override;
106 SPIRV::ModuleAnalysisInfo *MAI;
107
108protected:
109 void cleanUp(Module &M);
110};
111} // namespace
112
113void SPIRVAsmPrinter::getAnalysisUsage(AnalysisUsage &AU) const {
114 AU.addRequired<SPIRVModuleAnalysis>();
115 AU.addPreserved<SPIRVModuleAnalysis>();
116 AsmPrinter::getAnalysisUsage(AU);
117}
118
119// If the module has no functions, we need output global info anyway.
120void SPIRVAsmPrinter::emitEndOfAsmFile(Module &M) {
121 if (ModuleSectionsEmitted == false) {
122 outputModuleSections();
123 ModuleSectionsEmitted = true;
124 }
125
126 ST = static_cast<const SPIRVTargetMachine &>(TM).getSubtargetImpl();
127 VersionTuple SPIRVVersion = ST->getSPIRVVersion();
128 uint32_t Major = SPIRVVersion.getMajor();
129 uint32_t Minor = SPIRVVersion.getMinor().value_or(u: 0);
130 // Bound is an approximation that accounts for the maximum used register
131 // number and number of generated OpLabels
132 unsigned Bound = 2 * (ST->getBound() + 1) + NLabels;
133 if (MCAssembler *Asm = OutStreamer->getAssemblerPtr())
134 static_cast<SPIRVObjectWriter &>(Asm->getWriter())
135 .setBuildVersion(Major, Minor, Bound);
136
137 cleanUp(M);
138}
139
140// Any cleanup actions with the Module after we don't care about its content
141// anymore.
142void SPIRVAsmPrinter::cleanUp(Module &M) {
143 // Verifier disallows uses of intrinsic global variables.
144 for (StringRef GVName :
145 {"llvm.global_ctors", "llvm.global_dtors", "llvm.used"}) {
146 if (GlobalVariable *GV = M.getNamedGlobal(Name: GVName))
147 GV->setName("");
148 }
149}
150
151void SPIRVAsmPrinter::emitFunctionHeader() {
152 if (ModuleSectionsEmitted == false) {
153 outputModuleSections();
154 ModuleSectionsEmitted = true;
155 }
156 // Get the subtarget from the current MachineFunction.
157 ST = &MF->getSubtarget<SPIRVSubtarget>();
158 TII = ST->getInstrInfo();
159 const Function &F = MF->getFunction();
160
161 if (isVerbose() && !isHidden()) {
162 OutStreamer->getCommentOS()
163 << "-- Begin function "
164 << GlobalValue::dropLLVMManglingEscape(Name: F.getName()) << '\n';
165 }
166
167 auto Section = getObjFileLowering().SectionForGlobal(GO: &F, TM);
168 MF->setSection(Section);
169}
170
171void SPIRVAsmPrinter::outputOpFunctionEnd() {
172 MCInst FunctionEndInst;
173 FunctionEndInst.setOpcode(SPIRV::OpFunctionEnd);
174 outputMCInst(Inst&: FunctionEndInst);
175}
176
177void SPIRVAsmPrinter::emitFunctionBodyEnd() {
178 if (!isHidden())
179 outputOpFunctionEnd();
180}
181
182void SPIRVAsmPrinter::emitOpLabel(const MachineBasicBlock &MBB) {
183 // Do not emit anything if it's an internal service function.
184 if (isHidden())
185 return;
186
187 MCInst LabelInst;
188 LabelInst.setOpcode(SPIRV::OpLabel);
189 LabelInst.addOperand(Op: MCOperand::createReg(Reg: MAI->getOrCreateMBBRegister(MBB)));
190 outputMCInst(Inst&: LabelInst);
191 ++NLabels;
192 LabeledMBB.insert(Ptr: &MBB);
193}
194
195void SPIRVAsmPrinter::emitBasicBlockStart(const MachineBasicBlock &MBB) {
196 // Do not emit anything if it's an internal service function.
197 if (MBB.empty())
198 return;
199
200 // If it's the first MBB in MF, it has OpFunction and OpFunctionParameter, so
201 // OpLabel should be output after them.
202 if (MBB.getNumber() == MF->front().getNumber()) {
203 for (const MachineInstr &MI : MBB)
204 if (MI.getOpcode() == SPIRV::OpFunction)
205 return;
206 // TODO: this case should be checked by the verifier.
207 report_fatal_error(reason: "OpFunction is expected in the front MBB of MF");
208 }
209 emitOpLabel(MBB);
210}
211
212void SPIRVAsmPrinter::printOperand(const MachineInstr *MI, int OpNum,
213 raw_ostream &O) {
214 const MachineOperand &MO = MI->getOperand(i: OpNum);
215
216 switch (MO.getType()) {
217 case MachineOperand::MO_Register:
218 O << SPIRVInstPrinter::getRegisterName(Reg: MO.getReg());
219 break;
220
221 case MachineOperand::MO_Immediate:
222 O << MO.getImm();
223 break;
224
225 case MachineOperand::MO_FPImmediate:
226 O << MO.getFPImm();
227 break;
228
229 case MachineOperand::MO_MachineBasicBlock:
230 O << *MO.getMBB()->getSymbol();
231 break;
232
233 case MachineOperand::MO_GlobalAddress:
234 O << *getSymbol(GV: MO.getGlobal());
235 break;
236
237 case MachineOperand::MO_BlockAddress: {
238 MCSymbol *BA = GetBlockAddressSymbol(BA: MO.getBlockAddress());
239 O << BA->getName();
240 break;
241 }
242
243 case MachineOperand::MO_ExternalSymbol:
244 O << *GetExternalSymbolSymbol(Sym: MO.getSymbolName());
245 break;
246
247 case MachineOperand::MO_JumpTableIndex:
248 case MachineOperand::MO_ConstantPoolIndex:
249 default:
250 llvm_unreachable("<unknown operand type>");
251 }
252}
253
254bool SPIRVAsmPrinter::PrintAsmOperand(const MachineInstr *MI, unsigned OpNo,
255 const char *ExtraCode, raw_ostream &O) {
256 if (ExtraCode && ExtraCode[0])
257 return true; // Invalid instruction - SPIR-V does not have special modifiers
258
259 printOperand(MI, OpNum: OpNo, O);
260 return false;
261}
262
263static bool isFuncOrHeaderInstr(const MachineInstr *MI,
264 const SPIRVInstrInfo *TII) {
265 return TII->isHeaderInstr(MI: *MI) || MI->getOpcode() == SPIRV::OpFunction ||
266 MI->getOpcode() == SPIRV::OpFunctionParameter;
267}
268
269void SPIRVAsmPrinter::outputMCInst(MCInst &Inst) {
270 OutStreamer->emitInstruction(Inst, STI: *OutContext.getSubtargetInfo());
271}
272
273void SPIRVAsmPrinter::outputInstruction(const MachineInstr *MI) {
274 SPIRVMCInstLower MCInstLowering;
275 MCInst TmpInst;
276 MCInstLowering.lower(MI, OutMI&: TmpInst, MAI);
277 outputMCInst(Inst&: TmpInst);
278}
279
280void SPIRVAsmPrinter::emitInstruction(const MachineInstr *MI) {
281 SPIRV_MC::verifyInstructionPredicates(Opcode: MI->getOpcode(),
282 Features: getSubtargetInfo().getFeatureBits());
283
284 if (!MAI->getSkipEmission(MI))
285 outputInstruction(MI);
286
287 // Output OpLabel after OpFunction and OpFunctionParameter in the first MBB.
288 const MachineInstr *NextMI = MI->getNextNode();
289 if (!LabeledMBB.contains(Ptr: MI->getParent()) && isFuncOrHeaderInstr(MI, TII) &&
290 (!NextMI || !isFuncOrHeaderInstr(MI: NextMI, TII))) {
291 assert(MI->getParent()->getNumber() == MF->front().getNumber() &&
292 "OpFunction is not in the front MBB of MF");
293 emitOpLabel(MBB: *MI->getParent());
294 }
295}
296
297void SPIRVAsmPrinter::outputModuleSection(SPIRV::ModuleSectionType MSType) {
298 for (const MachineInstr *MI : MAI->getMSInstrs(MSType))
299 outputInstruction(MI);
300}
301
302void SPIRVAsmPrinter::outputDebugSourceAndStrings(const Module &M) {
303 // Output OpSourceExtensions.
304 for (auto &Str : MAI->SrcExt) {
305 MCInst Inst;
306 Inst.setOpcode(SPIRV::OpSourceExtension);
307 addStringImm(Str: Str.first(), Inst);
308 outputMCInst(Inst);
309 }
310 // Output OpString.
311 outputModuleSection(MSType: SPIRV::MB_DebugStrings);
312 // Output OpSource.
313 MCInst Inst;
314 Inst.setOpcode(SPIRV::OpSource);
315 Inst.addOperand(Op: MCOperand::createImm(Val: static_cast<unsigned>(MAI->SrcLang)));
316 Inst.addOperand(
317 Op: MCOperand::createImm(Val: static_cast<unsigned>(MAI->SrcLangVersion)));
318 outputMCInst(Inst);
319}
320
321void SPIRVAsmPrinter::outputOpExtInstImports(const Module &M) {
322 for (auto &CU : MAI->ExtInstSetMap) {
323 unsigned Set = CU.first;
324 MCRegister Reg = CU.second;
325 MCInst Inst;
326 Inst.setOpcode(SPIRV::OpExtInstImport);
327 Inst.addOperand(Op: MCOperand::createReg(Reg));
328 addStringImm(Str: getExtInstSetName(
329 Set: static_cast<SPIRV::InstructionSet::InstructionSet>(Set)),
330 Inst);
331 outputMCInst(Inst);
332 }
333}
334
335void SPIRVAsmPrinter::outputOpMemoryModel() {
336 MCInst Inst;
337 Inst.setOpcode(SPIRV::OpMemoryModel);
338 Inst.addOperand(Op: MCOperand::createImm(Val: static_cast<unsigned>(MAI->Addr)));
339 Inst.addOperand(Op: MCOperand::createImm(Val: static_cast<unsigned>(MAI->Mem)));
340 outputMCInst(Inst);
341}
342
343// Before the OpEntryPoints' output, we need to add the entry point's
344// interfaces. The interface is a list of IDs of global OpVariable instructions.
345// These declare the set of global variables from a module that form
346// the interface of this entry point.
347void SPIRVAsmPrinter::outputEntryPoints() {
348 // Find all OpVariable IDs with required StorageClass.
349 DenseSet<MCRegister> InterfaceIDs;
350 for (const MachineInstr *MI : MAI->GlobalVarList) {
351 assert(MI->getOpcode() == SPIRV::OpVariable);
352 auto SC = static_cast<SPIRV::StorageClass::StorageClass>(
353 MI->getOperand(i: 2).getImm());
354 // Before version 1.4, the interface's storage classes are limited to
355 // the Input and Output storage classes. Starting with version 1.4,
356 // the interface's storage classes are all storage classes used in
357 // declaring all global variables referenced by the entry point call tree.
358 if (ST->isAtLeastSPIRVVer(VerToCompareTo: VersionTuple(1, 4)) ||
359 SC == SPIRV::StorageClass::Input || SC == SPIRV::StorageClass::Output) {
360 const MachineFunction *MF = MI->getMF();
361 MCRegister Reg = MAI->getRegisterAlias(MF, Reg: MI->getOperand(i: 0).getReg());
362 InterfaceIDs.insert(V: Reg);
363 }
364 }
365
366 // Output OpEntryPoints adding interface args to all of them.
367 for (const MachineInstr *MI : MAI->getMSInstrs(MSType: SPIRV::MB_EntryPoints)) {
368 SPIRVMCInstLower MCInstLowering;
369 MCInst TmpInst;
370 MCInstLowering.lower(MI, OutMI&: TmpInst, MAI);
371 for (MCRegister Reg : InterfaceIDs) {
372 assert(Reg.isValid());
373 TmpInst.addOperand(Op: MCOperand::createReg(Reg));
374 }
375 outputMCInst(Inst&: TmpInst);
376 }
377}
378
379// Create global OpCapability instructions for the required capabilities.
380void SPIRVAsmPrinter::outputGlobalRequirements() {
381 // Abort here if not all requirements can be satisfied.
382 MAI->Reqs.checkSatisfiable(ST: *ST);
383
384 for (const auto &Cap : MAI->Reqs.getMinimalCapabilities()) {
385 MCInst Inst;
386 Inst.setOpcode(SPIRV::OpCapability);
387 Inst.addOperand(Op: MCOperand::createImm(Val: Cap));
388 outputMCInst(Inst);
389 }
390
391 // Generate the final OpExtensions with strings instead of enums.
392 for (const auto &Ext : MAI->Reqs.getExtensions()) {
393 MCInst Inst;
394 Inst.setOpcode(SPIRV::OpExtension);
395 addStringImm(Str: getSymbolicOperandMnemonic(
396 Category: SPIRV::OperandCategory::ExtensionOperand, Value: Ext),
397 Inst);
398 outputMCInst(Inst);
399 }
400 // TODO add a pseudo instr for version number.
401}
402
403void SPIRVAsmPrinter::outputExtFuncDecls() {
404 // Insert OpFunctionEnd after each declaration.
405 auto I = MAI->getMSInstrs(MSType: SPIRV::MB_ExtFuncDecls).begin(),
406 E = MAI->getMSInstrs(MSType: SPIRV::MB_ExtFuncDecls).end();
407 for (; I != E; ++I) {
408 outputInstruction(MI: *I);
409 if ((I + 1) == E || (*(I + 1))->getOpcode() == SPIRV::OpFunction)
410 outputOpFunctionEnd();
411 }
412}
413
414// Encode LLVM type by SPIR-V execution mode VecTypeHint.
415static unsigned encodeVecTypeHint(Type *Ty) {
416 if (Ty->isHalfTy())
417 return 4;
418 if (Ty->isFloatTy())
419 return 5;
420 if (Ty->isDoubleTy())
421 return 6;
422 if (IntegerType *IntTy = dyn_cast<IntegerType>(Val: Ty)) {
423 switch (IntTy->getIntegerBitWidth()) {
424 case 8:
425 return 0;
426 case 16:
427 return 1;
428 case 32:
429 return 2;
430 case 64:
431 return 3;
432 default:
433 llvm_unreachable("invalid integer type");
434 }
435 }
436 if (FixedVectorType *VecTy = dyn_cast<FixedVectorType>(Val: Ty)) {
437 Type *EleTy = VecTy->getElementType();
438 unsigned Size = VecTy->getNumElements();
439 return Size << 16 | encodeVecTypeHint(Ty: EleTy);
440 }
441 llvm_unreachable("invalid type");
442}
443
444static void addOpsFromMDNode(MDNode *MDN, MCInst &Inst,
445 SPIRV::ModuleAnalysisInfo *MAI) {
446 for (const MDOperand &MDOp : MDN->operands()) {
447 if (auto *CMeta = dyn_cast<ConstantAsMetadata>(Val: MDOp)) {
448 Constant *C = CMeta->getValue();
449 if (ConstantInt *Const = dyn_cast<ConstantInt>(Val: C)) {
450 Inst.addOperand(Op: MCOperand::createImm(Val: Const->getZExtValue()));
451 } else if (auto *CE = dyn_cast<Function>(Val: C)) {
452 MCRegister FuncReg = MAI->getFuncReg(F: CE);
453 assert(FuncReg.isValid());
454 Inst.addOperand(Op: MCOperand::createReg(Reg: FuncReg));
455 }
456 }
457 }
458}
459
460void SPIRVAsmPrinter::outputExecutionModeFromMDNode(
461 MCRegister Reg, MDNode *Node, SPIRV::ExecutionMode::ExecutionMode EM,
462 unsigned ExpectMDOps, int64_t DefVal) {
463 MCInst Inst;
464 Inst.setOpcode(SPIRV::OpExecutionMode);
465 Inst.addOperand(Op: MCOperand::createReg(Reg));
466 Inst.addOperand(Op: MCOperand::createImm(Val: static_cast<unsigned>(EM)));
467 addOpsFromMDNode(MDN: Node, Inst, MAI);
468 // reqd_work_group_size and work_group_size_hint require 3 operands,
469 // if metadata contains less operands, just add a default value
470 unsigned NodeSz = Node->getNumOperands();
471 if (ExpectMDOps > 0 && NodeSz < ExpectMDOps)
472 for (unsigned i = NodeSz; i < ExpectMDOps; ++i)
473 Inst.addOperand(Op: MCOperand::createImm(Val: DefVal));
474 outputMCInst(Inst);
475}
476
477void SPIRVAsmPrinter::outputExecutionModeFromNumthreadsAttribute(
478 const MCRegister &Reg, const Attribute &Attr,
479 SPIRV::ExecutionMode::ExecutionMode EM) {
480 assert(Attr.isValid() && "Function called with an invalid attribute.");
481
482 MCInst Inst;
483 Inst.setOpcode(SPIRV::OpExecutionMode);
484 Inst.addOperand(Op: MCOperand::createReg(Reg));
485 Inst.addOperand(Op: MCOperand::createImm(Val: static_cast<unsigned>(EM)));
486
487 SmallVector<StringRef> NumThreads;
488 Attr.getValueAsString().split(A&: NumThreads, Separator: ',');
489 assert(NumThreads.size() == 3 && "invalid numthreads");
490 for (uint32_t i = 0; i < 3; ++i) {
491 uint32_t V;
492 [[maybe_unused]] bool Result = NumThreads[i].getAsInteger(Radix: 10, Result&: V);
493 assert(!Result && "Failed to parse numthreads");
494 Inst.addOperand(Op: MCOperand::createImm(Val: V));
495 }
496
497 outputMCInst(Inst);
498}
499
500void SPIRVAsmPrinter::outputExecutionModeFromEnableMaximalReconvergenceAttr(
501 const MCRegister &Reg, const SPIRVSubtarget &ST) {
502 assert(ST.canUseExtension(SPIRV::Extension::SPV_KHR_maximal_reconvergence) &&
503 "Function called when SPV_KHR_maximal_reconvergence is not enabled.");
504
505 MCInst Inst;
506 Inst.setOpcode(SPIRV::OpExecutionMode);
507 Inst.addOperand(Op: MCOperand::createReg(Reg));
508 unsigned EM =
509 static_cast<unsigned>(SPIRV::ExecutionMode::MaximallyReconvergesKHR);
510 Inst.addOperand(Op: MCOperand::createImm(Val: EM));
511 outputMCInst(Inst);
512}
513
514void SPIRVAsmPrinter::outputExecutionMode(const Module &M) {
515 NamedMDNode *Node = M.getNamedMetadata(Name: "spirv.ExecutionMode");
516 if (Node) {
517 for (unsigned i = 0; i < Node->getNumOperands(); i++) {
518 // If SPV_KHR_float_controls2 is enabled and we find any of
519 // FPFastMathDefault, ContractionOff or SignedZeroInfNanPreserve execution
520 // modes, skip it, it'll be done somewhere else.
521 if (ST->canUseExtension(E: SPIRV::Extension::SPV_KHR_float_controls2)) {
522 const auto EM =
523 cast<ConstantInt>(
524 Val: cast<ConstantAsMetadata>(Val: (Node->getOperand(i))->getOperand(I: 1))
525 ->getValue())
526 ->getZExtValue();
527 if (EM == SPIRV::ExecutionMode::FPFastMathDefault ||
528 EM == SPIRV::ExecutionMode::ContractionOff ||
529 EM == SPIRV::ExecutionMode::SignedZeroInfNanPreserve)
530 continue;
531 }
532
533 MCInst Inst;
534 Inst.setOpcode(SPIRV::OpExecutionMode);
535 addOpsFromMDNode(MDN: cast<MDNode>(Val: Node->getOperand(i)), Inst, MAI);
536 outputMCInst(Inst);
537 }
538 outputFPFastMathDefaultInfo();
539 }
540 for (auto FI = M.begin(), E = M.end(); FI != E; ++FI) {
541 const Function &F = *FI;
542 // Only operands of OpEntryPoint instructions are allowed to be
543 // <Entry Point> operands of OpExecutionMode
544 if (F.isDeclaration() || !isEntryPoint(F))
545 continue;
546 MCRegister FReg = MAI->getFuncReg(F: &F);
547 assert(FReg.isValid());
548
549 if (Attribute Attr = F.getFnAttribute(Kind: "hlsl.shader"); Attr.isValid()) {
550 // SPIR-V common validation: Fragment requires OriginUpperLeft or
551 // OriginLowerLeft.
552 // VUID-StandaloneSpirv-OriginLowerLeft-04653: Fragment must declare
553 // OriginUpperLeft.
554 if (Attr.getValueAsString() == "pixel") {
555 MCInst Inst;
556 Inst.setOpcode(SPIRV::OpExecutionMode);
557 Inst.addOperand(Op: MCOperand::createReg(Reg: FReg));
558 unsigned EM =
559 static_cast<unsigned>(SPIRV::ExecutionMode::OriginUpperLeft);
560 Inst.addOperand(Op: MCOperand::createImm(Val: EM));
561 outputMCInst(Inst);
562 }
563 }
564 if (MDNode *Node = F.getMetadata(Kind: "reqd_work_group_size"))
565 outputExecutionModeFromMDNode(Reg: FReg, Node, EM: SPIRV::ExecutionMode::LocalSize,
566 ExpectMDOps: 3, DefVal: 1);
567 if (Attribute Attr = F.getFnAttribute(Kind: "hlsl.numthreads"); Attr.isValid())
568 outputExecutionModeFromNumthreadsAttribute(
569 Reg: FReg, Attr, EM: SPIRV::ExecutionMode::LocalSize);
570 if (Attribute Attr = F.getFnAttribute(Kind: "enable-maximal-reconvergence");
571 Attr.getValueAsBool()) {
572 outputExecutionModeFromEnableMaximalReconvergenceAttr(Reg: FReg, ST: *ST);
573 }
574 if (MDNode *Node = F.getMetadata(Kind: "work_group_size_hint"))
575 outputExecutionModeFromMDNode(Reg: FReg, Node,
576 EM: SPIRV::ExecutionMode::LocalSizeHint, ExpectMDOps: 3, DefVal: 1);
577 if (MDNode *Node = F.getMetadata(Kind: "intel_reqd_sub_group_size"))
578 outputExecutionModeFromMDNode(Reg: FReg, Node,
579 EM: SPIRV::ExecutionMode::SubgroupSize, ExpectMDOps: 0, DefVal: 0);
580 if (MDNode *Node = F.getMetadata(Kind: "max_work_group_size")) {
581 if (ST->canUseExtension(E: SPIRV::Extension::SPV_INTEL_kernel_attributes))
582 outputExecutionModeFromMDNode(
583 Reg: FReg, Node, EM: SPIRV::ExecutionMode::MaxWorkgroupSizeINTEL, ExpectMDOps: 3, DefVal: 1);
584 }
585 if (MDNode *Node = F.getMetadata(Kind: "vec_type_hint")) {
586 MCInst Inst;
587 Inst.setOpcode(SPIRV::OpExecutionMode);
588 Inst.addOperand(Op: MCOperand::createReg(Reg: FReg));
589 unsigned EM = static_cast<unsigned>(SPIRV::ExecutionMode::VecTypeHint);
590 Inst.addOperand(Op: MCOperand::createImm(Val: EM));
591 unsigned TypeCode = encodeVecTypeHint(Ty: getMDOperandAsType(N: Node, I: 0));
592 Inst.addOperand(Op: MCOperand::createImm(Val: TypeCode));
593 outputMCInst(Inst);
594 }
595 if (ST->isKernel() && !M.getNamedMetadata(Name: "spirv.ExecutionMode") &&
596 !M.getNamedMetadata(Name: "opencl.enable.FP_CONTRACT")) {
597 if (ST->canUseExtension(E: SPIRV::Extension::SPV_KHR_float_controls2)) {
598 // When SPV_KHR_float_controls2 is enabled, ContractionOff is
599 // deprecated. We need to use FPFastMathDefault with the appropriate
600 // flags instead. Since FPFastMathDefault takes a target type, we need
601 // to emit it for each floating-point type that exists in the module
602 // to match the effect of ContractionOff. As of now, there are 3 FP
603 // types: fp16, fp32 and fp64.
604
605 // We only end up here because there is no "spirv.ExecutionMode"
606 // metadata, so that means no FPFastMathDefault. Therefore, we only
607 // need to make sure AllowContract is set to 0, as the rest of flags.
608 // We still need to emit the OpExecutionMode instruction, otherwise
609 // it's up to the client API to define the flags. Therefore, we need
610 // to find the constant with 0 value.
611
612 // Collect the SPIRVTypes for fp16, fp32, and fp64 and the constant of
613 // type int32 with 0 value to represent the FP Fast Math Mode.
614 std::vector<const MachineInstr *> SPIRVFloatTypes;
615 const MachineInstr *ConstZeroInt32 = nullptr;
616 for (const MachineInstr *MI :
617 MAI->getMSInstrs(MSType: SPIRV::MB_TypeConstVars)) {
618 unsigned OpCode = MI->getOpcode();
619
620 // Collect the SPIRV type if it's a float.
621 if (OpCode == SPIRV::OpTypeFloat) {
622 // Skip if the target type is not fp16, fp32, fp64.
623 const unsigned OpTypeFloatSize = MI->getOperand(i: 1).getImm();
624 if (OpTypeFloatSize != 16 && OpTypeFloatSize != 32 &&
625 OpTypeFloatSize != 64) {
626 continue;
627 }
628 SPIRVFloatTypes.push_back(x: MI);
629 continue;
630 }
631
632 if (OpCode == SPIRV::OpConstantNull) {
633 // Check if the constant is int32, if not skip it.
634 const MachineRegisterInfo &MRI = MI->getMF()->getRegInfo();
635 MachineInstr *TypeMI = MRI.getVRegDef(Reg: MI->getOperand(i: 1).getReg());
636 bool IsInt32Ty = TypeMI &&
637 TypeMI->getOpcode() == SPIRV::OpTypeInt &&
638 TypeMI->getOperand(i: 1).getImm() == 32;
639 if (IsInt32Ty)
640 ConstZeroInt32 = MI;
641 }
642 }
643
644 // When SPV_KHR_float_controls2 is enabled, ContractionOff is
645 // deprecated. We need to use FPFastMathDefault with the appropriate
646 // flags instead. Since FPFastMathDefault takes a target type, we need
647 // to emit it for each floating-point type that exists in the module
648 // to match the effect of ContractionOff. As of now, there are 3 FP
649 // types: fp16, fp32 and fp64.
650 for (const MachineInstr *MI : SPIRVFloatTypes) {
651 MCInst Inst;
652 Inst.setOpcode(SPIRV::OpExecutionModeId);
653 Inst.addOperand(Op: MCOperand::createReg(Reg: FReg));
654 unsigned EM =
655 static_cast<unsigned>(SPIRV::ExecutionMode::FPFastMathDefault);
656 Inst.addOperand(Op: MCOperand::createImm(Val: EM));
657 const MachineFunction *MF = MI->getMF();
658 MCRegister TypeReg =
659 MAI->getRegisterAlias(MF, Reg: MI->getOperand(i: 0).getReg());
660 Inst.addOperand(Op: MCOperand::createReg(Reg: TypeReg));
661 assert(ConstZeroInt32 && "There should be a constant zero.");
662 MCRegister ConstReg = MAI->getRegisterAlias(
663 MF: ConstZeroInt32->getMF(), Reg: ConstZeroInt32->getOperand(i: 0).getReg());
664 Inst.addOperand(Op: MCOperand::createReg(Reg: ConstReg));
665 outputMCInst(Inst);
666 }
667 } else {
668 MCInst Inst;
669 Inst.setOpcode(SPIRV::OpExecutionMode);
670 Inst.addOperand(Op: MCOperand::createReg(Reg: FReg));
671 unsigned EM =
672 static_cast<unsigned>(SPIRV::ExecutionMode::ContractionOff);
673 Inst.addOperand(Op: MCOperand::createImm(Val: EM));
674 outputMCInst(Inst);
675 }
676 }
677 }
678}
679
680void SPIRVAsmPrinter::outputAnnotations(const Module &M) {
681 outputModuleSection(MSType: SPIRV::MB_Annotations);
682 // Process llvm.global.annotations special global variable.
683 for (auto F = M.global_begin(), E = M.global_end(); F != E; ++F) {
684 if ((*F).getName() != "llvm.global.annotations")
685 continue;
686 const GlobalVariable *V = &(*F);
687 const ConstantArray *CA = cast<ConstantArray>(Val: V->getOperand(i_nocapture: 0));
688 for (Value *Op : CA->operands()) {
689 ConstantStruct *CS = cast<ConstantStruct>(Val: Op);
690 // The first field of the struct contains a pointer to
691 // the annotated variable.
692 Value *AnnotatedVar = CS->getOperand(i_nocapture: 0)->stripPointerCasts();
693 if (!isa<Function>(Val: AnnotatedVar))
694 report_fatal_error(reason: "Unsupported value in llvm.global.annotations");
695 Function *Func = cast<Function>(Val: AnnotatedVar);
696 MCRegister Reg = MAI->getFuncReg(F: Func);
697 if (!Reg.isValid()) {
698 std::string DiagMsg;
699 raw_string_ostream OS(DiagMsg);
700 AnnotatedVar->print(O&: OS);
701 DiagMsg = "Unknown function in llvm.global.annotations: " + DiagMsg;
702 report_fatal_error(reason: DiagMsg.c_str());
703 }
704
705 // The second field contains a pointer to a global annotation string.
706 GlobalVariable *GV =
707 cast<GlobalVariable>(Val: CS->getOperand(i_nocapture: 1)->stripPointerCasts());
708
709 StringRef AnnotationString;
710 [[maybe_unused]] bool Success =
711 getConstantStringInfo(V: GV, Str&: AnnotationString);
712 assert(Success && "Failed to get annotation string");
713 MCInst Inst;
714 Inst.setOpcode(SPIRV::OpDecorate);
715 Inst.addOperand(Op: MCOperand::createReg(Reg));
716 unsigned Dec = static_cast<unsigned>(SPIRV::Decoration::UserSemantic);
717 Inst.addOperand(Op: MCOperand::createImm(Val: Dec));
718 addStringImm(Str: AnnotationString, Inst);
719 outputMCInst(Inst);
720 }
721 }
722}
723
724void SPIRVAsmPrinter::outputFPFastMathDefaultInfo() {
725 // Collect the SPIRVTypes that are OpTypeFloat and the constants of type
726 // int32, that might be used as FP Fast Math Mode.
727 std::vector<const MachineInstr *> SPIRVFloatTypes;
728 // Hashtable to associate immediate values with the constant holding them.
729 std::unordered_map<int, const MachineInstr *> ConstMap;
730 for (const MachineInstr *MI : MAI->getMSInstrs(MSType: SPIRV::MB_TypeConstVars)) {
731 // Skip if the instruction is not OpTypeFloat or OpConstant.
732 unsigned OpCode = MI->getOpcode();
733 if (OpCode != SPIRV::OpTypeFloat && OpCode != SPIRV::OpConstantI &&
734 OpCode != SPIRV::OpConstantNull)
735 continue;
736
737 // Collect the SPIRV type if it's a float.
738 if (OpCode == SPIRV::OpTypeFloat) {
739 SPIRVFloatTypes.push_back(x: MI);
740 } else {
741 // Check if the constant is int32, if not skip it.
742 const MachineRegisterInfo &MRI = MI->getMF()->getRegInfo();
743 MachineInstr *TypeMI = MRI.getVRegDef(Reg: MI->getOperand(i: 1).getReg());
744 if (!TypeMI || TypeMI->getOpcode() != SPIRV::OpTypeInt ||
745 TypeMI->getOperand(i: 1).getImm() != 32)
746 continue;
747
748 if (OpCode == SPIRV::OpConstantI)
749 ConstMap[MI->getOperand(i: 2).getImm()] = MI;
750 else
751 ConstMap[0] = MI;
752 }
753 }
754
755 for (const auto &[Func, FPFastMathDefaultInfoVec] :
756 MAI->FPFastMathDefaultInfoMap) {
757 if (FPFastMathDefaultInfoVec.empty())
758 continue;
759
760 for (const MachineInstr *MI : SPIRVFloatTypes) {
761 unsigned OpTypeFloatSize = MI->getOperand(i: 1).getImm();
762 unsigned Index = SPIRV::FPFastMathDefaultInfoVector::
763 computeFPFastMathDefaultInfoVecIndex(BitWidth: OpTypeFloatSize);
764 assert(Index < FPFastMathDefaultInfoVec.size() &&
765 "Index out of bounds for FPFastMathDefaultInfoVec");
766 const auto &FPFastMathDefaultInfo = FPFastMathDefaultInfoVec[Index];
767 assert(FPFastMathDefaultInfo.Ty &&
768 "Expected target type for FPFastMathDefaultInfo");
769 assert(FPFastMathDefaultInfo.Ty->getScalarSizeInBits() ==
770 OpTypeFloatSize &&
771 "Mismatched float type size");
772 MCInst Inst;
773 Inst.setOpcode(SPIRV::OpExecutionModeId);
774 MCRegister FuncReg = MAI->getFuncReg(F: Func);
775 assert(FuncReg.isValid());
776 Inst.addOperand(Op: MCOperand::createReg(Reg: FuncReg));
777 Inst.addOperand(
778 Op: MCOperand::createImm(Val: SPIRV::ExecutionMode::FPFastMathDefault));
779 MCRegister TypeReg =
780 MAI->getRegisterAlias(MF: MI->getMF(), Reg: MI->getOperand(i: 0).getReg());
781 Inst.addOperand(Op: MCOperand::createReg(Reg: TypeReg));
782 unsigned Flags = FPFastMathDefaultInfo.FastMathFlags;
783 if (FPFastMathDefaultInfo.ContractionOff &&
784 (Flags & SPIRV::FPFastMathMode::AllowContract))
785 report_fatal_error(
786 reason: "Conflicting FPFastMathFlags: ContractionOff and AllowContract");
787
788 if (FPFastMathDefaultInfo.SignedZeroInfNanPreserve &&
789 !(Flags &
790 (SPIRV::FPFastMathMode::NotNaN | SPIRV::FPFastMathMode::NotInf |
791 SPIRV::FPFastMathMode::NSZ))) {
792 if (FPFastMathDefaultInfo.FPFastMathDefault)
793 report_fatal_error(reason: "Conflicting FPFastMathFlags: "
794 "SignedZeroInfNanPreserve but at least one of "
795 "NotNaN/NotInf/NSZ is enabled.");
796 }
797
798 // Don't emit if none of the execution modes was used.
799 if (Flags == SPIRV::FPFastMathMode::None &&
800 !FPFastMathDefaultInfo.ContractionOff &&
801 !FPFastMathDefaultInfo.SignedZeroInfNanPreserve &&
802 !FPFastMathDefaultInfo.FPFastMathDefault)
803 continue;
804
805 // Retrieve the constant instruction for the immediate value.
806 auto It = ConstMap.find(x: Flags);
807 if (It == ConstMap.end())
808 report_fatal_error(reason: "Expected constant instruction for FP Fast Math "
809 "Mode operand of FPFastMathDefault execution mode.");
810 const MachineInstr *ConstMI = It->second;
811 MCRegister ConstReg = MAI->getRegisterAlias(
812 MF: ConstMI->getMF(), Reg: ConstMI->getOperand(i: 0).getReg());
813 Inst.addOperand(Op: MCOperand::createReg(Reg: ConstReg));
814 outputMCInst(Inst);
815 }
816 }
817}
818
819void SPIRVAsmPrinter::outputModuleSections() {
820 const Module *M = MMI->getModule();
821 // Get the global subtarget to output module-level info.
822 ST = static_cast<const SPIRVTargetMachine &>(TM).getSubtargetImpl();
823 TII = ST->getInstrInfo();
824 MAI = &SPIRVModuleAnalysis::MAI;
825 assert(ST && TII && MAI && M && "Module analysis is required");
826 // Output instructions according to the Logical Layout of a Module:
827 // 1,2. All OpCapability instructions, then optional OpExtension
828 // instructions.
829 outputGlobalRequirements();
830 // 3. Optional OpExtInstImport instructions.
831 outputOpExtInstImports(M: *M);
832 // 4. The single required OpMemoryModel instruction.
833 outputOpMemoryModel();
834 // 5. All entry point declarations, using OpEntryPoint.
835 outputEntryPoints();
836 // 6. Execution-mode declarations, using OpExecutionMode or
837 // OpExecutionModeId.
838 outputExecutionMode(M: *M);
839 // 7a. Debug: all OpString, OpSourceExtension, OpSource, and
840 // OpSourceContinued, without forward references.
841 outputDebugSourceAndStrings(M: *M);
842 // 7b. Debug: all OpName and all OpMemberName.
843 outputModuleSection(MSType: SPIRV::MB_DebugNames);
844 // 7c. Debug: all OpModuleProcessed instructions.
845 outputModuleSection(MSType: SPIRV::MB_DebugModuleProcessed);
846 // xxx. SPV_INTEL_memory_access_aliasing instructions go before 8.
847 // "All annotation instructions"
848 outputModuleSection(MSType: SPIRV::MB_AliasingInsts);
849 // 8. All annotation instructions (all decorations).
850 outputAnnotations(M: *M);
851 // 9. All type declarations (OpTypeXXX instructions), all constant
852 // instructions, and all global variable declarations. This section is
853 // the first section to allow use of: OpLine and OpNoLine debug information;
854 // non-semantic instructions with OpExtInst.
855 outputModuleSection(MSType: SPIRV::MB_TypeConstVars);
856 // 10. All global NonSemantic.Shader.DebugInfo.100 instructions.
857 outputModuleSection(MSType: SPIRV::MB_NonSemanticGlobalDI);
858 // 11. All function declarations (functions without a body).
859 outputExtFuncDecls();
860 // 12. All function definitions (functions with a body).
861 // This is done in regular function output.
862}
863
864bool SPIRVAsmPrinter::doInitialization(Module &M) {
865 ModuleSectionsEmitted = false;
866 // We need to call the parent's one explicitly.
867 return AsmPrinter::doInitialization(M);
868}
869
870char SPIRVAsmPrinter::ID = 0;
871
872INITIALIZE_PASS(SPIRVAsmPrinter, "spirv-asm-printer", "SPIRV Assembly Printer",
873 false, false)
874
875// Force static initialization.
876extern "C" LLVM_ABI LLVM_EXTERNAL_VISIBILITY void
877LLVMInitializeSPIRVAsmPrinter() {
878 RegisterAsmPrinter<SPIRVAsmPrinter> X(getTheSPIRV32Target());
879 RegisterAsmPrinter<SPIRVAsmPrinter> Y(getTheSPIRV64Target());
880 RegisterAsmPrinter<SPIRVAsmPrinter> Z(getTheSPIRVLogicalTarget());
881}
882