1//===-- NVPTXAsmPrinter.cpp - NVPTX LLVM assembly writer ------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file contains a printer that converts from our internal representation
10// of machine-dependent LLVM code to NVPTX assembly language.
11//
12//===----------------------------------------------------------------------===//
13
14#include "NVPTXAsmPrinter.h"
15#include "MCTargetDesc/NVPTXBaseInfo.h"
16#include "MCTargetDesc/NVPTXInstPrinter.h"
17#include "MCTargetDesc/NVPTXMCAsmInfo.h"
18#include "MCTargetDesc/NVPTXTargetStreamer.h"
19#include "NVPTX.h"
20#include "NVPTXDwarfDebug.h"
21#include "NVPTXMCExpr.h"
22#include "NVPTXMachineFunctionInfo.h"
23#include "NVPTXRegisterInfo.h"
24#include "NVPTXSubtarget.h"
25#include "NVPTXTargetMachine.h"
26#include "NVPTXUtilities.h"
27#include "NVVMProperties.h"
28#include "TargetInfo/NVPTXTargetInfo.h"
29#include "cl_common_defines.h"
30#include "llvm/ADT/APFloat.h"
31#include "llvm/ADT/APInt.h"
32#include "llvm/ADT/ArrayRef.h"
33#include "llvm/ADT/DenseMap.h"
34#include "llvm/ADT/DenseSet.h"
35#include "llvm/ADT/SmallString.h"
36#include "llvm/ADT/SmallVector.h"
37#include "llvm/ADT/StringExtras.h"
38#include "llvm/ADT/StringRef.h"
39#include "llvm/ADT/Twine.h"
40#include "llvm/ADT/iterator_range.h"
41#include "llvm/Analysis/ConstantFolding.h"
42#include "llvm/CodeGen/Analysis.h"
43#include "llvm/CodeGen/MachineBasicBlock.h"
44#include "llvm/CodeGen/MachineFrameInfo.h"
45#include "llvm/CodeGen/MachineFunction.h"
46#include "llvm/CodeGen/MachineInstr.h"
47#include "llvm/CodeGen/MachineLoopInfo.h"
48#include "llvm/CodeGen/MachineModuleInfo.h"
49#include "llvm/CodeGen/MachineOperand.h"
50#include "llvm/CodeGen/MachineRegisterInfo.h"
51#include "llvm/CodeGen/TargetRegisterInfo.h"
52#include "llvm/CodeGen/ValueTypes.h"
53#include "llvm/CodeGenTypes/MachineValueType.h"
54#include "llvm/IR/Argument.h"
55#include "llvm/IR/Attributes.h"
56#include "llvm/IR/BasicBlock.h"
57#include "llvm/IR/Constant.h"
58#include "llvm/IR/Constants.h"
59#include "llvm/IR/DataLayout.h"
60#include "llvm/IR/DebugInfo.h"
61#include "llvm/IR/DebugInfoMetadata.h"
62#include "llvm/IR/DebugLoc.h"
63#include "llvm/IR/DerivedTypes.h"
64#include "llvm/IR/Function.h"
65#include "llvm/IR/GlobalAlias.h"
66#include "llvm/IR/GlobalValue.h"
67#include "llvm/IR/GlobalVariable.h"
68#include "llvm/IR/Instruction.h"
69#include "llvm/IR/LLVMContext.h"
70#include "llvm/IR/Module.h"
71#include "llvm/IR/Operator.h"
72#include "llvm/IR/Type.h"
73#include "llvm/IR/User.h"
74#include "llvm/MC/MCExpr.h"
75#include "llvm/MC/MCInst.h"
76#include "llvm/MC/MCInstrDesc.h"
77#include "llvm/MC/MCStreamer.h"
78#include "llvm/MC/MCSymbol.h"
79#include "llvm/MC/TargetRegistry.h"
80#include "llvm/Support/Alignment.h"
81#include "llvm/Support/Casting.h"
82#include "llvm/Support/Compiler.h"
83#include "llvm/Support/Endian.h"
84#include "llvm/Support/ErrorHandling.h"
85#include "llvm/Support/NativeFormatting.h"
86#include "llvm/Support/raw_ostream.h"
87#include "llvm/Target/TargetLoweringObjectFile.h"
88#include "llvm/Target/TargetMachine.h"
89#include "llvm/Transforms/Utils/UnrollLoop.h"
90#include <cassert>
91#include <cstdint>
92#include <cstring>
93#include <string>
94
95using namespace llvm;
96
97#define DEPOTNAME "__local_depot"
98
99static StringRef getTextureName(const Value &V) {
100 assert(V.hasName() && "Found texture variable with no name");
101 return V.getName();
102}
103
104static StringRef getSurfaceName(const Value &V) {
105 assert(V.hasName() && "Found surface variable with no name");
106 return V.getName();
107}
108
109static StringRef getSamplerName(const Value &V) {
110 assert(V.hasName() && "Found sampler variable with no name");
111 return V.getName();
112}
113
114/// Emits initial debug location directive.
115static void emitInitialRawDwarfLocDirective(const MachineFunction &MF,
116 DwarfDebug *DD,
117 MCStreamer &OutStreamer) {
118 if (!DD)
119 return;
120
121 assert(OutStreamer.hasRawTextSupport() && "Expected assembly output mode.");
122 // This is NVPTX specific and it's unclear why.
123 // PR51079: If we have code without debug information we need to give up.
124 const DISubprogram *SP = MF.getFunction().getSubprogram();
125 if (!SP)
126 return;
127 assert(SP->getUnit());
128 // NoDebug and DebugDirectivesOnly do not require emitting the initial loc
129 // directive. NoDebug does not require any debug directives and the initial
130 // loc directive is not needed for DebugDirectivesOnly as it is redundant
131 // assuming this is a non-empty function.
132 if (SP->getUnit()->isDebugDirectivesOnly() || SP->getUnit()->isNoDebug())
133 return;
134
135 (void)DD->emitInitialLocDirective(MF, /*CUID=*/0);
136}
137
138/// discoverDependentGlobals - Return a set of GlobalVariables on which \p V
139/// depends.
140static void
141discoverDependentGlobals(const Value *V,
142 DenseSet<const GlobalVariable *> &Globals) {
143 if (const GlobalVariable *GV = dyn_cast<GlobalVariable>(Val: V)) {
144 Globals.insert(V: GV);
145 return;
146 }
147
148 if (const User *U = dyn_cast<User>(Val: V))
149 for (const auto &O : U->operands())
150 discoverDependentGlobals(V: O, Globals);
151}
152
153/// VisitGlobalVariableForEmission - Add \p GV to the list of GlobalVariable
154/// instances to be emitted, but only after any dependents have been added
155/// first.s
156static void
157VisitGlobalVariableForEmission(const GlobalVariable *GV,
158 SmallVectorImpl<const GlobalVariable *> &Order,
159 DenseSet<const GlobalVariable *> &Visited,
160 DenseSet<const GlobalVariable *> &Visiting) {
161 // Have we already visited this one?
162 if (Visited.count(V: GV))
163 return;
164
165 // Do we have a circular dependency?
166 if (!Visiting.insert(V: GV).second)
167 report_fatal_error(reason: "Circular dependency found in global variable set");
168
169 // Make sure we visit all dependents first
170 DenseSet<const GlobalVariable *> Others;
171 for (const auto &O : GV->operands())
172 discoverDependentGlobals(V: O, Globals&: Others);
173
174 for (const GlobalVariable *GV : Others)
175 VisitGlobalVariableForEmission(GV, Order, Visited, Visiting);
176
177 // Now we can visit ourself
178 Order.push_back(Elt: GV);
179 Visited.insert(V: GV);
180 Visiting.erase(V: GV);
181}
182
183void NVPTXAsmPrinter::emitInstruction(const MachineInstr *MI) {
184 NVPTX_MC::verifyInstructionPredicates(Opcode: MI->getOpcode(),
185 Features: getSubtargetInfo().getFeatureBits());
186
187 MCInst Inst;
188 lowerToMCInst(MI, OutMI&: Inst);
189 EmitToStreamer(S&: *OutStreamer, Inst);
190}
191
192void NVPTXAsmPrinter::lowerToMCInst(const MachineInstr *MI, MCInst &OutMI) {
193 OutMI.setOpcode(MI->getOpcode());
194 // Special: Do not mangle symbol operand of CALL_PROTOTYPE
195 if (MI->getOpcode() == NVPTX::CALL_PROTOTYPE) {
196 const MachineOperand &MO = MI->getOperand(i: 0);
197 OutMI.addOperand(Op: GetSymbolRef(
198 Symbol: OutContext.getOrCreateSymbol(Name: Twine(MO.getSymbolName()))));
199 return;
200 }
201
202 for (const auto MO : MI->operands())
203 OutMI.addOperand(Op: lowerOperand(MO));
204}
205
206MCOperand NVPTXAsmPrinter::lowerOperand(const MachineOperand &MO) {
207 switch (MO.getType()) {
208 default:
209 llvm_unreachable("unknown operand type");
210 case MachineOperand::MO_Register:
211 return MCOperand::createReg(Reg: encodeVirtualRegister(Reg: MO.getReg()));
212 case MachineOperand::MO_Immediate:
213 return MCOperand::createImm(Val: MO.getImm());
214 case MachineOperand::MO_MachineBasicBlock:
215 return MCOperand::createExpr(
216 Val: MCSymbolRefExpr::create(Symbol: MO.getMBB()->getSymbol(), Ctx&: OutContext));
217 case MachineOperand::MO_ExternalSymbol:
218 return GetSymbolRef(Symbol: GetExternalSymbolSymbol(Sym: MO.getSymbolName()));
219 case MachineOperand::MO_GlobalAddress:
220 return GetSymbolRef(Symbol: getSymbol(GV: MO.getGlobal()));
221 case MachineOperand::MO_FPImmediate: {
222 const ConstantFP *Cnt = MO.getFPImm();
223 const APFloat &Val = Cnt->getValueAPF();
224
225 switch (Cnt->getType()->getTypeID()) {
226 default:
227 report_fatal_error(reason: "Unsupported FP type");
228 break;
229 case Type::HalfTyID:
230 return MCOperand::createExpr(
231 Val: NVPTXFloatMCExpr::createConstantFPHalf(Flt: Val, Ctx&: OutContext));
232 case Type::BFloatTyID:
233 return MCOperand::createExpr(
234 Val: NVPTXFloatMCExpr::createConstantBFPHalf(Flt: Val, Ctx&: OutContext));
235 case Type::FloatTyID:
236 return MCOperand::createExpr(
237 Val: NVPTXFloatMCExpr::createConstantFPSingle(Flt: Val, Ctx&: OutContext));
238 case Type::DoubleTyID:
239 return MCOperand::createExpr(
240 Val: NVPTXFloatMCExpr::createConstantFPDouble(Flt: Val, Ctx&: OutContext));
241 }
242 break;
243 }
244 }
245}
246
247unsigned NVPTXAsmPrinter::encodeVirtualRegister(unsigned Reg) {
248 if (Register::isVirtualRegister(Reg)) {
249 const TargetRegisterClass *RC = MRI->getRegClass(Reg);
250
251 DenseMap<unsigned, unsigned> &RegMap = VRegMapping[RC];
252 unsigned RegNum = RegMap[Reg];
253
254 // Encode the register class in the upper 4 bits
255 // Must be kept in sync with NVPTXInstPrinter::printRegName
256 unsigned Ret = 0;
257 if (RC == &NVPTX::B1RegClass) {
258 Ret = (1 << 28);
259 } else if (RC == &NVPTX::B16RegClass) {
260 Ret = (2 << 28);
261 } else if (RC == &NVPTX::B32RegClass) {
262 Ret = (3 << 28);
263 } else if (RC == &NVPTX::B64RegClass) {
264 Ret = (4 << 28);
265 } else if (RC == &NVPTX::B128RegClass) {
266 Ret = (7 << 28);
267 } else {
268 report_fatal_error(reason: "Bad register class");
269 }
270
271 // Insert the vreg number
272 Ret |= (RegNum & 0x0FFFFFFF);
273 return Ret;
274 } else {
275 // Some special-use registers are actually physical registers.
276 // Encode this as the register class ID of 0 and the real register ID.
277 return Reg & 0x0FFFFFFF;
278 }
279}
280
281MCOperand NVPTXAsmPrinter::GetSymbolRef(const MCSymbol *Symbol) {
282 const MCExpr *Expr;
283 Expr = MCSymbolRefExpr::create(Symbol, Ctx&: OutContext);
284 return MCOperand::createExpr(Val: Expr);
285}
286
287void NVPTXAsmPrinter::printReturnValStr(const Function *F, raw_ostream &O) {
288 const DataLayout &DL = getDataLayout();
289 const NVPTXSubtarget &STI = TM.getSubtarget<NVPTXSubtarget>(F: *F);
290 const auto *TLI = cast<NVPTXTargetLowering>(Val: STI.getTargetLowering());
291
292 Type *Ty = F->getReturnType();
293 if (Ty->getTypeID() == Type::VoidTyID)
294 return;
295 O << " (";
296
297 auto PrintScalarRetVal = [&](unsigned Size) {
298 O << ".param .b" << promoteScalarArgumentSize(size: Size) << " func_retval0";
299 };
300 if (shouldPassAsArray(Ty)) {
301 const unsigned TotalSize = DL.getTypeAllocSize(Ty);
302 const Align RetAlignment =
303 getPTXParamAlign(F, Ty, AttrIdx: AttributeList::ReturnIndex, DL);
304 O << ".param .align " << RetAlignment.value() << " .b8 func_retval0["
305 << TotalSize << "]";
306 } else if (Ty->isFloatingPointTy()) {
307 PrintScalarRetVal(Ty->getPrimitiveSizeInBits());
308 } else if (auto *ITy = dyn_cast<IntegerType>(Val: Ty)) {
309 PrintScalarRetVal(ITy->getBitWidth());
310 } else if (isa<PointerType>(Val: Ty)) {
311 PrintScalarRetVal(TLI->getPointerTy(DL).getSizeInBits());
312 } else
313 llvm_unreachable("Unknown return type");
314 O << ") ";
315}
316
317void NVPTXAsmPrinter::printReturnValStr(const MachineFunction &MF,
318 raw_ostream &O) {
319 const Function &F = MF.getFunction();
320 printReturnValStr(F: &F, O);
321}
322
323// Return true if MBB is the header of a loop marked with
324// llvm.loop.unroll.disable or llvm.loop.unroll.count=1.
325bool NVPTXAsmPrinter::isLoopHeaderOfNoUnroll(
326 const MachineBasicBlock &MBB) const {
327 MachineLoopInfo &LI = getAnalysis<MachineLoopInfoWrapperPass>().getLI();
328 // We insert .pragma "nounroll" only to the loop header.
329 if (!LI.isLoopHeader(BB: &MBB))
330 return false;
331
332 // llvm.loop.unroll.disable is marked on the back edges of a loop. Therefore,
333 // we iterate through each back edge of the loop with header MBB, and check
334 // whether its metadata contains llvm.loop.unroll.disable.
335 for (const MachineBasicBlock *PMBB : MBB.predecessors()) {
336 if (LI.getLoopFor(BB: PMBB) != LI.getLoopFor(BB: &MBB)) {
337 // Edges from other loops to MBB are not back edges.
338 continue;
339 }
340 if (const BasicBlock *PBB = PMBB->getBasicBlock()) {
341 if (MDNode *LoopID =
342 PBB->getTerminator()->getMetadata(KindID: LLVMContext::MD_loop)) {
343 if (GetUnrollMetadata(LoopID, Name: "llvm.loop.unroll.disable"))
344 return true;
345 if (MDNode *UnrollCountMD =
346 GetUnrollMetadata(LoopID, Name: "llvm.loop.unroll.count")) {
347 if (mdconst::extract<ConstantInt>(MD: UnrollCountMD->getOperand(I: 1))
348 ->isOne())
349 return true;
350 }
351 }
352 }
353 }
354 return false;
355}
356
357void NVPTXAsmPrinter::emitBasicBlockStart(const MachineBasicBlock &MBB) {
358 AsmPrinter::emitBasicBlockStart(MBB);
359 if (isLoopHeaderOfNoUnroll(MBB))
360 OutStreamer->emitRawText(String: StringRef("\t.pragma \"nounroll\";\n"));
361}
362
363void NVPTXAsmPrinter::emitFunctionEntryLabel() {
364 SmallString<128> Str;
365 raw_svector_ostream O(Str);
366
367 if (!GlobalsEmitted) {
368 emitGlobals(M: *MF->getFunction().getParent());
369 GlobalsEmitted = true;
370 }
371
372 // Set up
373 MRI = &MF->getRegInfo();
374 F = &MF->getFunction();
375 emitLinkageDirective(V: F, O);
376 if (isKernelFunction(F: *F))
377 O << ".entry ";
378 else {
379 O << ".func ";
380 printReturnValStr(MF: *MF, O);
381 }
382
383 CurrentFnSym->print(OS&: O, MAI);
384
385 emitFunctionParamList(F, O);
386 O << "\n";
387
388 if (isKernelFunction(F: *F))
389 emitKernelFunctionDirectives(F: *F, O);
390
391 if (shouldEmitPTXNoReturn(V: F, TM))
392 O << ".noreturn";
393
394 OutStreamer->emitRawText(String: O.str());
395
396 VRegMapping.clear();
397 // Emit open brace for function body.
398 OutStreamer->emitRawText(String: StringRef("{\n"));
399 setAndEmitFunctionVirtualRegisters(*MF);
400 encodeDebugInfoRegisterNumbers(MF: *MF);
401 // Emit initial .loc debug directive for correct relocation symbol data.
402 emitInitialRawDwarfLocDirective(MF: *MF, DD: getDwarfDebug(), OutStreamer&: *OutStreamer);
403}
404
405bool NVPTXAsmPrinter::runOnMachineFunction(MachineFunction &F) {
406 bool Result = AsmPrinter::runOnMachineFunction(MF&: F);
407 // Emit closing brace for the body of function F.
408 // The closing brace must be emitted here because we need to emit additional
409 // debug labels/data after the last basic block.
410 // We need to emit the closing brace here because we don't have function that
411 // finished emission of the function body.
412 OutStreamer->emitRawText(String: StringRef("}\n"));
413 return Result;
414}
415
416void NVPTXAsmPrinter::emitFunctionBodyStart() {
417 SmallString<128> Str;
418 raw_svector_ostream O(Str);
419 emitDemotedVars(&MF->getFunction(), O);
420 OutStreamer->emitRawText(String: O.str());
421}
422
423void NVPTXAsmPrinter::emitFunctionBodyEnd() {
424 VRegMapping.clear();
425}
426
427const MCSymbol *NVPTXAsmPrinter::getFunctionFrameSymbol() const {
428 SmallString<128> Str;
429 raw_svector_ostream(Str) << DEPOTNAME << getFunctionNumber();
430 return OutContext.getOrCreateSymbol(Name: Str);
431}
432
433void NVPTXAsmPrinter::emitImplicitDef(const MachineInstr *MI) const {
434 Register RegNo = MI->getOperand(i: 0).getReg();
435 if (RegNo.isVirtual()) {
436 OutStreamer->AddComment(T: Twine("implicit-def: ") +
437 getVirtualRegisterName(RegNo));
438 } else {
439 const NVPTXSubtarget &STI = MI->getMF()->getSubtarget<NVPTXSubtarget>();
440 OutStreamer->AddComment(T: Twine("implicit-def: ") +
441 STI.getRegisterInfo()->getName(RegNo));
442 }
443 OutStreamer->addBlankLine();
444}
445
446void NVPTXAsmPrinter::emitKernelFunctionDirectives(const Function &F,
447 raw_ostream &O) const {
448 // If the NVVM IR has some of reqntid* specified, then output
449 // the reqntid directive, and set the unspecified ones to 1.
450 // If none of Reqntid* is specified, don't output reqntid directive.
451 const auto ReqNTID = getReqNTID(F);
452 if (!ReqNTID.empty())
453 O << formatv(Fmt: ".reqntid {0:$[, ]}\n",
454 Vals: make_range(x: ReqNTID.begin(), y: ReqNTID.end()));
455
456 const auto MaxNTID = getMaxNTID(F);
457 if (!MaxNTID.empty())
458 O << formatv(Fmt: ".maxntid {0:$[, ]}\n",
459 Vals: make_range(x: MaxNTID.begin(), y: MaxNTID.end()));
460
461 if (const auto Mincta = getMinCTASm(F))
462 O << ".minnctapersm " << *Mincta << "\n";
463
464 if (const auto Maxnreg = getMaxNReg(F))
465 O << ".maxnreg " << *Maxnreg << "\n";
466
467 // .maxclusterrank directive requires SM_90 or higher, make sure that we
468 // filter it out for lower SM versions, as it causes a hard ptxas crash.
469 const NVPTXTargetMachine &NTM = static_cast<const NVPTXTargetMachine &>(TM);
470 const NVPTXSubtarget *STI = &NTM.getSubtarget<NVPTXSubtarget>(F);
471
472 if (STI->getSmVersion() >= 90) {
473 const auto ClusterDim = getClusterDim(F);
474 const bool BlocksAreClusters = hasBlocksAreClusters(F);
475
476 if (!ClusterDim.empty()) {
477
478 if (!BlocksAreClusters)
479 O << ".explicitcluster\n";
480
481 if (ClusterDim[0] != 0) {
482 assert(llvm::all_of(ClusterDim, not_equal_to(0)) &&
483 "cluster_dim_x != 0 implies cluster_dim_y and cluster_dim_z "
484 "should be non-zero as well");
485
486 O << formatv(Fmt: ".reqnctapercluster {0:$[, ]}\n",
487 Vals: make_range(x: ClusterDim.begin(), y: ClusterDim.end()));
488 } else {
489 assert(llvm::all_of(ClusterDim, equal_to(0)) &&
490 "cluster_dim_x == 0 implies cluster_dim_y and cluster_dim_z "
491 "should be 0 as well");
492 }
493 }
494
495 if (BlocksAreClusters) {
496 LLVMContext &Ctx = F.getContext();
497 if (ReqNTID.empty() || ClusterDim.empty())
498 Ctx.diagnose(DI: DiagnosticInfoUnsupported(
499 F, "blocksareclusters requires reqntid and cluster_dim attributes",
500 F.getSubprogram()));
501 else if (STI->getPTXVersion() < 90)
502 Ctx.diagnose(DI: DiagnosticInfoUnsupported(
503 F, "blocksareclusters requires PTX version >= 9.0",
504 F.getSubprogram()));
505 else
506 O << ".blocksareclusters\n";
507 }
508
509 if (const auto Maxclusterrank = getMaxClusterRank(F))
510 O << ".maxclusterrank " << *Maxclusterrank << "\n";
511 }
512}
513
514std::string NVPTXAsmPrinter::getVirtualRegisterName(unsigned Reg) const {
515 const TargetRegisterClass *RC = MRI->getRegClass(Reg);
516
517 std::string Name;
518 raw_string_ostream NameStr(Name);
519
520 VRegRCMap::const_iterator I = VRegMapping.find(Val: RC);
521 assert(I != VRegMapping.end() && "Bad register class");
522 const DenseMap<unsigned, unsigned> &RegMap = I->second;
523
524 VRegMap::const_iterator VI = RegMap.find(Val: Reg);
525 assert(VI != RegMap.end() && "Bad virtual register");
526 unsigned MappedVR = VI->second;
527
528 NameStr << getNVPTXRegClassStr(RC) << MappedVR;
529
530 return Name;
531}
532
533void NVPTXAsmPrinter::emitVirtualRegister(unsigned int vr,
534 raw_ostream &O) {
535 O << getVirtualRegisterName(Reg: vr);
536}
537
538void NVPTXAsmPrinter::emitAliasDeclaration(const GlobalAlias *GA,
539 raw_ostream &O) {
540 const Function *F = dyn_cast_or_null<Function>(Val: GA->getAliaseeObject());
541 if (!F || isKernelFunction(F: *F) || F->isDeclaration())
542 report_fatal_error(
543 reason: "NVPTX aliasee must be a non-kernel function definition");
544
545 if (GA->hasLinkOnceLinkage() || GA->hasWeakLinkage() ||
546 GA->hasAvailableExternallyLinkage() || GA->hasCommonLinkage())
547 report_fatal_error(reason: "NVPTX aliasee must not be '.weak'");
548
549 emitDeclarationWithName(F, getSymbol(GV: GA), O);
550}
551
552void NVPTXAsmPrinter::emitDeclaration(const Function *F, raw_ostream &O) {
553 emitDeclarationWithName(F, getSymbol(GV: F), O);
554}
555
556void NVPTXAsmPrinter::emitDeclarationWithName(const Function *F, MCSymbol *S,
557 raw_ostream &O) {
558 emitLinkageDirective(V: F, O);
559 if (isKernelFunction(F: *F))
560 O << ".entry ";
561 else
562 O << ".func ";
563 printReturnValStr(F, O);
564 S->print(OS&: O, MAI);
565 O << "\n";
566 emitFunctionParamList(F, O);
567 O << "\n";
568 if (shouldEmitPTXNoReturn(V: F, TM))
569 O << ".noreturn";
570 O << ";\n";
571}
572
573static bool usedInGlobalVarDef(const Constant *C) {
574 if (!C)
575 return false;
576
577 if (const GlobalVariable *GV = dyn_cast<GlobalVariable>(Val: C))
578 return GV->getName() != "llvm.used";
579
580 for (const User *U : C->users())
581 if (const Constant *C = dyn_cast<Constant>(Val: U))
582 if (usedInGlobalVarDef(C))
583 return true;
584
585 return false;
586}
587
588static bool usedInOneFunc(const User *U, Function const *&OneFunc) {
589 if (const GlobalVariable *OtherGV = dyn_cast<GlobalVariable>(Val: U))
590 if (OtherGV->getName() == "llvm.used")
591 return true;
592
593 if (const Instruction *I = dyn_cast<Instruction>(Val: U)) {
594 if (const Function *CurFunc = I->getFunction()) {
595 if (OneFunc && (CurFunc != OneFunc))
596 return false;
597 OneFunc = CurFunc;
598 return true;
599 }
600 return false;
601 }
602
603 for (const User *UU : U->users())
604 if (!usedInOneFunc(U: UU, OneFunc))
605 return false;
606
607 return true;
608}
609
610/* Find out if a global variable can be demoted to local scope.
611 * Currently, this is valid for CUDA shared variables, which have local
612 * scope and global lifetime. So the conditions to check are :
613 * 1. Is the global variable in shared address space?
614 * 2. Does it have local linkage?
615 * 3. Is the global variable referenced only in one function?
616 */
617static bool canDemoteGlobalVar(const GlobalVariable *GV, Function const *&f) {
618 if (!GV->hasLocalLinkage())
619 return false;
620 if (GV->getAddressSpace() != ADDRESS_SPACE_SHARED)
621 return false;
622
623 const Function *oneFunc = nullptr;
624
625 bool flag = usedInOneFunc(U: GV, OneFunc&: oneFunc);
626 if (!flag)
627 return false;
628 if (!oneFunc)
629 return false;
630 f = oneFunc;
631 return true;
632}
633
634static bool useFuncSeen(const Constant *C,
635 const SmallPtrSetImpl<const Function *> &SeenSet) {
636 for (const User *U : C->users()) {
637 if (const Constant *cu = dyn_cast<Constant>(Val: U)) {
638 if (useFuncSeen(C: cu, SeenSet))
639 return true;
640 } else if (const Instruction *I = dyn_cast<Instruction>(Val: U)) {
641 if (const Function *Caller = I->getFunction())
642 if (SeenSet.contains(Ptr: Caller))
643 return true;
644 }
645 }
646 return false;
647}
648
649void NVPTXAsmPrinter::emitDeclarations(const Module &M, raw_ostream &O) {
650 SmallPtrSet<const Function *, 32> SeenSet;
651 for (const Function &F : M) {
652 if (F.getAttributes().hasFnAttr(Kind: "nvptx-libcall-callee")) {
653 emitDeclaration(F: &F, O);
654 continue;
655 }
656
657 if (F.isDeclaration()) {
658 if (F.use_empty())
659 continue;
660 if (F.getIntrinsicID())
661 continue;
662 // An unrecognized intrinsic would produce an invalid PTX declaration. Let
663 // the user know that, and skip it.
664 if (F.isIntrinsic()) {
665 LLVMContext &Ctx = F.getContext();
666 Ctx.diagnose(DI: DiagnosticInfoUnsupported(
667 F, "unknown intrinsic '" + F.getName() +
668 "' cannot be lowered by the NVPTX backend"));
669 continue;
670 }
671 emitDeclaration(F: &F, O);
672 continue;
673 }
674 for (const User *U : F.users()) {
675 if (const Constant *C = dyn_cast<Constant>(Val: U)) {
676 if (usedInGlobalVarDef(C)) {
677 // The use is in the initialization of a global variable
678 // that is a function pointer, so print a declaration
679 // for the original function
680 emitDeclaration(F: &F, O);
681 break;
682 }
683 // Emit a declaration of this function if the function that
684 // uses this constant expr has already been seen.
685 if (useFuncSeen(C, SeenSet)) {
686 emitDeclaration(F: &F, O);
687 break;
688 }
689 }
690
691 if (!isa<Instruction>(Val: U))
692 continue;
693 const Function *Caller = cast<Instruction>(Val: U)->getFunction();
694 if (!Caller)
695 continue;
696
697 // If a caller has already been seen, then the caller is
698 // appearing in the module before the callee. so print out
699 // a declaration for the callee.
700 if (SeenSet.contains(Ptr: Caller)) {
701 emitDeclaration(F: &F, O);
702 break;
703 }
704 }
705 SeenSet.insert(Ptr: &F);
706 }
707 for (const GlobalAlias &GA : M.aliases())
708 emitAliasDeclaration(GA: &GA, O);
709}
710
711void NVPTXAsmPrinter::emitStartOfAsmFile(Module &M) {
712 // Construct a default subtarget off of the TargetMachine defaults. The
713 // rest of NVPTX isn't friendly to change subtargets per function and
714 // so the default TargetMachine will have all of the options.
715 const NVPTXTargetMachine &NTM = static_cast<const NVPTXTargetMachine &>(TM);
716 const NVPTXSubtarget *STI = NTM.getSubtargetImpl();
717
718 // Emit header before any dwarf directives are emitted below.
719 emitHeader(M, STI: *STI);
720}
721
722/// Create NVPTX-specific DwarfDebug handler.
723DwarfDebug *NVPTXAsmPrinter::createDwarfDebug() {
724 return new NVPTXDwarfDebug(this);
725}
726
727bool NVPTXAsmPrinter::doInitialization(Module &M) {
728 const NVPTXTargetMachine &NTM = static_cast<const NVPTXTargetMachine &>(TM);
729 const NVPTXSubtarget &STI = *NTM.getSubtargetImpl();
730 if (M.alias_size() && (STI.getPTXVersion() < 63 || STI.getSmVersion() < 30))
731 report_fatal_error(reason: ".alias requires PTX version >= 6.3 and sm_30");
732
733 // We need to call the parent's one explicitly.
734 bool Result = AsmPrinter::doInitialization(M);
735
736 GlobalsEmitted = false;
737
738 return Result;
739}
740
741void NVPTXAsmPrinter::emitGlobals(const Module &M) {
742 SmallString<128> Str2;
743 raw_svector_ostream OS2(Str2);
744
745 emitDeclarations(M, O&: OS2);
746
747 // As ptxas does not support forward references of globals, we need to first
748 // sort the list of module-level globals in def-use order. We visit each
749 // global variable in order, and ensure that we emit it *after* its dependent
750 // globals. We use a little extra memory maintaining both a set and a list to
751 // have fast searches while maintaining a strict ordering.
752 SmallVector<const GlobalVariable *, 8> Globals;
753 DenseSet<const GlobalVariable *> GVVisited;
754 DenseSet<const GlobalVariable *> GVVisiting;
755
756 // Visit each global variable, in order
757 for (const GlobalVariable &I : M.globals())
758 VisitGlobalVariableForEmission(GV: &I, Order&: Globals, Visited&: GVVisited, Visiting&: GVVisiting);
759
760 assert(GVVisited.size() == M.global_size() && "Missed a global variable");
761 assert(GVVisiting.size() == 0 && "Did not fully process a global variable");
762
763 const NVPTXTargetMachine &NTM = static_cast<const NVPTXTargetMachine &>(TM);
764 const NVPTXSubtarget &STI = *NTM.getSubtargetImpl();
765
766 // Print out module-level global variables in proper order
767 for (const GlobalVariable *GV : Globals)
768 printModuleLevelGV(GVar: GV, O&: OS2, /*ProcessDemoted=*/processDemoted: false, STI);
769
770 OS2 << '\n';
771
772 OutStreamer->emitRawText(String: OS2.str());
773}
774
775void NVPTXAsmPrinter::emitGlobalAlias(const Module &M, const GlobalAlias &GA) {
776 SmallString<128> Str;
777 raw_svector_ostream OS(Str);
778
779 MCSymbol *Name = getSymbol(GV: &GA);
780
781 OS << ".alias " << Name->getName() << ", " << GA.getAliaseeObject()->getName()
782 << ";\n";
783
784 OutStreamer->emitRawText(String: OS.str());
785}
786
787NVPTXTargetStreamer *NVPTXAsmPrinter::getTargetStreamer() const {
788 return static_cast<NVPTXTargetStreamer *>(OutStreamer->getTargetStreamer());
789}
790
791static bool hasFullDebugInfo(Module &M) {
792 for (DICompileUnit *CU : M.debug_compile_units()) {
793 switch(CU->getEmissionKind()) {
794 case DICompileUnit::NoDebug:
795 case DICompileUnit::DebugDirectivesOnly:
796 break;
797 case DICompileUnit::LineTablesOnly:
798 case DICompileUnit::FullDebug:
799 return true;
800 }
801 }
802
803 return false;
804}
805
806void NVPTXAsmPrinter::emitHeader(Module &M, const NVPTXSubtarget &STI) {
807 auto *TS = getTargetStreamer();
808
809 TS->emitBanner();
810
811 const unsigned PTXVersion = STI.getPTXVersion();
812 TS->emitVersionDirective(PTXVersion);
813
814 const NVPTXTargetMachine &NTM = static_cast<const NVPTXTargetMachine &>(TM);
815 bool TexModeIndependent = NTM.getDrvInterface() == NVPTX::NVCL;
816
817 TS->emitTargetDirective(Target: STI.getTargetName(), TexModeIndependent,
818 HasDebug: hasFullDebugInfo(M));
819 TS->emitAddressSizeDirective(AddrSize: M.getDataLayout().getPointerSizeInBits());
820}
821
822bool NVPTXAsmPrinter::doFinalization(Module &M) {
823 // If we did not emit any functions, then the global declarations have not
824 // yet been emitted.
825 if (!GlobalsEmitted) {
826 emitGlobals(M);
827 GlobalsEmitted = true;
828 }
829
830 // call doFinalization
831 bool ret = AsmPrinter::doFinalization(M);
832
833 clearAnnotationCache(&M);
834
835 auto *TS =
836 static_cast<NVPTXTargetStreamer *>(OutStreamer->getTargetStreamer());
837 // Close the last emitted section
838 if (hasDebugInfo()) {
839 TS->closeLastSection();
840 // Emit empty .debug_macinfo section for better support of the empty files.
841 OutStreamer->emitRawText(String: "\t.section\t.debug_macinfo\t{\t}");
842 }
843
844 // Output last DWARF .file directives, if any.
845 TS->outputDwarfFileDirectives();
846
847 return ret;
848}
849
850// This function emits appropriate linkage directives for
851// functions and global variables.
852//
853// extern function declaration -> .extern
854// extern function definition -> .visible
855// external global variable with init -> .visible
856// external without init -> .extern
857// appending -> not allowed, assert.
858// for any linkage other than
859// internal, private, linker_private,
860// linker_private_weak, linker_private_weak_def_auto,
861// we emit -> .weak.
862
863void NVPTXAsmPrinter::emitLinkageDirective(const GlobalValue *V,
864 raw_ostream &O) {
865 if (static_cast<NVPTXTargetMachine &>(TM).getDrvInterface() == NVPTX::CUDA) {
866 if (V->hasExternalLinkage()) {
867 if (const auto *GVar = dyn_cast<GlobalVariable>(Val: V))
868 O << (GVar->hasInitializer() ? ".visible " : ".extern ");
869 else if (V->isDeclaration())
870 O << ".extern ";
871 else
872 O << ".visible ";
873 } else if (V->hasAppendingLinkage()) {
874 report_fatal_error(reason: "Symbol '" + (V->hasName() ? V->getName() : "") +
875 "' has unsupported appending linkage type");
876 } else if (!V->hasInternalLinkage() && !V->hasPrivateLinkage()) {
877 O << ".weak ";
878 }
879 }
880}
881
882void NVPTXAsmPrinter::printModuleLevelGV(const GlobalVariable *GVar,
883 raw_ostream &O, bool ProcessDemoted,
884 const NVPTXSubtarget &STI) {
885 // Skip meta data
886 if (GVar->hasSection())
887 if (GVar->getSection() == "llvm.metadata")
888 return;
889
890 // Skip LLVM intrinsic global variables
891 if (GVar->getName().starts_with(Prefix: "llvm.") ||
892 GVar->getName().starts_with(Prefix: "nvvm."))
893 return;
894
895 const DataLayout &DL = getDataLayout();
896
897 // GlobalVariables are always constant pointers themselves.
898 Type *ETy = GVar->getValueType();
899
900 if (GVar->hasExternalLinkage()) {
901 if (GVar->hasInitializer())
902 O << ".visible ";
903 else
904 O << ".extern ";
905 } else if (STI.getPTXVersion() >= 50 && GVar->hasCommonLinkage() &&
906 GVar->getAddressSpace() == ADDRESS_SPACE_GLOBAL) {
907 O << ".common ";
908 } else if (GVar->hasLinkOnceLinkage() || GVar->hasWeakLinkage() ||
909 GVar->hasAvailableExternallyLinkage() ||
910 GVar->hasCommonLinkage()) {
911 O << ".weak ";
912 }
913
914 const PTXOpaqueType OpaqueType = getPTXOpaqueType(*GVar);
915
916 if (OpaqueType == PTXOpaqueType::Texture) {
917 O << ".global .texref " << getTextureName(V: *GVar) << ";\n";
918 return;
919 }
920
921 if (OpaqueType == PTXOpaqueType::Surface) {
922 O << ".global .surfref " << getSurfaceName(V: *GVar) << ";\n";
923 return;
924 }
925
926 if (GVar->isDeclaration()) {
927 // (extern) declarations, no definition or initializer
928 // Currently the only known declaration is for an automatic __local
929 // (.shared) promoted to global.
930 emitPTXGlobalVariable(GVar, O, STI);
931 O << ";\n";
932 return;
933 }
934
935 if (OpaqueType == PTXOpaqueType::Sampler) {
936 O << ".global .samplerref " << getSamplerName(V: *GVar);
937
938 const Constant *Initializer = nullptr;
939 if (GVar->hasInitializer())
940 Initializer = GVar->getInitializer();
941 const ConstantInt *CI = nullptr;
942 if (Initializer)
943 CI = dyn_cast<ConstantInt>(Val: Initializer);
944 if (CI) {
945 unsigned sample = CI->getZExtValue();
946
947 O << " = { ";
948
949 for (int i = 0,
950 addr = ((sample & __CLK_ADDRESS_MASK) >> __CLK_ADDRESS_BASE);
951 i < 3; i++) {
952 O << "addr_mode_" << i << " = ";
953 switch (addr) {
954 case 0:
955 O << "wrap";
956 break;
957 case 1:
958 O << "clamp_to_border";
959 break;
960 case 2:
961 O << "clamp_to_edge";
962 break;
963 case 3:
964 O << "wrap";
965 break;
966 case 4:
967 O << "mirror";
968 break;
969 }
970 O << ", ";
971 }
972 O << "filter_mode = ";
973 switch ((sample & __CLK_FILTER_MASK) >> __CLK_FILTER_BASE) {
974 case 0:
975 O << "nearest";
976 break;
977 case 1:
978 O << "linear";
979 break;
980 case 2:
981 llvm_unreachable("Anisotropic filtering is not supported");
982 default:
983 O << "nearest";
984 break;
985 }
986 if (!((sample & __CLK_NORMALIZED_MASK) >> __CLK_NORMALIZED_BASE)) {
987 O << ", force_unnormalized_coords = 1";
988 }
989 O << " }";
990 }
991
992 O << ";\n";
993 return;
994 }
995
996 if (GVar->hasPrivateLinkage()) {
997 if (GVar->getName().starts_with(Prefix: "unrollpragma"))
998 return;
999
1000 // FIXME - need better way (e.g. Metadata) to avoid generating this global
1001 if (GVar->getName().starts_with(Prefix: "filename"))
1002 return;
1003 if (GVar->use_empty())
1004 return;
1005 }
1006
1007 const Function *DemotedFunc = nullptr;
1008 if (!ProcessDemoted && canDemoteGlobalVar(GV: GVar, f&: DemotedFunc)) {
1009 O << "// " << GVar->getName() << " has been demoted\n";
1010 localDecls[DemotedFunc].push_back(x: GVar);
1011 return;
1012 }
1013
1014 O << ".";
1015 emitPTXAddressSpace(AddressSpace: GVar->getAddressSpace(), O);
1016
1017 if (isManaged(*GVar)) {
1018 if (STI.getPTXVersion() < 40 || STI.getSmVersion() < 30)
1019 report_fatal_error(
1020 reason: ".attribute(.managed) requires PTX version >= 4.0 and sm_30");
1021 O << " .attribute(.managed)";
1022 }
1023
1024 O << " .align "
1025 << GVar->getAlign().value_or(u: DL.getPrefTypeAlign(Ty: ETy)).value();
1026
1027 if (ETy->isPointerTy() || ((ETy->isIntegerTy() || ETy->isFloatingPointTy()) &&
1028 ETy->getScalarSizeInBits() <= 64)) {
1029 O << " .";
1030 // Special case: ABI requires that we use .u8 for predicates
1031 if (ETy->isIntegerTy(BitWidth: 1))
1032 O << "u8";
1033 else
1034 O << getPTXFundamentalTypeStr(Ty: ETy, false);
1035 O << " ";
1036 getSymbol(GV: GVar)->print(OS&: O, MAI);
1037
1038 // Ptx allows variable initilization only for constant and global state
1039 // spaces.
1040 if (GVar->hasInitializer()) {
1041 if ((GVar->getAddressSpace() == ADDRESS_SPACE_GLOBAL) ||
1042 (GVar->getAddressSpace() == ADDRESS_SPACE_CONST)) {
1043 const Constant *Initializer = GVar->getInitializer();
1044 // 'undef' is treated as there is no value specified.
1045 if (!Initializer->isNullValue() && !isa<UndefValue>(Val: Initializer)) {
1046 O << " = ";
1047 printScalarConstant(CPV: Initializer, O);
1048 }
1049 } else {
1050 // The frontend adds zero-initializer to device and constant variables
1051 // that don't have an initial value, and UndefValue to shared
1052 // variables, so skip warning for this case.
1053 if (!GVar->getInitializer()->isNullValue() &&
1054 !isa<UndefValue>(Val: GVar->getInitializer())) {
1055 report_fatal_error(reason: "initial value of '" + GVar->getName() +
1056 "' is not allowed in addrspace(" +
1057 Twine(GVar->getAddressSpace()) + ")");
1058 }
1059 }
1060 }
1061 } else {
1062 // Although PTX has direct support for struct type and array type and
1063 // LLVM IR is very similar to PTX, the LLVM CodeGen does not support for
1064 // targets that support these high level field accesses. Structs, arrays
1065 // and vectors are lowered into arrays of bytes.
1066 switch (ETy->getTypeID()) {
1067 case Type::IntegerTyID: // Integers larger than 64 bits
1068 case Type::FP128TyID:
1069 case Type::StructTyID:
1070 case Type::ArrayTyID:
1071 case Type::FixedVectorTyID: {
1072 const uint64_t ElementSize = DL.getTypeStoreSize(Ty: ETy);
1073 // Ptx allows variable initilization only for constant and
1074 // global state spaces.
1075 if (((GVar->getAddressSpace() == ADDRESS_SPACE_GLOBAL) ||
1076 (GVar->getAddressSpace() == ADDRESS_SPACE_CONST)) &&
1077 GVar->hasInitializer()) {
1078 const Constant *Initializer = GVar->getInitializer();
1079 if (!isa<UndefValue>(Val: Initializer) && !Initializer->isNullValue()) {
1080 AggBuffer aggBuffer(ElementSize, *this);
1081 bufferAggregateConstant(CV: Initializer, aggBuffer: &aggBuffer);
1082 if (aggBuffer.numSymbols()) {
1083 const unsigned int ptrSize = MAI.getCodePointerSize();
1084 if (ElementSize % ptrSize ||
1085 !aggBuffer.allSymbolsAligned(ptrSize)) {
1086 // Print in bytes and use the mask() operator for pointers.
1087 if (!STI.hasMaskOperator())
1088 report_fatal_error(
1089 reason: "initialized packed aggregate with pointers '" +
1090 GVar->getName() +
1091 "' requires at least PTX ISA version 7.1");
1092 O << " .u8 ";
1093 getSymbol(GV: GVar)->print(OS&: O, MAI);
1094 O << "[" << ElementSize << "] = {";
1095 aggBuffer.printBytes(os&: O);
1096 O << "}";
1097 } else {
1098 O << " .u" << ptrSize * 8 << " ";
1099 getSymbol(GV: GVar)->print(OS&: O, MAI);
1100 O << "[" << ElementSize / ptrSize << "] = {";
1101 aggBuffer.printWords(os&: O);
1102 O << "}";
1103 }
1104 } else {
1105 O << " .b8 ";
1106 getSymbol(GV: GVar)->print(OS&: O, MAI);
1107 O << "[" << ElementSize << "] = {";
1108 aggBuffer.printBytes(os&: O);
1109 O << "}";
1110 }
1111 } else {
1112 O << " .b8 ";
1113 getSymbol(GV: GVar)->print(OS&: O, MAI);
1114 if (ElementSize)
1115 O << "[" << ElementSize << "]";
1116 }
1117 } else {
1118 O << " .b8 ";
1119 getSymbol(GV: GVar)->print(OS&: O, MAI);
1120 if (ElementSize)
1121 O << "[" << ElementSize << "]";
1122 }
1123 break;
1124 }
1125 default:
1126 llvm_unreachable("type not supported yet");
1127 }
1128 }
1129 O << ";\n";
1130}
1131
1132void NVPTXAsmPrinter::AggBuffer::printSymbol(unsigned nSym, raw_ostream &os) {
1133 const Value *v = Symbols[nSym];
1134 const Value *v0 = SymbolsBeforeStripping[nSym];
1135 if (const GlobalValue *GVar = dyn_cast<GlobalValue>(Val: v)) {
1136 MCSymbol *Name = AP.getSymbol(GV: GVar);
1137 PointerType *PTy = dyn_cast<PointerType>(Val: v0->getType());
1138 // Is v0 a generic pointer?
1139 bool isGenericPointer = PTy && PTy->getAddressSpace() == 0;
1140 if (EmitGeneric && isGenericPointer && !isa<Function>(Val: v)) {
1141 os << "generic(";
1142 Name->print(OS&: os, MAI: AP.MAI);
1143 os << ")";
1144 } else {
1145 Name->print(OS&: os, MAI: AP.MAI);
1146 }
1147 } else if (const ConstantExpr *CExpr = dyn_cast<ConstantExpr>(Val: v0)) {
1148 const MCExpr *Expr = AP.lowerConstantForGV(CV: CExpr, ProcessingGeneric: false);
1149 AP.printMCExpr(Expr: *Expr, OS&: os);
1150 } else
1151 llvm_unreachable("symbol type unknown");
1152}
1153
1154void NVPTXAsmPrinter::AggBuffer::printBytes(raw_ostream &os) {
1155 unsigned int ptrSize = AP.MAI.getCodePointerSize();
1156 // Do not emit trailing zero initializers. They will be zero-initialized by
1157 // ptxas. This saves on both space requirements for the generated PTX and on
1158 // memory use by ptxas. (See:
1159 // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#global-state-space)
1160 unsigned int InitializerCount = Size;
1161 // TODO: symbols make this harder, but it would still be good to trim trailing
1162 // 0s for aggs with symbols as well.
1163 if (numSymbols() == 0)
1164 while (InitializerCount >= 1 && !buffer[InitializerCount - 1])
1165 InitializerCount--;
1166
1167 symbolPosInBuffer.push_back(Elt: InitializerCount);
1168 unsigned int nSym = 0;
1169 unsigned int nextSymbolPos = symbolPosInBuffer[nSym];
1170 for (unsigned int pos = 0; pos < InitializerCount;) {
1171 if (pos)
1172 os << ", ";
1173 if (pos != nextSymbolPos) {
1174 os << (unsigned int)buffer[pos];
1175 ++pos;
1176 continue;
1177 }
1178 // Generate a per-byte mask() operator for the symbol, which looks like:
1179 // .global .u8 addr[] = {0xFF(foo), 0xFF00(foo), 0xFF0000(foo), ...};
1180 // See https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#initializers
1181 std::string symText;
1182 llvm::raw_string_ostream oss(symText);
1183 printSymbol(nSym, os&: oss);
1184 for (unsigned i = 0; i < ptrSize; ++i) {
1185 if (i)
1186 os << ", ";
1187 llvm::write_hex(S&: os, N: 0xFFULL << i * 8, Style: HexPrintStyle::PrefixUpper);
1188 os << "(" << symText << ")";
1189 }
1190 pos += ptrSize;
1191 nextSymbolPos = symbolPosInBuffer[++nSym];
1192 assert(nextSymbolPos >= pos);
1193 }
1194}
1195
1196void NVPTXAsmPrinter::AggBuffer::printWords(raw_ostream &os) {
1197 unsigned int ptrSize = AP.MAI.getCodePointerSize();
1198 symbolPosInBuffer.push_back(Elt: Size);
1199 unsigned int nSym = 0;
1200 unsigned int nextSymbolPos = symbolPosInBuffer[nSym];
1201 assert(nextSymbolPos % ptrSize == 0);
1202 for (unsigned int pos = 0; pos < Size; pos += ptrSize) {
1203 if (pos)
1204 os << ", ";
1205 if (pos == nextSymbolPos) {
1206 printSymbol(nSym, os);
1207 nextSymbolPos = symbolPosInBuffer[++nSym];
1208 assert(nextSymbolPos % ptrSize == 0);
1209 assert(nextSymbolPos >= pos + ptrSize);
1210 } else if (ptrSize == 4)
1211 os << support::endian::read32le(P: &buffer[pos]);
1212 else
1213 os << support::endian::read64le(P: &buffer[pos]);
1214 }
1215}
1216
1217void NVPTXAsmPrinter::emitDemotedVars(const Function *F, raw_ostream &O) {
1218 auto It = localDecls.find(x: F);
1219 if (It == localDecls.end())
1220 return;
1221
1222 ArrayRef<const GlobalVariable *> GVars = It->second;
1223
1224 const NVPTXTargetMachine &NTM = static_cast<const NVPTXTargetMachine &>(TM);
1225 const NVPTXSubtarget &STI = *NTM.getSubtargetImpl();
1226
1227 for (const GlobalVariable *GV : GVars) {
1228 O << "\t// demoted variable\n\t";
1229 printModuleLevelGV(GVar: GV, O, /*processDemoted=*/ProcessDemoted: true, STI);
1230 }
1231}
1232
1233void NVPTXAsmPrinter::emitPTXAddressSpace(unsigned int AddressSpace,
1234 raw_ostream &O) const {
1235 switch (AddressSpace) {
1236 case ADDRESS_SPACE_LOCAL:
1237 O << "local";
1238 break;
1239 case ADDRESS_SPACE_GLOBAL:
1240 O << "global";
1241 break;
1242 case ADDRESS_SPACE_CONST:
1243 O << "const";
1244 break;
1245 case ADDRESS_SPACE_SHARED:
1246 O << "shared";
1247 break;
1248 default:
1249 report_fatal_error(reason: "Bad address space found while emitting PTX: " +
1250 llvm::Twine(AddressSpace));
1251 break;
1252 }
1253}
1254
1255std::string
1256NVPTXAsmPrinter::getPTXFundamentalTypeStr(Type *Ty, bool useB4PTR) const {
1257 switch (Ty->getTypeID()) {
1258 case Type::IntegerTyID: {
1259 unsigned NumBits = cast<IntegerType>(Val: Ty)->getBitWidth();
1260 if (NumBits == 1)
1261 return "pred";
1262 if (NumBits <= 64) {
1263 std::string name = "u";
1264 return name + utostr(X: NumBits);
1265 }
1266 llvm_unreachable("Integer too large");
1267 break;
1268 }
1269 case Type::BFloatTyID:
1270 case Type::HalfTyID:
1271 // fp16 and bf16 are stored as .b16 for compatibility with pre-sm_53
1272 // PTX assembly.
1273 return "b16";
1274 case Type::FloatTyID:
1275 return "f32";
1276 case Type::DoubleTyID:
1277 return "f64";
1278 case Type::PointerTyID: {
1279 unsigned PtrSize = TM.getPointerSizeInBits(AS: Ty->getPointerAddressSpace());
1280 assert((PtrSize == 64 || PtrSize == 32) && "Unexpected pointer size");
1281
1282 if (PtrSize == 64)
1283 if (useB4PTR)
1284 return "b64";
1285 else
1286 return "u64";
1287 else if (useB4PTR)
1288 return "b32";
1289 else
1290 return "u32";
1291 }
1292 default:
1293 break;
1294 }
1295 llvm_unreachable("unexpected type");
1296}
1297
1298void NVPTXAsmPrinter::emitPTXGlobalVariable(const GlobalVariable *GVar,
1299 raw_ostream &O,
1300 const NVPTXSubtarget &STI) {
1301 const DataLayout &DL = getDataLayout();
1302
1303 // GlobalVariables are always constant pointers themselves.
1304 Type *ETy = GVar->getValueType();
1305
1306 O << ".";
1307 emitPTXAddressSpace(AddressSpace: GVar->getType()->getAddressSpace(), O);
1308 if (isManaged(*GVar)) {
1309 if (STI.getPTXVersion() < 40 || STI.getSmVersion() < 30)
1310 report_fatal_error(
1311 reason: ".attribute(.managed) requires PTX version >= 4.0 and sm_30");
1312
1313 O << " .attribute(.managed)";
1314 }
1315 O << " .align "
1316 << GVar->getAlign().value_or(u: DL.getPrefTypeAlign(Ty: ETy)).value();
1317
1318 // Special case for i128/fp128
1319 if (ETy->getScalarSizeInBits() == 128) {
1320 O << " .b8 ";
1321 getSymbol(GV: GVar)->print(OS&: O, MAI);
1322 O << "[16]";
1323 return;
1324 }
1325
1326 if (ETy->isFloatingPointTy() || ETy->isIntOrPtrTy()) {
1327 O << " ." << getPTXFundamentalTypeStr(Ty: ETy) << " ";
1328 getSymbol(GV: GVar)->print(OS&: O, MAI);
1329 return;
1330 }
1331
1332 int64_t ElementSize = 0;
1333
1334 // Although PTX has direct support for struct type and array type and LLVM IR
1335 // is very similar to PTX, the LLVM CodeGen does not support for targets that
1336 // support these high level field accesses. Structs and arrays are lowered
1337 // into arrays of bytes.
1338 switch (ETy->getTypeID()) {
1339 case Type::StructTyID:
1340 case Type::ArrayTyID:
1341 case Type::FixedVectorTyID:
1342 ElementSize = DL.getTypeStoreSize(Ty: ETy);
1343 O << " .b8 ";
1344 getSymbol(GV: GVar)->print(OS&: O, MAI);
1345 O << "[";
1346 if (ElementSize) {
1347 O << ElementSize;
1348 }
1349 O << "]";
1350 break;
1351 default:
1352 llvm_unreachable("type not supported yet");
1353 }
1354}
1355
1356void NVPTXAsmPrinter::emitFunctionParamList(const Function *F, raw_ostream &O) {
1357 const DataLayout &DL = getDataLayout();
1358 const NVPTXSubtarget &STI = TM.getSubtarget<NVPTXSubtarget>(F: *F);
1359 const auto *TLI = cast<NVPTXTargetLowering>(Val: STI.getTargetLowering());
1360 const NVPTXMachineFunctionInfo *MFI =
1361 MF ? MF->getInfo<NVPTXMachineFunctionInfo>() : nullptr;
1362
1363 bool IsFirst = true;
1364 const bool IsKernelFunc = isKernelFunction(F: *F);
1365
1366 if (F->arg_empty() && !F->isVarArg()) {
1367 O << "()";
1368 return;
1369 }
1370
1371 O << "(\n";
1372
1373 for (const Argument &Arg : F->args()) {
1374 Type *Ty = Arg.getType();
1375 const std::string ParamSym = TLI->getParamName(F, Idx: Arg.getArgNo());
1376
1377 if (!IsFirst)
1378 O << ",\n";
1379
1380 IsFirst = false;
1381
1382 // Handle image/sampler parameters
1383 if (IsKernelFunc) {
1384 const PTXOpaqueType ArgOpaqueType = getPTXOpaqueType(Arg);
1385 if (ArgOpaqueType != PTXOpaqueType::None) {
1386 const bool EmitImgPtr = !MFI || !MFI->checkImageHandleSymbol(Symbol: ParamSym);
1387 O << "\t.param ";
1388 if (EmitImgPtr)
1389 O << ".u64 .ptr ";
1390
1391 switch (ArgOpaqueType) {
1392 case PTXOpaqueType::Sampler:
1393 O << ".samplerref ";
1394 break;
1395 case PTXOpaqueType::Texture:
1396 O << ".texref ";
1397 break;
1398 case PTXOpaqueType::Surface:
1399 O << ".surfref ";
1400 break;
1401 case PTXOpaqueType::None:
1402 llvm_unreachable("handled above");
1403 }
1404 O << ParamSym;
1405 continue;
1406 }
1407 }
1408
1409 if (Arg.hasByValAttr()) {
1410 // param has byVal attribute.
1411 Type *ETy = Arg.getParamByValType();
1412 assert(ETy && "Param should have byval type");
1413
1414 // Print .param .align <a> .b8 .param[size];
1415 // <a> = optimal alignment for the element type; always multiple of
1416 // PAL.getParamAlignment
1417 // size = typeallocsize of element type
1418 const Align OptimalAlign =
1419 IsKernelFunc
1420 ? getPTXParamAlign(
1421 F, Ty: ETy, AttrIdx: Arg.getArgNo() + AttributeList::FirstArgIndex, DL)
1422 : getDeviceByValParamAlign(F, ArgTy: ETy,
1423 InitialAlign: Arg.getParamAlign().valueOrOne(), DL);
1424
1425 O << "\t.param .align " << OptimalAlign.value() << " .b8 " << ParamSym
1426 << "[" << DL.getTypeAllocSize(Ty: ETy) << "]";
1427 continue;
1428 }
1429
1430 if (shouldPassAsArray(Ty)) {
1431 // Just print .param .align <a> .b8 .param[size];
1432 // <a> = optimal alignment for the element type; always multiple of
1433 // PAL.getParamAlignment
1434 // size = typeallocsize of element type
1435 Align OptimalAlign = getPTXParamAlign(
1436 F, Ty, AttrIdx: Arg.getArgNo() + AttributeList::FirstArgIndex, DL);
1437
1438 O << "\t.param .align " << OptimalAlign.value() << " .b8 " << ParamSym
1439 << "[" << DL.getTypeAllocSize(Ty) << "]";
1440
1441 continue;
1442 }
1443 // Just a scalar
1444 auto *PTy = dyn_cast<PointerType>(Val: Ty);
1445 unsigned PTySizeInBits = 0;
1446 if (PTy) {
1447 PTySizeInBits =
1448 TLI->getPointerTy(DL, AS: PTy->getAddressSpace()).getSizeInBits();
1449 assert(PTySizeInBits && "Invalid pointer size");
1450 }
1451
1452 if (IsKernelFunc) {
1453 if (PTy) {
1454 O << "\t.param .u" << PTySizeInBits << " .ptr";
1455
1456 switch (PTy->getAddressSpace()) {
1457 default:
1458 break;
1459 case ADDRESS_SPACE_GLOBAL:
1460 O << " .global";
1461 break;
1462 case ADDRESS_SPACE_SHARED:
1463 O << " .shared";
1464 break;
1465 case ADDRESS_SPACE_CONST:
1466 O << " .const";
1467 break;
1468 case ADDRESS_SPACE_LOCAL:
1469 O << " .local";
1470 break;
1471 }
1472
1473 O << " .align " << Arg.getParamAlign().valueOrOne().value() << " "
1474 << ParamSym;
1475 continue;
1476 }
1477
1478 // non-pointer scalar to kernel func
1479 O << "\t.param .";
1480 // Special case: predicate operands become .u8 types
1481 if (Ty->isIntegerTy(BitWidth: 1))
1482 O << "u8";
1483 else
1484 O << getPTXFundamentalTypeStr(Ty);
1485 O << " " << ParamSym;
1486 continue;
1487 }
1488 // Non-kernel function, just print .param .b<size> for ABI
1489 // and .reg .b<size> for non-ABI
1490 unsigned Size;
1491 if (auto *ITy = dyn_cast<IntegerType>(Val: Ty)) {
1492 Size = promoteScalarArgumentSize(size: ITy->getBitWidth());
1493 } else if (PTy) {
1494 assert(PTySizeInBits && "Invalid pointer size");
1495 Size = PTySizeInBits;
1496 } else
1497 Size = Ty->getPrimitiveSizeInBits();
1498 O << "\t.param .b" << Size << " " << ParamSym;
1499 }
1500
1501 if (F->isVarArg()) {
1502 if (!IsFirst)
1503 O << ",\n";
1504 O << "\t.param .align " << STI.getMaxRequiredAlignment() << " .b8 "
1505 << TLI->getParamName(F, /* vararg */ Idx: -1) << "[]";
1506 }
1507
1508 O << "\n)";
1509}
1510
1511void NVPTXAsmPrinter::setAndEmitFunctionVirtualRegisters(
1512 const MachineFunction &MF) {
1513 SmallString<128> Str;
1514 raw_svector_ostream O(Str);
1515
1516 // Map the global virtual register number to a register class specific
1517 // virtual register number starting from 1 with that class.
1518 const TargetRegisterInfo *TRI = MF.getSubtarget().getRegisterInfo();
1519
1520 // Emit the Fake Stack Object
1521 const MachineFrameInfo &MFI = MF.getFrameInfo();
1522 int64_t NumBytes = MFI.getStackSize();
1523 if (NumBytes) {
1524 O << "\t.local .align " << MFI.getMaxAlign().value() << " .b8 \t"
1525 << DEPOTNAME << getFunctionNumber() << "[" << NumBytes << "];\n";
1526 if (static_cast<const NVPTXTargetMachine &>(MF.getTarget()).is64Bit()) {
1527 O << "\t.reg .b64 \t%SP;\n"
1528 << "\t.reg .b64 \t%SPL;\n";
1529 } else {
1530 O << "\t.reg .b32 \t%SP;\n"
1531 << "\t.reg .b32 \t%SPL;\n";
1532 }
1533 }
1534
1535 // Go through all virtual registers to establish the mapping between the
1536 // global virtual
1537 // register number and the per class virtual register number.
1538 // We use the per class virtual register number in the ptx output.
1539 for (unsigned I : llvm::seq(Size: MRI->getNumVirtRegs())) {
1540 Register VR = Register::index2VirtReg(Index: I);
1541 if (MRI->use_empty(RegNo: VR) && MRI->def_empty(RegNo: VR))
1542 continue;
1543 auto &RCRegMap = VRegMapping[MRI->getRegClass(Reg: VR)];
1544 RCRegMap[VR] = RCRegMap.size() + 1;
1545 }
1546
1547 // Emit declaration of the virtual registers or 'physical' registers for
1548 // each register class
1549 for (const TargetRegisterClass *RC : TRI->regclasses()) {
1550 const unsigned N = VRegMapping[RC].size();
1551
1552 // Only declare those registers that may be used.
1553 if (N) {
1554 const StringRef RCName = getNVPTXRegClassName(RC);
1555 const StringRef RCStr = getNVPTXRegClassStr(RC);
1556 O << "\t.reg " << RCName << " \t" << RCStr << "<" << (N + 1) << ">;\n";
1557 }
1558 }
1559
1560 OutStreamer->emitRawText(String: O.str());
1561}
1562
1563/// Translate virtual register numbers in DebugInfo locations to their printed
1564/// encodings, as used by CUDA-GDB.
1565void NVPTXAsmPrinter::encodeDebugInfoRegisterNumbers(
1566 const MachineFunction &MF) {
1567 const NVPTXSubtarget &STI = MF.getSubtarget<NVPTXSubtarget>();
1568 const NVPTXRegisterInfo *registerInfo = STI.getRegisterInfo();
1569
1570 // Clear the old mapping, and add the new one. This mapping is used after the
1571 // printing of the current function is complete, but before the next function
1572 // is printed.
1573 registerInfo->clearDebugRegisterMap();
1574
1575 for (auto &classMap : VRegMapping) {
1576 for (auto &registerMapping : classMap.getSecond()) {
1577 auto reg = registerMapping.getFirst();
1578 registerInfo->addToDebugRegisterMap(preEncodedVirtualRegister: reg, RegisterName: getVirtualRegisterName(Reg: reg));
1579 }
1580 }
1581}
1582
1583void NVPTXAsmPrinter::printFPConstant(const ConstantFP *Fp,
1584 raw_ostream &O) const {
1585 APFloat APF = APFloat(Fp->getValueAPF()); // make a copy
1586 bool ignored;
1587 unsigned int numHex;
1588 const char *lead;
1589
1590 if (Fp->getType()->getTypeID() == Type::FloatTyID) {
1591 numHex = 8;
1592 lead = "0f";
1593 APF.convert(ToSemantics: APFloat::IEEEsingle(), RM: APFloat::rmNearestTiesToEven, losesInfo: &ignored);
1594 } else if (Fp->getType()->getTypeID() == Type::DoubleTyID) {
1595 numHex = 16;
1596 lead = "0d";
1597 APF.convert(ToSemantics: APFloat::IEEEdouble(), RM: APFloat::rmNearestTiesToEven, losesInfo: &ignored);
1598 } else
1599 llvm_unreachable("unsupported fp type");
1600
1601 APInt API = APF.bitcastToAPInt();
1602 O << lead << format_hex_no_prefix(N: API.getZExtValue(), Width: numHex, /*Upper=*/true);
1603}
1604
1605void NVPTXAsmPrinter::printScalarConstant(const Constant *CPV, raw_ostream &O) {
1606 if (const ConstantInt *CI = dyn_cast<ConstantInt>(Val: CPV)) {
1607 O << CI->getValue();
1608 return;
1609 }
1610 if (const ConstantFP *CFP = dyn_cast<ConstantFP>(Val: CPV)) {
1611 printFPConstant(Fp: CFP, O);
1612 return;
1613 }
1614 if (isa<ConstantPointerNull>(Val: CPV)) {
1615 O << "0";
1616 return;
1617 }
1618 if (const GlobalValue *GVar = dyn_cast<GlobalValue>(Val: CPV)) {
1619 const bool IsNonGenericPointer = GVar->getAddressSpace() != 0;
1620 if (EmitGeneric && !isa<Function>(Val: CPV) && !IsNonGenericPointer) {
1621 O << "generic(";
1622 getSymbol(GV: GVar)->print(OS&: O, MAI);
1623 O << ")";
1624 } else {
1625 getSymbol(GV: GVar)->print(OS&: O, MAI);
1626 }
1627 return;
1628 }
1629 if (const ConstantExpr *Cexpr = dyn_cast<ConstantExpr>(Val: CPV)) {
1630 const MCExpr *E = lowerConstantForGV(CV: cast<Constant>(Val: Cexpr), ProcessingGeneric: false);
1631 printMCExpr(Expr: *E, OS&: O);
1632 return;
1633 }
1634 llvm_unreachable("Not scalar type found in printScalarConstant()");
1635}
1636
1637void NVPTXAsmPrinter::bufferLEByte(const Constant *CPV, int Bytes,
1638 AggBuffer *AggBuffer) {
1639 const DataLayout &DL = getDataLayout();
1640 int AllocSize = DL.getTypeAllocSize(Ty: CPV->getType());
1641 if (isa<UndefValue>(Val: CPV) || CPV->isNullValue()) {
1642 // Non-zero Bytes indicates that we need to zero-fill everything. Otherwise,
1643 // only the space allocated by CPV.
1644 AggBuffer->addZeros(Num: Bytes ? Bytes : AllocSize);
1645 return;
1646 }
1647
1648 // Helper for filling AggBuffer with APInts.
1649 auto AddIntToBuffer = [AggBuffer, Bytes](const APInt &Val) {
1650 size_t NumBytes = (Val.getBitWidth() + 7) / 8;
1651 SmallVector<unsigned char, 16> Buf(NumBytes);
1652 // `extractBitsAsZExtValue` does not allow the extraction of bits beyond the
1653 // input's bit width, and i1 arrays may not have a length that is a multuple
1654 // of 8. We handle the last byte separately, so we never request out of
1655 // bounds bits.
1656 for (unsigned I = 0; I < NumBytes - 1; ++I) {
1657 Buf[I] = Val.extractBitsAsZExtValue(numBits: 8, bitPosition: I * 8);
1658 }
1659 size_t LastBytePosition = (NumBytes - 1) * 8;
1660 size_t LastByteBits = Val.getBitWidth() - LastBytePosition;
1661 Buf[NumBytes - 1] =
1662 Val.extractBitsAsZExtValue(numBits: LastByteBits, bitPosition: LastBytePosition);
1663 AggBuffer->addBytes(Ptr: Buf.data(), Num: NumBytes, Bytes);
1664 };
1665
1666 switch (CPV->getType()->getTypeID()) {
1667 case Type::IntegerTyID:
1668 if (const auto *CI = dyn_cast<ConstantInt>(Val: CPV)) {
1669 AddIntToBuffer(CI->getValue());
1670 break;
1671 }
1672 if (const auto *Cexpr = dyn_cast<ConstantExpr>(Val: CPV)) {
1673 if (const auto *CI =
1674 dyn_cast<ConstantInt>(Val: ConstantFoldConstant(C: Cexpr, DL))) {
1675 AddIntToBuffer(CI->getValue());
1676 break;
1677 }
1678 if (Cexpr->getOpcode() == Instruction::PtrToInt) {
1679 Value *V = Cexpr->getOperand(i_nocapture: 0)->stripPointerCasts();
1680 AggBuffer->addSymbol(GVar: V, GVarBeforeStripping: Cexpr->getOperand(i_nocapture: 0));
1681 AggBuffer->addZeros(Num: AllocSize);
1682 break;
1683 }
1684 // A symbol-relative integer whose offset is applied outside the
1685 // ptrtoint, e.g. add(ptrtoint(@g), C). It can't fold to a ConstantInt
1686 // because it references a symbol; emit it through lowerConstantForGV, the
1687 // same path scalar symbol-relative integer globals use.
1688 AggBuffer->addSymbol(GVar: Cexpr, GVarBeforeStripping: Cexpr);
1689 AggBuffer->addZeros(Num: AllocSize);
1690 break;
1691 }
1692 llvm_unreachable("unsupported integer const type");
1693 break;
1694
1695 case Type::HalfTyID:
1696 case Type::BFloatTyID:
1697 case Type::FloatTyID:
1698 case Type::DoubleTyID:
1699 AddIntToBuffer(cast<ConstantFP>(Val: CPV)->getValueAPF().bitcastToAPInt());
1700 break;
1701
1702 case Type::PointerTyID: {
1703 if (const GlobalValue *GVar = dyn_cast<GlobalValue>(Val: CPV)) {
1704 AggBuffer->addSymbol(GVar, GVarBeforeStripping: GVar);
1705 } else if (const ConstantExpr *Cexpr = dyn_cast<ConstantExpr>(Val: CPV)) {
1706 const Value *v = Cexpr->stripPointerCasts();
1707 AggBuffer->addSymbol(GVar: v, GVarBeforeStripping: Cexpr);
1708 }
1709 AggBuffer->addZeros(Num: AllocSize);
1710 break;
1711 }
1712
1713 case Type::ArrayTyID:
1714 case Type::FixedVectorTyID:
1715 case Type::StructTyID: {
1716 if (isa<ConstantAggregate>(Val: CPV) || isa<ConstantDataSequential>(Val: CPV)) {
1717 // bufferAggregateConstant doesn't emit tail-padding, i.e. it writes
1718 // `store_size` bytes, not `alloc_size` bytes. Do it ourselves here.
1719 unsigned StartPos = AggBuffer->getCurpos();
1720 bufferAggregateConstant(CV: CPV, aggBuffer: AggBuffer);
1721 unsigned Written = AggBuffer->getCurpos() - StartPos;
1722 unsigned SlotSize = std::max<int>(a: Bytes, b: AllocSize);
1723 if (SlotSize > Written)
1724 AggBuffer->addZeros(Num: SlotSize - Written);
1725 } else if (isa<ConstantAggregateZero>(Val: CPV))
1726 AggBuffer->addZeros(Num: Bytes);
1727 else
1728 llvm_unreachable("Unexpected Constant type");
1729 break;
1730 }
1731
1732 default:
1733 llvm_unreachable("unsupported type");
1734 }
1735}
1736
1737void NVPTXAsmPrinter::bufferAggregateConstant(const Constant *CPV,
1738 AggBuffer *aggBuffer) {
1739 const DataLayout &DL = getDataLayout();
1740
1741 auto ExtendBuffer = [](APInt Val, AggBuffer *Buffer) {
1742 unsigned NumBytes = divideCeil(Numerator: Val.getBitWidth(), Denominator: 8);
1743 for (unsigned I : llvm::seq(Size: NumBytes)) {
1744 unsigned NumBits = std::min(a: 8u, b: Val.getBitWidth() - I * 8);
1745 Buffer->addByte(Byte: Val.extractBitsAsZExtValue(numBits: NumBits, bitPosition: I * 8));
1746 }
1747 };
1748
1749 // Integer or floating point vector splats.
1750 if (isa<ConstantInt, ConstantFP>(Val: CPV)) {
1751 if (auto *VTy = dyn_cast<FixedVectorType>(Val: CPV->getType())) {
1752 for (unsigned I : llvm::seq(Size: VTy->getNumElements()))
1753 bufferLEByte(CPV: CPV->getAggregateElement(Elt: I), Bytes: 0, AggBuffer: aggBuffer);
1754 return;
1755 }
1756 }
1757
1758 // Integers of arbitrary width
1759 if (const ConstantInt *CI = dyn_cast<ConstantInt>(Val: CPV)) {
1760 assert(CI->getType()->isIntegerTy() && "Expected integer constant!");
1761 ExtendBuffer(CI->getValue(), aggBuffer);
1762 return;
1763 }
1764
1765 // f128
1766 if (const ConstantFP *CFP = dyn_cast<ConstantFP>(Val: CPV)) {
1767 assert(CFP->getType()->isFloatingPointTy() && "Expected fp constant!");
1768 if (CFP->getType()->isFP128Ty()) {
1769 ExtendBuffer(CFP->getValueAPF().bitcastToAPInt(), aggBuffer);
1770 return;
1771 }
1772 }
1773
1774 // Buffer arrays one element at a time.
1775 if (isa<ConstantArray>(Val: CPV)) {
1776 for (const auto &Op : CPV->operands())
1777 bufferLEByte(CPV: cast<Constant>(Val: Op), Bytes: 0, AggBuffer: aggBuffer);
1778 return;
1779 }
1780
1781 // Constant vectors
1782 if (const auto *CVec = dyn_cast<ConstantVector>(Val: CPV)) {
1783 bufferAggregateConstVec(CV: CVec, aggBuffer);
1784 return;
1785 }
1786
1787 if (const auto *CDS = dyn_cast<ConstantDataSequential>(Val: CPV)) {
1788 for (unsigned I : llvm::seq(Size: CDS->getNumElements()))
1789 bufferLEByte(CPV: cast<Constant>(Val: CDS->getElementAsConstant(i: I)), Bytes: 0, AggBuffer: aggBuffer);
1790 return;
1791 }
1792
1793 if (isa<ConstantStruct>(Val: CPV)) {
1794 if (CPV->getNumOperands()) {
1795 StructType *ST = cast<StructType>(Val: CPV->getType());
1796 for (unsigned I : llvm::seq(Size: CPV->getNumOperands())) {
1797 int EndOffset = (I + 1 == CPV->getNumOperands())
1798 ? DL.getStructLayout(Ty: ST)->getElementOffset(Idx: 0) +
1799 DL.getTypeAllocSize(Ty: ST)
1800 : DL.getStructLayout(Ty: ST)->getElementOffset(Idx: I + 1);
1801 int Bytes = EndOffset - DL.getStructLayout(Ty: ST)->getElementOffset(Idx: I);
1802 bufferLEByte(CPV: cast<Constant>(Val: CPV->getOperand(i: I)), Bytes, AggBuffer: aggBuffer);
1803 }
1804 }
1805 return;
1806 }
1807 llvm_unreachable("unsupported constant type in printAggregateConstant()");
1808}
1809
1810void NVPTXAsmPrinter::bufferAggregateConstVec(const ConstantVector *CV,
1811 AggBuffer *aggBuffer) {
1812 unsigned NumElems = CV->getType()->getNumElements();
1813 const unsigned BuffSize = aggBuffer->getBufferSize();
1814
1815 // Buffer one element at a time if we have allocated enough buffer space.
1816 if (BuffSize >= NumElems) {
1817 for (const auto &Op : CV->operands())
1818 bufferLEByte(CPV: cast<Constant>(Val: Op), Bytes: 0, AggBuffer: aggBuffer);
1819 return;
1820 }
1821
1822 // Sub-byte datatypes will have more elements than bytes allocated for the
1823 // buffer. Merge consecutive elements to form a full byte. We expect that 8 %
1824 // sub-byte-elem-size should be 0 and current expected usage is for i4 (for
1825 // e2m1-fp4 types).
1826 Type *ElemTy = CV->getType()->getElementType();
1827 assert(ElemTy->isIntegerTy() && "Expected integer data type.");
1828 unsigned ElemTySize = ElemTy->getPrimitiveSizeInBits();
1829 assert(ElemTySize < 8 && "Expected sub-byte data type.");
1830 assert(8 % ElemTySize == 0 && "Element type size must evenly divide a byte.");
1831 // Number of elements to merge to form a full byte.
1832 unsigned NumElemsPerByte = 8 / ElemTySize;
1833 unsigned NumCompleteBytes = NumElems / NumElemsPerByte;
1834 unsigned NumTailElems = NumElems % NumElemsPerByte;
1835
1836 // Helper lambda to constant-fold sub-vector of sub-byte type elements into
1837 // i8. Start and end indices of the sub-vector is provided, along with number
1838 // of padding zeros if required.
1839 auto ConvertSubCVtoInt8 = [this, &ElemTy](const ConstantVector *CV,
1840 unsigned Start, unsigned End,
1841 unsigned NumPaddingZeros = 0) {
1842 // Collect elements to create sub-vector.
1843 SmallVector<Constant *, 8> SubCVElems;
1844 for (unsigned I : llvm::seq(Begin: Start, End))
1845 SubCVElems.push_back(Elt: CV->getAggregateElement(Elt: I));
1846
1847 // Optionally pad with zeros.
1848 if (NumPaddingZeros)
1849 SubCVElems.append(NumInputs: NumPaddingZeros, Elt: ConstantInt::getNullValue(Ty: ElemTy));
1850
1851 auto SubCV = ConstantVector::get(V: SubCVElems);
1852 Type *Int8Ty = IntegerType::get(C&: SubCV->getContext(), NumBits: 8);
1853
1854 // Merge elements of the sub-vector using ConstantFolding.
1855 ConstantInt *MergedElem =
1856 dyn_cast_or_null<ConstantInt>(Val: ConstantFoldConstant(
1857 C: ConstantExpr::getBitCast(C: const_cast<Constant *>(SubCV), Ty: Int8Ty),
1858 DL: getDataLayout()));
1859
1860 if (!MergedElem)
1861 report_fatal_error(
1862 reason: "Cannot lower vector global with unusual element type");
1863
1864 return MergedElem;
1865 };
1866
1867 // Iterate through elements of vector one chunk at a time and buffer that
1868 // chunk.
1869 for (unsigned ByteIdx : llvm::seq(Size: NumCompleteBytes))
1870 bufferLEByte(CPV: ConvertSubCVtoInt8(CV, ByteIdx * NumElemsPerByte,
1871 (ByteIdx + 1) * NumElemsPerByte),
1872 Bytes: 0, AggBuffer: aggBuffer);
1873
1874 // For unevenly sized vectors add tail padding zeros.
1875 if (NumTailElems > 0)
1876 bufferLEByte(CPV: ConvertSubCVtoInt8(CV, NumElems - NumTailElems, NumElems,
1877 NumElemsPerByte - NumTailElems),
1878 Bytes: 0, AggBuffer: aggBuffer);
1879}
1880
1881/// lowerConstantForGV - Return an MCExpr for the given Constant. This is mostly
1882/// a copy from AsmPrinter::lowerConstant, except customized to only handle
1883/// expressions that are representable in PTX and create
1884/// NVPTXGenericMCSymbolRefExpr nodes for addrspacecast instructions.
1885const MCExpr *
1886NVPTXAsmPrinter::lowerConstantForGV(const Constant *CV,
1887 bool ProcessingGeneric) const {
1888 MCContext &Ctx = OutContext;
1889
1890 if (CV->isNullValue() || isa<UndefValue>(Val: CV))
1891 return MCConstantExpr::create(Value: 0, Ctx);
1892
1893 if (const ConstantInt *CI = dyn_cast<ConstantInt>(Val: CV))
1894 return MCConstantExpr::create(Value: CI->getZExtValue(), Ctx);
1895
1896 if (const GlobalValue *GV = dyn_cast<GlobalValue>(Val: CV)) {
1897 const MCSymbolRefExpr *Expr = MCSymbolRefExpr::create(Symbol: getSymbol(GV), Ctx);
1898 if (ProcessingGeneric)
1899 return NVPTXGenericMCSymbolRefExpr::create(SymExpr: Expr, Ctx);
1900 return Expr;
1901 }
1902
1903 const ConstantExpr *CE = dyn_cast<ConstantExpr>(Val: CV);
1904 if (!CE) {
1905 llvm_unreachable("Unknown constant value to lower!");
1906 }
1907
1908 switch (CE->getOpcode()) {
1909 default:
1910 break; // Error
1911
1912 case Instruction::AddrSpaceCast: {
1913 // Strip the addrspacecast and pass along the operand
1914 PointerType *DstTy = cast<PointerType>(Val: CE->getType());
1915 if (DstTy->getAddressSpace() == 0)
1916 return lowerConstantForGV(CV: cast<const Constant>(Val: CE->getOperand(i_nocapture: 0)), ProcessingGeneric: true);
1917
1918 break; // Error
1919 }
1920
1921 case Instruction::GetElementPtr: {
1922 const DataLayout &DL = getDataLayout();
1923
1924 // Generate a symbolic expression for the byte address
1925 APInt OffsetAI(DL.getPointerTypeSizeInBits(CE->getType()), 0);
1926 cast<GEPOperator>(Val: CE)->accumulateConstantOffset(DL, Offset&: OffsetAI);
1927
1928 const MCExpr *Base = lowerConstantForGV(CV: CE->getOperand(i_nocapture: 0),
1929 ProcessingGeneric);
1930 if (!OffsetAI)
1931 return Base;
1932
1933 int64_t Offset = OffsetAI.getSExtValue();
1934 return MCBinaryExpr::createAdd(LHS: Base, RHS: MCConstantExpr::create(Value: Offset, Ctx),
1935 Ctx);
1936 }
1937
1938 case Instruction::Trunc:
1939 // We emit the value and depend on the assembler to truncate the generated
1940 // expression properly. This is important for differences between
1941 // blockaddress labels. Since the two labels are in the same function, it
1942 // is reasonable to treat their delta as a 32-bit value.
1943 [[fallthrough]];
1944 case Instruction::BitCast:
1945 return lowerConstantForGV(CV: CE->getOperand(i_nocapture: 0), ProcessingGeneric);
1946
1947 case Instruction::IntToPtr: {
1948 const DataLayout &DL = getDataLayout();
1949
1950 // Handle casts to pointers by changing them into casts to the appropriate
1951 // integer type. This promotes constant folding and simplifies this code.
1952 Constant *Op = CE->getOperand(i_nocapture: 0);
1953 Op = ConstantFoldIntegerCast(C: Op, DestTy: DL.getIntPtrType(CV->getType()),
1954 /*IsSigned*/ false, DL);
1955 if (Op)
1956 return lowerConstantForGV(CV: Op, ProcessingGeneric);
1957
1958 break; // Error
1959 }
1960
1961 case Instruction::PtrToInt: {
1962 const DataLayout &DL = getDataLayout();
1963
1964 // Support only foldable casts to/from pointers that can be eliminated by
1965 // changing the pointer to the appropriately sized integer type.
1966 Constant *Op = CE->getOperand(i_nocapture: 0);
1967 Type *Ty = CE->getType();
1968
1969 const MCExpr *OpExpr = lowerConstantForGV(CV: Op, ProcessingGeneric);
1970
1971 // We can emit the pointer value into this slot if the slot is an
1972 // integer slot equal to the size of the pointer.
1973 if (DL.getTypeAllocSize(Ty) == DL.getTypeAllocSize(Ty: Op->getType()))
1974 return OpExpr;
1975
1976 // Otherwise the pointer is smaller than the resultant integer, mask off
1977 // the high bits so we are sure to get a proper truncation if the input is
1978 // a constant expr.
1979 unsigned InBits = DL.getTypeAllocSizeInBits(Ty: Op->getType());
1980 const MCExpr *MaskExpr = MCConstantExpr::create(Value: ~0ULL >> (64-InBits), Ctx);
1981 return MCBinaryExpr::createAnd(LHS: OpExpr, RHS: MaskExpr, Ctx);
1982 }
1983
1984 // The MC library also has a right-shift operator, but it isn't consistently
1985 // signed or unsigned between different targets.
1986 case Instruction::Add: {
1987 const MCExpr *LHS = lowerConstantForGV(CV: CE->getOperand(i_nocapture: 0), ProcessingGeneric);
1988 const MCExpr *RHS = lowerConstantForGV(CV: CE->getOperand(i_nocapture: 1), ProcessingGeneric);
1989 switch (CE->getOpcode()) {
1990 default: llvm_unreachable("Unknown binary operator constant cast expr");
1991 case Instruction::Add: return MCBinaryExpr::createAdd(LHS, RHS, Ctx);
1992 }
1993 }
1994 }
1995
1996 // If the code isn't optimized, there may be outstanding folding
1997 // opportunities. Attempt to fold the expression using DataLayout as a
1998 // last resort before giving up.
1999 Constant *C = ConstantFoldConstant(C: CE, DL: getDataLayout());
2000 if (C != CE)
2001 return lowerConstantForGV(CV: C, ProcessingGeneric);
2002
2003 // Otherwise report the problem to the user.
2004 std::string S;
2005 raw_string_ostream OS(S);
2006 OS << "Unsupported expression in static initializer: ";
2007 CE->printAsOperand(O&: OS, /*PrintType=*/false,
2008 M: !MF ? nullptr : MF->getFunction().getParent());
2009 report_fatal_error(reason: Twine(OS.str()));
2010}
2011
2012void NVPTXAsmPrinter::printMCExpr(const MCExpr &Expr, raw_ostream &OS) const {
2013 OutContext.getAsmInfo().printExpr(OS, Expr);
2014}
2015
2016/// PrintAsmOperand - Print out an operand for an inline asm expression.
2017///
2018bool NVPTXAsmPrinter::PrintAsmOperand(const MachineInstr *MI, unsigned OpNo,
2019 const char *ExtraCode, raw_ostream &O) {
2020 if (ExtraCode && ExtraCode[0]) {
2021 if (ExtraCode[1] != 0)
2022 return true; // Unknown modifier.
2023
2024 switch (ExtraCode[0]) {
2025 default:
2026 // See if this is a generic print operand
2027 return AsmPrinter::PrintAsmOperand(MI, OpNo, ExtraCode, OS&: O);
2028 case 'r':
2029 break;
2030 }
2031 }
2032
2033 printOperand(MI, OpNum: OpNo, O);
2034
2035 return false;
2036}
2037
2038bool NVPTXAsmPrinter::PrintAsmMemoryOperand(const MachineInstr *MI,
2039 unsigned OpNo,
2040 const char *ExtraCode,
2041 raw_ostream &O) {
2042 if (ExtraCode && ExtraCode[0])
2043 return true; // Unknown modifier
2044
2045 O << '[';
2046 printMemOperand(MI, OpNum: OpNo, O);
2047 O << ']';
2048
2049 return false;
2050}
2051
2052void NVPTXAsmPrinter::printOperand(const MachineInstr *MI, unsigned OpNum,
2053 raw_ostream &O) {
2054 const MachineOperand &MO = MI->getOperand(i: OpNum);
2055 switch (MO.getType()) {
2056 case MachineOperand::MO_Register:
2057 if (MO.getReg().isPhysical()) {
2058 if (MO.getReg() == NVPTX::VRDepot)
2059 O << DEPOTNAME << getFunctionNumber();
2060 else
2061 O << NVPTXInstPrinter::getRegisterName(Reg: MO.getReg());
2062 } else {
2063 emitVirtualRegister(vr: MO.getReg(), O);
2064 }
2065 break;
2066
2067 case MachineOperand::MO_Immediate:
2068 O << MO.getImm();
2069 break;
2070
2071 case MachineOperand::MO_FPImmediate:
2072 printFPConstant(Fp: MO.getFPImm(), O);
2073 break;
2074
2075 case MachineOperand::MO_GlobalAddress:
2076 PrintSymbolOperand(MO, OS&: O);
2077 break;
2078
2079 case MachineOperand::MO_MachineBasicBlock:
2080 MO.getMBB()->getSymbol()->print(OS&: O, MAI);
2081 break;
2082
2083 default:
2084 llvm_unreachable("Operand type not supported.");
2085 }
2086}
2087
2088void NVPTXAsmPrinter::printMemOperand(const MachineInstr *MI, unsigned OpNum,
2089 raw_ostream &O, const char *Modifier) {
2090 printOperand(MI, OpNum, O);
2091
2092 if (Modifier && strcmp(s1: Modifier, s2: "add") == 0) {
2093 O << ", ";
2094 printOperand(MI, OpNum: OpNum + 1, O);
2095 } else {
2096 if (MI->getOperand(i: OpNum + 1).isImm() &&
2097 MI->getOperand(i: OpNum + 1).getImm() == 0)
2098 return; // don't print ',0' or '+0'
2099 O << "+";
2100 printOperand(MI, OpNum: OpNum + 1, O);
2101 }
2102}
2103
2104/// Returns true if \p Line begins with an alphabetic character or underscore,
2105/// indicating it is a PTX instruction that should receive a .loc directive.
2106static bool isPTXInstruction(StringRef Line) {
2107 StringRef Trimmed = Line.ltrim();
2108 return !Trimmed.empty() &&
2109 (std::isalpha(static_cast<unsigned char>(Trimmed[0])) ||
2110 Trimmed[0] == '_');
2111}
2112
2113/// Returns the DILocation for an inline asm MachineInstr if debug line info
2114/// should be emitted, or nullptr otherwise.
2115static const DILocation *getInlineAsmDebugLoc(const MachineInstr *MI) {
2116 if (!MI || !MI->getDebugLoc())
2117 return nullptr;
2118 const DISubprogram *SP = MI->getMF()->getFunction().getSubprogram();
2119 if (!SP || SP->getUnit()->getEmissionKind() == DICompileUnit::NoDebug)
2120 return nullptr;
2121 const DILocation *DL = MI->getDebugLoc();
2122 if (!DL->getFile() || !DL->getLine())
2123 return nullptr;
2124 return DL;
2125}
2126
2127namespace {
2128struct InlineAsmInliningContext {
2129 MCSymbol *FuncNameSym = nullptr;
2130 unsigned FileIA = 0;
2131 unsigned LineIA = 0;
2132 unsigned ColIA = 0;
2133
2134 bool hasInlinedAt() const { return FuncNameSym != nullptr; }
2135};
2136} // namespace
2137
2138/// Resolves the enhanced-lineinfo inlining context for an inline asm debug
2139/// location. Returns a default (empty) context if inlining info is unavailable.
2140static InlineAsmInliningContext
2141getInlineAsmInliningContext(const DILocation *DL, const MachineFunction &MF,
2142 NVPTXDwarfDebug *NVDD, MCStreamer &Streamer,
2143 unsigned CUID) {
2144 InlineAsmInliningContext Ctx;
2145 const DILocation *InlinedAt = DL->getInlinedAt();
2146 if (!InlinedAt || !InlinedAt->getFile() || !NVDD ||
2147 !NVDD->isEnhancedLineinfo(MF))
2148 return Ctx;
2149 const auto *SubProg = getDISubprogram(Scope: DL->getScope());
2150 if (!SubProg)
2151 return Ctx;
2152 Ctx.FuncNameSym = NVDD->getOrCreateFuncNameSymbol(LinkageName: SubProg->getLinkageName());
2153 Ctx.FileIA = Streamer.emitDwarfFileDirective(
2154 FileNo: 0, Directory: InlinedAt->getFile()->getDirectory(),
2155 Filename: InlinedAt->getFile()->getFilename(), Checksum: std::nullopt, Source: std::nullopt, CUID);
2156 Ctx.LineIA = InlinedAt->getLine();
2157 Ctx.ColIA = InlinedAt->getColumn();
2158 return Ctx;
2159}
2160
2161void NVPTXAsmPrinter::emitInlineAsm(StringRef Str, const MCSubtargetInfo &STI,
2162 const MCTargetOptions &MCOptions,
2163 const MDNode *LocMDNode,
2164 InlineAsm::AsmDialect Dialect,
2165 const MachineInstr *MI) {
2166 assert(!Str.empty() && "Can't emit empty inline asm block");
2167 if (Str.back() == 0)
2168 Str = Str.substr(Start: 0, N: Str.size() - 1);
2169
2170 auto emitAsmStr = [&](StringRef AsmStr) {
2171 emitInlineAsmStart();
2172 OutStreamer->emitRawText(String: AsmStr);
2173 emitInlineAsmEnd(StartInfo: STI, EndInfo: nullptr, MI);
2174 };
2175
2176 const DILocation *DL = getInlineAsmDebugLoc(MI);
2177 if (!DL) {
2178 emitAsmStr(Str);
2179 return;
2180 }
2181
2182 const DIFile *File = DL->getFile();
2183 unsigned Line = DL->getLine();
2184 const unsigned Column = DL->getColumn();
2185 const unsigned CUID = OutStreamer->getContext().getDwarfCompileUnitID();
2186 const unsigned FileNumber = OutStreamer->emitDwarfFileDirective(
2187 FileNo: 0, Directory: File->getDirectory(), Filename: File->getFilename(), Checksum: std::nullopt, Source: std::nullopt,
2188 CUID);
2189
2190 auto *NVDD = static_cast<NVPTXDwarfDebug *>(getDwarfDebug());
2191 InlineAsmInliningContext InlineCtx =
2192 getInlineAsmInliningContext(DL, MF: *MI->getMF(), NVDD, Streamer&: *OutStreamer, CUID);
2193
2194 SmallVector<StringRef, 16> Lines;
2195 Str.split(A&: Lines, Separator: '\n');
2196 emitInlineAsmStart();
2197 for (const StringRef &L : Lines) {
2198 StringRef RTrimmed = L.rtrim(Char: '\r');
2199 if (isPTXInstruction(Line: L)) {
2200 if (InlineCtx.hasInlinedAt()) {
2201 OutStreamer->emitDwarfLocDirectiveWithInlinedAt(
2202 FileNo: FileNumber, Line, Column, FileIA: InlineCtx.FileIA, LineIA: InlineCtx.LineIA,
2203 ColumnIA: InlineCtx.ColIA, Sym: InlineCtx.FuncNameSym, DWARF2_FLAG_IS_STMT, Isa: 0, Discriminator: 0,
2204 FileName: File->getFilename());
2205 } else {
2206 OutStreamer->emitDwarfLocDirective(FileNo: FileNumber, Line, Column,
2207 DWARF2_FLAG_IS_STMT, Isa: 0, Discriminator: 0,
2208 FileName: File->getFilename());
2209 }
2210 }
2211 OutStreamer->emitRawText(String: RTrimmed);
2212 ++Line;
2213 }
2214 emitInlineAsmEnd(StartInfo: STI, EndInfo: nullptr, MI);
2215}
2216
2217char NVPTXAsmPrinter::ID = 0;
2218
2219INITIALIZE_PASS(NVPTXAsmPrinter, "nvptx-asm-printer", "NVPTX Assembly Printer",
2220 false, false)
2221
2222// Force static initialization.
2223extern "C" LLVM_ABI LLVM_EXTERNAL_VISIBILITY void
2224LLVMInitializeNVPTXAsmPrinter() {
2225 RegisterAsmPrinter<NVPTXAsmPrinter> X(getTheNVPTXTarget32());
2226 RegisterAsmPrinter<NVPTXAsmPrinter> Y(getTheNVPTXTarget64());
2227}
2228