1 | //===-- AArch64CondBrTuning.cpp --- Conditional branch tuning for AArch64 -===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | /// \file |
9 | /// This file contains a pass that transforms CBZ/CBNZ/TBZ/TBNZ instructions |
10 | /// into a conditional branch (B.cond), when the NZCV flags can be set for |
11 | /// "free". This is preferred on targets that have more flexibility when |
12 | /// scheduling B.cond instructions as compared to CBZ/CBNZ/TBZ/TBNZ (assuming |
13 | /// all other variables are equal). This can also reduce register pressure. |
14 | /// |
15 | /// A few examples: |
16 | /// |
17 | /// 1) add w8, w0, w1 -> cmn w0, w1 ; CMN is an alias of ADDS. |
18 | /// cbz w8, .LBB_2 -> b.eq .LBB0_2 |
19 | /// |
20 | /// 2) add w8, w0, w1 -> adds w8, w0, w1 ; w8 has multiple uses. |
21 | /// cbz w8, .LBB1_2 -> b.eq .LBB1_2 |
22 | /// |
23 | /// 3) sub w8, w0, w1 -> subs w8, w0, w1 ; w8 has multiple uses. |
24 | /// tbz w8, #31, .LBB6_2 -> b.pl .LBB6_2 |
25 | /// |
26 | //===----------------------------------------------------------------------===// |
27 | |
28 | #include "AArch64.h" |
29 | #include "AArch64Subtarget.h" |
30 | #include "llvm/CodeGen/MachineFunction.h" |
31 | #include "llvm/CodeGen/MachineFunctionPass.h" |
32 | #include "llvm/CodeGen/MachineInstrBuilder.h" |
33 | #include "llvm/CodeGen/MachineRegisterInfo.h" |
34 | #include "llvm/CodeGen/Passes.h" |
35 | #include "llvm/CodeGen/TargetInstrInfo.h" |
36 | #include "llvm/CodeGen/TargetRegisterInfo.h" |
37 | #include "llvm/CodeGen/TargetSubtargetInfo.h" |
38 | #include "llvm/Support/Debug.h" |
39 | #include "llvm/Support/raw_ostream.h" |
40 | |
41 | using namespace llvm; |
42 | |
43 | #define DEBUG_TYPE "aarch64-cond-br-tuning" |
44 | #define AARCH64_CONDBR_TUNING_NAME "AArch64 Conditional Branch Tuning" |
45 | |
46 | namespace { |
47 | class AArch64CondBrTuning : public MachineFunctionPass { |
48 | const AArch64InstrInfo *TII; |
49 | const TargetRegisterInfo *TRI; |
50 | |
51 | MachineRegisterInfo *MRI; |
52 | |
53 | public: |
54 | static char ID; |
55 | AArch64CondBrTuning() : MachineFunctionPass(ID) {} |
56 | void getAnalysisUsage(AnalysisUsage &AU) const override; |
57 | bool runOnMachineFunction(MachineFunction &MF) override; |
58 | StringRef getPassName() const override { return AARCH64_CONDBR_TUNING_NAME; } |
59 | |
60 | private: |
61 | MachineInstr *getOperandDef(const MachineOperand &MO); |
62 | MachineInstr *convertToFlagSetting(MachineInstr &MI, bool IsFlagSetting, |
63 | bool Is64Bit); |
64 | MachineInstr *convertToCondBr(MachineInstr &MI); |
65 | bool tryToTuneBranch(MachineInstr &MI, MachineInstr &DefMI); |
66 | }; |
67 | } // end anonymous namespace |
68 | |
69 | char AArch64CondBrTuning::ID = 0; |
70 | |
71 | INITIALIZE_PASS(AArch64CondBrTuning, "aarch64-cond-br-tuning" , |
72 | AARCH64_CONDBR_TUNING_NAME, false, false) |
73 | |
74 | void AArch64CondBrTuning::getAnalysisUsage(AnalysisUsage &AU) const { |
75 | AU.setPreservesCFG(); |
76 | MachineFunctionPass::getAnalysisUsage(AU); |
77 | } |
78 | |
79 | MachineInstr *AArch64CondBrTuning::getOperandDef(const MachineOperand &MO) { |
80 | if (!MO.getReg().isVirtual()) |
81 | return nullptr; |
82 | return MRI->getUniqueVRegDef(Reg: MO.getReg()); |
83 | } |
84 | |
85 | MachineInstr *AArch64CondBrTuning::convertToFlagSetting(MachineInstr &MI, |
86 | bool IsFlagSetting, |
87 | bool Is64Bit) { |
88 | // If this is already the flag setting version of the instruction (e.g., SUBS) |
89 | // just make sure the implicit-def of NZCV isn't marked dead. |
90 | if (IsFlagSetting) { |
91 | for (MachineOperand &MO : MI.implicit_operands()) |
92 | if (MO.isReg() && MO.isDead() && MO.getReg() == AArch64::NZCV) |
93 | MO.setIsDead(false); |
94 | return &MI; |
95 | } |
96 | unsigned NewOpc = TII->convertToFlagSettingOpc(Opc: MI.getOpcode()); |
97 | Register NewDestReg = MI.getOperand(i: 0).getReg(); |
98 | if (MRI->hasOneNonDBGUse(RegNo: MI.getOperand(i: 0).getReg())) |
99 | NewDestReg = Is64Bit ? AArch64::XZR : AArch64::WZR; |
100 | |
101 | MachineInstrBuilder MIB = BuildMI(BB&: *MI.getParent(), I&: MI, MIMD: MI.getDebugLoc(), |
102 | MCID: TII->get(Opcode: NewOpc), DestReg: NewDestReg); |
103 | |
104 | // If the MI has a debug instruction number, preserve that in the new Machine |
105 | // Instruction that is created. |
106 | if (MI.peekDebugInstrNum() != 0) |
107 | MIB->setDebugInstrNum(MI.peekDebugInstrNum()); |
108 | |
109 | for (const MachineOperand &MO : llvm::drop_begin(RangeOrContainer: MI.operands())) |
110 | MIB.add(MO); |
111 | |
112 | return MIB; |
113 | } |
114 | |
115 | MachineInstr *AArch64CondBrTuning::convertToCondBr(MachineInstr &MI) { |
116 | AArch64CC::CondCode CC; |
117 | MachineBasicBlock *TargetMBB = TII->getBranchDestBlock(MI); |
118 | switch (MI.getOpcode()) { |
119 | default: |
120 | llvm_unreachable("Unexpected opcode!" ); |
121 | |
122 | case AArch64::CBZW: |
123 | case AArch64::CBZX: |
124 | CC = AArch64CC::EQ; |
125 | break; |
126 | case AArch64::CBNZW: |
127 | case AArch64::CBNZX: |
128 | CC = AArch64CC::NE; |
129 | break; |
130 | case AArch64::TBZW: |
131 | case AArch64::TBZX: |
132 | CC = AArch64CC::PL; |
133 | break; |
134 | case AArch64::TBNZW: |
135 | case AArch64::TBNZX: |
136 | CC = AArch64CC::MI; |
137 | break; |
138 | } |
139 | return BuildMI(BB&: *MI.getParent(), I&: MI, MIMD: MI.getDebugLoc(), MCID: TII->get(Opcode: AArch64::Bcc)) |
140 | .addImm(Val: CC) |
141 | .addMBB(MBB: TargetMBB); |
142 | } |
143 | |
144 | bool AArch64CondBrTuning::tryToTuneBranch(MachineInstr &MI, |
145 | MachineInstr &DefMI) { |
146 | // We don't want NZCV bits live across blocks. |
147 | if (MI.getParent() != DefMI.getParent()) |
148 | return false; |
149 | |
150 | bool IsFlagSetting = true; |
151 | unsigned MIOpc = MI.getOpcode(); |
152 | MachineInstr *NewCmp = nullptr, *NewBr = nullptr; |
153 | switch (DefMI.getOpcode()) { |
154 | default: |
155 | return false; |
156 | case AArch64::ADDWri: |
157 | case AArch64::ADDWrr: |
158 | case AArch64::ADDWrs: |
159 | case AArch64::ADDWrx: |
160 | case AArch64::ANDWri: |
161 | case AArch64::ANDWrr: |
162 | case AArch64::ANDWrs: |
163 | case AArch64::BICWrr: |
164 | case AArch64::BICWrs: |
165 | case AArch64::SUBWri: |
166 | case AArch64::SUBWrr: |
167 | case AArch64::SUBWrs: |
168 | case AArch64::SUBWrx: |
169 | IsFlagSetting = false; |
170 | [[fallthrough]]; |
171 | case AArch64::ADDSWri: |
172 | case AArch64::ADDSWrr: |
173 | case AArch64::ADDSWrs: |
174 | case AArch64::ADDSWrx: |
175 | case AArch64::ANDSWri: |
176 | case AArch64::ANDSWrr: |
177 | case AArch64::ANDSWrs: |
178 | case AArch64::BICSWrr: |
179 | case AArch64::BICSWrs: |
180 | case AArch64::SUBSWri: |
181 | case AArch64::SUBSWrr: |
182 | case AArch64::SUBSWrs: |
183 | case AArch64::SUBSWrx: |
184 | switch (MIOpc) { |
185 | default: |
186 | llvm_unreachable("Unexpected opcode!" ); |
187 | |
188 | case AArch64::CBZW: |
189 | case AArch64::CBNZW: |
190 | case AArch64::TBZW: |
191 | case AArch64::TBNZW: |
192 | // Check to see if the TBZ/TBNZ is checking the sign bit. |
193 | if ((MIOpc == AArch64::TBZW || MIOpc == AArch64::TBNZW) && |
194 | MI.getOperand(i: 1).getImm() != 31) |
195 | return false; |
196 | |
197 | // There must not be any instruction between DefMI and MI that clobbers or |
198 | // reads NZCV. |
199 | if (isNZCVTouchedInInstructionRange(DefMI, UseMI: MI, TRI)) |
200 | return false; |
201 | LLVM_DEBUG(dbgs() << " Replacing instructions:\n " ); |
202 | LLVM_DEBUG(DefMI.print(dbgs())); |
203 | LLVM_DEBUG(dbgs() << " " ); |
204 | LLVM_DEBUG(MI.print(dbgs())); |
205 | |
206 | NewCmp = convertToFlagSetting(MI&: DefMI, IsFlagSetting, /*Is64Bit=*/false); |
207 | NewBr = convertToCondBr(MI); |
208 | break; |
209 | } |
210 | break; |
211 | |
212 | case AArch64::ADDXri: |
213 | case AArch64::ADDXrr: |
214 | case AArch64::ADDXrs: |
215 | case AArch64::ADDXrx: |
216 | case AArch64::ANDXri: |
217 | case AArch64::ANDXrr: |
218 | case AArch64::ANDXrs: |
219 | case AArch64::BICXrr: |
220 | case AArch64::BICXrs: |
221 | case AArch64::SUBXri: |
222 | case AArch64::SUBXrr: |
223 | case AArch64::SUBXrs: |
224 | case AArch64::SUBXrx: |
225 | IsFlagSetting = false; |
226 | [[fallthrough]]; |
227 | case AArch64::ADDSXri: |
228 | case AArch64::ADDSXrr: |
229 | case AArch64::ADDSXrs: |
230 | case AArch64::ADDSXrx: |
231 | case AArch64::ANDSXri: |
232 | case AArch64::ANDSXrr: |
233 | case AArch64::ANDSXrs: |
234 | case AArch64::BICSXrr: |
235 | case AArch64::BICSXrs: |
236 | case AArch64::SUBSXri: |
237 | case AArch64::SUBSXrr: |
238 | case AArch64::SUBSXrs: |
239 | case AArch64::SUBSXrx: |
240 | switch (MIOpc) { |
241 | default: |
242 | llvm_unreachable("Unexpected opcode!" ); |
243 | |
244 | case AArch64::CBZX: |
245 | case AArch64::CBNZX: |
246 | case AArch64::TBZX: |
247 | case AArch64::TBNZX: { |
248 | // Check to see if the TBZ/TBNZ is checking the sign bit. |
249 | if ((MIOpc == AArch64::TBZX || MIOpc == AArch64::TBNZX) && |
250 | MI.getOperand(i: 1).getImm() != 63) |
251 | return false; |
252 | // There must not be any instruction between DefMI and MI that clobbers or |
253 | // reads NZCV. |
254 | if (isNZCVTouchedInInstructionRange(DefMI, UseMI: MI, TRI)) |
255 | return false; |
256 | LLVM_DEBUG(dbgs() << " Replacing instructions:\n " ); |
257 | LLVM_DEBUG(DefMI.print(dbgs())); |
258 | LLVM_DEBUG(dbgs() << " " ); |
259 | LLVM_DEBUG(MI.print(dbgs())); |
260 | |
261 | NewCmp = convertToFlagSetting(MI&: DefMI, IsFlagSetting, /*Is64Bit=*/true); |
262 | NewBr = convertToCondBr(MI); |
263 | break; |
264 | } |
265 | } |
266 | break; |
267 | } |
268 | (void)NewCmp; (void)NewBr; |
269 | assert(NewCmp && NewBr && "Expected new instructions." ); |
270 | |
271 | LLVM_DEBUG(dbgs() << " with instruction:\n " ); |
272 | LLVM_DEBUG(NewCmp->print(dbgs())); |
273 | LLVM_DEBUG(dbgs() << " " ); |
274 | LLVM_DEBUG(NewBr->print(dbgs())); |
275 | |
276 | // If this was a flag setting version of the instruction, we use the original |
277 | // instruction by just clearing the dead marked on the implicit-def of NCZV. |
278 | // Therefore, we should not erase this instruction. |
279 | if (!IsFlagSetting) |
280 | DefMI.eraseFromParent(); |
281 | MI.eraseFromParent(); |
282 | return true; |
283 | } |
284 | |
285 | bool AArch64CondBrTuning::runOnMachineFunction(MachineFunction &MF) { |
286 | if (skipFunction(F: MF.getFunction())) |
287 | return false; |
288 | |
289 | LLVM_DEBUG( |
290 | dbgs() << "********** AArch64 Conditional Branch Tuning **********\n" |
291 | << "********** Function: " << MF.getName() << '\n'); |
292 | |
293 | TII = static_cast<const AArch64InstrInfo *>(MF.getSubtarget().getInstrInfo()); |
294 | TRI = MF.getSubtarget().getRegisterInfo(); |
295 | MRI = &MF.getRegInfo(); |
296 | |
297 | bool Changed = false; |
298 | for (MachineBasicBlock &MBB : MF) { |
299 | bool LocalChange = false; |
300 | for (MachineInstr &MI : MBB.terminators()) { |
301 | switch (MI.getOpcode()) { |
302 | default: |
303 | break; |
304 | case AArch64::CBZW: |
305 | case AArch64::CBZX: |
306 | case AArch64::CBNZW: |
307 | case AArch64::CBNZX: |
308 | case AArch64::TBZW: |
309 | case AArch64::TBZX: |
310 | case AArch64::TBNZW: |
311 | case AArch64::TBNZX: |
312 | MachineInstr *DefMI = getOperandDef(MO: MI.getOperand(i: 0)); |
313 | LocalChange = (DefMI && tryToTuneBranch(MI, DefMI&: *DefMI)); |
314 | break; |
315 | } |
316 | // If the optimization was successful, we can't optimize any other |
317 | // branches because doing so would clobber the NZCV flags. |
318 | if (LocalChange) { |
319 | Changed = true; |
320 | break; |
321 | } |
322 | } |
323 | } |
324 | return Changed; |
325 | } |
326 | |
327 | FunctionPass *llvm::createAArch64CondBrTuning() { |
328 | return new AArch64CondBrTuning(); |
329 | } |
330 | |