1//=- AArch64ConditionOptimizer.cpp - Remove useless comparisons for AArch64 -=//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9//
10// This pass tries to make consecutive comparisons of values use the same
11// operands to allow the CSE pass to remove duplicate instructions. It adjusts
12// comparisons with immediate values by converting between inclusive and
13// exclusive forms (GE <-> GT, LE <-> LT) and correcting immediate values to
14// make them equal.
15//
16// The pass handles:
17// * Cross-block: SUBS/ADDS followed by conditional branches
18// * Intra-block: Select-family conditional instructions
19//
20//
21// Consider the following example in C:
22//
23// if ((a < 5 && ...) || (a > 5 && ...)) {
24// ~~~~~ ~~~~~
25// ^ ^
26// x y
27//
28// Here both "x" and "y" expressions compare "a" with "5". When "x" evaluates
29// to "false", "y" can just check flags set by the first comparison. As a
30// result of the canonicalization employed by
31// SelectionDAGBuilder::visitSwitchCase, DAGCombine, and other target-specific
32// code, assembly ends up in the form that is not CSE friendly:
33//
34// ...
35// cmp w8, #4
36// b.gt .LBB0_3
37// ...
38// .LBB0_3:
39// cmp w8, #6
40// b.lt .LBB0_6
41// ...
42//
43// Same assembly after the pass:
44//
45// ...
46// cmp w8, #5
47// b.ge .LBB0_3
48// ...
49// .LBB0_3:
50// cmp w8, #5 // <-- CSE pass removes this instruction
51// b.le .LBB0_6
52// ...
53//
54// See optimizeCrossBlock() and optimizeIntraBlock() for implementation details.
55//
56// TODO: maybe handle TBNZ/TBZ the same way as CMP when used instead for "a < 0"
57// TODO: For cross-block:
58// - handle other conditional instructions (e.g. CSET)
59// - allow second branching to be anything if it doesn't require adjusting
60//
61//===----------------------------------------------------------------------===//
62
63#include "AArch64.h"
64#include "AArch64Subtarget.h"
65#include "MCTargetDesc/AArch64AddressingModes.h"
66#include "Utils/AArch64BaseInfo.h"
67#include "llvm/ADT/ArrayRef.h"
68#include "llvm/ADT/DepthFirstIterator.h"
69#include "llvm/ADT/SmallVector.h"
70#include "llvm/ADT/Statistic.h"
71#include "llvm/CodeGen/MachineBasicBlock.h"
72#include "llvm/CodeGen/MachineDominators.h"
73#include "llvm/CodeGen/MachineFunction.h"
74#include "llvm/CodeGen/MachineFunctionPass.h"
75#include "llvm/CodeGen/MachineInstr.h"
76#include "llvm/CodeGen/MachineOperand.h"
77#include "llvm/CodeGen/MachineRegisterInfo.h"
78#include "llvm/CodeGen/TargetInstrInfo.h"
79#include "llvm/CodeGen/TargetRegisterInfo.h"
80#include "llvm/CodeGen/TargetSubtargetInfo.h"
81#include "llvm/InitializePasses.h"
82#include "llvm/Pass.h"
83#include "llvm/Support/Debug.h"
84#include "llvm/Support/ErrorHandling.h"
85#include "llvm/Support/raw_ostream.h"
86#include <cassert>
87#include <cstdlib>
88
89using namespace llvm;
90
91#define DEBUG_TYPE "aarch64-condopt"
92
93STATISTIC(NumConditionsAdjusted, "Number of conditions adjusted");
94
95namespace {
96
97/// Bundles the parameters needed to adjust a comparison instruction.
98struct CmpInfo {
99 int Imm;
100 unsigned Opc;
101 AArch64CC::CondCode CC;
102};
103
104class AArch64ConditionOptimizerImpl {
105 /// Represents a comparison instruction paired with its consuming
106 /// conditional instruction
107 struct CmpCondPair {
108 MachineInstr *CmpMI;
109 MachineInstr *CondMI;
110 AArch64CC::CondCode CC;
111
112 int getImm() const { return CmpMI->getOperand(i: 2).getImm(); }
113 unsigned getOpc() const { return CmpMI->getOpcode(); }
114 };
115
116 const AArch64InstrInfo *TII;
117 const TargetRegisterInfo *TRI;
118 MachineDominatorTree *DomTree;
119 const MachineRegisterInfo *MRI;
120
121public:
122 bool run(MachineFunction &MF, MachineDominatorTree &MDT);
123
124private:
125 bool canAdjustCmp(MachineInstr &CmpMI);
126 bool registersMatch(MachineInstr *FirstMI, MachineInstr *SecondMI);
127 bool nzcvLivesOut(MachineBasicBlock *MBB);
128 MachineInstr *getBccTerminator(MachineBasicBlock *MBB);
129 MachineInstr *findAdjustableCmp(MachineInstr *CondMI);
130 CmpInfo getAdjustedCmpInfo(MachineInstr *CmpMI, AArch64CC::CondCode Cmp);
131 void updateCmpInstr(MachineInstr *CmpMI, int NewImm, unsigned NewOpc);
132 void updateCondInstr(MachineInstr *CondMI, AArch64CC::CondCode NewCC);
133 void applyCmpAdjustment(CmpCondPair &Pair, const CmpInfo &Info);
134 bool commitPendingPair(std::optional<CmpCondPair> &PendingPair,
135 SmallDenseMap<Register, CmpCondPair> &PairsByReg);
136 bool tryOptimizePair(CmpCondPair &First, CmpCondPair &Second);
137 bool optimizeIntraBlock(MachineBasicBlock &MBB);
138 bool optimizeCrossBlock(MachineBasicBlock &HBB);
139};
140
141class AArch64ConditionOptimizerLegacy : public MachineFunctionPass {
142public:
143 static char ID;
144 AArch64ConditionOptimizerLegacy() : MachineFunctionPass(ID) {}
145
146 void getAnalysisUsage(AnalysisUsage &AU) const override;
147 bool runOnMachineFunction(MachineFunction &MF) override;
148
149 StringRef getPassName() const override {
150 return "AArch64 Condition Optimizer";
151 }
152};
153
154} // end anonymous namespace
155
156char AArch64ConditionOptimizerLegacy::ID = 0;
157
158INITIALIZE_PASS_BEGIN(AArch64ConditionOptimizerLegacy, "aarch64-condopt",
159 "AArch64 CondOpt Pass", false, false)
160INITIALIZE_PASS_DEPENDENCY(MachineDominatorTreeWrapperPass)
161INITIALIZE_PASS_END(AArch64ConditionOptimizerLegacy, "aarch64-condopt",
162 "AArch64 CondOpt Pass", false, false)
163
164FunctionPass *llvm::createAArch64ConditionOptimizerLegacyPass() {
165 return new AArch64ConditionOptimizerLegacy();
166}
167
168void AArch64ConditionOptimizerLegacy::getAnalysisUsage(
169 AnalysisUsage &AU) const {
170 AU.addRequired<MachineDominatorTreeWrapperPass>();
171 AU.addPreserved<MachineDominatorTreeWrapperPass>();
172 MachineFunctionPass::getAnalysisUsage(AU);
173}
174
175// Verify that the MI's immediate is adjustable and it only sets flags (pure
176// cmp)
177bool AArch64ConditionOptimizerImpl::canAdjustCmp(MachineInstr &CmpMI) {
178 unsigned ShiftAmt = AArch64_AM::getShiftValue(Imm: CmpMI.getOperand(i: 3).getImm());
179 if (!CmpMI.getOperand(i: 2).isImm()) {
180 LLVM_DEBUG(dbgs() << "Immediate of cmp is symbolic, " << CmpMI << '\n');
181 return false;
182 } else if (CmpMI.getOperand(i: 2).getImm() << ShiftAmt >= 0xfff) {
183 LLVM_DEBUG(dbgs() << "Immediate of cmp may be out of range, " << CmpMI
184 << '\n');
185 return false;
186 } else if (!MRI->use_nodbg_empty(RegNo: CmpMI.getOperand(i: 0).getReg())) {
187 LLVM_DEBUG(dbgs() << "Destination of cmp is not dead, " << CmpMI << '\n');
188 return false;
189 }
190
191 return true;
192}
193
194// Ensure both compare MIs use the same register, tracing through copies.
195bool AArch64ConditionOptimizerImpl::registersMatch(MachineInstr *FirstMI,
196 MachineInstr *SecondMI) {
197 Register FirstReg = FirstMI->getOperand(i: 1).getReg();
198 Register SecondReg = SecondMI->getOperand(i: 1).getReg();
199 Register FirstCmpReg =
200 FirstReg.isVirtual() ? TRI->lookThruCopyLike(SrcReg: FirstReg, MRI) : FirstReg;
201 Register SecondCmpReg =
202 SecondReg.isVirtual() ? TRI->lookThruCopyLike(SrcReg: SecondReg, MRI) : SecondReg;
203 if (FirstCmpReg != SecondCmpReg) {
204 LLVM_DEBUG(dbgs() << "CMPs compare different registers\n");
205 return false;
206 }
207
208 return true;
209}
210
211// Check if NZCV lives out to any successor block.
212bool AArch64ConditionOptimizerImpl::nzcvLivesOut(MachineBasicBlock *MBB) {
213 for (auto *SuccBB : MBB->successors()) {
214 if (SuccBB->isLiveIn(Reg: AArch64::NZCV)) {
215 LLVM_DEBUG(dbgs() << "NZCV live into successor "
216 << printMBBReference(*SuccBB) << " from "
217 << printMBBReference(*MBB) << '\n');
218 return true;
219 }
220 }
221 return false;
222}
223
224// Returns true if the opcode is a comparison instruction (CMP/CMN).
225static bool isCmpInstruction(unsigned Opc) {
226 switch (Opc) {
227 // cmp is an alias for SUBS with a dead destination register.
228 case AArch64::SUBSWri:
229 case AArch64::SUBSXri:
230 // cmp is an alias for ADDS with a dead destination register.
231 case AArch64::ADDSWri:
232 case AArch64::ADDSXri:
233 return true;
234 default:
235 return false;
236 }
237}
238
239// Returns the Bcc terminator if present, otherwise nullptr.
240MachineInstr *
241AArch64ConditionOptimizerImpl::getBccTerminator(MachineBasicBlock *MBB) {
242 MachineBasicBlock::iterator Term = MBB->getFirstTerminator();
243 if (Term == MBB->end()) {
244 LLVM_DEBUG(dbgs() << "No terminator in " << printMBBReference(*MBB)
245 << '\n');
246 return nullptr;
247 }
248
249 if (Term->getOpcode() != AArch64::Bcc) {
250 LLVM_DEBUG(dbgs() << "Non-Bcc terminator in " << printMBBReference(*MBB)
251 << ": " << *Term);
252 return nullptr;
253 }
254
255 return &*Term;
256}
257
258// Find the CMP instruction controlling the given conditional instruction and
259// ensure it can be adjusted for CSE optimization. Searches backward from
260// CondMI, ensuring no NZCV interference. Returns nullptr if no suitable CMP
261// is found or if adjustments are not safe.
262MachineInstr *
263AArch64ConditionOptimizerImpl::findAdjustableCmp(MachineInstr *CondMI) {
264 assert(CondMI && "CondMI cannot be null");
265 MachineBasicBlock *MBB = CondMI->getParent();
266
267 // Search backward from the conditional to find the instruction controlling
268 // it.
269 for (MachineBasicBlock::iterator B = MBB->begin(),
270 It = MachineBasicBlock::iterator(CondMI);
271 It != B;) {
272 It = prev_nodbg(It, Begin: B);
273 MachineInstr &I = *It;
274 assert(!I.isTerminator() && "Spurious terminator");
275 // Ensure there is no use of NZCV between CMP and conditional.
276 if (I.readsRegister(Reg: AArch64::NZCV, /*TRI=*/nullptr))
277 return nullptr;
278
279 if (isCmpInstruction(Opc: I.getOpcode())) {
280 if (!canAdjustCmp(CmpMI&: I)) {
281 return nullptr;
282 }
283 return &I;
284 }
285
286 if (I.modifiesRegister(Reg: AArch64::NZCV, /*TRI=*/nullptr))
287 return nullptr;
288 }
289 LLVM_DEBUG(dbgs() << "Flags not defined in " << printMBBReference(*MBB)
290 << '\n');
291 return nullptr;
292}
293
294// Changes opcode adds <-> subs considering register operand width.
295static int getComplementOpc(int Opc) {
296 switch (Opc) {
297 case AArch64::ADDSWri: return AArch64::SUBSWri;
298 case AArch64::ADDSXri: return AArch64::SUBSXri;
299 case AArch64::SUBSWri: return AArch64::ADDSWri;
300 case AArch64::SUBSXri: return AArch64::ADDSXri;
301 default:
302 llvm_unreachable("Unexpected opcode");
303 }
304}
305
306// Changes form of comparison inclusive <-> exclusive.
307static AArch64CC::CondCode getAdjustedCmp(AArch64CC::CondCode Cmp) {
308 switch (Cmp) {
309 case AArch64CC::GT:
310 return AArch64CC::GE;
311 case AArch64CC::GE:
312 return AArch64CC::GT;
313 case AArch64CC::LT:
314 return AArch64CC::LE;
315 case AArch64CC::LE:
316 return AArch64CC::LT;
317 case AArch64CC::HI:
318 return AArch64CC::HS;
319 case AArch64CC::HS:
320 return AArch64CC::HI;
321 case AArch64CC::LO:
322 return AArch64CC::LS;
323 case AArch64CC::LS:
324 return AArch64CC::LO;
325 default:
326 llvm_unreachable("Unexpected condition code");
327 }
328}
329
330// Returns the adjusted immediate, opcode, and condition code for switching
331// between inclusive/exclusive forms (GT <-> GE, LT <-> LE).
332CmpInfo
333AArch64ConditionOptimizerImpl::getAdjustedCmpInfo(MachineInstr *CmpMI,
334 AArch64CC::CondCode Cmp) {
335 unsigned Opc = CmpMI->getOpcode();
336
337 bool IsSigned = Cmp == AArch64CC::GT || Cmp == AArch64CC::GE ||
338 Cmp == AArch64CC::LT || Cmp == AArch64CC::LE;
339
340 // CMN (compare with negative immediate) is an alias to ADDS (as
341 // "operand - negative" == "operand + positive")
342 bool Negative = (Opc == AArch64::ADDSWri || Opc == AArch64::ADDSXri);
343
344 int Correction = (Cmp == AArch64CC::GT || Cmp == AArch64CC::HI) ? 1 : -1;
345 // Negate Correction value for comparison with negative immediate (CMN).
346 if (Negative) {
347 Correction = -Correction;
348 }
349
350 const int OldImm = (int)CmpMI->getOperand(i: 2).getImm();
351 const int NewImm = std::abs(x: OldImm + Correction);
352
353 // Bail out on cmn 0 (ADDS with immediate 0). It is a valid instruction but
354 // doesn't set flags in a way we can safely transform, so skip optimization.
355 if (OldImm == 0 && Negative)
356 return {.Imm: OldImm, .Opc: Opc, .CC: Cmp};
357
358 if ((OldImm == 1 && Negative && Correction == -1) ||
359 (OldImm == 0 && Correction == -1)) {
360 // If we change opcodes for unsigned comparisons, this means we did an
361 // unsigned wrap (e.g., 0 wrapping to 0xFFFFFFFF), so return the old cmp.
362 // Note: For signed comparisons, opcode changes (cmn 1 ↔ cmp 0) are valid.
363 if (!IsSigned)
364 return {.Imm: OldImm, .Opc: Opc, .CC: Cmp};
365 Opc = getComplementOpc(Opc);
366 }
367
368 return {.Imm: NewImm, .Opc: Opc, .CC: getAdjustedCmp(Cmp)};
369}
370
371// Modifies a comparison instruction's immediate and opcode.
372void AArch64ConditionOptimizerImpl::updateCmpInstr(MachineInstr *CmpMI,
373 int NewImm,
374 unsigned NewOpc) {
375 CmpMI->getOperand(i: 2).setImm(NewImm);
376 CmpMI->setDesc(TII->get(Opcode: NewOpc));
377}
378
379// Modifies the condition code of a conditional instruction.
380void AArch64ConditionOptimizerImpl::updateCondInstr(MachineInstr *CondMI,
381 AArch64CC::CondCode NewCC) {
382 int CCOpIdx =
383 AArch64InstrInfo::findCondCodeUseOperandIdxForBranchOrSelect(Instr: *CondMI);
384 assert(CCOpIdx >= 0 && "Unsupported conditional instruction");
385 CondMI->getOperand(i: CCOpIdx).setImm(NewCC);
386 ++NumConditionsAdjusted;
387}
388
389// Applies a comparison adjustment to a cmp/cond instruction pair.
390void AArch64ConditionOptimizerImpl::applyCmpAdjustment(CmpCondPair &Pair,
391 const CmpInfo &Info) {
392 updateCmpInstr(CmpMI: Pair.CmpMI, NewImm: Info.Imm, NewOpc: Info.Opc);
393 updateCondInstr(CondMI: Pair.CondMI, NewCC: Info.CC);
394 Pair.CC = Info.CC;
395}
396
397// Extracts the condition code from the result of analyzeBranch.
398// Returns the CondCode or Invalid if the format is not a simple br.cond.
399static AArch64CC::CondCode parseCondCode(ArrayRef<MachineOperand> Cond) {
400 assert(!Cond.empty() && "Expected non-empty condition from analyzeBranch");
401 // A normal br.cond simply has the condition code.
402 if (Cond[0].getImm() != -1) {
403 assert(Cond.size() == 1 && "Unknown Cond array format");
404 return (AArch64CC::CondCode)(int)Cond[0].getImm();
405 }
406 return AArch64CC::CondCode::Invalid;
407}
408
409static bool isGreaterThan(AArch64CC::CondCode Cmp) {
410 return Cmp == AArch64CC::GT || Cmp == AArch64CC::HI;
411}
412
413static bool isLessThan(AArch64CC::CondCode Cmp) {
414 return Cmp == AArch64CC::LT || Cmp == AArch64CC::LO;
415}
416
417bool AArch64ConditionOptimizerImpl::tryOptimizePair(CmpCondPair &First,
418 CmpCondPair &Second) {
419 if (!((isGreaterThan(Cmp: First.CC) || isLessThan(Cmp: First.CC)) &&
420 (isGreaterThan(Cmp: Second.CC) || isLessThan(Cmp: Second.CC))))
421 return false;
422
423 int FirstImmTrueValue = First.getImm();
424 int SecondImmTrueValue = Second.getImm();
425
426 // Normalize immediate of CMN (ADDS) instructions
427 if (First.getOpc() == AArch64::ADDSWri || First.getOpc() == AArch64::ADDSXri)
428 FirstImmTrueValue = -FirstImmTrueValue;
429 if (Second.getOpc() == AArch64::ADDSWri ||
430 Second.getOpc() == AArch64::ADDSXri)
431 SecondImmTrueValue = -SecondImmTrueValue;
432
433 CmpInfo FirstAdj = getAdjustedCmpInfo(CmpMI: First.CmpMI, Cmp: First.CC);
434 CmpInfo SecondAdj = getAdjustedCmpInfo(CmpMI: Second.CmpMI, Cmp: Second.CC);
435
436 if (((isGreaterThan(Cmp: First.CC) && isLessThan(Cmp: Second.CC)) ||
437 (isLessThan(Cmp: First.CC) && isGreaterThan(Cmp: Second.CC))) &&
438 std::abs(x: SecondImmTrueValue - FirstImmTrueValue) == 2) {
439 // This branch transforms machine instructions that correspond to
440 //
441 // 1) (a > {SecondImm} && ...) || (a < {FirstImm} && ...)
442 // 2) (a < {SecondImm} && ...) || (a > {FirstImm} && ...)
443 //
444 // into
445 //
446 // 1) (a >= {NewImm} && ...) || (a <= {NewImm} && ...)
447 // 2) (a <= {NewImm} && ...) || (a >= {NewImm} && ...)
448
449 // Verify both adjustments converge to identical comparisons (same
450 // immediate and opcode). This ensures CSE can eliminate the duplicate.
451 if (FirstAdj.Imm != SecondAdj.Imm || FirstAdj.Opc != SecondAdj.Opc)
452 return false;
453
454 LLVM_DEBUG(dbgs() << "Optimized (opposite): "
455 << AArch64CC::getCondCodeName(First.CC) << " #"
456 << First.getImm() << ", "
457 << AArch64CC::getCondCodeName(Second.CC) << " #"
458 << Second.getImm() << " -> "
459 << AArch64CC::getCondCodeName(FirstAdj.CC) << " #"
460 << FirstAdj.Imm << ", "
461 << AArch64CC::getCondCodeName(SecondAdj.CC) << " #"
462 << SecondAdj.Imm << '\n');
463 applyCmpAdjustment(Pair&: First, Info: FirstAdj);
464 applyCmpAdjustment(Pair&: Second, Info: SecondAdj);
465 return true;
466
467 } else if (((isGreaterThan(Cmp: First.CC) && isGreaterThan(Cmp: Second.CC)) ||
468 (isLessThan(Cmp: First.CC) && isLessThan(Cmp: Second.CC))) &&
469 std::abs(x: SecondImmTrueValue - FirstImmTrueValue) == 1) {
470 // This branch transforms machine instructions that correspond to
471 //
472 // 1) (a > {SecondImm} && ...) || (a > {FirstImm} && ...)
473 // 2) (a < {SecondImm} && ...) || (a < {FirstImm} && ...)
474 //
475 // into
476 //
477 // 1) (a <= {NewImm} && ...) || (a > {NewImm} && ...)
478 // 2) (a < {NewImm} && ...) || (a >= {NewImm} && ...)
479
480 // GT -> GE transformation increases immediate value, so picking the
481 // smaller one; LT -> LE decreases immediate value so invert the choice.
482 bool AdjustFirst = (FirstImmTrueValue < SecondImmTrueValue);
483 if (isLessThan(Cmp: First.CC))
484 AdjustFirst = !AdjustFirst;
485
486 CmpCondPair &Target = AdjustFirst ? Second : First;
487 CmpCondPair &ToChange = AdjustFirst ? First : Second;
488 CmpInfo &Adj = AdjustFirst ? FirstAdj : SecondAdj;
489
490 // Verify the adjustment converges to the target's comparison (same
491 // immediate and opcode). This ensures CSE can eliminate the duplicate.
492 if (Adj.Imm != Target.getImm() || Adj.Opc != Target.getOpc())
493 return false;
494
495 LLVM_DEBUG(dbgs() << "Optimized (same-direction): "
496 << AArch64CC::getCondCodeName(ToChange.CC) << " #"
497 << ToChange.getImm() << " -> "
498 << AArch64CC::getCondCodeName(Adj.CC) << " #" << Adj.Imm
499 << '\n');
500 applyCmpAdjustment(Pair&: ToChange, Info: Adj);
501 return true;
502 }
503
504 // Other transformation cases almost never occur due to generation of < or >
505 // comparisons instead of <= and >=.
506 return false;
507}
508
509bool AArch64ConditionOptimizerImpl::commitPendingPair(
510 std::optional<CmpCondPair> &PendingPair,
511 SmallDenseMap<Register, CmpCondPair> &PairsByReg) {
512 if (!PendingPair)
513 return false;
514
515 Register Reg = PendingPair->CmpMI->getOperand(i: 1).getReg();
516 Register Key = Reg.isVirtual() ? TRI->lookThruCopyLike(SrcReg: Reg, MRI) : Reg;
517
518 auto MatchingPair = PairsByReg.find(Val: Key);
519 bool Changed = MatchingPair != PairsByReg.end() &&
520 tryOptimizePair(First&: MatchingPair->second, Second&: *PendingPair);
521
522 PairsByReg[Key] = *PendingPair;
523 PendingPair = std::nullopt;
524 return Changed;
525}
526
527// This function transforms cmps and their consuming conditionals (CmpCondPairs)
528// 1. Same direction: when both conditions are the same (e.g. GT/GT or LT/LT)
529// and immediates differ by 1
530// 2. Opposite direction: when both conditions are adjustable to a common middle
531// (e.g., GT/LT) and immediates differ by 2.
532// The compare instructions are made to match to enable CSE.
533// All cmp/cond pairs within a basic block are examined
534//
535// Example transformation:
536// cmp w8, #10
537// csinc w9, w0, w1, gt ; w9 = (w8 > 10) ? w0 : w1+1
538// cmp w8, #9
539// csinc w10, w0, w1, gt ; w10 = (w8 > 9) ? w0 : w1+1
540//
541// Into:
542// cmp w8, #10
543// csinc w9, w0, w1, gt ; w9 = (w8 > 10) ? w0 : w1+1
544// cmp w8, #10 ; <- CSE can remove the redundant cmp
545// csinc w10, w0, w1, ge ; w10 = (w8 >= 10) ? w0 : w1+1
546//
547bool AArch64ConditionOptimizerImpl::optimizeIntraBlock(MachineBasicBlock &MBB) {
548 SmallDenseMap<Register, CmpCondPair> PairsByReg;
549 std::optional<CmpCondPair> PendingPair;
550 MachineInstr *ActiveCmp = nullptr;
551 bool Changed = false;
552
553 for (MachineInstr &MI : MBB) {
554 if (MI.isDebugInstr())
555 continue;
556
557 if (isCmpInstruction(Opc: MI.getOpcode()) && canAdjustCmp(CmpMI&: MI)) {
558 Changed |= commitPendingPair(PendingPair, PairsByReg);
559 ActiveCmp = &MI;
560 continue;
561 }
562
563 if (MI.modifiesRegister(Reg: AArch64::NZCV, /*TRI=*/nullptr)) {
564 // Non-CMP clobber: commit any pending pair and reset all state, since
565 // unknown flag state at this point invalidates all prior pairs
566 Changed |= commitPendingPair(PendingPair, PairsByReg);
567 ActiveCmp = nullptr;
568 PairsByReg.clear();
569 continue;
570 }
571
572 if (AArch64InstrInfo::findCondCodeUseOperandIdxForBranchOrSelect(Instr: MI) >= 0 &&
573 !MI.isBranch()) {
574 if (PendingPair) {
575 // A second conditional consuming the same CMP would invalidate any
576 // optimization: modifying the CMP would silently change what both
577 // consumers compare against. Mark the CMP spent.
578 PendingPair = std::nullopt;
579 ActiveCmp = nullptr;
580 } else if (ActiveCmp) {
581 int CCOpIdx =
582 AArch64InstrInfo::findCondCodeUseOperandIdxForBranchOrSelect(Instr: MI);
583 assert(CCOpIdx >= 0 && "Unsupported conditional instruction");
584 AArch64CC::CondCode CC =
585 (AArch64CC::CondCode)(int)MI.getOperand(i: CCOpIdx).getImm();
586 PendingPair = CmpCondPair{.CmpMI: ActiveCmp, .CondMI: &MI, .CC: CC};
587 }
588 continue;
589 }
590
591 if (MI.readsRegister(Reg: AArch64::NZCV, /*TRI=*/nullptr)) {
592 ActiveCmp = nullptr;
593 PendingPair = std::nullopt;
594 continue;
595 }
596 }
597
598 // Only commit the final pending pair if NZCV doesn't live out: a cross-block
599 // consumer would be affected by any CMP adjustment we make.
600 if (!nzcvLivesOut(MBB: &MBB))
601 Changed |= commitPendingPair(PendingPair, PairsByReg);
602
603 return Changed;
604}
605
606// Optimizes CMP+Bcc pairs across two basic blocks in the dominator tree.
607bool AArch64ConditionOptimizerImpl::optimizeCrossBlock(MachineBasicBlock &HBB) {
608 SmallVector<MachineOperand, 4> HeadCondOperands;
609 MachineBasicBlock *TBB = nullptr, *FBB = nullptr;
610 if (TII->analyzeBranch(MBB&: HBB, TBB, FBB, Cond&: HeadCondOperands)) {
611 return false;
612 }
613
614 // Equivalence check is to skip loops.
615 if (!TBB || TBB == &HBB) {
616 return false;
617 }
618
619 SmallVector<MachineOperand, 4> TrueCondOperands;
620 MachineBasicBlock *TBB_TBB = nullptr, *TBB_FBB = nullptr;
621 if (TII->analyzeBranch(MBB&: *TBB, TBB&: TBB_TBB, FBB&: TBB_FBB, Cond&: TrueCondOperands)) {
622 return false;
623 }
624
625 MachineInstr *HeadBrMI = getBccTerminator(MBB: &HBB);
626 MachineInstr *TrueBrMI = getBccTerminator(MBB: TBB);
627 if (!HeadBrMI || !TrueBrMI)
628 return false;
629
630 // Since we may modify cmps in these blocks, make sure NZCV does not live out.
631 if (nzcvLivesOut(MBB: &HBB) || nzcvLivesOut(MBB: TBB))
632 return false;
633
634 // Find the CMPs controlling each branch
635 MachineInstr *HeadCmpMI = findAdjustableCmp(CondMI: HeadBrMI);
636 MachineInstr *TrueCmpMI = findAdjustableCmp(CondMI: TrueBrMI);
637 if (!HeadCmpMI || !TrueCmpMI)
638 return false;
639
640 if (!registersMatch(FirstMI: HeadCmpMI, SecondMI: TrueCmpMI))
641 return false;
642
643 AArch64CC::CondCode HeadCondCode = parseCondCode(Cond: HeadCondOperands);
644 AArch64CC::CondCode TrueCondCode = parseCondCode(Cond: TrueCondOperands);
645 if (HeadCondCode == AArch64CC::CondCode::Invalid ||
646 TrueCondCode == AArch64CC::CondCode::Invalid) {
647 return false;
648 }
649
650 LLVM_DEBUG(dbgs() << "Checking cross-block pair: "
651 << AArch64CC::getCondCodeName(HeadCondCode) << " #"
652 << HeadCmpMI->getOperand(2).getImm() << ", "
653 << AArch64CC::getCondCodeName(TrueCondCode) << " #"
654 << TrueCmpMI->getOperand(2).getImm() << '\n');
655
656 CmpCondPair Head{.CmpMI: HeadCmpMI, .CondMI: HeadBrMI, .CC: HeadCondCode};
657 CmpCondPair True{.CmpMI: TrueCmpMI, .CondMI: TrueBrMI, .CC: TrueCondCode};
658
659 return tryOptimizePair(First&: Head, Second&: True);
660}
661
662bool AArch64ConditionOptimizerLegacy::runOnMachineFunction(
663 MachineFunction &MF) {
664 if (skipFunction(F: MF.getFunction()))
665 return false;
666 MachineDominatorTree &MDT =
667 getAnalysis<MachineDominatorTreeWrapperPass>().getDomTree();
668 return AArch64ConditionOptimizerImpl().run(MF, MDT);
669}
670
671bool AArch64ConditionOptimizerImpl::run(MachineFunction &MF,
672 MachineDominatorTree &MDT) {
673 LLVM_DEBUG(dbgs() << "********** AArch64 Conditional Compares **********\n"
674 << "********** Function: " << MF.getName() << '\n');
675
676 TII = static_cast<const AArch64InstrInfo *>(MF.getSubtarget().getInstrInfo());
677 TRI = MF.getSubtarget().getRegisterInfo();
678 DomTree = &MDT;
679 MRI = &MF.getRegInfo();
680
681 bool Changed = false;
682
683 // Visit blocks in dominator tree pre-order. The pre-order enables multiple
684 // cmp-conversions from the same head block.
685 // Note that updateDomTree() modifies the children of the DomTree node
686 // currently being visited. The df_iterator supports that; it doesn't look at
687 // child_begin() / child_end() until after a node has been visited.
688 for (MachineDomTreeNode *I : depth_first(G: DomTree)) {
689 MachineBasicBlock *HBB = I->getBlock();
690 Changed |= optimizeIntraBlock(MBB&: *HBB);
691 Changed |= optimizeCrossBlock(HBB&: *HBB);
692 }
693
694 return Changed;
695}
696
697PreservedAnalyses
698AArch64ConditionOptimizerPass::run(MachineFunction &MF,
699 MachineFunctionAnalysisManager &MFAM) {
700 auto &MDT = MFAM.getResult<MachineDominatorTreeAnalysis>(IR&: MF);
701 bool Changed = AArch64ConditionOptimizerImpl().run(MF, MDT);
702 if (!Changed)
703 return PreservedAnalyses::all();
704 PreservedAnalyses PA = getMachineFunctionPassPreservedAnalyses();
705 PA.preserveSet<CFGAnalyses>();
706 return PA;
707}
708