1//===-- NVPTXInstPrinter.cpp - PTX assembly instruction printing ----------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Print MCInst instructions to .ptx format.
10//
11//===----------------------------------------------------------------------===//
12
13#include "MCTargetDesc/NVPTXInstPrinter.h"
14#include "NVPTX.h"
15#include "NVPTXUtilities.h"
16#include "llvm/ADT/StringRef.h"
17#include "llvm/IR/NVVMIntrinsicUtils.h"
18#include "llvm/MC/MCAsmInfo.h"
19#include "llvm/MC/MCExpr.h"
20#include "llvm/MC/MCInst.h"
21#include "llvm/MC/MCInstrInfo.h"
22#include "llvm/MC/MCSubtargetInfo.h"
23#include "llvm/MC/MCSymbol.h"
24#include "llvm/Support/ErrorHandling.h"
25#include "llvm/Support/FormatVariadic.h"
26#include <cctype>
27using namespace llvm;
28
29#define DEBUG_TYPE "asm-printer"
30
31#define GET_SUBTARGETINFO_ENUM
32#include "NVPTXGenSubtargetInfo.inc"
33
34#include "NVPTXGenAsmWriter.inc"
35
36static bool hasParamSubqualifiers(const MCSubtargetInfo &STI) {
37 return STI.hasFeature(Feature: NVPTX::PTX83);
38}
39
40NVPTXInstPrinter::NVPTXInstPrinter(const MCAsmInfo &MAI, const MCInstrInfo &MII,
41 const MCRegisterInfo &MRI)
42 : MCInstPrinter(MAI, MII, MRI) {}
43
44void NVPTXInstPrinter::printRegName(raw_ostream &OS, MCRegister Reg) {
45 // Decode the virtual register
46 // Must be kept in sync with NVPTXAsmPrinter::encodeVirtualRegister
47 unsigned RCId = (Reg.id() >> 28);
48 switch (RCId) {
49 default: report_fatal_error(reason: "Bad virtual register encoding");
50 case 0:
51 // This is actually a physical register, so defer to the autogenerated
52 // register printer
53 OS << getRegisterName(Reg);
54 return;
55 case 1:
56 OS << "%p";
57 break;
58 case 2:
59 OS << "%rs";
60 break;
61 case 3:
62 OS << "%r";
63 break;
64 case 4:
65 OS << "%rd";
66 break;
67 case 5:
68 OS << "%f";
69 break;
70 case 6:
71 OS << "%fd";
72 break;
73 case 7:
74 OS << "%rq";
75 break;
76 }
77
78 unsigned VReg = Reg.id() & 0x0FFFFFFF;
79 OS << VReg;
80}
81
82void NVPTXInstPrinter::printInst(const MCInst *MI, uint64_t Address,
83 StringRef Annot, const MCSubtargetInfo &STI,
84 raw_ostream &OS) {
85 printInstruction(MI, Address, STI, O&: OS);
86
87 // Next always print the annotation.
88 printAnnotation(OS, Annot);
89}
90
91void NVPTXInstPrinter::printOperand(const MCInst *MI, unsigned OpNo,
92 const MCSubtargetInfo &, raw_ostream &O) {
93 const MCOperand &Op = MI->getOperand(i: OpNo);
94 if (Op.isReg()) {
95 MCRegister Reg = Op.getReg();
96 printRegName(OS&: O, Reg);
97 } else if (Op.isImm()) {
98 markup(OS&: O, M: Markup::Immediate) << formatImm(Value: Op.getImm());
99 } else {
100 assert(Op.isExpr() && "Unknown operand kind in printOperand");
101 MAI.printExpr(O, *Op.getExpr());
102 }
103}
104
105void NVPTXInstPrinter::printCvtMode(const MCInst *MI, int OpNum,
106 const MCSubtargetInfo &, raw_ostream &O,
107 StringRef Modifier) {
108 const MCOperand &MO = MI->getOperand(i: OpNum);
109 int64_t Imm = MO.getImm();
110
111 if (Modifier == "ftz") {
112 // FTZ flag
113 if (Imm & NVPTX::PTXCvtMode::FTZ_FLAG)
114 O << ".ftz";
115 return;
116 } else if (Modifier == "sat") {
117 // SAT flag
118 if (Imm & NVPTX::PTXCvtMode::SAT_FLAG)
119 O << ".sat";
120 return;
121 } else if (Modifier == "satfinite") {
122 // SATFINITE flag
123 if (Imm & NVPTX::PTXCvtMode::SATFINITE_FLAG)
124 O << ".satfinite";
125 return;
126 } else if (Modifier == "relu") {
127 // RELU flag
128 if (Imm & NVPTX::PTXCvtMode::RELU_FLAG)
129 O << ".relu";
130 return;
131 } else if (Modifier == "base") {
132 // Default operand
133 switch (Imm & NVPTX::PTXCvtMode::BASE_MASK) {
134 default:
135 return;
136 case NVPTX::PTXCvtMode::NONE:
137 return;
138 case NVPTX::PTXCvtMode::RNI:
139 O << ".rni";
140 return;
141 case NVPTX::PTXCvtMode::RZI:
142 O << ".rzi";
143 return;
144 case NVPTX::PTXCvtMode::RMI:
145 O << ".rmi";
146 return;
147 case NVPTX::PTXCvtMode::RPI:
148 O << ".rpi";
149 return;
150 case NVPTX::PTXCvtMode::RN:
151 O << ".rn";
152 return;
153 case NVPTX::PTXCvtMode::RZ:
154 O << ".rz";
155 return;
156 case NVPTX::PTXCvtMode::RM:
157 O << ".rm";
158 return;
159 case NVPTX::PTXCvtMode::RP:
160 O << ".rp";
161 return;
162 case NVPTX::PTXCvtMode::RNA:
163 O << ".rna";
164 return;
165 case NVPTX::PTXCvtMode::RS:
166 O << ".rs";
167 return;
168 }
169 }
170 llvm_unreachable("Invalid conversion modifier");
171}
172
173void NVPTXInstPrinter::printFTZFlag(const MCInst *MI, int OpNum,
174 const MCSubtargetInfo &, raw_ostream &O) {
175 const MCOperand &MO = MI->getOperand(i: OpNum);
176 const int Imm = MO.getImm();
177 if (Imm)
178 O << ".ftz";
179}
180
181void NVPTXInstPrinter::printNegatedPredicate(const MCInst *MI, int OpNum,
182 const MCSubtargetInfo &,
183 raw_ostream &O) {
184 if (MI->getOperand(i: OpNum).getImm())
185 O << "!";
186}
187
188void NVPTXInstPrinter::printCmpMode(const MCInst *MI, int OpNum,
189 const MCSubtargetInfo &, raw_ostream &O,
190 StringRef Modifier) {
191 const MCOperand &MO = MI->getOperand(i: OpNum);
192 int64_t Imm = MO.getImm();
193
194 if (Modifier == "FCmp") {
195 switch (Imm) {
196 default:
197 return;
198 case NVPTX::PTXCmpMode::EQ:
199 O << "eq";
200 return;
201 case NVPTX::PTXCmpMode::NE:
202 O << "ne";
203 return;
204 case NVPTX::PTXCmpMode::LT:
205 O << "lt";
206 return;
207 case NVPTX::PTXCmpMode::LE:
208 O << "le";
209 return;
210 case NVPTX::PTXCmpMode::GT:
211 O << "gt";
212 return;
213 case NVPTX::PTXCmpMode::GE:
214 O << "ge";
215 return;
216 case NVPTX::PTXCmpMode::EQU:
217 O << "equ";
218 return;
219 case NVPTX::PTXCmpMode::NEU:
220 O << "neu";
221 return;
222 case NVPTX::PTXCmpMode::LTU:
223 O << "ltu";
224 return;
225 case NVPTX::PTXCmpMode::LEU:
226 O << "leu";
227 return;
228 case NVPTX::PTXCmpMode::GTU:
229 O << "gtu";
230 return;
231 case NVPTX::PTXCmpMode::GEU:
232 O << "geu";
233 return;
234 case NVPTX::PTXCmpMode::NUM:
235 O << "num";
236 return;
237 case NVPTX::PTXCmpMode::NotANumber:
238 O << "nan";
239 return;
240 }
241 }
242 if (Modifier == "ICmp") {
243 switch (Imm) {
244 default:
245 llvm_unreachable("Invalid ICmp mode");
246 case NVPTX::PTXCmpMode::EQ:
247 O << "eq";
248 return;
249 case NVPTX::PTXCmpMode::NE:
250 O << "ne";
251 return;
252 case NVPTX::PTXCmpMode::LT:
253 case NVPTX::PTXCmpMode::LTU:
254 O << "lt";
255 return;
256 case NVPTX::PTXCmpMode::LE:
257 case NVPTX::PTXCmpMode::LEU:
258 O << "le";
259 return;
260 case NVPTX::PTXCmpMode::GT:
261 case NVPTX::PTXCmpMode::GTU:
262 O << "gt";
263 return;
264 case NVPTX::PTXCmpMode::GE:
265 case NVPTX::PTXCmpMode::GEU:
266 O << "ge";
267 return;
268 }
269 }
270 if (Modifier == "IType") {
271 switch (Imm) {
272 default:
273 llvm_unreachable("Invalid IType");
274 case NVPTX::PTXCmpMode::EQ:
275 case NVPTX::PTXCmpMode::NE:
276 O << "b";
277 return;
278 case NVPTX::PTXCmpMode::LT:
279 case NVPTX::PTXCmpMode::LE:
280 case NVPTX::PTXCmpMode::GT:
281 case NVPTX::PTXCmpMode::GE:
282 O << "s";
283 return;
284 case NVPTX::PTXCmpMode::LTU:
285 case NVPTX::PTXCmpMode::LEU:
286 case NVPTX::PTXCmpMode::GTU:
287 case NVPTX::PTXCmpMode::GEU:
288 O << "u";
289 return;
290 }
291 }
292 llvm_unreachable("Empty Modifier");
293}
294
295void NVPTXInstPrinter::printAtomicCode(const MCInst *MI, int OpNum,
296 const MCSubtargetInfo &STI,
297 raw_ostream &O, StringRef Modifier) {
298 const MCOperand &MO = MI->getOperand(i: OpNum);
299 int Imm = (int)MO.getImm();
300 if (Modifier == "sem") {
301 auto Ordering = NVPTX::Ordering(Imm);
302 switch (Ordering) {
303 case NVPTX::Ordering::NotAtomic:
304 return;
305 case NVPTX::Ordering::Relaxed:
306 O << ".relaxed";
307 return;
308 case NVPTX::Ordering::Acquire:
309 O << ".acquire";
310 return;
311 case NVPTX::Ordering::Release:
312 O << ".release";
313 return;
314 case NVPTX::Ordering::AcquireRelease:
315 O << ".acq_rel";
316 return;
317 case NVPTX::Ordering::SequentiallyConsistent:
318 report_fatal_error(
319 reason: "NVPTX AtomicCode Printer does not support \"seq_cst\" ordering.");
320 return;
321 case NVPTX::Ordering::Volatile:
322 O << ".volatile";
323 return;
324 case NVPTX::Ordering::RelaxedMMIO:
325 O << ".mmio.relaxed";
326 return;
327 }
328 } else if (Modifier == "scope") {
329 auto S = NVPTX::Scope(Imm);
330 switch (S) {
331 case NVPTX::Scope::Thread:
332 case NVPTX::Scope::DefaultDevice:
333 return;
334 case NVPTX::Scope::System:
335 O << ".sys";
336 return;
337 case NVPTX::Scope::Block:
338 O << ".cta";
339 return;
340 case NVPTX::Scope::Cluster:
341 O << ".cluster";
342 return;
343 case NVPTX::Scope::Device:
344 O << ".gpu";
345 return;
346 }
347 report_fatal_error(reason: formatv(
348 Fmt: "NVPTX AtomicCode Printer does not support \"{}\" scope modifier.",
349 Vals: ScopeToString(S)));
350 } else if (Modifier == "addsp") {
351 auto A = NVPTX::AddressSpace(Imm);
352 switch (A) {
353 case NVPTX::AddressSpace::Generic:
354 return;
355 case NVPTX::AddressSpace::Global:
356 case NVPTX::AddressSpace::Const:
357 case NVPTX::AddressSpace::Shared:
358 case NVPTX::AddressSpace::SharedCluster:
359 case NVPTX::AddressSpace::EntryParam:
360 case NVPTX::AddressSpace::DeviceParam:
361 case NVPTX::AddressSpace::Local:
362 O << "." << addressSpaceToString(A, UseParamSubqualifiers: hasParamSubqualifiers(STI));
363 return;
364 }
365 report_fatal_error(reason: formatv(
366 Fmt: "NVPTX AtomicCode Printer does not support \"{}\" addsp modifier.",
367 Vals: addressSpaceToString(A)));
368 } else if (Modifier == "sign") {
369 switch (Imm) {
370 case NVPTX::PTXLdStInstCode::Signed:
371 O << "s";
372 return;
373 case NVPTX::PTXLdStInstCode::Unsigned:
374 O << "u";
375 return;
376 case NVPTX::PTXLdStInstCode::Untyped:
377 O << "b";
378 return;
379 case NVPTX::PTXLdStInstCode::Float:
380 O << "f";
381 return;
382 default:
383 llvm_unreachable("Unknown register type");
384 }
385 }
386 llvm_unreachable(formatv("Unknown Modifier: {}", Modifier).str().c_str());
387}
388
389void NVPTXInstPrinter::printMmaCode(const MCInst *MI, int OpNum,
390 const MCSubtargetInfo &, raw_ostream &O,
391 StringRef Modifier) {
392 const MCOperand &MO = MI->getOperand(i: OpNum);
393 int Imm = (int)MO.getImm();
394 if (Modifier.empty() || Modifier == "version") {
395 O << Imm; // Just print out PTX version
396 return;
397 } else if (Modifier == "aligned") {
398 // PTX63 requires '.aligned' in the name of the instruction.
399 if (Imm >= 63)
400 O << ".aligned";
401 return;
402 }
403 llvm_unreachable("Unknown Modifier");
404}
405
406void NVPTXInstPrinter::printMemOperand(const MCInst *MI, int OpNum,
407 const MCSubtargetInfo &STI,
408 raw_ostream &O, StringRef Modifier) {
409 printOperand(MI, OpNo: OpNum, STI, O);
410
411 if (Modifier == "add") {
412 O << ", ";
413 printOperand(MI, OpNo: OpNum + 1, STI, O);
414 } else {
415 if (MI->getOperand(i: OpNum + 1).isImm() &&
416 MI->getOperand(i: OpNum + 1).getImm() == 0)
417 return; // don't print ',0' or '+0'
418 O << "+";
419 printOperand(MI, OpNo: OpNum + 1, STI, O);
420 }
421}
422
423void NVPTXInstPrinter::printUsedBytesMaskPragma(const MCInst *MI, int OpNum,
424 const MCSubtargetInfo &,
425 raw_ostream &O) {
426 auto &Op = MI->getOperand(i: OpNum);
427 assert(Op.isImm() && "Invalid operand");
428 uint32_t Imm = (uint32_t)Op.getImm();
429 if (Imm != UINT32_MAX) {
430 O << ".pragma \"used_bytes_mask " << format_hex(N: Imm, Width: 1) << "\";\n\t";
431 }
432}
433
434void NVPTXInstPrinter::printRegisterOrSinkSymbol(const MCInst *MI, int OpNum,
435 const MCSubtargetInfo &STI,
436 raw_ostream &O) {
437 const MCOperand &Op = MI->getOperand(i: OpNum);
438 if (Op.isReg() && Op.getReg() == MCRegister::NoRegister)
439 O << "_";
440 else
441 printOperand(MI, OpNo: OpNum, STI, O);
442}
443
444void NVPTXInstPrinter::printHexu32imm(const MCInst *MI, int OpNum,
445 const MCSubtargetInfo &, raw_ostream &O) {
446 int64_t Imm = MI->getOperand(i: OpNum).getImm();
447 O << formatHex(Value: Imm) << "U";
448}
449
450void NVPTXInstPrinter::printProtoIdent(const MCInst *MI, int OpNum,
451 const MCSubtargetInfo &,
452 raw_ostream &O) {
453 const MCOperand &Op = MI->getOperand(i: OpNum);
454 assert(Op.isExpr() && "Call prototype is not an MCExpr?");
455 const MCExpr *Expr = Op.getExpr();
456 const MCSymbol &Sym = cast<MCSymbolRefExpr>(Val: Expr)->getSymbol();
457 O << Sym.getName();
458}
459
460void NVPTXInstPrinter::printPrmtMode(const MCInst *MI, int OpNum,
461 const MCSubtargetInfo &, raw_ostream &O) {
462 const MCOperand &MO = MI->getOperand(i: OpNum);
463 int64_t Imm = MO.getImm();
464
465 switch (Imm) {
466 default:
467 return;
468 case NVPTX::PTXPrmtMode::NONE:
469 return;
470 case NVPTX::PTXPrmtMode::F4E:
471 O << ".f4e";
472 return;
473 case NVPTX::PTXPrmtMode::B4E:
474 O << ".b4e";
475 return;
476 case NVPTX::PTXPrmtMode::RC8:
477 O << ".rc8";
478 return;
479 case NVPTX::PTXPrmtMode::ECL:
480 O << ".ecl";
481 return;
482 case NVPTX::PTXPrmtMode::ECR:
483 O << ".ecr";
484 return;
485 case NVPTX::PTXPrmtMode::RC16:
486 O << ".rc16";
487 return;
488 }
489}
490
491void NVPTXInstPrinter::printTmaReductionMode(const MCInst *MI, int OpNum,
492 const MCSubtargetInfo &,
493 raw_ostream &O) {
494 const MCOperand &MO = MI->getOperand(i: OpNum);
495 using RedTy = nvvm::TMAReductionOp;
496
497 switch (static_cast<RedTy>(MO.getImm())) {
498 case RedTy::ADD:
499 O << ".add";
500 return;
501 case RedTy::MIN:
502 O << ".min";
503 return;
504 case RedTy::MAX:
505 O << ".max";
506 return;
507 case RedTy::INC:
508 O << ".inc";
509 return;
510 case RedTy::DEC:
511 O << ".dec";
512 return;
513 case RedTy::AND:
514 O << ".and";
515 return;
516 case RedTy::OR:
517 O << ".or";
518 return;
519 case RedTy::XOR:
520 O << ".xor";
521 return;
522 }
523 llvm_unreachable(
524 "Invalid Reduction Op in printCpAsyncBulkTensorReductionMode");
525}
526
527void NVPTXInstPrinter::printCTAGroup(const MCInst *MI, int OpNum,
528 const MCSubtargetInfo &, raw_ostream &O) {
529 const MCOperand &MO = MI->getOperand(i: OpNum);
530 using CGTy = nvvm::CTAGroupKind;
531
532 switch (static_cast<CGTy>(MO.getImm())) {
533 case CGTy::CG_NONE:
534 O << "";
535 return;
536 case CGTy::CG_1:
537 O << ".cta_group::1";
538 return;
539 case CGTy::CG_2:
540 O << ".cta_group::2";
541 return;
542 }
543 llvm_unreachable("Invalid cta_group in printCTAGroup");
544}
545
546void NVPTXInstPrinter::printCallOperand(const MCInst *MI, int OpNum,
547 const MCSubtargetInfo &, raw_ostream &O,
548 StringRef Modifier) {
549 const MCOperand &MO = MI->getOperand(i: OpNum);
550 assert(MO.isImm() && "Invalid operand");
551 const auto Imm = MO.getImm();
552
553 if (Modifier == "RetList") {
554 assert((Imm == 1 || Imm == 0) && "Invalid return list");
555 if (Imm)
556 O << " (retval0),";
557 return;
558 }
559
560 if (Modifier == "ParamList") {
561 assert(Imm >= 0 && "Invalid parameter list");
562 interleaveComma(c: llvm::seq(Size: Imm), os&: O,
563 each_fn: [&](const auto &I) { O << "param" << I; });
564 return;
565 }
566 llvm_unreachable("Invalid modifier");
567}
568
569template <unsigned Bits>
570void NVPTXInstPrinter::printHexUImm(const MCInst *MI, int OpNum,
571 const MCSubtargetInfo &, raw_ostream &O) {
572 const MCOperand &MO = MI->getOperand(i: OpNum);
573 assert(MO.isImm() && "Expected immediate operand");
574 assert(isInt<Bits>(MO.getImm()) &&
575 "Immediate value does not fit in specified bits");
576 uint64_t Imm = MO.getImm();
577 Imm &= maskTrailingOnes<uint64_t>(N: Bits);
578 O << formatHex(Value: Imm) << "U";
579}
580