1//===- NVPTXInstrInfo.cpp - NVPTX Instruction Information -----------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file contains the NVPTX implementation of the TargetInstrInfo class.
10//
11//===----------------------------------------------------------------------===//
12
13#include "NVPTXInstrInfo.h"
14#include "NVPTX.h"
15#include "NVPTXSubtarget.h"
16#include "llvm/CodeGen/MachineFunction.h"
17#include "llvm/CodeGen/MachineInstrBuilder.h"
18#include "llvm/CodeGen/MachineRegisterInfo.h"
19
20using namespace llvm;
21
22#define GET_INSTRINFO_CTOR_DTOR
23#include "NVPTXGenInstrInfo.inc"
24
25// Pin the vtable to this file.
26void NVPTXInstrInfo::anchor() {}
27
28NVPTXInstrInfo::NVPTXInstrInfo(const NVPTXSubtarget &STI)
29 : NVPTXGenInstrInfo(STI, RegInfo), RegInfo() {}
30
31void NVPTXInstrInfo::copyPhysReg(MachineBasicBlock &MBB,
32 MachineBasicBlock::iterator I,
33 const DebugLoc &DL, Register DestReg,
34 Register SrcReg, bool KillSrc,
35 bool RenamableDest, bool RenamableSrc) const {
36 const MachineRegisterInfo &MRI = MBB.getParent()->getRegInfo();
37 const TargetRegisterClass *DestRC = MRI.getRegClass(Reg: DestReg);
38 const TargetRegisterClass *SrcRC = MRI.getRegClass(Reg: SrcReg);
39
40 if (DestRC != SrcRC)
41 report_fatal_error(reason: "Copy one register into another with a different width");
42
43 unsigned Op;
44 if (DestRC == &NVPTX::B1RegClass)
45 Op = NVPTX::MOV_B1_r;
46 else if (DestRC == &NVPTX::B16RegClass)
47 Op = NVPTX::MOV_B16_r;
48 else if (DestRC == &NVPTX::B32RegClass)
49 Op = NVPTX::MOV_B32_r;
50 else if (DestRC == &NVPTX::B64RegClass)
51 Op = NVPTX::MOV_B64_r;
52 else if (DestRC == &NVPTX::B128RegClass)
53 Op = NVPTX::MOV_B128_r;
54 else
55 llvm_unreachable("Bad register copy");
56
57 BuildMI(BB&: MBB, I, MIMD: DL, MCID: get(Opcode: Op), DestReg)
58 .addReg(RegNo: SrcReg, Flags: getKillRegState(B: KillSrc));
59}
60
61/// analyzeBranch - Analyze the branching code at the end of MBB, returning
62/// true if it cannot be understood (e.g. it's a switch dispatch or isn't
63/// implemented for a target). Upon success, this returns false and returns
64/// with the following information in various cases:
65///
66/// 1. If this block ends with no branches (it just falls through to its succ)
67/// just return false, leaving TBB/FBB null.
68/// 2. If this block ends with only an unconditional branch, it sets TBB to be
69/// the destination block.
70/// 3. If this block ends with an conditional branch and it falls through to
71/// an successor block, it sets TBB to be the branch destination block and a
72/// list of operands that evaluate the condition. These
73/// operands can be passed to other TargetInstrInfo methods to create new
74/// branches.
75/// 4. If this block ends with an conditional branch and an unconditional
76/// block, it returns the 'true' destination in TBB, the 'false' destination
77/// in FBB, and a list of operands that evaluate the condition. These
78/// operands can be passed to other TargetInstrInfo methods to create new
79/// branches.
80///
81/// Note that removeBranch and insertBranch must be implemented to support
82/// cases where this method returns success.
83///
84bool NVPTXInstrInfo::analyzeBranch(MachineBasicBlock &MBB,
85 MachineBasicBlock *&TBB,
86 MachineBasicBlock *&FBB,
87 SmallVectorImpl<MachineOperand> &Cond,
88 bool AllowModify) const {
89 // If the block has no terminators, it just falls into the block after it.
90 MachineBasicBlock::iterator I = MBB.end();
91 if (I == MBB.begin() || !isUnpredicatedTerminator(MI: *--I))
92 return false;
93
94 // Get the last instruction in the block.
95 MachineInstr &LastInst = *I;
96
97 // If there is only one terminator instruction, process it.
98 if (I == MBB.begin() || !isUnpredicatedTerminator(MI: *--I)) {
99 if (LastInst.getOpcode() == NVPTX::GOTO) {
100 TBB = LastInst.getOperand(i: 0).getMBB();
101 return false;
102 } else if (LastInst.getOpcode() == NVPTX::CBranch) {
103 // Block ends with fall-through condbranch.
104 TBB = LastInst.getOperand(i: 1).getMBB();
105 Cond.push_back(Elt: LastInst.getOperand(i: 0));
106 Cond.push_back(Elt: LastInst.getOperand(i: 2));
107 return false;
108 }
109 // Otherwise, don't know what this is.
110 return true;
111 }
112
113 // Get the instruction before it if it's a terminator.
114 MachineInstr &SecondLastInst = *I;
115
116 // If there are three terminators, we don't know what sort of block this is.
117 if (I != MBB.begin() && isUnpredicatedTerminator(MI: *--I))
118 return true;
119
120 // If the block ends with NVPTX::GOTO and NVPTX:CBranch, handle it.
121 if (SecondLastInst.getOpcode() == NVPTX::CBranch &&
122 LastInst.getOpcode() == NVPTX::GOTO) {
123 TBB = SecondLastInst.getOperand(i: 1).getMBB();
124 Cond.push_back(Elt: SecondLastInst.getOperand(i: 0));
125 Cond.push_back(Elt: SecondLastInst.getOperand(i: 2));
126 FBB = LastInst.getOperand(i: 0).getMBB();
127 return false;
128 }
129
130 // If the block ends with two NVPTX:GOTOs, handle it. The second one is not
131 // executed, so remove it.
132 if (SecondLastInst.getOpcode() == NVPTX::GOTO &&
133 LastInst.getOpcode() == NVPTX::GOTO) {
134 TBB = SecondLastInst.getOperand(i: 0).getMBB();
135 I = LastInst;
136 if (AllowModify)
137 I->eraseFromParent();
138 return false;
139 }
140
141 // Otherwise, can't handle this.
142 return true;
143}
144
145unsigned NVPTXInstrInfo::removeBranch(MachineBasicBlock &MBB,
146 int *BytesRemoved) const {
147 assert(!BytesRemoved && "code size not handled");
148 MachineBasicBlock::iterator I = MBB.end();
149 if (I == MBB.begin())
150 return 0;
151 --I;
152 if (I->getOpcode() != NVPTX::GOTO && I->getOpcode() != NVPTX::CBranch)
153 return 0;
154
155 // Remove the branch.
156 I->eraseFromParent();
157
158 I = MBB.end();
159
160 if (I == MBB.begin())
161 return 1;
162 --I;
163 if (I->getOpcode() != NVPTX::CBranch)
164 return 1;
165
166 // Remove the branch.
167 I->eraseFromParent();
168 return 2;
169}
170
171unsigned NVPTXInstrInfo::insertBranch(MachineBasicBlock &MBB,
172 MachineBasicBlock *TBB,
173 MachineBasicBlock *FBB,
174 ArrayRef<MachineOperand> Cond,
175 const DebugLoc &DL,
176 int *BytesAdded) const {
177 assert(!BytesAdded && "code size not handled");
178
179 // Shouldn't be a fall through.
180 assert(TBB && "insertBranch must not be told to insert a fallthrough");
181 assert((Cond.size() == 2 || Cond.size() == 0) &&
182 "NVPTX branch conditions have two components!");
183
184 // One-way branch.
185 if (!FBB) {
186 if (Cond.empty()) // Unconditional branch
187 BuildMI(BB: &MBB, MIMD: DL, MCID: get(Opcode: NVPTX::GOTO)).addMBB(MBB: TBB);
188 else // Conditional branch
189 BuildMI(BB: &MBB, MIMD: DL, MCID: get(Opcode: NVPTX::CBranch))
190 .add(MO: Cond[0])
191 .addMBB(MBB: TBB)
192 .add(MO: Cond[1]);
193 return 1;
194 }
195
196 // Two-way Conditional Branch.
197 BuildMI(BB: &MBB, MIMD: DL, MCID: get(Opcode: NVPTX::CBranch)).add(MO: Cond[0]).addMBB(MBB: TBB).add(MO: Cond[1]);
198 BuildMI(BB: &MBB, MIMD: DL, MCID: get(Opcode: NVPTX::GOTO)).addMBB(MBB: FBB);
199 return 2;
200}
201
202bool NVPTXInstrInfo::reverseBranchCondition(
203 SmallVectorImpl<MachineOperand> &Cond) const {
204 assert(Cond.size() == 2 && "Invalid NVPTX branch condition!");
205 Cond[1].setImm(!Cond[1].getImm());
206 return false;
207}
208
209bool NVPTXInstrInfo::invertPredicateBranchInstr(MachineBasicBlock &MBB) const {
210 MachineBasicBlock *TBB = nullptr, *FBB = nullptr;
211 SmallVector<MachineOperand, 4> Cond;
212 if (analyzeBranch(MBB, TBB, FBB, Cond, /*AllowModify=*/false))
213 return false;
214 if (Cond.empty())
215 return false;
216 if (reverseBranchCondition(Cond))
217 return false;
218 DebugLoc DL = MBB.findBranchDebugLoc();
219 removeBranch(MBB);
220 insertBranch(MBB, TBB, FBB, Cond, DL);
221 return true;
222}
223
224static bool isIntegerSetp(const MachineInstr &MI) {
225 switch (MI.getOpcode()) {
226 case NVPTX::SETP_i16rr:
227 case NVPTX::SETP_i16ri:
228 case NVPTX::SETP_i16ir:
229 case NVPTX::SETP_i32rr:
230 case NVPTX::SETP_i32ri:
231 case NVPTX::SETP_i32ir:
232 case NVPTX::SETP_i64rr:
233 case NVPTX::SETP_i64ri:
234 case NVPTX::SETP_i64ir:
235 return true;
236 default:
237 return false;
238 }
239}
240
241static bool isScalarFloatSetp(const MachineInstr &MI) {
242 switch (MI.getOpcode()) {
243 case NVPTX::SETP_bf16rr:
244 case NVPTX::SETP_f16rr:
245 case NVPTX::SETP_f32rr:
246 case NVPTX::SETP_f32ri:
247 case NVPTX::SETP_f32ir:
248 case NVPTX::SETP_f64rr:
249 case NVPTX::SETP_f64ri:
250 case NVPTX::SETP_f64ir:
251 return true;
252 default:
253 return false;
254 }
255}
256
257static int64_t invertIntegerCmpMode(int64_t Mode) {
258 switch (Mode) {
259 case NVPTX::PTXCmpMode::EQ:
260 return NVPTX::PTXCmpMode::NE;
261 case NVPTX::PTXCmpMode::NE:
262 return NVPTX::PTXCmpMode::EQ;
263 case NVPTX::PTXCmpMode::LT:
264 return NVPTX::PTXCmpMode::GE;
265 case NVPTX::PTXCmpMode::LE:
266 return NVPTX::PTXCmpMode::GT;
267 case NVPTX::PTXCmpMode::GT:
268 return NVPTX::PTXCmpMode::LE;
269 case NVPTX::PTXCmpMode::GE:
270 return NVPTX::PTXCmpMode::LT;
271 case NVPTX::PTXCmpMode::LTU:
272 return NVPTX::PTXCmpMode::GEU;
273 case NVPTX::PTXCmpMode::LEU:
274 return NVPTX::PTXCmpMode::GTU;
275 case NVPTX::PTXCmpMode::GTU:
276 return NVPTX::PTXCmpMode::LEU;
277 case NVPTX::PTXCmpMode::GEU:
278 return NVPTX::PTXCmpMode::LTU;
279 default:
280 llvm_unreachable("Invalid integer comparison mode");
281 }
282}
283
284static int64_t invertScalarFloatCmpMode(int64_t Mode) {
285 switch (Mode) {
286 case NVPTX::PTXCmpMode::EQ:
287 return NVPTX::PTXCmpMode::NEU;
288 case NVPTX::PTXCmpMode::NE:
289 return NVPTX::PTXCmpMode::EQU;
290 case NVPTX::PTXCmpMode::EQU:
291 return NVPTX::PTXCmpMode::NE;
292 case NVPTX::PTXCmpMode::NEU:
293 return NVPTX::PTXCmpMode::EQ;
294 case NVPTX::PTXCmpMode::LT:
295 return NVPTX::PTXCmpMode::GEU;
296 case NVPTX::PTXCmpMode::LE:
297 return NVPTX::PTXCmpMode::GTU;
298 case NVPTX::PTXCmpMode::GT:
299 return NVPTX::PTXCmpMode::LEU;
300 case NVPTX::PTXCmpMode::GE:
301 return NVPTX::PTXCmpMode::LTU;
302 case NVPTX::PTXCmpMode::LTU:
303 return NVPTX::PTXCmpMode::GE;
304 case NVPTX::PTXCmpMode::LEU:
305 return NVPTX::PTXCmpMode::GT;
306 case NVPTX::PTXCmpMode::GTU:
307 return NVPTX::PTXCmpMode::LE;
308 case NVPTX::PTXCmpMode::GEU:
309 return NVPTX::PTXCmpMode::LT;
310 case NVPTX::PTXCmpMode::NUM:
311 return NVPTX::PTXCmpMode::NotANumber;
312 case NVPTX::PTXCmpMode::NotANumber:
313 return NVPTX::PTXCmpMode::NUM;
314 default:
315 llvm_unreachable("Invalid scalar float comparison mode");
316 }
317}
318
319static void invertScalarCompareInstr(MachineInstr &MI) {
320 MachineOperand &ModeOp = MI.getOperand(i: 3);
321
322 if (isIntegerSetp(MI))
323 ModeOp.setImm(invertIntegerCmpMode(Mode: ModeOp.getImm()));
324 else if (isScalarFloatSetp(MI))
325 ModeOp.setImm(invertScalarFloatCmpMode(Mode: ModeOp.getImm()));
326 else
327 llvm_unreachable("Invalid SETP instruction");
328}
329
330bool NVPTXInstrInfo::findCommutedOpIndices(const MachineInstr &MI,
331 unsigned &SrcOpIdx1,
332 unsigned &SrcOpIdx2) const {
333 if (isIntegerSetp(MI) || isScalarFloatSetp(MI))
334 return fixCommutedOpIndices(ResultIdx1&: SrcOpIdx1, ResultIdx2&: SrcOpIdx2, CommutableOpIdx1: 1, CommutableOpIdx2: 2);
335 return TargetInstrInfo::findCommutedOpIndices(MI, SrcOpIdx1, SrcOpIdx2);
336}
337
338MachineInstr *NVPTXInstrInfo::commuteInstructionImpl(MachineInstr &MI,
339 bool NewMI,
340 unsigned OpIdx1,
341 unsigned OpIdx2) const {
342 assert(!NewMI && "this should never be used");
343
344 if (!isIntegerSetp(MI) && !isScalarFloatSetp(MI))
345 return TargetInstrInfo::commuteInstructionImpl(MI, NewMI, OpIdx1, OpIdx2);
346
347 // For now all users must be invertible conditional branches.
348 // TODO: Support other users such as selects.
349 MachineRegisterInfo &MRI = MI.getParent()->getParent()->getRegInfo();
350 SmallVector<MachineBasicBlock *, 4> BranchMBBs;
351 for (MachineInstr &UseMI :
352 MRI.use_nodbg_instructions(Reg: MI.getOperand(i: 0).getReg())) {
353 if (!UseMI.isConditionalBranch())
354 return nullptr;
355 BranchMBBs.push_back(Elt: UseMI.getParent());
356 }
357
358 invertScalarCompareInstr(MI);
359 auto *Failed = llvm::find_if(Range&: BranchMBBs, P: [this](MachineBasicBlock *MBB) {
360 return !invertPredicateBranchInstr(MBB&: *MBB);
361 });
362 if (Failed == BranchMBBs.end())
363 return &MI;
364
365 // Couldn't invert one of the branches. Roll back the prefix we
366 // already inverted and the compare-mode flip.
367 for (MachineBasicBlock *MBB : make_range(x: BranchMBBs.begin(), y: Failed))
368 invertPredicateBranchInstr(MBB&: *MBB);
369 invertScalarCompareInstr(MI);
370 return nullptr;
371}
372