1//=- AArch64ConditionOptimizer.cpp - Remove useless comparisons for AArch64 -=//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9//
10// This pass tries to make consecutive comparisons of values use the same
11// operands to allow the CSE pass to remove duplicate instructions. It adjusts
12// comparisons with immediate values by converting between inclusive and
13// exclusive forms (GE <-> GT, LE <-> LT) and correcting immediate values to
14// make them equal.
15//
16// The pass handles:
17// * Cross-block: SUBS/ADDS followed by conditional branches
18// * Intra-block: CSINC conditional instructions
19//
20//
21// Consider the following example in C:
22//
23// if ((a < 5 && ...) || (a > 5 && ...)) {
24// ~~~~~ ~~~~~
25// ^ ^
26// x y
27//
28// Here both "x" and "y" expressions compare "a" with "5". When "x" evaluates
29// to "false", "y" can just check flags set by the first comparison. As a
30// result of the canonicalization employed by
31// SelectionDAGBuilder::visitSwitchCase, DAGCombine, and other target-specific
32// code, assembly ends up in the form that is not CSE friendly:
33//
34// ...
35// cmp w8, #4
36// b.gt .LBB0_3
37// ...
38// .LBB0_3:
39// cmp w8, #6
40// b.lt .LBB0_6
41// ...
42//
43// Same assembly after the pass:
44//
45// ...
46// cmp w8, #5
47// b.ge .LBB0_3
48// ...
49// .LBB0_3:
50// cmp w8, #5 // <-- CSE pass removes this instruction
51// b.le .LBB0_6
52// ...
53//
54// See optimizeCrossBlock() and optimizeIntraBlock() for implementation details.
55//
56// TODO: maybe handle TBNZ/TBZ the same way as CMP when used instead for "a < 0"
57// TODO: For cross-block:
58// - handle other conditional instructions (e.g. CSET)
59// - allow second branching to be anything if it doesn't require adjusting
60// TODO: For intra-block:
61// - handle CINC and CSET (CSINC aliases) as their conditions are inverted
62// compared to CSINC.
63// - handle other non-CSINC conditional instructions
64//
65//===----------------------------------------------------------------------===//
66
67#include "AArch64.h"
68#include "AArch64Subtarget.h"
69#include "MCTargetDesc/AArch64AddressingModes.h"
70#include "Utils/AArch64BaseInfo.h"
71#include "llvm/ADT/ArrayRef.h"
72#include "llvm/ADT/DepthFirstIterator.h"
73#include "llvm/ADT/SmallVector.h"
74#include "llvm/ADT/Statistic.h"
75#include "llvm/CodeGen/MachineBasicBlock.h"
76#include "llvm/CodeGen/MachineDominators.h"
77#include "llvm/CodeGen/MachineFunction.h"
78#include "llvm/CodeGen/MachineFunctionPass.h"
79#include "llvm/CodeGen/MachineInstr.h"
80#include "llvm/CodeGen/MachineOperand.h"
81#include "llvm/CodeGen/MachineRegisterInfo.h"
82#include "llvm/CodeGen/TargetInstrInfo.h"
83#include "llvm/CodeGen/TargetRegisterInfo.h"
84#include "llvm/CodeGen/TargetSubtargetInfo.h"
85#include "llvm/InitializePasses.h"
86#include "llvm/Pass.h"
87#include "llvm/Support/Debug.h"
88#include "llvm/Support/ErrorHandling.h"
89#include "llvm/Support/raw_ostream.h"
90#include <cassert>
91#include <cstdlib>
92
93using namespace llvm;
94
95#define DEBUG_TYPE "aarch64-condopt"
96
97STATISTIC(NumConditionsAdjusted, "Number of conditions adjusted");
98
99namespace {
100
101/// Bundles the parameters needed to adjust a comparison instruction.
102struct CmpInfo {
103 int Imm;
104 unsigned Opc;
105 AArch64CC::CondCode CC;
106};
107
108class AArch64ConditionOptimizerImpl {
109 /// Represents a comparison instruction paired with its consuming
110 /// conditional instruction
111 struct CmpCondPair {
112 MachineInstr *CmpMI;
113 MachineInstr *CondMI;
114 AArch64CC::CondCode CC;
115
116 int getImm() const { return CmpMI->getOperand(i: 2).getImm(); }
117 unsigned getOpc() const { return CmpMI->getOpcode(); }
118 };
119
120 const AArch64InstrInfo *TII;
121 const TargetRegisterInfo *TRI;
122 MachineDominatorTree *DomTree;
123 const MachineRegisterInfo *MRI;
124
125public:
126 bool run(MachineFunction &MF, MachineDominatorTree &MDT);
127
128private:
129 bool canAdjustCmp(MachineInstr &CmpMI);
130 bool registersMatch(MachineInstr *FirstMI, MachineInstr *SecondMI);
131 bool nzcvLivesOut(MachineBasicBlock *MBB);
132 MachineInstr *getBccTerminator(MachineBasicBlock *MBB);
133 MachineInstr *findAdjustableCmp(MachineInstr *CondMI);
134 CmpInfo getAdjustedCmpInfo(MachineInstr *CmpMI, AArch64CC::CondCode Cmp);
135 void updateCmpInstr(MachineInstr *CmpMI, int NewImm, unsigned NewOpc);
136 void updateCondInstr(MachineInstr *CondMI, AArch64CC::CondCode NewCC);
137 void applyCmpAdjustment(CmpCondPair &Pair, const CmpInfo &Info);
138 bool tryOptimizePair(CmpCondPair &First, CmpCondPair &Second);
139 bool optimizeIntraBlock(MachineBasicBlock &MBB);
140 bool optimizeCrossBlock(MachineBasicBlock &HBB);
141};
142
143class AArch64ConditionOptimizerLegacy : public MachineFunctionPass {
144public:
145 static char ID;
146 AArch64ConditionOptimizerLegacy() : MachineFunctionPass(ID) {}
147
148 void getAnalysisUsage(AnalysisUsage &AU) const override;
149 bool runOnMachineFunction(MachineFunction &MF) override;
150
151 StringRef getPassName() const override {
152 return "AArch64 Condition Optimizer";
153 }
154};
155
156} // end anonymous namespace
157
158char AArch64ConditionOptimizerLegacy::ID = 0;
159
160INITIALIZE_PASS_BEGIN(AArch64ConditionOptimizerLegacy, "aarch64-condopt",
161 "AArch64 CondOpt Pass", false, false)
162INITIALIZE_PASS_DEPENDENCY(MachineDominatorTreeWrapperPass)
163INITIALIZE_PASS_END(AArch64ConditionOptimizerLegacy, "aarch64-condopt",
164 "AArch64 CondOpt Pass", false, false)
165
166FunctionPass *llvm::createAArch64ConditionOptimizerLegacyPass() {
167 return new AArch64ConditionOptimizerLegacy();
168}
169
170void AArch64ConditionOptimizerLegacy::getAnalysisUsage(
171 AnalysisUsage &AU) const {
172 AU.addRequired<MachineDominatorTreeWrapperPass>();
173 AU.addPreserved<MachineDominatorTreeWrapperPass>();
174 MachineFunctionPass::getAnalysisUsage(AU);
175}
176
177// Verify that the MI's immediate is adjustable and it only sets flags (pure
178// cmp)
179bool AArch64ConditionOptimizerImpl::canAdjustCmp(MachineInstr &CmpMI) {
180 unsigned ShiftAmt = AArch64_AM::getShiftValue(Imm: CmpMI.getOperand(i: 3).getImm());
181 if (!CmpMI.getOperand(i: 2).isImm()) {
182 LLVM_DEBUG(dbgs() << "Immediate of cmp is symbolic, " << CmpMI << '\n');
183 return false;
184 } else if (CmpMI.getOperand(i: 2).getImm() << ShiftAmt >= 0xfff) {
185 LLVM_DEBUG(dbgs() << "Immediate of cmp may be out of range, " << CmpMI
186 << '\n');
187 return false;
188 } else if (!MRI->use_nodbg_empty(RegNo: CmpMI.getOperand(i: 0).getReg())) {
189 LLVM_DEBUG(dbgs() << "Destination of cmp is not dead, " << CmpMI << '\n');
190 return false;
191 }
192
193 return true;
194}
195
196// Ensure both compare MIs use the same register, tracing through copies.
197bool AArch64ConditionOptimizerImpl::registersMatch(MachineInstr *FirstMI,
198 MachineInstr *SecondMI) {
199 Register FirstReg = FirstMI->getOperand(i: 1).getReg();
200 Register SecondReg = SecondMI->getOperand(i: 1).getReg();
201 Register FirstCmpReg =
202 FirstReg.isVirtual() ? TRI->lookThruCopyLike(SrcReg: FirstReg, MRI) : FirstReg;
203 Register SecondCmpReg =
204 SecondReg.isVirtual() ? TRI->lookThruCopyLike(SrcReg: SecondReg, MRI) : SecondReg;
205 if (FirstCmpReg != SecondCmpReg) {
206 LLVM_DEBUG(dbgs() << "CMPs compare different registers\n");
207 return false;
208 }
209
210 return true;
211}
212
213// Check if NZCV lives out to any successor block.
214bool AArch64ConditionOptimizerImpl::nzcvLivesOut(MachineBasicBlock *MBB) {
215 for (auto *SuccBB : MBB->successors()) {
216 if (SuccBB->isLiveIn(Reg: AArch64::NZCV)) {
217 LLVM_DEBUG(dbgs() << "NZCV live into successor "
218 << printMBBReference(*SuccBB) << " from "
219 << printMBBReference(*MBB) << '\n');
220 return true;
221 }
222 }
223 return false;
224}
225
226// Returns true if the opcode is a comparison instruction (CMP/CMN).
227static bool isCmpInstruction(unsigned Opc) {
228 switch (Opc) {
229 // cmp is an alias for SUBS with a dead destination register.
230 case AArch64::SUBSWri:
231 case AArch64::SUBSXri:
232 // cmp is an alias for ADDS with a dead destination register.
233 case AArch64::ADDSWri:
234 case AArch64::ADDSXri:
235 return true;
236 default:
237 return false;
238 }
239}
240
241static bool isCSINCInstruction(unsigned Opc) {
242 return Opc == AArch64::CSINCWr || Opc == AArch64::CSINCXr;
243}
244
245// Returns the Bcc terminator if present, otherwise nullptr.
246MachineInstr *
247AArch64ConditionOptimizerImpl::getBccTerminator(MachineBasicBlock *MBB) {
248 MachineBasicBlock::iterator Term = MBB->getFirstTerminator();
249 if (Term == MBB->end()) {
250 LLVM_DEBUG(dbgs() << "No terminator in " << printMBBReference(*MBB)
251 << '\n');
252 return nullptr;
253 }
254
255 if (Term->getOpcode() != AArch64::Bcc) {
256 LLVM_DEBUG(dbgs() << "Non-Bcc terminator in " << printMBBReference(*MBB)
257 << ": " << *Term);
258 return nullptr;
259 }
260
261 return &*Term;
262}
263
264// Find the CMP instruction controlling the given conditional instruction and
265// ensure it can be adjusted for CSE optimization. Searches backward from
266// CondMI, ensuring no NZCV interference. Returns nullptr if no suitable CMP
267// is found or if adjustments are not safe.
268MachineInstr *
269AArch64ConditionOptimizerImpl::findAdjustableCmp(MachineInstr *CondMI) {
270 assert(CondMI && "CondMI cannot be null");
271 MachineBasicBlock *MBB = CondMI->getParent();
272
273 // Search backward from the conditional to find the instruction controlling
274 // it.
275 for (MachineBasicBlock::iterator B = MBB->begin(),
276 It = MachineBasicBlock::iterator(CondMI);
277 It != B;) {
278 It = prev_nodbg(It, Begin: B);
279 MachineInstr &I = *It;
280 assert(!I.isTerminator() && "Spurious terminator");
281 // Ensure there is no use of NZCV between CMP and conditional.
282 if (I.readsRegister(Reg: AArch64::NZCV, /*TRI=*/nullptr))
283 return nullptr;
284
285 if (isCmpInstruction(Opc: I.getOpcode())) {
286 if (!canAdjustCmp(CmpMI&: I)) {
287 return nullptr;
288 }
289 return &I;
290 }
291
292 if (I.modifiesRegister(Reg: AArch64::NZCV, /*TRI=*/nullptr))
293 return nullptr;
294 }
295 LLVM_DEBUG(dbgs() << "Flags not defined in " << printMBBReference(*MBB)
296 << '\n');
297 return nullptr;
298}
299
300// Changes opcode adds <-> subs considering register operand width.
301static int getComplementOpc(int Opc) {
302 switch (Opc) {
303 case AArch64::ADDSWri: return AArch64::SUBSWri;
304 case AArch64::ADDSXri: return AArch64::SUBSXri;
305 case AArch64::SUBSWri: return AArch64::ADDSWri;
306 case AArch64::SUBSXri: return AArch64::ADDSXri;
307 default:
308 llvm_unreachable("Unexpected opcode");
309 }
310}
311
312// Changes form of comparison inclusive <-> exclusive.
313static AArch64CC::CondCode getAdjustedCmp(AArch64CC::CondCode Cmp) {
314 switch (Cmp) {
315 case AArch64CC::GT:
316 return AArch64CC::GE;
317 case AArch64CC::GE:
318 return AArch64CC::GT;
319 case AArch64CC::LT:
320 return AArch64CC::LE;
321 case AArch64CC::LE:
322 return AArch64CC::LT;
323 case AArch64CC::HI:
324 return AArch64CC::HS;
325 case AArch64CC::HS:
326 return AArch64CC::HI;
327 case AArch64CC::LO:
328 return AArch64CC::LS;
329 case AArch64CC::LS:
330 return AArch64CC::LO;
331 default:
332 llvm_unreachable("Unexpected condition code");
333 }
334}
335
336// Returns the adjusted immediate, opcode, and condition code for switching
337// between inclusive/exclusive forms (GT <-> GE, LT <-> LE).
338CmpInfo
339AArch64ConditionOptimizerImpl::getAdjustedCmpInfo(MachineInstr *CmpMI,
340 AArch64CC::CondCode Cmp) {
341 unsigned Opc = CmpMI->getOpcode();
342
343 bool IsSigned = Cmp == AArch64CC::GT || Cmp == AArch64CC::GE ||
344 Cmp == AArch64CC::LT || Cmp == AArch64CC::LE;
345
346 // CMN (compare with negative immediate) is an alias to ADDS (as
347 // "operand - negative" == "operand + positive")
348 bool Negative = (Opc == AArch64::ADDSWri || Opc == AArch64::ADDSXri);
349
350 int Correction = (Cmp == AArch64CC::GT || Cmp == AArch64CC::HI) ? 1 : -1;
351 // Negate Correction value for comparison with negative immediate (CMN).
352 if (Negative) {
353 Correction = -Correction;
354 }
355
356 const int OldImm = (int)CmpMI->getOperand(i: 2).getImm();
357 const int NewImm = std::abs(x: OldImm + Correction);
358
359 // Bail out on cmn 0 (ADDS with immediate 0). It is a valid instruction but
360 // doesn't set flags in a way we can safely transform, so skip optimization.
361 if (OldImm == 0 && Negative)
362 return {.Imm: OldImm, .Opc: Opc, .CC: Cmp};
363
364 if ((OldImm == 1 && Negative && Correction == -1) ||
365 (OldImm == 0 && Correction == -1)) {
366 // If we change opcodes for unsigned comparisons, this means we did an
367 // unsigned wrap (e.g., 0 wrapping to 0xFFFFFFFF), so return the old cmp.
368 // Note: For signed comparisons, opcode changes (cmn 1 ↔ cmp 0) are valid.
369 if (!IsSigned)
370 return {.Imm: OldImm, .Opc: Opc, .CC: Cmp};
371 Opc = getComplementOpc(Opc);
372 }
373
374 return {.Imm: NewImm, .Opc: Opc, .CC: getAdjustedCmp(Cmp)};
375}
376
377// Modifies a comparison instruction's immediate and opcode.
378void AArch64ConditionOptimizerImpl::updateCmpInstr(MachineInstr *CmpMI,
379 int NewImm,
380 unsigned NewOpc) {
381 CmpMI->getOperand(i: 2).setImm(NewImm);
382 CmpMI->setDesc(TII->get(Opcode: NewOpc));
383}
384
385// Modifies the condition code of a conditional instruction.
386void AArch64ConditionOptimizerImpl::updateCondInstr(MachineInstr *CondMI,
387 AArch64CC::CondCode NewCC) {
388 int CCOpIdx =
389 AArch64InstrInfo::findCondCodeUseOperandIdxForBranchOrSelect(Instr: *CondMI);
390 assert(CCOpIdx >= 0 && "Unsupported conditional instruction");
391 CondMI->getOperand(i: CCOpIdx).setImm(NewCC);
392 ++NumConditionsAdjusted;
393}
394
395// Applies a comparison adjustment to a cmp/cond instruction pair.
396void AArch64ConditionOptimizerImpl::applyCmpAdjustment(CmpCondPair &Pair,
397 const CmpInfo &Info) {
398 updateCmpInstr(CmpMI: Pair.CmpMI, NewImm: Info.Imm, NewOpc: Info.Opc);
399 updateCondInstr(CondMI: Pair.CondMI, NewCC: Info.CC);
400}
401
402// Extracts the condition code from the result of analyzeBranch.
403// Returns the CondCode or Invalid if the format is not a simple br.cond.
404static AArch64CC::CondCode parseCondCode(ArrayRef<MachineOperand> Cond) {
405 assert(!Cond.empty() && "Expected non-empty condition from analyzeBranch");
406 // A normal br.cond simply has the condition code.
407 if (Cond[0].getImm() != -1) {
408 assert(Cond.size() == 1 && "Unknown Cond array format");
409 return (AArch64CC::CondCode)(int)Cond[0].getImm();
410 }
411 return AArch64CC::CondCode::Invalid;
412}
413
414static bool isGreaterThan(AArch64CC::CondCode Cmp) {
415 return Cmp == AArch64CC::GT || Cmp == AArch64CC::HI;
416}
417
418static bool isLessThan(AArch64CC::CondCode Cmp) {
419 return Cmp == AArch64CC::LT || Cmp == AArch64CC::LO;
420}
421
422bool AArch64ConditionOptimizerImpl::tryOptimizePair(CmpCondPair &First,
423 CmpCondPair &Second) {
424 if (!((isGreaterThan(Cmp: First.CC) || isLessThan(Cmp: First.CC)) &&
425 (isGreaterThan(Cmp: Second.CC) || isLessThan(Cmp: Second.CC))))
426 return false;
427
428 int FirstImmTrueValue = First.getImm();
429 int SecondImmTrueValue = Second.getImm();
430
431 // Normalize immediate of CMN (ADDS) instructions
432 if (First.getOpc() == AArch64::ADDSWri || First.getOpc() == AArch64::ADDSXri)
433 FirstImmTrueValue = -FirstImmTrueValue;
434 if (Second.getOpc() == AArch64::ADDSWri ||
435 Second.getOpc() == AArch64::ADDSXri)
436 SecondImmTrueValue = -SecondImmTrueValue;
437
438 CmpInfo FirstAdj = getAdjustedCmpInfo(CmpMI: First.CmpMI, Cmp: First.CC);
439 CmpInfo SecondAdj = getAdjustedCmpInfo(CmpMI: Second.CmpMI, Cmp: Second.CC);
440
441 if (((isGreaterThan(Cmp: First.CC) && isLessThan(Cmp: Second.CC)) ||
442 (isLessThan(Cmp: First.CC) && isGreaterThan(Cmp: Second.CC))) &&
443 std::abs(x: SecondImmTrueValue - FirstImmTrueValue) == 2) {
444 // This branch transforms machine instructions that correspond to
445 //
446 // 1) (a > {SecondImm} && ...) || (a < {FirstImm} && ...)
447 // 2) (a < {SecondImm} && ...) || (a > {FirstImm} && ...)
448 //
449 // into
450 //
451 // 1) (a >= {NewImm} && ...) || (a <= {NewImm} && ...)
452 // 2) (a <= {NewImm} && ...) || (a >= {NewImm} && ...)
453
454 // Verify both adjustments converge to identical comparisons (same
455 // immediate and opcode). This ensures CSE can eliminate the duplicate.
456 if (FirstAdj.Imm != SecondAdj.Imm || FirstAdj.Opc != SecondAdj.Opc)
457 return false;
458
459 LLVM_DEBUG(dbgs() << "Optimized (opposite): "
460 << AArch64CC::getCondCodeName(First.CC) << " #"
461 << First.getImm() << ", "
462 << AArch64CC::getCondCodeName(Second.CC) << " #"
463 << Second.getImm() << " -> "
464 << AArch64CC::getCondCodeName(FirstAdj.CC) << " #"
465 << FirstAdj.Imm << ", "
466 << AArch64CC::getCondCodeName(SecondAdj.CC) << " #"
467 << SecondAdj.Imm << '\n');
468 applyCmpAdjustment(Pair&: First, Info: FirstAdj);
469 applyCmpAdjustment(Pair&: Second, Info: SecondAdj);
470 return true;
471
472 } else if (((isGreaterThan(Cmp: First.CC) && isGreaterThan(Cmp: Second.CC)) ||
473 (isLessThan(Cmp: First.CC) && isLessThan(Cmp: Second.CC))) &&
474 std::abs(x: SecondImmTrueValue - FirstImmTrueValue) == 1) {
475 // This branch transforms machine instructions that correspond to
476 //
477 // 1) (a > {SecondImm} && ...) || (a > {FirstImm} && ...)
478 // 2) (a < {SecondImm} && ...) || (a < {FirstImm} && ...)
479 //
480 // into
481 //
482 // 1) (a <= {NewImm} && ...) || (a > {NewImm} && ...)
483 // 2) (a < {NewImm} && ...) || (a >= {NewImm} && ...)
484
485 // GT -> GE transformation increases immediate value, so picking the
486 // smaller one; LT -> LE decreases immediate value so invert the choice.
487 bool AdjustFirst = (FirstImmTrueValue < SecondImmTrueValue);
488 if (isLessThan(Cmp: First.CC))
489 AdjustFirst = !AdjustFirst;
490
491 CmpCondPair &Target = AdjustFirst ? Second : First;
492 CmpCondPair &ToChange = AdjustFirst ? First : Second;
493 CmpInfo &Adj = AdjustFirst ? FirstAdj : SecondAdj;
494
495 // Verify the adjustment converges to the target's comparison (same
496 // immediate and opcode). This ensures CSE can eliminate the duplicate.
497 if (Adj.Imm != Target.getImm() || Adj.Opc != Target.getOpc())
498 return false;
499
500 LLVM_DEBUG(dbgs() << "Optimized (same-direction): "
501 << AArch64CC::getCondCodeName(ToChange.CC) << " #"
502 << ToChange.getImm() << " -> "
503 << AArch64CC::getCondCodeName(Adj.CC) << " #" << Adj.Imm
504 << '\n');
505 applyCmpAdjustment(Pair&: ToChange, Info: Adj);
506 return true;
507 }
508
509 // Other transformation cases almost never occur due to generation of < or >
510 // comparisons instead of <= and >=.
511 return false;
512}
513
514// This function transforms two CMP+CSINC pairs within the same basic block
515// when both conditions are the same (GT/GT or LT/LT) and immediates differ
516// by 1.
517//
518// Example transformation:
519// cmp w8, #10
520// csinc w9, w0, w1, gt ; w9 = (w8 > 10) ? w0 : w1+1
521// cmp w8, #9
522// csinc w10, w0, w1, gt ; w10 = (w8 > 9) ? w0 : w1+1
523//
524// Into:
525// cmp w8, #10
526// csinc w9, w0, w1, gt ; w9 = (w8 > 10) ? w0 : w1+1
527// csinc w10, w0, w1, ge ; w10 = (w8 >= 10) ? w0 : w1+1
528//
529// The second CMP is eliminated, enabling CSE to remove the redundant
530// comparison.
531bool AArch64ConditionOptimizerImpl::optimizeIntraBlock(MachineBasicBlock &MBB) {
532 MachineInstr *FirstCSINC = nullptr;
533 MachineInstr *SecondCSINC = nullptr;
534
535 // Find two CSINC instructions
536 for (MachineInstr &MI : MBB) {
537 if (isCSINCInstruction(Opc: MI.getOpcode())) {
538 if (!FirstCSINC) {
539 FirstCSINC = &MI;
540 } else if (!SecondCSINC) {
541 SecondCSINC = &MI;
542 break; // Found both
543 }
544 }
545 }
546
547 if (!FirstCSINC || !SecondCSINC) {
548 return false;
549 }
550
551 // Since we may modify cmps in this MBB, make sure NZCV does not live out.
552 if (nzcvLivesOut(MBB: &MBB))
553 return false;
554
555 // Find the CMPs controlling each CSINC
556 MachineInstr *FirstCmpMI = findAdjustableCmp(CondMI: FirstCSINC);
557 MachineInstr *SecondCmpMI = findAdjustableCmp(CondMI: SecondCSINC);
558 if (!FirstCmpMI || !SecondCmpMI)
559 return false;
560
561 // Ensure we have two distinct CMPs
562 if (FirstCmpMI == SecondCmpMI) {
563 LLVM_DEBUG(dbgs() << "Both CSINCs already controlled by same CMP\n");
564 return false;
565 }
566
567 if (!registersMatch(FirstMI: FirstCmpMI, SecondMI: SecondCmpMI))
568 return false;
569
570 // Check that nothing else modifies the flags between the first CMP and second
571 // conditional
572 for (auto It = std::next(x: MachineBasicBlock::iterator(FirstCmpMI));
573 It != std::next(x: MachineBasicBlock::iterator(SecondCSINC)); ++It) {
574 if (&*It != SecondCmpMI &&
575 It->modifiesRegister(Reg: AArch64::NZCV, /*TRI=*/nullptr)) {
576 LLVM_DEBUG(dbgs() << "Flags modified between CMPs by: " << *It << '\n');
577 return false;
578 }
579 }
580
581 // Check flags aren't read after second conditional within the same block
582 for (auto It = std::next(x: MachineBasicBlock::iterator(SecondCSINC));
583 It != MBB.end(); ++It) {
584 if (It->readsRegister(Reg: AArch64::NZCV, /*TRI=*/nullptr)) {
585 LLVM_DEBUG(dbgs() << "Flags read after second CSINC by: " << *It << '\n');
586 return false;
587 }
588 }
589
590 // Extract condition codes from both CSINCs (operand 3)
591 AArch64CC::CondCode FirstCondCode =
592 (AArch64CC::CondCode)(int)FirstCSINC->getOperand(i: 3).getImm();
593 AArch64CC::CondCode SecondCondCode =
594 (AArch64CC::CondCode)(int)SecondCSINC->getOperand(i: 3).getImm();
595
596 LLVM_DEBUG(dbgs() << "Comparing intra-block CSINCs: "
597 << AArch64CC::getCondCodeName(FirstCondCode) << " #"
598 << FirstCmpMI->getOperand(2).getImm() << " and "
599 << AArch64CC::getCondCodeName(SecondCondCode) << " #"
600 << SecondCmpMI->getOperand(2).getImm() << '\n');
601
602 CmpCondPair First{.CmpMI: FirstCmpMI, .CondMI: FirstCSINC, .CC: FirstCondCode};
603 CmpCondPair Second{.CmpMI: SecondCmpMI, .CondMI: SecondCSINC, .CC: SecondCondCode};
604
605 return tryOptimizePair(First, Second);
606}
607
608// Optimizes CMP+Bcc pairs across two basic blocks in the dominator tree.
609bool AArch64ConditionOptimizerImpl::optimizeCrossBlock(MachineBasicBlock &HBB) {
610 SmallVector<MachineOperand, 4> HeadCondOperands;
611 MachineBasicBlock *TBB = nullptr, *FBB = nullptr;
612 if (TII->analyzeBranch(MBB&: HBB, TBB, FBB, Cond&: HeadCondOperands)) {
613 return false;
614 }
615
616 // Equivalence check is to skip loops.
617 if (!TBB || TBB == &HBB) {
618 return false;
619 }
620
621 SmallVector<MachineOperand, 4> TrueCondOperands;
622 MachineBasicBlock *TBB_TBB = nullptr, *TBB_FBB = nullptr;
623 if (TII->analyzeBranch(MBB&: *TBB, TBB&: TBB_TBB, FBB&: TBB_FBB, Cond&: TrueCondOperands)) {
624 return false;
625 }
626
627 MachineInstr *HeadBrMI = getBccTerminator(MBB: &HBB);
628 MachineInstr *TrueBrMI = getBccTerminator(MBB: TBB);
629 if (!HeadBrMI || !TrueBrMI)
630 return false;
631
632 // Since we may modify cmps in these blocks, make sure NZCV does not live out.
633 if (nzcvLivesOut(MBB: &HBB) || nzcvLivesOut(MBB: TBB))
634 return false;
635
636 // Find the CMPs controlling each branch
637 MachineInstr *HeadCmpMI = findAdjustableCmp(CondMI: HeadBrMI);
638 MachineInstr *TrueCmpMI = findAdjustableCmp(CondMI: TrueBrMI);
639 if (!HeadCmpMI || !TrueCmpMI)
640 return false;
641
642 if (!registersMatch(FirstMI: HeadCmpMI, SecondMI: TrueCmpMI))
643 return false;
644
645 AArch64CC::CondCode HeadCondCode = parseCondCode(Cond: HeadCondOperands);
646 AArch64CC::CondCode TrueCondCode = parseCondCode(Cond: TrueCondOperands);
647 if (HeadCondCode == AArch64CC::CondCode::Invalid ||
648 TrueCondCode == AArch64CC::CondCode::Invalid) {
649 return false;
650 }
651
652 LLVM_DEBUG(dbgs() << "Checking cross-block pair: "
653 << AArch64CC::getCondCodeName(HeadCondCode) << " #"
654 << HeadCmpMI->getOperand(2).getImm() << ", "
655 << AArch64CC::getCondCodeName(TrueCondCode) << " #"
656 << TrueCmpMI->getOperand(2).getImm() << '\n');
657
658 CmpCondPair Head{.CmpMI: HeadCmpMI, .CondMI: HeadBrMI, .CC: HeadCondCode};
659 CmpCondPair True{.CmpMI: TrueCmpMI, .CondMI: TrueBrMI, .CC: TrueCondCode};
660
661 return tryOptimizePair(First&: Head, Second&: True);
662}
663
664bool AArch64ConditionOptimizerLegacy::runOnMachineFunction(
665 MachineFunction &MF) {
666 if (skipFunction(F: MF.getFunction()))
667 return false;
668 MachineDominatorTree &MDT =
669 getAnalysis<MachineDominatorTreeWrapperPass>().getDomTree();
670 return AArch64ConditionOptimizerImpl().run(MF, MDT);
671}
672
673bool AArch64ConditionOptimizerImpl::run(MachineFunction &MF,
674 MachineDominatorTree &MDT) {
675 LLVM_DEBUG(dbgs() << "********** AArch64 Conditional Compares **********\n"
676 << "********** Function: " << MF.getName() << '\n');
677
678 TII = static_cast<const AArch64InstrInfo *>(MF.getSubtarget().getInstrInfo());
679 TRI = MF.getSubtarget().getRegisterInfo();
680 DomTree = &MDT;
681 MRI = &MF.getRegInfo();
682
683 bool Changed = false;
684
685 // Visit blocks in dominator tree pre-order. The pre-order enables multiple
686 // cmp-conversions from the same head block.
687 // Note that updateDomTree() modifies the children of the DomTree node
688 // currently being visited. The df_iterator supports that; it doesn't look at
689 // child_begin() / child_end() until after a node has been visited.
690 for (MachineDomTreeNode *I : depth_first(G: DomTree)) {
691 MachineBasicBlock *HBB = I->getBlock();
692 Changed |= optimizeIntraBlock(MBB&: *HBB);
693 Changed |= optimizeCrossBlock(HBB&: *HBB);
694 }
695
696 return Changed;
697}
698
699PreservedAnalyses
700AArch64ConditionOptimizerPass::run(MachineFunction &MF,
701 MachineFunctionAnalysisManager &MFAM) {
702 auto &MDT = MFAM.getResult<MachineDominatorTreeAnalysis>(IR&: MF);
703 bool Changed = AArch64ConditionOptimizerImpl().run(MF, MDT);
704 if (!Changed)
705 return PreservedAnalyses::all();
706 PreservedAnalyses PA = getMachineFunctionPassPreservedAnalyses();
707 PA.preserveSet<CFGAnalyses>();
708 return PA;
709}
710