1//===-- NVPTXAsmPrinter.cpp - NVPTX LLVM assembly writer ------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file contains a printer that converts from our internal representation
10// of machine-dependent LLVM code to NVPTX assembly language.
11//
12//===----------------------------------------------------------------------===//
13
14#include "NVPTXAsmPrinter.h"
15#include "MCTargetDesc/NVPTXBaseInfo.h"
16#include "MCTargetDesc/NVPTXInstPrinter.h"
17#include "MCTargetDesc/NVPTXMCAsmInfo.h"
18#include "MCTargetDesc/NVPTXTargetStreamer.h"
19#include "NVPTX.h"
20#include "NVPTXDwarfDebug.h"
21#include "NVPTXMCExpr.h"
22#include "NVPTXMachineFunctionInfo.h"
23#include "NVPTXRegisterInfo.h"
24#include "NVPTXSubtarget.h"
25#include "NVPTXTargetMachine.h"
26#include "NVPTXUtilities.h"
27#include "TargetInfo/NVPTXTargetInfo.h"
28#include "cl_common_defines.h"
29#include "llvm/ADT/APFloat.h"
30#include "llvm/ADT/APInt.h"
31#include "llvm/ADT/ArrayRef.h"
32#include "llvm/ADT/DenseMap.h"
33#include "llvm/ADT/DenseSet.h"
34#include "llvm/ADT/SmallString.h"
35#include "llvm/ADT/SmallVector.h"
36#include "llvm/ADT/StringExtras.h"
37#include "llvm/ADT/StringRef.h"
38#include "llvm/ADT/Twine.h"
39#include "llvm/ADT/iterator_range.h"
40#include "llvm/Analysis/ConstantFolding.h"
41#include "llvm/CodeGen/Analysis.h"
42#include "llvm/CodeGen/MachineBasicBlock.h"
43#include "llvm/CodeGen/MachineFrameInfo.h"
44#include "llvm/CodeGen/MachineFunction.h"
45#include "llvm/CodeGen/MachineInstr.h"
46#include "llvm/CodeGen/MachineLoopInfo.h"
47#include "llvm/CodeGen/MachineModuleInfo.h"
48#include "llvm/CodeGen/MachineOperand.h"
49#include "llvm/CodeGen/MachineRegisterInfo.h"
50#include "llvm/CodeGen/TargetRegisterInfo.h"
51#include "llvm/CodeGen/ValueTypes.h"
52#include "llvm/CodeGenTypes/MachineValueType.h"
53#include "llvm/IR/Argument.h"
54#include "llvm/IR/Attributes.h"
55#include "llvm/IR/BasicBlock.h"
56#include "llvm/IR/Constant.h"
57#include "llvm/IR/Constants.h"
58#include "llvm/IR/DataLayout.h"
59#include "llvm/IR/DebugInfo.h"
60#include "llvm/IR/DebugInfoMetadata.h"
61#include "llvm/IR/DebugLoc.h"
62#include "llvm/IR/DerivedTypes.h"
63#include "llvm/IR/Function.h"
64#include "llvm/IR/GlobalAlias.h"
65#include "llvm/IR/GlobalValue.h"
66#include "llvm/IR/GlobalVariable.h"
67#include "llvm/IR/Instruction.h"
68#include "llvm/IR/LLVMContext.h"
69#include "llvm/IR/Module.h"
70#include "llvm/IR/Operator.h"
71#include "llvm/IR/Type.h"
72#include "llvm/IR/User.h"
73#include "llvm/MC/MCExpr.h"
74#include "llvm/MC/MCInst.h"
75#include "llvm/MC/MCInstrDesc.h"
76#include "llvm/MC/MCStreamer.h"
77#include "llvm/MC/MCSymbol.h"
78#include "llvm/MC/TargetRegistry.h"
79#include "llvm/Support/Alignment.h"
80#include "llvm/Support/Casting.h"
81#include "llvm/Support/Compiler.h"
82#include "llvm/Support/Endian.h"
83#include "llvm/Support/ErrorHandling.h"
84#include "llvm/Support/NativeFormatting.h"
85#include "llvm/Support/raw_ostream.h"
86#include "llvm/Target/TargetLoweringObjectFile.h"
87#include "llvm/Target/TargetMachine.h"
88#include "llvm/Transforms/Utils/UnrollLoop.h"
89#include <cassert>
90#include <cstdint>
91#include <cstring>
92#include <string>
93
94using namespace llvm;
95
96#define DEPOTNAME "__local_depot"
97
98/// discoverDependentGlobals - Return a set of GlobalVariables on which \p V
99/// depends.
100static void
101discoverDependentGlobals(const Value *V,
102 DenseSet<const GlobalVariable *> &Globals) {
103 if (const GlobalVariable *GV = dyn_cast<GlobalVariable>(Val: V)) {
104 Globals.insert(V: GV);
105 return;
106 }
107
108 if (const User *U = dyn_cast<User>(Val: V))
109 for (const auto &O : U->operands())
110 discoverDependentGlobals(V: O, Globals);
111}
112
113/// VisitGlobalVariableForEmission - Add \p GV to the list of GlobalVariable
114/// instances to be emitted, but only after any dependents have been added
115/// first.s
116static void
117VisitGlobalVariableForEmission(const GlobalVariable *GV,
118 SmallVectorImpl<const GlobalVariable *> &Order,
119 DenseSet<const GlobalVariable *> &Visited,
120 DenseSet<const GlobalVariable *> &Visiting) {
121 // Have we already visited this one?
122 if (Visited.count(V: GV))
123 return;
124
125 // Do we have a circular dependency?
126 if (!Visiting.insert(V: GV).second)
127 report_fatal_error(reason: "Circular dependency found in global variable set");
128
129 // Make sure we visit all dependents first
130 DenseSet<const GlobalVariable *> Others;
131 for (const auto &O : GV->operands())
132 discoverDependentGlobals(V: O, Globals&: Others);
133
134 for (const GlobalVariable *GV : Others)
135 VisitGlobalVariableForEmission(GV, Order, Visited, Visiting);
136
137 // Now we can visit ourself
138 Order.push_back(Elt: GV);
139 Visited.insert(V: GV);
140 Visiting.erase(V: GV);
141}
142
143void NVPTXAsmPrinter::emitInstruction(const MachineInstr *MI) {
144 NVPTX_MC::verifyInstructionPredicates(Opcode: MI->getOpcode(),
145 Features: getSubtargetInfo().getFeatureBits());
146
147 MCInst Inst;
148 lowerToMCInst(MI, OutMI&: Inst);
149 EmitToStreamer(S&: *OutStreamer, Inst);
150}
151
152void NVPTXAsmPrinter::lowerToMCInst(const MachineInstr *MI, MCInst &OutMI) {
153 OutMI.setOpcode(MI->getOpcode());
154 // Special: Do not mangle symbol operand of CALL_PROTOTYPE
155 if (MI->getOpcode() == NVPTX::CALL_PROTOTYPE) {
156 const MachineOperand &MO = MI->getOperand(i: 0);
157 OutMI.addOperand(Op: GetSymbolRef(
158 Symbol: OutContext.getOrCreateSymbol(Name: Twine(MO.getSymbolName()))));
159 return;
160 }
161
162 for (const auto MO : MI->operands())
163 OutMI.addOperand(Op: lowerOperand(MO));
164}
165
166MCOperand NVPTXAsmPrinter::lowerOperand(const MachineOperand &MO) {
167 switch (MO.getType()) {
168 default:
169 llvm_unreachable("unknown operand type");
170 case MachineOperand::MO_Register:
171 return MCOperand::createReg(Reg: encodeVirtualRegister(Reg: MO.getReg()));
172 case MachineOperand::MO_Immediate:
173 return MCOperand::createImm(Val: MO.getImm());
174 case MachineOperand::MO_MachineBasicBlock:
175 return MCOperand::createExpr(
176 Val: MCSymbolRefExpr::create(Symbol: MO.getMBB()->getSymbol(), Ctx&: OutContext));
177 case MachineOperand::MO_ExternalSymbol:
178 return GetSymbolRef(Symbol: GetExternalSymbolSymbol(Sym: MO.getSymbolName()));
179 case MachineOperand::MO_GlobalAddress:
180 return GetSymbolRef(Symbol: getSymbol(GV: MO.getGlobal()));
181 case MachineOperand::MO_FPImmediate: {
182 const ConstantFP *Cnt = MO.getFPImm();
183 const APFloat &Val = Cnt->getValueAPF();
184
185 switch (Cnt->getType()->getTypeID()) {
186 default:
187 report_fatal_error(reason: "Unsupported FP type");
188 break;
189 case Type::HalfTyID:
190 return MCOperand::createExpr(
191 Val: NVPTXFloatMCExpr::createConstantFPHalf(Flt: Val, Ctx&: OutContext));
192 case Type::BFloatTyID:
193 return MCOperand::createExpr(
194 Val: NVPTXFloatMCExpr::createConstantBFPHalf(Flt: Val, Ctx&: OutContext));
195 case Type::FloatTyID:
196 return MCOperand::createExpr(
197 Val: NVPTXFloatMCExpr::createConstantFPSingle(Flt: Val, Ctx&: OutContext));
198 case Type::DoubleTyID:
199 return MCOperand::createExpr(
200 Val: NVPTXFloatMCExpr::createConstantFPDouble(Flt: Val, Ctx&: OutContext));
201 }
202 break;
203 }
204 }
205}
206
207unsigned NVPTXAsmPrinter::encodeVirtualRegister(unsigned Reg) {
208 if (Register::isVirtualRegister(Reg)) {
209 const TargetRegisterClass *RC = MRI->getRegClass(Reg);
210
211 DenseMap<unsigned, unsigned> &RegMap = VRegMapping[RC];
212 unsigned RegNum = RegMap[Reg];
213
214 // Encode the register class in the upper 4 bits
215 // Must be kept in sync with NVPTXInstPrinter::printRegName
216 unsigned Ret = 0;
217 if (RC == &NVPTX::B1RegClass) {
218 Ret = (1 << 28);
219 } else if (RC == &NVPTX::B16RegClass) {
220 Ret = (2 << 28);
221 } else if (RC == &NVPTX::B32RegClass) {
222 Ret = (3 << 28);
223 } else if (RC == &NVPTX::B64RegClass) {
224 Ret = (4 << 28);
225 } else if (RC == &NVPTX::B128RegClass) {
226 Ret = (7 << 28);
227 } else {
228 report_fatal_error(reason: "Bad register class");
229 }
230
231 // Insert the vreg number
232 Ret |= (RegNum & 0x0FFFFFFF);
233 return Ret;
234 } else {
235 // Some special-use registers are actually physical registers.
236 // Encode this as the register class ID of 0 and the real register ID.
237 return Reg & 0x0FFFFFFF;
238 }
239}
240
241MCOperand NVPTXAsmPrinter::GetSymbolRef(const MCSymbol *Symbol) {
242 const MCExpr *Expr;
243 Expr = MCSymbolRefExpr::create(Symbol, Ctx&: OutContext);
244 return MCOperand::createExpr(Val: Expr);
245}
246
247void NVPTXAsmPrinter::printReturnValStr(const Function *F, raw_ostream &O) {
248 const DataLayout &DL = getDataLayout();
249 const NVPTXSubtarget &STI = TM.getSubtarget<NVPTXSubtarget>(F: *F);
250 const auto *TLI = cast<NVPTXTargetLowering>(Val: STI.getTargetLowering());
251
252 Type *Ty = F->getReturnType();
253 if (Ty->getTypeID() == Type::VoidTyID)
254 return;
255 O << " (";
256
257 auto PrintScalarRetVal = [&](unsigned Size) {
258 O << ".param .b" << promoteScalarArgumentSize(size: Size) << " func_retval0";
259 };
260 if (shouldPassAsArray(Ty)) {
261 const unsigned TotalSize = DL.getTypeAllocSize(Ty);
262 const Align RetAlignment = TLI->getFunctionArgumentAlignment(
263 F, Ty, Idx: AttributeList::ReturnIndex, DL);
264 O << ".param .align " << RetAlignment.value() << " .b8 func_retval0["
265 << TotalSize << "]";
266 } else if (Ty->isFloatingPointTy()) {
267 PrintScalarRetVal(Ty->getPrimitiveSizeInBits());
268 } else if (auto *ITy = dyn_cast<IntegerType>(Val: Ty)) {
269 PrintScalarRetVal(ITy->getBitWidth());
270 } else if (isa<PointerType>(Val: Ty)) {
271 PrintScalarRetVal(TLI->getPointerTy(DL).getSizeInBits());
272 } else
273 llvm_unreachable("Unknown return type");
274 O << ") ";
275}
276
277void NVPTXAsmPrinter::printReturnValStr(const MachineFunction &MF,
278 raw_ostream &O) {
279 const Function &F = MF.getFunction();
280 printReturnValStr(F: &F, O);
281}
282
283// Return true if MBB is the header of a loop marked with
284// llvm.loop.unroll.disable or llvm.loop.unroll.count=1.
285bool NVPTXAsmPrinter::isLoopHeaderOfNoUnroll(
286 const MachineBasicBlock &MBB) const {
287 MachineLoopInfo &LI = getAnalysis<MachineLoopInfoWrapperPass>().getLI();
288 // We insert .pragma "nounroll" only to the loop header.
289 if (!LI.isLoopHeader(BB: &MBB))
290 return false;
291
292 // llvm.loop.unroll.disable is marked on the back edges of a loop. Therefore,
293 // we iterate through each back edge of the loop with header MBB, and check
294 // whether its metadata contains llvm.loop.unroll.disable.
295 for (const MachineBasicBlock *PMBB : MBB.predecessors()) {
296 if (LI.getLoopFor(BB: PMBB) != LI.getLoopFor(BB: &MBB)) {
297 // Edges from other loops to MBB are not back edges.
298 continue;
299 }
300 if (const BasicBlock *PBB = PMBB->getBasicBlock()) {
301 if (MDNode *LoopID =
302 PBB->getTerminator()->getMetadata(KindID: LLVMContext::MD_loop)) {
303 if (GetUnrollMetadata(LoopID, Name: "llvm.loop.unroll.disable"))
304 return true;
305 if (MDNode *UnrollCountMD =
306 GetUnrollMetadata(LoopID, Name: "llvm.loop.unroll.count")) {
307 if (mdconst::extract<ConstantInt>(MD: UnrollCountMD->getOperand(I: 1))
308 ->isOne())
309 return true;
310 }
311 }
312 }
313 }
314 return false;
315}
316
317void NVPTXAsmPrinter::emitBasicBlockStart(const MachineBasicBlock &MBB) {
318 AsmPrinter::emitBasicBlockStart(MBB);
319 if (isLoopHeaderOfNoUnroll(MBB))
320 OutStreamer->emitRawText(String: StringRef("\t.pragma \"nounroll\";\n"));
321}
322
323void NVPTXAsmPrinter::emitFunctionEntryLabel() {
324 SmallString<128> Str;
325 raw_svector_ostream O(Str);
326
327 if (!GlobalsEmitted) {
328 emitGlobals(M: *MF->getFunction().getParent());
329 GlobalsEmitted = true;
330 }
331
332 // Set up
333 MRI = &MF->getRegInfo();
334 F = &MF->getFunction();
335 emitLinkageDirective(V: F, O);
336 if (isKernelFunction(F: *F))
337 O << ".entry ";
338 else {
339 O << ".func ";
340 printReturnValStr(MF: *MF, O);
341 }
342
343 CurrentFnSym->print(OS&: O, MAI);
344
345 emitFunctionParamList(F, O);
346 O << "\n";
347
348 if (isKernelFunction(F: *F))
349 emitKernelFunctionDirectives(F: *F, O);
350
351 if (shouldEmitPTXNoReturn(V: F, TM))
352 O << ".noreturn";
353
354 OutStreamer->emitRawText(String: O.str());
355
356 VRegMapping.clear();
357 // Emit open brace for function body.
358 OutStreamer->emitRawText(String: StringRef("{\n"));
359 setAndEmitFunctionVirtualRegisters(*MF);
360 encodeDebugInfoRegisterNumbers(MF: *MF);
361 // Emit initial .loc debug directive for correct relocation symbol data.
362 if (const DISubprogram *SP = MF->getFunction().getSubprogram()) {
363 assert(SP->getUnit());
364 if (!SP->getUnit()->isDebugDirectivesOnly())
365 emitInitialRawDwarfLocDirective(MF: *MF);
366 }
367}
368
369bool NVPTXAsmPrinter::runOnMachineFunction(MachineFunction &F) {
370 bool Result = AsmPrinter::runOnMachineFunction(MF&: F);
371 // Emit closing brace for the body of function F.
372 // The closing brace must be emitted here because we need to emit additional
373 // debug labels/data after the last basic block.
374 // We need to emit the closing brace here because we don't have function that
375 // finished emission of the function body.
376 OutStreamer->emitRawText(String: StringRef("}\n"));
377 return Result;
378}
379
380void NVPTXAsmPrinter::emitFunctionBodyStart() {
381 SmallString<128> Str;
382 raw_svector_ostream O(Str);
383 emitDemotedVars(&MF->getFunction(), O);
384 OutStreamer->emitRawText(String: O.str());
385}
386
387void NVPTXAsmPrinter::emitFunctionBodyEnd() {
388 VRegMapping.clear();
389}
390
391const MCSymbol *NVPTXAsmPrinter::getFunctionFrameSymbol() const {
392 SmallString<128> Str;
393 raw_svector_ostream(Str) << DEPOTNAME << getFunctionNumber();
394 return OutContext.getOrCreateSymbol(Name: Str);
395}
396
397void NVPTXAsmPrinter::emitImplicitDef(const MachineInstr *MI) const {
398 Register RegNo = MI->getOperand(i: 0).getReg();
399 if (RegNo.isVirtual()) {
400 OutStreamer->AddComment(T: Twine("implicit-def: ") +
401 getVirtualRegisterName(RegNo));
402 } else {
403 const NVPTXSubtarget &STI = MI->getMF()->getSubtarget<NVPTXSubtarget>();
404 OutStreamer->AddComment(T: Twine("implicit-def: ") +
405 STI.getRegisterInfo()->getName(RegNo));
406 }
407 OutStreamer->addBlankLine();
408}
409
410void NVPTXAsmPrinter::emitKernelFunctionDirectives(const Function &F,
411 raw_ostream &O) const {
412 // If the NVVM IR has some of reqntid* specified, then output
413 // the reqntid directive, and set the unspecified ones to 1.
414 // If none of Reqntid* is specified, don't output reqntid directive.
415 const auto ReqNTID = getReqNTID(F);
416 if (!ReqNTID.empty())
417 O << formatv(Fmt: ".reqntid {0:$[, ]}\n",
418 Vals: make_range(x: ReqNTID.begin(), y: ReqNTID.end()));
419
420 const auto MaxNTID = getMaxNTID(F);
421 if (!MaxNTID.empty())
422 O << formatv(Fmt: ".maxntid {0:$[, ]}\n",
423 Vals: make_range(x: MaxNTID.begin(), y: MaxNTID.end()));
424
425 if (const auto Mincta = getMinCTASm(F))
426 O << ".minnctapersm " << *Mincta << "\n";
427
428 if (const auto Maxnreg = getMaxNReg(F))
429 O << ".maxnreg " << *Maxnreg << "\n";
430
431 // .maxclusterrank directive requires SM_90 or higher, make sure that we
432 // filter it out for lower SM versions, as it causes a hard ptxas crash.
433 const NVPTXTargetMachine &NTM = static_cast<const NVPTXTargetMachine &>(TM);
434 const NVPTXSubtarget *STI = &NTM.getSubtarget<NVPTXSubtarget>(F);
435
436 if (STI->getSmVersion() >= 90) {
437 const auto ClusterDim = getClusterDim(F);
438 const bool BlocksAreClusters = hasBlocksAreClusters(F);
439
440 if (!ClusterDim.empty()) {
441
442 if (!BlocksAreClusters)
443 O << ".explicitcluster\n";
444
445 if (ClusterDim[0] != 0) {
446 assert(llvm::all_of(ClusterDim, not_equal_to(0)) &&
447 "cluster_dim_x != 0 implies cluster_dim_y and cluster_dim_z "
448 "should be non-zero as well");
449
450 O << formatv(Fmt: ".reqnctapercluster {0:$[, ]}\n",
451 Vals: make_range(x: ClusterDim.begin(), y: ClusterDim.end()));
452 } else {
453 assert(llvm::all_of(ClusterDim, equal_to(0)) &&
454 "cluster_dim_x == 0 implies cluster_dim_y and cluster_dim_z "
455 "should be 0 as well");
456 }
457 }
458
459 if (BlocksAreClusters) {
460 LLVMContext &Ctx = F.getContext();
461 if (ReqNTID.empty() || ClusterDim.empty())
462 Ctx.diagnose(DI: DiagnosticInfoUnsupported(
463 F, "blocksareclusters requires reqntid and cluster_dim attributes",
464 F.getSubprogram()));
465 else if (STI->getPTXVersion() < 90)
466 Ctx.diagnose(DI: DiagnosticInfoUnsupported(
467 F, "blocksareclusters requires PTX version >= 9.0",
468 F.getSubprogram()));
469 else
470 O << ".blocksareclusters\n";
471 }
472
473 if (const auto Maxclusterrank = getMaxClusterRank(F))
474 O << ".maxclusterrank " << *Maxclusterrank << "\n";
475 }
476}
477
478std::string NVPTXAsmPrinter::getVirtualRegisterName(unsigned Reg) const {
479 const TargetRegisterClass *RC = MRI->getRegClass(Reg);
480
481 std::string Name;
482 raw_string_ostream NameStr(Name);
483
484 VRegRCMap::const_iterator I = VRegMapping.find(Val: RC);
485 assert(I != VRegMapping.end() && "Bad register class");
486 const DenseMap<unsigned, unsigned> &RegMap = I->second;
487
488 VRegMap::const_iterator VI = RegMap.find(Val: Reg);
489 assert(VI != RegMap.end() && "Bad virtual register");
490 unsigned MappedVR = VI->second;
491
492 NameStr << getNVPTXRegClassStr(RC) << MappedVR;
493
494 return Name;
495}
496
497void NVPTXAsmPrinter::emitVirtualRegister(unsigned int vr,
498 raw_ostream &O) {
499 O << getVirtualRegisterName(Reg: vr);
500}
501
502void NVPTXAsmPrinter::emitAliasDeclaration(const GlobalAlias *GA,
503 raw_ostream &O) {
504 const Function *F = dyn_cast_or_null<Function>(Val: GA->getAliaseeObject());
505 if (!F || isKernelFunction(F: *F) || F->isDeclaration())
506 report_fatal_error(
507 reason: "NVPTX aliasee must be a non-kernel function definition");
508
509 if (GA->hasLinkOnceLinkage() || GA->hasWeakLinkage() ||
510 GA->hasAvailableExternallyLinkage() || GA->hasCommonLinkage())
511 report_fatal_error(reason: "NVPTX aliasee must not be '.weak'");
512
513 emitDeclarationWithName(F, getSymbol(GV: GA), O);
514}
515
516void NVPTXAsmPrinter::emitDeclaration(const Function *F, raw_ostream &O) {
517 emitDeclarationWithName(F, getSymbol(GV: F), O);
518}
519
520void NVPTXAsmPrinter::emitDeclarationWithName(const Function *F, MCSymbol *S,
521 raw_ostream &O) {
522 emitLinkageDirective(V: F, O);
523 if (isKernelFunction(F: *F))
524 O << ".entry ";
525 else
526 O << ".func ";
527 printReturnValStr(F, O);
528 S->print(OS&: O, MAI);
529 O << "\n";
530 emitFunctionParamList(F, O);
531 O << "\n";
532 if (shouldEmitPTXNoReturn(V: F, TM))
533 O << ".noreturn";
534 O << ";\n";
535}
536
537static bool usedInGlobalVarDef(const Constant *C) {
538 if (!C)
539 return false;
540
541 if (const GlobalVariable *GV = dyn_cast<GlobalVariable>(Val: C))
542 return GV->getName() != "llvm.used";
543
544 for (const User *U : C->users())
545 if (const Constant *C = dyn_cast<Constant>(Val: U))
546 if (usedInGlobalVarDef(C))
547 return true;
548
549 return false;
550}
551
552static bool usedInOneFunc(const User *U, Function const *&OneFunc) {
553 if (const GlobalVariable *OtherGV = dyn_cast<GlobalVariable>(Val: U))
554 if (OtherGV->getName() == "llvm.used")
555 return true;
556
557 if (const Instruction *I = dyn_cast<Instruction>(Val: U)) {
558 if (const Function *CurFunc = I->getFunction()) {
559 if (OneFunc && (CurFunc != OneFunc))
560 return false;
561 OneFunc = CurFunc;
562 return true;
563 }
564 return false;
565 }
566
567 for (const User *UU : U->users())
568 if (!usedInOneFunc(U: UU, OneFunc))
569 return false;
570
571 return true;
572}
573
574/* Find out if a global variable can be demoted to local scope.
575 * Currently, this is valid for CUDA shared variables, which have local
576 * scope and global lifetime. So the conditions to check are :
577 * 1. Is the global variable in shared address space?
578 * 2. Does it have local linkage?
579 * 3. Is the global variable referenced only in one function?
580 */
581static bool canDemoteGlobalVar(const GlobalVariable *GV, Function const *&f) {
582 if (!GV->hasLocalLinkage())
583 return false;
584 if (GV->getAddressSpace() != ADDRESS_SPACE_SHARED)
585 return false;
586
587 const Function *oneFunc = nullptr;
588
589 bool flag = usedInOneFunc(U: GV, OneFunc&: oneFunc);
590 if (!flag)
591 return false;
592 if (!oneFunc)
593 return false;
594 f = oneFunc;
595 return true;
596}
597
598static bool useFuncSeen(const Constant *C,
599 const SmallPtrSetImpl<const Function *> &SeenSet) {
600 for (const User *U : C->users()) {
601 if (const Constant *cu = dyn_cast<Constant>(Val: U)) {
602 if (useFuncSeen(C: cu, SeenSet))
603 return true;
604 } else if (const Instruction *I = dyn_cast<Instruction>(Val: U)) {
605 if (const Function *Caller = I->getFunction())
606 if (SeenSet.contains(Ptr: Caller))
607 return true;
608 }
609 }
610 return false;
611}
612
613void NVPTXAsmPrinter::emitDeclarations(const Module &M, raw_ostream &O) {
614 SmallPtrSet<const Function *, 32> SeenSet;
615 for (const Function &F : M) {
616 if (F.getAttributes().hasFnAttr(Kind: "nvptx-libcall-callee")) {
617 emitDeclaration(F: &F, O);
618 continue;
619 }
620
621 if (F.isDeclaration()) {
622 if (F.use_empty())
623 continue;
624 if (F.getIntrinsicID())
625 continue;
626 emitDeclaration(F: &F, O);
627 continue;
628 }
629 for (const User *U : F.users()) {
630 if (const Constant *C = dyn_cast<Constant>(Val: U)) {
631 if (usedInGlobalVarDef(C)) {
632 // The use is in the initialization of a global variable
633 // that is a function pointer, so print a declaration
634 // for the original function
635 emitDeclaration(F: &F, O);
636 break;
637 }
638 // Emit a declaration of this function if the function that
639 // uses this constant expr has already been seen.
640 if (useFuncSeen(C, SeenSet)) {
641 emitDeclaration(F: &F, O);
642 break;
643 }
644 }
645
646 if (!isa<Instruction>(Val: U))
647 continue;
648 const Function *Caller = cast<Instruction>(Val: U)->getFunction();
649 if (!Caller)
650 continue;
651
652 // If a caller has already been seen, then the caller is
653 // appearing in the module before the callee. so print out
654 // a declaration for the callee.
655 if (SeenSet.contains(Ptr: Caller)) {
656 emitDeclaration(F: &F, O);
657 break;
658 }
659 }
660 SeenSet.insert(Ptr: &F);
661 }
662 for (const GlobalAlias &GA : M.aliases())
663 emitAliasDeclaration(GA: &GA, O);
664}
665
666void NVPTXAsmPrinter::emitStartOfAsmFile(Module &M) {
667 // Construct a default subtarget off of the TargetMachine defaults. The
668 // rest of NVPTX isn't friendly to change subtargets per function and
669 // so the default TargetMachine will have all of the options.
670 const NVPTXTargetMachine &NTM = static_cast<const NVPTXTargetMachine &>(TM);
671 const NVPTXSubtarget *STI = NTM.getSubtargetImpl();
672 SmallString<128> Str1;
673 raw_svector_ostream OS1(Str1);
674
675 // Emit header before any dwarf directives are emitted below.
676 emitHeader(M, O&: OS1, STI: *STI);
677 OutStreamer->emitRawText(String: OS1.str());
678}
679
680/// Create NVPTX-specific DwarfDebug handler.
681DwarfDebug *NVPTXAsmPrinter::createDwarfDebug() {
682 return new NVPTXDwarfDebug(this);
683}
684
685bool NVPTXAsmPrinter::doInitialization(Module &M) {
686 const NVPTXTargetMachine &NTM = static_cast<const NVPTXTargetMachine &>(TM);
687 const NVPTXSubtarget &STI = *NTM.getSubtargetImpl();
688 if (M.alias_size() && (STI.getPTXVersion() < 63 || STI.getSmVersion() < 30))
689 report_fatal_error(reason: ".alias requires PTX version >= 6.3 and sm_30");
690
691 // We need to call the parent's one explicitly.
692 bool Result = AsmPrinter::doInitialization(M);
693
694 GlobalsEmitted = false;
695
696 return Result;
697}
698
699void NVPTXAsmPrinter::emitGlobals(const Module &M) {
700 SmallString<128> Str2;
701 raw_svector_ostream OS2(Str2);
702
703 emitDeclarations(M, O&: OS2);
704
705 // As ptxas does not support forward references of globals, we need to first
706 // sort the list of module-level globals in def-use order. We visit each
707 // global variable in order, and ensure that we emit it *after* its dependent
708 // globals. We use a little extra memory maintaining both a set and a list to
709 // have fast searches while maintaining a strict ordering.
710 SmallVector<const GlobalVariable *, 8> Globals;
711 DenseSet<const GlobalVariable *> GVVisited;
712 DenseSet<const GlobalVariable *> GVVisiting;
713
714 // Visit each global variable, in order
715 for (const GlobalVariable &I : M.globals())
716 VisitGlobalVariableForEmission(GV: &I, Order&: Globals, Visited&: GVVisited, Visiting&: GVVisiting);
717
718 assert(GVVisited.size() == M.global_size() && "Missed a global variable");
719 assert(GVVisiting.size() == 0 && "Did not fully process a global variable");
720
721 const NVPTXTargetMachine &NTM = static_cast<const NVPTXTargetMachine &>(TM);
722 const NVPTXSubtarget &STI = *NTM.getSubtargetImpl();
723
724 // Print out module-level global variables in proper order
725 for (const GlobalVariable *GV : Globals)
726 printModuleLevelGV(GVar: GV, O&: OS2, /*ProcessDemoted=*/processDemoted: false, STI);
727
728 OS2 << '\n';
729
730 OutStreamer->emitRawText(String: OS2.str());
731}
732
733void NVPTXAsmPrinter::emitGlobalAlias(const Module &M, const GlobalAlias &GA) {
734 SmallString<128> Str;
735 raw_svector_ostream OS(Str);
736
737 MCSymbol *Name = getSymbol(GV: &GA);
738
739 OS << ".alias " << Name->getName() << ", " << GA.getAliaseeObject()->getName()
740 << ";\n";
741
742 OutStreamer->emitRawText(String: OS.str());
743}
744
745void NVPTXAsmPrinter::emitHeader(Module &M, raw_ostream &O,
746 const NVPTXSubtarget &STI) {
747 const unsigned PTXVersion = STI.getPTXVersion();
748
749 O << "//\n"
750 "// Generated by LLVM NVPTX Back-End\n"
751 "//\n"
752 "\n"
753 << ".version " << (PTXVersion / 10) << "." << (PTXVersion % 10) << "\n"
754 << ".target " << STI.getTargetName();
755
756 const NVPTXTargetMachine &NTM = static_cast<const NVPTXTargetMachine &>(TM);
757 if (NTM.getDrvInterface() == NVPTX::NVCL)
758 O << ", texmode_independent";
759
760 bool HasFullDebugInfo = false;
761 for (DICompileUnit *CU : M.debug_compile_units()) {
762 switch(CU->getEmissionKind()) {
763 case DICompileUnit::NoDebug:
764 case DICompileUnit::DebugDirectivesOnly:
765 break;
766 case DICompileUnit::LineTablesOnly:
767 case DICompileUnit::FullDebug:
768 HasFullDebugInfo = true;
769 break;
770 }
771 if (HasFullDebugInfo)
772 break;
773 }
774 if (HasFullDebugInfo)
775 O << ", debug";
776
777 O << "\n"
778 << ".address_size " << (NTM.is64Bit() ? "64" : "32") << "\n"
779 << "\n";
780}
781
782bool NVPTXAsmPrinter::doFinalization(Module &M) {
783 // If we did not emit any functions, then the global declarations have not
784 // yet been emitted.
785 if (!GlobalsEmitted) {
786 emitGlobals(M);
787 GlobalsEmitted = true;
788 }
789
790 // call doFinalization
791 bool ret = AsmPrinter::doFinalization(M);
792
793 clearAnnotationCache(&M);
794
795 auto *TS =
796 static_cast<NVPTXTargetStreamer *>(OutStreamer->getTargetStreamer());
797 // Close the last emitted section
798 if (hasDebugInfo()) {
799 TS->closeLastSection();
800 // Emit empty .debug_macinfo section for better support of the empty files.
801 OutStreamer->emitRawText(String: "\t.section\t.debug_macinfo\t{\t}");
802 }
803
804 // Output last DWARF .file directives, if any.
805 TS->outputDwarfFileDirectives();
806
807 return ret;
808}
809
810// This function emits appropriate linkage directives for
811// functions and global variables.
812//
813// extern function declaration -> .extern
814// extern function definition -> .visible
815// external global variable with init -> .visible
816// external without init -> .extern
817// appending -> not allowed, assert.
818// for any linkage other than
819// internal, private, linker_private,
820// linker_private_weak, linker_private_weak_def_auto,
821// we emit -> .weak.
822
823void NVPTXAsmPrinter::emitLinkageDirective(const GlobalValue *V,
824 raw_ostream &O) {
825 if (static_cast<NVPTXTargetMachine &>(TM).getDrvInterface() == NVPTX::CUDA) {
826 if (V->hasExternalLinkage()) {
827 if (const auto *GVar = dyn_cast<GlobalVariable>(Val: V))
828 O << (GVar->hasInitializer() ? ".visible " : ".extern ");
829 else if (V->isDeclaration())
830 O << ".extern ";
831 else
832 O << ".visible ";
833 } else if (V->hasAppendingLinkage()) {
834 report_fatal_error(reason: "Symbol '" + (V->hasName() ? V->getName() : "") +
835 "' has unsupported appending linkage type");
836 } else if (!V->hasInternalLinkage() && !V->hasPrivateLinkage()) {
837 O << ".weak ";
838 }
839 }
840}
841
842void NVPTXAsmPrinter::printModuleLevelGV(const GlobalVariable *GVar,
843 raw_ostream &O, bool ProcessDemoted,
844 const NVPTXSubtarget &STI) {
845 // Skip meta data
846 if (GVar->hasSection())
847 if (GVar->getSection() == "llvm.metadata")
848 return;
849
850 // Skip LLVM intrinsic global variables
851 if (GVar->getName().starts_with(Prefix: "llvm.") ||
852 GVar->getName().starts_with(Prefix: "nvvm."))
853 return;
854
855 const DataLayout &DL = getDataLayout();
856
857 // GlobalVariables are always constant pointers themselves.
858 Type *ETy = GVar->getValueType();
859
860 if (GVar->hasExternalLinkage()) {
861 if (GVar->hasInitializer())
862 O << ".visible ";
863 else
864 O << ".extern ";
865 } else if (STI.getPTXVersion() >= 50 && GVar->hasCommonLinkage() &&
866 GVar->getAddressSpace() == ADDRESS_SPACE_GLOBAL) {
867 O << ".common ";
868 } else if (GVar->hasLinkOnceLinkage() || GVar->hasWeakLinkage() ||
869 GVar->hasAvailableExternallyLinkage() ||
870 GVar->hasCommonLinkage()) {
871 O << ".weak ";
872 }
873
874 if (isTexture(*GVar)) {
875 O << ".global .texref " << getTextureName(*GVar) << ";\n";
876 return;
877 }
878
879 if (isSurface(*GVar)) {
880 O << ".global .surfref " << getSurfaceName(*GVar) << ";\n";
881 return;
882 }
883
884 if (GVar->isDeclaration()) {
885 // (extern) declarations, no definition or initializer
886 // Currently the only known declaration is for an automatic __local
887 // (.shared) promoted to global.
888 emitPTXGlobalVariable(GVar, O, STI);
889 O << ";\n";
890 return;
891 }
892
893 if (isSampler(*GVar)) {
894 O << ".global .samplerref " << getSamplerName(*GVar);
895
896 const Constant *Initializer = nullptr;
897 if (GVar->hasInitializer())
898 Initializer = GVar->getInitializer();
899 const ConstantInt *CI = nullptr;
900 if (Initializer)
901 CI = dyn_cast<ConstantInt>(Val: Initializer);
902 if (CI) {
903 unsigned sample = CI->getZExtValue();
904
905 O << " = { ";
906
907 for (int i = 0,
908 addr = ((sample & __CLK_ADDRESS_MASK) >> __CLK_ADDRESS_BASE);
909 i < 3; i++) {
910 O << "addr_mode_" << i << " = ";
911 switch (addr) {
912 case 0:
913 O << "wrap";
914 break;
915 case 1:
916 O << "clamp_to_border";
917 break;
918 case 2:
919 O << "clamp_to_edge";
920 break;
921 case 3:
922 O << "wrap";
923 break;
924 case 4:
925 O << "mirror";
926 break;
927 }
928 O << ", ";
929 }
930 O << "filter_mode = ";
931 switch ((sample & __CLK_FILTER_MASK) >> __CLK_FILTER_BASE) {
932 case 0:
933 O << "nearest";
934 break;
935 case 1:
936 O << "linear";
937 break;
938 case 2:
939 llvm_unreachable("Anisotropic filtering is not supported");
940 default:
941 O << "nearest";
942 break;
943 }
944 if (!((sample & __CLK_NORMALIZED_MASK) >> __CLK_NORMALIZED_BASE)) {
945 O << ", force_unnormalized_coords = 1";
946 }
947 O << " }";
948 }
949
950 O << ";\n";
951 return;
952 }
953
954 if (GVar->hasPrivateLinkage()) {
955 if (GVar->getName().starts_with(Prefix: "unrollpragma"))
956 return;
957
958 // FIXME - need better way (e.g. Metadata) to avoid generating this global
959 if (GVar->getName().starts_with(Prefix: "filename"))
960 return;
961 if (GVar->use_empty())
962 return;
963 }
964
965 const Function *DemotedFunc = nullptr;
966 if (!ProcessDemoted && canDemoteGlobalVar(GV: GVar, f&: DemotedFunc)) {
967 O << "// " << GVar->getName() << " has been demoted\n";
968 localDecls[DemotedFunc].push_back(x: GVar);
969 return;
970 }
971
972 O << ".";
973 emitPTXAddressSpace(AddressSpace: GVar->getAddressSpace(), O);
974
975 if (isManaged(*GVar)) {
976 if (STI.getPTXVersion() < 40 || STI.getSmVersion() < 30)
977 report_fatal_error(
978 reason: ".attribute(.managed) requires PTX version >= 4.0 and sm_30");
979 O << " .attribute(.managed)";
980 }
981
982 O << " .align "
983 << GVar->getAlign().value_or(u: DL.getPrefTypeAlign(Ty: ETy)).value();
984
985 if (ETy->isPointerTy() || ((ETy->isIntegerTy() || ETy->isFloatingPointTy()) &&
986 ETy->getScalarSizeInBits() <= 64)) {
987 O << " .";
988 // Special case: ABI requires that we use .u8 for predicates
989 if (ETy->isIntegerTy(Bitwidth: 1))
990 O << "u8";
991 else
992 O << getPTXFundamentalTypeStr(Ty: ETy, false);
993 O << " ";
994 getSymbol(GV: GVar)->print(OS&: O, MAI);
995
996 // Ptx allows variable initilization only for constant and global state
997 // spaces.
998 if (GVar->hasInitializer()) {
999 if ((GVar->getAddressSpace() == ADDRESS_SPACE_GLOBAL) ||
1000 (GVar->getAddressSpace() == ADDRESS_SPACE_CONST)) {
1001 const Constant *Initializer = GVar->getInitializer();
1002 // 'undef' is treated as there is no value specified.
1003 if (!Initializer->isNullValue() && !isa<UndefValue>(Val: Initializer)) {
1004 O << " = ";
1005 printScalarConstant(CPV: Initializer, O);
1006 }
1007 } else {
1008 // The frontend adds zero-initializer to device and constant variables
1009 // that don't have an initial value, and UndefValue to shared
1010 // variables, so skip warning for this case.
1011 if (!GVar->getInitializer()->isNullValue() &&
1012 !isa<UndefValue>(Val: GVar->getInitializer())) {
1013 report_fatal_error(reason: "initial value of '" + GVar->getName() +
1014 "' is not allowed in addrspace(" +
1015 Twine(GVar->getAddressSpace()) + ")");
1016 }
1017 }
1018 }
1019 } else {
1020 // Although PTX has direct support for struct type and array type and
1021 // LLVM IR is very similar to PTX, the LLVM CodeGen does not support for
1022 // targets that support these high level field accesses. Structs, arrays
1023 // and vectors are lowered into arrays of bytes.
1024 switch (ETy->getTypeID()) {
1025 case Type::IntegerTyID: // Integers larger than 64 bits
1026 case Type::FP128TyID:
1027 case Type::StructTyID:
1028 case Type::ArrayTyID:
1029 case Type::FixedVectorTyID: {
1030 const uint64_t ElementSize = DL.getTypeStoreSize(Ty: ETy);
1031 // Ptx allows variable initilization only for constant and
1032 // global state spaces.
1033 if (((GVar->getAddressSpace() == ADDRESS_SPACE_GLOBAL) ||
1034 (GVar->getAddressSpace() == ADDRESS_SPACE_CONST)) &&
1035 GVar->hasInitializer()) {
1036 const Constant *Initializer = GVar->getInitializer();
1037 if (!isa<UndefValue>(Val: Initializer) && !Initializer->isNullValue()) {
1038 AggBuffer aggBuffer(ElementSize, *this);
1039 bufferAggregateConstant(CV: Initializer, aggBuffer: &aggBuffer);
1040 if (aggBuffer.numSymbols()) {
1041 const unsigned int ptrSize = MAI->getCodePointerSize();
1042 if (ElementSize % ptrSize ||
1043 !aggBuffer.allSymbolsAligned(ptrSize)) {
1044 // Print in bytes and use the mask() operator for pointers.
1045 if (!STI.hasMaskOperator())
1046 report_fatal_error(
1047 reason: "initialized packed aggregate with pointers '" +
1048 GVar->getName() +
1049 "' requires at least PTX ISA version 7.1");
1050 O << " .u8 ";
1051 getSymbol(GV: GVar)->print(OS&: O, MAI);
1052 O << "[" << ElementSize << "] = {";
1053 aggBuffer.printBytes(os&: O);
1054 O << "}";
1055 } else {
1056 O << " .u" << ptrSize * 8 << " ";
1057 getSymbol(GV: GVar)->print(OS&: O, MAI);
1058 O << "[" << ElementSize / ptrSize << "] = {";
1059 aggBuffer.printWords(os&: O);
1060 O << "}";
1061 }
1062 } else {
1063 O << " .b8 ";
1064 getSymbol(GV: GVar)->print(OS&: O, MAI);
1065 O << "[" << ElementSize << "] = {";
1066 aggBuffer.printBytes(os&: O);
1067 O << "}";
1068 }
1069 } else {
1070 O << " .b8 ";
1071 getSymbol(GV: GVar)->print(OS&: O, MAI);
1072 if (ElementSize)
1073 O << "[" << ElementSize << "]";
1074 }
1075 } else {
1076 O << " .b8 ";
1077 getSymbol(GV: GVar)->print(OS&: O, MAI);
1078 if (ElementSize)
1079 O << "[" << ElementSize << "]";
1080 }
1081 break;
1082 }
1083 default:
1084 llvm_unreachable("type not supported yet");
1085 }
1086 }
1087 O << ";\n";
1088}
1089
1090void NVPTXAsmPrinter::AggBuffer::printSymbol(unsigned nSym, raw_ostream &os) {
1091 const Value *v = Symbols[nSym];
1092 const Value *v0 = SymbolsBeforeStripping[nSym];
1093 if (const GlobalValue *GVar = dyn_cast<GlobalValue>(Val: v)) {
1094 MCSymbol *Name = AP.getSymbol(GV: GVar);
1095 PointerType *PTy = dyn_cast<PointerType>(Val: v0->getType());
1096 // Is v0 a generic pointer?
1097 bool isGenericPointer = PTy && PTy->getAddressSpace() == 0;
1098 if (EmitGeneric && isGenericPointer && !isa<Function>(Val: v)) {
1099 os << "generic(";
1100 Name->print(OS&: os, MAI: AP.MAI);
1101 os << ")";
1102 } else {
1103 Name->print(OS&: os, MAI: AP.MAI);
1104 }
1105 } else if (const ConstantExpr *CExpr = dyn_cast<ConstantExpr>(Val: v0)) {
1106 const MCExpr *Expr = AP.lowerConstantForGV(CV: CExpr, ProcessingGeneric: false);
1107 AP.printMCExpr(Expr: *Expr, OS&: os);
1108 } else
1109 llvm_unreachable("symbol type unknown");
1110}
1111
1112void NVPTXAsmPrinter::AggBuffer::printBytes(raw_ostream &os) {
1113 unsigned int ptrSize = AP.MAI->getCodePointerSize();
1114 // Do not emit trailing zero initializers. They will be zero-initialized by
1115 // ptxas. This saves on both space requirements for the generated PTX and on
1116 // memory use by ptxas. (See:
1117 // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#global-state-space)
1118 unsigned int InitializerCount = size;
1119 // TODO: symbols make this harder, but it would still be good to trim trailing
1120 // 0s for aggs with symbols as well.
1121 if (numSymbols() == 0)
1122 while (InitializerCount >= 1 && !buffer[InitializerCount - 1])
1123 InitializerCount--;
1124
1125 symbolPosInBuffer.push_back(Elt: InitializerCount);
1126 unsigned int nSym = 0;
1127 unsigned int nextSymbolPos = symbolPosInBuffer[nSym];
1128 for (unsigned int pos = 0; pos < InitializerCount;) {
1129 if (pos)
1130 os << ", ";
1131 if (pos != nextSymbolPos) {
1132 os << (unsigned int)buffer[pos];
1133 ++pos;
1134 continue;
1135 }
1136 // Generate a per-byte mask() operator for the symbol, which looks like:
1137 // .global .u8 addr[] = {0xFF(foo), 0xFF00(foo), 0xFF0000(foo), ...};
1138 // See https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#initializers
1139 std::string symText;
1140 llvm::raw_string_ostream oss(symText);
1141 printSymbol(nSym, os&: oss);
1142 for (unsigned i = 0; i < ptrSize; ++i) {
1143 if (i)
1144 os << ", ";
1145 llvm::write_hex(S&: os, N: 0xFFULL << i * 8, Style: HexPrintStyle::PrefixUpper);
1146 os << "(" << symText << ")";
1147 }
1148 pos += ptrSize;
1149 nextSymbolPos = symbolPosInBuffer[++nSym];
1150 assert(nextSymbolPos >= pos);
1151 }
1152}
1153
1154void NVPTXAsmPrinter::AggBuffer::printWords(raw_ostream &os) {
1155 unsigned int ptrSize = AP.MAI->getCodePointerSize();
1156 symbolPosInBuffer.push_back(Elt: size);
1157 unsigned int nSym = 0;
1158 unsigned int nextSymbolPos = symbolPosInBuffer[nSym];
1159 assert(nextSymbolPos % ptrSize == 0);
1160 for (unsigned int pos = 0; pos < size; pos += ptrSize) {
1161 if (pos)
1162 os << ", ";
1163 if (pos == nextSymbolPos) {
1164 printSymbol(nSym, os);
1165 nextSymbolPos = symbolPosInBuffer[++nSym];
1166 assert(nextSymbolPos % ptrSize == 0);
1167 assert(nextSymbolPos >= pos + ptrSize);
1168 } else if (ptrSize == 4)
1169 os << support::endian::read32le(P: &buffer[pos]);
1170 else
1171 os << support::endian::read64le(P: &buffer[pos]);
1172 }
1173}
1174
1175void NVPTXAsmPrinter::emitDemotedVars(const Function *F, raw_ostream &O) {
1176 auto It = localDecls.find(x: F);
1177 if (It == localDecls.end())
1178 return;
1179
1180 ArrayRef<const GlobalVariable *> GVars = It->second;
1181
1182 const NVPTXTargetMachine &NTM = static_cast<const NVPTXTargetMachine &>(TM);
1183 const NVPTXSubtarget &STI = *NTM.getSubtargetImpl();
1184
1185 for (const GlobalVariable *GV : GVars) {
1186 O << "\t// demoted variable\n\t";
1187 printModuleLevelGV(GVar: GV, O, /*processDemoted=*/ProcessDemoted: true, STI);
1188 }
1189}
1190
1191void NVPTXAsmPrinter::emitPTXAddressSpace(unsigned int AddressSpace,
1192 raw_ostream &O) const {
1193 switch (AddressSpace) {
1194 case ADDRESS_SPACE_LOCAL:
1195 O << "local";
1196 break;
1197 case ADDRESS_SPACE_GLOBAL:
1198 O << "global";
1199 break;
1200 case ADDRESS_SPACE_CONST:
1201 O << "const";
1202 break;
1203 case ADDRESS_SPACE_SHARED:
1204 O << "shared";
1205 break;
1206 default:
1207 report_fatal_error(reason: "Bad address space found while emitting PTX: " +
1208 llvm::Twine(AddressSpace));
1209 break;
1210 }
1211}
1212
1213std::string
1214NVPTXAsmPrinter::getPTXFundamentalTypeStr(Type *Ty, bool useB4PTR) const {
1215 switch (Ty->getTypeID()) {
1216 case Type::IntegerTyID: {
1217 unsigned NumBits = cast<IntegerType>(Val: Ty)->getBitWidth();
1218 if (NumBits == 1)
1219 return "pred";
1220 if (NumBits <= 64) {
1221 std::string name = "u";
1222 return name + utostr(X: NumBits);
1223 }
1224 llvm_unreachable("Integer too large");
1225 break;
1226 }
1227 case Type::BFloatTyID:
1228 case Type::HalfTyID:
1229 // fp16 and bf16 are stored as .b16 for compatibility with pre-sm_53
1230 // PTX assembly.
1231 return "b16";
1232 case Type::FloatTyID:
1233 return "f32";
1234 case Type::DoubleTyID:
1235 return "f64";
1236 case Type::PointerTyID: {
1237 unsigned PtrSize = TM.getPointerSizeInBits(AS: Ty->getPointerAddressSpace());
1238 assert((PtrSize == 64 || PtrSize == 32) && "Unexpected pointer size");
1239
1240 if (PtrSize == 64)
1241 if (useB4PTR)
1242 return "b64";
1243 else
1244 return "u64";
1245 else if (useB4PTR)
1246 return "b32";
1247 else
1248 return "u32";
1249 }
1250 default:
1251 break;
1252 }
1253 llvm_unreachable("unexpected type");
1254}
1255
1256void NVPTXAsmPrinter::emitPTXGlobalVariable(const GlobalVariable *GVar,
1257 raw_ostream &O,
1258 const NVPTXSubtarget &STI) {
1259 const DataLayout &DL = getDataLayout();
1260
1261 // GlobalVariables are always constant pointers themselves.
1262 Type *ETy = GVar->getValueType();
1263
1264 O << ".";
1265 emitPTXAddressSpace(AddressSpace: GVar->getType()->getAddressSpace(), O);
1266 if (isManaged(*GVar)) {
1267 if (STI.getPTXVersion() < 40 || STI.getSmVersion() < 30)
1268 report_fatal_error(
1269 reason: ".attribute(.managed) requires PTX version >= 4.0 and sm_30");
1270
1271 O << " .attribute(.managed)";
1272 }
1273 O << " .align "
1274 << GVar->getAlign().value_or(u: DL.getPrefTypeAlign(Ty: ETy)).value();
1275
1276 // Special case for i128/fp128
1277 if (ETy->getScalarSizeInBits() == 128) {
1278 O << " .b8 ";
1279 getSymbol(GV: GVar)->print(OS&: O, MAI);
1280 O << "[16]";
1281 return;
1282 }
1283
1284 if (ETy->isFloatingPointTy() || ETy->isIntOrPtrTy()) {
1285 O << " ." << getPTXFundamentalTypeStr(Ty: ETy) << " ";
1286 getSymbol(GV: GVar)->print(OS&: O, MAI);
1287 return;
1288 }
1289
1290 int64_t ElementSize = 0;
1291
1292 // Although PTX has direct support for struct type and array type and LLVM IR
1293 // is very similar to PTX, the LLVM CodeGen does not support for targets that
1294 // support these high level field accesses. Structs and arrays are lowered
1295 // into arrays of bytes.
1296 switch (ETy->getTypeID()) {
1297 case Type::StructTyID:
1298 case Type::ArrayTyID:
1299 case Type::FixedVectorTyID:
1300 ElementSize = DL.getTypeStoreSize(Ty: ETy);
1301 O << " .b8 ";
1302 getSymbol(GV: GVar)->print(OS&: O, MAI);
1303 O << "[";
1304 if (ElementSize) {
1305 O << ElementSize;
1306 }
1307 O << "]";
1308 break;
1309 default:
1310 llvm_unreachable("type not supported yet");
1311 }
1312}
1313
1314void NVPTXAsmPrinter::emitFunctionParamList(const Function *F, raw_ostream &O) {
1315 const DataLayout &DL = getDataLayout();
1316 const NVPTXSubtarget &STI = TM.getSubtarget<NVPTXSubtarget>(F: *F);
1317 const auto *TLI = cast<NVPTXTargetLowering>(Val: STI.getTargetLowering());
1318 const NVPTXMachineFunctionInfo *MFI =
1319 MF ? MF->getInfo<NVPTXMachineFunctionInfo>() : nullptr;
1320
1321 bool IsFirst = true;
1322 const bool IsKernelFunc = isKernelFunction(F: *F);
1323
1324 if (F->arg_empty() && !F->isVarArg()) {
1325 O << "()";
1326 return;
1327 }
1328
1329 O << "(\n";
1330
1331 for (const Argument &Arg : F->args()) {
1332 Type *Ty = Arg.getType();
1333 const std::string ParamSym = TLI->getParamName(F, Idx: Arg.getArgNo());
1334
1335 if (!IsFirst)
1336 O << ",\n";
1337
1338 IsFirst = false;
1339
1340 // Handle image/sampler parameters
1341 if (IsKernelFunc) {
1342 const bool IsSampler = isSampler(Arg);
1343 const bool IsTexture = !IsSampler && isImageReadOnly(Arg);
1344 const bool IsSurface = !IsSampler && !IsTexture &&
1345 (isImageReadWrite(Arg) || isImageWriteOnly(Arg));
1346 if (IsSampler || IsTexture || IsSurface) {
1347 const bool EmitImgPtr = !MFI || !MFI->checkImageHandleSymbol(Symbol: ParamSym);
1348 O << "\t.param ";
1349 if (EmitImgPtr)
1350 O << ".u64 .ptr ";
1351
1352 if (IsSampler)
1353 O << ".samplerref ";
1354 else if (IsTexture)
1355 O << ".texref ";
1356 else // IsSurface
1357 O << ".surfref ";
1358 O << ParamSym;
1359 continue;
1360 }
1361 }
1362
1363 auto GetOptimalAlignForParam = [TLI, &DL, F, &Arg](Type *Ty) -> Align {
1364 if (MaybeAlign StackAlign =
1365 getAlign(F: *F, Index: Arg.getArgNo() + AttributeList::FirstArgIndex))
1366 return StackAlign.value();
1367
1368 Align TypeAlign = TLI->getFunctionParamOptimizedAlign(F, ArgTy: Ty, DL);
1369 MaybeAlign ParamAlign =
1370 Arg.hasByValAttr() ? Arg.getParamAlign() : MaybeAlign();
1371 return std::max(a: TypeAlign, b: ParamAlign.valueOrOne());
1372 };
1373
1374 if (Arg.hasByValAttr()) {
1375 // param has byVal attribute.
1376 Type *ETy = Arg.getParamByValType();
1377 assert(ETy && "Param should have byval type");
1378
1379 // Print .param .align <a> .b8 .param[size];
1380 // <a> = optimal alignment for the element type; always multiple of
1381 // PAL.getParamAlignment
1382 // size = typeallocsize of element type
1383 const Align OptimalAlign =
1384 IsKernelFunc ? GetOptimalAlignForParam(ETy)
1385 : TLI->getFunctionByValParamAlign(
1386 F, ArgTy: ETy, InitialAlign: Arg.getParamAlign().valueOrOne(), DL);
1387
1388 O << "\t.param .align " << OptimalAlign.value() << " .b8 " << ParamSym
1389 << "[" << DL.getTypeAllocSize(Ty: ETy) << "]";
1390 continue;
1391 }
1392
1393 if (shouldPassAsArray(Ty)) {
1394 // Just print .param .align <a> .b8 .param[size];
1395 // <a> = optimal alignment for the element type; always multiple of
1396 // PAL.getParamAlignment
1397 // size = typeallocsize of element type
1398 Align OptimalAlign = GetOptimalAlignForParam(Ty);
1399
1400 O << "\t.param .align " << OptimalAlign.value() << " .b8 " << ParamSym
1401 << "[" << DL.getTypeAllocSize(Ty) << "]";
1402
1403 continue;
1404 }
1405 // Just a scalar
1406 auto *PTy = dyn_cast<PointerType>(Val: Ty);
1407 unsigned PTySizeInBits = 0;
1408 if (PTy) {
1409 PTySizeInBits =
1410 TLI->getPointerTy(DL, AS: PTy->getAddressSpace()).getSizeInBits();
1411 assert(PTySizeInBits && "Invalid pointer size");
1412 }
1413
1414 if (IsKernelFunc) {
1415 if (PTy) {
1416 O << "\t.param .u" << PTySizeInBits << " .ptr";
1417
1418 switch (PTy->getAddressSpace()) {
1419 default:
1420 break;
1421 case ADDRESS_SPACE_GLOBAL:
1422 O << " .global";
1423 break;
1424 case ADDRESS_SPACE_SHARED:
1425 O << " .shared";
1426 break;
1427 case ADDRESS_SPACE_CONST:
1428 O << " .const";
1429 break;
1430 case ADDRESS_SPACE_LOCAL:
1431 O << " .local";
1432 break;
1433 }
1434
1435 O << " .align " << Arg.getParamAlign().valueOrOne().value() << " "
1436 << ParamSym;
1437 continue;
1438 }
1439
1440 // non-pointer scalar to kernel func
1441 O << "\t.param .";
1442 // Special case: predicate operands become .u8 types
1443 if (Ty->isIntegerTy(Bitwidth: 1))
1444 O << "u8";
1445 else
1446 O << getPTXFundamentalTypeStr(Ty);
1447 O << " " << ParamSym;
1448 continue;
1449 }
1450 // Non-kernel function, just print .param .b<size> for ABI
1451 // and .reg .b<size> for non-ABI
1452 unsigned Size;
1453 if (auto *ITy = dyn_cast<IntegerType>(Val: Ty)) {
1454 Size = promoteScalarArgumentSize(size: ITy->getBitWidth());
1455 } else if (PTy) {
1456 assert(PTySizeInBits && "Invalid pointer size");
1457 Size = PTySizeInBits;
1458 } else
1459 Size = Ty->getPrimitiveSizeInBits();
1460 O << "\t.param .b" << Size << " " << ParamSym;
1461 }
1462
1463 if (F->isVarArg()) {
1464 if (!IsFirst)
1465 O << ",\n";
1466 O << "\t.param .align " << STI.getMaxRequiredAlignment() << " .b8 "
1467 << TLI->getParamName(F, /* vararg */ Idx: -1) << "[]";
1468 }
1469
1470 O << "\n)";
1471}
1472
1473void NVPTXAsmPrinter::setAndEmitFunctionVirtualRegisters(
1474 const MachineFunction &MF) {
1475 SmallString<128> Str;
1476 raw_svector_ostream O(Str);
1477
1478 // Map the global virtual register number to a register class specific
1479 // virtual register number starting from 1 with that class.
1480 const TargetRegisterInfo *TRI = MF.getSubtarget().getRegisterInfo();
1481
1482 // Emit the Fake Stack Object
1483 const MachineFrameInfo &MFI = MF.getFrameInfo();
1484 int64_t NumBytes = MFI.getStackSize();
1485 if (NumBytes) {
1486 O << "\t.local .align " << MFI.getMaxAlign().value() << " .b8 \t"
1487 << DEPOTNAME << getFunctionNumber() << "[" << NumBytes << "];\n";
1488 if (static_cast<const NVPTXTargetMachine &>(MF.getTarget()).is64Bit()) {
1489 O << "\t.reg .b64 \t%SP;\n"
1490 << "\t.reg .b64 \t%SPL;\n";
1491 } else {
1492 O << "\t.reg .b32 \t%SP;\n"
1493 << "\t.reg .b32 \t%SPL;\n";
1494 }
1495 }
1496
1497 // Go through all virtual registers to establish the mapping between the
1498 // global virtual
1499 // register number and the per class virtual register number.
1500 // We use the per class virtual register number in the ptx output.
1501 for (unsigned I : llvm::seq(Size: MRI->getNumVirtRegs())) {
1502 Register VR = Register::index2VirtReg(Index: I);
1503 if (MRI->use_empty(RegNo: VR) && MRI->def_empty(RegNo: VR))
1504 continue;
1505 auto &RCRegMap = VRegMapping[MRI->getRegClass(Reg: VR)];
1506 RCRegMap[VR] = RCRegMap.size() + 1;
1507 }
1508
1509 // Emit declaration of the virtual registers or 'physical' registers for
1510 // each register class
1511 for (const TargetRegisterClass *RC : TRI->regclasses()) {
1512 const unsigned N = VRegMapping[RC].size();
1513
1514 // Only declare those registers that may be used.
1515 if (N) {
1516 const StringRef RCName = getNVPTXRegClassName(RC);
1517 const StringRef RCStr = getNVPTXRegClassStr(RC);
1518 O << "\t.reg " << RCName << " \t" << RCStr << "<" << (N + 1) << ">;\n";
1519 }
1520 }
1521
1522 OutStreamer->emitRawText(String: O.str());
1523}
1524
1525/// Translate virtual register numbers in DebugInfo locations to their printed
1526/// encodings, as used by CUDA-GDB.
1527void NVPTXAsmPrinter::encodeDebugInfoRegisterNumbers(
1528 const MachineFunction &MF) {
1529 const NVPTXSubtarget &STI = MF.getSubtarget<NVPTXSubtarget>();
1530 const NVPTXRegisterInfo *registerInfo = STI.getRegisterInfo();
1531
1532 // Clear the old mapping, and add the new one. This mapping is used after the
1533 // printing of the current function is complete, but before the next function
1534 // is printed.
1535 registerInfo->clearDebugRegisterMap();
1536
1537 for (auto &classMap : VRegMapping) {
1538 for (auto &registerMapping : classMap.getSecond()) {
1539 auto reg = registerMapping.getFirst();
1540 registerInfo->addToDebugRegisterMap(preEncodedVirtualRegister: reg, RegisterName: getVirtualRegisterName(Reg: reg));
1541 }
1542 }
1543}
1544
1545void NVPTXAsmPrinter::printFPConstant(const ConstantFP *Fp,
1546 raw_ostream &O) const {
1547 APFloat APF = APFloat(Fp->getValueAPF()); // make a copy
1548 bool ignored;
1549 unsigned int numHex;
1550 const char *lead;
1551
1552 if (Fp->getType()->getTypeID() == Type::FloatTyID) {
1553 numHex = 8;
1554 lead = "0f";
1555 APF.convert(ToSemantics: APFloat::IEEEsingle(), RM: APFloat::rmNearestTiesToEven, losesInfo: &ignored);
1556 } else if (Fp->getType()->getTypeID() == Type::DoubleTyID) {
1557 numHex = 16;
1558 lead = "0d";
1559 APF.convert(ToSemantics: APFloat::IEEEdouble(), RM: APFloat::rmNearestTiesToEven, losesInfo: &ignored);
1560 } else
1561 llvm_unreachable("unsupported fp type");
1562
1563 APInt API = APF.bitcastToAPInt();
1564 O << lead << format_hex_no_prefix(N: API.getZExtValue(), Width: numHex, /*Upper=*/true);
1565}
1566
1567void NVPTXAsmPrinter::printScalarConstant(const Constant *CPV, raw_ostream &O) {
1568 if (const ConstantInt *CI = dyn_cast<ConstantInt>(Val: CPV)) {
1569 O << CI->getValue();
1570 return;
1571 }
1572 if (const ConstantFP *CFP = dyn_cast<ConstantFP>(Val: CPV)) {
1573 printFPConstant(Fp: CFP, O);
1574 return;
1575 }
1576 if (isa<ConstantPointerNull>(Val: CPV)) {
1577 O << "0";
1578 return;
1579 }
1580 if (const GlobalValue *GVar = dyn_cast<GlobalValue>(Val: CPV)) {
1581 const bool IsNonGenericPointer = GVar->getAddressSpace() != 0;
1582 if (EmitGeneric && !isa<Function>(Val: CPV) && !IsNonGenericPointer) {
1583 O << "generic(";
1584 getSymbol(GV: GVar)->print(OS&: O, MAI);
1585 O << ")";
1586 } else {
1587 getSymbol(GV: GVar)->print(OS&: O, MAI);
1588 }
1589 return;
1590 }
1591 if (const ConstantExpr *Cexpr = dyn_cast<ConstantExpr>(Val: CPV)) {
1592 const MCExpr *E = lowerConstantForGV(CV: cast<Constant>(Val: Cexpr), ProcessingGeneric: false);
1593 printMCExpr(Expr: *E, OS&: O);
1594 return;
1595 }
1596 llvm_unreachable("Not scalar type found in printScalarConstant()");
1597}
1598
1599void NVPTXAsmPrinter::bufferLEByte(const Constant *CPV, int Bytes,
1600 AggBuffer *AggBuffer) {
1601 const DataLayout &DL = getDataLayout();
1602 int AllocSize = DL.getTypeAllocSize(Ty: CPV->getType());
1603 if (isa<UndefValue>(Val: CPV) || CPV->isNullValue()) {
1604 // Non-zero Bytes indicates that we need to zero-fill everything. Otherwise,
1605 // only the space allocated by CPV.
1606 AggBuffer->addZeros(Num: Bytes ? Bytes : AllocSize);
1607 return;
1608 }
1609
1610 // Helper for filling AggBuffer with APInts.
1611 auto AddIntToBuffer = [AggBuffer, Bytes](const APInt &Val) {
1612 size_t NumBytes = (Val.getBitWidth() + 7) / 8;
1613 SmallVector<unsigned char, 16> Buf(NumBytes);
1614 // `extractBitsAsZExtValue` does not allow the extraction of bits beyond the
1615 // input's bit width, and i1 arrays may not have a length that is a multuple
1616 // of 8. We handle the last byte separately, so we never request out of
1617 // bounds bits.
1618 for (unsigned I = 0; I < NumBytes - 1; ++I) {
1619 Buf[I] = Val.extractBitsAsZExtValue(numBits: 8, bitPosition: I * 8);
1620 }
1621 size_t LastBytePosition = (NumBytes - 1) * 8;
1622 size_t LastByteBits = Val.getBitWidth() - LastBytePosition;
1623 Buf[NumBytes - 1] =
1624 Val.extractBitsAsZExtValue(numBits: LastByteBits, bitPosition: LastBytePosition);
1625 AggBuffer->addBytes(Ptr: Buf.data(), Num: NumBytes, Bytes);
1626 };
1627
1628 switch (CPV->getType()->getTypeID()) {
1629 case Type::IntegerTyID:
1630 if (const auto *CI = dyn_cast<ConstantInt>(Val: CPV)) {
1631 AddIntToBuffer(CI->getValue());
1632 break;
1633 }
1634 if (const auto *Cexpr = dyn_cast<ConstantExpr>(Val: CPV)) {
1635 if (const auto *CI =
1636 dyn_cast<ConstantInt>(Val: ConstantFoldConstant(C: Cexpr, DL))) {
1637 AddIntToBuffer(CI->getValue());
1638 break;
1639 }
1640 if (Cexpr->getOpcode() == Instruction::PtrToInt) {
1641 Value *V = Cexpr->getOperand(i_nocapture: 0)->stripPointerCasts();
1642 AggBuffer->addSymbol(GVar: V, GVarBeforeStripping: Cexpr->getOperand(i_nocapture: 0));
1643 AggBuffer->addZeros(Num: AllocSize);
1644 break;
1645 }
1646 }
1647 llvm_unreachable("unsupported integer const type");
1648 break;
1649
1650 case Type::HalfTyID:
1651 case Type::BFloatTyID:
1652 case Type::FloatTyID:
1653 case Type::DoubleTyID:
1654 AddIntToBuffer(cast<ConstantFP>(Val: CPV)->getValueAPF().bitcastToAPInt());
1655 break;
1656
1657 case Type::PointerTyID: {
1658 if (const GlobalValue *GVar = dyn_cast<GlobalValue>(Val: CPV)) {
1659 AggBuffer->addSymbol(GVar, GVarBeforeStripping: GVar);
1660 } else if (const ConstantExpr *Cexpr = dyn_cast<ConstantExpr>(Val: CPV)) {
1661 const Value *v = Cexpr->stripPointerCasts();
1662 AggBuffer->addSymbol(GVar: v, GVarBeforeStripping: Cexpr);
1663 }
1664 AggBuffer->addZeros(Num: AllocSize);
1665 break;
1666 }
1667
1668 case Type::ArrayTyID:
1669 case Type::FixedVectorTyID:
1670 case Type::StructTyID: {
1671 if (isa<ConstantAggregate>(Val: CPV) || isa<ConstantDataSequential>(Val: CPV)) {
1672 bufferAggregateConstant(CV: CPV, aggBuffer: AggBuffer);
1673 if (Bytes > AllocSize)
1674 AggBuffer->addZeros(Num: Bytes - AllocSize);
1675 } else if (isa<ConstantAggregateZero>(Val: CPV))
1676 AggBuffer->addZeros(Num: Bytes);
1677 else
1678 llvm_unreachable("Unexpected Constant type");
1679 break;
1680 }
1681
1682 default:
1683 llvm_unreachable("unsupported type");
1684 }
1685}
1686
1687void NVPTXAsmPrinter::bufferAggregateConstant(const Constant *CPV,
1688 AggBuffer *aggBuffer) {
1689 const DataLayout &DL = getDataLayout();
1690
1691 auto ExtendBuffer = [](APInt Val, AggBuffer *Buffer) {
1692 for (unsigned I : llvm::seq(Size: Val.getBitWidth() / 8))
1693 Buffer->addByte(Byte: Val.extractBitsAsZExtValue(numBits: 8, bitPosition: I * 8));
1694 };
1695
1696 // Integers of arbitrary width
1697 if (const ConstantInt *CI = dyn_cast<ConstantInt>(Val: CPV)) {
1698 ExtendBuffer(CI->getValue(), aggBuffer);
1699 return;
1700 }
1701
1702 // f128
1703 if (const ConstantFP *CFP = dyn_cast<ConstantFP>(Val: CPV)) {
1704 if (CFP->getType()->isFP128Ty()) {
1705 ExtendBuffer(CFP->getValueAPF().bitcastToAPInt(), aggBuffer);
1706 return;
1707 }
1708 }
1709
1710 // Old constants
1711 if (isa<ConstantArray>(Val: CPV) || isa<ConstantVector>(Val: CPV)) {
1712 for (const auto &Op : CPV->operands())
1713 bufferLEByte(CPV: cast<Constant>(Val: Op), Bytes: 0, AggBuffer: aggBuffer);
1714 return;
1715 }
1716
1717 if (const auto *CDS = dyn_cast<ConstantDataSequential>(Val: CPV)) {
1718 for (unsigned I : llvm::seq(Size: CDS->getNumElements()))
1719 bufferLEByte(CPV: cast<Constant>(Val: CDS->getElementAsConstant(i: I)), Bytes: 0, AggBuffer: aggBuffer);
1720 return;
1721 }
1722
1723 if (isa<ConstantStruct>(Val: CPV)) {
1724 if (CPV->getNumOperands()) {
1725 StructType *ST = cast<StructType>(Val: CPV->getType());
1726 for (unsigned I : llvm::seq(Size: CPV->getNumOperands())) {
1727 int EndOffset = (I + 1 == CPV->getNumOperands())
1728 ? DL.getStructLayout(Ty: ST)->getElementOffset(Idx: 0) +
1729 DL.getTypeAllocSize(Ty: ST)
1730 : DL.getStructLayout(Ty: ST)->getElementOffset(Idx: I + 1);
1731 int Bytes = EndOffset - DL.getStructLayout(Ty: ST)->getElementOffset(Idx: I);
1732 bufferLEByte(CPV: cast<Constant>(Val: CPV->getOperand(i: I)), Bytes, AggBuffer: aggBuffer);
1733 }
1734 }
1735 return;
1736 }
1737 llvm_unreachable("unsupported constant type in printAggregateConstant()");
1738}
1739
1740/// lowerConstantForGV - Return an MCExpr for the given Constant. This is mostly
1741/// a copy from AsmPrinter::lowerConstant, except customized to only handle
1742/// expressions that are representable in PTX and create
1743/// NVPTXGenericMCSymbolRefExpr nodes for addrspacecast instructions.
1744const MCExpr *
1745NVPTXAsmPrinter::lowerConstantForGV(const Constant *CV,
1746 bool ProcessingGeneric) const {
1747 MCContext &Ctx = OutContext;
1748
1749 if (CV->isNullValue() || isa<UndefValue>(Val: CV))
1750 return MCConstantExpr::create(Value: 0, Ctx);
1751
1752 if (const ConstantInt *CI = dyn_cast<ConstantInt>(Val: CV))
1753 return MCConstantExpr::create(Value: CI->getZExtValue(), Ctx);
1754
1755 if (const GlobalValue *GV = dyn_cast<GlobalValue>(Val: CV)) {
1756 const MCSymbolRefExpr *Expr = MCSymbolRefExpr::create(Symbol: getSymbol(GV), Ctx);
1757 if (ProcessingGeneric)
1758 return NVPTXGenericMCSymbolRefExpr::create(SymExpr: Expr, Ctx);
1759 return Expr;
1760 }
1761
1762 const ConstantExpr *CE = dyn_cast<ConstantExpr>(Val: CV);
1763 if (!CE) {
1764 llvm_unreachable("Unknown constant value to lower!");
1765 }
1766
1767 switch (CE->getOpcode()) {
1768 default:
1769 break; // Error
1770
1771 case Instruction::AddrSpaceCast: {
1772 // Strip the addrspacecast and pass along the operand
1773 PointerType *DstTy = cast<PointerType>(Val: CE->getType());
1774 if (DstTy->getAddressSpace() == 0)
1775 return lowerConstantForGV(CV: cast<const Constant>(Val: CE->getOperand(i_nocapture: 0)), ProcessingGeneric: true);
1776
1777 break; // Error
1778 }
1779
1780 case Instruction::GetElementPtr: {
1781 const DataLayout &DL = getDataLayout();
1782
1783 // Generate a symbolic expression for the byte address
1784 APInt OffsetAI(DL.getPointerTypeSizeInBits(CE->getType()), 0);
1785 cast<GEPOperator>(Val: CE)->accumulateConstantOffset(DL, Offset&: OffsetAI);
1786
1787 const MCExpr *Base = lowerConstantForGV(CV: CE->getOperand(i_nocapture: 0),
1788 ProcessingGeneric);
1789 if (!OffsetAI)
1790 return Base;
1791
1792 int64_t Offset = OffsetAI.getSExtValue();
1793 return MCBinaryExpr::createAdd(LHS: Base, RHS: MCConstantExpr::create(Value: Offset, Ctx),
1794 Ctx);
1795 }
1796
1797 case Instruction::Trunc:
1798 // We emit the value and depend on the assembler to truncate the generated
1799 // expression properly. This is important for differences between
1800 // blockaddress labels. Since the two labels are in the same function, it
1801 // is reasonable to treat their delta as a 32-bit value.
1802 [[fallthrough]];
1803 case Instruction::BitCast:
1804 return lowerConstantForGV(CV: CE->getOperand(i_nocapture: 0), ProcessingGeneric);
1805
1806 case Instruction::IntToPtr: {
1807 const DataLayout &DL = getDataLayout();
1808
1809 // Handle casts to pointers by changing them into casts to the appropriate
1810 // integer type. This promotes constant folding and simplifies this code.
1811 Constant *Op = CE->getOperand(i_nocapture: 0);
1812 Op = ConstantFoldIntegerCast(C: Op, DestTy: DL.getIntPtrType(CV->getType()),
1813 /*IsSigned*/ false, DL);
1814 if (Op)
1815 return lowerConstantForGV(CV: Op, ProcessingGeneric);
1816
1817 break; // Error
1818 }
1819
1820 case Instruction::PtrToInt: {
1821 const DataLayout &DL = getDataLayout();
1822
1823 // Support only foldable casts to/from pointers that can be eliminated by
1824 // changing the pointer to the appropriately sized integer type.
1825 Constant *Op = CE->getOperand(i_nocapture: 0);
1826 Type *Ty = CE->getType();
1827
1828 const MCExpr *OpExpr = lowerConstantForGV(CV: Op, ProcessingGeneric);
1829
1830 // We can emit the pointer value into this slot if the slot is an
1831 // integer slot equal to the size of the pointer.
1832 if (DL.getTypeAllocSize(Ty) == DL.getTypeAllocSize(Ty: Op->getType()))
1833 return OpExpr;
1834
1835 // Otherwise the pointer is smaller than the resultant integer, mask off
1836 // the high bits so we are sure to get a proper truncation if the input is
1837 // a constant expr.
1838 unsigned InBits = DL.getTypeAllocSizeInBits(Ty: Op->getType());
1839 const MCExpr *MaskExpr = MCConstantExpr::create(Value: ~0ULL >> (64-InBits), Ctx);
1840 return MCBinaryExpr::createAnd(LHS: OpExpr, RHS: MaskExpr, Ctx);
1841 }
1842
1843 // The MC library also has a right-shift operator, but it isn't consistently
1844 // signed or unsigned between different targets.
1845 case Instruction::Add: {
1846 const MCExpr *LHS = lowerConstantForGV(CV: CE->getOperand(i_nocapture: 0), ProcessingGeneric);
1847 const MCExpr *RHS = lowerConstantForGV(CV: CE->getOperand(i_nocapture: 1), ProcessingGeneric);
1848 switch (CE->getOpcode()) {
1849 default: llvm_unreachable("Unknown binary operator constant cast expr");
1850 case Instruction::Add: return MCBinaryExpr::createAdd(LHS, RHS, Ctx);
1851 }
1852 }
1853 }
1854
1855 // If the code isn't optimized, there may be outstanding folding
1856 // opportunities. Attempt to fold the expression using DataLayout as a
1857 // last resort before giving up.
1858 Constant *C = ConstantFoldConstant(C: CE, DL: getDataLayout());
1859 if (C != CE)
1860 return lowerConstantForGV(CV: C, ProcessingGeneric);
1861
1862 // Otherwise report the problem to the user.
1863 std::string S;
1864 raw_string_ostream OS(S);
1865 OS << "Unsupported expression in static initializer: ";
1866 CE->printAsOperand(O&: OS, /*PrintType=*/false,
1867 M: !MF ? nullptr : MF->getFunction().getParent());
1868 report_fatal_error(reason: Twine(OS.str()));
1869}
1870
1871void NVPTXAsmPrinter::printMCExpr(const MCExpr &Expr, raw_ostream &OS) const {
1872 OutContext.getAsmInfo()->printExpr(OS, Expr);
1873}
1874
1875/// PrintAsmOperand - Print out an operand for an inline asm expression.
1876///
1877bool NVPTXAsmPrinter::PrintAsmOperand(const MachineInstr *MI, unsigned OpNo,
1878 const char *ExtraCode, raw_ostream &O) {
1879 if (ExtraCode && ExtraCode[0]) {
1880 if (ExtraCode[1] != 0)
1881 return true; // Unknown modifier.
1882
1883 switch (ExtraCode[0]) {
1884 default:
1885 // See if this is a generic print operand
1886 return AsmPrinter::PrintAsmOperand(MI, OpNo, ExtraCode, OS&: O);
1887 case 'r':
1888 break;
1889 }
1890 }
1891
1892 printOperand(MI, OpNum: OpNo, O);
1893
1894 return false;
1895}
1896
1897bool NVPTXAsmPrinter::PrintAsmMemoryOperand(const MachineInstr *MI,
1898 unsigned OpNo,
1899 const char *ExtraCode,
1900 raw_ostream &O) {
1901 if (ExtraCode && ExtraCode[0])
1902 return true; // Unknown modifier
1903
1904 O << '[';
1905 printMemOperand(MI, OpNum: OpNo, O);
1906 O << ']';
1907
1908 return false;
1909}
1910
1911void NVPTXAsmPrinter::printOperand(const MachineInstr *MI, unsigned OpNum,
1912 raw_ostream &O) {
1913 const MachineOperand &MO = MI->getOperand(i: OpNum);
1914 switch (MO.getType()) {
1915 case MachineOperand::MO_Register:
1916 if (MO.getReg().isPhysical()) {
1917 if (MO.getReg() == NVPTX::VRDepot)
1918 O << DEPOTNAME << getFunctionNumber();
1919 else
1920 O << NVPTXInstPrinter::getRegisterName(Reg: MO.getReg());
1921 } else {
1922 emitVirtualRegister(vr: MO.getReg(), O);
1923 }
1924 break;
1925
1926 case MachineOperand::MO_Immediate:
1927 O << MO.getImm();
1928 break;
1929
1930 case MachineOperand::MO_FPImmediate:
1931 printFPConstant(Fp: MO.getFPImm(), O);
1932 break;
1933
1934 case MachineOperand::MO_GlobalAddress:
1935 PrintSymbolOperand(MO, OS&: O);
1936 break;
1937
1938 case MachineOperand::MO_MachineBasicBlock:
1939 MO.getMBB()->getSymbol()->print(OS&: O, MAI);
1940 break;
1941
1942 default:
1943 llvm_unreachable("Operand type not supported.");
1944 }
1945}
1946
1947void NVPTXAsmPrinter::printMemOperand(const MachineInstr *MI, unsigned OpNum,
1948 raw_ostream &O, const char *Modifier) {
1949 printOperand(MI, OpNum, O);
1950
1951 if (Modifier && strcmp(s1: Modifier, s2: "add") == 0) {
1952 O << ", ";
1953 printOperand(MI, OpNum: OpNum + 1, O);
1954 } else {
1955 if (MI->getOperand(i: OpNum + 1).isImm() &&
1956 MI->getOperand(i: OpNum + 1).getImm() == 0)
1957 return; // don't print ',0' or '+0'
1958 O << "+";
1959 printOperand(MI, OpNum: OpNum + 1, O);
1960 }
1961}
1962
1963char NVPTXAsmPrinter::ID = 0;
1964
1965INITIALIZE_PASS(NVPTXAsmPrinter, "nvptx-asm-printer", "NVPTX Assembly Printer",
1966 false, false)
1967
1968// Force static initialization.
1969extern "C" LLVM_ABI LLVM_EXTERNAL_VISIBILITY void
1970LLVMInitializeNVPTXAsmPrinter() {
1971 RegisterAsmPrinter<NVPTXAsmPrinter> X(getTheNVPTXTarget32());
1972 RegisterAsmPrinter<NVPTXAsmPrinter> Y(getTheNVPTXTarget64());
1973}
1974