1//===-- NVPTXAsmPrinter.cpp - NVPTX LLVM assembly writer ------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file contains a printer that converts from our internal representation
10// of machine-dependent LLVM code to NVPTX assembly language.
11//
12//===----------------------------------------------------------------------===//
13
14#include "NVPTXAsmPrinter.h"
15#include "MCTargetDesc/NVPTXBaseInfo.h"
16#include "MCTargetDesc/NVPTXInstPrinter.h"
17#include "MCTargetDesc/NVPTXMCAsmInfo.h"
18#include "MCTargetDesc/NVPTXTargetStreamer.h"
19#include "NVPTX.h"
20#include "NVPTXMCExpr.h"
21#include "NVPTXMachineFunctionInfo.h"
22#include "NVPTXRegisterInfo.h"
23#include "NVPTXSubtarget.h"
24#include "NVPTXTargetMachine.h"
25#include "NVPTXUtilities.h"
26#include "TargetInfo/NVPTXTargetInfo.h"
27#include "cl_common_defines.h"
28#include "llvm/ADT/APFloat.h"
29#include "llvm/ADT/APInt.h"
30#include "llvm/ADT/ArrayRef.h"
31#include "llvm/ADT/DenseMap.h"
32#include "llvm/ADT/DenseSet.h"
33#include "llvm/ADT/SmallString.h"
34#include "llvm/ADT/SmallVector.h"
35#include "llvm/ADT/StringExtras.h"
36#include "llvm/ADT/StringRef.h"
37#include "llvm/ADT/Twine.h"
38#include "llvm/ADT/iterator_range.h"
39#include "llvm/Analysis/ConstantFolding.h"
40#include "llvm/CodeGen/Analysis.h"
41#include "llvm/CodeGen/MachineBasicBlock.h"
42#include "llvm/CodeGen/MachineFrameInfo.h"
43#include "llvm/CodeGen/MachineFunction.h"
44#include "llvm/CodeGen/MachineInstr.h"
45#include "llvm/CodeGen/MachineLoopInfo.h"
46#include "llvm/CodeGen/MachineModuleInfo.h"
47#include "llvm/CodeGen/MachineOperand.h"
48#include "llvm/CodeGen/MachineRegisterInfo.h"
49#include "llvm/CodeGen/TargetRegisterInfo.h"
50#include "llvm/CodeGen/ValueTypes.h"
51#include "llvm/CodeGenTypes/MachineValueType.h"
52#include "llvm/IR/Argument.h"
53#include "llvm/IR/Attributes.h"
54#include "llvm/IR/BasicBlock.h"
55#include "llvm/IR/Constant.h"
56#include "llvm/IR/Constants.h"
57#include "llvm/IR/DataLayout.h"
58#include "llvm/IR/DebugInfo.h"
59#include "llvm/IR/DebugInfoMetadata.h"
60#include "llvm/IR/DebugLoc.h"
61#include "llvm/IR/DerivedTypes.h"
62#include "llvm/IR/Function.h"
63#include "llvm/IR/GlobalAlias.h"
64#include "llvm/IR/GlobalValue.h"
65#include "llvm/IR/GlobalVariable.h"
66#include "llvm/IR/Instruction.h"
67#include "llvm/IR/LLVMContext.h"
68#include "llvm/IR/Module.h"
69#include "llvm/IR/Operator.h"
70#include "llvm/IR/Type.h"
71#include "llvm/IR/User.h"
72#include "llvm/MC/MCExpr.h"
73#include "llvm/MC/MCInst.h"
74#include "llvm/MC/MCInstrDesc.h"
75#include "llvm/MC/MCStreamer.h"
76#include "llvm/MC/MCSymbol.h"
77#include "llvm/MC/TargetRegistry.h"
78#include "llvm/Support/Alignment.h"
79#include "llvm/Support/Casting.h"
80#include "llvm/Support/Compiler.h"
81#include "llvm/Support/Endian.h"
82#include "llvm/Support/ErrorHandling.h"
83#include "llvm/Support/NativeFormatting.h"
84#include "llvm/Support/raw_ostream.h"
85#include "llvm/Target/TargetLoweringObjectFile.h"
86#include "llvm/Target/TargetMachine.h"
87#include "llvm/Transforms/Utils/UnrollLoop.h"
88#include <cassert>
89#include <cstdint>
90#include <cstring>
91#include <string>
92#include <utility>
93#include <vector>
94
95using namespace llvm;
96
97#define DEPOTNAME "__local_depot"
98
99/// discoverDependentGlobals - Return a set of GlobalVariables on which \p V
100/// depends.
101static void
102discoverDependentGlobals(const Value *V,
103 DenseSet<const GlobalVariable *> &Globals) {
104 if (const GlobalVariable *GV = dyn_cast<GlobalVariable>(Val: V)) {
105 Globals.insert(V: GV);
106 return;
107 }
108
109 if (const User *U = dyn_cast<User>(Val: V))
110 for (const auto &O : U->operands())
111 discoverDependentGlobals(V: O, Globals);
112}
113
114/// VisitGlobalVariableForEmission - Add \p GV to the list of GlobalVariable
115/// instances to be emitted, but only after any dependents have been added
116/// first.s
117static void
118VisitGlobalVariableForEmission(const GlobalVariable *GV,
119 SmallVectorImpl<const GlobalVariable *> &Order,
120 DenseSet<const GlobalVariable *> &Visited,
121 DenseSet<const GlobalVariable *> &Visiting) {
122 // Have we already visited this one?
123 if (Visited.count(V: GV))
124 return;
125
126 // Do we have a circular dependency?
127 if (!Visiting.insert(V: GV).second)
128 report_fatal_error(reason: "Circular dependency found in global variable set");
129
130 // Make sure we visit all dependents first
131 DenseSet<const GlobalVariable *> Others;
132 for (const auto &O : GV->operands())
133 discoverDependentGlobals(V: O, Globals&: Others);
134
135 for (const GlobalVariable *GV : Others)
136 VisitGlobalVariableForEmission(GV, Order, Visited, Visiting);
137
138 // Now we can visit ourself
139 Order.push_back(Elt: GV);
140 Visited.insert(V: GV);
141 Visiting.erase(V: GV);
142}
143
144void NVPTXAsmPrinter::emitInstruction(const MachineInstr *MI) {
145 NVPTX_MC::verifyInstructionPredicates(Opcode: MI->getOpcode(),
146 Features: getSubtargetInfo().getFeatureBits());
147
148 MCInst Inst;
149 lowerToMCInst(MI, OutMI&: Inst);
150 EmitToStreamer(S&: *OutStreamer, Inst);
151}
152
153void NVPTXAsmPrinter::lowerToMCInst(const MachineInstr *MI, MCInst &OutMI) {
154 OutMI.setOpcode(MI->getOpcode());
155 // Special: Do not mangle symbol operand of CALL_PROTOTYPE
156 if (MI->getOpcode() == NVPTX::CALL_PROTOTYPE) {
157 const MachineOperand &MO = MI->getOperand(i: 0);
158 OutMI.addOperand(Op: GetSymbolRef(
159 Symbol: OutContext.getOrCreateSymbol(Name: Twine(MO.getSymbolName()))));
160 return;
161 }
162
163 for (const auto MO : MI->operands())
164 OutMI.addOperand(Op: lowerOperand(MO));
165}
166
167MCOperand NVPTXAsmPrinter::lowerOperand(const MachineOperand &MO) {
168 switch (MO.getType()) {
169 default:
170 llvm_unreachable("unknown operand type");
171 case MachineOperand::MO_Register:
172 return MCOperand::createReg(Reg: encodeVirtualRegister(Reg: MO.getReg()));
173 case MachineOperand::MO_Immediate:
174 return MCOperand::createImm(Val: MO.getImm());
175 case MachineOperand::MO_MachineBasicBlock:
176 return MCOperand::createExpr(
177 Val: MCSymbolRefExpr::create(Symbol: MO.getMBB()->getSymbol(), Ctx&: OutContext));
178 case MachineOperand::MO_ExternalSymbol:
179 return GetSymbolRef(Symbol: GetExternalSymbolSymbol(Sym: MO.getSymbolName()));
180 case MachineOperand::MO_GlobalAddress:
181 return GetSymbolRef(Symbol: getSymbol(GV: MO.getGlobal()));
182 case MachineOperand::MO_FPImmediate: {
183 const ConstantFP *Cnt = MO.getFPImm();
184 const APFloat &Val = Cnt->getValueAPF();
185
186 switch (Cnt->getType()->getTypeID()) {
187 default:
188 report_fatal_error(reason: "Unsupported FP type");
189 break;
190 case Type::HalfTyID:
191 return MCOperand::createExpr(
192 Val: NVPTXFloatMCExpr::createConstantFPHalf(Flt: Val, Ctx&: OutContext));
193 case Type::BFloatTyID:
194 return MCOperand::createExpr(
195 Val: NVPTXFloatMCExpr::createConstantBFPHalf(Flt: Val, Ctx&: OutContext));
196 case Type::FloatTyID:
197 return MCOperand::createExpr(
198 Val: NVPTXFloatMCExpr::createConstantFPSingle(Flt: Val, Ctx&: OutContext));
199 case Type::DoubleTyID:
200 return MCOperand::createExpr(
201 Val: NVPTXFloatMCExpr::createConstantFPDouble(Flt: Val, Ctx&: OutContext));
202 }
203 break;
204 }
205 }
206}
207
208unsigned NVPTXAsmPrinter::encodeVirtualRegister(unsigned Reg) {
209 if (Register::isVirtualRegister(Reg)) {
210 const TargetRegisterClass *RC = MRI->getRegClass(Reg);
211
212 DenseMap<unsigned, unsigned> &RegMap = VRegMapping[RC];
213 unsigned RegNum = RegMap[Reg];
214
215 // Encode the register class in the upper 4 bits
216 // Must be kept in sync with NVPTXInstPrinter::printRegName
217 unsigned Ret = 0;
218 if (RC == &NVPTX::B1RegClass) {
219 Ret = (1 << 28);
220 } else if (RC == &NVPTX::B16RegClass) {
221 Ret = (2 << 28);
222 } else if (RC == &NVPTX::B32RegClass) {
223 Ret = (3 << 28);
224 } else if (RC == &NVPTX::B64RegClass) {
225 Ret = (4 << 28);
226 } else if (RC == &NVPTX::B128RegClass) {
227 Ret = (7 << 28);
228 } else {
229 report_fatal_error(reason: "Bad register class");
230 }
231
232 // Insert the vreg number
233 Ret |= (RegNum & 0x0FFFFFFF);
234 return Ret;
235 } else {
236 // Some special-use registers are actually physical registers.
237 // Encode this as the register class ID of 0 and the real register ID.
238 return Reg & 0x0FFFFFFF;
239 }
240}
241
242MCOperand NVPTXAsmPrinter::GetSymbolRef(const MCSymbol *Symbol) {
243 const MCExpr *Expr;
244 Expr = MCSymbolRefExpr::create(Symbol, Ctx&: OutContext);
245 return MCOperand::createExpr(Val: Expr);
246}
247
248void NVPTXAsmPrinter::printReturnValStr(const Function *F, raw_ostream &O) {
249 const DataLayout &DL = getDataLayout();
250 const NVPTXSubtarget &STI = TM.getSubtarget<NVPTXSubtarget>(F: *F);
251 const auto *TLI = cast<NVPTXTargetLowering>(Val: STI.getTargetLowering());
252
253 Type *Ty = F->getReturnType();
254 if (Ty->getTypeID() == Type::VoidTyID)
255 return;
256 O << " (";
257
258 auto PrintScalarRetVal = [&](unsigned Size) {
259 O << ".param .b" << promoteScalarArgumentSize(size: Size) << " func_retval0";
260 };
261 if (shouldPassAsArray(Ty)) {
262 const unsigned TotalSize = DL.getTypeAllocSize(Ty);
263 const Align RetAlignment = TLI->getFunctionArgumentAlignment(
264 F, Ty, Idx: AttributeList::ReturnIndex, DL);
265 O << ".param .align " << RetAlignment.value() << " .b8 func_retval0["
266 << TotalSize << "]";
267 } else if (Ty->isFloatingPointTy()) {
268 PrintScalarRetVal(Ty->getPrimitiveSizeInBits());
269 } else if (auto *ITy = dyn_cast<IntegerType>(Val: Ty)) {
270 PrintScalarRetVal(ITy->getBitWidth());
271 } else if (isa<PointerType>(Val: Ty)) {
272 PrintScalarRetVal(TLI->getPointerTy(DL).getSizeInBits());
273 } else
274 llvm_unreachable("Unknown return type");
275 O << ") ";
276}
277
278void NVPTXAsmPrinter::printReturnValStr(const MachineFunction &MF,
279 raw_ostream &O) {
280 const Function &F = MF.getFunction();
281 printReturnValStr(F: &F, O);
282}
283
284// Return true if MBB is the header of a loop marked with
285// llvm.loop.unroll.disable or llvm.loop.unroll.count=1.
286bool NVPTXAsmPrinter::isLoopHeaderOfNoUnroll(
287 const MachineBasicBlock &MBB) const {
288 MachineLoopInfo &LI = getAnalysis<MachineLoopInfoWrapperPass>().getLI();
289 // We insert .pragma "nounroll" only to the loop header.
290 if (!LI.isLoopHeader(BB: &MBB))
291 return false;
292
293 // llvm.loop.unroll.disable is marked on the back edges of a loop. Therefore,
294 // we iterate through each back edge of the loop with header MBB, and check
295 // whether its metadata contains llvm.loop.unroll.disable.
296 for (const MachineBasicBlock *PMBB : MBB.predecessors()) {
297 if (LI.getLoopFor(BB: PMBB) != LI.getLoopFor(BB: &MBB)) {
298 // Edges from other loops to MBB are not back edges.
299 continue;
300 }
301 if (const BasicBlock *PBB = PMBB->getBasicBlock()) {
302 if (MDNode *LoopID =
303 PBB->getTerminator()->getMetadata(KindID: LLVMContext::MD_loop)) {
304 if (GetUnrollMetadata(LoopID, Name: "llvm.loop.unroll.disable"))
305 return true;
306 if (MDNode *UnrollCountMD =
307 GetUnrollMetadata(LoopID, Name: "llvm.loop.unroll.count")) {
308 if (mdconst::extract<ConstantInt>(MD: UnrollCountMD->getOperand(I: 1))
309 ->isOne())
310 return true;
311 }
312 }
313 }
314 }
315 return false;
316}
317
318void NVPTXAsmPrinter::emitBasicBlockStart(const MachineBasicBlock &MBB) {
319 AsmPrinter::emitBasicBlockStart(MBB);
320 if (isLoopHeaderOfNoUnroll(MBB))
321 OutStreamer->emitRawText(String: StringRef("\t.pragma \"nounroll\";\n"));
322}
323
324void NVPTXAsmPrinter::emitFunctionEntryLabel() {
325 SmallString<128> Str;
326 raw_svector_ostream O(Str);
327
328 if (!GlobalsEmitted) {
329 emitGlobals(M: *MF->getFunction().getParent());
330 GlobalsEmitted = true;
331 }
332
333 // Set up
334 MRI = &MF->getRegInfo();
335 F = &MF->getFunction();
336 emitLinkageDirective(V: F, O);
337 if (isKernelFunction(F: *F))
338 O << ".entry ";
339 else {
340 O << ".func ";
341 printReturnValStr(MF: *MF, O);
342 }
343
344 CurrentFnSym->print(OS&: O, MAI);
345
346 emitFunctionParamList(F, O);
347 O << "\n";
348
349 if (isKernelFunction(F: *F))
350 emitKernelFunctionDirectives(F: *F, O);
351
352 if (shouldEmitPTXNoReturn(V: F, TM))
353 O << ".noreturn";
354
355 OutStreamer->emitRawText(String: O.str());
356
357 VRegMapping.clear();
358 // Emit open brace for function body.
359 OutStreamer->emitRawText(String: StringRef("{\n"));
360 setAndEmitFunctionVirtualRegisters(*MF);
361 encodeDebugInfoRegisterNumbers(MF: *MF);
362 // Emit initial .loc debug directive for correct relocation symbol data.
363 if (const DISubprogram *SP = MF->getFunction().getSubprogram()) {
364 assert(SP->getUnit());
365 if (!SP->getUnit()->isDebugDirectivesOnly())
366 emitInitialRawDwarfLocDirective(MF: *MF);
367 }
368}
369
370bool NVPTXAsmPrinter::runOnMachineFunction(MachineFunction &F) {
371 bool Result = AsmPrinter::runOnMachineFunction(MF&: F);
372 // Emit closing brace for the body of function F.
373 // The closing brace must be emitted here because we need to emit additional
374 // debug labels/data after the last basic block.
375 // We need to emit the closing brace here because we don't have function that
376 // finished emission of the function body.
377 OutStreamer->emitRawText(String: StringRef("}\n"));
378 return Result;
379}
380
381void NVPTXAsmPrinter::emitFunctionBodyStart() {
382 SmallString<128> Str;
383 raw_svector_ostream O(Str);
384 emitDemotedVars(&MF->getFunction(), O);
385 OutStreamer->emitRawText(String: O.str());
386}
387
388void NVPTXAsmPrinter::emitFunctionBodyEnd() {
389 VRegMapping.clear();
390}
391
392const MCSymbol *NVPTXAsmPrinter::getFunctionFrameSymbol() const {
393 SmallString<128> Str;
394 raw_svector_ostream(Str) << DEPOTNAME << getFunctionNumber();
395 return OutContext.getOrCreateSymbol(Name: Str);
396}
397
398void NVPTXAsmPrinter::emitImplicitDef(const MachineInstr *MI) const {
399 Register RegNo = MI->getOperand(i: 0).getReg();
400 if (RegNo.isVirtual()) {
401 OutStreamer->AddComment(T: Twine("implicit-def: ") +
402 getVirtualRegisterName(RegNo));
403 } else {
404 const NVPTXSubtarget &STI = MI->getMF()->getSubtarget<NVPTXSubtarget>();
405 OutStreamer->AddComment(T: Twine("implicit-def: ") +
406 STI.getRegisterInfo()->getName(RegNo));
407 }
408 OutStreamer->addBlankLine();
409}
410
411void NVPTXAsmPrinter::emitKernelFunctionDirectives(const Function &F,
412 raw_ostream &O) const {
413 // If the NVVM IR has some of reqntid* specified, then output
414 // the reqntid directive, and set the unspecified ones to 1.
415 // If none of Reqntid* is specified, don't output reqntid directive.
416 const auto ReqNTID = getReqNTID(F);
417 if (!ReqNTID.empty())
418 O << formatv(Fmt: ".reqntid {0:$[, ]}\n",
419 Vals: make_range(x: ReqNTID.begin(), y: ReqNTID.end()));
420
421 const auto MaxNTID = getMaxNTID(F);
422 if (!MaxNTID.empty())
423 O << formatv(Fmt: ".maxntid {0:$[, ]}\n",
424 Vals: make_range(x: MaxNTID.begin(), y: MaxNTID.end()));
425
426 if (const auto Mincta = getMinCTASm(F))
427 O << ".minnctapersm " << *Mincta << "\n";
428
429 if (const auto Maxnreg = getMaxNReg(F))
430 O << ".maxnreg " << *Maxnreg << "\n";
431
432 // .maxclusterrank directive requires SM_90 or higher, make sure that we
433 // filter it out for lower SM versions, as it causes a hard ptxas crash.
434 const NVPTXTargetMachine &NTM = static_cast<const NVPTXTargetMachine &>(TM);
435 const auto *STI = static_cast<const NVPTXSubtarget *>(NTM.getSubtargetImpl());
436
437 if (STI->getSmVersion() >= 90) {
438 const auto ClusterDim = getClusterDim(F);
439
440 if (!ClusterDim.empty()) {
441 O << ".explicitcluster\n";
442 if (ClusterDim[0] != 0) {
443 assert(llvm::all_of(ClusterDim, [](unsigned D) { return D != 0; }) &&
444 "cluster_dim_x != 0 implies cluster_dim_y and cluster_dim_z "
445 "should be non-zero as well");
446
447 O << formatv(Fmt: ".reqnctapercluster {0:$[, ]}\n",
448 Vals: make_range(x: ClusterDim.begin(), y: ClusterDim.end()));
449 } else {
450 assert(llvm::all_of(ClusterDim, [](unsigned D) { return D == 0; }) &&
451 "cluster_dim_x == 0 implies cluster_dim_y and cluster_dim_z "
452 "should be 0 as well");
453 }
454 }
455 if (const auto Maxclusterrank = getMaxClusterRank(F))
456 O << ".maxclusterrank " << *Maxclusterrank << "\n";
457 }
458}
459
460std::string NVPTXAsmPrinter::getVirtualRegisterName(unsigned Reg) const {
461 const TargetRegisterClass *RC = MRI->getRegClass(Reg);
462
463 std::string Name;
464 raw_string_ostream NameStr(Name);
465
466 VRegRCMap::const_iterator I = VRegMapping.find(Val: RC);
467 assert(I != VRegMapping.end() && "Bad register class");
468 const DenseMap<unsigned, unsigned> &RegMap = I->second;
469
470 VRegMap::const_iterator VI = RegMap.find(Val: Reg);
471 assert(VI != RegMap.end() && "Bad virtual register");
472 unsigned MappedVR = VI->second;
473
474 NameStr << getNVPTXRegClassStr(RC) << MappedVR;
475
476 return Name;
477}
478
479void NVPTXAsmPrinter::emitVirtualRegister(unsigned int vr,
480 raw_ostream &O) {
481 O << getVirtualRegisterName(Reg: vr);
482}
483
484void NVPTXAsmPrinter::emitAliasDeclaration(const GlobalAlias *GA,
485 raw_ostream &O) {
486 const Function *F = dyn_cast_or_null<Function>(Val: GA->getAliaseeObject());
487 if (!F || isKernelFunction(F: *F) || F->isDeclaration())
488 report_fatal_error(
489 reason: "NVPTX aliasee must be a non-kernel function definition");
490
491 if (GA->hasLinkOnceLinkage() || GA->hasWeakLinkage() ||
492 GA->hasAvailableExternallyLinkage() || GA->hasCommonLinkage())
493 report_fatal_error(reason: "NVPTX aliasee must not be '.weak'");
494
495 emitDeclarationWithName(F, getSymbol(GV: GA), O);
496}
497
498void NVPTXAsmPrinter::emitDeclaration(const Function *F, raw_ostream &O) {
499 emitDeclarationWithName(F, getSymbol(GV: F), O);
500}
501
502void NVPTXAsmPrinter::emitDeclarationWithName(const Function *F, MCSymbol *S,
503 raw_ostream &O) {
504 emitLinkageDirective(V: F, O);
505 if (isKernelFunction(F: *F))
506 O << ".entry ";
507 else
508 O << ".func ";
509 printReturnValStr(F, O);
510 S->print(OS&: O, MAI);
511 O << "\n";
512 emitFunctionParamList(F, O);
513 O << "\n";
514 if (shouldEmitPTXNoReturn(V: F, TM))
515 O << ".noreturn";
516 O << ";\n";
517}
518
519static bool usedInGlobalVarDef(const Constant *C) {
520 if (!C)
521 return false;
522
523 if (const GlobalVariable *GV = dyn_cast<GlobalVariable>(Val: C))
524 return GV->getName() != "llvm.used";
525
526 for (const User *U : C->users())
527 if (const Constant *C = dyn_cast<Constant>(Val: U))
528 if (usedInGlobalVarDef(C))
529 return true;
530
531 return false;
532}
533
534static bool usedInOneFunc(const User *U, Function const *&OneFunc) {
535 if (const GlobalVariable *OtherGV = dyn_cast<GlobalVariable>(Val: U))
536 if (OtherGV->getName() == "llvm.used")
537 return true;
538
539 if (const Instruction *I = dyn_cast<Instruction>(Val: U)) {
540 if (const Function *CurFunc = I->getFunction()) {
541 if (OneFunc && (CurFunc != OneFunc))
542 return false;
543 OneFunc = CurFunc;
544 return true;
545 }
546 return false;
547 }
548
549 for (const User *UU : U->users())
550 if (!usedInOneFunc(U: UU, OneFunc))
551 return false;
552
553 return true;
554}
555
556/* Find out if a global variable can be demoted to local scope.
557 * Currently, this is valid for CUDA shared variables, which have local
558 * scope and global lifetime. So the conditions to check are :
559 * 1. Is the global variable in shared address space?
560 * 2. Does it have local linkage?
561 * 3. Is the global variable referenced only in one function?
562 */
563static bool canDemoteGlobalVar(const GlobalVariable *GV, Function const *&f) {
564 if (!GV->hasLocalLinkage())
565 return false;
566 if (GV->getAddressSpace() != ADDRESS_SPACE_SHARED)
567 return false;
568
569 const Function *oneFunc = nullptr;
570
571 bool flag = usedInOneFunc(U: GV, OneFunc&: oneFunc);
572 if (!flag)
573 return false;
574 if (!oneFunc)
575 return false;
576 f = oneFunc;
577 return true;
578}
579
580static bool useFuncSeen(const Constant *C,
581 const SmallPtrSetImpl<const Function *> &SeenSet) {
582 for (const User *U : C->users()) {
583 if (const Constant *cu = dyn_cast<Constant>(Val: U)) {
584 if (useFuncSeen(C: cu, SeenSet))
585 return true;
586 } else if (const Instruction *I = dyn_cast<Instruction>(Val: U)) {
587 if (const Function *Caller = I->getFunction())
588 if (SeenSet.contains(Ptr: Caller))
589 return true;
590 }
591 }
592 return false;
593}
594
595void NVPTXAsmPrinter::emitDeclarations(const Module &M, raw_ostream &O) {
596 SmallPtrSet<const Function *, 32> SeenSet;
597 for (const Function &F : M) {
598 if (F.getAttributes().hasFnAttr(Kind: "nvptx-libcall-callee")) {
599 emitDeclaration(F: &F, O);
600 continue;
601 }
602
603 if (F.isDeclaration()) {
604 if (F.use_empty())
605 continue;
606 if (F.getIntrinsicID())
607 continue;
608 emitDeclaration(F: &F, O);
609 continue;
610 }
611 for (const User *U : F.users()) {
612 if (const Constant *C = dyn_cast<Constant>(Val: U)) {
613 if (usedInGlobalVarDef(C)) {
614 // The use is in the initialization of a global variable
615 // that is a function pointer, so print a declaration
616 // for the original function
617 emitDeclaration(F: &F, O);
618 break;
619 }
620 // Emit a declaration of this function if the function that
621 // uses this constant expr has already been seen.
622 if (useFuncSeen(C, SeenSet)) {
623 emitDeclaration(F: &F, O);
624 break;
625 }
626 }
627
628 if (!isa<Instruction>(Val: U))
629 continue;
630 const Function *Caller = cast<Instruction>(Val: U)->getFunction();
631 if (!Caller)
632 continue;
633
634 // If a caller has already been seen, then the caller is
635 // appearing in the module before the callee. so print out
636 // a declaration for the callee.
637 if (SeenSet.contains(Ptr: Caller)) {
638 emitDeclaration(F: &F, O);
639 break;
640 }
641 }
642 SeenSet.insert(Ptr: &F);
643 }
644 for (const GlobalAlias &GA : M.aliases())
645 emitAliasDeclaration(GA: &GA, O);
646}
647
648void NVPTXAsmPrinter::emitStartOfAsmFile(Module &M) {
649 // Construct a default subtarget off of the TargetMachine defaults. The
650 // rest of NVPTX isn't friendly to change subtargets per function and
651 // so the default TargetMachine will have all of the options.
652 const NVPTXTargetMachine &NTM = static_cast<const NVPTXTargetMachine &>(TM);
653 const auto* STI = static_cast<const NVPTXSubtarget*>(NTM.getSubtargetImpl());
654 SmallString<128> Str1;
655 raw_svector_ostream OS1(Str1);
656
657 // Emit header before any dwarf directives are emitted below.
658 emitHeader(M, O&: OS1, STI: *STI);
659 OutStreamer->emitRawText(String: OS1.str());
660}
661
662bool NVPTXAsmPrinter::doInitialization(Module &M) {
663 const NVPTXTargetMachine &NTM = static_cast<const NVPTXTargetMachine &>(TM);
664 const NVPTXSubtarget &STI =
665 *static_cast<const NVPTXSubtarget *>(NTM.getSubtargetImpl());
666 if (M.alias_size() && (STI.getPTXVersion() < 63 || STI.getSmVersion() < 30))
667 report_fatal_error(reason: ".alias requires PTX version >= 6.3 and sm_30");
668
669 // We need to call the parent's one explicitly.
670 bool Result = AsmPrinter::doInitialization(M);
671
672 GlobalsEmitted = false;
673
674 return Result;
675}
676
677void NVPTXAsmPrinter::emitGlobals(const Module &M) {
678 SmallString<128> Str2;
679 raw_svector_ostream OS2(Str2);
680
681 emitDeclarations(M, O&: OS2);
682
683 // As ptxas does not support forward references of globals, we need to first
684 // sort the list of module-level globals in def-use order. We visit each
685 // global variable in order, and ensure that we emit it *after* its dependent
686 // globals. We use a little extra memory maintaining both a set and a list to
687 // have fast searches while maintaining a strict ordering.
688 SmallVector<const GlobalVariable *, 8> Globals;
689 DenseSet<const GlobalVariable *> GVVisited;
690 DenseSet<const GlobalVariable *> GVVisiting;
691
692 // Visit each global variable, in order
693 for (const GlobalVariable &I : M.globals())
694 VisitGlobalVariableForEmission(GV: &I, Order&: Globals, Visited&: GVVisited, Visiting&: GVVisiting);
695
696 assert(GVVisited.size() == M.global_size() && "Missed a global variable");
697 assert(GVVisiting.size() == 0 && "Did not fully process a global variable");
698
699 const NVPTXTargetMachine &NTM = static_cast<const NVPTXTargetMachine &>(TM);
700 const NVPTXSubtarget &STI =
701 *static_cast<const NVPTXSubtarget *>(NTM.getSubtargetImpl());
702
703 // Print out module-level global variables in proper order
704 for (const GlobalVariable *GV : Globals)
705 printModuleLevelGV(GVar: GV, O&: OS2, /*ProcessDemoted=*/processDemoted: false, STI);
706
707 OS2 << '\n';
708
709 OutStreamer->emitRawText(String: OS2.str());
710}
711
712void NVPTXAsmPrinter::emitGlobalAlias(const Module &M, const GlobalAlias &GA) {
713 SmallString<128> Str;
714 raw_svector_ostream OS(Str);
715
716 MCSymbol *Name = getSymbol(GV: &GA);
717
718 OS << ".alias " << Name->getName() << ", " << GA.getAliaseeObject()->getName()
719 << ";\n";
720
721 OutStreamer->emitRawText(String: OS.str());
722}
723
724void NVPTXAsmPrinter::emitHeader(Module &M, raw_ostream &O,
725 const NVPTXSubtarget &STI) {
726 const unsigned PTXVersion = STI.getPTXVersion();
727
728 O << "//\n"
729 "// Generated by LLVM NVPTX Back-End\n"
730 "//\n"
731 "\n"
732 << ".version " << (PTXVersion / 10) << "." << (PTXVersion % 10) << "\n"
733 << ".target " << STI.getTargetName();
734
735 const NVPTXTargetMachine &NTM = static_cast<const NVPTXTargetMachine &>(TM);
736 if (NTM.getDrvInterface() == NVPTX::NVCL)
737 O << ", texmode_independent";
738
739 bool HasFullDebugInfo = false;
740 for (DICompileUnit *CU : M.debug_compile_units()) {
741 switch(CU->getEmissionKind()) {
742 case DICompileUnit::NoDebug:
743 case DICompileUnit::DebugDirectivesOnly:
744 break;
745 case DICompileUnit::LineTablesOnly:
746 case DICompileUnit::FullDebug:
747 HasFullDebugInfo = true;
748 break;
749 }
750 if (HasFullDebugInfo)
751 break;
752 }
753 if (HasFullDebugInfo)
754 O << ", debug";
755
756 O << "\n"
757 << ".address_size " << (NTM.is64Bit() ? "64" : "32") << "\n"
758 << "\n";
759}
760
761bool NVPTXAsmPrinter::doFinalization(Module &M) {
762 // If we did not emit any functions, then the global declarations have not
763 // yet been emitted.
764 if (!GlobalsEmitted) {
765 emitGlobals(M);
766 GlobalsEmitted = true;
767 }
768
769 // call doFinalization
770 bool ret = AsmPrinter::doFinalization(M);
771
772 clearAnnotationCache(&M);
773
774 auto *TS =
775 static_cast<NVPTXTargetStreamer *>(OutStreamer->getTargetStreamer());
776 // Close the last emitted section
777 if (hasDebugInfo()) {
778 TS->closeLastSection();
779 // Emit empty .debug_macinfo section for better support of the empty files.
780 OutStreamer->emitRawText(String: "\t.section\t.debug_macinfo\t{\t}");
781 }
782
783 // Output last DWARF .file directives, if any.
784 TS->outputDwarfFileDirectives();
785
786 return ret;
787}
788
789// This function emits appropriate linkage directives for
790// functions and global variables.
791//
792// extern function declaration -> .extern
793// extern function definition -> .visible
794// external global variable with init -> .visible
795// external without init -> .extern
796// appending -> not allowed, assert.
797// for any linkage other than
798// internal, private, linker_private,
799// linker_private_weak, linker_private_weak_def_auto,
800// we emit -> .weak.
801
802void NVPTXAsmPrinter::emitLinkageDirective(const GlobalValue *V,
803 raw_ostream &O) {
804 if (static_cast<NVPTXTargetMachine &>(TM).getDrvInterface() == NVPTX::CUDA) {
805 if (V->hasExternalLinkage()) {
806 if (const auto *GVar = dyn_cast<GlobalVariable>(Val: V))
807 O << (GVar->hasInitializer() ? ".visible " : ".extern ");
808 else if (V->isDeclaration())
809 O << ".extern ";
810 else
811 O << ".visible ";
812 } else if (V->hasAppendingLinkage()) {
813 report_fatal_error(reason: "Symbol '" + (V->hasName() ? V->getName() : "") +
814 "' has unsupported appending linkage type");
815 } else if (!V->hasInternalLinkage() && !V->hasPrivateLinkage()) {
816 O << ".weak ";
817 }
818 }
819}
820
821void NVPTXAsmPrinter::printModuleLevelGV(const GlobalVariable *GVar,
822 raw_ostream &O, bool ProcessDemoted,
823 const NVPTXSubtarget &STI) {
824 // Skip meta data
825 if (GVar->hasSection())
826 if (GVar->getSection() == "llvm.metadata")
827 return;
828
829 // Skip LLVM intrinsic global variables
830 if (GVar->getName().starts_with(Prefix: "llvm.") ||
831 GVar->getName().starts_with(Prefix: "nvvm."))
832 return;
833
834 const DataLayout &DL = getDataLayout();
835
836 // GlobalVariables are always constant pointers themselves.
837 Type *ETy = GVar->getValueType();
838
839 if (GVar->hasExternalLinkage()) {
840 if (GVar->hasInitializer())
841 O << ".visible ";
842 else
843 O << ".extern ";
844 } else if (STI.getPTXVersion() >= 50 && GVar->hasCommonLinkage() &&
845 GVar->getAddressSpace() == ADDRESS_SPACE_GLOBAL) {
846 O << ".common ";
847 } else if (GVar->hasLinkOnceLinkage() || GVar->hasWeakLinkage() ||
848 GVar->hasAvailableExternallyLinkage() ||
849 GVar->hasCommonLinkage()) {
850 O << ".weak ";
851 }
852
853 if (isTexture(*GVar)) {
854 O << ".global .texref " << getTextureName(*GVar) << ";\n";
855 return;
856 }
857
858 if (isSurface(*GVar)) {
859 O << ".global .surfref " << getSurfaceName(*GVar) << ";\n";
860 return;
861 }
862
863 if (GVar->isDeclaration()) {
864 // (extern) declarations, no definition or initializer
865 // Currently the only known declaration is for an automatic __local
866 // (.shared) promoted to global.
867 emitPTXGlobalVariable(GVar, O, STI);
868 O << ";\n";
869 return;
870 }
871
872 if (isSampler(*GVar)) {
873 O << ".global .samplerref " << getSamplerName(*GVar);
874
875 const Constant *Initializer = nullptr;
876 if (GVar->hasInitializer())
877 Initializer = GVar->getInitializer();
878 const ConstantInt *CI = nullptr;
879 if (Initializer)
880 CI = dyn_cast<ConstantInt>(Val: Initializer);
881 if (CI) {
882 unsigned sample = CI->getZExtValue();
883
884 O << " = { ";
885
886 for (int i = 0,
887 addr = ((sample & __CLK_ADDRESS_MASK) >> __CLK_ADDRESS_BASE);
888 i < 3; i++) {
889 O << "addr_mode_" << i << " = ";
890 switch (addr) {
891 case 0:
892 O << "wrap";
893 break;
894 case 1:
895 O << "clamp_to_border";
896 break;
897 case 2:
898 O << "clamp_to_edge";
899 break;
900 case 3:
901 O << "wrap";
902 break;
903 case 4:
904 O << "mirror";
905 break;
906 }
907 O << ", ";
908 }
909 O << "filter_mode = ";
910 switch ((sample & __CLK_FILTER_MASK) >> __CLK_FILTER_BASE) {
911 case 0:
912 O << "nearest";
913 break;
914 case 1:
915 O << "linear";
916 break;
917 case 2:
918 llvm_unreachable("Anisotropic filtering is not supported");
919 default:
920 O << "nearest";
921 break;
922 }
923 if (!((sample & __CLK_NORMALIZED_MASK) >> __CLK_NORMALIZED_BASE)) {
924 O << ", force_unnormalized_coords = 1";
925 }
926 O << " }";
927 }
928
929 O << ";\n";
930 return;
931 }
932
933 if (GVar->hasPrivateLinkage()) {
934 if (GVar->getName().starts_with(Prefix: "unrollpragma"))
935 return;
936
937 // FIXME - need better way (e.g. Metadata) to avoid generating this global
938 if (GVar->getName().starts_with(Prefix: "filename"))
939 return;
940 if (GVar->use_empty())
941 return;
942 }
943
944 const Function *DemotedFunc = nullptr;
945 if (!ProcessDemoted && canDemoteGlobalVar(GV: GVar, f&: DemotedFunc)) {
946 O << "// " << GVar->getName() << " has been demoted\n";
947 localDecls[DemotedFunc].push_back(x: GVar);
948 return;
949 }
950
951 O << ".";
952 emitPTXAddressSpace(AddressSpace: GVar->getAddressSpace(), O);
953
954 if (isManaged(*GVar)) {
955 if (STI.getPTXVersion() < 40 || STI.getSmVersion() < 30)
956 report_fatal_error(
957 reason: ".attribute(.managed) requires PTX version >= 4.0 and sm_30");
958 O << " .attribute(.managed)";
959 }
960
961 O << " .align "
962 << GVar->getAlign().value_or(u: DL.getPrefTypeAlign(Ty: ETy)).value();
963
964 if (ETy->isPointerTy() || ((ETy->isIntegerTy() || ETy->isFloatingPointTy()) &&
965 ETy->getScalarSizeInBits() <= 64)) {
966 O << " .";
967 // Special case: ABI requires that we use .u8 for predicates
968 if (ETy->isIntegerTy(Bitwidth: 1))
969 O << "u8";
970 else
971 O << getPTXFundamentalTypeStr(Ty: ETy, false);
972 O << " ";
973 getSymbol(GV: GVar)->print(OS&: O, MAI);
974
975 // Ptx allows variable initilization only for constant and global state
976 // spaces.
977 if (GVar->hasInitializer()) {
978 if ((GVar->getAddressSpace() == ADDRESS_SPACE_GLOBAL) ||
979 (GVar->getAddressSpace() == ADDRESS_SPACE_CONST)) {
980 const Constant *Initializer = GVar->getInitializer();
981 // 'undef' is treated as there is no value specified.
982 if (!Initializer->isNullValue() && !isa<UndefValue>(Val: Initializer)) {
983 O << " = ";
984 printScalarConstant(CPV: Initializer, O);
985 }
986 } else {
987 // The frontend adds zero-initializer to device and constant variables
988 // that don't have an initial value, and UndefValue to shared
989 // variables, so skip warning for this case.
990 if (!GVar->getInitializer()->isNullValue() &&
991 !isa<UndefValue>(Val: GVar->getInitializer())) {
992 report_fatal_error(reason: "initial value of '" + GVar->getName() +
993 "' is not allowed in addrspace(" +
994 Twine(GVar->getAddressSpace()) + ")");
995 }
996 }
997 }
998 } else {
999 // Although PTX has direct support for struct type and array type and
1000 // LLVM IR is very similar to PTX, the LLVM CodeGen does not support for
1001 // targets that support these high level field accesses. Structs, arrays
1002 // and vectors are lowered into arrays of bytes.
1003 switch (ETy->getTypeID()) {
1004 case Type::IntegerTyID: // Integers larger than 64 bits
1005 case Type::FP128TyID:
1006 case Type::StructTyID:
1007 case Type::ArrayTyID:
1008 case Type::FixedVectorTyID: {
1009 const uint64_t ElementSize = DL.getTypeStoreSize(Ty: ETy);
1010 // Ptx allows variable initilization only for constant and
1011 // global state spaces.
1012 if (((GVar->getAddressSpace() == ADDRESS_SPACE_GLOBAL) ||
1013 (GVar->getAddressSpace() == ADDRESS_SPACE_CONST)) &&
1014 GVar->hasInitializer()) {
1015 const Constant *Initializer = GVar->getInitializer();
1016 if (!isa<UndefValue>(Val: Initializer) && !Initializer->isNullValue()) {
1017 AggBuffer aggBuffer(ElementSize, *this);
1018 bufferAggregateConstant(CV: Initializer, aggBuffer: &aggBuffer);
1019 if (aggBuffer.numSymbols()) {
1020 const unsigned int ptrSize = MAI->getCodePointerSize();
1021 if (ElementSize % ptrSize ||
1022 !aggBuffer.allSymbolsAligned(ptrSize)) {
1023 // Print in bytes and use the mask() operator for pointers.
1024 if (!STI.hasMaskOperator())
1025 report_fatal_error(
1026 reason: "initialized packed aggregate with pointers '" +
1027 GVar->getName() +
1028 "' requires at least PTX ISA version 7.1");
1029 O << " .u8 ";
1030 getSymbol(GV: GVar)->print(OS&: O, MAI);
1031 O << "[" << ElementSize << "] = {";
1032 aggBuffer.printBytes(os&: O);
1033 O << "}";
1034 } else {
1035 O << " .u" << ptrSize * 8 << " ";
1036 getSymbol(GV: GVar)->print(OS&: O, MAI);
1037 O << "[" << ElementSize / ptrSize << "] = {";
1038 aggBuffer.printWords(os&: O);
1039 O << "}";
1040 }
1041 } else {
1042 O << " .b8 ";
1043 getSymbol(GV: GVar)->print(OS&: O, MAI);
1044 O << "[" << ElementSize << "] = {";
1045 aggBuffer.printBytes(os&: O);
1046 O << "}";
1047 }
1048 } else {
1049 O << " .b8 ";
1050 getSymbol(GV: GVar)->print(OS&: O, MAI);
1051 if (ElementSize)
1052 O << "[" << ElementSize << "]";
1053 }
1054 } else {
1055 O << " .b8 ";
1056 getSymbol(GV: GVar)->print(OS&: O, MAI);
1057 if (ElementSize)
1058 O << "[" << ElementSize << "]";
1059 }
1060 break;
1061 }
1062 default:
1063 llvm_unreachable("type not supported yet");
1064 }
1065 }
1066 O << ";\n";
1067}
1068
1069void NVPTXAsmPrinter::AggBuffer::printSymbol(unsigned nSym, raw_ostream &os) {
1070 const Value *v = Symbols[nSym];
1071 const Value *v0 = SymbolsBeforeStripping[nSym];
1072 if (const GlobalValue *GVar = dyn_cast<GlobalValue>(Val: v)) {
1073 MCSymbol *Name = AP.getSymbol(GV: GVar);
1074 PointerType *PTy = dyn_cast<PointerType>(Val: v0->getType());
1075 // Is v0 a generic pointer?
1076 bool isGenericPointer = PTy && PTy->getAddressSpace() == 0;
1077 if (EmitGeneric && isGenericPointer && !isa<Function>(Val: v)) {
1078 os << "generic(";
1079 Name->print(OS&: os, MAI: AP.MAI);
1080 os << ")";
1081 } else {
1082 Name->print(OS&: os, MAI: AP.MAI);
1083 }
1084 } else if (const ConstantExpr *CExpr = dyn_cast<ConstantExpr>(Val: v0)) {
1085 const MCExpr *Expr = AP.lowerConstantForGV(CV: CExpr, ProcessingGeneric: false);
1086 AP.printMCExpr(Expr: *Expr, OS&: os);
1087 } else
1088 llvm_unreachable("symbol type unknown");
1089}
1090
1091void NVPTXAsmPrinter::AggBuffer::printBytes(raw_ostream &os) {
1092 unsigned int ptrSize = AP.MAI->getCodePointerSize();
1093 // Do not emit trailing zero initializers. They will be zero-initialized by
1094 // ptxas. This saves on both space requirements for the generated PTX and on
1095 // memory use by ptxas. (See:
1096 // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#global-state-space)
1097 unsigned int InitializerCount = size;
1098 // TODO: symbols make this harder, but it would still be good to trim trailing
1099 // 0s for aggs with symbols as well.
1100 if (numSymbols() == 0)
1101 while (InitializerCount >= 1 && !buffer[InitializerCount - 1])
1102 InitializerCount--;
1103
1104 symbolPosInBuffer.push_back(Elt: InitializerCount);
1105 unsigned int nSym = 0;
1106 unsigned int nextSymbolPos = symbolPosInBuffer[nSym];
1107 for (unsigned int pos = 0; pos < InitializerCount;) {
1108 if (pos)
1109 os << ", ";
1110 if (pos != nextSymbolPos) {
1111 os << (unsigned int)buffer[pos];
1112 ++pos;
1113 continue;
1114 }
1115 // Generate a per-byte mask() operator for the symbol, which looks like:
1116 // .global .u8 addr[] = {0xFF(foo), 0xFF00(foo), 0xFF0000(foo), ...};
1117 // See https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#initializers
1118 std::string symText;
1119 llvm::raw_string_ostream oss(symText);
1120 printSymbol(nSym, os&: oss);
1121 for (unsigned i = 0; i < ptrSize; ++i) {
1122 if (i)
1123 os << ", ";
1124 llvm::write_hex(S&: os, N: 0xFFULL << i * 8, Style: HexPrintStyle::PrefixUpper);
1125 os << "(" << symText << ")";
1126 }
1127 pos += ptrSize;
1128 nextSymbolPos = symbolPosInBuffer[++nSym];
1129 assert(nextSymbolPos >= pos);
1130 }
1131}
1132
1133void NVPTXAsmPrinter::AggBuffer::printWords(raw_ostream &os) {
1134 unsigned int ptrSize = AP.MAI->getCodePointerSize();
1135 symbolPosInBuffer.push_back(Elt: size);
1136 unsigned int nSym = 0;
1137 unsigned int nextSymbolPos = symbolPosInBuffer[nSym];
1138 assert(nextSymbolPos % ptrSize == 0);
1139 for (unsigned int pos = 0; pos < size; pos += ptrSize) {
1140 if (pos)
1141 os << ", ";
1142 if (pos == nextSymbolPos) {
1143 printSymbol(nSym, os);
1144 nextSymbolPos = symbolPosInBuffer[++nSym];
1145 assert(nextSymbolPos % ptrSize == 0);
1146 assert(nextSymbolPos >= pos + ptrSize);
1147 } else if (ptrSize == 4)
1148 os << support::endian::read32le(P: &buffer[pos]);
1149 else
1150 os << support::endian::read64le(P: &buffer[pos]);
1151 }
1152}
1153
1154void NVPTXAsmPrinter::emitDemotedVars(const Function *F, raw_ostream &O) {
1155 auto It = localDecls.find(x: F);
1156 if (It == localDecls.end())
1157 return;
1158
1159 ArrayRef<const GlobalVariable *> GVars = It->second;
1160
1161 const NVPTXTargetMachine &NTM = static_cast<const NVPTXTargetMachine &>(TM);
1162 const NVPTXSubtarget &STI =
1163 *static_cast<const NVPTXSubtarget *>(NTM.getSubtargetImpl());
1164
1165 for (const GlobalVariable *GV : GVars) {
1166 O << "\t// demoted variable\n\t";
1167 printModuleLevelGV(GVar: GV, O, /*processDemoted=*/ProcessDemoted: true, STI);
1168 }
1169}
1170
1171void NVPTXAsmPrinter::emitPTXAddressSpace(unsigned int AddressSpace,
1172 raw_ostream &O) const {
1173 switch (AddressSpace) {
1174 case ADDRESS_SPACE_LOCAL:
1175 O << "local";
1176 break;
1177 case ADDRESS_SPACE_GLOBAL:
1178 O << "global";
1179 break;
1180 case ADDRESS_SPACE_CONST:
1181 O << "const";
1182 break;
1183 case ADDRESS_SPACE_SHARED:
1184 O << "shared";
1185 break;
1186 default:
1187 report_fatal_error(reason: "Bad address space found while emitting PTX: " +
1188 llvm::Twine(AddressSpace));
1189 break;
1190 }
1191}
1192
1193std::string
1194NVPTXAsmPrinter::getPTXFundamentalTypeStr(Type *Ty, bool useB4PTR) const {
1195 switch (Ty->getTypeID()) {
1196 case Type::IntegerTyID: {
1197 unsigned NumBits = cast<IntegerType>(Val: Ty)->getBitWidth();
1198 if (NumBits == 1)
1199 return "pred";
1200 if (NumBits <= 64) {
1201 std::string name = "u";
1202 return name + utostr(X: NumBits);
1203 }
1204 llvm_unreachable("Integer too large");
1205 break;
1206 }
1207 case Type::BFloatTyID:
1208 case Type::HalfTyID:
1209 // fp16 and bf16 are stored as .b16 for compatibility with pre-sm_53
1210 // PTX assembly.
1211 return "b16";
1212 case Type::FloatTyID:
1213 return "f32";
1214 case Type::DoubleTyID:
1215 return "f64";
1216 case Type::PointerTyID: {
1217 unsigned PtrSize = TM.getPointerSizeInBits(AS: Ty->getPointerAddressSpace());
1218 assert((PtrSize == 64 || PtrSize == 32) && "Unexpected pointer size");
1219
1220 if (PtrSize == 64)
1221 if (useB4PTR)
1222 return "b64";
1223 else
1224 return "u64";
1225 else if (useB4PTR)
1226 return "b32";
1227 else
1228 return "u32";
1229 }
1230 default:
1231 break;
1232 }
1233 llvm_unreachable("unexpected type");
1234}
1235
1236void NVPTXAsmPrinter::emitPTXGlobalVariable(const GlobalVariable *GVar,
1237 raw_ostream &O,
1238 const NVPTXSubtarget &STI) {
1239 const DataLayout &DL = getDataLayout();
1240
1241 // GlobalVariables are always constant pointers themselves.
1242 Type *ETy = GVar->getValueType();
1243
1244 O << ".";
1245 emitPTXAddressSpace(AddressSpace: GVar->getType()->getAddressSpace(), O);
1246 if (isManaged(*GVar)) {
1247 if (STI.getPTXVersion() < 40 || STI.getSmVersion() < 30)
1248 report_fatal_error(
1249 reason: ".attribute(.managed) requires PTX version >= 4.0 and sm_30");
1250
1251 O << " .attribute(.managed)";
1252 }
1253 O << " .align "
1254 << GVar->getAlign().value_or(u: DL.getPrefTypeAlign(Ty: ETy)).value();
1255
1256 // Special case for i128/fp128
1257 if (ETy->getScalarSizeInBits() == 128) {
1258 O << " .b8 ";
1259 getSymbol(GV: GVar)->print(OS&: O, MAI);
1260 O << "[16]";
1261 return;
1262 }
1263
1264 if (ETy->isFloatingPointTy() || ETy->isIntOrPtrTy()) {
1265 O << " ." << getPTXFundamentalTypeStr(Ty: ETy) << " ";
1266 getSymbol(GV: GVar)->print(OS&: O, MAI);
1267 return;
1268 }
1269
1270 int64_t ElementSize = 0;
1271
1272 // Although PTX has direct support for struct type and array type and LLVM IR
1273 // is very similar to PTX, the LLVM CodeGen does not support for targets that
1274 // support these high level field accesses. Structs and arrays are lowered
1275 // into arrays of bytes.
1276 switch (ETy->getTypeID()) {
1277 case Type::StructTyID:
1278 case Type::ArrayTyID:
1279 case Type::FixedVectorTyID:
1280 ElementSize = DL.getTypeStoreSize(Ty: ETy);
1281 O << " .b8 ";
1282 getSymbol(GV: GVar)->print(OS&: O, MAI);
1283 O << "[";
1284 if (ElementSize) {
1285 O << ElementSize;
1286 }
1287 O << "]";
1288 break;
1289 default:
1290 llvm_unreachable("type not supported yet");
1291 }
1292}
1293
1294void NVPTXAsmPrinter::emitFunctionParamList(const Function *F, raw_ostream &O) {
1295 const DataLayout &DL = getDataLayout();
1296 const NVPTXSubtarget &STI = TM.getSubtarget<NVPTXSubtarget>(F: *F);
1297 const auto *TLI = cast<NVPTXTargetLowering>(Val: STI.getTargetLowering());
1298 const NVPTXMachineFunctionInfo *MFI =
1299 MF ? MF->getInfo<NVPTXMachineFunctionInfo>() : nullptr;
1300
1301 bool IsFirst = true;
1302 const bool IsKernelFunc = isKernelFunction(F: *F);
1303
1304 if (F->arg_empty() && !F->isVarArg()) {
1305 O << "()";
1306 return;
1307 }
1308
1309 O << "(\n";
1310
1311 for (const Argument &Arg : F->args()) {
1312 Type *Ty = Arg.getType();
1313 const std::string ParamSym = TLI->getParamName(F, Idx: Arg.getArgNo());
1314
1315 if (!IsFirst)
1316 O << ",\n";
1317
1318 IsFirst = false;
1319
1320 // Handle image/sampler parameters
1321 if (IsKernelFunc) {
1322 const bool IsSampler = isSampler(Arg);
1323 const bool IsTexture = !IsSampler && isImageReadOnly(Arg);
1324 const bool IsSurface = !IsSampler && !IsTexture &&
1325 (isImageReadWrite(Arg) || isImageWriteOnly(Arg));
1326 if (IsSampler || IsTexture || IsSurface) {
1327 const bool EmitImgPtr = !MFI || !MFI->checkImageHandleSymbol(Symbol: ParamSym);
1328 O << "\t.param ";
1329 if (EmitImgPtr)
1330 O << ".u64 .ptr ";
1331
1332 if (IsSampler)
1333 O << ".samplerref ";
1334 else if (IsTexture)
1335 O << ".texref ";
1336 else // IsSurface
1337 O << ".surfref ";
1338 O << ParamSym;
1339 continue;
1340 }
1341 }
1342
1343 auto GetOptimalAlignForParam = [TLI, &DL, F, &Arg](Type *Ty) -> Align {
1344 if (MaybeAlign StackAlign =
1345 getAlign(F: *F, Index: Arg.getArgNo() + AttributeList::FirstArgIndex))
1346 return StackAlign.value();
1347
1348 Align TypeAlign = TLI->getFunctionParamOptimizedAlign(F, ArgTy: Ty, DL);
1349 MaybeAlign ParamAlign =
1350 Arg.hasByValAttr() ? Arg.getParamAlign() : MaybeAlign();
1351 return std::max(a: TypeAlign, b: ParamAlign.valueOrOne());
1352 };
1353
1354 if (Arg.hasByValAttr()) {
1355 // param has byVal attribute.
1356 Type *ETy = Arg.getParamByValType();
1357 assert(ETy && "Param should have byval type");
1358
1359 // Print .param .align <a> .b8 .param[size];
1360 // <a> = optimal alignment for the element type; always multiple of
1361 // PAL.getParamAlignment
1362 // size = typeallocsize of element type
1363 const Align OptimalAlign =
1364 IsKernelFunc ? GetOptimalAlignForParam(ETy)
1365 : TLI->getFunctionByValParamAlign(
1366 F, ArgTy: ETy, InitialAlign: Arg.getParamAlign().valueOrOne(), DL);
1367
1368 O << "\t.param .align " << OptimalAlign.value() << " .b8 " << ParamSym
1369 << "[" << DL.getTypeAllocSize(Ty: ETy) << "]";
1370 continue;
1371 }
1372
1373 if (shouldPassAsArray(Ty)) {
1374 // Just print .param .align <a> .b8 .param[size];
1375 // <a> = optimal alignment for the element type; always multiple of
1376 // PAL.getParamAlignment
1377 // size = typeallocsize of element type
1378 Align OptimalAlign = GetOptimalAlignForParam(Ty);
1379
1380 O << "\t.param .align " << OptimalAlign.value() << " .b8 " << ParamSym
1381 << "[" << DL.getTypeAllocSize(Ty) << "]";
1382
1383 continue;
1384 }
1385 // Just a scalar
1386 auto *PTy = dyn_cast<PointerType>(Val: Ty);
1387 unsigned PTySizeInBits = 0;
1388 if (PTy) {
1389 PTySizeInBits =
1390 TLI->getPointerTy(DL, AS: PTy->getAddressSpace()).getSizeInBits();
1391 assert(PTySizeInBits && "Invalid pointer size");
1392 }
1393
1394 if (IsKernelFunc) {
1395 if (PTy) {
1396 O << "\t.param .u" << PTySizeInBits << " .ptr";
1397
1398 switch (PTy->getAddressSpace()) {
1399 default:
1400 break;
1401 case ADDRESS_SPACE_GLOBAL:
1402 O << " .global";
1403 break;
1404 case ADDRESS_SPACE_SHARED:
1405 O << " .shared";
1406 break;
1407 case ADDRESS_SPACE_CONST:
1408 O << " .const";
1409 break;
1410 case ADDRESS_SPACE_LOCAL:
1411 O << " .local";
1412 break;
1413 }
1414
1415 O << " .align " << Arg.getParamAlign().valueOrOne().value() << " "
1416 << ParamSym;
1417 continue;
1418 }
1419
1420 // non-pointer scalar to kernel func
1421 O << "\t.param .";
1422 // Special case: predicate operands become .u8 types
1423 if (Ty->isIntegerTy(Bitwidth: 1))
1424 O << "u8";
1425 else
1426 O << getPTXFundamentalTypeStr(Ty);
1427 O << " " << ParamSym;
1428 continue;
1429 }
1430 // Non-kernel function, just print .param .b<size> for ABI
1431 // and .reg .b<size> for non-ABI
1432 unsigned Size;
1433 if (auto *ITy = dyn_cast<IntegerType>(Val: Ty)) {
1434 Size = promoteScalarArgumentSize(size: ITy->getBitWidth());
1435 } else if (PTy) {
1436 assert(PTySizeInBits && "Invalid pointer size");
1437 Size = PTySizeInBits;
1438 } else
1439 Size = Ty->getPrimitiveSizeInBits();
1440 O << "\t.param .b" << Size << " " << ParamSym;
1441 }
1442
1443 if (F->isVarArg()) {
1444 if (!IsFirst)
1445 O << ",\n";
1446 O << "\t.param .align " << STI.getMaxRequiredAlignment() << " .b8 "
1447 << TLI->getParamName(F, /* vararg */ Idx: -1) << "[]";
1448 }
1449
1450 O << "\n)";
1451}
1452
1453void NVPTXAsmPrinter::setAndEmitFunctionVirtualRegisters(
1454 const MachineFunction &MF) {
1455 SmallString<128> Str;
1456 raw_svector_ostream O(Str);
1457
1458 // Map the global virtual register number to a register class specific
1459 // virtual register number starting from 1 with that class.
1460 const TargetRegisterInfo *TRI = MF.getSubtarget().getRegisterInfo();
1461 //unsigned numRegClasses = TRI->getNumRegClasses();
1462
1463 // Emit the Fake Stack Object
1464 const MachineFrameInfo &MFI = MF.getFrameInfo();
1465 int64_t NumBytes = MFI.getStackSize();
1466 if (NumBytes) {
1467 O << "\t.local .align " << MFI.getMaxAlign().value() << " .b8 \t"
1468 << DEPOTNAME << getFunctionNumber() << "[" << NumBytes << "];\n";
1469 if (static_cast<const NVPTXTargetMachine &>(MF.getTarget()).is64Bit()) {
1470 O << "\t.reg .b64 \t%SP;\n"
1471 << "\t.reg .b64 \t%SPL;\n";
1472 } else {
1473 O << "\t.reg .b32 \t%SP;\n"
1474 << "\t.reg .b32 \t%SPL;\n";
1475 }
1476 }
1477
1478 // Go through all virtual registers to establish the mapping between the
1479 // global virtual
1480 // register number and the per class virtual register number.
1481 // We use the per class virtual register number in the ptx output.
1482 unsigned int numVRs = MRI->getNumVirtRegs();
1483 for (unsigned i = 0; i < numVRs; i++) {
1484 Register vr = Register::index2VirtReg(Index: i);
1485 const TargetRegisterClass *RC = MRI->getRegClass(Reg: vr);
1486 DenseMap<unsigned, unsigned> &regmap = VRegMapping[RC];
1487 int n = regmap.size();
1488 regmap.insert(KV: std::make_pair(x&: vr, y: n + 1));
1489 }
1490
1491 // Emit declaration of the virtual registers or 'physical' registers for
1492 // each register class
1493 for (const TargetRegisterClass *RC : TRI->regclasses()) {
1494 const unsigned N = VRegMapping[RC].size();
1495
1496 // Only declare those registers that may be used.
1497 if (N) {
1498 const StringRef RCName = getNVPTXRegClassName(RC);
1499 const StringRef RCStr = getNVPTXRegClassStr(RC);
1500 O << "\t.reg " << RCName << " \t" << RCStr << "<" << (N + 1) << ">;\n";
1501 }
1502 }
1503
1504 OutStreamer->emitRawText(String: O.str());
1505}
1506
1507/// Translate virtual register numbers in DebugInfo locations to their printed
1508/// encodings, as used by CUDA-GDB.
1509void NVPTXAsmPrinter::encodeDebugInfoRegisterNumbers(
1510 const MachineFunction &MF) {
1511 const NVPTXSubtarget &STI = MF.getSubtarget<NVPTXSubtarget>();
1512 const NVPTXRegisterInfo *registerInfo = STI.getRegisterInfo();
1513
1514 // Clear the old mapping, and add the new one. This mapping is used after the
1515 // printing of the current function is complete, but before the next function
1516 // is printed.
1517 registerInfo->clearDebugRegisterMap();
1518
1519 for (auto &classMap : VRegMapping) {
1520 for (auto &registerMapping : classMap.getSecond()) {
1521 auto reg = registerMapping.getFirst();
1522 registerInfo->addToDebugRegisterMap(preEncodedVirtualRegister: reg, RegisterName: getVirtualRegisterName(Reg: reg));
1523 }
1524 }
1525}
1526
1527void NVPTXAsmPrinter::printFPConstant(const ConstantFP *Fp,
1528 raw_ostream &O) const {
1529 APFloat APF = APFloat(Fp->getValueAPF()); // make a copy
1530 bool ignored;
1531 unsigned int numHex;
1532 const char *lead;
1533
1534 if (Fp->getType()->getTypeID() == Type::FloatTyID) {
1535 numHex = 8;
1536 lead = "0f";
1537 APF.convert(ToSemantics: APFloat::IEEEsingle(), RM: APFloat::rmNearestTiesToEven, losesInfo: &ignored);
1538 } else if (Fp->getType()->getTypeID() == Type::DoubleTyID) {
1539 numHex = 16;
1540 lead = "0d";
1541 APF.convert(ToSemantics: APFloat::IEEEdouble(), RM: APFloat::rmNearestTiesToEven, losesInfo: &ignored);
1542 } else
1543 llvm_unreachable("unsupported fp type");
1544
1545 APInt API = APF.bitcastToAPInt();
1546 O << lead << format_hex_no_prefix(N: API.getZExtValue(), Width: numHex, /*Upper=*/true);
1547}
1548
1549void NVPTXAsmPrinter::printScalarConstant(const Constant *CPV, raw_ostream &O) {
1550 if (const ConstantInt *CI = dyn_cast<ConstantInt>(Val: CPV)) {
1551 O << CI->getValue();
1552 return;
1553 }
1554 if (const ConstantFP *CFP = dyn_cast<ConstantFP>(Val: CPV)) {
1555 printFPConstant(Fp: CFP, O);
1556 return;
1557 }
1558 if (isa<ConstantPointerNull>(Val: CPV)) {
1559 O << "0";
1560 return;
1561 }
1562 if (const GlobalValue *GVar = dyn_cast<GlobalValue>(Val: CPV)) {
1563 const bool IsNonGenericPointer = GVar->getAddressSpace() != 0;
1564 if (EmitGeneric && !isa<Function>(Val: CPV) && !IsNonGenericPointer) {
1565 O << "generic(";
1566 getSymbol(GV: GVar)->print(OS&: O, MAI);
1567 O << ")";
1568 } else {
1569 getSymbol(GV: GVar)->print(OS&: O, MAI);
1570 }
1571 return;
1572 }
1573 if (const ConstantExpr *Cexpr = dyn_cast<ConstantExpr>(Val: CPV)) {
1574 const MCExpr *E = lowerConstantForGV(CV: cast<Constant>(Val: Cexpr), ProcessingGeneric: false);
1575 printMCExpr(Expr: *E, OS&: O);
1576 return;
1577 }
1578 llvm_unreachable("Not scalar type found in printScalarConstant()");
1579}
1580
1581void NVPTXAsmPrinter::bufferLEByte(const Constant *CPV, int Bytes,
1582 AggBuffer *AggBuffer) {
1583 const DataLayout &DL = getDataLayout();
1584 int AllocSize = DL.getTypeAllocSize(Ty: CPV->getType());
1585 if (isa<UndefValue>(Val: CPV) || CPV->isNullValue()) {
1586 // Non-zero Bytes indicates that we need to zero-fill everything. Otherwise,
1587 // only the space allocated by CPV.
1588 AggBuffer->addZeros(Num: Bytes ? Bytes : AllocSize);
1589 return;
1590 }
1591
1592 // Helper for filling AggBuffer with APInts.
1593 auto AddIntToBuffer = [AggBuffer, Bytes](const APInt &Val) {
1594 size_t NumBytes = (Val.getBitWidth() + 7) / 8;
1595 SmallVector<unsigned char, 16> Buf(NumBytes);
1596 // `extractBitsAsZExtValue` does not allow the extraction of bits beyond the
1597 // input's bit width, and i1 arrays may not have a length that is a multuple
1598 // of 8. We handle the last byte separately, so we never request out of
1599 // bounds bits.
1600 for (unsigned I = 0; I < NumBytes - 1; ++I) {
1601 Buf[I] = Val.extractBitsAsZExtValue(numBits: 8, bitPosition: I * 8);
1602 }
1603 size_t LastBytePosition = (NumBytes - 1) * 8;
1604 size_t LastByteBits = Val.getBitWidth() - LastBytePosition;
1605 Buf[NumBytes - 1] =
1606 Val.extractBitsAsZExtValue(numBits: LastByteBits, bitPosition: LastBytePosition);
1607 AggBuffer->addBytes(Ptr: Buf.data(), Num: NumBytes, Bytes);
1608 };
1609
1610 switch (CPV->getType()->getTypeID()) {
1611 case Type::IntegerTyID:
1612 if (const auto *CI = dyn_cast<ConstantInt>(Val: CPV)) {
1613 AddIntToBuffer(CI->getValue());
1614 break;
1615 }
1616 if (const auto *Cexpr = dyn_cast<ConstantExpr>(Val: CPV)) {
1617 if (const auto *CI =
1618 dyn_cast<ConstantInt>(Val: ConstantFoldConstant(C: Cexpr, DL))) {
1619 AddIntToBuffer(CI->getValue());
1620 break;
1621 }
1622 if (Cexpr->getOpcode() == Instruction::PtrToInt) {
1623 Value *V = Cexpr->getOperand(i_nocapture: 0)->stripPointerCasts();
1624 AggBuffer->addSymbol(GVar: V, GVarBeforeStripping: Cexpr->getOperand(i_nocapture: 0));
1625 AggBuffer->addZeros(Num: AllocSize);
1626 break;
1627 }
1628 }
1629 llvm_unreachable("unsupported integer const type");
1630 break;
1631
1632 case Type::HalfTyID:
1633 case Type::BFloatTyID:
1634 case Type::FloatTyID:
1635 case Type::DoubleTyID:
1636 AddIntToBuffer(cast<ConstantFP>(Val: CPV)->getValueAPF().bitcastToAPInt());
1637 break;
1638
1639 case Type::PointerTyID: {
1640 if (const GlobalValue *GVar = dyn_cast<GlobalValue>(Val: CPV)) {
1641 AggBuffer->addSymbol(GVar, GVarBeforeStripping: GVar);
1642 } else if (const ConstantExpr *Cexpr = dyn_cast<ConstantExpr>(Val: CPV)) {
1643 const Value *v = Cexpr->stripPointerCasts();
1644 AggBuffer->addSymbol(GVar: v, GVarBeforeStripping: Cexpr);
1645 }
1646 AggBuffer->addZeros(Num: AllocSize);
1647 break;
1648 }
1649
1650 case Type::ArrayTyID:
1651 case Type::FixedVectorTyID:
1652 case Type::StructTyID: {
1653 if (isa<ConstantAggregate>(Val: CPV) || isa<ConstantDataSequential>(Val: CPV)) {
1654 bufferAggregateConstant(CV: CPV, aggBuffer: AggBuffer);
1655 if (Bytes > AllocSize)
1656 AggBuffer->addZeros(Num: Bytes - AllocSize);
1657 } else if (isa<ConstantAggregateZero>(Val: CPV))
1658 AggBuffer->addZeros(Num: Bytes);
1659 else
1660 llvm_unreachable("Unexpected Constant type");
1661 break;
1662 }
1663
1664 default:
1665 llvm_unreachable("unsupported type");
1666 }
1667}
1668
1669void NVPTXAsmPrinter::bufferAggregateConstant(const Constant *CPV,
1670 AggBuffer *aggBuffer) {
1671 const DataLayout &DL = getDataLayout();
1672
1673 auto ExtendBuffer = [](APInt Val, AggBuffer *Buffer) {
1674 for (unsigned I : llvm::seq(Size: Val.getBitWidth() / 8))
1675 Buffer->addByte(Byte: Val.extractBitsAsZExtValue(numBits: 8, bitPosition: I * 8));
1676 };
1677
1678 // Integers of arbitrary width
1679 if (const ConstantInt *CI = dyn_cast<ConstantInt>(Val: CPV)) {
1680 ExtendBuffer(CI->getValue(), aggBuffer);
1681 return;
1682 }
1683
1684 // f128
1685 if (const ConstantFP *CFP = dyn_cast<ConstantFP>(Val: CPV)) {
1686 if (CFP->getType()->isFP128Ty()) {
1687 ExtendBuffer(CFP->getValueAPF().bitcastToAPInt(), aggBuffer);
1688 return;
1689 }
1690 }
1691
1692 // Old constants
1693 if (isa<ConstantArray>(Val: CPV) || isa<ConstantVector>(Val: CPV)) {
1694 for (const auto &Op : CPV->operands())
1695 bufferLEByte(CPV: cast<Constant>(Val: Op), Bytes: 0, AggBuffer: aggBuffer);
1696 return;
1697 }
1698
1699 if (const auto *CDS = dyn_cast<ConstantDataSequential>(Val: CPV)) {
1700 for (unsigned I : llvm::seq(Size: CDS->getNumElements()))
1701 bufferLEByte(CPV: cast<Constant>(Val: CDS->getElementAsConstant(i: I)), Bytes: 0, AggBuffer: aggBuffer);
1702 return;
1703 }
1704
1705 if (isa<ConstantStruct>(Val: CPV)) {
1706 if (CPV->getNumOperands()) {
1707 StructType *ST = cast<StructType>(Val: CPV->getType());
1708 for (unsigned I : llvm::seq(Size: CPV->getNumOperands())) {
1709 int EndOffset = (I + 1 == CPV->getNumOperands())
1710 ? DL.getStructLayout(Ty: ST)->getElementOffset(Idx: 0) +
1711 DL.getTypeAllocSize(Ty: ST)
1712 : DL.getStructLayout(Ty: ST)->getElementOffset(Idx: I + 1);
1713 int Bytes = EndOffset - DL.getStructLayout(Ty: ST)->getElementOffset(Idx: I);
1714 bufferLEByte(CPV: cast<Constant>(Val: CPV->getOperand(i: I)), Bytes, AggBuffer: aggBuffer);
1715 }
1716 }
1717 return;
1718 }
1719 llvm_unreachable("unsupported constant type in printAggregateConstant()");
1720}
1721
1722/// lowerConstantForGV - Return an MCExpr for the given Constant. This is mostly
1723/// a copy from AsmPrinter::lowerConstant, except customized to only handle
1724/// expressions that are representable in PTX and create
1725/// NVPTXGenericMCSymbolRefExpr nodes for addrspacecast instructions.
1726const MCExpr *
1727NVPTXAsmPrinter::lowerConstantForGV(const Constant *CV,
1728 bool ProcessingGeneric) const {
1729 MCContext &Ctx = OutContext;
1730
1731 if (CV->isNullValue() || isa<UndefValue>(Val: CV))
1732 return MCConstantExpr::create(Value: 0, Ctx);
1733
1734 if (const ConstantInt *CI = dyn_cast<ConstantInt>(Val: CV))
1735 return MCConstantExpr::create(Value: CI->getZExtValue(), Ctx);
1736
1737 if (const GlobalValue *GV = dyn_cast<GlobalValue>(Val: CV)) {
1738 const MCSymbolRefExpr *Expr = MCSymbolRefExpr::create(Symbol: getSymbol(GV), Ctx);
1739 if (ProcessingGeneric)
1740 return NVPTXGenericMCSymbolRefExpr::create(SymExpr: Expr, Ctx);
1741 return Expr;
1742 }
1743
1744 const ConstantExpr *CE = dyn_cast<ConstantExpr>(Val: CV);
1745 if (!CE) {
1746 llvm_unreachable("Unknown constant value to lower!");
1747 }
1748
1749 switch (CE->getOpcode()) {
1750 default:
1751 break; // Error
1752
1753 case Instruction::AddrSpaceCast: {
1754 // Strip the addrspacecast and pass along the operand
1755 PointerType *DstTy = cast<PointerType>(Val: CE->getType());
1756 if (DstTy->getAddressSpace() == 0)
1757 return lowerConstantForGV(CV: cast<const Constant>(Val: CE->getOperand(i_nocapture: 0)), ProcessingGeneric: true);
1758
1759 break; // Error
1760 }
1761
1762 case Instruction::GetElementPtr: {
1763 const DataLayout &DL = getDataLayout();
1764
1765 // Generate a symbolic expression for the byte address
1766 APInt OffsetAI(DL.getPointerTypeSizeInBits(CE->getType()), 0);
1767 cast<GEPOperator>(Val: CE)->accumulateConstantOffset(DL, Offset&: OffsetAI);
1768
1769 const MCExpr *Base = lowerConstantForGV(CV: CE->getOperand(i_nocapture: 0),
1770 ProcessingGeneric);
1771 if (!OffsetAI)
1772 return Base;
1773
1774 int64_t Offset = OffsetAI.getSExtValue();
1775 return MCBinaryExpr::createAdd(LHS: Base, RHS: MCConstantExpr::create(Value: Offset, Ctx),
1776 Ctx);
1777 }
1778
1779 case Instruction::Trunc:
1780 // We emit the value and depend on the assembler to truncate the generated
1781 // expression properly. This is important for differences between
1782 // blockaddress labels. Since the two labels are in the same function, it
1783 // is reasonable to treat their delta as a 32-bit value.
1784 [[fallthrough]];
1785 case Instruction::BitCast:
1786 return lowerConstantForGV(CV: CE->getOperand(i_nocapture: 0), ProcessingGeneric);
1787
1788 case Instruction::IntToPtr: {
1789 const DataLayout &DL = getDataLayout();
1790
1791 // Handle casts to pointers by changing them into casts to the appropriate
1792 // integer type. This promotes constant folding and simplifies this code.
1793 Constant *Op = CE->getOperand(i_nocapture: 0);
1794 Op = ConstantFoldIntegerCast(C: Op, DestTy: DL.getIntPtrType(CV->getType()),
1795 /*IsSigned*/ false, DL);
1796 if (Op)
1797 return lowerConstantForGV(CV: Op, ProcessingGeneric);
1798
1799 break; // Error
1800 }
1801
1802 case Instruction::PtrToInt: {
1803 const DataLayout &DL = getDataLayout();
1804
1805 // Support only foldable casts to/from pointers that can be eliminated by
1806 // changing the pointer to the appropriately sized integer type.
1807 Constant *Op = CE->getOperand(i_nocapture: 0);
1808 Type *Ty = CE->getType();
1809
1810 const MCExpr *OpExpr = lowerConstantForGV(CV: Op, ProcessingGeneric);
1811
1812 // We can emit the pointer value into this slot if the slot is an
1813 // integer slot equal to the size of the pointer.
1814 if (DL.getTypeAllocSize(Ty) == DL.getTypeAllocSize(Ty: Op->getType()))
1815 return OpExpr;
1816
1817 // Otherwise the pointer is smaller than the resultant integer, mask off
1818 // the high bits so we are sure to get a proper truncation if the input is
1819 // a constant expr.
1820 unsigned InBits = DL.getTypeAllocSizeInBits(Ty: Op->getType());
1821 const MCExpr *MaskExpr = MCConstantExpr::create(Value: ~0ULL >> (64-InBits), Ctx);
1822 return MCBinaryExpr::createAnd(LHS: OpExpr, RHS: MaskExpr, Ctx);
1823 }
1824
1825 // The MC library also has a right-shift operator, but it isn't consistently
1826 // signed or unsigned between different targets.
1827 case Instruction::Add: {
1828 const MCExpr *LHS = lowerConstantForGV(CV: CE->getOperand(i_nocapture: 0), ProcessingGeneric);
1829 const MCExpr *RHS = lowerConstantForGV(CV: CE->getOperand(i_nocapture: 1), ProcessingGeneric);
1830 switch (CE->getOpcode()) {
1831 default: llvm_unreachable("Unknown binary operator constant cast expr");
1832 case Instruction::Add: return MCBinaryExpr::createAdd(LHS, RHS, Ctx);
1833 }
1834 }
1835 }
1836
1837 // If the code isn't optimized, there may be outstanding folding
1838 // opportunities. Attempt to fold the expression using DataLayout as a
1839 // last resort before giving up.
1840 Constant *C = ConstantFoldConstant(C: CE, DL: getDataLayout());
1841 if (C != CE)
1842 return lowerConstantForGV(CV: C, ProcessingGeneric);
1843
1844 // Otherwise report the problem to the user.
1845 std::string S;
1846 raw_string_ostream OS(S);
1847 OS << "Unsupported expression in static initializer: ";
1848 CE->printAsOperand(O&: OS, /*PrintType=*/false,
1849 M: !MF ? nullptr : MF->getFunction().getParent());
1850 report_fatal_error(reason: Twine(OS.str()));
1851}
1852
1853void NVPTXAsmPrinter::printMCExpr(const MCExpr &Expr, raw_ostream &OS) const {
1854 OutContext.getAsmInfo()->printExpr(OS, Expr);
1855}
1856
1857/// PrintAsmOperand - Print out an operand for an inline asm expression.
1858///
1859bool NVPTXAsmPrinter::PrintAsmOperand(const MachineInstr *MI, unsigned OpNo,
1860 const char *ExtraCode, raw_ostream &O) {
1861 if (ExtraCode && ExtraCode[0]) {
1862 if (ExtraCode[1] != 0)
1863 return true; // Unknown modifier.
1864
1865 switch (ExtraCode[0]) {
1866 default:
1867 // See if this is a generic print operand
1868 return AsmPrinter::PrintAsmOperand(MI, OpNo, ExtraCode, OS&: O);
1869 case 'r':
1870 break;
1871 }
1872 }
1873
1874 printOperand(MI, OpNum: OpNo, O);
1875
1876 return false;
1877}
1878
1879bool NVPTXAsmPrinter::PrintAsmMemoryOperand(const MachineInstr *MI,
1880 unsigned OpNo,
1881 const char *ExtraCode,
1882 raw_ostream &O) {
1883 if (ExtraCode && ExtraCode[0])
1884 return true; // Unknown modifier
1885
1886 O << '[';
1887 printMemOperand(MI, OpNum: OpNo, O);
1888 O << ']';
1889
1890 return false;
1891}
1892
1893void NVPTXAsmPrinter::printOperand(const MachineInstr *MI, unsigned OpNum,
1894 raw_ostream &O) {
1895 const MachineOperand &MO = MI->getOperand(i: OpNum);
1896 switch (MO.getType()) {
1897 case MachineOperand::MO_Register:
1898 if (MO.getReg().isPhysical()) {
1899 if (MO.getReg() == NVPTX::VRDepot)
1900 O << DEPOTNAME << getFunctionNumber();
1901 else
1902 O << NVPTXInstPrinter::getRegisterName(Reg: MO.getReg());
1903 } else {
1904 emitVirtualRegister(vr: MO.getReg(), O);
1905 }
1906 break;
1907
1908 case MachineOperand::MO_Immediate:
1909 O << MO.getImm();
1910 break;
1911
1912 case MachineOperand::MO_FPImmediate:
1913 printFPConstant(Fp: MO.getFPImm(), O);
1914 break;
1915
1916 case MachineOperand::MO_GlobalAddress:
1917 PrintSymbolOperand(MO, OS&: O);
1918 break;
1919
1920 case MachineOperand::MO_MachineBasicBlock:
1921 MO.getMBB()->getSymbol()->print(OS&: O, MAI);
1922 break;
1923
1924 default:
1925 llvm_unreachable("Operand type not supported.");
1926 }
1927}
1928
1929void NVPTXAsmPrinter::printMemOperand(const MachineInstr *MI, unsigned OpNum,
1930 raw_ostream &O, const char *Modifier) {
1931 printOperand(MI, OpNum, O);
1932
1933 if (Modifier && strcmp(s1: Modifier, s2: "add") == 0) {
1934 O << ", ";
1935 printOperand(MI, OpNum: OpNum + 1, O);
1936 } else {
1937 if (MI->getOperand(i: OpNum + 1).isImm() &&
1938 MI->getOperand(i: OpNum + 1).getImm() == 0)
1939 return; // don't print ',0' or '+0'
1940 O << "+";
1941 printOperand(MI, OpNum: OpNum + 1, O);
1942 }
1943}
1944
1945char NVPTXAsmPrinter::ID = 0;
1946
1947INITIALIZE_PASS(NVPTXAsmPrinter, "nvptx-asm-printer", "NVPTX Assembly Printer",
1948 false, false)
1949
1950// Force static initialization.
1951extern "C" LLVM_ABI LLVM_EXTERNAL_VISIBILITY void
1952LLVMInitializeNVPTXAsmPrinter() {
1953 RegisterAsmPrinter<NVPTXAsmPrinter> X(getTheNVPTXTarget32());
1954 RegisterAsmPrinter<NVPTXAsmPrinter> Y(getTheNVPTXTarget64());
1955}
1956