1//===-- NVPTXInstPrinter.cpp - PTX assembly instruction printing ----------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Print MCInst instructions to .ptx format.
10//
11//===----------------------------------------------------------------------===//
12
13#include "MCTargetDesc/NVPTXInstPrinter.h"
14#include "NVPTX.h"
15#include "NVPTXUtilities.h"
16#include "llvm/ADT/StringRef.h"
17#include "llvm/IR/NVVMIntrinsicUtils.h"
18#include "llvm/MC/MCAsmInfo.h"
19#include "llvm/MC/MCExpr.h"
20#include "llvm/MC/MCInst.h"
21#include "llvm/MC/MCInstrInfo.h"
22#include "llvm/MC/MCSubtargetInfo.h"
23#include "llvm/MC/MCSymbol.h"
24#include "llvm/Support/ErrorHandling.h"
25#include "llvm/Support/FormatVariadic.h"
26#include <cctype>
27using namespace llvm;
28
29#define DEBUG_TYPE "asm-printer"
30
31#include "NVPTXGenAsmWriter.inc"
32
33NVPTXInstPrinter::NVPTXInstPrinter(const MCAsmInfo &MAI, const MCInstrInfo &MII,
34 const MCRegisterInfo &MRI)
35 : MCInstPrinter(MAI, MII, MRI) {}
36
37void NVPTXInstPrinter::printRegName(raw_ostream &OS, MCRegister Reg) {
38 // Decode the virtual register
39 // Must be kept in sync with NVPTXAsmPrinter::encodeVirtualRegister
40 unsigned RCId = (Reg.id() >> 28);
41 switch (RCId) {
42 default: report_fatal_error(reason: "Bad virtual register encoding");
43 case 0:
44 // This is actually a physical register, so defer to the autogenerated
45 // register printer
46 OS << getRegisterName(Reg);
47 return;
48 case 1:
49 OS << "%p";
50 break;
51 case 2:
52 OS << "%rs";
53 break;
54 case 3:
55 OS << "%r";
56 break;
57 case 4:
58 OS << "%rd";
59 break;
60 case 5:
61 OS << "%f";
62 break;
63 case 6:
64 OS << "%fd";
65 break;
66 case 7:
67 OS << "%rq";
68 break;
69 }
70
71 unsigned VReg = Reg.id() & 0x0FFFFFFF;
72 OS << VReg;
73}
74
75void NVPTXInstPrinter::printInst(const MCInst *MI, uint64_t Address,
76 StringRef Annot, const MCSubtargetInfo &STI,
77 raw_ostream &OS) {
78 printInstruction(MI, Address, O&: OS);
79
80 // Next always print the annotation.
81 printAnnotation(OS, Annot);
82}
83
84void NVPTXInstPrinter::printOperand(const MCInst *MI, unsigned OpNo,
85 raw_ostream &O) {
86 const MCOperand &Op = MI->getOperand(i: OpNo);
87 if (Op.isReg()) {
88 MCRegister Reg = Op.getReg();
89 printRegName(OS&: O, Reg);
90 } else if (Op.isImm()) {
91 markup(OS&: O, M: Markup::Immediate) << formatImm(Value: Op.getImm());
92 } else {
93 assert(Op.isExpr() && "Unknown operand kind in printOperand");
94 MAI.printExpr(O, *Op.getExpr());
95 }
96}
97
98void NVPTXInstPrinter::printCvtMode(const MCInst *MI, int OpNum, raw_ostream &O,
99 StringRef Modifier) {
100 const MCOperand &MO = MI->getOperand(i: OpNum);
101 int64_t Imm = MO.getImm();
102
103 if (Modifier == "ftz") {
104 // FTZ flag
105 if (Imm & NVPTX::PTXCvtMode::FTZ_FLAG)
106 O << ".ftz";
107 return;
108 } else if (Modifier == "sat") {
109 // SAT flag
110 if (Imm & NVPTX::PTXCvtMode::SAT_FLAG)
111 O << ".sat";
112 return;
113 } else if (Modifier == "relu") {
114 // RELU flag
115 if (Imm & NVPTX::PTXCvtMode::RELU_FLAG)
116 O << ".relu";
117 return;
118 } else if (Modifier == "base") {
119 // Default operand
120 switch (Imm & NVPTX::PTXCvtMode::BASE_MASK) {
121 default:
122 return;
123 case NVPTX::PTXCvtMode::NONE:
124 return;
125 case NVPTX::PTXCvtMode::RNI:
126 O << ".rni";
127 return;
128 case NVPTX::PTXCvtMode::RZI:
129 O << ".rzi";
130 return;
131 case NVPTX::PTXCvtMode::RMI:
132 O << ".rmi";
133 return;
134 case NVPTX::PTXCvtMode::RPI:
135 O << ".rpi";
136 return;
137 case NVPTX::PTXCvtMode::RN:
138 O << ".rn";
139 return;
140 case NVPTX::PTXCvtMode::RZ:
141 O << ".rz";
142 return;
143 case NVPTX::PTXCvtMode::RM:
144 O << ".rm";
145 return;
146 case NVPTX::PTXCvtMode::RP:
147 O << ".rp";
148 return;
149 case NVPTX::PTXCvtMode::RNA:
150 O << ".rna";
151 return;
152 case NVPTX::PTXCvtMode::RS:
153 O << ".rs";
154 return;
155 }
156 }
157 llvm_unreachable("Invalid conversion modifier");
158}
159
160void NVPTXInstPrinter::printFTZFlag(const MCInst *MI, int OpNum,
161 raw_ostream &O) {
162 const MCOperand &MO = MI->getOperand(i: OpNum);
163 const int Imm = MO.getImm();
164 if (Imm)
165 O << ".ftz";
166}
167
168void NVPTXInstPrinter::printCmpMode(const MCInst *MI, int OpNum, raw_ostream &O,
169 StringRef Modifier) {
170 const MCOperand &MO = MI->getOperand(i: OpNum);
171 int64_t Imm = MO.getImm();
172
173 if (Modifier == "FCmp") {
174 switch (Imm) {
175 default:
176 return;
177 case NVPTX::PTXCmpMode::EQ:
178 O << "eq";
179 return;
180 case NVPTX::PTXCmpMode::NE:
181 O << "ne";
182 return;
183 case NVPTX::PTXCmpMode::LT:
184 O << "lt";
185 return;
186 case NVPTX::PTXCmpMode::LE:
187 O << "le";
188 return;
189 case NVPTX::PTXCmpMode::GT:
190 O << "gt";
191 return;
192 case NVPTX::PTXCmpMode::GE:
193 O << "ge";
194 return;
195 case NVPTX::PTXCmpMode::EQU:
196 O << "equ";
197 return;
198 case NVPTX::PTXCmpMode::NEU:
199 O << "neu";
200 return;
201 case NVPTX::PTXCmpMode::LTU:
202 O << "ltu";
203 return;
204 case NVPTX::PTXCmpMode::LEU:
205 O << "leu";
206 return;
207 case NVPTX::PTXCmpMode::GTU:
208 O << "gtu";
209 return;
210 case NVPTX::PTXCmpMode::GEU:
211 O << "geu";
212 return;
213 case NVPTX::PTXCmpMode::NUM:
214 O << "num";
215 return;
216 case NVPTX::PTXCmpMode::NotANumber:
217 O << "nan";
218 return;
219 }
220 }
221 if (Modifier == "ICmp") {
222 switch (Imm) {
223 default:
224 llvm_unreachable("Invalid ICmp mode");
225 case NVPTX::PTXCmpMode::EQ:
226 O << "eq";
227 return;
228 case NVPTX::PTXCmpMode::NE:
229 O << "ne";
230 return;
231 case NVPTX::PTXCmpMode::LT:
232 case NVPTX::PTXCmpMode::LTU:
233 O << "lt";
234 return;
235 case NVPTX::PTXCmpMode::LE:
236 case NVPTX::PTXCmpMode::LEU:
237 O << "le";
238 return;
239 case NVPTX::PTXCmpMode::GT:
240 case NVPTX::PTXCmpMode::GTU:
241 O << "gt";
242 return;
243 case NVPTX::PTXCmpMode::GE:
244 case NVPTX::PTXCmpMode::GEU:
245 O << "ge";
246 return;
247 }
248 }
249 if (Modifier == "IType") {
250 switch (Imm) {
251 default:
252 llvm_unreachable("Invalid IType");
253 case NVPTX::PTXCmpMode::EQ:
254 case NVPTX::PTXCmpMode::NE:
255 O << "b";
256 return;
257 case NVPTX::PTXCmpMode::LT:
258 case NVPTX::PTXCmpMode::LE:
259 case NVPTX::PTXCmpMode::GT:
260 case NVPTX::PTXCmpMode::GE:
261 O << "s";
262 return;
263 case NVPTX::PTXCmpMode::LTU:
264 case NVPTX::PTXCmpMode::LEU:
265 case NVPTX::PTXCmpMode::GTU:
266 case NVPTX::PTXCmpMode::GEU:
267 O << "u";
268 return;
269 }
270 }
271 llvm_unreachable("Empty Modifier");
272}
273
274void NVPTXInstPrinter::printAtomicCode(const MCInst *MI, int OpNum,
275 raw_ostream &O, StringRef Modifier) {
276 const MCOperand &MO = MI->getOperand(i: OpNum);
277 int Imm = (int)MO.getImm();
278 if (Modifier == "sem") {
279 auto Ordering = NVPTX::Ordering(Imm);
280 switch (Ordering) {
281 case NVPTX::Ordering::NotAtomic:
282 return;
283 case NVPTX::Ordering::Relaxed:
284 O << ".relaxed";
285 return;
286 case NVPTX::Ordering::Acquire:
287 O << ".acquire";
288 return;
289 case NVPTX::Ordering::Release:
290 O << ".release";
291 return;
292 case NVPTX::Ordering::AcquireRelease:
293 O << ".acq_rel";
294 return;
295 case NVPTX::Ordering::SequentiallyConsistent:
296 report_fatal_error(
297 reason: "NVPTX AtomicCode Printer does not support \"seq_cst\" ordering.");
298 return;
299 case NVPTX::Ordering::Volatile:
300 O << ".volatile";
301 return;
302 case NVPTX::Ordering::RelaxedMMIO:
303 O << ".mmio.relaxed";
304 return;
305 }
306 } else if (Modifier == "scope") {
307 auto S = NVPTX::Scope(Imm);
308 switch (S) {
309 case NVPTX::Scope::Thread:
310 case NVPTX::Scope::DefaultDevice:
311 return;
312 case NVPTX::Scope::System:
313 O << ".sys";
314 return;
315 case NVPTX::Scope::Block:
316 O << ".cta";
317 return;
318 case NVPTX::Scope::Cluster:
319 O << ".cluster";
320 return;
321 case NVPTX::Scope::Device:
322 O << ".gpu";
323 return;
324 }
325 report_fatal_error(reason: formatv(
326 Fmt: "NVPTX AtomicCode Printer does not support \"{}\" scope modifier.",
327 Vals: ScopeToString(S)));
328 } else if (Modifier == "addsp") {
329 auto A = NVPTX::AddressSpace(Imm);
330 switch (A) {
331 case NVPTX::AddressSpace::Generic:
332 return;
333 case NVPTX::AddressSpace::Global:
334 case NVPTX::AddressSpace::Const:
335 case NVPTX::AddressSpace::Shared:
336 case NVPTX::AddressSpace::SharedCluster:
337 case NVPTX::AddressSpace::Param:
338 case NVPTX::AddressSpace::Local:
339 O << "." << A;
340 return;
341 }
342 report_fatal_error(reason: formatv(
343 Fmt: "NVPTX AtomicCode Printer does not support \"{}\" addsp modifier.",
344 Vals: AddressSpaceToString(A)));
345 } else if (Modifier == "sign") {
346 switch (Imm) {
347 case NVPTX::PTXLdStInstCode::Signed:
348 O << "s";
349 return;
350 case NVPTX::PTXLdStInstCode::Unsigned:
351 O << "u";
352 return;
353 case NVPTX::PTXLdStInstCode::Untyped:
354 O << "b";
355 return;
356 case NVPTX::PTXLdStInstCode::Float:
357 O << "f";
358 return;
359 default:
360 llvm_unreachable("Unknown register type");
361 }
362 }
363 llvm_unreachable(formatv("Unknown Modifier: {}", Modifier).str().c_str());
364}
365
366void NVPTXInstPrinter::printMmaCode(const MCInst *MI, int OpNum, raw_ostream &O,
367 StringRef Modifier) {
368 const MCOperand &MO = MI->getOperand(i: OpNum);
369 int Imm = (int)MO.getImm();
370 if (Modifier.empty() || Modifier == "version") {
371 O << Imm; // Just print out PTX version
372 return;
373 } else if (Modifier == "aligned") {
374 // PTX63 requires '.aligned' in the name of the instruction.
375 if (Imm >= 63)
376 O << ".aligned";
377 return;
378 }
379 llvm_unreachable("Unknown Modifier");
380}
381
382void NVPTXInstPrinter::printMemOperand(const MCInst *MI, int OpNum,
383 raw_ostream &O, StringRef Modifier) {
384 printOperand(MI, OpNo: OpNum, O);
385
386 if (Modifier == "add") {
387 O << ", ";
388 printOperand(MI, OpNo: OpNum + 1, O);
389 } else {
390 if (MI->getOperand(i: OpNum + 1).isImm() &&
391 MI->getOperand(i: OpNum + 1).getImm() == 0)
392 return; // don't print ',0' or '+0'
393 O << "+";
394 printOperand(MI, OpNo: OpNum + 1, O);
395 }
396}
397
398void NVPTXInstPrinter::printUsedBytesMaskPragma(const MCInst *MI, int OpNum,
399 raw_ostream &O) {
400 auto &Op = MI->getOperand(i: OpNum);
401 assert(Op.isImm() && "Invalid operand");
402 uint32_t Imm = (uint32_t)Op.getImm();
403 if (Imm != UINT32_MAX) {
404 O << ".pragma \"used_bytes_mask " << format_hex(N: Imm, Width: 1) << "\";\n\t";
405 }
406}
407
408void NVPTXInstPrinter::printRegisterOrSinkSymbol(const MCInst *MI, int OpNum,
409 raw_ostream &O) {
410 const MCOperand &Op = MI->getOperand(i: OpNum);
411 if (Op.isReg() && Op.getReg() == MCRegister::NoRegister)
412 O << "_";
413 else
414 printOperand(MI, OpNo: OpNum, O);
415}
416
417void NVPTXInstPrinter::printHexu32imm(const MCInst *MI, int OpNum,
418 raw_ostream &O) {
419 int64_t Imm = MI->getOperand(i: OpNum).getImm();
420 O << formatHex(Value: Imm) << "U";
421}
422
423void NVPTXInstPrinter::printProtoIdent(const MCInst *MI, int OpNum,
424 raw_ostream &O) {
425 const MCOperand &Op = MI->getOperand(i: OpNum);
426 assert(Op.isExpr() && "Call prototype is not an MCExpr?");
427 const MCExpr *Expr = Op.getExpr();
428 const MCSymbol &Sym = cast<MCSymbolRefExpr>(Val: Expr)->getSymbol();
429 O << Sym.getName();
430}
431
432void NVPTXInstPrinter::printPrmtMode(const MCInst *MI, int OpNum,
433 raw_ostream &O) {
434 const MCOperand &MO = MI->getOperand(i: OpNum);
435 int64_t Imm = MO.getImm();
436
437 switch (Imm) {
438 default:
439 return;
440 case NVPTX::PTXPrmtMode::NONE:
441 return;
442 case NVPTX::PTXPrmtMode::F4E:
443 O << ".f4e";
444 return;
445 case NVPTX::PTXPrmtMode::B4E:
446 O << ".b4e";
447 return;
448 case NVPTX::PTXPrmtMode::RC8:
449 O << ".rc8";
450 return;
451 case NVPTX::PTXPrmtMode::ECL:
452 O << ".ecl";
453 return;
454 case NVPTX::PTXPrmtMode::ECR:
455 O << ".ecr";
456 return;
457 case NVPTX::PTXPrmtMode::RC16:
458 O << ".rc16";
459 return;
460 }
461}
462
463void NVPTXInstPrinter::printTmaReductionMode(const MCInst *MI, int OpNum,
464 raw_ostream &O) {
465 const MCOperand &MO = MI->getOperand(i: OpNum);
466 using RedTy = nvvm::TMAReductionOp;
467
468 switch (static_cast<RedTy>(MO.getImm())) {
469 case RedTy::ADD:
470 O << ".add";
471 return;
472 case RedTy::MIN:
473 O << ".min";
474 return;
475 case RedTy::MAX:
476 O << ".max";
477 return;
478 case RedTy::INC:
479 O << ".inc";
480 return;
481 case RedTy::DEC:
482 O << ".dec";
483 return;
484 case RedTy::AND:
485 O << ".and";
486 return;
487 case RedTy::OR:
488 O << ".or";
489 return;
490 case RedTy::XOR:
491 O << ".xor";
492 return;
493 }
494 llvm_unreachable(
495 "Invalid Reduction Op in printCpAsyncBulkTensorReductionMode");
496}
497
498void NVPTXInstPrinter::printCTAGroup(const MCInst *MI, int OpNum,
499 raw_ostream &O) {
500 const MCOperand &MO = MI->getOperand(i: OpNum);
501 using CGTy = nvvm::CTAGroupKind;
502
503 switch (static_cast<CGTy>(MO.getImm())) {
504 case CGTy::CG_NONE:
505 O << "";
506 return;
507 case CGTy::CG_1:
508 O << ".cta_group::1";
509 return;
510 case CGTy::CG_2:
511 O << ".cta_group::2";
512 return;
513 }
514 llvm_unreachable("Invalid cta_group in printCTAGroup");
515}
516
517void NVPTXInstPrinter::printCallOperand(const MCInst *MI, int OpNum,
518 raw_ostream &O, StringRef Modifier) {
519 const MCOperand &MO = MI->getOperand(i: OpNum);
520 assert(MO.isImm() && "Invalid operand");
521 const auto Imm = MO.getImm();
522
523 if (Modifier == "RetList") {
524 assert((Imm == 1 || Imm == 0) && "Invalid return list");
525 if (Imm)
526 O << " (retval0),";
527 return;
528 }
529
530 if (Modifier == "ParamList") {
531 assert(Imm >= 0 && "Invalid parameter list");
532 interleaveComma(c: llvm::seq(Size: Imm), os&: O,
533 each_fn: [&](const auto &I) { O << "param" << I; });
534 return;
535 }
536 llvm_unreachable("Invalid modifier");
537}
538
539template <unsigned Bits>
540void NVPTXInstPrinter::printHexUImm(const MCInst *MI, int OpNum,
541 raw_ostream &O) {
542 const MCOperand &MO = MI->getOperand(i: OpNum);
543 assert(MO.isImm() && "Expected immediate operand");
544 assert(isInt<Bits>(MO.getImm()) &&
545 "Immediate value does not fit in specified bits");
546 uint64_t Imm = MO.getImm();
547 Imm &= maskTrailingOnes<uint64_t>(N: Bits);
548 O << formatHex(Value: Imm) << "U";
549}
550