1 | //===-- AArch64CondBrTuning.cpp --- Conditional branch tuning for AArch64 -===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | /// \file |
9 | /// This file contains a pass that transforms CBZ/CBNZ/TBZ/TBNZ instructions |
10 | /// into a conditional branch (B.cond), when the NZCV flags can be set for |
11 | /// "free". This is preferred on targets that have more flexibility when |
12 | /// scheduling B.cond instructions as compared to CBZ/CBNZ/TBZ/TBNZ (assuming |
13 | /// all other variables are equal). This can also reduce register pressure. |
14 | /// |
15 | /// A few examples: |
16 | /// |
17 | /// 1) add w8, w0, w1 -> cmn w0, w1 ; CMN is an alias of ADDS. |
18 | /// cbz w8, .LBB_2 -> b.eq .LBB0_2 |
19 | /// |
20 | /// 2) add w8, w0, w1 -> adds w8, w0, w1 ; w8 has multiple uses. |
21 | /// cbz w8, .LBB1_2 -> b.eq .LBB1_2 |
22 | /// |
23 | /// 3) sub w8, w0, w1 -> subs w8, w0, w1 ; w8 has multiple uses. |
24 | /// tbz w8, #31, .LBB6_2 -> b.pl .LBB6_2 |
25 | /// |
26 | //===----------------------------------------------------------------------===// |
27 | |
28 | #include "AArch64.h" |
29 | #include "AArch64Subtarget.h" |
30 | #include "llvm/CodeGen/MachineFunction.h" |
31 | #include "llvm/CodeGen/MachineFunctionPass.h" |
32 | #include "llvm/CodeGen/MachineInstrBuilder.h" |
33 | #include "llvm/CodeGen/MachineRegisterInfo.h" |
34 | #include "llvm/CodeGen/Passes.h" |
35 | #include "llvm/CodeGen/TargetInstrInfo.h" |
36 | #include "llvm/CodeGen/TargetRegisterInfo.h" |
37 | #include "llvm/CodeGen/TargetSubtargetInfo.h" |
38 | #include "llvm/Support/Debug.h" |
39 | #include "llvm/Support/raw_ostream.h" |
40 | |
41 | using namespace llvm; |
42 | |
43 | #define DEBUG_TYPE "aarch64-cond-br-tuning" |
44 | #define AARCH64_CONDBR_TUNING_NAME "AArch64 Conditional Branch Tuning" |
45 | |
46 | namespace { |
47 | class AArch64CondBrTuning : public MachineFunctionPass { |
48 | const AArch64InstrInfo *TII; |
49 | const TargetRegisterInfo *TRI; |
50 | |
51 | MachineRegisterInfo *MRI; |
52 | |
53 | public: |
54 | static char ID; |
55 | AArch64CondBrTuning() : MachineFunctionPass(ID) { |
56 | initializeAArch64CondBrTuningPass(*PassRegistry::getPassRegistry()); |
57 | } |
58 | void getAnalysisUsage(AnalysisUsage &AU) const override; |
59 | bool runOnMachineFunction(MachineFunction &MF) override; |
60 | StringRef getPassName() const override { return AARCH64_CONDBR_TUNING_NAME; } |
61 | |
62 | private: |
63 | MachineInstr *getOperandDef(const MachineOperand &MO); |
64 | MachineInstr *convertToFlagSetting(MachineInstr &MI, bool IsFlagSetting, |
65 | bool Is64Bit); |
66 | MachineInstr *convertToCondBr(MachineInstr &MI); |
67 | bool tryToTuneBranch(MachineInstr &MI, MachineInstr &DefMI); |
68 | }; |
69 | } // end anonymous namespace |
70 | |
71 | char AArch64CondBrTuning::ID = 0; |
72 | |
73 | INITIALIZE_PASS(AArch64CondBrTuning, "aarch64-cond-br-tuning" , |
74 | AARCH64_CONDBR_TUNING_NAME, false, false) |
75 | |
76 | void AArch64CondBrTuning::getAnalysisUsage(AnalysisUsage &AU) const { |
77 | AU.setPreservesCFG(); |
78 | MachineFunctionPass::getAnalysisUsage(AU); |
79 | } |
80 | |
81 | MachineInstr *AArch64CondBrTuning::getOperandDef(const MachineOperand &MO) { |
82 | if (!MO.getReg().isVirtual()) |
83 | return nullptr; |
84 | return MRI->getUniqueVRegDef(Reg: MO.getReg()); |
85 | } |
86 | |
87 | MachineInstr *AArch64CondBrTuning::convertToFlagSetting(MachineInstr &MI, |
88 | bool IsFlagSetting, |
89 | bool Is64Bit) { |
90 | // If this is already the flag setting version of the instruction (e.g., SUBS) |
91 | // just make sure the implicit-def of NZCV isn't marked dead. |
92 | if (IsFlagSetting) { |
93 | for (MachineOperand &MO : MI.implicit_operands()) |
94 | if (MO.isReg() && MO.isDead() && MO.getReg() == AArch64::NZCV) |
95 | MO.setIsDead(false); |
96 | return &MI; |
97 | } |
98 | unsigned NewOpc = TII->convertToFlagSettingOpc(Opc: MI.getOpcode()); |
99 | Register NewDestReg = MI.getOperand(i: 0).getReg(); |
100 | if (MRI->hasOneNonDBGUse(RegNo: MI.getOperand(i: 0).getReg())) |
101 | NewDestReg = Is64Bit ? AArch64::XZR : AArch64::WZR; |
102 | |
103 | MachineInstrBuilder MIB = BuildMI(BB&: *MI.getParent(), I&: MI, MIMD: MI.getDebugLoc(), |
104 | MCID: TII->get(Opcode: NewOpc), DestReg: NewDestReg); |
105 | for (const MachineOperand &MO : llvm::drop_begin(RangeOrContainer: MI.operands())) |
106 | MIB.add(MO); |
107 | |
108 | return MIB; |
109 | } |
110 | |
111 | MachineInstr *AArch64CondBrTuning::convertToCondBr(MachineInstr &MI) { |
112 | AArch64CC::CondCode CC; |
113 | MachineBasicBlock *TargetMBB = TII->getBranchDestBlock(MI); |
114 | switch (MI.getOpcode()) { |
115 | default: |
116 | llvm_unreachable("Unexpected opcode!" ); |
117 | |
118 | case AArch64::CBZW: |
119 | case AArch64::CBZX: |
120 | CC = AArch64CC::EQ; |
121 | break; |
122 | case AArch64::CBNZW: |
123 | case AArch64::CBNZX: |
124 | CC = AArch64CC::NE; |
125 | break; |
126 | case AArch64::TBZW: |
127 | case AArch64::TBZX: |
128 | CC = AArch64CC::PL; |
129 | break; |
130 | case AArch64::TBNZW: |
131 | case AArch64::TBNZX: |
132 | CC = AArch64CC::MI; |
133 | break; |
134 | } |
135 | return BuildMI(BB&: *MI.getParent(), I&: MI, MIMD: MI.getDebugLoc(), MCID: TII->get(Opcode: AArch64::Bcc)) |
136 | .addImm(Val: CC) |
137 | .addMBB(MBB: TargetMBB); |
138 | } |
139 | |
140 | bool AArch64CondBrTuning::tryToTuneBranch(MachineInstr &MI, |
141 | MachineInstr &DefMI) { |
142 | // We don't want NZCV bits live across blocks. |
143 | if (MI.getParent() != DefMI.getParent()) |
144 | return false; |
145 | |
146 | bool IsFlagSetting = true; |
147 | unsigned MIOpc = MI.getOpcode(); |
148 | MachineInstr *NewCmp = nullptr, *NewBr = nullptr; |
149 | switch (DefMI.getOpcode()) { |
150 | default: |
151 | return false; |
152 | case AArch64::ADDWri: |
153 | case AArch64::ADDWrr: |
154 | case AArch64::ADDWrs: |
155 | case AArch64::ADDWrx: |
156 | case AArch64::ANDWri: |
157 | case AArch64::ANDWrr: |
158 | case AArch64::ANDWrs: |
159 | case AArch64::BICWrr: |
160 | case AArch64::BICWrs: |
161 | case AArch64::SUBWri: |
162 | case AArch64::SUBWrr: |
163 | case AArch64::SUBWrs: |
164 | case AArch64::SUBWrx: |
165 | IsFlagSetting = false; |
166 | [[fallthrough]]; |
167 | case AArch64::ADDSWri: |
168 | case AArch64::ADDSWrr: |
169 | case AArch64::ADDSWrs: |
170 | case AArch64::ADDSWrx: |
171 | case AArch64::ANDSWri: |
172 | case AArch64::ANDSWrr: |
173 | case AArch64::ANDSWrs: |
174 | case AArch64::BICSWrr: |
175 | case AArch64::BICSWrs: |
176 | case AArch64::SUBSWri: |
177 | case AArch64::SUBSWrr: |
178 | case AArch64::SUBSWrs: |
179 | case AArch64::SUBSWrx: |
180 | switch (MIOpc) { |
181 | default: |
182 | llvm_unreachable("Unexpected opcode!" ); |
183 | |
184 | case AArch64::CBZW: |
185 | case AArch64::CBNZW: |
186 | case AArch64::TBZW: |
187 | case AArch64::TBNZW: |
188 | // Check to see if the TBZ/TBNZ is checking the sign bit. |
189 | if ((MIOpc == AArch64::TBZW || MIOpc == AArch64::TBNZW) && |
190 | MI.getOperand(i: 1).getImm() != 31) |
191 | return false; |
192 | |
193 | // There must not be any instruction between DefMI and MI that clobbers or |
194 | // reads NZCV. |
195 | if (isNZCVTouchedInInstructionRange(DefMI, UseMI: MI, TRI)) |
196 | return false; |
197 | LLVM_DEBUG(dbgs() << " Replacing instructions:\n " ); |
198 | LLVM_DEBUG(DefMI.print(dbgs())); |
199 | LLVM_DEBUG(dbgs() << " " ); |
200 | LLVM_DEBUG(MI.print(dbgs())); |
201 | |
202 | NewCmp = convertToFlagSetting(MI&: DefMI, IsFlagSetting, /*Is64Bit=*/false); |
203 | NewBr = convertToCondBr(MI); |
204 | break; |
205 | } |
206 | break; |
207 | |
208 | case AArch64::ADDXri: |
209 | case AArch64::ADDXrr: |
210 | case AArch64::ADDXrs: |
211 | case AArch64::ADDXrx: |
212 | case AArch64::ANDXri: |
213 | case AArch64::ANDXrr: |
214 | case AArch64::ANDXrs: |
215 | case AArch64::BICXrr: |
216 | case AArch64::BICXrs: |
217 | case AArch64::SUBXri: |
218 | case AArch64::SUBXrr: |
219 | case AArch64::SUBXrs: |
220 | case AArch64::SUBXrx: |
221 | IsFlagSetting = false; |
222 | [[fallthrough]]; |
223 | case AArch64::ADDSXri: |
224 | case AArch64::ADDSXrr: |
225 | case AArch64::ADDSXrs: |
226 | case AArch64::ADDSXrx: |
227 | case AArch64::ANDSXri: |
228 | case AArch64::ANDSXrr: |
229 | case AArch64::ANDSXrs: |
230 | case AArch64::BICSXrr: |
231 | case AArch64::BICSXrs: |
232 | case AArch64::SUBSXri: |
233 | case AArch64::SUBSXrr: |
234 | case AArch64::SUBSXrs: |
235 | case AArch64::SUBSXrx: |
236 | switch (MIOpc) { |
237 | default: |
238 | llvm_unreachable("Unexpected opcode!" ); |
239 | |
240 | case AArch64::CBZX: |
241 | case AArch64::CBNZX: |
242 | case AArch64::TBZX: |
243 | case AArch64::TBNZX: { |
244 | // Check to see if the TBZ/TBNZ is checking the sign bit. |
245 | if ((MIOpc == AArch64::TBZX || MIOpc == AArch64::TBNZX) && |
246 | MI.getOperand(i: 1).getImm() != 63) |
247 | return false; |
248 | // There must not be any instruction between DefMI and MI that clobbers or |
249 | // reads NZCV. |
250 | if (isNZCVTouchedInInstructionRange(DefMI, UseMI: MI, TRI)) |
251 | return false; |
252 | LLVM_DEBUG(dbgs() << " Replacing instructions:\n " ); |
253 | LLVM_DEBUG(DefMI.print(dbgs())); |
254 | LLVM_DEBUG(dbgs() << " " ); |
255 | LLVM_DEBUG(MI.print(dbgs())); |
256 | |
257 | NewCmp = convertToFlagSetting(MI&: DefMI, IsFlagSetting, /*Is64Bit=*/true); |
258 | NewBr = convertToCondBr(MI); |
259 | break; |
260 | } |
261 | } |
262 | break; |
263 | } |
264 | (void)NewCmp; (void)NewBr; |
265 | assert(NewCmp && NewBr && "Expected new instructions." ); |
266 | |
267 | LLVM_DEBUG(dbgs() << " with instruction:\n " ); |
268 | LLVM_DEBUG(NewCmp->print(dbgs())); |
269 | LLVM_DEBUG(dbgs() << " " ); |
270 | LLVM_DEBUG(NewBr->print(dbgs())); |
271 | |
272 | // If this was a flag setting version of the instruction, we use the original |
273 | // instruction by just clearing the dead marked on the implicit-def of NCZV. |
274 | // Therefore, we should not erase this instruction. |
275 | if (!IsFlagSetting) |
276 | DefMI.eraseFromParent(); |
277 | MI.eraseFromParent(); |
278 | return true; |
279 | } |
280 | |
281 | bool AArch64CondBrTuning::runOnMachineFunction(MachineFunction &MF) { |
282 | if (skipFunction(F: MF.getFunction())) |
283 | return false; |
284 | |
285 | LLVM_DEBUG( |
286 | dbgs() << "********** AArch64 Conditional Branch Tuning **********\n" |
287 | << "********** Function: " << MF.getName() << '\n'); |
288 | |
289 | TII = static_cast<const AArch64InstrInfo *>(MF.getSubtarget().getInstrInfo()); |
290 | TRI = MF.getSubtarget().getRegisterInfo(); |
291 | MRI = &MF.getRegInfo(); |
292 | |
293 | bool Changed = false; |
294 | for (MachineBasicBlock &MBB : MF) { |
295 | bool LocalChange = false; |
296 | for (MachineInstr &MI : MBB.terminators()) { |
297 | switch (MI.getOpcode()) { |
298 | default: |
299 | break; |
300 | case AArch64::CBZW: |
301 | case AArch64::CBZX: |
302 | case AArch64::CBNZW: |
303 | case AArch64::CBNZX: |
304 | case AArch64::TBZW: |
305 | case AArch64::TBZX: |
306 | case AArch64::TBNZW: |
307 | case AArch64::TBNZX: |
308 | MachineInstr *DefMI = getOperandDef(MO: MI.getOperand(i: 0)); |
309 | LocalChange = (DefMI && tryToTuneBranch(MI, DefMI&: *DefMI)); |
310 | break; |
311 | } |
312 | // If the optimization was successful, we can't optimize any other |
313 | // branches because doing so would clobber the NZCV flags. |
314 | if (LocalChange) { |
315 | Changed = true; |
316 | break; |
317 | } |
318 | } |
319 | } |
320 | return Changed; |
321 | } |
322 | |
323 | FunctionPass *llvm::createAArch64CondBrTuning() { |
324 | return new AArch64CondBrTuning(); |
325 | } |
326 | |