1//===-- NVPTXAsmPrinter.cpp - NVPTX LLVM assembly writer ------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file contains a printer that converts from our internal representation
10// of machine-dependent LLVM code to NVPTX assembly language.
11//
12//===----------------------------------------------------------------------===//
13
14#include "NVPTXAsmPrinter.h"
15#include "MCTargetDesc/NVPTXBaseInfo.h"
16#include "MCTargetDesc/NVPTXInstPrinter.h"
17#include "MCTargetDesc/NVPTXMCAsmInfo.h"
18#include "MCTargetDesc/NVPTXTargetStreamer.h"
19#include "NVPTX.h"
20#include "NVPTXDwarfDebug.h"
21#include "NVPTXMCExpr.h"
22#include "NVPTXMachineFunctionInfo.h"
23#include "NVPTXRegisterInfo.h"
24#include "NVPTXSubtarget.h"
25#include "NVPTXTargetMachine.h"
26#include "NVPTXUtilities.h"
27#include "NVVMProperties.h"
28#include "TargetInfo/NVPTXTargetInfo.h"
29#include "cl_common_defines.h"
30#include "llvm/ADT/APFloat.h"
31#include "llvm/ADT/APInt.h"
32#include "llvm/ADT/ArrayRef.h"
33#include "llvm/ADT/DenseMap.h"
34#include "llvm/ADT/DenseSet.h"
35#include "llvm/ADT/SmallString.h"
36#include "llvm/ADT/SmallVector.h"
37#include "llvm/ADT/StringExtras.h"
38#include "llvm/ADT/StringRef.h"
39#include "llvm/ADT/Twine.h"
40#include "llvm/ADT/iterator_range.h"
41#include "llvm/Analysis/ConstantFolding.h"
42#include "llvm/CodeGen/Analysis.h"
43#include "llvm/CodeGen/MachineBasicBlock.h"
44#include "llvm/CodeGen/MachineFrameInfo.h"
45#include "llvm/CodeGen/MachineFunction.h"
46#include "llvm/CodeGen/MachineInstr.h"
47#include "llvm/CodeGen/MachineLoopInfo.h"
48#include "llvm/CodeGen/MachineModuleInfo.h"
49#include "llvm/CodeGen/MachineOperand.h"
50#include "llvm/CodeGen/MachineRegisterInfo.h"
51#include "llvm/CodeGen/TargetRegisterInfo.h"
52#include "llvm/CodeGen/ValueTypes.h"
53#include "llvm/CodeGenTypes/MachineValueType.h"
54#include "llvm/IR/Argument.h"
55#include "llvm/IR/Attributes.h"
56#include "llvm/IR/BasicBlock.h"
57#include "llvm/IR/Constant.h"
58#include "llvm/IR/Constants.h"
59#include "llvm/IR/DataLayout.h"
60#include "llvm/IR/DebugInfo.h"
61#include "llvm/IR/DebugInfoMetadata.h"
62#include "llvm/IR/DebugLoc.h"
63#include "llvm/IR/DerivedTypes.h"
64#include "llvm/IR/Function.h"
65#include "llvm/IR/GlobalAlias.h"
66#include "llvm/IR/GlobalValue.h"
67#include "llvm/IR/GlobalVariable.h"
68#include "llvm/IR/Instruction.h"
69#include "llvm/IR/LLVMContext.h"
70#include "llvm/IR/Module.h"
71#include "llvm/IR/Operator.h"
72#include "llvm/IR/Type.h"
73#include "llvm/IR/User.h"
74#include "llvm/MC/MCExpr.h"
75#include "llvm/MC/MCInst.h"
76#include "llvm/MC/MCInstrDesc.h"
77#include "llvm/MC/MCStreamer.h"
78#include "llvm/MC/MCSymbol.h"
79#include "llvm/MC/TargetRegistry.h"
80#include "llvm/Support/Alignment.h"
81#include "llvm/Support/Casting.h"
82#include "llvm/Support/Compiler.h"
83#include "llvm/Support/Endian.h"
84#include "llvm/Support/ErrorHandling.h"
85#include "llvm/Support/NativeFormatting.h"
86#include "llvm/Support/raw_ostream.h"
87#include "llvm/Target/TargetLoweringObjectFile.h"
88#include "llvm/Target/TargetMachine.h"
89#include "llvm/Transforms/Utils/UnrollLoop.h"
90#include <cassert>
91#include <cstdint>
92#include <cstring>
93#include <string>
94
95using namespace llvm;
96
97#define DEPOTNAME "__local_depot"
98
99static StringRef getTextureName(const Value &V) {
100 assert(V.hasName() && "Found texture variable with no name");
101 return V.getName();
102}
103
104static StringRef getSurfaceName(const Value &V) {
105 assert(V.hasName() && "Found surface variable with no name");
106 return V.getName();
107}
108
109static StringRef getSamplerName(const Value &V) {
110 assert(V.hasName() && "Found sampler variable with no name");
111 return V.getName();
112}
113
114/// discoverDependentGlobals - Return a set of GlobalVariables on which \p V
115/// depends.
116static void
117discoverDependentGlobals(const Value *V,
118 DenseSet<const GlobalVariable *> &Globals) {
119 if (const GlobalVariable *GV = dyn_cast<GlobalVariable>(Val: V)) {
120 Globals.insert(V: GV);
121 return;
122 }
123
124 if (const User *U = dyn_cast<User>(Val: V))
125 for (const auto &O : U->operands())
126 discoverDependentGlobals(V: O, Globals);
127}
128
129/// VisitGlobalVariableForEmission - Add \p GV to the list of GlobalVariable
130/// instances to be emitted, but only after any dependents have been added
131/// first.s
132static void
133VisitGlobalVariableForEmission(const GlobalVariable *GV,
134 SmallVectorImpl<const GlobalVariable *> &Order,
135 DenseSet<const GlobalVariable *> &Visited,
136 DenseSet<const GlobalVariable *> &Visiting) {
137 // Have we already visited this one?
138 if (Visited.count(V: GV))
139 return;
140
141 // Do we have a circular dependency?
142 if (!Visiting.insert(V: GV).second)
143 report_fatal_error(reason: "Circular dependency found in global variable set");
144
145 // Make sure we visit all dependents first
146 DenseSet<const GlobalVariable *> Others;
147 for (const auto &O : GV->operands())
148 discoverDependentGlobals(V: O, Globals&: Others);
149
150 for (const GlobalVariable *GV : Others)
151 VisitGlobalVariableForEmission(GV, Order, Visited, Visiting);
152
153 // Now we can visit ourself
154 Order.push_back(Elt: GV);
155 Visited.insert(V: GV);
156 Visiting.erase(V: GV);
157}
158
159void NVPTXAsmPrinter::emitInstruction(const MachineInstr *MI) {
160 NVPTX_MC::verifyInstructionPredicates(Opcode: MI->getOpcode(),
161 Features: getSubtargetInfo().getFeatureBits());
162
163 MCInst Inst;
164 lowerToMCInst(MI, OutMI&: Inst);
165 EmitToStreamer(S&: *OutStreamer, Inst);
166}
167
168void NVPTXAsmPrinter::lowerToMCInst(const MachineInstr *MI, MCInst &OutMI) {
169 OutMI.setOpcode(MI->getOpcode());
170 // Special: Do not mangle symbol operand of CALL_PROTOTYPE
171 if (MI->getOpcode() == NVPTX::CALL_PROTOTYPE) {
172 const MachineOperand &MO = MI->getOperand(i: 0);
173 OutMI.addOperand(Op: GetSymbolRef(
174 Symbol: OutContext.getOrCreateSymbol(Name: Twine(MO.getSymbolName()))));
175 return;
176 }
177
178 for (const auto MO : MI->operands())
179 OutMI.addOperand(Op: lowerOperand(MO));
180}
181
182MCOperand NVPTXAsmPrinter::lowerOperand(const MachineOperand &MO) {
183 switch (MO.getType()) {
184 default:
185 llvm_unreachable("unknown operand type");
186 case MachineOperand::MO_Register:
187 return MCOperand::createReg(Reg: encodeVirtualRegister(Reg: MO.getReg()));
188 case MachineOperand::MO_Immediate:
189 return MCOperand::createImm(Val: MO.getImm());
190 case MachineOperand::MO_MachineBasicBlock:
191 return MCOperand::createExpr(
192 Val: MCSymbolRefExpr::create(Symbol: MO.getMBB()->getSymbol(), Ctx&: OutContext));
193 case MachineOperand::MO_ExternalSymbol:
194 return GetSymbolRef(Symbol: GetExternalSymbolSymbol(Sym: MO.getSymbolName()));
195 case MachineOperand::MO_GlobalAddress:
196 return GetSymbolRef(Symbol: getSymbol(GV: MO.getGlobal()));
197 case MachineOperand::MO_FPImmediate: {
198 const ConstantFP *Cnt = MO.getFPImm();
199 const APFloat &Val = Cnt->getValueAPF();
200
201 switch (Cnt->getType()->getTypeID()) {
202 default:
203 report_fatal_error(reason: "Unsupported FP type");
204 break;
205 case Type::HalfTyID:
206 return MCOperand::createExpr(
207 Val: NVPTXFloatMCExpr::createConstantFPHalf(Flt: Val, Ctx&: OutContext));
208 case Type::BFloatTyID:
209 return MCOperand::createExpr(
210 Val: NVPTXFloatMCExpr::createConstantBFPHalf(Flt: Val, Ctx&: OutContext));
211 case Type::FloatTyID:
212 return MCOperand::createExpr(
213 Val: NVPTXFloatMCExpr::createConstantFPSingle(Flt: Val, Ctx&: OutContext));
214 case Type::DoubleTyID:
215 return MCOperand::createExpr(
216 Val: NVPTXFloatMCExpr::createConstantFPDouble(Flt: Val, Ctx&: OutContext));
217 }
218 break;
219 }
220 }
221}
222
223unsigned NVPTXAsmPrinter::encodeVirtualRegister(unsigned Reg) {
224 if (Register::isVirtualRegister(Reg)) {
225 const TargetRegisterClass *RC = MRI->getRegClass(Reg);
226
227 DenseMap<unsigned, unsigned> &RegMap = VRegMapping[RC];
228 unsigned RegNum = RegMap[Reg];
229
230 // Encode the register class in the upper 4 bits
231 // Must be kept in sync with NVPTXInstPrinter::printRegName
232 unsigned Ret = 0;
233 if (RC == &NVPTX::B1RegClass) {
234 Ret = (1 << 28);
235 } else if (RC == &NVPTX::B16RegClass) {
236 Ret = (2 << 28);
237 } else if (RC == &NVPTX::B32RegClass) {
238 Ret = (3 << 28);
239 } else if (RC == &NVPTX::B64RegClass) {
240 Ret = (4 << 28);
241 } else if (RC == &NVPTX::B128RegClass) {
242 Ret = (7 << 28);
243 } else {
244 report_fatal_error(reason: "Bad register class");
245 }
246
247 // Insert the vreg number
248 Ret |= (RegNum & 0x0FFFFFFF);
249 return Ret;
250 } else {
251 // Some special-use registers are actually physical registers.
252 // Encode this as the register class ID of 0 and the real register ID.
253 return Reg & 0x0FFFFFFF;
254 }
255}
256
257MCOperand NVPTXAsmPrinter::GetSymbolRef(const MCSymbol *Symbol) {
258 const MCExpr *Expr;
259 Expr = MCSymbolRefExpr::create(Symbol, Ctx&: OutContext);
260 return MCOperand::createExpr(Val: Expr);
261}
262
263void NVPTXAsmPrinter::printReturnValStr(const Function *F, raw_ostream &O) {
264 const DataLayout &DL = getDataLayout();
265 const NVPTXSubtarget &STI = TM.getSubtarget<NVPTXSubtarget>(F: *F);
266 const auto *TLI = cast<NVPTXTargetLowering>(Val: STI.getTargetLowering());
267
268 Type *Ty = F->getReturnType();
269 if (Ty->getTypeID() == Type::VoidTyID)
270 return;
271 O << " (";
272
273 auto PrintScalarRetVal = [&](unsigned Size) {
274 O << ".param .b" << promoteScalarArgumentSize(size: Size) << " func_retval0";
275 };
276 if (shouldPassAsArray(Ty)) {
277 const unsigned TotalSize = DL.getTypeAllocSize(Ty);
278 const Align RetAlignment =
279 getFunctionArgumentAlignment(F, Ty, Idx: AttributeList::ReturnIndex, DL);
280 O << ".param .align " << RetAlignment.value() << " .b8 func_retval0["
281 << TotalSize << "]";
282 } else if (Ty->isFloatingPointTy()) {
283 PrintScalarRetVal(Ty->getPrimitiveSizeInBits());
284 } else if (auto *ITy = dyn_cast<IntegerType>(Val: Ty)) {
285 PrintScalarRetVal(ITy->getBitWidth());
286 } else if (isa<PointerType>(Val: Ty)) {
287 PrintScalarRetVal(TLI->getPointerTy(DL).getSizeInBits());
288 } else
289 llvm_unreachable("Unknown return type");
290 O << ") ";
291}
292
293void NVPTXAsmPrinter::printReturnValStr(const MachineFunction &MF,
294 raw_ostream &O) {
295 const Function &F = MF.getFunction();
296 printReturnValStr(F: &F, O);
297}
298
299// Return true if MBB is the header of a loop marked with
300// llvm.loop.unroll.disable or llvm.loop.unroll.count=1.
301bool NVPTXAsmPrinter::isLoopHeaderOfNoUnroll(
302 const MachineBasicBlock &MBB) const {
303 MachineLoopInfo &LI = getAnalysis<MachineLoopInfoWrapperPass>().getLI();
304 // We insert .pragma "nounroll" only to the loop header.
305 if (!LI.isLoopHeader(BB: &MBB))
306 return false;
307
308 // llvm.loop.unroll.disable is marked on the back edges of a loop. Therefore,
309 // we iterate through each back edge of the loop with header MBB, and check
310 // whether its metadata contains llvm.loop.unroll.disable.
311 for (const MachineBasicBlock *PMBB : MBB.predecessors()) {
312 if (LI.getLoopFor(BB: PMBB) != LI.getLoopFor(BB: &MBB)) {
313 // Edges from other loops to MBB are not back edges.
314 continue;
315 }
316 if (const BasicBlock *PBB = PMBB->getBasicBlock()) {
317 if (MDNode *LoopID =
318 PBB->getTerminator()->getMetadata(KindID: LLVMContext::MD_loop)) {
319 if (GetUnrollMetadata(LoopID, Name: "llvm.loop.unroll.disable"))
320 return true;
321 if (MDNode *UnrollCountMD =
322 GetUnrollMetadata(LoopID, Name: "llvm.loop.unroll.count")) {
323 if (mdconst::extract<ConstantInt>(MD: UnrollCountMD->getOperand(I: 1))
324 ->isOne())
325 return true;
326 }
327 }
328 }
329 }
330 return false;
331}
332
333void NVPTXAsmPrinter::emitBasicBlockStart(const MachineBasicBlock &MBB) {
334 AsmPrinter::emitBasicBlockStart(MBB);
335 if (isLoopHeaderOfNoUnroll(MBB))
336 OutStreamer->emitRawText(String: StringRef("\t.pragma \"nounroll\";\n"));
337}
338
339void NVPTXAsmPrinter::emitFunctionEntryLabel() {
340 SmallString<128> Str;
341 raw_svector_ostream O(Str);
342
343 if (!GlobalsEmitted) {
344 emitGlobals(M: *MF->getFunction().getParent());
345 GlobalsEmitted = true;
346 }
347
348 // Set up
349 MRI = &MF->getRegInfo();
350 F = &MF->getFunction();
351 emitLinkageDirective(V: F, O);
352 if (isKernelFunction(F: *F))
353 O << ".entry ";
354 else {
355 O << ".func ";
356 printReturnValStr(MF: *MF, O);
357 }
358
359 CurrentFnSym->print(OS&: O, MAI);
360
361 emitFunctionParamList(F, O);
362 O << "\n";
363
364 if (isKernelFunction(F: *F))
365 emitKernelFunctionDirectives(F: *F, O);
366
367 if (shouldEmitPTXNoReturn(V: F, TM))
368 O << ".noreturn";
369
370 OutStreamer->emitRawText(String: O.str());
371
372 VRegMapping.clear();
373 // Emit open brace for function body.
374 OutStreamer->emitRawText(String: StringRef("{\n"));
375 setAndEmitFunctionVirtualRegisters(*MF);
376 encodeDebugInfoRegisterNumbers(MF: *MF);
377 // Emit initial .loc debug directive for correct relocation symbol data.
378 if (const DISubprogram *SP = MF->getFunction().getSubprogram()) {
379 assert(SP->getUnit());
380 if (!SP->getUnit()->isDebugDirectivesOnly())
381 emitInitialRawDwarfLocDirective(MF: *MF);
382 }
383}
384
385bool NVPTXAsmPrinter::runOnMachineFunction(MachineFunction &F) {
386 bool Result = AsmPrinter::runOnMachineFunction(MF&: F);
387 // Emit closing brace for the body of function F.
388 // The closing brace must be emitted here because we need to emit additional
389 // debug labels/data after the last basic block.
390 // We need to emit the closing brace here because we don't have function that
391 // finished emission of the function body.
392 OutStreamer->emitRawText(String: StringRef("}\n"));
393 return Result;
394}
395
396void NVPTXAsmPrinter::emitFunctionBodyStart() {
397 SmallString<128> Str;
398 raw_svector_ostream O(Str);
399 emitDemotedVars(&MF->getFunction(), O);
400 OutStreamer->emitRawText(String: O.str());
401}
402
403void NVPTXAsmPrinter::emitFunctionBodyEnd() {
404 VRegMapping.clear();
405}
406
407const MCSymbol *NVPTXAsmPrinter::getFunctionFrameSymbol() const {
408 SmallString<128> Str;
409 raw_svector_ostream(Str) << DEPOTNAME << getFunctionNumber();
410 return OutContext.getOrCreateSymbol(Name: Str);
411}
412
413void NVPTXAsmPrinter::emitImplicitDef(const MachineInstr *MI) const {
414 Register RegNo = MI->getOperand(i: 0).getReg();
415 if (RegNo.isVirtual()) {
416 OutStreamer->AddComment(T: Twine("implicit-def: ") +
417 getVirtualRegisterName(RegNo));
418 } else {
419 const NVPTXSubtarget &STI = MI->getMF()->getSubtarget<NVPTXSubtarget>();
420 OutStreamer->AddComment(T: Twine("implicit-def: ") +
421 STI.getRegisterInfo()->getName(RegNo));
422 }
423 OutStreamer->addBlankLine();
424}
425
426void NVPTXAsmPrinter::emitKernelFunctionDirectives(const Function &F,
427 raw_ostream &O) const {
428 // If the NVVM IR has some of reqntid* specified, then output
429 // the reqntid directive, and set the unspecified ones to 1.
430 // If none of Reqntid* is specified, don't output reqntid directive.
431 const auto ReqNTID = getReqNTID(F);
432 if (!ReqNTID.empty())
433 O << formatv(Fmt: ".reqntid {0:$[, ]}\n",
434 Vals: make_range(x: ReqNTID.begin(), y: ReqNTID.end()));
435
436 const auto MaxNTID = getMaxNTID(F);
437 if (!MaxNTID.empty())
438 O << formatv(Fmt: ".maxntid {0:$[, ]}\n",
439 Vals: make_range(x: MaxNTID.begin(), y: MaxNTID.end()));
440
441 if (const auto Mincta = getMinCTASm(F))
442 O << ".minnctapersm " << *Mincta << "\n";
443
444 if (const auto Maxnreg = getMaxNReg(F))
445 O << ".maxnreg " << *Maxnreg << "\n";
446
447 // .maxclusterrank directive requires SM_90 or higher, make sure that we
448 // filter it out for lower SM versions, as it causes a hard ptxas crash.
449 const NVPTXTargetMachine &NTM = static_cast<const NVPTXTargetMachine &>(TM);
450 const NVPTXSubtarget *STI = &NTM.getSubtarget<NVPTXSubtarget>(F);
451
452 if (STI->getSmVersion() >= 90) {
453 const auto ClusterDim = getClusterDim(F);
454 const bool BlocksAreClusters = hasBlocksAreClusters(F);
455
456 if (!ClusterDim.empty()) {
457
458 if (!BlocksAreClusters)
459 O << ".explicitcluster\n";
460
461 if (ClusterDim[0] != 0) {
462 assert(llvm::all_of(ClusterDim, not_equal_to(0)) &&
463 "cluster_dim_x != 0 implies cluster_dim_y and cluster_dim_z "
464 "should be non-zero as well");
465
466 O << formatv(Fmt: ".reqnctapercluster {0:$[, ]}\n",
467 Vals: make_range(x: ClusterDim.begin(), y: ClusterDim.end()));
468 } else {
469 assert(llvm::all_of(ClusterDim, equal_to(0)) &&
470 "cluster_dim_x == 0 implies cluster_dim_y and cluster_dim_z "
471 "should be 0 as well");
472 }
473 }
474
475 if (BlocksAreClusters) {
476 LLVMContext &Ctx = F.getContext();
477 if (ReqNTID.empty() || ClusterDim.empty())
478 Ctx.diagnose(DI: DiagnosticInfoUnsupported(
479 F, "blocksareclusters requires reqntid and cluster_dim attributes",
480 F.getSubprogram()));
481 else if (STI->getPTXVersion() < 90)
482 Ctx.diagnose(DI: DiagnosticInfoUnsupported(
483 F, "blocksareclusters requires PTX version >= 9.0",
484 F.getSubprogram()));
485 else
486 O << ".blocksareclusters\n";
487 }
488
489 if (const auto Maxclusterrank = getMaxClusterRank(F))
490 O << ".maxclusterrank " << *Maxclusterrank << "\n";
491 }
492}
493
494std::string NVPTXAsmPrinter::getVirtualRegisterName(unsigned Reg) const {
495 const TargetRegisterClass *RC = MRI->getRegClass(Reg);
496
497 std::string Name;
498 raw_string_ostream NameStr(Name);
499
500 VRegRCMap::const_iterator I = VRegMapping.find(Val: RC);
501 assert(I != VRegMapping.end() && "Bad register class");
502 const DenseMap<unsigned, unsigned> &RegMap = I->second;
503
504 VRegMap::const_iterator VI = RegMap.find(Val: Reg);
505 assert(VI != RegMap.end() && "Bad virtual register");
506 unsigned MappedVR = VI->second;
507
508 NameStr << getNVPTXRegClassStr(RC) << MappedVR;
509
510 return Name;
511}
512
513void NVPTXAsmPrinter::emitVirtualRegister(unsigned int vr,
514 raw_ostream &O) {
515 O << getVirtualRegisterName(Reg: vr);
516}
517
518void NVPTXAsmPrinter::emitAliasDeclaration(const GlobalAlias *GA,
519 raw_ostream &O) {
520 const Function *F = dyn_cast_or_null<Function>(Val: GA->getAliaseeObject());
521 if (!F || isKernelFunction(F: *F) || F->isDeclaration())
522 report_fatal_error(
523 reason: "NVPTX aliasee must be a non-kernel function definition");
524
525 if (GA->hasLinkOnceLinkage() || GA->hasWeakLinkage() ||
526 GA->hasAvailableExternallyLinkage() || GA->hasCommonLinkage())
527 report_fatal_error(reason: "NVPTX aliasee must not be '.weak'");
528
529 emitDeclarationWithName(F, getSymbol(GV: GA), O);
530}
531
532void NVPTXAsmPrinter::emitDeclaration(const Function *F, raw_ostream &O) {
533 emitDeclarationWithName(F, getSymbol(GV: F), O);
534}
535
536void NVPTXAsmPrinter::emitDeclarationWithName(const Function *F, MCSymbol *S,
537 raw_ostream &O) {
538 emitLinkageDirective(V: F, O);
539 if (isKernelFunction(F: *F))
540 O << ".entry ";
541 else
542 O << ".func ";
543 printReturnValStr(F, O);
544 S->print(OS&: O, MAI);
545 O << "\n";
546 emitFunctionParamList(F, O);
547 O << "\n";
548 if (shouldEmitPTXNoReturn(V: F, TM))
549 O << ".noreturn";
550 O << ";\n";
551}
552
553static bool usedInGlobalVarDef(const Constant *C) {
554 if (!C)
555 return false;
556
557 if (const GlobalVariable *GV = dyn_cast<GlobalVariable>(Val: C))
558 return GV->getName() != "llvm.used";
559
560 for (const User *U : C->users())
561 if (const Constant *C = dyn_cast<Constant>(Val: U))
562 if (usedInGlobalVarDef(C))
563 return true;
564
565 return false;
566}
567
568static bool usedInOneFunc(const User *U, Function const *&OneFunc) {
569 if (const GlobalVariable *OtherGV = dyn_cast<GlobalVariable>(Val: U))
570 if (OtherGV->getName() == "llvm.used")
571 return true;
572
573 if (const Instruction *I = dyn_cast<Instruction>(Val: U)) {
574 if (const Function *CurFunc = I->getFunction()) {
575 if (OneFunc && (CurFunc != OneFunc))
576 return false;
577 OneFunc = CurFunc;
578 return true;
579 }
580 return false;
581 }
582
583 for (const User *UU : U->users())
584 if (!usedInOneFunc(U: UU, OneFunc))
585 return false;
586
587 return true;
588}
589
590/* Find out if a global variable can be demoted to local scope.
591 * Currently, this is valid for CUDA shared variables, which have local
592 * scope and global lifetime. So the conditions to check are :
593 * 1. Is the global variable in shared address space?
594 * 2. Does it have local linkage?
595 * 3. Is the global variable referenced only in one function?
596 */
597static bool canDemoteGlobalVar(const GlobalVariable *GV, Function const *&f) {
598 if (!GV->hasLocalLinkage())
599 return false;
600 if (GV->getAddressSpace() != ADDRESS_SPACE_SHARED)
601 return false;
602
603 const Function *oneFunc = nullptr;
604
605 bool flag = usedInOneFunc(U: GV, OneFunc&: oneFunc);
606 if (!flag)
607 return false;
608 if (!oneFunc)
609 return false;
610 f = oneFunc;
611 return true;
612}
613
614static bool useFuncSeen(const Constant *C,
615 const SmallPtrSetImpl<const Function *> &SeenSet) {
616 for (const User *U : C->users()) {
617 if (const Constant *cu = dyn_cast<Constant>(Val: U)) {
618 if (useFuncSeen(C: cu, SeenSet))
619 return true;
620 } else if (const Instruction *I = dyn_cast<Instruction>(Val: U)) {
621 if (const Function *Caller = I->getFunction())
622 if (SeenSet.contains(Ptr: Caller))
623 return true;
624 }
625 }
626 return false;
627}
628
629void NVPTXAsmPrinter::emitDeclarations(const Module &M, raw_ostream &O) {
630 SmallPtrSet<const Function *, 32> SeenSet;
631 for (const Function &F : M) {
632 if (F.getAttributes().hasFnAttr(Kind: "nvptx-libcall-callee")) {
633 emitDeclaration(F: &F, O);
634 continue;
635 }
636
637 if (F.isDeclaration()) {
638 if (F.use_empty())
639 continue;
640 if (F.getIntrinsicID())
641 continue;
642 emitDeclaration(F: &F, O);
643 continue;
644 }
645 for (const User *U : F.users()) {
646 if (const Constant *C = dyn_cast<Constant>(Val: U)) {
647 if (usedInGlobalVarDef(C)) {
648 // The use is in the initialization of a global variable
649 // that is a function pointer, so print a declaration
650 // for the original function
651 emitDeclaration(F: &F, O);
652 break;
653 }
654 // Emit a declaration of this function if the function that
655 // uses this constant expr has already been seen.
656 if (useFuncSeen(C, SeenSet)) {
657 emitDeclaration(F: &F, O);
658 break;
659 }
660 }
661
662 if (!isa<Instruction>(Val: U))
663 continue;
664 const Function *Caller = cast<Instruction>(Val: U)->getFunction();
665 if (!Caller)
666 continue;
667
668 // If a caller has already been seen, then the caller is
669 // appearing in the module before the callee. so print out
670 // a declaration for the callee.
671 if (SeenSet.contains(Ptr: Caller)) {
672 emitDeclaration(F: &F, O);
673 break;
674 }
675 }
676 SeenSet.insert(Ptr: &F);
677 }
678 for (const GlobalAlias &GA : M.aliases())
679 emitAliasDeclaration(GA: &GA, O);
680}
681
682void NVPTXAsmPrinter::emitStartOfAsmFile(Module &M) {
683 // Construct a default subtarget off of the TargetMachine defaults. The
684 // rest of NVPTX isn't friendly to change subtargets per function and
685 // so the default TargetMachine will have all of the options.
686 const NVPTXTargetMachine &NTM = static_cast<const NVPTXTargetMachine &>(TM);
687 const NVPTXSubtarget *STI = NTM.getSubtargetImpl();
688
689 // Emit header before any dwarf directives are emitted below.
690 emitHeader(M, STI: *STI);
691}
692
693/// Create NVPTX-specific DwarfDebug handler.
694DwarfDebug *NVPTXAsmPrinter::createDwarfDebug() {
695 return new NVPTXDwarfDebug(this);
696}
697
698bool NVPTXAsmPrinter::doInitialization(Module &M) {
699 const NVPTXTargetMachine &NTM = static_cast<const NVPTXTargetMachine &>(TM);
700 const NVPTXSubtarget &STI = *NTM.getSubtargetImpl();
701 if (M.alias_size() && (STI.getPTXVersion() < 63 || STI.getSmVersion() < 30))
702 report_fatal_error(reason: ".alias requires PTX version >= 6.3 and sm_30");
703
704 // We need to call the parent's one explicitly.
705 bool Result = AsmPrinter::doInitialization(M);
706
707 GlobalsEmitted = false;
708
709 return Result;
710}
711
712void NVPTXAsmPrinter::emitGlobals(const Module &M) {
713 SmallString<128> Str2;
714 raw_svector_ostream OS2(Str2);
715
716 emitDeclarations(M, O&: OS2);
717
718 // As ptxas does not support forward references of globals, we need to first
719 // sort the list of module-level globals in def-use order. We visit each
720 // global variable in order, and ensure that we emit it *after* its dependent
721 // globals. We use a little extra memory maintaining both a set and a list to
722 // have fast searches while maintaining a strict ordering.
723 SmallVector<const GlobalVariable *, 8> Globals;
724 DenseSet<const GlobalVariable *> GVVisited;
725 DenseSet<const GlobalVariable *> GVVisiting;
726
727 // Visit each global variable, in order
728 for (const GlobalVariable &I : M.globals())
729 VisitGlobalVariableForEmission(GV: &I, Order&: Globals, Visited&: GVVisited, Visiting&: GVVisiting);
730
731 assert(GVVisited.size() == M.global_size() && "Missed a global variable");
732 assert(GVVisiting.size() == 0 && "Did not fully process a global variable");
733
734 const NVPTXTargetMachine &NTM = static_cast<const NVPTXTargetMachine &>(TM);
735 const NVPTXSubtarget &STI = *NTM.getSubtargetImpl();
736
737 // Print out module-level global variables in proper order
738 for (const GlobalVariable *GV : Globals)
739 printModuleLevelGV(GVar: GV, O&: OS2, /*ProcessDemoted=*/processDemoted: false, STI);
740
741 OS2 << '\n';
742
743 OutStreamer->emitRawText(String: OS2.str());
744}
745
746void NVPTXAsmPrinter::emitGlobalAlias(const Module &M, const GlobalAlias &GA) {
747 SmallString<128> Str;
748 raw_svector_ostream OS(Str);
749
750 MCSymbol *Name = getSymbol(GV: &GA);
751
752 OS << ".alias " << Name->getName() << ", " << GA.getAliaseeObject()->getName()
753 << ";\n";
754
755 OutStreamer->emitRawText(String: OS.str());
756}
757
758NVPTXTargetStreamer *NVPTXAsmPrinter::getTargetStreamer() const {
759 return static_cast<NVPTXTargetStreamer *>(OutStreamer->getTargetStreamer());
760}
761
762static bool hasFullDebugInfo(Module &M) {
763 for (DICompileUnit *CU : M.debug_compile_units()) {
764 switch(CU->getEmissionKind()) {
765 case DICompileUnit::NoDebug:
766 case DICompileUnit::DebugDirectivesOnly:
767 break;
768 case DICompileUnit::LineTablesOnly:
769 case DICompileUnit::FullDebug:
770 return true;
771 }
772 }
773
774 return false;
775}
776
777void NVPTXAsmPrinter::emitHeader(Module &M, const NVPTXSubtarget &STI) {
778 auto *TS = getTargetStreamer();
779
780 TS->emitBanner();
781
782 const unsigned PTXVersion = STI.getPTXVersion();
783 TS->emitVersionDirective(PTXVersion);
784
785 const NVPTXTargetMachine &NTM = static_cast<const NVPTXTargetMachine &>(TM);
786 bool TexModeIndependent = NTM.getDrvInterface() == NVPTX::NVCL;
787
788 TS->emitTargetDirective(Target: STI.getTargetName(), TexModeIndependent,
789 HasDebug: hasFullDebugInfo(M));
790 TS->emitAddressSizeDirective(AddrSize: M.getDataLayout().getPointerSizeInBits());
791}
792
793bool NVPTXAsmPrinter::doFinalization(Module &M) {
794 // If we did not emit any functions, then the global declarations have not
795 // yet been emitted.
796 if (!GlobalsEmitted) {
797 emitGlobals(M);
798 GlobalsEmitted = true;
799 }
800
801 // call doFinalization
802 bool ret = AsmPrinter::doFinalization(M);
803
804 clearAnnotationCache(&M);
805
806 auto *TS =
807 static_cast<NVPTXTargetStreamer *>(OutStreamer->getTargetStreamer());
808 // Close the last emitted section
809 if (hasDebugInfo()) {
810 TS->closeLastSection();
811 // Emit empty .debug_macinfo section for better support of the empty files.
812 OutStreamer->emitRawText(String: "\t.section\t.debug_macinfo\t{\t}");
813 }
814
815 // Output last DWARF .file directives, if any.
816 TS->outputDwarfFileDirectives();
817
818 return ret;
819}
820
821// This function emits appropriate linkage directives for
822// functions and global variables.
823//
824// extern function declaration -> .extern
825// extern function definition -> .visible
826// external global variable with init -> .visible
827// external without init -> .extern
828// appending -> not allowed, assert.
829// for any linkage other than
830// internal, private, linker_private,
831// linker_private_weak, linker_private_weak_def_auto,
832// we emit -> .weak.
833
834void NVPTXAsmPrinter::emitLinkageDirective(const GlobalValue *V,
835 raw_ostream &O) {
836 if (static_cast<NVPTXTargetMachine &>(TM).getDrvInterface() == NVPTX::CUDA) {
837 if (V->hasExternalLinkage()) {
838 if (const auto *GVar = dyn_cast<GlobalVariable>(Val: V))
839 O << (GVar->hasInitializer() ? ".visible " : ".extern ");
840 else if (V->isDeclaration())
841 O << ".extern ";
842 else
843 O << ".visible ";
844 } else if (V->hasAppendingLinkage()) {
845 report_fatal_error(reason: "Symbol '" + (V->hasName() ? V->getName() : "") +
846 "' has unsupported appending linkage type");
847 } else if (!V->hasInternalLinkage() && !V->hasPrivateLinkage()) {
848 O << ".weak ";
849 }
850 }
851}
852
853void NVPTXAsmPrinter::printModuleLevelGV(const GlobalVariable *GVar,
854 raw_ostream &O, bool ProcessDemoted,
855 const NVPTXSubtarget &STI) {
856 // Skip meta data
857 if (GVar->hasSection())
858 if (GVar->getSection() == "llvm.metadata")
859 return;
860
861 // Skip LLVM intrinsic global variables
862 if (GVar->getName().starts_with(Prefix: "llvm.") ||
863 GVar->getName().starts_with(Prefix: "nvvm."))
864 return;
865
866 const DataLayout &DL = getDataLayout();
867
868 // GlobalVariables are always constant pointers themselves.
869 Type *ETy = GVar->getValueType();
870
871 if (GVar->hasExternalLinkage()) {
872 if (GVar->hasInitializer())
873 O << ".visible ";
874 else
875 O << ".extern ";
876 } else if (STI.getPTXVersion() >= 50 && GVar->hasCommonLinkage() &&
877 GVar->getAddressSpace() == ADDRESS_SPACE_GLOBAL) {
878 O << ".common ";
879 } else if (GVar->hasLinkOnceLinkage() || GVar->hasWeakLinkage() ||
880 GVar->hasAvailableExternallyLinkage() ||
881 GVar->hasCommonLinkage()) {
882 O << ".weak ";
883 }
884
885 const PTXOpaqueType OpaqueType = getPTXOpaqueType(*GVar);
886
887 if (OpaqueType == PTXOpaqueType::Texture) {
888 O << ".global .texref " << getTextureName(V: *GVar) << ";\n";
889 return;
890 }
891
892 if (OpaqueType == PTXOpaqueType::Surface) {
893 O << ".global .surfref " << getSurfaceName(V: *GVar) << ";\n";
894 return;
895 }
896
897 if (GVar->isDeclaration()) {
898 // (extern) declarations, no definition or initializer
899 // Currently the only known declaration is for an automatic __local
900 // (.shared) promoted to global.
901 emitPTXGlobalVariable(GVar, O, STI);
902 O << ";\n";
903 return;
904 }
905
906 if (OpaqueType == PTXOpaqueType::Sampler) {
907 O << ".global .samplerref " << getSamplerName(V: *GVar);
908
909 const Constant *Initializer = nullptr;
910 if (GVar->hasInitializer())
911 Initializer = GVar->getInitializer();
912 const ConstantInt *CI = nullptr;
913 if (Initializer)
914 CI = dyn_cast<ConstantInt>(Val: Initializer);
915 if (CI) {
916 unsigned sample = CI->getZExtValue();
917
918 O << " = { ";
919
920 for (int i = 0,
921 addr = ((sample & __CLK_ADDRESS_MASK) >> __CLK_ADDRESS_BASE);
922 i < 3; i++) {
923 O << "addr_mode_" << i << " = ";
924 switch (addr) {
925 case 0:
926 O << "wrap";
927 break;
928 case 1:
929 O << "clamp_to_border";
930 break;
931 case 2:
932 O << "clamp_to_edge";
933 break;
934 case 3:
935 O << "wrap";
936 break;
937 case 4:
938 O << "mirror";
939 break;
940 }
941 O << ", ";
942 }
943 O << "filter_mode = ";
944 switch ((sample & __CLK_FILTER_MASK) >> __CLK_FILTER_BASE) {
945 case 0:
946 O << "nearest";
947 break;
948 case 1:
949 O << "linear";
950 break;
951 case 2:
952 llvm_unreachable("Anisotropic filtering is not supported");
953 default:
954 O << "nearest";
955 break;
956 }
957 if (!((sample & __CLK_NORMALIZED_MASK) >> __CLK_NORMALIZED_BASE)) {
958 O << ", force_unnormalized_coords = 1";
959 }
960 O << " }";
961 }
962
963 O << ";\n";
964 return;
965 }
966
967 if (GVar->hasPrivateLinkage()) {
968 if (GVar->getName().starts_with(Prefix: "unrollpragma"))
969 return;
970
971 // FIXME - need better way (e.g. Metadata) to avoid generating this global
972 if (GVar->getName().starts_with(Prefix: "filename"))
973 return;
974 if (GVar->use_empty())
975 return;
976 }
977
978 const Function *DemotedFunc = nullptr;
979 if (!ProcessDemoted && canDemoteGlobalVar(GV: GVar, f&: DemotedFunc)) {
980 O << "// " << GVar->getName() << " has been demoted\n";
981 localDecls[DemotedFunc].push_back(x: GVar);
982 return;
983 }
984
985 O << ".";
986 emitPTXAddressSpace(AddressSpace: GVar->getAddressSpace(), O);
987
988 if (isManaged(*GVar)) {
989 if (STI.getPTXVersion() < 40 || STI.getSmVersion() < 30)
990 report_fatal_error(
991 reason: ".attribute(.managed) requires PTX version >= 4.0 and sm_30");
992 O << " .attribute(.managed)";
993 }
994
995 O << " .align "
996 << GVar->getAlign().value_or(u: DL.getPrefTypeAlign(Ty: ETy)).value();
997
998 if (ETy->isPointerTy() || ((ETy->isIntegerTy() || ETy->isFloatingPointTy()) &&
999 ETy->getScalarSizeInBits() <= 64)) {
1000 O << " .";
1001 // Special case: ABI requires that we use .u8 for predicates
1002 if (ETy->isIntegerTy(Bitwidth: 1))
1003 O << "u8";
1004 else
1005 O << getPTXFundamentalTypeStr(Ty: ETy, false);
1006 O << " ";
1007 getSymbol(GV: GVar)->print(OS&: O, MAI);
1008
1009 // Ptx allows variable initilization only for constant and global state
1010 // spaces.
1011 if (GVar->hasInitializer()) {
1012 if ((GVar->getAddressSpace() == ADDRESS_SPACE_GLOBAL) ||
1013 (GVar->getAddressSpace() == ADDRESS_SPACE_CONST)) {
1014 const Constant *Initializer = GVar->getInitializer();
1015 // 'undef' is treated as there is no value specified.
1016 if (!Initializer->isNullValue() && !isa<UndefValue>(Val: Initializer)) {
1017 O << " = ";
1018 printScalarConstant(CPV: Initializer, O);
1019 }
1020 } else {
1021 // The frontend adds zero-initializer to device and constant variables
1022 // that don't have an initial value, and UndefValue to shared
1023 // variables, so skip warning for this case.
1024 if (!GVar->getInitializer()->isNullValue() &&
1025 !isa<UndefValue>(Val: GVar->getInitializer())) {
1026 report_fatal_error(reason: "initial value of '" + GVar->getName() +
1027 "' is not allowed in addrspace(" +
1028 Twine(GVar->getAddressSpace()) + ")");
1029 }
1030 }
1031 }
1032 } else {
1033 // Although PTX has direct support for struct type and array type and
1034 // LLVM IR is very similar to PTX, the LLVM CodeGen does not support for
1035 // targets that support these high level field accesses. Structs, arrays
1036 // and vectors are lowered into arrays of bytes.
1037 switch (ETy->getTypeID()) {
1038 case Type::IntegerTyID: // Integers larger than 64 bits
1039 case Type::FP128TyID:
1040 case Type::StructTyID:
1041 case Type::ArrayTyID:
1042 case Type::FixedVectorTyID: {
1043 const uint64_t ElementSize = DL.getTypeStoreSize(Ty: ETy);
1044 // Ptx allows variable initilization only for constant and
1045 // global state spaces.
1046 if (((GVar->getAddressSpace() == ADDRESS_SPACE_GLOBAL) ||
1047 (GVar->getAddressSpace() == ADDRESS_SPACE_CONST)) &&
1048 GVar->hasInitializer()) {
1049 const Constant *Initializer = GVar->getInitializer();
1050 if (!isa<UndefValue>(Val: Initializer) && !Initializer->isNullValue()) {
1051 AggBuffer aggBuffer(ElementSize, *this);
1052 bufferAggregateConstant(CV: Initializer, aggBuffer: &aggBuffer);
1053 if (aggBuffer.numSymbols()) {
1054 const unsigned int ptrSize = MAI->getCodePointerSize();
1055 if (ElementSize % ptrSize ||
1056 !aggBuffer.allSymbolsAligned(ptrSize)) {
1057 // Print in bytes and use the mask() operator for pointers.
1058 if (!STI.hasMaskOperator())
1059 report_fatal_error(
1060 reason: "initialized packed aggregate with pointers '" +
1061 GVar->getName() +
1062 "' requires at least PTX ISA version 7.1");
1063 O << " .u8 ";
1064 getSymbol(GV: GVar)->print(OS&: O, MAI);
1065 O << "[" << ElementSize << "] = {";
1066 aggBuffer.printBytes(os&: O);
1067 O << "}";
1068 } else {
1069 O << " .u" << ptrSize * 8 << " ";
1070 getSymbol(GV: GVar)->print(OS&: O, MAI);
1071 O << "[" << ElementSize / ptrSize << "] = {";
1072 aggBuffer.printWords(os&: O);
1073 O << "}";
1074 }
1075 } else {
1076 O << " .b8 ";
1077 getSymbol(GV: GVar)->print(OS&: O, MAI);
1078 O << "[" << ElementSize << "] = {";
1079 aggBuffer.printBytes(os&: O);
1080 O << "}";
1081 }
1082 } else {
1083 O << " .b8 ";
1084 getSymbol(GV: GVar)->print(OS&: O, MAI);
1085 if (ElementSize)
1086 O << "[" << ElementSize << "]";
1087 }
1088 } else {
1089 O << " .b8 ";
1090 getSymbol(GV: GVar)->print(OS&: O, MAI);
1091 if (ElementSize)
1092 O << "[" << ElementSize << "]";
1093 }
1094 break;
1095 }
1096 default:
1097 llvm_unreachable("type not supported yet");
1098 }
1099 }
1100 O << ";\n";
1101}
1102
1103void NVPTXAsmPrinter::AggBuffer::printSymbol(unsigned nSym, raw_ostream &os) {
1104 const Value *v = Symbols[nSym];
1105 const Value *v0 = SymbolsBeforeStripping[nSym];
1106 if (const GlobalValue *GVar = dyn_cast<GlobalValue>(Val: v)) {
1107 MCSymbol *Name = AP.getSymbol(GV: GVar);
1108 PointerType *PTy = dyn_cast<PointerType>(Val: v0->getType());
1109 // Is v0 a generic pointer?
1110 bool isGenericPointer = PTy && PTy->getAddressSpace() == 0;
1111 if (EmitGeneric && isGenericPointer && !isa<Function>(Val: v)) {
1112 os << "generic(";
1113 Name->print(OS&: os, MAI: AP.MAI);
1114 os << ")";
1115 } else {
1116 Name->print(OS&: os, MAI: AP.MAI);
1117 }
1118 } else if (const ConstantExpr *CExpr = dyn_cast<ConstantExpr>(Val: v0)) {
1119 const MCExpr *Expr = AP.lowerConstantForGV(CV: CExpr, ProcessingGeneric: false);
1120 AP.printMCExpr(Expr: *Expr, OS&: os);
1121 } else
1122 llvm_unreachable("symbol type unknown");
1123}
1124
1125void NVPTXAsmPrinter::AggBuffer::printBytes(raw_ostream &os) {
1126 unsigned int ptrSize = AP.MAI->getCodePointerSize();
1127 // Do not emit trailing zero initializers. They will be zero-initialized by
1128 // ptxas. This saves on both space requirements for the generated PTX and on
1129 // memory use by ptxas. (See:
1130 // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#global-state-space)
1131 unsigned int InitializerCount = size;
1132 // TODO: symbols make this harder, but it would still be good to trim trailing
1133 // 0s for aggs with symbols as well.
1134 if (numSymbols() == 0)
1135 while (InitializerCount >= 1 && !buffer[InitializerCount - 1])
1136 InitializerCount--;
1137
1138 symbolPosInBuffer.push_back(Elt: InitializerCount);
1139 unsigned int nSym = 0;
1140 unsigned int nextSymbolPos = symbolPosInBuffer[nSym];
1141 for (unsigned int pos = 0; pos < InitializerCount;) {
1142 if (pos)
1143 os << ", ";
1144 if (pos != nextSymbolPos) {
1145 os << (unsigned int)buffer[pos];
1146 ++pos;
1147 continue;
1148 }
1149 // Generate a per-byte mask() operator for the symbol, which looks like:
1150 // .global .u8 addr[] = {0xFF(foo), 0xFF00(foo), 0xFF0000(foo), ...};
1151 // See https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#initializers
1152 std::string symText;
1153 llvm::raw_string_ostream oss(symText);
1154 printSymbol(nSym, os&: oss);
1155 for (unsigned i = 0; i < ptrSize; ++i) {
1156 if (i)
1157 os << ", ";
1158 llvm::write_hex(S&: os, N: 0xFFULL << i * 8, Style: HexPrintStyle::PrefixUpper);
1159 os << "(" << symText << ")";
1160 }
1161 pos += ptrSize;
1162 nextSymbolPos = symbolPosInBuffer[++nSym];
1163 assert(nextSymbolPos >= pos);
1164 }
1165}
1166
1167void NVPTXAsmPrinter::AggBuffer::printWords(raw_ostream &os) {
1168 unsigned int ptrSize = AP.MAI->getCodePointerSize();
1169 symbolPosInBuffer.push_back(Elt: size);
1170 unsigned int nSym = 0;
1171 unsigned int nextSymbolPos = symbolPosInBuffer[nSym];
1172 assert(nextSymbolPos % ptrSize == 0);
1173 for (unsigned int pos = 0; pos < size; pos += ptrSize) {
1174 if (pos)
1175 os << ", ";
1176 if (pos == nextSymbolPos) {
1177 printSymbol(nSym, os);
1178 nextSymbolPos = symbolPosInBuffer[++nSym];
1179 assert(nextSymbolPos % ptrSize == 0);
1180 assert(nextSymbolPos >= pos + ptrSize);
1181 } else if (ptrSize == 4)
1182 os << support::endian::read32le(P: &buffer[pos]);
1183 else
1184 os << support::endian::read64le(P: &buffer[pos]);
1185 }
1186}
1187
1188void NVPTXAsmPrinter::emitDemotedVars(const Function *F, raw_ostream &O) {
1189 auto It = localDecls.find(x: F);
1190 if (It == localDecls.end())
1191 return;
1192
1193 ArrayRef<const GlobalVariable *> GVars = It->second;
1194
1195 const NVPTXTargetMachine &NTM = static_cast<const NVPTXTargetMachine &>(TM);
1196 const NVPTXSubtarget &STI = *NTM.getSubtargetImpl();
1197
1198 for (const GlobalVariable *GV : GVars) {
1199 O << "\t// demoted variable\n\t";
1200 printModuleLevelGV(GVar: GV, O, /*processDemoted=*/ProcessDemoted: true, STI);
1201 }
1202}
1203
1204void NVPTXAsmPrinter::emitPTXAddressSpace(unsigned int AddressSpace,
1205 raw_ostream &O) const {
1206 switch (AddressSpace) {
1207 case ADDRESS_SPACE_LOCAL:
1208 O << "local";
1209 break;
1210 case ADDRESS_SPACE_GLOBAL:
1211 O << "global";
1212 break;
1213 case ADDRESS_SPACE_CONST:
1214 O << "const";
1215 break;
1216 case ADDRESS_SPACE_SHARED:
1217 O << "shared";
1218 break;
1219 default:
1220 report_fatal_error(reason: "Bad address space found while emitting PTX: " +
1221 llvm::Twine(AddressSpace));
1222 break;
1223 }
1224}
1225
1226std::string
1227NVPTXAsmPrinter::getPTXFundamentalTypeStr(Type *Ty, bool useB4PTR) const {
1228 switch (Ty->getTypeID()) {
1229 case Type::IntegerTyID: {
1230 unsigned NumBits = cast<IntegerType>(Val: Ty)->getBitWidth();
1231 if (NumBits == 1)
1232 return "pred";
1233 if (NumBits <= 64) {
1234 std::string name = "u";
1235 return name + utostr(X: NumBits);
1236 }
1237 llvm_unreachable("Integer too large");
1238 break;
1239 }
1240 case Type::BFloatTyID:
1241 case Type::HalfTyID:
1242 // fp16 and bf16 are stored as .b16 for compatibility with pre-sm_53
1243 // PTX assembly.
1244 return "b16";
1245 case Type::FloatTyID:
1246 return "f32";
1247 case Type::DoubleTyID:
1248 return "f64";
1249 case Type::PointerTyID: {
1250 unsigned PtrSize = TM.getPointerSizeInBits(AS: Ty->getPointerAddressSpace());
1251 assert((PtrSize == 64 || PtrSize == 32) && "Unexpected pointer size");
1252
1253 if (PtrSize == 64)
1254 if (useB4PTR)
1255 return "b64";
1256 else
1257 return "u64";
1258 else if (useB4PTR)
1259 return "b32";
1260 else
1261 return "u32";
1262 }
1263 default:
1264 break;
1265 }
1266 llvm_unreachable("unexpected type");
1267}
1268
1269void NVPTXAsmPrinter::emitPTXGlobalVariable(const GlobalVariable *GVar,
1270 raw_ostream &O,
1271 const NVPTXSubtarget &STI) {
1272 const DataLayout &DL = getDataLayout();
1273
1274 // GlobalVariables are always constant pointers themselves.
1275 Type *ETy = GVar->getValueType();
1276
1277 O << ".";
1278 emitPTXAddressSpace(AddressSpace: GVar->getType()->getAddressSpace(), O);
1279 if (isManaged(*GVar)) {
1280 if (STI.getPTXVersion() < 40 || STI.getSmVersion() < 30)
1281 report_fatal_error(
1282 reason: ".attribute(.managed) requires PTX version >= 4.0 and sm_30");
1283
1284 O << " .attribute(.managed)";
1285 }
1286 O << " .align "
1287 << GVar->getAlign().value_or(u: DL.getPrefTypeAlign(Ty: ETy)).value();
1288
1289 // Special case for i128/fp128
1290 if (ETy->getScalarSizeInBits() == 128) {
1291 O << " .b8 ";
1292 getSymbol(GV: GVar)->print(OS&: O, MAI);
1293 O << "[16]";
1294 return;
1295 }
1296
1297 if (ETy->isFloatingPointTy() || ETy->isIntOrPtrTy()) {
1298 O << " ." << getPTXFundamentalTypeStr(Ty: ETy) << " ";
1299 getSymbol(GV: GVar)->print(OS&: O, MAI);
1300 return;
1301 }
1302
1303 int64_t ElementSize = 0;
1304
1305 // Although PTX has direct support for struct type and array type and LLVM IR
1306 // is very similar to PTX, the LLVM CodeGen does not support for targets that
1307 // support these high level field accesses. Structs and arrays are lowered
1308 // into arrays of bytes.
1309 switch (ETy->getTypeID()) {
1310 case Type::StructTyID:
1311 case Type::ArrayTyID:
1312 case Type::FixedVectorTyID:
1313 ElementSize = DL.getTypeStoreSize(Ty: ETy);
1314 O << " .b8 ";
1315 getSymbol(GV: GVar)->print(OS&: O, MAI);
1316 O << "[";
1317 if (ElementSize) {
1318 O << ElementSize;
1319 }
1320 O << "]";
1321 break;
1322 default:
1323 llvm_unreachable("type not supported yet");
1324 }
1325}
1326
1327void NVPTXAsmPrinter::emitFunctionParamList(const Function *F, raw_ostream &O) {
1328 const DataLayout &DL = getDataLayout();
1329 const NVPTXSubtarget &STI = TM.getSubtarget<NVPTXSubtarget>(F: *F);
1330 const auto *TLI = cast<NVPTXTargetLowering>(Val: STI.getTargetLowering());
1331 const NVPTXMachineFunctionInfo *MFI =
1332 MF ? MF->getInfo<NVPTXMachineFunctionInfo>() : nullptr;
1333
1334 bool IsFirst = true;
1335 const bool IsKernelFunc = isKernelFunction(F: *F);
1336
1337 if (F->arg_empty() && !F->isVarArg()) {
1338 O << "()";
1339 return;
1340 }
1341
1342 O << "(\n";
1343
1344 for (const Argument &Arg : F->args()) {
1345 Type *Ty = Arg.getType();
1346 const std::string ParamSym = TLI->getParamName(F, Idx: Arg.getArgNo());
1347
1348 if (!IsFirst)
1349 O << ",\n";
1350
1351 IsFirst = false;
1352
1353 // Handle image/sampler parameters
1354 if (IsKernelFunc) {
1355 const PTXOpaqueType ArgOpaqueType = getPTXOpaqueType(Arg);
1356 if (ArgOpaqueType != PTXOpaqueType::None) {
1357 const bool EmitImgPtr = !MFI || !MFI->checkImageHandleSymbol(Symbol: ParamSym);
1358 O << "\t.param ";
1359 if (EmitImgPtr)
1360 O << ".u64 .ptr ";
1361
1362 switch (ArgOpaqueType) {
1363 case PTXOpaqueType::Sampler:
1364 O << ".samplerref ";
1365 break;
1366 case PTXOpaqueType::Texture:
1367 O << ".texref ";
1368 break;
1369 case PTXOpaqueType::Surface:
1370 O << ".surfref ";
1371 break;
1372 case PTXOpaqueType::None:
1373 llvm_unreachable("handled above");
1374 }
1375 O << ParamSym;
1376 continue;
1377 }
1378 }
1379
1380 auto GetOptimalAlignForParam = [&DL, F, &Arg](Type *Ty) -> Align {
1381 if (MaybeAlign StackAlign =
1382 getAlign(F: *F, Index: Arg.getArgNo() + AttributeList::FirstArgIndex))
1383 return StackAlign.value();
1384
1385 Align TypeAlign = getFunctionParamOptimizedAlign(F, ArgTy: Ty, DL);
1386 MaybeAlign ParamAlign =
1387 Arg.hasByValAttr() ? Arg.getParamAlign() : MaybeAlign();
1388 return std::max(a: TypeAlign, b: ParamAlign.valueOrOne());
1389 };
1390
1391 if (Arg.hasByValAttr()) {
1392 // param has byVal attribute.
1393 Type *ETy = Arg.getParamByValType();
1394 assert(ETy && "Param should have byval type");
1395
1396 // Print .param .align <a> .b8 .param[size];
1397 // <a> = optimal alignment for the element type; always multiple of
1398 // PAL.getParamAlignment
1399 // size = typeallocsize of element type
1400 const Align OptimalAlign =
1401 IsKernelFunc ? GetOptimalAlignForParam(ETy)
1402 : getFunctionByValParamAlign(
1403 F, ArgTy: ETy, InitialAlign: Arg.getParamAlign().valueOrOne(), DL);
1404
1405 O << "\t.param .align " << OptimalAlign.value() << " .b8 " << ParamSym
1406 << "[" << DL.getTypeAllocSize(Ty: ETy) << "]";
1407 continue;
1408 }
1409
1410 if (shouldPassAsArray(Ty)) {
1411 // Just print .param .align <a> .b8 .param[size];
1412 // <a> = optimal alignment for the element type; always multiple of
1413 // PAL.getParamAlignment
1414 // size = typeallocsize of element type
1415 Align OptimalAlign = GetOptimalAlignForParam(Ty);
1416
1417 O << "\t.param .align " << OptimalAlign.value() << " .b8 " << ParamSym
1418 << "[" << DL.getTypeAllocSize(Ty) << "]";
1419
1420 continue;
1421 }
1422 // Just a scalar
1423 auto *PTy = dyn_cast<PointerType>(Val: Ty);
1424 unsigned PTySizeInBits = 0;
1425 if (PTy) {
1426 PTySizeInBits =
1427 TLI->getPointerTy(DL, AS: PTy->getAddressSpace()).getSizeInBits();
1428 assert(PTySizeInBits && "Invalid pointer size");
1429 }
1430
1431 if (IsKernelFunc) {
1432 if (PTy) {
1433 O << "\t.param .u" << PTySizeInBits << " .ptr";
1434
1435 switch (PTy->getAddressSpace()) {
1436 default:
1437 break;
1438 case ADDRESS_SPACE_GLOBAL:
1439 O << " .global";
1440 break;
1441 case ADDRESS_SPACE_SHARED:
1442 O << " .shared";
1443 break;
1444 case ADDRESS_SPACE_CONST:
1445 O << " .const";
1446 break;
1447 case ADDRESS_SPACE_LOCAL:
1448 O << " .local";
1449 break;
1450 }
1451
1452 O << " .align " << Arg.getParamAlign().valueOrOne().value() << " "
1453 << ParamSym;
1454 continue;
1455 }
1456
1457 // non-pointer scalar to kernel func
1458 O << "\t.param .";
1459 // Special case: predicate operands become .u8 types
1460 if (Ty->isIntegerTy(Bitwidth: 1))
1461 O << "u8";
1462 else
1463 O << getPTXFundamentalTypeStr(Ty);
1464 O << " " << ParamSym;
1465 continue;
1466 }
1467 // Non-kernel function, just print .param .b<size> for ABI
1468 // and .reg .b<size> for non-ABI
1469 unsigned Size;
1470 if (auto *ITy = dyn_cast<IntegerType>(Val: Ty)) {
1471 Size = promoteScalarArgumentSize(size: ITy->getBitWidth());
1472 } else if (PTy) {
1473 assert(PTySizeInBits && "Invalid pointer size");
1474 Size = PTySizeInBits;
1475 } else
1476 Size = Ty->getPrimitiveSizeInBits();
1477 O << "\t.param .b" << Size << " " << ParamSym;
1478 }
1479
1480 if (F->isVarArg()) {
1481 if (!IsFirst)
1482 O << ",\n";
1483 O << "\t.param .align " << STI.getMaxRequiredAlignment() << " .b8 "
1484 << TLI->getParamName(F, /* vararg */ Idx: -1) << "[]";
1485 }
1486
1487 O << "\n)";
1488}
1489
1490void NVPTXAsmPrinter::setAndEmitFunctionVirtualRegisters(
1491 const MachineFunction &MF) {
1492 SmallString<128> Str;
1493 raw_svector_ostream O(Str);
1494
1495 // Map the global virtual register number to a register class specific
1496 // virtual register number starting from 1 with that class.
1497 const TargetRegisterInfo *TRI = MF.getSubtarget().getRegisterInfo();
1498
1499 // Emit the Fake Stack Object
1500 const MachineFrameInfo &MFI = MF.getFrameInfo();
1501 int64_t NumBytes = MFI.getStackSize();
1502 if (NumBytes) {
1503 O << "\t.local .align " << MFI.getMaxAlign().value() << " .b8 \t"
1504 << DEPOTNAME << getFunctionNumber() << "[" << NumBytes << "];\n";
1505 if (static_cast<const NVPTXTargetMachine &>(MF.getTarget()).is64Bit()) {
1506 O << "\t.reg .b64 \t%SP;\n"
1507 << "\t.reg .b64 \t%SPL;\n";
1508 } else {
1509 O << "\t.reg .b32 \t%SP;\n"
1510 << "\t.reg .b32 \t%SPL;\n";
1511 }
1512 }
1513
1514 // Go through all virtual registers to establish the mapping between the
1515 // global virtual
1516 // register number and the per class virtual register number.
1517 // We use the per class virtual register number in the ptx output.
1518 for (unsigned I : llvm::seq(Size: MRI->getNumVirtRegs())) {
1519 Register VR = Register::index2VirtReg(Index: I);
1520 if (MRI->use_empty(RegNo: VR) && MRI->def_empty(RegNo: VR))
1521 continue;
1522 auto &RCRegMap = VRegMapping[MRI->getRegClass(Reg: VR)];
1523 RCRegMap[VR] = RCRegMap.size() + 1;
1524 }
1525
1526 // Emit declaration of the virtual registers or 'physical' registers for
1527 // each register class
1528 for (const TargetRegisterClass *RC : TRI->regclasses()) {
1529 const unsigned N = VRegMapping[RC].size();
1530
1531 // Only declare those registers that may be used.
1532 if (N) {
1533 const StringRef RCName = getNVPTXRegClassName(RC);
1534 const StringRef RCStr = getNVPTXRegClassStr(RC);
1535 O << "\t.reg " << RCName << " \t" << RCStr << "<" << (N + 1) << ">;\n";
1536 }
1537 }
1538
1539 OutStreamer->emitRawText(String: O.str());
1540}
1541
1542/// Translate virtual register numbers in DebugInfo locations to their printed
1543/// encodings, as used by CUDA-GDB.
1544void NVPTXAsmPrinter::encodeDebugInfoRegisterNumbers(
1545 const MachineFunction &MF) {
1546 const NVPTXSubtarget &STI = MF.getSubtarget<NVPTXSubtarget>();
1547 const NVPTXRegisterInfo *registerInfo = STI.getRegisterInfo();
1548
1549 // Clear the old mapping, and add the new one. This mapping is used after the
1550 // printing of the current function is complete, but before the next function
1551 // is printed.
1552 registerInfo->clearDebugRegisterMap();
1553
1554 for (auto &classMap : VRegMapping) {
1555 for (auto &registerMapping : classMap.getSecond()) {
1556 auto reg = registerMapping.getFirst();
1557 registerInfo->addToDebugRegisterMap(preEncodedVirtualRegister: reg, RegisterName: getVirtualRegisterName(Reg: reg));
1558 }
1559 }
1560}
1561
1562void NVPTXAsmPrinter::printFPConstant(const ConstantFP *Fp,
1563 raw_ostream &O) const {
1564 APFloat APF = APFloat(Fp->getValueAPF()); // make a copy
1565 bool ignored;
1566 unsigned int numHex;
1567 const char *lead;
1568
1569 if (Fp->getType()->getTypeID() == Type::FloatTyID) {
1570 numHex = 8;
1571 lead = "0f";
1572 APF.convert(ToSemantics: APFloat::IEEEsingle(), RM: APFloat::rmNearestTiesToEven, losesInfo: &ignored);
1573 } else if (Fp->getType()->getTypeID() == Type::DoubleTyID) {
1574 numHex = 16;
1575 lead = "0d";
1576 APF.convert(ToSemantics: APFloat::IEEEdouble(), RM: APFloat::rmNearestTiesToEven, losesInfo: &ignored);
1577 } else
1578 llvm_unreachable("unsupported fp type");
1579
1580 APInt API = APF.bitcastToAPInt();
1581 O << lead << format_hex_no_prefix(N: API.getZExtValue(), Width: numHex, /*Upper=*/true);
1582}
1583
1584void NVPTXAsmPrinter::printScalarConstant(const Constant *CPV, raw_ostream &O) {
1585 if (const ConstantInt *CI = dyn_cast<ConstantInt>(Val: CPV)) {
1586 O << CI->getValue();
1587 return;
1588 }
1589 if (const ConstantFP *CFP = dyn_cast<ConstantFP>(Val: CPV)) {
1590 printFPConstant(Fp: CFP, O);
1591 return;
1592 }
1593 if (isa<ConstantPointerNull>(Val: CPV)) {
1594 O << "0";
1595 return;
1596 }
1597 if (const GlobalValue *GVar = dyn_cast<GlobalValue>(Val: CPV)) {
1598 const bool IsNonGenericPointer = GVar->getAddressSpace() != 0;
1599 if (EmitGeneric && !isa<Function>(Val: CPV) && !IsNonGenericPointer) {
1600 O << "generic(";
1601 getSymbol(GV: GVar)->print(OS&: O, MAI);
1602 O << ")";
1603 } else {
1604 getSymbol(GV: GVar)->print(OS&: O, MAI);
1605 }
1606 return;
1607 }
1608 if (const ConstantExpr *Cexpr = dyn_cast<ConstantExpr>(Val: CPV)) {
1609 const MCExpr *E = lowerConstantForGV(CV: cast<Constant>(Val: Cexpr), ProcessingGeneric: false);
1610 printMCExpr(Expr: *E, OS&: O);
1611 return;
1612 }
1613 llvm_unreachable("Not scalar type found in printScalarConstant()");
1614}
1615
1616void NVPTXAsmPrinter::bufferLEByte(const Constant *CPV, int Bytes,
1617 AggBuffer *AggBuffer) {
1618 const DataLayout &DL = getDataLayout();
1619 int AllocSize = DL.getTypeAllocSize(Ty: CPV->getType());
1620 if (isa<UndefValue>(Val: CPV) || CPV->isNullValue()) {
1621 // Non-zero Bytes indicates that we need to zero-fill everything. Otherwise,
1622 // only the space allocated by CPV.
1623 AggBuffer->addZeros(Num: Bytes ? Bytes : AllocSize);
1624 return;
1625 }
1626
1627 // Helper for filling AggBuffer with APInts.
1628 auto AddIntToBuffer = [AggBuffer, Bytes](const APInt &Val) {
1629 size_t NumBytes = (Val.getBitWidth() + 7) / 8;
1630 SmallVector<unsigned char, 16> Buf(NumBytes);
1631 // `extractBitsAsZExtValue` does not allow the extraction of bits beyond the
1632 // input's bit width, and i1 arrays may not have a length that is a multuple
1633 // of 8. We handle the last byte separately, so we never request out of
1634 // bounds bits.
1635 for (unsigned I = 0; I < NumBytes - 1; ++I) {
1636 Buf[I] = Val.extractBitsAsZExtValue(numBits: 8, bitPosition: I * 8);
1637 }
1638 size_t LastBytePosition = (NumBytes - 1) * 8;
1639 size_t LastByteBits = Val.getBitWidth() - LastBytePosition;
1640 Buf[NumBytes - 1] =
1641 Val.extractBitsAsZExtValue(numBits: LastByteBits, bitPosition: LastBytePosition);
1642 AggBuffer->addBytes(Ptr: Buf.data(), Num: NumBytes, Bytes);
1643 };
1644
1645 switch (CPV->getType()->getTypeID()) {
1646 case Type::IntegerTyID:
1647 if (const auto *CI = dyn_cast<ConstantInt>(Val: CPV)) {
1648 AddIntToBuffer(CI->getValue());
1649 break;
1650 }
1651 if (const auto *Cexpr = dyn_cast<ConstantExpr>(Val: CPV)) {
1652 if (const auto *CI =
1653 dyn_cast<ConstantInt>(Val: ConstantFoldConstant(C: Cexpr, DL))) {
1654 AddIntToBuffer(CI->getValue());
1655 break;
1656 }
1657 if (Cexpr->getOpcode() == Instruction::PtrToInt) {
1658 Value *V = Cexpr->getOperand(i_nocapture: 0)->stripPointerCasts();
1659 AggBuffer->addSymbol(GVar: V, GVarBeforeStripping: Cexpr->getOperand(i_nocapture: 0));
1660 AggBuffer->addZeros(Num: AllocSize);
1661 break;
1662 }
1663 }
1664 llvm_unreachable("unsupported integer const type");
1665 break;
1666
1667 case Type::HalfTyID:
1668 case Type::BFloatTyID:
1669 case Type::FloatTyID:
1670 case Type::DoubleTyID:
1671 AddIntToBuffer(cast<ConstantFP>(Val: CPV)->getValueAPF().bitcastToAPInt());
1672 break;
1673
1674 case Type::PointerTyID: {
1675 if (const GlobalValue *GVar = dyn_cast<GlobalValue>(Val: CPV)) {
1676 AggBuffer->addSymbol(GVar, GVarBeforeStripping: GVar);
1677 } else if (const ConstantExpr *Cexpr = dyn_cast<ConstantExpr>(Val: CPV)) {
1678 const Value *v = Cexpr->stripPointerCasts();
1679 AggBuffer->addSymbol(GVar: v, GVarBeforeStripping: Cexpr);
1680 }
1681 AggBuffer->addZeros(Num: AllocSize);
1682 break;
1683 }
1684
1685 case Type::ArrayTyID:
1686 case Type::FixedVectorTyID:
1687 case Type::StructTyID: {
1688 if (isa<ConstantAggregate>(Val: CPV) || isa<ConstantDataSequential>(Val: CPV)) {
1689 bufferAggregateConstant(CV: CPV, aggBuffer: AggBuffer);
1690 if (Bytes > AllocSize)
1691 AggBuffer->addZeros(Num: Bytes - AllocSize);
1692 } else if (isa<ConstantAggregateZero>(Val: CPV))
1693 AggBuffer->addZeros(Num: Bytes);
1694 else
1695 llvm_unreachable("Unexpected Constant type");
1696 break;
1697 }
1698
1699 default:
1700 llvm_unreachable("unsupported type");
1701 }
1702}
1703
1704void NVPTXAsmPrinter::bufferAggregateConstant(const Constant *CPV,
1705 AggBuffer *aggBuffer) {
1706 const DataLayout &DL = getDataLayout();
1707
1708 auto ExtendBuffer = [](APInt Val, AggBuffer *Buffer) {
1709 for (unsigned I : llvm::seq(Size: Val.getBitWidth() / 8))
1710 Buffer->addByte(Byte: Val.extractBitsAsZExtValue(numBits: 8, bitPosition: I * 8));
1711 };
1712
1713 // Integers of arbitrary width
1714 if (const ConstantInt *CI = dyn_cast<ConstantInt>(Val: CPV)) {
1715 ExtendBuffer(CI->getValue(), aggBuffer);
1716 return;
1717 }
1718
1719 // f128
1720 if (const ConstantFP *CFP = dyn_cast<ConstantFP>(Val: CPV)) {
1721 if (CFP->getType()->isFP128Ty()) {
1722 ExtendBuffer(CFP->getValueAPF().bitcastToAPInt(), aggBuffer);
1723 return;
1724 }
1725 }
1726
1727 // Old constants
1728 if (isa<ConstantArray>(Val: CPV) || isa<ConstantVector>(Val: CPV)) {
1729 for (const auto &Op : CPV->operands())
1730 bufferLEByte(CPV: cast<Constant>(Val: Op), Bytes: 0, AggBuffer: aggBuffer);
1731 return;
1732 }
1733
1734 if (const auto *CDS = dyn_cast<ConstantDataSequential>(Val: CPV)) {
1735 for (unsigned I : llvm::seq(Size: CDS->getNumElements()))
1736 bufferLEByte(CPV: cast<Constant>(Val: CDS->getElementAsConstant(i: I)), Bytes: 0, AggBuffer: aggBuffer);
1737 return;
1738 }
1739
1740 if (isa<ConstantStruct>(Val: CPV)) {
1741 if (CPV->getNumOperands()) {
1742 StructType *ST = cast<StructType>(Val: CPV->getType());
1743 for (unsigned I : llvm::seq(Size: CPV->getNumOperands())) {
1744 int EndOffset = (I + 1 == CPV->getNumOperands())
1745 ? DL.getStructLayout(Ty: ST)->getElementOffset(Idx: 0) +
1746 DL.getTypeAllocSize(Ty: ST)
1747 : DL.getStructLayout(Ty: ST)->getElementOffset(Idx: I + 1);
1748 int Bytes = EndOffset - DL.getStructLayout(Ty: ST)->getElementOffset(Idx: I);
1749 bufferLEByte(CPV: cast<Constant>(Val: CPV->getOperand(i: I)), Bytes, AggBuffer: aggBuffer);
1750 }
1751 }
1752 return;
1753 }
1754 llvm_unreachable("unsupported constant type in printAggregateConstant()");
1755}
1756
1757/// lowerConstantForGV - Return an MCExpr for the given Constant. This is mostly
1758/// a copy from AsmPrinter::lowerConstant, except customized to only handle
1759/// expressions that are representable in PTX and create
1760/// NVPTXGenericMCSymbolRefExpr nodes for addrspacecast instructions.
1761const MCExpr *
1762NVPTXAsmPrinter::lowerConstantForGV(const Constant *CV,
1763 bool ProcessingGeneric) const {
1764 MCContext &Ctx = OutContext;
1765
1766 if (CV->isNullValue() || isa<UndefValue>(Val: CV))
1767 return MCConstantExpr::create(Value: 0, Ctx);
1768
1769 if (const ConstantInt *CI = dyn_cast<ConstantInt>(Val: CV))
1770 return MCConstantExpr::create(Value: CI->getZExtValue(), Ctx);
1771
1772 if (const GlobalValue *GV = dyn_cast<GlobalValue>(Val: CV)) {
1773 const MCSymbolRefExpr *Expr = MCSymbolRefExpr::create(Symbol: getSymbol(GV), Ctx);
1774 if (ProcessingGeneric)
1775 return NVPTXGenericMCSymbolRefExpr::create(SymExpr: Expr, Ctx);
1776 return Expr;
1777 }
1778
1779 const ConstantExpr *CE = dyn_cast<ConstantExpr>(Val: CV);
1780 if (!CE) {
1781 llvm_unreachable("Unknown constant value to lower!");
1782 }
1783
1784 switch (CE->getOpcode()) {
1785 default:
1786 break; // Error
1787
1788 case Instruction::AddrSpaceCast: {
1789 // Strip the addrspacecast and pass along the operand
1790 PointerType *DstTy = cast<PointerType>(Val: CE->getType());
1791 if (DstTy->getAddressSpace() == 0)
1792 return lowerConstantForGV(CV: cast<const Constant>(Val: CE->getOperand(i_nocapture: 0)), ProcessingGeneric: true);
1793
1794 break; // Error
1795 }
1796
1797 case Instruction::GetElementPtr: {
1798 const DataLayout &DL = getDataLayout();
1799
1800 // Generate a symbolic expression for the byte address
1801 APInt OffsetAI(DL.getPointerTypeSizeInBits(CE->getType()), 0);
1802 cast<GEPOperator>(Val: CE)->accumulateConstantOffset(DL, Offset&: OffsetAI);
1803
1804 const MCExpr *Base = lowerConstantForGV(CV: CE->getOperand(i_nocapture: 0),
1805 ProcessingGeneric);
1806 if (!OffsetAI)
1807 return Base;
1808
1809 int64_t Offset = OffsetAI.getSExtValue();
1810 return MCBinaryExpr::createAdd(LHS: Base, RHS: MCConstantExpr::create(Value: Offset, Ctx),
1811 Ctx);
1812 }
1813
1814 case Instruction::Trunc:
1815 // We emit the value and depend on the assembler to truncate the generated
1816 // expression properly. This is important for differences between
1817 // blockaddress labels. Since the two labels are in the same function, it
1818 // is reasonable to treat their delta as a 32-bit value.
1819 [[fallthrough]];
1820 case Instruction::BitCast:
1821 return lowerConstantForGV(CV: CE->getOperand(i_nocapture: 0), ProcessingGeneric);
1822
1823 case Instruction::IntToPtr: {
1824 const DataLayout &DL = getDataLayout();
1825
1826 // Handle casts to pointers by changing them into casts to the appropriate
1827 // integer type. This promotes constant folding and simplifies this code.
1828 Constant *Op = CE->getOperand(i_nocapture: 0);
1829 Op = ConstantFoldIntegerCast(C: Op, DestTy: DL.getIntPtrType(CV->getType()),
1830 /*IsSigned*/ false, DL);
1831 if (Op)
1832 return lowerConstantForGV(CV: Op, ProcessingGeneric);
1833
1834 break; // Error
1835 }
1836
1837 case Instruction::PtrToInt: {
1838 const DataLayout &DL = getDataLayout();
1839
1840 // Support only foldable casts to/from pointers that can be eliminated by
1841 // changing the pointer to the appropriately sized integer type.
1842 Constant *Op = CE->getOperand(i_nocapture: 0);
1843 Type *Ty = CE->getType();
1844
1845 const MCExpr *OpExpr = lowerConstantForGV(CV: Op, ProcessingGeneric);
1846
1847 // We can emit the pointer value into this slot if the slot is an
1848 // integer slot equal to the size of the pointer.
1849 if (DL.getTypeAllocSize(Ty) == DL.getTypeAllocSize(Ty: Op->getType()))
1850 return OpExpr;
1851
1852 // Otherwise the pointer is smaller than the resultant integer, mask off
1853 // the high bits so we are sure to get a proper truncation if the input is
1854 // a constant expr.
1855 unsigned InBits = DL.getTypeAllocSizeInBits(Ty: Op->getType());
1856 const MCExpr *MaskExpr = MCConstantExpr::create(Value: ~0ULL >> (64-InBits), Ctx);
1857 return MCBinaryExpr::createAnd(LHS: OpExpr, RHS: MaskExpr, Ctx);
1858 }
1859
1860 // The MC library also has a right-shift operator, but it isn't consistently
1861 // signed or unsigned between different targets.
1862 case Instruction::Add: {
1863 const MCExpr *LHS = lowerConstantForGV(CV: CE->getOperand(i_nocapture: 0), ProcessingGeneric);
1864 const MCExpr *RHS = lowerConstantForGV(CV: CE->getOperand(i_nocapture: 1), ProcessingGeneric);
1865 switch (CE->getOpcode()) {
1866 default: llvm_unreachable("Unknown binary operator constant cast expr");
1867 case Instruction::Add: return MCBinaryExpr::createAdd(LHS, RHS, Ctx);
1868 }
1869 }
1870 }
1871
1872 // If the code isn't optimized, there may be outstanding folding
1873 // opportunities. Attempt to fold the expression using DataLayout as a
1874 // last resort before giving up.
1875 Constant *C = ConstantFoldConstant(C: CE, DL: getDataLayout());
1876 if (C != CE)
1877 return lowerConstantForGV(CV: C, ProcessingGeneric);
1878
1879 // Otherwise report the problem to the user.
1880 std::string S;
1881 raw_string_ostream OS(S);
1882 OS << "Unsupported expression in static initializer: ";
1883 CE->printAsOperand(O&: OS, /*PrintType=*/false,
1884 M: !MF ? nullptr : MF->getFunction().getParent());
1885 report_fatal_error(reason: Twine(OS.str()));
1886}
1887
1888void NVPTXAsmPrinter::printMCExpr(const MCExpr &Expr, raw_ostream &OS) const {
1889 OutContext.getAsmInfo()->printExpr(OS, Expr);
1890}
1891
1892/// PrintAsmOperand - Print out an operand for an inline asm expression.
1893///
1894bool NVPTXAsmPrinter::PrintAsmOperand(const MachineInstr *MI, unsigned OpNo,
1895 const char *ExtraCode, raw_ostream &O) {
1896 if (ExtraCode && ExtraCode[0]) {
1897 if (ExtraCode[1] != 0)
1898 return true; // Unknown modifier.
1899
1900 switch (ExtraCode[0]) {
1901 default:
1902 // See if this is a generic print operand
1903 return AsmPrinter::PrintAsmOperand(MI, OpNo, ExtraCode, OS&: O);
1904 case 'r':
1905 break;
1906 }
1907 }
1908
1909 printOperand(MI, OpNum: OpNo, O);
1910
1911 return false;
1912}
1913
1914bool NVPTXAsmPrinter::PrintAsmMemoryOperand(const MachineInstr *MI,
1915 unsigned OpNo,
1916 const char *ExtraCode,
1917 raw_ostream &O) {
1918 if (ExtraCode && ExtraCode[0])
1919 return true; // Unknown modifier
1920
1921 O << '[';
1922 printMemOperand(MI, OpNum: OpNo, O);
1923 O << ']';
1924
1925 return false;
1926}
1927
1928void NVPTXAsmPrinter::printOperand(const MachineInstr *MI, unsigned OpNum,
1929 raw_ostream &O) {
1930 const MachineOperand &MO = MI->getOperand(i: OpNum);
1931 switch (MO.getType()) {
1932 case MachineOperand::MO_Register:
1933 if (MO.getReg().isPhysical()) {
1934 if (MO.getReg() == NVPTX::VRDepot)
1935 O << DEPOTNAME << getFunctionNumber();
1936 else
1937 O << NVPTXInstPrinter::getRegisterName(Reg: MO.getReg());
1938 } else {
1939 emitVirtualRegister(vr: MO.getReg(), O);
1940 }
1941 break;
1942
1943 case MachineOperand::MO_Immediate:
1944 O << MO.getImm();
1945 break;
1946
1947 case MachineOperand::MO_FPImmediate:
1948 printFPConstant(Fp: MO.getFPImm(), O);
1949 break;
1950
1951 case MachineOperand::MO_GlobalAddress:
1952 PrintSymbolOperand(MO, OS&: O);
1953 break;
1954
1955 case MachineOperand::MO_MachineBasicBlock:
1956 MO.getMBB()->getSymbol()->print(OS&: O, MAI);
1957 break;
1958
1959 default:
1960 llvm_unreachable("Operand type not supported.");
1961 }
1962}
1963
1964void NVPTXAsmPrinter::printMemOperand(const MachineInstr *MI, unsigned OpNum,
1965 raw_ostream &O, const char *Modifier) {
1966 printOperand(MI, OpNum, O);
1967
1968 if (Modifier && strcmp(s1: Modifier, s2: "add") == 0) {
1969 O << ", ";
1970 printOperand(MI, OpNum: OpNum + 1, O);
1971 } else {
1972 if (MI->getOperand(i: OpNum + 1).isImm() &&
1973 MI->getOperand(i: OpNum + 1).getImm() == 0)
1974 return; // don't print ',0' or '+0'
1975 O << "+";
1976 printOperand(MI, OpNum: OpNum + 1, O);
1977 }
1978}
1979
1980char NVPTXAsmPrinter::ID = 0;
1981
1982INITIALIZE_PASS(NVPTXAsmPrinter, "nvptx-asm-printer", "NVPTX Assembly Printer",
1983 false, false)
1984
1985// Force static initialization.
1986extern "C" LLVM_ABI LLVM_EXTERNAL_VISIBILITY void
1987LLVMInitializeNVPTXAsmPrinter() {
1988 RegisterAsmPrinter<NVPTXAsmPrinter> X(getTheNVPTXTarget32());
1989 RegisterAsmPrinter<NVPTXAsmPrinter> Y(getTheNVPTXTarget64());
1990}
1991