1 | //===-- NVPTXInstPrinter.cpp - PTX assembly instruction printing ----------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // Print MCInst instructions to .ptx format. |
10 | // |
11 | //===----------------------------------------------------------------------===// |
12 | |
13 | #include "MCTargetDesc/NVPTXInstPrinter.h" |
14 | #include "NVPTX.h" |
15 | #include "NVPTXUtilities.h" |
16 | #include "llvm/ADT/StringRef.h" |
17 | #include "llvm/IR/NVVMIntrinsicUtils.h" |
18 | #include "llvm/MC/MCAsmInfo.h" |
19 | #include "llvm/MC/MCExpr.h" |
20 | #include "llvm/MC/MCInst.h" |
21 | #include "llvm/MC/MCInstrInfo.h" |
22 | #include "llvm/MC/MCSubtargetInfo.h" |
23 | #include "llvm/MC/MCSymbol.h" |
24 | #include "llvm/Support/ErrorHandling.h" |
25 | #include "llvm/Support/FormatVariadic.h" |
26 | #include <cctype> |
27 | using namespace llvm; |
28 | |
29 | #define DEBUG_TYPE "asm-printer" |
30 | |
31 | #include "NVPTXGenAsmWriter.inc" |
32 | |
33 | NVPTXInstPrinter::NVPTXInstPrinter(const MCAsmInfo &MAI, const MCInstrInfo &MII, |
34 | const MCRegisterInfo &MRI) |
35 | : MCInstPrinter(MAI, MII, MRI) {} |
36 | |
37 | void NVPTXInstPrinter::printRegName(raw_ostream &OS, MCRegister Reg) { |
38 | // Decode the virtual register |
39 | // Must be kept in sync with NVPTXAsmPrinter::encodeVirtualRegister |
40 | unsigned RCId = (Reg.id() >> 28); |
41 | switch (RCId) { |
42 | default: report_fatal_error(reason: "Bad virtual register encoding" ); |
43 | case 0: |
44 | // This is actually a physical register, so defer to the autogenerated |
45 | // register printer |
46 | OS << getRegisterName(Reg); |
47 | return; |
48 | case 1: |
49 | OS << "%p" ; |
50 | break; |
51 | case 2: |
52 | OS << "%rs" ; |
53 | break; |
54 | case 3: |
55 | OS << "%r" ; |
56 | break; |
57 | case 4: |
58 | OS << "%rd" ; |
59 | break; |
60 | case 5: |
61 | OS << "%f" ; |
62 | break; |
63 | case 6: |
64 | OS << "%fd" ; |
65 | break; |
66 | case 7: |
67 | OS << "%rq" ; |
68 | break; |
69 | } |
70 | |
71 | unsigned VReg = Reg.id() & 0x0FFFFFFF; |
72 | OS << VReg; |
73 | } |
74 | |
75 | void NVPTXInstPrinter::printInst(const MCInst *MI, uint64_t Address, |
76 | StringRef Annot, const MCSubtargetInfo &STI, |
77 | raw_ostream &OS) { |
78 | printInstruction(MI, Address, O&: OS); |
79 | |
80 | // Next always print the annotation. |
81 | printAnnotation(OS, Annot); |
82 | } |
83 | |
84 | void NVPTXInstPrinter::printOperand(const MCInst *MI, unsigned OpNo, |
85 | raw_ostream &O) { |
86 | const MCOperand &Op = MI->getOperand(i: OpNo); |
87 | if (Op.isReg()) { |
88 | MCRegister Reg = Op.getReg(); |
89 | printRegName(OS&: O, Reg); |
90 | } else if (Op.isImm()) { |
91 | markup(OS&: O, M: Markup::Immediate) << formatImm(Value: Op.getImm()); |
92 | } else { |
93 | assert(Op.isExpr() && "Unknown operand kind in printOperand" ); |
94 | MAI.printExpr(O, *Op.getExpr()); |
95 | } |
96 | } |
97 | |
98 | void NVPTXInstPrinter::printCvtMode(const MCInst *MI, int OpNum, raw_ostream &O, |
99 | StringRef Modifier) { |
100 | const MCOperand &MO = MI->getOperand(i: OpNum); |
101 | int64_t Imm = MO.getImm(); |
102 | |
103 | if (Modifier == "ftz" ) { |
104 | // FTZ flag |
105 | if (Imm & NVPTX::PTXCvtMode::FTZ_FLAG) |
106 | O << ".ftz" ; |
107 | return; |
108 | } else if (Modifier == "sat" ) { |
109 | // SAT flag |
110 | if (Imm & NVPTX::PTXCvtMode::SAT_FLAG) |
111 | O << ".sat" ; |
112 | return; |
113 | } else if (Modifier == "relu" ) { |
114 | // RELU flag |
115 | if (Imm & NVPTX::PTXCvtMode::RELU_FLAG) |
116 | O << ".relu" ; |
117 | return; |
118 | } else if (Modifier == "base" ) { |
119 | // Default operand |
120 | switch (Imm & NVPTX::PTXCvtMode::BASE_MASK) { |
121 | default: |
122 | return; |
123 | case NVPTX::PTXCvtMode::NONE: |
124 | return; |
125 | case NVPTX::PTXCvtMode::RNI: |
126 | O << ".rni" ; |
127 | return; |
128 | case NVPTX::PTXCvtMode::RZI: |
129 | O << ".rzi" ; |
130 | return; |
131 | case NVPTX::PTXCvtMode::RMI: |
132 | O << ".rmi" ; |
133 | return; |
134 | case NVPTX::PTXCvtMode::RPI: |
135 | O << ".rpi" ; |
136 | return; |
137 | case NVPTX::PTXCvtMode::RN: |
138 | O << ".rn" ; |
139 | return; |
140 | case NVPTX::PTXCvtMode::RZ: |
141 | O << ".rz" ; |
142 | return; |
143 | case NVPTX::PTXCvtMode::RM: |
144 | O << ".rm" ; |
145 | return; |
146 | case NVPTX::PTXCvtMode::RP: |
147 | O << ".rp" ; |
148 | return; |
149 | case NVPTX::PTXCvtMode::RNA: |
150 | O << ".rna" ; |
151 | return; |
152 | } |
153 | } |
154 | llvm_unreachable("Invalid conversion modifier" ); |
155 | } |
156 | |
157 | void NVPTXInstPrinter::printCmpMode(const MCInst *MI, int OpNum, raw_ostream &O, |
158 | StringRef Modifier) { |
159 | const MCOperand &MO = MI->getOperand(i: OpNum); |
160 | int64_t Imm = MO.getImm(); |
161 | |
162 | if (Modifier == "ftz" ) { |
163 | // FTZ flag |
164 | if (Imm & NVPTX::PTXCmpMode::FTZ_FLAG) |
165 | O << ".ftz" ; |
166 | return; |
167 | } else if (Modifier == "base" ) { |
168 | switch (Imm & NVPTX::PTXCmpMode::BASE_MASK) { |
169 | default: |
170 | return; |
171 | case NVPTX::PTXCmpMode::EQ: |
172 | O << ".eq" ; |
173 | return; |
174 | case NVPTX::PTXCmpMode::NE: |
175 | O << ".ne" ; |
176 | return; |
177 | case NVPTX::PTXCmpMode::LT: |
178 | O << ".lt" ; |
179 | return; |
180 | case NVPTX::PTXCmpMode::LE: |
181 | O << ".le" ; |
182 | return; |
183 | case NVPTX::PTXCmpMode::GT: |
184 | O << ".gt" ; |
185 | return; |
186 | case NVPTX::PTXCmpMode::GE: |
187 | O << ".ge" ; |
188 | return; |
189 | case NVPTX::PTXCmpMode::LO: |
190 | O << ".lo" ; |
191 | return; |
192 | case NVPTX::PTXCmpMode::LS: |
193 | O << ".ls" ; |
194 | return; |
195 | case NVPTX::PTXCmpMode::HI: |
196 | O << ".hi" ; |
197 | return; |
198 | case NVPTX::PTXCmpMode::HS: |
199 | O << ".hs" ; |
200 | return; |
201 | case NVPTX::PTXCmpMode::EQU: |
202 | O << ".equ" ; |
203 | return; |
204 | case NVPTX::PTXCmpMode::NEU: |
205 | O << ".neu" ; |
206 | return; |
207 | case NVPTX::PTXCmpMode::LTU: |
208 | O << ".ltu" ; |
209 | return; |
210 | case NVPTX::PTXCmpMode::LEU: |
211 | O << ".leu" ; |
212 | return; |
213 | case NVPTX::PTXCmpMode::GTU: |
214 | O << ".gtu" ; |
215 | return; |
216 | case NVPTX::PTXCmpMode::GEU: |
217 | O << ".geu" ; |
218 | return; |
219 | case NVPTX::PTXCmpMode::NUM: |
220 | O << ".num" ; |
221 | return; |
222 | case NVPTX::PTXCmpMode::NotANumber: |
223 | O << ".nan" ; |
224 | return; |
225 | } |
226 | } |
227 | llvm_unreachable("Empty Modifier" ); |
228 | } |
229 | |
230 | void NVPTXInstPrinter::printLdStCode(const MCInst *MI, int OpNum, |
231 | raw_ostream &O, StringRef Modifier) { |
232 | const MCOperand &MO = MI->getOperand(i: OpNum); |
233 | int Imm = (int)MO.getImm(); |
234 | if (Modifier == "sem" ) { |
235 | auto Ordering = NVPTX::Ordering(Imm); |
236 | switch (Ordering) { |
237 | case NVPTX::Ordering::NotAtomic: |
238 | return; |
239 | case NVPTX::Ordering::Relaxed: |
240 | O << ".relaxed" ; |
241 | return; |
242 | case NVPTX::Ordering::Acquire: |
243 | O << ".acquire" ; |
244 | return; |
245 | case NVPTX::Ordering::Release: |
246 | O << ".release" ; |
247 | return; |
248 | case NVPTX::Ordering::Volatile: |
249 | O << ".volatile" ; |
250 | return; |
251 | case NVPTX::Ordering::RelaxedMMIO: |
252 | O << ".mmio.relaxed" ; |
253 | return; |
254 | default: |
255 | report_fatal_error(reason: formatv( |
256 | Fmt: "NVPTX LdStCode Printer does not support \"{}\" sem modifier. " |
257 | "Loads/Stores cannot be AcquireRelease or SequentiallyConsistent." , |
258 | Vals: OrderingToString(Order: Ordering))); |
259 | } |
260 | } else if (Modifier == "scope" ) { |
261 | auto S = NVPTX::Scope(Imm); |
262 | switch (S) { |
263 | case NVPTX::Scope::Thread: |
264 | return; |
265 | case NVPTX::Scope::System: |
266 | O << ".sys" ; |
267 | return; |
268 | case NVPTX::Scope::Block: |
269 | O << ".cta" ; |
270 | return; |
271 | case NVPTX::Scope::Cluster: |
272 | O << ".cluster" ; |
273 | return; |
274 | case NVPTX::Scope::Device: |
275 | O << ".gpu" ; |
276 | return; |
277 | } |
278 | report_fatal_error( |
279 | reason: formatv(Fmt: "NVPTX LdStCode Printer does not support \"{}\" sco modifier." , |
280 | Vals: ScopeToString(S))); |
281 | } else if (Modifier == "addsp" ) { |
282 | auto A = NVPTX::AddressSpace(Imm); |
283 | switch (A) { |
284 | case NVPTX::AddressSpace::Generic: |
285 | return; |
286 | case NVPTX::AddressSpace::Global: |
287 | case NVPTX::AddressSpace::Const: |
288 | case NVPTX::AddressSpace::Shared: |
289 | case NVPTX::AddressSpace::SharedCluster: |
290 | case NVPTX::AddressSpace::Param: |
291 | case NVPTX::AddressSpace::Local: |
292 | O << "." << A; |
293 | return; |
294 | } |
295 | report_fatal_error(reason: formatv( |
296 | Fmt: "NVPTX LdStCode Printer does not support \"{}\" addsp modifier." , |
297 | Vals: AddressSpaceToString(A))); |
298 | } else if (Modifier == "sign" ) { |
299 | switch (Imm) { |
300 | case NVPTX::PTXLdStInstCode::Signed: |
301 | O << "s" ; |
302 | return; |
303 | case NVPTX::PTXLdStInstCode::Unsigned: |
304 | O << "u" ; |
305 | return; |
306 | case NVPTX::PTXLdStInstCode::Untyped: |
307 | O << "b" ; |
308 | return; |
309 | case NVPTX::PTXLdStInstCode::Float: |
310 | O << "f" ; |
311 | return; |
312 | default: |
313 | llvm_unreachable("Unknown register type" ); |
314 | } |
315 | } |
316 | llvm_unreachable(formatv("Unknown Modifier: {}" , Modifier).str().c_str()); |
317 | } |
318 | |
319 | void NVPTXInstPrinter::printMmaCode(const MCInst *MI, int OpNum, raw_ostream &O, |
320 | StringRef Modifier) { |
321 | const MCOperand &MO = MI->getOperand(i: OpNum); |
322 | int Imm = (int)MO.getImm(); |
323 | if (Modifier.empty() || Modifier == "version" ) { |
324 | O << Imm; // Just print out PTX version |
325 | return; |
326 | } else if (Modifier == "aligned" ) { |
327 | // PTX63 requires '.aligned' in the name of the instruction. |
328 | if (Imm >= 63) |
329 | O << ".aligned" ; |
330 | return; |
331 | } |
332 | llvm_unreachable("Unknown Modifier" ); |
333 | } |
334 | |
335 | void NVPTXInstPrinter::printMemOperand(const MCInst *MI, int OpNum, |
336 | raw_ostream &O, StringRef Modifier) { |
337 | printOperand(MI, OpNo: OpNum, O); |
338 | |
339 | if (Modifier == "add" ) { |
340 | O << ", " ; |
341 | printOperand(MI, OpNo: OpNum + 1, O); |
342 | } else { |
343 | if (MI->getOperand(i: OpNum + 1).isImm() && |
344 | MI->getOperand(i: OpNum + 1).getImm() == 0) |
345 | return; // don't print ',0' or '+0' |
346 | O << "+" ; |
347 | printOperand(MI, OpNo: OpNum + 1, O); |
348 | } |
349 | } |
350 | |
351 | void NVPTXInstPrinter::printOffseti32imm(const MCInst *MI, int OpNum, |
352 | raw_ostream &O) { |
353 | auto &Op = MI->getOperand(i: OpNum); |
354 | assert(Op.isImm() && "Invalid operand" ); |
355 | if (Op.getImm() != 0) { |
356 | O << "+" ; |
357 | printOperand(MI, OpNo: OpNum, O); |
358 | } |
359 | } |
360 | |
361 | void NVPTXInstPrinter::printHexu32imm(const MCInst *MI, int OpNum, |
362 | raw_ostream &O) { |
363 | int64_t Imm = MI->getOperand(i: OpNum).getImm(); |
364 | O << formatHex(Value: Imm) << "U" ; |
365 | } |
366 | |
367 | void NVPTXInstPrinter::printProtoIdent(const MCInst *MI, int OpNum, |
368 | raw_ostream &O) { |
369 | const MCOperand &Op = MI->getOperand(i: OpNum); |
370 | assert(Op.isExpr() && "Call prototype is not an MCExpr?" ); |
371 | const MCExpr *Expr = Op.getExpr(); |
372 | const MCSymbol &Sym = cast<MCSymbolRefExpr>(Val: Expr)->getSymbol(); |
373 | O << Sym.getName(); |
374 | } |
375 | |
376 | void NVPTXInstPrinter::printPrmtMode(const MCInst *MI, int OpNum, |
377 | raw_ostream &O) { |
378 | const MCOperand &MO = MI->getOperand(i: OpNum); |
379 | int64_t Imm = MO.getImm(); |
380 | |
381 | switch (Imm) { |
382 | default: |
383 | return; |
384 | case NVPTX::PTXPrmtMode::NONE: |
385 | return; |
386 | case NVPTX::PTXPrmtMode::F4E: |
387 | O << ".f4e" ; |
388 | return; |
389 | case NVPTX::PTXPrmtMode::B4E: |
390 | O << ".b4e" ; |
391 | return; |
392 | case NVPTX::PTXPrmtMode::RC8: |
393 | O << ".rc8" ; |
394 | return; |
395 | case NVPTX::PTXPrmtMode::ECL: |
396 | O << ".ecl" ; |
397 | return; |
398 | case NVPTX::PTXPrmtMode::ECR: |
399 | O << ".ecr" ; |
400 | return; |
401 | case NVPTX::PTXPrmtMode::RC16: |
402 | O << ".rc16" ; |
403 | return; |
404 | } |
405 | } |
406 | |
407 | void NVPTXInstPrinter::printTmaReductionMode(const MCInst *MI, int OpNum, |
408 | raw_ostream &O) { |
409 | const MCOperand &MO = MI->getOperand(i: OpNum); |
410 | using RedTy = nvvm::TMAReductionOp; |
411 | |
412 | switch (static_cast<RedTy>(MO.getImm())) { |
413 | case RedTy::ADD: |
414 | O << ".add" ; |
415 | return; |
416 | case RedTy::MIN: |
417 | O << ".min" ; |
418 | return; |
419 | case RedTy::MAX: |
420 | O << ".max" ; |
421 | return; |
422 | case RedTy::INC: |
423 | O << ".inc" ; |
424 | return; |
425 | case RedTy::DEC: |
426 | O << ".dec" ; |
427 | return; |
428 | case RedTy::AND: |
429 | O << ".and" ; |
430 | return; |
431 | case RedTy::OR: |
432 | O << ".or" ; |
433 | return; |
434 | case RedTy::XOR: |
435 | O << ".xor" ; |
436 | return; |
437 | } |
438 | llvm_unreachable( |
439 | "Invalid Reduction Op in printCpAsyncBulkTensorReductionMode" ); |
440 | } |
441 | |
442 | void NVPTXInstPrinter::printCTAGroup(const MCInst *MI, int OpNum, |
443 | raw_ostream &O) { |
444 | const MCOperand &MO = MI->getOperand(i: OpNum); |
445 | using CGTy = nvvm::CTAGroupKind; |
446 | |
447 | switch (static_cast<CGTy>(MO.getImm())) { |
448 | case CGTy::CG_NONE: |
449 | O << "" ; |
450 | return; |
451 | case CGTy::CG_1: |
452 | O << ".cta_group::1" ; |
453 | return; |
454 | case CGTy::CG_2: |
455 | O << ".cta_group::2" ; |
456 | return; |
457 | } |
458 | llvm_unreachable("Invalid cta_group in printCTAGroup" ); |
459 | } |
460 | |
461 | void NVPTXInstPrinter::printCallOperand(const MCInst *MI, int OpNum, |
462 | raw_ostream &O, StringRef Modifier) { |
463 | const MCOperand &MO = MI->getOperand(i: OpNum); |
464 | assert(MO.isImm() && "Invalid operand" ); |
465 | const auto Imm = MO.getImm(); |
466 | |
467 | if (Modifier == "RetList" ) { |
468 | assert((Imm == 1 || Imm == 0) && "Invalid return list" ); |
469 | if (Imm) |
470 | O << " (retval0)," ; |
471 | return; |
472 | } |
473 | |
474 | if (Modifier == "ParamList" ) { |
475 | assert(Imm >= 0 && "Invalid parameter list" ); |
476 | interleaveComma(c: llvm::seq(Size: Imm), os&: O, |
477 | each_fn: [&](const auto &I) { O << "param" << I; }); |
478 | return; |
479 | } |
480 | llvm_unreachable("Invalid modifier" ); |
481 | } |
482 | |