1//===-- NVPTXInstPrinter.cpp - PTX assembly instruction printing ----------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Print MCInst instructions to .ptx format.
10//
11//===----------------------------------------------------------------------===//
12
13#include "MCTargetDesc/NVPTXInstPrinter.h"
14#include "NVPTX.h"
15#include "NVPTXUtilities.h"
16#include "llvm/ADT/StringRef.h"
17#include "llvm/IR/NVVMIntrinsicUtils.h"
18#include "llvm/MC/MCAsmInfo.h"
19#include "llvm/MC/MCExpr.h"
20#include "llvm/MC/MCInst.h"
21#include "llvm/MC/MCInstrInfo.h"
22#include "llvm/MC/MCSubtargetInfo.h"
23#include "llvm/MC/MCSymbol.h"
24#include "llvm/Support/ErrorHandling.h"
25#include "llvm/Support/FormatVariadic.h"
26#include <cctype>
27using namespace llvm;
28
29#define DEBUG_TYPE "asm-printer"
30
31#include "NVPTXGenAsmWriter.inc"
32
33NVPTXInstPrinter::NVPTXInstPrinter(const MCAsmInfo &MAI, const MCInstrInfo &MII,
34 const MCRegisterInfo &MRI)
35 : MCInstPrinter(MAI, MII, MRI) {}
36
37void NVPTXInstPrinter::printRegName(raw_ostream &OS, MCRegister Reg) {
38 // Decode the virtual register
39 // Must be kept in sync with NVPTXAsmPrinter::encodeVirtualRegister
40 unsigned RCId = (Reg.id() >> 28);
41 switch (RCId) {
42 default: report_fatal_error(reason: "Bad virtual register encoding");
43 case 0:
44 // This is actually a physical register, so defer to the autogenerated
45 // register printer
46 OS << getRegisterName(Reg);
47 return;
48 case 1:
49 OS << "%p";
50 break;
51 case 2:
52 OS << "%rs";
53 break;
54 case 3:
55 OS << "%r";
56 break;
57 case 4:
58 OS << "%rd";
59 break;
60 case 5:
61 OS << "%f";
62 break;
63 case 6:
64 OS << "%fd";
65 break;
66 case 7:
67 OS << "%rq";
68 break;
69 }
70
71 unsigned VReg = Reg.id() & 0x0FFFFFFF;
72 OS << VReg;
73}
74
75void NVPTXInstPrinter::printInst(const MCInst *MI, uint64_t Address,
76 StringRef Annot, const MCSubtargetInfo &STI,
77 raw_ostream &OS) {
78 printInstruction(MI, Address, O&: OS);
79
80 // Next always print the annotation.
81 printAnnotation(OS, Annot);
82}
83
84void NVPTXInstPrinter::printOperand(const MCInst *MI, unsigned OpNo,
85 raw_ostream &O) {
86 const MCOperand &Op = MI->getOperand(i: OpNo);
87 if (Op.isReg()) {
88 MCRegister Reg = Op.getReg();
89 printRegName(OS&: O, Reg);
90 } else if (Op.isImm()) {
91 markup(OS&: O, M: Markup::Immediate) << formatImm(Value: Op.getImm());
92 } else {
93 assert(Op.isExpr() && "Unknown operand kind in printOperand");
94 MAI.printExpr(O, *Op.getExpr());
95 }
96}
97
98void NVPTXInstPrinter::printCvtMode(const MCInst *MI, int OpNum, raw_ostream &O,
99 StringRef Modifier) {
100 const MCOperand &MO = MI->getOperand(i: OpNum);
101 int64_t Imm = MO.getImm();
102
103 if (Modifier == "ftz") {
104 // FTZ flag
105 if (Imm & NVPTX::PTXCvtMode::FTZ_FLAG)
106 O << ".ftz";
107 return;
108 } else if (Modifier == "sat") {
109 // SAT flag
110 if (Imm & NVPTX::PTXCvtMode::SAT_FLAG)
111 O << ".sat";
112 return;
113 } else if (Modifier == "relu") {
114 // RELU flag
115 if (Imm & NVPTX::PTXCvtMode::RELU_FLAG)
116 O << ".relu";
117 return;
118 } else if (Modifier == "base") {
119 // Default operand
120 switch (Imm & NVPTX::PTXCvtMode::BASE_MASK) {
121 default:
122 return;
123 case NVPTX::PTXCvtMode::NONE:
124 return;
125 case NVPTX::PTXCvtMode::RNI:
126 O << ".rni";
127 return;
128 case NVPTX::PTXCvtMode::RZI:
129 O << ".rzi";
130 return;
131 case NVPTX::PTXCvtMode::RMI:
132 O << ".rmi";
133 return;
134 case NVPTX::PTXCvtMode::RPI:
135 O << ".rpi";
136 return;
137 case NVPTX::PTXCvtMode::RN:
138 O << ".rn";
139 return;
140 case NVPTX::PTXCvtMode::RZ:
141 O << ".rz";
142 return;
143 case NVPTX::PTXCvtMode::RM:
144 O << ".rm";
145 return;
146 case NVPTX::PTXCvtMode::RP:
147 O << ".rp";
148 return;
149 case NVPTX::PTXCvtMode::RNA:
150 O << ".rna";
151 return;
152 }
153 }
154 llvm_unreachable("Invalid conversion modifier");
155}
156
157void NVPTXInstPrinter::printCmpMode(const MCInst *MI, int OpNum, raw_ostream &O,
158 StringRef Modifier) {
159 const MCOperand &MO = MI->getOperand(i: OpNum);
160 int64_t Imm = MO.getImm();
161
162 if (Modifier == "ftz") {
163 // FTZ flag
164 if (Imm & NVPTX::PTXCmpMode::FTZ_FLAG)
165 O << ".ftz";
166 return;
167 } else if (Modifier == "base") {
168 switch (Imm & NVPTX::PTXCmpMode::BASE_MASK) {
169 default:
170 return;
171 case NVPTX::PTXCmpMode::EQ:
172 O << ".eq";
173 return;
174 case NVPTX::PTXCmpMode::NE:
175 O << ".ne";
176 return;
177 case NVPTX::PTXCmpMode::LT:
178 O << ".lt";
179 return;
180 case NVPTX::PTXCmpMode::LE:
181 O << ".le";
182 return;
183 case NVPTX::PTXCmpMode::GT:
184 O << ".gt";
185 return;
186 case NVPTX::PTXCmpMode::GE:
187 O << ".ge";
188 return;
189 case NVPTX::PTXCmpMode::LO:
190 O << ".lo";
191 return;
192 case NVPTX::PTXCmpMode::LS:
193 O << ".ls";
194 return;
195 case NVPTX::PTXCmpMode::HI:
196 O << ".hi";
197 return;
198 case NVPTX::PTXCmpMode::HS:
199 O << ".hs";
200 return;
201 case NVPTX::PTXCmpMode::EQU:
202 O << ".equ";
203 return;
204 case NVPTX::PTXCmpMode::NEU:
205 O << ".neu";
206 return;
207 case NVPTX::PTXCmpMode::LTU:
208 O << ".ltu";
209 return;
210 case NVPTX::PTXCmpMode::LEU:
211 O << ".leu";
212 return;
213 case NVPTX::PTXCmpMode::GTU:
214 O << ".gtu";
215 return;
216 case NVPTX::PTXCmpMode::GEU:
217 O << ".geu";
218 return;
219 case NVPTX::PTXCmpMode::NUM:
220 O << ".num";
221 return;
222 case NVPTX::PTXCmpMode::NotANumber:
223 O << ".nan";
224 return;
225 }
226 }
227 llvm_unreachable("Empty Modifier");
228}
229
230void NVPTXInstPrinter::printLdStCode(const MCInst *MI, int OpNum,
231 raw_ostream &O, StringRef Modifier) {
232 const MCOperand &MO = MI->getOperand(i: OpNum);
233 int Imm = (int)MO.getImm();
234 if (Modifier == "sem") {
235 auto Ordering = NVPTX::Ordering(Imm);
236 switch (Ordering) {
237 case NVPTX::Ordering::NotAtomic:
238 return;
239 case NVPTX::Ordering::Relaxed:
240 O << ".relaxed";
241 return;
242 case NVPTX::Ordering::Acquire:
243 O << ".acquire";
244 return;
245 case NVPTX::Ordering::Release:
246 O << ".release";
247 return;
248 case NVPTX::Ordering::Volatile:
249 O << ".volatile";
250 return;
251 case NVPTX::Ordering::RelaxedMMIO:
252 O << ".mmio.relaxed";
253 return;
254 default:
255 report_fatal_error(reason: formatv(
256 Fmt: "NVPTX LdStCode Printer does not support \"{}\" sem modifier. "
257 "Loads/Stores cannot be AcquireRelease or SequentiallyConsistent.",
258 Vals: OrderingToString(Order: Ordering)));
259 }
260 } else if (Modifier == "scope") {
261 auto S = NVPTX::Scope(Imm);
262 switch (S) {
263 case NVPTX::Scope::Thread:
264 return;
265 case NVPTX::Scope::System:
266 O << ".sys";
267 return;
268 case NVPTX::Scope::Block:
269 O << ".cta";
270 return;
271 case NVPTX::Scope::Cluster:
272 O << ".cluster";
273 return;
274 case NVPTX::Scope::Device:
275 O << ".gpu";
276 return;
277 }
278 report_fatal_error(
279 reason: formatv(Fmt: "NVPTX LdStCode Printer does not support \"{}\" sco modifier.",
280 Vals: ScopeToString(S)));
281 } else if (Modifier == "addsp") {
282 auto A = NVPTX::AddressSpace(Imm);
283 switch (A) {
284 case NVPTX::AddressSpace::Generic:
285 return;
286 case NVPTX::AddressSpace::Global:
287 case NVPTX::AddressSpace::Const:
288 case NVPTX::AddressSpace::Shared:
289 case NVPTX::AddressSpace::SharedCluster:
290 case NVPTX::AddressSpace::Param:
291 case NVPTX::AddressSpace::Local:
292 O << "." << A;
293 return;
294 }
295 report_fatal_error(reason: formatv(
296 Fmt: "NVPTX LdStCode Printer does not support \"{}\" addsp modifier.",
297 Vals: AddressSpaceToString(A)));
298 } else if (Modifier == "sign") {
299 switch (Imm) {
300 case NVPTX::PTXLdStInstCode::Signed:
301 O << "s";
302 return;
303 case NVPTX::PTXLdStInstCode::Unsigned:
304 O << "u";
305 return;
306 case NVPTX::PTXLdStInstCode::Untyped:
307 O << "b";
308 return;
309 case NVPTX::PTXLdStInstCode::Float:
310 O << "f";
311 return;
312 default:
313 llvm_unreachable("Unknown register type");
314 }
315 }
316 llvm_unreachable(formatv("Unknown Modifier: {}", Modifier).str().c_str());
317}
318
319void NVPTXInstPrinter::printMmaCode(const MCInst *MI, int OpNum, raw_ostream &O,
320 StringRef Modifier) {
321 const MCOperand &MO = MI->getOperand(i: OpNum);
322 int Imm = (int)MO.getImm();
323 if (Modifier.empty() || Modifier == "version") {
324 O << Imm; // Just print out PTX version
325 return;
326 } else if (Modifier == "aligned") {
327 // PTX63 requires '.aligned' in the name of the instruction.
328 if (Imm >= 63)
329 O << ".aligned";
330 return;
331 }
332 llvm_unreachable("Unknown Modifier");
333}
334
335void NVPTXInstPrinter::printMemOperand(const MCInst *MI, int OpNum,
336 raw_ostream &O, StringRef Modifier) {
337 printOperand(MI, OpNo: OpNum, O);
338
339 if (Modifier == "add") {
340 O << ", ";
341 printOperand(MI, OpNo: OpNum + 1, O);
342 } else {
343 if (MI->getOperand(i: OpNum + 1).isImm() &&
344 MI->getOperand(i: OpNum + 1).getImm() == 0)
345 return; // don't print ',0' or '+0'
346 O << "+";
347 printOperand(MI, OpNo: OpNum + 1, O);
348 }
349}
350
351void NVPTXInstPrinter::printOffseti32imm(const MCInst *MI, int OpNum,
352 raw_ostream &O) {
353 auto &Op = MI->getOperand(i: OpNum);
354 assert(Op.isImm() && "Invalid operand");
355 if (Op.getImm() != 0) {
356 O << "+";
357 printOperand(MI, OpNo: OpNum, O);
358 }
359}
360
361void NVPTXInstPrinter::printHexu32imm(const MCInst *MI, int OpNum,
362 raw_ostream &O) {
363 int64_t Imm = MI->getOperand(i: OpNum).getImm();
364 O << formatHex(Value: Imm) << "U";
365}
366
367void NVPTXInstPrinter::printProtoIdent(const MCInst *MI, int OpNum,
368 raw_ostream &O) {
369 const MCOperand &Op = MI->getOperand(i: OpNum);
370 assert(Op.isExpr() && "Call prototype is not an MCExpr?");
371 const MCExpr *Expr = Op.getExpr();
372 const MCSymbol &Sym = cast<MCSymbolRefExpr>(Val: Expr)->getSymbol();
373 O << Sym.getName();
374}
375
376void NVPTXInstPrinter::printPrmtMode(const MCInst *MI, int OpNum,
377 raw_ostream &O) {
378 const MCOperand &MO = MI->getOperand(i: OpNum);
379 int64_t Imm = MO.getImm();
380
381 switch (Imm) {
382 default:
383 return;
384 case NVPTX::PTXPrmtMode::NONE:
385 return;
386 case NVPTX::PTXPrmtMode::F4E:
387 O << ".f4e";
388 return;
389 case NVPTX::PTXPrmtMode::B4E:
390 O << ".b4e";
391 return;
392 case NVPTX::PTXPrmtMode::RC8:
393 O << ".rc8";
394 return;
395 case NVPTX::PTXPrmtMode::ECL:
396 O << ".ecl";
397 return;
398 case NVPTX::PTXPrmtMode::ECR:
399 O << ".ecr";
400 return;
401 case NVPTX::PTXPrmtMode::RC16:
402 O << ".rc16";
403 return;
404 }
405}
406
407void NVPTXInstPrinter::printTmaReductionMode(const MCInst *MI, int OpNum,
408 raw_ostream &O) {
409 const MCOperand &MO = MI->getOperand(i: OpNum);
410 using RedTy = nvvm::TMAReductionOp;
411
412 switch (static_cast<RedTy>(MO.getImm())) {
413 case RedTy::ADD:
414 O << ".add";
415 return;
416 case RedTy::MIN:
417 O << ".min";
418 return;
419 case RedTy::MAX:
420 O << ".max";
421 return;
422 case RedTy::INC:
423 O << ".inc";
424 return;
425 case RedTy::DEC:
426 O << ".dec";
427 return;
428 case RedTy::AND:
429 O << ".and";
430 return;
431 case RedTy::OR:
432 O << ".or";
433 return;
434 case RedTy::XOR:
435 O << ".xor";
436 return;
437 }
438 llvm_unreachable(
439 "Invalid Reduction Op in printCpAsyncBulkTensorReductionMode");
440}
441
442void NVPTXInstPrinter::printCTAGroup(const MCInst *MI, int OpNum,
443 raw_ostream &O) {
444 const MCOperand &MO = MI->getOperand(i: OpNum);
445 using CGTy = nvvm::CTAGroupKind;
446
447 switch (static_cast<CGTy>(MO.getImm())) {
448 case CGTy::CG_NONE:
449 O << "";
450 return;
451 case CGTy::CG_1:
452 O << ".cta_group::1";
453 return;
454 case CGTy::CG_2:
455 O << ".cta_group::2";
456 return;
457 }
458 llvm_unreachable("Invalid cta_group in printCTAGroup");
459}
460
461void NVPTXInstPrinter::printCallOperand(const MCInst *MI, int OpNum,
462 raw_ostream &O, StringRef Modifier) {
463 const MCOperand &MO = MI->getOperand(i: OpNum);
464 assert(MO.isImm() && "Invalid operand");
465 const auto Imm = MO.getImm();
466
467 if (Modifier == "RetList") {
468 assert((Imm == 1 || Imm == 0) && "Invalid return list");
469 if (Imm)
470 O << " (retval0),";
471 return;
472 }
473
474 if (Modifier == "ParamList") {
475 assert(Imm >= 0 && "Invalid parameter list");
476 interleaveComma(c: llvm::seq(Size: Imm), os&: O,
477 each_fn: [&](const auto &I) { O << "param" << I; });
478 return;
479 }
480 llvm_unreachable("Invalid modifier");
481}
482