1 | //===-- LoopPredication.cpp - Guard based loop predication pass -----------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // The LoopPredication pass tries to convert loop variant range checks to loop |
10 | // invariant by widening checks across loop iterations. For example, it will |
11 | // convert |
12 | // |
13 | // for (i = 0; i < n; i++) { |
14 | // guard(i < len); |
15 | // ... |
16 | // } |
17 | // |
18 | // to |
19 | // |
20 | // for (i = 0; i < n; i++) { |
21 | // guard(n - 1 < len); |
22 | // ... |
23 | // } |
24 | // |
25 | // After this transformation the condition of the guard is loop invariant, so |
26 | // loop-unswitch can later unswitch the loop by this condition which basically |
27 | // predicates the loop by the widened condition: |
28 | // |
29 | // if (n - 1 < len) |
30 | // for (i = 0; i < n; i++) { |
31 | // ... |
32 | // } |
33 | // else |
34 | // deoptimize |
35 | // |
36 | // It's tempting to rely on SCEV here, but it has proven to be problematic. |
37 | // Generally the facts SCEV provides about the increment step of add |
38 | // recurrences are true if the backedge of the loop is taken, which implicitly |
39 | // assumes that the guard doesn't fail. Using these facts to optimize the |
40 | // guard results in a circular logic where the guard is optimized under the |
41 | // assumption that it never fails. |
42 | // |
43 | // For example, in the loop below the induction variable will be marked as nuw |
44 | // basing on the guard. Basing on nuw the guard predicate will be considered |
45 | // monotonic. Given a monotonic condition it's tempting to replace the induction |
46 | // variable in the condition with its value on the last iteration. But this |
47 | // transformation is not correct, e.g. e = 4, b = 5 breaks the loop. |
48 | // |
49 | // for (int i = b; i != e; i++) |
50 | // guard(i u< len) |
51 | // |
52 | // One of the ways to reason about this problem is to use an inductive proof |
53 | // approach. Given the loop: |
54 | // |
55 | // if (B(0)) { |
56 | // do { |
57 | // I = PHI(0, I.INC) |
58 | // I.INC = I + Step |
59 | // guard(G(I)); |
60 | // } while (B(I)); |
61 | // } |
62 | // |
63 | // where B(x) and G(x) are predicates that map integers to booleans, we want a |
64 | // loop invariant expression M such the following program has the same semantics |
65 | // as the above: |
66 | // |
67 | // if (B(0)) { |
68 | // do { |
69 | // I = PHI(0, I.INC) |
70 | // I.INC = I + Step |
71 | // guard(G(0) && M); |
72 | // } while (B(I)); |
73 | // } |
74 | // |
75 | // One solution for M is M = forall X . (G(X) && B(X)) => G(X + Step) |
76 | // |
77 | // Informal proof that the transformation above is correct: |
78 | // |
79 | // By the definition of guards we can rewrite the guard condition to: |
80 | // G(I) && G(0) && M |
81 | // |
82 | // Let's prove that for each iteration of the loop: |
83 | // G(0) && M => G(I) |
84 | // And the condition above can be simplified to G(Start) && M. |
85 | // |
86 | // Induction base. |
87 | // G(0) && M => G(0) |
88 | // |
89 | // Induction step. Assuming G(0) && M => G(I) on the subsequent |
90 | // iteration: |
91 | // |
92 | // B(I) is true because it's the backedge condition. |
93 | // G(I) is true because the backedge is guarded by this condition. |
94 | // |
95 | // So M = forall X . (G(X) && B(X)) => G(X + Step) implies G(I + Step). |
96 | // |
97 | // Note that we can use anything stronger than M, i.e. any condition which |
98 | // implies M. |
99 | // |
100 | // When S = 1 (i.e. forward iterating loop), the transformation is supported |
101 | // when: |
102 | // * The loop has a single latch with the condition of the form: |
103 | // B(X) = latchStart + X <pred> latchLimit, |
104 | // where <pred> is u<, u<=, s<, or s<=. |
105 | // * The guard condition is of the form |
106 | // G(X) = guardStart + X u< guardLimit |
107 | // |
108 | // For the ult latch comparison case M is: |
109 | // forall X . guardStart + X u< guardLimit && latchStart + X <u latchLimit => |
110 | // guardStart + X + 1 u< guardLimit |
111 | // |
112 | // The only way the antecedent can be true and the consequent can be false is |
113 | // if |
114 | // X == guardLimit - 1 - guardStart |
115 | // (and guardLimit is non-zero, but we won't use this latter fact). |
116 | // If X == guardLimit - 1 - guardStart then the second half of the antecedent is |
117 | // latchStart + guardLimit - 1 - guardStart u< latchLimit |
118 | // and its negation is |
119 | // latchStart + guardLimit - 1 - guardStart u>= latchLimit |
120 | // |
121 | // In other words, if |
122 | // latchLimit u<= latchStart + guardLimit - 1 - guardStart |
123 | // then: |
124 | // (the ranges below are written in ConstantRange notation, where [A, B) is the |
125 | // set for (I = A; I != B; I++ /*maywrap*/) yield(I);) |
126 | // |
127 | // forall X . guardStart + X u< guardLimit && |
128 | // latchStart + X u< latchLimit => |
129 | // guardStart + X + 1 u< guardLimit |
130 | // == forall X . guardStart + X u< guardLimit && |
131 | // latchStart + X u< latchStart + guardLimit - 1 - guardStart => |
132 | // guardStart + X + 1 u< guardLimit |
133 | // == forall X . (guardStart + X) in [0, guardLimit) && |
134 | // (latchStart + X) in [0, latchStart + guardLimit - 1 - guardStart) => |
135 | // (guardStart + X + 1) in [0, guardLimit) |
136 | // == forall X . X in [-guardStart, guardLimit - guardStart) && |
137 | // X in [-latchStart, guardLimit - 1 - guardStart) => |
138 | // X in [-guardStart - 1, guardLimit - guardStart - 1) |
139 | // == true |
140 | // |
141 | // So the widened condition is: |
142 | // guardStart u< guardLimit && |
143 | // latchStart + guardLimit - 1 - guardStart u>= latchLimit |
144 | // Similarly for ule condition the widened condition is: |
145 | // guardStart u< guardLimit && |
146 | // latchStart + guardLimit - 1 - guardStart u> latchLimit |
147 | // For slt condition the widened condition is: |
148 | // guardStart u< guardLimit && |
149 | // latchStart + guardLimit - 1 - guardStart s>= latchLimit |
150 | // For sle condition the widened condition is: |
151 | // guardStart u< guardLimit && |
152 | // latchStart + guardLimit - 1 - guardStart s> latchLimit |
153 | // |
154 | // When S = -1 (i.e. reverse iterating loop), the transformation is supported |
155 | // when: |
156 | // * The loop has a single latch with the condition of the form: |
157 | // B(X) = X <pred> latchLimit, where <pred> is u>, u>=, s>, or s>=. |
158 | // * The guard condition is of the form |
159 | // G(X) = X - 1 u< guardLimit |
160 | // |
161 | // For the ugt latch comparison case M is: |
162 | // forall X. X-1 u< guardLimit and X u> latchLimit => X-2 u< guardLimit |
163 | // |
164 | // The only way the antecedent can be true and the consequent can be false is if |
165 | // X == 1. |
166 | // If X == 1 then the second half of the antecedent is |
167 | // 1 u> latchLimit, and its negation is latchLimit u>= 1. |
168 | // |
169 | // So the widened condition is: |
170 | // guardStart u< guardLimit && latchLimit u>= 1. |
171 | // Similarly for sgt condition the widened condition is: |
172 | // guardStart u< guardLimit && latchLimit s>= 1. |
173 | // For uge condition the widened condition is: |
174 | // guardStart u< guardLimit && latchLimit u> 1. |
175 | // For sge condition the widened condition is: |
176 | // guardStart u< guardLimit && latchLimit s> 1. |
177 | //===----------------------------------------------------------------------===// |
178 | |
179 | #include "llvm/Transforms/Scalar/LoopPredication.h" |
180 | #include "llvm/ADT/Statistic.h" |
181 | #include "llvm/Analysis/AliasAnalysis.h" |
182 | #include "llvm/Analysis/BranchProbabilityInfo.h" |
183 | #include "llvm/Analysis/GuardUtils.h" |
184 | #include "llvm/Analysis/LoopInfo.h" |
185 | #include "llvm/Analysis/LoopPass.h" |
186 | #include "llvm/Analysis/MemorySSA.h" |
187 | #include "llvm/Analysis/MemorySSAUpdater.h" |
188 | #include "llvm/Analysis/ScalarEvolution.h" |
189 | #include "llvm/Analysis/ScalarEvolutionExpressions.h" |
190 | #include "llvm/IR/Function.h" |
191 | #include "llvm/IR/IntrinsicInst.h" |
192 | #include "llvm/IR/Module.h" |
193 | #include "llvm/IR/PatternMatch.h" |
194 | #include "llvm/IR/ProfDataUtils.h" |
195 | #include "llvm/Pass.h" |
196 | #include "llvm/Support/CommandLine.h" |
197 | #include "llvm/Support/Debug.h" |
198 | #include "llvm/Transforms/Scalar.h" |
199 | #include "llvm/Transforms/Utils/GuardUtils.h" |
200 | #include "llvm/Transforms/Utils/Local.h" |
201 | #include "llvm/Transforms/Utils/LoopUtils.h" |
202 | #include "llvm/Transforms/Utils/ScalarEvolutionExpander.h" |
203 | #include <optional> |
204 | |
205 | #define DEBUG_TYPE "loop-predication" |
206 | |
207 | STATISTIC(TotalConsidered, "Number of guards considered" ); |
208 | STATISTIC(TotalWidened, "Number of checks widened" ); |
209 | |
210 | using namespace llvm; |
211 | |
212 | static cl::opt<bool> EnableIVTruncation("loop-predication-enable-iv-truncation" , |
213 | cl::Hidden, cl::init(Val: true)); |
214 | |
215 | static cl::opt<bool> EnableCountDownLoop("loop-predication-enable-count-down-loop" , |
216 | cl::Hidden, cl::init(Val: true)); |
217 | |
218 | static cl::opt<bool> |
219 | SkipProfitabilityChecks("loop-predication-skip-profitability-checks" , |
220 | cl::Hidden, cl::init(Val: false)); |
221 | |
222 | // This is the scale factor for the latch probability. We use this during |
223 | // profitability analysis to find other exiting blocks that have a much higher |
224 | // probability of exiting the loop instead of loop exiting via latch. |
225 | // This value should be greater than 1 for a sane profitability check. |
226 | static cl::opt<float> LatchExitProbabilityScale( |
227 | "loop-predication-latch-probability-scale" , cl::Hidden, cl::init(Val: 2.0), |
228 | cl::desc("scale factor for the latch probability. Value should be greater " |
229 | "than 1. Lower values are ignored" )); |
230 | |
231 | static cl::opt<bool> PredicateWidenableBranchGuards( |
232 | "loop-predication-predicate-widenable-branches-to-deopt" , cl::Hidden, |
233 | cl::desc("Whether or not we should predicate guards " |
234 | "expressed as widenable branches to deoptimize blocks" ), |
235 | cl::init(Val: true)); |
236 | |
237 | static cl::opt<bool> InsertAssumesOfPredicatedGuardsConditions( |
238 | "loop-predication-insert-assumes-of-predicated-guards-conditions" , |
239 | cl::Hidden, |
240 | cl::desc("Whether or not we should insert assumes of conditions of " |
241 | "predicated guards" ), |
242 | cl::init(Val: true)); |
243 | |
244 | namespace { |
245 | /// Represents an induction variable check: |
246 | /// icmp Pred, <induction variable>, <loop invariant limit> |
247 | struct LoopICmp { |
248 | ICmpInst::Predicate Pred; |
249 | const SCEVAddRecExpr *IV; |
250 | const SCEV *Limit; |
251 | LoopICmp(ICmpInst::Predicate Pred, const SCEVAddRecExpr *IV, |
252 | const SCEV *Limit) |
253 | : Pred(Pred), IV(IV), Limit(Limit) {} |
254 | LoopICmp() = default; |
255 | void dump() { |
256 | dbgs() << "LoopICmp Pred = " << Pred << ", IV = " << *IV |
257 | << ", Limit = " << *Limit << "\n" ; |
258 | } |
259 | }; |
260 | |
261 | class LoopPredication { |
262 | AliasAnalysis *AA; |
263 | DominatorTree *DT; |
264 | ScalarEvolution *SE; |
265 | LoopInfo *LI; |
266 | MemorySSAUpdater *MSSAU; |
267 | |
268 | Loop *L; |
269 | const DataLayout *DL; |
270 | BasicBlock *; |
271 | LoopICmp LatchCheck; |
272 | |
273 | bool isSupportedStep(const SCEV* Step); |
274 | std::optional<LoopICmp> parseLoopICmp(ICmpInst *ICI); |
275 | std::optional<LoopICmp> parseLoopLatchICmp(); |
276 | |
277 | /// Return an insertion point suitable for inserting a safe to speculate |
278 | /// instruction whose only user will be 'User' which has operands 'Ops'. A |
279 | /// trivial result would be the at the User itself, but we try to return a |
280 | /// loop invariant location if possible. |
281 | Instruction *findInsertPt(Instruction *User, ArrayRef<Value*> Ops); |
282 | /// Same as above, *except* that this uses the SCEV definition of invariant |
283 | /// which is that an expression *can be made* invariant via SCEVExpander. |
284 | /// Thus, this version is only suitable for finding an insert point to be |
285 | /// passed to SCEVExpander! |
286 | Instruction *findInsertPt(const SCEVExpander &Expander, Instruction *User, |
287 | ArrayRef<const SCEV *> Ops); |
288 | |
289 | /// Return true if the value is known to produce a single fixed value across |
290 | /// all iterations on which it executes. Note that this does not imply |
291 | /// speculation safety. That must be established separately. |
292 | bool isLoopInvariantValue(const SCEV* S); |
293 | |
294 | Value *expandCheck(SCEVExpander &Expander, Instruction *Guard, |
295 | ICmpInst::Predicate Pred, const SCEV *LHS, |
296 | const SCEV *RHS); |
297 | |
298 | std::optional<Value *> widenICmpRangeCheck(ICmpInst *ICI, |
299 | SCEVExpander &Expander, |
300 | Instruction *Guard); |
301 | std::optional<Value *> |
302 | widenICmpRangeCheckIncrementingLoop(LoopICmp LatchCheck, LoopICmp RangeCheck, |
303 | SCEVExpander &Expander, |
304 | Instruction *Guard); |
305 | std::optional<Value *> |
306 | widenICmpRangeCheckDecrementingLoop(LoopICmp LatchCheck, LoopICmp RangeCheck, |
307 | SCEVExpander &Expander, |
308 | Instruction *Guard); |
309 | void widenChecks(SmallVectorImpl<Value *> &Checks, |
310 | SmallVectorImpl<Value *> &WidenedChecks, |
311 | SCEVExpander &Expander, Instruction *Guard); |
312 | bool widenGuardConditions(IntrinsicInst *II, SCEVExpander &Expander); |
313 | bool widenWidenableBranchGuardConditions(BranchInst *Guard, SCEVExpander &Expander); |
314 | // If the loop always exits through another block in the loop, we should not |
315 | // predicate based on the latch check. For example, the latch check can be a |
316 | // very coarse grained check and there can be more fine grained exit checks |
317 | // within the loop. |
318 | bool isLoopProfitableToPredicate(); |
319 | |
320 | bool predicateLoopExits(Loop *L, SCEVExpander &Rewriter); |
321 | |
322 | public: |
323 | LoopPredication(AliasAnalysis *AA, DominatorTree *DT, ScalarEvolution *SE, |
324 | LoopInfo *LI, MemorySSAUpdater *MSSAU) |
325 | : AA(AA), DT(DT), SE(SE), LI(LI), MSSAU(MSSAU){}; |
326 | bool runOnLoop(Loop *L); |
327 | }; |
328 | |
329 | } // end namespace |
330 | |
331 | PreservedAnalyses LoopPredicationPass::run(Loop &L, LoopAnalysisManager &AM, |
332 | LoopStandardAnalysisResults &AR, |
333 | LPMUpdater &U) { |
334 | std::unique_ptr<MemorySSAUpdater> MSSAU; |
335 | if (AR.MSSA) |
336 | MSSAU = std::make_unique<MemorySSAUpdater>(args&: AR.MSSA); |
337 | LoopPredication LP(&AR.AA, &AR.DT, &AR.SE, &AR.LI, |
338 | MSSAU ? MSSAU.get() : nullptr); |
339 | if (!LP.runOnLoop(L: &L)) |
340 | return PreservedAnalyses::all(); |
341 | |
342 | auto PA = getLoopPassPreservedAnalyses(); |
343 | if (AR.MSSA) |
344 | PA.preserve<MemorySSAAnalysis>(); |
345 | return PA; |
346 | } |
347 | |
348 | std::optional<LoopICmp> LoopPredication::parseLoopICmp(ICmpInst *ICI) { |
349 | auto Pred = ICI->getPredicate(); |
350 | auto *LHS = ICI->getOperand(i_nocapture: 0); |
351 | auto *RHS = ICI->getOperand(i_nocapture: 1); |
352 | |
353 | const SCEV *LHSS = SE->getSCEV(V: LHS); |
354 | if (isa<SCEVCouldNotCompute>(Val: LHSS)) |
355 | return std::nullopt; |
356 | const SCEV *RHSS = SE->getSCEV(V: RHS); |
357 | if (isa<SCEVCouldNotCompute>(Val: RHSS)) |
358 | return std::nullopt; |
359 | |
360 | // Canonicalize RHS to be loop invariant bound, LHS - a loop computable IV |
361 | if (SE->isLoopInvariant(S: LHSS, L)) { |
362 | std::swap(a&: LHS, b&: RHS); |
363 | std::swap(a&: LHSS, b&: RHSS); |
364 | Pred = ICmpInst::getSwappedPredicate(pred: Pred); |
365 | } |
366 | |
367 | const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Val: LHSS); |
368 | if (!AR || AR->getLoop() != L) |
369 | return std::nullopt; |
370 | |
371 | return LoopICmp(Pred, AR, RHSS); |
372 | } |
373 | |
374 | Value *LoopPredication::expandCheck(SCEVExpander &Expander, |
375 | Instruction *Guard, |
376 | ICmpInst::Predicate Pred, const SCEV *LHS, |
377 | const SCEV *RHS) { |
378 | Type *Ty = LHS->getType(); |
379 | assert(Ty == RHS->getType() && "expandCheck operands have different types?" ); |
380 | |
381 | if (SE->isLoopInvariant(S: LHS, L) && SE->isLoopInvariant(S: RHS, L)) { |
382 | IRBuilder<> Builder(Guard); |
383 | if (SE->isLoopEntryGuardedByCond(L, Pred, LHS, RHS)) |
384 | return Builder.getTrue(); |
385 | if (SE->isLoopEntryGuardedByCond(L, Pred: ICmpInst::getInversePredicate(pred: Pred), |
386 | LHS, RHS)) |
387 | return Builder.getFalse(); |
388 | } |
389 | |
390 | Value *LHSV = |
391 | Expander.expandCodeFor(SH: LHS, Ty, I: findInsertPt(Expander, User: Guard, Ops: {LHS})); |
392 | Value *RHSV = |
393 | Expander.expandCodeFor(SH: RHS, Ty, I: findInsertPt(Expander, User: Guard, Ops: {RHS})); |
394 | IRBuilder<> Builder(findInsertPt(User: Guard, Ops: {LHSV, RHSV})); |
395 | return Builder.CreateICmp(P: Pred, LHS: LHSV, RHS: RHSV); |
396 | } |
397 | |
398 | // Returns true if its safe to truncate the IV to RangeCheckType. |
399 | // When the IV type is wider than the range operand type, we can still do loop |
400 | // predication, by generating SCEVs for the range and latch that are of the |
401 | // same type. We achieve this by generating a SCEV truncate expression for the |
402 | // latch IV. This is done iff truncation of the IV is a safe operation, |
403 | // without loss of information. |
404 | // Another way to achieve this is by generating a wider type SCEV for the |
405 | // range check operand, however, this needs a more involved check that |
406 | // operands do not overflow. This can lead to loss of information when the |
407 | // range operand is of the form: add i32 %offset, %iv. We need to prove that |
408 | // sext(x + y) is same as sext(x) + sext(y). |
409 | // This function returns true if we can safely represent the IV type in |
410 | // the RangeCheckType without loss of information. |
411 | static bool isSafeToTruncateWideIVType(const DataLayout &DL, |
412 | ScalarEvolution &SE, |
413 | const LoopICmp LatchCheck, |
414 | Type *RangeCheckType) { |
415 | if (!EnableIVTruncation) |
416 | return false; |
417 | assert(DL.getTypeSizeInBits(LatchCheck.IV->getType()).getFixedValue() > |
418 | DL.getTypeSizeInBits(RangeCheckType).getFixedValue() && |
419 | "Expected latch check IV type to be larger than range check operand " |
420 | "type!" ); |
421 | // The start and end values of the IV should be known. This is to guarantee |
422 | // that truncating the wide type will not lose information. |
423 | auto *Limit = dyn_cast<SCEVConstant>(Val: LatchCheck.Limit); |
424 | auto *Start = dyn_cast<SCEVConstant>(Val: LatchCheck.IV->getStart()); |
425 | if (!Limit || !Start) |
426 | return false; |
427 | // This check makes sure that the IV does not change sign during loop |
428 | // iterations. Consider latchType = i64, LatchStart = 5, Pred = ICMP_SGE, |
429 | // LatchEnd = 2, rangeCheckType = i32. If it's not a monotonic predicate, the |
430 | // IV wraps around, and the truncation of the IV would lose the range of |
431 | // iterations between 2^32 and 2^64. |
432 | if (!SE.getMonotonicPredicateType(LHS: LatchCheck.IV, Pred: LatchCheck.Pred)) |
433 | return false; |
434 | // The active bits should be less than the bits in the RangeCheckType. This |
435 | // guarantees that truncating the latch check to RangeCheckType is a safe |
436 | // operation. |
437 | auto RangeCheckTypeBitSize = |
438 | DL.getTypeSizeInBits(Ty: RangeCheckType).getFixedValue(); |
439 | return Start->getAPInt().getActiveBits() < RangeCheckTypeBitSize && |
440 | Limit->getAPInt().getActiveBits() < RangeCheckTypeBitSize; |
441 | } |
442 | |
443 | |
444 | // Return an LoopICmp describing a latch check equivlent to LatchCheck but with |
445 | // the requested type if safe to do so. May involve the use of a new IV. |
446 | static std::optional<LoopICmp> generateLoopLatchCheck(const DataLayout &DL, |
447 | ScalarEvolution &SE, |
448 | const LoopICmp LatchCheck, |
449 | Type *RangeCheckType) { |
450 | |
451 | auto *LatchType = LatchCheck.IV->getType(); |
452 | if (RangeCheckType == LatchType) |
453 | return LatchCheck; |
454 | // For now, bail out if latch type is narrower than range type. |
455 | if (DL.getTypeSizeInBits(Ty: LatchType).getFixedValue() < |
456 | DL.getTypeSizeInBits(Ty: RangeCheckType).getFixedValue()) |
457 | return std::nullopt; |
458 | if (!isSafeToTruncateWideIVType(DL, SE, LatchCheck, RangeCheckType)) |
459 | return std::nullopt; |
460 | // We can now safely identify the truncated version of the IV and limit for |
461 | // RangeCheckType. |
462 | LoopICmp NewLatchCheck; |
463 | NewLatchCheck.Pred = LatchCheck.Pred; |
464 | NewLatchCheck.IV = dyn_cast<SCEVAddRecExpr>( |
465 | Val: SE.getTruncateExpr(Op: LatchCheck.IV, Ty: RangeCheckType)); |
466 | if (!NewLatchCheck.IV) |
467 | return std::nullopt; |
468 | NewLatchCheck.Limit = SE.getTruncateExpr(Op: LatchCheck.Limit, Ty: RangeCheckType); |
469 | LLVM_DEBUG(dbgs() << "IV of type: " << *LatchType |
470 | << "can be represented as range check type:" |
471 | << *RangeCheckType << "\n" ); |
472 | LLVM_DEBUG(dbgs() << "LatchCheck.IV: " << *NewLatchCheck.IV << "\n" ); |
473 | LLVM_DEBUG(dbgs() << "LatchCheck.Limit: " << *NewLatchCheck.Limit << "\n" ); |
474 | return NewLatchCheck; |
475 | } |
476 | |
477 | bool LoopPredication::isSupportedStep(const SCEV* Step) { |
478 | return Step->isOne() || (Step->isAllOnesValue() && EnableCountDownLoop); |
479 | } |
480 | |
481 | Instruction *LoopPredication::findInsertPt(Instruction *Use, |
482 | ArrayRef<Value*> Ops) { |
483 | for (Value *Op : Ops) |
484 | if (!L->isLoopInvariant(V: Op)) |
485 | return Use; |
486 | return Preheader->getTerminator(); |
487 | } |
488 | |
489 | Instruction *LoopPredication::findInsertPt(const SCEVExpander &Expander, |
490 | Instruction *Use, |
491 | ArrayRef<const SCEV *> Ops) { |
492 | // Subtlety: SCEV considers things to be invariant if the value produced is |
493 | // the same across iterations. This is not the same as being able to |
494 | // evaluate outside the loop, which is what we actually need here. |
495 | for (const SCEV *Op : Ops) |
496 | if (!SE->isLoopInvariant(S: Op, L) || |
497 | !Expander.isSafeToExpandAt(S: Op, InsertionPoint: Preheader->getTerminator())) |
498 | return Use; |
499 | return Preheader->getTerminator(); |
500 | } |
501 | |
502 | bool LoopPredication::isLoopInvariantValue(const SCEV* S) { |
503 | // Handling expressions which produce invariant results, but *haven't* yet |
504 | // been removed from the loop serves two important purposes. |
505 | // 1) Most importantly, it resolves a pass ordering cycle which would |
506 | // otherwise need us to iteration licm, loop-predication, and either |
507 | // loop-unswitch or loop-peeling to make progress on examples with lots of |
508 | // predicable range checks in a row. (Since, in the general case, we can't |
509 | // hoist the length checks until the dominating checks have been discharged |
510 | // as we can't prove doing so is safe.) |
511 | // 2) As a nice side effect, this exposes the value of peeling or unswitching |
512 | // much more obviously in the IR. Otherwise, the cost modeling for other |
513 | // transforms would end up needing to duplicate all of this logic to model a |
514 | // check which becomes predictable based on a modeled peel or unswitch. |
515 | // |
516 | // The cost of doing so in the worst case is an extra fill from the stack in |
517 | // the loop to materialize the loop invariant test value instead of checking |
518 | // against the original IV which is presumable in a register inside the loop. |
519 | // Such cases are presumably rare, and hint at missing oppurtunities for |
520 | // other passes. |
521 | |
522 | if (SE->isLoopInvariant(S, L)) |
523 | // Note: This the SCEV variant, so the original Value* may be within the |
524 | // loop even though SCEV has proven it is loop invariant. |
525 | return true; |
526 | |
527 | // Handle a particular important case which SCEV doesn't yet know about which |
528 | // shows up in range checks on arrays with immutable lengths. |
529 | // TODO: This should be sunk inside SCEV. |
530 | if (const SCEVUnknown *U = dyn_cast<SCEVUnknown>(Val: S)) |
531 | if (const auto *LI = dyn_cast<LoadInst>(Val: U->getValue())) |
532 | if (LI->isUnordered() && L->hasLoopInvariantOperands(I: LI)) |
533 | if (!isModSet(MRI: AA->getModRefInfoMask(P: LI->getOperand(i_nocapture: 0))) || |
534 | LI->hasMetadata(KindID: LLVMContext::MD_invariant_load)) |
535 | return true; |
536 | return false; |
537 | } |
538 | |
539 | std::optional<Value *> LoopPredication::widenICmpRangeCheckIncrementingLoop( |
540 | LoopICmp LatchCheck, LoopICmp RangeCheck, SCEVExpander &Expander, |
541 | Instruction *Guard) { |
542 | auto *Ty = RangeCheck.IV->getType(); |
543 | // Generate the widened condition for the forward loop: |
544 | // guardStart u< guardLimit && |
545 | // latchLimit <pred> guardLimit - 1 - guardStart + latchStart |
546 | // where <pred> depends on the latch condition predicate. See the file |
547 | // header comment for the reasoning. |
548 | // guardLimit - guardStart + latchStart - 1 |
549 | const SCEV *GuardStart = RangeCheck.IV->getStart(); |
550 | const SCEV *GuardLimit = RangeCheck.Limit; |
551 | const SCEV *LatchStart = LatchCheck.IV->getStart(); |
552 | const SCEV *LatchLimit = LatchCheck.Limit; |
553 | // Subtlety: We need all the values to be *invariant* across all iterations, |
554 | // but we only need to check expansion safety for those which *aren't* |
555 | // already guaranteed to dominate the guard. |
556 | if (!isLoopInvariantValue(S: GuardStart) || |
557 | !isLoopInvariantValue(S: GuardLimit) || |
558 | !isLoopInvariantValue(S: LatchStart) || |
559 | !isLoopInvariantValue(S: LatchLimit)) { |
560 | LLVM_DEBUG(dbgs() << "Can't expand limit check!\n" ); |
561 | return std::nullopt; |
562 | } |
563 | if (!Expander.isSafeToExpandAt(S: LatchStart, InsertionPoint: Guard) || |
564 | !Expander.isSafeToExpandAt(S: LatchLimit, InsertionPoint: Guard)) { |
565 | LLVM_DEBUG(dbgs() << "Can't expand limit check!\n" ); |
566 | return std::nullopt; |
567 | } |
568 | |
569 | // guardLimit - guardStart + latchStart - 1 |
570 | const SCEV *RHS = |
571 | SE->getAddExpr(LHS: SE->getMinusSCEV(LHS: GuardLimit, RHS: GuardStart), |
572 | RHS: SE->getMinusSCEV(LHS: LatchStart, RHS: SE->getOne(Ty))); |
573 | auto LimitCheckPred = |
574 | ICmpInst::getFlippedStrictnessPredicate(pred: LatchCheck.Pred); |
575 | |
576 | LLVM_DEBUG(dbgs() << "LHS: " << *LatchLimit << "\n" ); |
577 | LLVM_DEBUG(dbgs() << "RHS: " << *RHS << "\n" ); |
578 | LLVM_DEBUG(dbgs() << "Pred: " << LimitCheckPred << "\n" ); |
579 | |
580 | auto *LimitCheck = |
581 | expandCheck(Expander, Guard, Pred: LimitCheckPred, LHS: LatchLimit, RHS); |
582 | auto *FirstIterationCheck = expandCheck(Expander, Guard, Pred: RangeCheck.Pred, |
583 | LHS: GuardStart, RHS: GuardLimit); |
584 | IRBuilder<> Builder(findInsertPt(Use: Guard, Ops: {FirstIterationCheck, LimitCheck})); |
585 | return Builder.CreateFreeze( |
586 | V: Builder.CreateAnd(LHS: FirstIterationCheck, RHS: LimitCheck)); |
587 | } |
588 | |
589 | std::optional<Value *> LoopPredication::widenICmpRangeCheckDecrementingLoop( |
590 | LoopICmp LatchCheck, LoopICmp RangeCheck, SCEVExpander &Expander, |
591 | Instruction *Guard) { |
592 | auto *Ty = RangeCheck.IV->getType(); |
593 | const SCEV *GuardStart = RangeCheck.IV->getStart(); |
594 | const SCEV *GuardLimit = RangeCheck.Limit; |
595 | const SCEV *LatchStart = LatchCheck.IV->getStart(); |
596 | const SCEV *LatchLimit = LatchCheck.Limit; |
597 | // Subtlety: We need all the values to be *invariant* across all iterations, |
598 | // but we only need to check expansion safety for those which *aren't* |
599 | // already guaranteed to dominate the guard. |
600 | if (!isLoopInvariantValue(S: GuardStart) || |
601 | !isLoopInvariantValue(S: GuardLimit) || |
602 | !isLoopInvariantValue(S: LatchStart) || |
603 | !isLoopInvariantValue(S: LatchLimit)) { |
604 | LLVM_DEBUG(dbgs() << "Can't expand limit check!\n" ); |
605 | return std::nullopt; |
606 | } |
607 | if (!Expander.isSafeToExpandAt(S: LatchStart, InsertionPoint: Guard) || |
608 | !Expander.isSafeToExpandAt(S: LatchLimit, InsertionPoint: Guard)) { |
609 | LLVM_DEBUG(dbgs() << "Can't expand limit check!\n" ); |
610 | return std::nullopt; |
611 | } |
612 | // The decrement of the latch check IV should be the same as the |
613 | // rangeCheckIV. |
614 | auto *PostDecLatchCheckIV = LatchCheck.IV->getPostIncExpr(SE&: *SE); |
615 | if (RangeCheck.IV != PostDecLatchCheckIV) { |
616 | LLVM_DEBUG(dbgs() << "Not the same. PostDecLatchCheckIV: " |
617 | << *PostDecLatchCheckIV |
618 | << " and RangeCheckIV: " << *RangeCheck.IV << "\n" ); |
619 | return std::nullopt; |
620 | } |
621 | |
622 | // Generate the widened condition for CountDownLoop: |
623 | // guardStart u< guardLimit && |
624 | // latchLimit <pred> 1. |
625 | // See the header comment for reasoning of the checks. |
626 | auto LimitCheckPred = |
627 | ICmpInst::getFlippedStrictnessPredicate(pred: LatchCheck.Pred); |
628 | auto *FirstIterationCheck = expandCheck(Expander, Guard, |
629 | Pred: ICmpInst::ICMP_ULT, |
630 | LHS: GuardStart, RHS: GuardLimit); |
631 | auto *LimitCheck = expandCheck(Expander, Guard, Pred: LimitCheckPred, LHS: LatchLimit, |
632 | RHS: SE->getOne(Ty)); |
633 | IRBuilder<> Builder(findInsertPt(Use: Guard, Ops: {FirstIterationCheck, LimitCheck})); |
634 | return Builder.CreateFreeze( |
635 | V: Builder.CreateAnd(LHS: FirstIterationCheck, RHS: LimitCheck)); |
636 | } |
637 | |
638 | static void normalizePredicate(ScalarEvolution *SE, Loop *L, |
639 | LoopICmp& RC) { |
640 | // LFTR canonicalizes checks to the ICMP_NE/EQ form; normalize back to the |
641 | // ULT/UGE form for ease of handling by our caller. |
642 | if (ICmpInst::isEquality(P: RC.Pred) && |
643 | RC.IV->getStepRecurrence(SE&: *SE)->isOne() && |
644 | SE->isKnownPredicate(Pred: ICmpInst::ICMP_ULE, LHS: RC.IV->getStart(), RHS: RC.Limit)) |
645 | RC.Pred = RC.Pred == ICmpInst::ICMP_NE ? |
646 | ICmpInst::ICMP_ULT : ICmpInst::ICMP_UGE; |
647 | } |
648 | |
649 | /// If ICI can be widened to a loop invariant condition emits the loop |
650 | /// invariant condition in the loop preheader and return it, otherwise |
651 | /// returns std::nullopt. |
652 | std::optional<Value *> |
653 | LoopPredication::widenICmpRangeCheck(ICmpInst *ICI, SCEVExpander &Expander, |
654 | Instruction *Guard) { |
655 | LLVM_DEBUG(dbgs() << "Analyzing ICmpInst condition:\n" ); |
656 | LLVM_DEBUG(ICI->dump()); |
657 | |
658 | // parseLoopStructure guarantees that the latch condition is: |
659 | // ++i <pred> latchLimit, where <pred> is u<, u<=, s<, or s<=. |
660 | // We are looking for the range checks of the form: |
661 | // i u< guardLimit |
662 | auto RangeCheck = parseLoopICmp(ICI); |
663 | if (!RangeCheck) { |
664 | LLVM_DEBUG(dbgs() << "Failed to parse the loop latch condition!\n" ); |
665 | return std::nullopt; |
666 | } |
667 | LLVM_DEBUG(dbgs() << "Guard check:\n" ); |
668 | LLVM_DEBUG(RangeCheck->dump()); |
669 | if (RangeCheck->Pred != ICmpInst::ICMP_ULT) { |
670 | LLVM_DEBUG(dbgs() << "Unsupported range check predicate(" |
671 | << RangeCheck->Pred << ")!\n" ); |
672 | return std::nullopt; |
673 | } |
674 | auto *RangeCheckIV = RangeCheck->IV; |
675 | if (!RangeCheckIV->isAffine()) { |
676 | LLVM_DEBUG(dbgs() << "Range check IV is not affine!\n" ); |
677 | return std::nullopt; |
678 | } |
679 | const SCEV *Step = RangeCheckIV->getStepRecurrence(SE&: *SE); |
680 | // We cannot just compare with latch IV step because the latch and range IVs |
681 | // may have different types. |
682 | if (!isSupportedStep(Step)) { |
683 | LLVM_DEBUG(dbgs() << "Range check and latch have IVs different steps!\n" ); |
684 | return std::nullopt; |
685 | } |
686 | auto *Ty = RangeCheckIV->getType(); |
687 | auto CurrLatchCheckOpt = generateLoopLatchCheck(DL: *DL, SE&: *SE, LatchCheck, RangeCheckType: Ty); |
688 | if (!CurrLatchCheckOpt) { |
689 | LLVM_DEBUG(dbgs() << "Failed to generate a loop latch check " |
690 | "corresponding to range type: " |
691 | << *Ty << "\n" ); |
692 | return std::nullopt; |
693 | } |
694 | |
695 | LoopICmp CurrLatchCheck = *CurrLatchCheckOpt; |
696 | // At this point, the range and latch step should have the same type, but need |
697 | // not have the same value (we support both 1 and -1 steps). |
698 | assert(Step->getType() == |
699 | CurrLatchCheck.IV->getStepRecurrence(*SE)->getType() && |
700 | "Range and latch steps should be of same type!" ); |
701 | if (Step != CurrLatchCheck.IV->getStepRecurrence(SE&: *SE)) { |
702 | LLVM_DEBUG(dbgs() << "Range and latch have different step values!\n" ); |
703 | return std::nullopt; |
704 | } |
705 | |
706 | if (Step->isOne()) |
707 | return widenICmpRangeCheckIncrementingLoop(LatchCheck: CurrLatchCheck, RangeCheck: *RangeCheck, |
708 | Expander, Guard); |
709 | else { |
710 | assert(Step->isAllOnesValue() && "Step should be -1!" ); |
711 | return widenICmpRangeCheckDecrementingLoop(LatchCheck: CurrLatchCheck, RangeCheck: *RangeCheck, |
712 | Expander, Guard); |
713 | } |
714 | } |
715 | |
716 | void LoopPredication::widenChecks(SmallVectorImpl<Value *> &Checks, |
717 | SmallVectorImpl<Value *> &WidenedChecks, |
718 | SCEVExpander &Expander, Instruction *Guard) { |
719 | for (auto &Check : Checks) |
720 | if (ICmpInst *ICI = dyn_cast<ICmpInst>(Val: Check)) |
721 | if (auto NewRangeCheck = widenICmpRangeCheck(ICI, Expander, Guard)) { |
722 | WidenedChecks.push_back(Elt: Check); |
723 | Check = *NewRangeCheck; |
724 | } |
725 | } |
726 | |
727 | bool LoopPredication::widenGuardConditions(IntrinsicInst *Guard, |
728 | SCEVExpander &Expander) { |
729 | LLVM_DEBUG(dbgs() << "Processing guard:\n" ); |
730 | LLVM_DEBUG(Guard->dump()); |
731 | |
732 | TotalConsidered++; |
733 | SmallVector<Value *, 4> Checks; |
734 | SmallVector<Value *> WidenedChecks; |
735 | parseWidenableGuard(U: Guard, Checks); |
736 | widenChecks(Checks, WidenedChecks, Expander, Guard); |
737 | if (WidenedChecks.empty()) |
738 | return false; |
739 | |
740 | TotalWidened += WidenedChecks.size(); |
741 | |
742 | // Emit the new guard condition |
743 | IRBuilder<> Builder(findInsertPt(Use: Guard, Ops: Checks)); |
744 | Value *AllChecks = Builder.CreateAnd(Ops: Checks); |
745 | auto *OldCond = Guard->getOperand(i_nocapture: 0); |
746 | Guard->setOperand(i_nocapture: 0, Val_nocapture: AllChecks); |
747 | if (InsertAssumesOfPredicatedGuardsConditions) { |
748 | Builder.SetInsertPoint(&*++BasicBlock::iterator(Guard)); |
749 | Builder.CreateAssumption(Cond: OldCond); |
750 | } |
751 | RecursivelyDeleteTriviallyDeadInstructions(V: OldCond, TLI: nullptr /* TLI */, MSSAU); |
752 | |
753 | LLVM_DEBUG(dbgs() << "Widened checks = " << WidenedChecks.size() << "\n" ); |
754 | return true; |
755 | } |
756 | |
757 | bool LoopPredication::widenWidenableBranchGuardConditions( |
758 | BranchInst *BI, SCEVExpander &Expander) { |
759 | assert(isGuardAsWidenableBranch(BI) && "Must be!" ); |
760 | LLVM_DEBUG(dbgs() << "Processing guard:\n" ); |
761 | LLVM_DEBUG(BI->dump()); |
762 | |
763 | TotalConsidered++; |
764 | SmallVector<Value *, 4> Checks; |
765 | SmallVector<Value *> WidenedChecks; |
766 | parseWidenableGuard(U: BI, Checks); |
767 | // At the moment, our matching logic for wideable conditions implicitly |
768 | // assumes we preserve the form: (br (and Cond, WC())). FIXME |
769 | auto WC = extractWidenableCondition(U: BI); |
770 | Checks.push_back(Elt: WC); |
771 | widenChecks(Checks, WidenedChecks, Expander, Guard: BI); |
772 | if (WidenedChecks.empty()) |
773 | return false; |
774 | |
775 | TotalWidened += WidenedChecks.size(); |
776 | |
777 | // Emit the new guard condition |
778 | IRBuilder<> Builder(findInsertPt(Use: BI, Ops: Checks)); |
779 | Value *AllChecks = Builder.CreateAnd(Ops: Checks); |
780 | auto *OldCond = BI->getCondition(); |
781 | BI->setCondition(AllChecks); |
782 | if (InsertAssumesOfPredicatedGuardsConditions) { |
783 | BasicBlock *IfTrueBB = BI->getSuccessor(i: 0); |
784 | Builder.SetInsertPoint(TheBB: IfTrueBB, IP: IfTrueBB->getFirstInsertionPt()); |
785 | // If this block has other predecessors, we might not be able to use Cond. |
786 | // In this case, create a Phi where every other input is `true` and input |
787 | // from guard block is Cond. |
788 | Value *AssumeCond = Builder.CreateAnd(Ops: WidenedChecks); |
789 | if (!IfTrueBB->getUniquePredecessor()) { |
790 | auto *GuardBB = BI->getParent(); |
791 | auto *PN = Builder.CreatePHI(Ty: AssumeCond->getType(), NumReservedValues: pred_size(BB: IfTrueBB), |
792 | Name: "assume.cond" ); |
793 | for (auto *Pred : predecessors(BB: IfTrueBB)) |
794 | PN->addIncoming(V: Pred == GuardBB ? AssumeCond : Builder.getTrue(), BB: Pred); |
795 | AssumeCond = PN; |
796 | } |
797 | Builder.CreateAssumption(Cond: AssumeCond); |
798 | } |
799 | RecursivelyDeleteTriviallyDeadInstructions(V: OldCond, TLI: nullptr /* TLI */, MSSAU); |
800 | assert(isGuardAsWidenableBranch(BI) && |
801 | "Stopped being a guard after transform?" ); |
802 | |
803 | LLVM_DEBUG(dbgs() << "Widened checks = " << WidenedChecks.size() << "\n" ); |
804 | return true; |
805 | } |
806 | |
807 | std::optional<LoopICmp> LoopPredication::parseLoopLatchICmp() { |
808 | using namespace PatternMatch; |
809 | |
810 | BasicBlock *LoopLatch = L->getLoopLatch(); |
811 | if (!LoopLatch) { |
812 | LLVM_DEBUG(dbgs() << "The loop doesn't have a single latch!\n" ); |
813 | return std::nullopt; |
814 | } |
815 | |
816 | auto *BI = dyn_cast<BranchInst>(Val: LoopLatch->getTerminator()); |
817 | if (!BI || !BI->isConditional()) { |
818 | LLVM_DEBUG(dbgs() << "Failed to match the latch terminator!\n" ); |
819 | return std::nullopt; |
820 | } |
821 | BasicBlock *TrueDest = BI->getSuccessor(i: 0); |
822 | assert( |
823 | (TrueDest == L->getHeader() || BI->getSuccessor(1) == L->getHeader()) && |
824 | "One of the latch's destinations must be the header" ); |
825 | |
826 | auto *ICI = dyn_cast<ICmpInst>(Val: BI->getCondition()); |
827 | if (!ICI) { |
828 | LLVM_DEBUG(dbgs() << "Failed to match the latch condition!\n" ); |
829 | return std::nullopt; |
830 | } |
831 | auto Result = parseLoopICmp(ICI); |
832 | if (!Result) { |
833 | LLVM_DEBUG(dbgs() << "Failed to parse the loop latch condition!\n" ); |
834 | return std::nullopt; |
835 | } |
836 | |
837 | if (TrueDest != L->getHeader()) |
838 | Result->Pred = ICmpInst::getInversePredicate(pred: Result->Pred); |
839 | |
840 | // Check affine first, so if it's not we don't try to compute the step |
841 | // recurrence. |
842 | if (!Result->IV->isAffine()) { |
843 | LLVM_DEBUG(dbgs() << "The induction variable is not affine!\n" ); |
844 | return std::nullopt; |
845 | } |
846 | |
847 | const SCEV *Step = Result->IV->getStepRecurrence(SE&: *SE); |
848 | if (!isSupportedStep(Step)) { |
849 | LLVM_DEBUG(dbgs() << "Unsupported loop stride(" << *Step << ")!\n" ); |
850 | return std::nullopt; |
851 | } |
852 | |
853 | auto IsUnsupportedPredicate = [](const SCEV *Step, ICmpInst::Predicate Pred) { |
854 | if (Step->isOne()) { |
855 | return Pred != ICmpInst::ICMP_ULT && Pred != ICmpInst::ICMP_SLT && |
856 | Pred != ICmpInst::ICMP_ULE && Pred != ICmpInst::ICMP_SLE; |
857 | } else { |
858 | assert(Step->isAllOnesValue() && "Step should be -1!" ); |
859 | return Pred != ICmpInst::ICMP_UGT && Pred != ICmpInst::ICMP_SGT && |
860 | Pred != ICmpInst::ICMP_UGE && Pred != ICmpInst::ICMP_SGE; |
861 | } |
862 | }; |
863 | |
864 | normalizePredicate(SE, L, RC&: *Result); |
865 | if (IsUnsupportedPredicate(Step, Result->Pred)) { |
866 | LLVM_DEBUG(dbgs() << "Unsupported loop latch predicate(" << Result->Pred |
867 | << ")!\n" ); |
868 | return std::nullopt; |
869 | } |
870 | |
871 | return Result; |
872 | } |
873 | |
874 | bool LoopPredication::isLoopProfitableToPredicate() { |
875 | if (SkipProfitabilityChecks) |
876 | return true; |
877 | |
878 | SmallVector<std::pair<BasicBlock *, BasicBlock *>, 8> ExitEdges; |
879 | L->getExitEdges(ExitEdges); |
880 | // If there is only one exiting edge in the loop, it is always profitable to |
881 | // predicate the loop. |
882 | if (ExitEdges.size() == 1) |
883 | return true; |
884 | |
885 | // Calculate the exiting probabilities of all exiting edges from the loop, |
886 | // starting with the LatchExitProbability. |
887 | // Heuristic for profitability: If any of the exiting blocks' probability of |
888 | // exiting the loop is larger than exiting through the latch block, it's not |
889 | // profitable to predicate the loop. |
890 | auto *LatchBlock = L->getLoopLatch(); |
891 | assert(LatchBlock && "Should have a single latch at this point!" ); |
892 | auto *LatchTerm = LatchBlock->getTerminator(); |
893 | assert(LatchTerm->getNumSuccessors() == 2 && |
894 | "expected to be an exiting block with 2 succs!" ); |
895 | unsigned LatchBrExitIdx = |
896 | LatchTerm->getSuccessor(Idx: 0) == L->getHeader() ? 1 : 0; |
897 | // We compute branch probabilities without BPI. We do not rely on BPI since |
898 | // Loop predication is usually run in an LPM and BPI is only preserved |
899 | // lossily within loop pass managers, while BPI has an inherent notion of |
900 | // being complete for an entire function. |
901 | |
902 | // If the latch exits into a deoptimize or an unreachable block, do not |
903 | // predicate on that latch check. |
904 | auto *LatchExitBlock = LatchTerm->getSuccessor(Idx: LatchBrExitIdx); |
905 | if (isa<UnreachableInst>(Val: LatchTerm) || |
906 | LatchExitBlock->getTerminatingDeoptimizeCall()) |
907 | return false; |
908 | |
909 | // Latch terminator has no valid profile data, so nothing to check |
910 | // profitability on. |
911 | if (!hasValidBranchWeightMD(I: *LatchTerm)) |
912 | return true; |
913 | |
914 | auto ComputeBranchProbability = |
915 | [&](const BasicBlock *ExitingBlock, |
916 | const BasicBlock *ExitBlock) -> BranchProbability { |
917 | auto *Term = ExitingBlock->getTerminator(); |
918 | unsigned NumSucc = Term->getNumSuccessors(); |
919 | if (MDNode *ProfileData = getValidBranchWeightMDNode(I: *Term)) { |
920 | SmallVector<uint32_t> Weights; |
921 | extractBranchWeights(ProfileData, Weights); |
922 | uint64_t Numerator = 0, Denominator = 0; |
923 | for (auto [i, Weight] : llvm::enumerate(First&: Weights)) { |
924 | if (Term->getSuccessor(Idx: i) == ExitBlock) |
925 | Numerator += Weight; |
926 | Denominator += Weight; |
927 | } |
928 | // If all weights are zero act as if there was no profile data |
929 | if (Denominator == 0) |
930 | return BranchProbability::getBranchProbability(Numerator: 1, Denominator: NumSucc); |
931 | return BranchProbability::getBranchProbability(Numerator, Denominator); |
932 | } else { |
933 | assert(LatchBlock != ExitingBlock && |
934 | "Latch term should always have profile data!" ); |
935 | // No profile data, so we choose the weight as 1/num_of_succ(Src) |
936 | return BranchProbability::getBranchProbability(Numerator: 1, Denominator: NumSucc); |
937 | } |
938 | }; |
939 | |
940 | BranchProbability LatchExitProbability = |
941 | ComputeBranchProbability(LatchBlock, LatchExitBlock); |
942 | |
943 | // Protect against degenerate inputs provided by the user. Providing a value |
944 | // less than one, can invert the definition of profitable loop predication. |
945 | float ScaleFactor = LatchExitProbabilityScale; |
946 | if (ScaleFactor < 1) { |
947 | LLVM_DEBUG( |
948 | dbgs() |
949 | << "Ignored user setting for loop-predication-latch-probability-scale: " |
950 | << LatchExitProbabilityScale << "\n" ); |
951 | LLVM_DEBUG(dbgs() << "The value is set to 1.0\n" ); |
952 | ScaleFactor = 1.0; |
953 | } |
954 | const auto LatchProbabilityThreshold = LatchExitProbability * ScaleFactor; |
955 | |
956 | for (const auto &ExitEdge : ExitEdges) { |
957 | BranchProbability ExitingBlockProbability = |
958 | ComputeBranchProbability(ExitEdge.first, ExitEdge.second); |
959 | // Some exiting edge has higher probability than the latch exiting edge. |
960 | // No longer profitable to predicate. |
961 | if (ExitingBlockProbability > LatchProbabilityThreshold) |
962 | return false; |
963 | } |
964 | |
965 | // We have concluded that the most probable way to exit from the |
966 | // loop is through the latch (or there's no profile information and all |
967 | // exits are equally likely). |
968 | return true; |
969 | } |
970 | |
971 | /// If we can (cheaply) find a widenable branch which controls entry into the |
972 | /// loop, return it. |
973 | static BranchInst *FindWidenableTerminatorAboveLoop(Loop *L, LoopInfo &LI) { |
974 | // Walk back through any unconditional executed blocks and see if we can find |
975 | // a widenable condition which seems to control execution of this loop. Note |
976 | // that we predict that maythrow calls are likely untaken and thus that it's |
977 | // profitable to widen a branch before a maythrow call with a condition |
978 | // afterwards even though that may cause the slow path to run in a case where |
979 | // it wouldn't have otherwise. |
980 | BasicBlock *BB = L->getLoopPreheader(); |
981 | if (!BB) |
982 | return nullptr; |
983 | do { |
984 | if (BasicBlock *Pred = BB->getSinglePredecessor()) |
985 | if (BB == Pred->getSingleSuccessor()) { |
986 | BB = Pred; |
987 | continue; |
988 | } |
989 | break; |
990 | } while (true); |
991 | |
992 | if (BasicBlock *Pred = BB->getSinglePredecessor()) { |
993 | if (auto *BI = dyn_cast<BranchInst>(Val: Pred->getTerminator())) |
994 | if (BI->getSuccessor(i: 0) == BB && isWidenableBranch(U: BI)) |
995 | return BI; |
996 | } |
997 | return nullptr; |
998 | } |
999 | |
1000 | /// Return the minimum of all analyzeable exit counts. This is an upper bound |
1001 | /// on the actual exit count. If there are not at least two analyzeable exits, |
1002 | /// returns SCEVCouldNotCompute. |
1003 | static const SCEV *getMinAnalyzeableBackedgeTakenCount(ScalarEvolution &SE, |
1004 | DominatorTree &DT, |
1005 | Loop *L) { |
1006 | SmallVector<BasicBlock *, 16> ExitingBlocks; |
1007 | L->getExitingBlocks(ExitingBlocks); |
1008 | |
1009 | SmallVector<const SCEV *, 4> ExitCounts; |
1010 | for (BasicBlock *ExitingBB : ExitingBlocks) { |
1011 | const SCEV *ExitCount = SE.getExitCount(L, ExitingBlock: ExitingBB); |
1012 | if (isa<SCEVCouldNotCompute>(Val: ExitCount)) |
1013 | continue; |
1014 | assert(DT.dominates(ExitingBB, L->getLoopLatch()) && |
1015 | "We should only have known counts for exiting blocks that " |
1016 | "dominate latch!" ); |
1017 | ExitCounts.push_back(Elt: ExitCount); |
1018 | } |
1019 | if (ExitCounts.size() < 2) |
1020 | return SE.getCouldNotCompute(); |
1021 | return SE.getUMinFromMismatchedTypes(Ops&: ExitCounts); |
1022 | } |
1023 | |
1024 | /// This implements an analogous, but entirely distinct transform from the main |
1025 | /// loop predication transform. This one is phrased in terms of using a |
1026 | /// widenable branch *outside* the loop to allow us to simplify loop exits in a |
1027 | /// following loop. This is close in spirit to the IndVarSimplify transform |
1028 | /// of the same name, but is materially different widening loosens legality |
1029 | /// sharply. |
1030 | bool LoopPredication::predicateLoopExits(Loop *L, SCEVExpander &Rewriter) { |
1031 | // The transformation performed here aims to widen a widenable condition |
1032 | // above the loop such that all analyzeable exit leading to deopt are dead. |
1033 | // It assumes that the latch is the dominant exit for profitability and that |
1034 | // exits branching to deoptimizing blocks are rarely taken. It relies on the |
1035 | // semantics of widenable expressions for legality. (i.e. being able to fall |
1036 | // down the widenable path spuriously allows us to ignore exit order, |
1037 | // unanalyzeable exits, side effects, exceptional exits, and other challenges |
1038 | // which restrict the applicability of the non-WC based version of this |
1039 | // transform in IndVarSimplify.) |
1040 | // |
1041 | // NOTE ON POISON/UNDEF - We're hoisting an expression above guards which may |
1042 | // imply flags on the expression being hoisted and inserting new uses (flags |
1043 | // are only correct for current uses). The result is that we may be |
1044 | // inserting a branch on the value which can be either poison or undef. In |
1045 | // this case, the branch can legally go either way; we just need to avoid |
1046 | // introducing UB. This is achieved through the use of the freeze |
1047 | // instruction. |
1048 | |
1049 | SmallVector<BasicBlock *, 16> ExitingBlocks; |
1050 | L->getExitingBlocks(ExitingBlocks); |
1051 | |
1052 | if (ExitingBlocks.empty()) |
1053 | return false; // Nothing to do. |
1054 | |
1055 | auto *Latch = L->getLoopLatch(); |
1056 | if (!Latch) |
1057 | return false; |
1058 | |
1059 | auto *WidenableBR = FindWidenableTerminatorAboveLoop(L, LI&: *LI); |
1060 | if (!WidenableBR) |
1061 | return false; |
1062 | |
1063 | const SCEV *LatchEC = SE->getExitCount(L, ExitingBlock: Latch); |
1064 | if (isa<SCEVCouldNotCompute>(Val: LatchEC)) |
1065 | return false; // profitability - want hot exit in analyzeable set |
1066 | |
1067 | // At this point, we have found an analyzeable latch, and a widenable |
1068 | // condition above the loop. If we have a widenable exit within the loop |
1069 | // (for which we can't compute exit counts), drop the ability to further |
1070 | // widen so that we gain ability to analyze it's exit count and perform this |
1071 | // transform. TODO: It'd be nice to know for sure the exit became |
1072 | // analyzeable after dropping widenability. |
1073 | bool ChangedLoop = false; |
1074 | |
1075 | for (auto *ExitingBB : ExitingBlocks) { |
1076 | if (LI->getLoopFor(BB: ExitingBB) != L) |
1077 | continue; |
1078 | |
1079 | auto *BI = dyn_cast<BranchInst>(Val: ExitingBB->getTerminator()); |
1080 | if (!BI) |
1081 | continue; |
1082 | |
1083 | if (auto WC = extractWidenableCondition(U: BI)) |
1084 | if (L->contains(BB: BI->getSuccessor(i: 0))) { |
1085 | assert(WC->hasOneUse() && "Not appropriate widenable branch!" ); |
1086 | WC->user_back()->replaceUsesOfWith( |
1087 | From: WC, To: ConstantInt::getTrue(Context&: BI->getContext())); |
1088 | ChangedLoop = true; |
1089 | } |
1090 | } |
1091 | if (ChangedLoop) |
1092 | SE->forgetLoop(L); |
1093 | |
1094 | // The insertion point for the widening should be at the widenably call, not |
1095 | // at the WidenableBR. If we do this at the widenableBR, we can incorrectly |
1096 | // change a loop-invariant condition to a loop-varying one. |
1097 | auto *IP = cast<Instruction>(Val: WidenableBR->getCondition()); |
1098 | |
1099 | // The use of umin(all analyzeable exits) instead of latch is subtle, but |
1100 | // important for profitability. We may have a loop which hasn't been fully |
1101 | // canonicalized just yet. If the exit we chose to widen is provably never |
1102 | // taken, we want the widened form to *also* be provably never taken. We |
1103 | // can't guarantee this as a current unanalyzeable exit may later become |
1104 | // analyzeable, but we can at least avoid the obvious cases. |
1105 | const SCEV *MinEC = getMinAnalyzeableBackedgeTakenCount(SE&: *SE, DT&: *DT, L); |
1106 | if (isa<SCEVCouldNotCompute>(Val: MinEC) || MinEC->getType()->isPointerTy() || |
1107 | !SE->isLoopInvariant(S: MinEC, L) || |
1108 | !Rewriter.isSafeToExpandAt(S: MinEC, InsertionPoint: IP)) |
1109 | return ChangedLoop; |
1110 | |
1111 | Rewriter.setInsertPoint(IP); |
1112 | IRBuilder<> B(IP); |
1113 | |
1114 | bool InvalidateLoop = false; |
1115 | Value *MinECV = nullptr; // lazily generated if needed |
1116 | for (BasicBlock *ExitingBB : ExitingBlocks) { |
1117 | // If our exiting block exits multiple loops, we can only rewrite the |
1118 | // innermost one. Otherwise, we're changing how many times the innermost |
1119 | // loop runs before it exits. |
1120 | if (LI->getLoopFor(BB: ExitingBB) != L) |
1121 | continue; |
1122 | |
1123 | // Can't rewrite non-branch yet. |
1124 | auto *BI = dyn_cast<BranchInst>(Val: ExitingBB->getTerminator()); |
1125 | if (!BI) |
1126 | continue; |
1127 | |
1128 | // If already constant, nothing to do. |
1129 | if (isa<Constant>(Val: BI->getCondition())) |
1130 | continue; |
1131 | |
1132 | const SCEV *ExitCount = SE->getExitCount(L, ExitingBlock: ExitingBB); |
1133 | if (isa<SCEVCouldNotCompute>(Val: ExitCount) || |
1134 | ExitCount->getType()->isPointerTy() || |
1135 | !Rewriter.isSafeToExpandAt(S: ExitCount, InsertionPoint: WidenableBR)) |
1136 | continue; |
1137 | |
1138 | const bool ExitIfTrue = !L->contains(BB: *succ_begin(BB: ExitingBB)); |
1139 | BasicBlock *ExitBB = BI->getSuccessor(i: ExitIfTrue ? 0 : 1); |
1140 | if (!ExitBB->getPostdominatingDeoptimizeCall()) |
1141 | continue; |
1142 | |
1143 | /// Here we can be fairly sure that executing this exit will most likely |
1144 | /// lead to executing llvm.experimental.deoptimize. |
1145 | /// This is a profitability heuristic, not a legality constraint. |
1146 | |
1147 | // If we found a widenable exit condition, do two things: |
1148 | // 1) fold the widened exit test into the widenable condition |
1149 | // 2) fold the branch to untaken - avoids infinite looping |
1150 | |
1151 | Value *ECV = Rewriter.expandCodeFor(SH: ExitCount); |
1152 | if (!MinECV) |
1153 | MinECV = Rewriter.expandCodeFor(SH: MinEC); |
1154 | Value *RHS = MinECV; |
1155 | if (ECV->getType() != RHS->getType()) { |
1156 | Type *WiderTy = SE->getWiderType(Ty1: ECV->getType(), Ty2: RHS->getType()); |
1157 | ECV = B.CreateZExt(V: ECV, DestTy: WiderTy); |
1158 | RHS = B.CreateZExt(V: RHS, DestTy: WiderTy); |
1159 | } |
1160 | assert(!Latch || DT->dominates(ExitingBB, Latch)); |
1161 | Value *NewCond = B.CreateICmp(P: ICmpInst::ICMP_UGT, LHS: ECV, RHS); |
1162 | // Freeze poison or undef to an arbitrary bit pattern to ensure we can |
1163 | // branch without introducing UB. See NOTE ON POISON/UNDEF above for |
1164 | // context. |
1165 | NewCond = B.CreateFreeze(V: NewCond); |
1166 | |
1167 | widenWidenableBranch(WidenableBR, NewCond); |
1168 | |
1169 | Value *OldCond = BI->getCondition(); |
1170 | BI->setCondition(ConstantInt::get(Ty: OldCond->getType(), V: !ExitIfTrue)); |
1171 | InvalidateLoop = true; |
1172 | } |
1173 | |
1174 | if (InvalidateLoop) |
1175 | // We just mutated a bunch of loop exits changing there exit counts |
1176 | // widely. We need to force recomputation of the exit counts given these |
1177 | // changes. Note that all of the inserted exits are never taken, and |
1178 | // should be removed next time the CFG is modified. |
1179 | SE->forgetLoop(L); |
1180 | |
1181 | // Always return `true` since we have moved the WidenableBR's condition. |
1182 | return true; |
1183 | } |
1184 | |
1185 | bool LoopPredication::runOnLoop(Loop *Loop) { |
1186 | L = Loop; |
1187 | |
1188 | LLVM_DEBUG(dbgs() << "Analyzing " ); |
1189 | LLVM_DEBUG(L->dump()); |
1190 | |
1191 | Module *M = L->getHeader()->getModule(); |
1192 | |
1193 | // There is nothing to do if the module doesn't use guards |
1194 | auto *GuardDecl = |
1195 | Intrinsic::getDeclarationIfExists(M, id: Intrinsic::experimental_guard); |
1196 | bool HasIntrinsicGuards = GuardDecl && !GuardDecl->use_empty(); |
1197 | auto *WCDecl = Intrinsic::getDeclarationIfExists( |
1198 | M, id: Intrinsic::experimental_widenable_condition); |
1199 | bool HasWidenableConditions = |
1200 | PredicateWidenableBranchGuards && WCDecl && !WCDecl->use_empty(); |
1201 | if (!HasIntrinsicGuards && !HasWidenableConditions) |
1202 | return false; |
1203 | |
1204 | DL = &M->getDataLayout(); |
1205 | |
1206 | Preheader = L->getLoopPreheader(); |
1207 | if (!Preheader) |
1208 | return false; |
1209 | |
1210 | auto LatchCheckOpt = parseLoopLatchICmp(); |
1211 | if (!LatchCheckOpt) |
1212 | return false; |
1213 | LatchCheck = *LatchCheckOpt; |
1214 | |
1215 | LLVM_DEBUG(dbgs() << "Latch check:\n" ); |
1216 | LLVM_DEBUG(LatchCheck.dump()); |
1217 | |
1218 | if (!isLoopProfitableToPredicate()) { |
1219 | LLVM_DEBUG(dbgs() << "Loop not profitable to predicate!\n" ); |
1220 | return false; |
1221 | } |
1222 | // Collect all the guards into a vector and process later, so as not |
1223 | // to invalidate the instruction iterator. |
1224 | SmallVector<IntrinsicInst *, 4> Guards; |
1225 | SmallVector<BranchInst *, 4> GuardsAsWidenableBranches; |
1226 | for (const auto BB : L->blocks()) { |
1227 | for (auto &I : *BB) |
1228 | if (isGuard(U: &I)) |
1229 | Guards.push_back(Elt: cast<IntrinsicInst>(Val: &I)); |
1230 | if (PredicateWidenableBranchGuards && |
1231 | isGuardAsWidenableBranch(U: BB->getTerminator())) |
1232 | GuardsAsWidenableBranches.push_back( |
1233 | Elt: cast<BranchInst>(Val: BB->getTerminator())); |
1234 | } |
1235 | |
1236 | SCEVExpander Expander(*SE, *DL, "loop-predication" ); |
1237 | bool Changed = false; |
1238 | for (auto *Guard : Guards) |
1239 | Changed |= widenGuardConditions(Guard, Expander); |
1240 | for (auto *Guard : GuardsAsWidenableBranches) |
1241 | Changed |= widenWidenableBranchGuardConditions(BI: Guard, Expander); |
1242 | Changed |= predicateLoopExits(L, Rewriter&: Expander); |
1243 | |
1244 | if (MSSAU && VerifyMemorySSA) |
1245 | MSSAU->getMemorySSA()->verifyMemorySSA(); |
1246 | return Changed; |
1247 | } |
1248 | |