1//===-- NVPTXInstPrinter.cpp - PTX assembly instruction printing ----------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Print MCInst instructions to .ptx format.
10//
11//===----------------------------------------------------------------------===//
12
13#include "MCTargetDesc/NVPTXInstPrinter.h"
14#include "NVPTX.h"
15#include "NVPTXUtilities.h"
16#include "llvm/ADT/StringRef.h"
17#include "llvm/IR/NVVMIntrinsicUtils.h"
18#include "llvm/MC/MCAsmInfo.h"
19#include "llvm/MC/MCExpr.h"
20#include "llvm/MC/MCInst.h"
21#include "llvm/MC/MCInstrInfo.h"
22#include "llvm/MC/MCSubtargetInfo.h"
23#include "llvm/MC/MCSymbol.h"
24#include "llvm/Support/ErrorHandling.h"
25#include "llvm/Support/FormatVariadic.h"
26#include <cctype>
27using namespace llvm;
28
29#define DEBUG_TYPE "asm-printer"
30
31#define GET_SUBTARGETINFO_ENUM
32#include "NVPTXGenSubtargetInfo.inc"
33
34#include "NVPTXGenAsmWriter.inc"
35
36static bool hasParamSubqualifiers(const MCSubtargetInfo &STI) {
37 return STI.hasFeature(Feature: NVPTX::PTX83);
38}
39
40NVPTXInstPrinter::NVPTXInstPrinter(const MCAsmInfo &MAI, const MCInstrInfo &MII,
41 const MCRegisterInfo &MRI)
42 : MCInstPrinter(MAI, MII, MRI) {}
43
44void NVPTXInstPrinter::printRegName(raw_ostream &OS, MCRegister Reg) {
45 // Decode the virtual register
46 // Must be kept in sync with NVPTXAsmPrinter::encodeVirtualRegister
47 unsigned RCId = (Reg.id() >> 28);
48 switch (RCId) {
49 default: report_fatal_error(reason: "Bad virtual register encoding");
50 case 0:
51 // This is actually a physical register, so defer to the autogenerated
52 // register printer
53 OS << getRegisterName(Reg);
54 return;
55 case 1:
56 OS << "%p";
57 break;
58 case 2:
59 OS << "%rs";
60 break;
61 case 3:
62 OS << "%r";
63 break;
64 case 4:
65 OS << "%rd";
66 break;
67 case 5:
68 OS << "%f";
69 break;
70 case 6:
71 OS << "%fd";
72 break;
73 case 7:
74 OS << "%rq";
75 break;
76 }
77
78 unsigned VReg = Reg.id() & 0x0FFFFFFF;
79 OS << VReg;
80}
81
82void NVPTXInstPrinter::printInst(const MCInst *MI, uint64_t Address,
83 StringRef Annot, const MCSubtargetInfo &STI,
84 raw_ostream &OS) {
85 printInstruction(MI, Address, STI, O&: OS);
86
87 // Next always print the annotation.
88 printAnnotation(OS, Annot);
89}
90
91void NVPTXInstPrinter::printOperand(const MCInst *MI, unsigned OpNo,
92 const MCSubtargetInfo &, raw_ostream &O) {
93 const MCOperand &Op = MI->getOperand(i: OpNo);
94 if (Op.isReg()) {
95 MCRegister Reg = Op.getReg();
96 printRegName(OS&: O, Reg);
97 } else if (Op.isImm()) {
98 markup(OS&: O, M: Markup::Immediate) << formatImm(Value: Op.getImm());
99 } else {
100 assert(Op.isExpr() && "Unknown operand kind in printOperand");
101 MAI.printExpr(O, *Op.getExpr());
102 }
103}
104
105void NVPTXInstPrinter::printCvtMode(const MCInst *MI, int OpNum,
106 const MCSubtargetInfo &, raw_ostream &O,
107 StringRef Modifier) {
108 const MCOperand &MO = MI->getOperand(i: OpNum);
109 int64_t Imm = MO.getImm();
110
111 if (Modifier == "ftz") {
112 // FTZ flag
113 if (Imm & NVPTX::PTXCvtMode::FTZ_FLAG)
114 O << ".ftz";
115 return;
116 } else if (Modifier == "sat") {
117 // SAT flag
118 if (Imm & NVPTX::PTXCvtMode::SAT_FLAG)
119 O << ".sat";
120 return;
121 } else if (Modifier == "relu") {
122 // RELU flag
123 if (Imm & NVPTX::PTXCvtMode::RELU_FLAG)
124 O << ".relu";
125 return;
126 } else if (Modifier == "base") {
127 // Default operand
128 switch (Imm & NVPTX::PTXCvtMode::BASE_MASK) {
129 default:
130 return;
131 case NVPTX::PTXCvtMode::NONE:
132 return;
133 case NVPTX::PTXCvtMode::RNI:
134 O << ".rni";
135 return;
136 case NVPTX::PTXCvtMode::RZI:
137 O << ".rzi";
138 return;
139 case NVPTX::PTXCvtMode::RMI:
140 O << ".rmi";
141 return;
142 case NVPTX::PTXCvtMode::RPI:
143 O << ".rpi";
144 return;
145 case NVPTX::PTXCvtMode::RN:
146 O << ".rn";
147 return;
148 case NVPTX::PTXCvtMode::RZ:
149 O << ".rz";
150 return;
151 case NVPTX::PTXCvtMode::RM:
152 O << ".rm";
153 return;
154 case NVPTX::PTXCvtMode::RP:
155 O << ".rp";
156 return;
157 case NVPTX::PTXCvtMode::RNA:
158 O << ".rna";
159 return;
160 case NVPTX::PTXCvtMode::RS:
161 O << ".rs";
162 return;
163 }
164 }
165 llvm_unreachable("Invalid conversion modifier");
166}
167
168void NVPTXInstPrinter::printFTZFlag(const MCInst *MI, int OpNum,
169 const MCSubtargetInfo &, raw_ostream &O) {
170 const MCOperand &MO = MI->getOperand(i: OpNum);
171 const int Imm = MO.getImm();
172 if (Imm)
173 O << ".ftz";
174}
175
176void NVPTXInstPrinter::printCmpMode(const MCInst *MI, int OpNum,
177 const MCSubtargetInfo &, raw_ostream &O,
178 StringRef Modifier) {
179 const MCOperand &MO = MI->getOperand(i: OpNum);
180 int64_t Imm = MO.getImm();
181
182 if (Modifier == "FCmp") {
183 switch (Imm) {
184 default:
185 return;
186 case NVPTX::PTXCmpMode::EQ:
187 O << "eq";
188 return;
189 case NVPTX::PTXCmpMode::NE:
190 O << "ne";
191 return;
192 case NVPTX::PTXCmpMode::LT:
193 O << "lt";
194 return;
195 case NVPTX::PTXCmpMode::LE:
196 O << "le";
197 return;
198 case NVPTX::PTXCmpMode::GT:
199 O << "gt";
200 return;
201 case NVPTX::PTXCmpMode::GE:
202 O << "ge";
203 return;
204 case NVPTX::PTXCmpMode::EQU:
205 O << "equ";
206 return;
207 case NVPTX::PTXCmpMode::NEU:
208 O << "neu";
209 return;
210 case NVPTX::PTXCmpMode::LTU:
211 O << "ltu";
212 return;
213 case NVPTX::PTXCmpMode::LEU:
214 O << "leu";
215 return;
216 case NVPTX::PTXCmpMode::GTU:
217 O << "gtu";
218 return;
219 case NVPTX::PTXCmpMode::GEU:
220 O << "geu";
221 return;
222 case NVPTX::PTXCmpMode::NUM:
223 O << "num";
224 return;
225 case NVPTX::PTXCmpMode::NotANumber:
226 O << "nan";
227 return;
228 }
229 }
230 if (Modifier == "ICmp") {
231 switch (Imm) {
232 default:
233 llvm_unreachable("Invalid ICmp mode");
234 case NVPTX::PTXCmpMode::EQ:
235 O << "eq";
236 return;
237 case NVPTX::PTXCmpMode::NE:
238 O << "ne";
239 return;
240 case NVPTX::PTXCmpMode::LT:
241 case NVPTX::PTXCmpMode::LTU:
242 O << "lt";
243 return;
244 case NVPTX::PTXCmpMode::LE:
245 case NVPTX::PTXCmpMode::LEU:
246 O << "le";
247 return;
248 case NVPTX::PTXCmpMode::GT:
249 case NVPTX::PTXCmpMode::GTU:
250 O << "gt";
251 return;
252 case NVPTX::PTXCmpMode::GE:
253 case NVPTX::PTXCmpMode::GEU:
254 O << "ge";
255 return;
256 }
257 }
258 if (Modifier == "IType") {
259 switch (Imm) {
260 default:
261 llvm_unreachable("Invalid IType");
262 case NVPTX::PTXCmpMode::EQ:
263 case NVPTX::PTXCmpMode::NE:
264 O << "b";
265 return;
266 case NVPTX::PTXCmpMode::LT:
267 case NVPTX::PTXCmpMode::LE:
268 case NVPTX::PTXCmpMode::GT:
269 case NVPTX::PTXCmpMode::GE:
270 O << "s";
271 return;
272 case NVPTX::PTXCmpMode::LTU:
273 case NVPTX::PTXCmpMode::LEU:
274 case NVPTX::PTXCmpMode::GTU:
275 case NVPTX::PTXCmpMode::GEU:
276 O << "u";
277 return;
278 }
279 }
280 llvm_unreachable("Empty Modifier");
281}
282
283void NVPTXInstPrinter::printAtomicCode(const MCInst *MI, int OpNum,
284 const MCSubtargetInfo &STI,
285 raw_ostream &O, StringRef Modifier) {
286 const MCOperand &MO = MI->getOperand(i: OpNum);
287 int Imm = (int)MO.getImm();
288 if (Modifier == "sem") {
289 auto Ordering = NVPTX::Ordering(Imm);
290 switch (Ordering) {
291 case NVPTX::Ordering::NotAtomic:
292 return;
293 case NVPTX::Ordering::Relaxed:
294 O << ".relaxed";
295 return;
296 case NVPTX::Ordering::Acquire:
297 O << ".acquire";
298 return;
299 case NVPTX::Ordering::Release:
300 O << ".release";
301 return;
302 case NVPTX::Ordering::AcquireRelease:
303 O << ".acq_rel";
304 return;
305 case NVPTX::Ordering::SequentiallyConsistent:
306 report_fatal_error(
307 reason: "NVPTX AtomicCode Printer does not support \"seq_cst\" ordering.");
308 return;
309 case NVPTX::Ordering::Volatile:
310 O << ".volatile";
311 return;
312 case NVPTX::Ordering::RelaxedMMIO:
313 O << ".mmio.relaxed";
314 return;
315 }
316 } else if (Modifier == "scope") {
317 auto S = NVPTX::Scope(Imm);
318 switch (S) {
319 case NVPTX::Scope::Thread:
320 case NVPTX::Scope::DefaultDevice:
321 return;
322 case NVPTX::Scope::System:
323 O << ".sys";
324 return;
325 case NVPTX::Scope::Block:
326 O << ".cta";
327 return;
328 case NVPTX::Scope::Cluster:
329 O << ".cluster";
330 return;
331 case NVPTX::Scope::Device:
332 O << ".gpu";
333 return;
334 }
335 report_fatal_error(reason: formatv(
336 Fmt: "NVPTX AtomicCode Printer does not support \"{}\" scope modifier.",
337 Vals: ScopeToString(S)));
338 } else if (Modifier == "addsp") {
339 auto A = NVPTX::AddressSpace(Imm);
340 switch (A) {
341 case NVPTX::AddressSpace::Generic:
342 return;
343 case NVPTX::AddressSpace::Global:
344 case NVPTX::AddressSpace::Const:
345 case NVPTX::AddressSpace::Shared:
346 case NVPTX::AddressSpace::SharedCluster:
347 case NVPTX::AddressSpace::EntryParam:
348 case NVPTX::AddressSpace::DeviceParam:
349 case NVPTX::AddressSpace::Local:
350 O << "." << addressSpaceToString(A, UseParamSubqualifiers: hasParamSubqualifiers(STI));
351 return;
352 }
353 report_fatal_error(reason: formatv(
354 Fmt: "NVPTX AtomicCode Printer does not support \"{}\" addsp modifier.",
355 Vals: addressSpaceToString(A)));
356 } else if (Modifier == "sign") {
357 switch (Imm) {
358 case NVPTX::PTXLdStInstCode::Signed:
359 O << "s";
360 return;
361 case NVPTX::PTXLdStInstCode::Unsigned:
362 O << "u";
363 return;
364 case NVPTX::PTXLdStInstCode::Untyped:
365 O << "b";
366 return;
367 case NVPTX::PTXLdStInstCode::Float:
368 O << "f";
369 return;
370 default:
371 llvm_unreachable("Unknown register type");
372 }
373 }
374 llvm_unreachable(formatv("Unknown Modifier: {}", Modifier).str().c_str());
375}
376
377void NVPTXInstPrinter::printMmaCode(const MCInst *MI, int OpNum,
378 const MCSubtargetInfo &, raw_ostream &O,
379 StringRef Modifier) {
380 const MCOperand &MO = MI->getOperand(i: OpNum);
381 int Imm = (int)MO.getImm();
382 if (Modifier.empty() || Modifier == "version") {
383 O << Imm; // Just print out PTX version
384 return;
385 } else if (Modifier == "aligned") {
386 // PTX63 requires '.aligned' in the name of the instruction.
387 if (Imm >= 63)
388 O << ".aligned";
389 return;
390 }
391 llvm_unreachable("Unknown Modifier");
392}
393
394void NVPTXInstPrinter::printMemOperand(const MCInst *MI, int OpNum,
395 const MCSubtargetInfo &STI,
396 raw_ostream &O, StringRef Modifier) {
397 printOperand(MI, OpNo: OpNum, STI, O);
398
399 if (Modifier == "add") {
400 O << ", ";
401 printOperand(MI, OpNo: OpNum + 1, STI, O);
402 } else {
403 if (MI->getOperand(i: OpNum + 1).isImm() &&
404 MI->getOperand(i: OpNum + 1).getImm() == 0)
405 return; // don't print ',0' or '+0'
406 O << "+";
407 printOperand(MI, OpNo: OpNum + 1, STI, O);
408 }
409}
410
411void NVPTXInstPrinter::printUsedBytesMaskPragma(const MCInst *MI, int OpNum,
412 const MCSubtargetInfo &,
413 raw_ostream &O) {
414 auto &Op = MI->getOperand(i: OpNum);
415 assert(Op.isImm() && "Invalid operand");
416 uint32_t Imm = (uint32_t)Op.getImm();
417 if (Imm != UINT32_MAX) {
418 O << ".pragma \"used_bytes_mask " << format_hex(N: Imm, Width: 1) << "\";\n\t";
419 }
420}
421
422void NVPTXInstPrinter::printRegisterOrSinkSymbol(const MCInst *MI, int OpNum,
423 const MCSubtargetInfo &STI,
424 raw_ostream &O) {
425 const MCOperand &Op = MI->getOperand(i: OpNum);
426 if (Op.isReg() && Op.getReg() == MCRegister::NoRegister)
427 O << "_";
428 else
429 printOperand(MI, OpNo: OpNum, STI, O);
430}
431
432void NVPTXInstPrinter::printHexu32imm(const MCInst *MI, int OpNum,
433 const MCSubtargetInfo &, raw_ostream &O) {
434 int64_t Imm = MI->getOperand(i: OpNum).getImm();
435 O << formatHex(Value: Imm) << "U";
436}
437
438void NVPTXInstPrinter::printProtoIdent(const MCInst *MI, int OpNum,
439 const MCSubtargetInfo &,
440 raw_ostream &O) {
441 const MCOperand &Op = MI->getOperand(i: OpNum);
442 assert(Op.isExpr() && "Call prototype is not an MCExpr?");
443 const MCExpr *Expr = Op.getExpr();
444 const MCSymbol &Sym = cast<MCSymbolRefExpr>(Val: Expr)->getSymbol();
445 O << Sym.getName();
446}
447
448void NVPTXInstPrinter::printPrmtMode(const MCInst *MI, int OpNum,
449 const MCSubtargetInfo &, raw_ostream &O) {
450 const MCOperand &MO = MI->getOperand(i: OpNum);
451 int64_t Imm = MO.getImm();
452
453 switch (Imm) {
454 default:
455 return;
456 case NVPTX::PTXPrmtMode::NONE:
457 return;
458 case NVPTX::PTXPrmtMode::F4E:
459 O << ".f4e";
460 return;
461 case NVPTX::PTXPrmtMode::B4E:
462 O << ".b4e";
463 return;
464 case NVPTX::PTXPrmtMode::RC8:
465 O << ".rc8";
466 return;
467 case NVPTX::PTXPrmtMode::ECL:
468 O << ".ecl";
469 return;
470 case NVPTX::PTXPrmtMode::ECR:
471 O << ".ecr";
472 return;
473 case NVPTX::PTXPrmtMode::RC16:
474 O << ".rc16";
475 return;
476 }
477}
478
479void NVPTXInstPrinter::printTmaReductionMode(const MCInst *MI, int OpNum,
480 const MCSubtargetInfo &,
481 raw_ostream &O) {
482 const MCOperand &MO = MI->getOperand(i: OpNum);
483 using RedTy = nvvm::TMAReductionOp;
484
485 switch (static_cast<RedTy>(MO.getImm())) {
486 case RedTy::ADD:
487 O << ".add";
488 return;
489 case RedTy::MIN:
490 O << ".min";
491 return;
492 case RedTy::MAX:
493 O << ".max";
494 return;
495 case RedTy::INC:
496 O << ".inc";
497 return;
498 case RedTy::DEC:
499 O << ".dec";
500 return;
501 case RedTy::AND:
502 O << ".and";
503 return;
504 case RedTy::OR:
505 O << ".or";
506 return;
507 case RedTy::XOR:
508 O << ".xor";
509 return;
510 }
511 llvm_unreachable(
512 "Invalid Reduction Op in printCpAsyncBulkTensorReductionMode");
513}
514
515void NVPTXInstPrinter::printCTAGroup(const MCInst *MI, int OpNum,
516 const MCSubtargetInfo &, raw_ostream &O) {
517 const MCOperand &MO = MI->getOperand(i: OpNum);
518 using CGTy = nvvm::CTAGroupKind;
519
520 switch (static_cast<CGTy>(MO.getImm())) {
521 case CGTy::CG_NONE:
522 O << "";
523 return;
524 case CGTy::CG_1:
525 O << ".cta_group::1";
526 return;
527 case CGTy::CG_2:
528 O << ".cta_group::2";
529 return;
530 }
531 llvm_unreachable("Invalid cta_group in printCTAGroup");
532}
533
534void NVPTXInstPrinter::printCallOperand(const MCInst *MI, int OpNum,
535 const MCSubtargetInfo &, raw_ostream &O,
536 StringRef Modifier) {
537 const MCOperand &MO = MI->getOperand(i: OpNum);
538 assert(MO.isImm() && "Invalid operand");
539 const auto Imm = MO.getImm();
540
541 if (Modifier == "RetList") {
542 assert((Imm == 1 || Imm == 0) && "Invalid return list");
543 if (Imm)
544 O << " (retval0),";
545 return;
546 }
547
548 if (Modifier == "ParamList") {
549 assert(Imm >= 0 && "Invalid parameter list");
550 interleaveComma(c: llvm::seq(Size: Imm), os&: O,
551 each_fn: [&](const auto &I) { O << "param" << I; });
552 return;
553 }
554 llvm_unreachable("Invalid modifier");
555}
556
557template <unsigned Bits>
558void NVPTXInstPrinter::printHexUImm(const MCInst *MI, int OpNum,
559 const MCSubtargetInfo &, raw_ostream &O) {
560 const MCOperand &MO = MI->getOperand(i: OpNum);
561 assert(MO.isImm() && "Expected immediate operand");
562 assert(isInt<Bits>(MO.getImm()) &&
563 "Immediate value does not fit in specified bits");
564 uint64_t Imm = MO.getImm();
565 Imm &= maskTrailingOnes<uint64_t>(N: Bits);
566 O << formatHex(Value: Imm) << "U";
567}
568