1 | //===-- LoopPredication.cpp - Guard based loop predication pass -----------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // The LoopPredication pass tries to convert loop variant range checks to loop |
10 | // invariant by widening checks across loop iterations. For example, it will |
11 | // convert |
12 | // |
13 | // for (i = 0; i < n; i++) { |
14 | // guard(i < len); |
15 | // ... |
16 | // } |
17 | // |
18 | // to |
19 | // |
20 | // for (i = 0; i < n; i++) { |
21 | // guard(n - 1 < len); |
22 | // ... |
23 | // } |
24 | // |
25 | // After this transformation the condition of the guard is loop invariant, so |
26 | // loop-unswitch can later unswitch the loop by this condition which basically |
27 | // predicates the loop by the widened condition: |
28 | // |
29 | // if (n - 1 < len) |
30 | // for (i = 0; i < n; i++) { |
31 | // ... |
32 | // } |
33 | // else |
34 | // deoptimize |
35 | // |
36 | // It's tempting to rely on SCEV here, but it has proven to be problematic. |
37 | // Generally the facts SCEV provides about the increment step of add |
38 | // recurrences are true if the backedge of the loop is taken, which implicitly |
39 | // assumes that the guard doesn't fail. Using these facts to optimize the |
40 | // guard results in a circular logic where the guard is optimized under the |
41 | // assumption that it never fails. |
42 | // |
43 | // For example, in the loop below the induction variable will be marked as nuw |
44 | // basing on the guard. Basing on nuw the guard predicate will be considered |
45 | // monotonic. Given a monotonic condition it's tempting to replace the induction |
46 | // variable in the condition with its value on the last iteration. But this |
47 | // transformation is not correct, e.g. e = 4, b = 5 breaks the loop. |
48 | // |
49 | // for (int i = b; i != e; i++) |
50 | // guard(i u< len) |
51 | // |
52 | // One of the ways to reason about this problem is to use an inductive proof |
53 | // approach. Given the loop: |
54 | // |
55 | // if (B(0)) { |
56 | // do { |
57 | // I = PHI(0, I.INC) |
58 | // I.INC = I + Step |
59 | // guard(G(I)); |
60 | // } while (B(I)); |
61 | // } |
62 | // |
63 | // where B(x) and G(x) are predicates that map integers to booleans, we want a |
64 | // loop invariant expression M such the following program has the same semantics |
65 | // as the above: |
66 | // |
67 | // if (B(0)) { |
68 | // do { |
69 | // I = PHI(0, I.INC) |
70 | // I.INC = I + Step |
71 | // guard(G(0) && M); |
72 | // } while (B(I)); |
73 | // } |
74 | // |
75 | // One solution for M is M = forall X . (G(X) && B(X)) => G(X + Step) |
76 | // |
77 | // Informal proof that the transformation above is correct: |
78 | // |
79 | // By the definition of guards we can rewrite the guard condition to: |
80 | // G(I) && G(0) && M |
81 | // |
82 | // Let's prove that for each iteration of the loop: |
83 | // G(0) && M => G(I) |
84 | // And the condition above can be simplified to G(Start) && M. |
85 | // |
86 | // Induction base. |
87 | // G(0) && M => G(0) |
88 | // |
89 | // Induction step. Assuming G(0) && M => G(I) on the subsequent |
90 | // iteration: |
91 | // |
92 | // B(I) is true because it's the backedge condition. |
93 | // G(I) is true because the backedge is guarded by this condition. |
94 | // |
95 | // So M = forall X . (G(X) && B(X)) => G(X + Step) implies G(I + Step). |
96 | // |
97 | // Note that we can use anything stronger than M, i.e. any condition which |
98 | // implies M. |
99 | // |
100 | // When S = 1 (i.e. forward iterating loop), the transformation is supported |
101 | // when: |
102 | // * The loop has a single latch with the condition of the form: |
103 | // B(X) = latchStart + X <pred> latchLimit, |
104 | // where <pred> is u<, u<=, s<, or s<=. |
105 | // * The guard condition is of the form |
106 | // G(X) = guardStart + X u< guardLimit |
107 | // |
108 | // For the ult latch comparison case M is: |
109 | // forall X . guardStart + X u< guardLimit && latchStart + X <u latchLimit => |
110 | // guardStart + X + 1 u< guardLimit |
111 | // |
112 | // The only way the antecedent can be true and the consequent can be false is |
113 | // if |
114 | // X == guardLimit - 1 - guardStart |
115 | // (and guardLimit is non-zero, but we won't use this latter fact). |
116 | // If X == guardLimit - 1 - guardStart then the second half of the antecedent is |
117 | // latchStart + guardLimit - 1 - guardStart u< latchLimit |
118 | // and its negation is |
119 | // latchStart + guardLimit - 1 - guardStart u>= latchLimit |
120 | // |
121 | // In other words, if |
122 | // latchLimit u<= latchStart + guardLimit - 1 - guardStart |
123 | // then: |
124 | // (the ranges below are written in ConstantRange notation, where [A, B) is the |
125 | // set for (I = A; I != B; I++ /*maywrap*/) yield(I);) |
126 | // |
127 | // forall X . guardStart + X u< guardLimit && |
128 | // latchStart + X u< latchLimit => |
129 | // guardStart + X + 1 u< guardLimit |
130 | // == forall X . guardStart + X u< guardLimit && |
131 | // latchStart + X u< latchStart + guardLimit - 1 - guardStart => |
132 | // guardStart + X + 1 u< guardLimit |
133 | // == forall X . (guardStart + X) in [0, guardLimit) && |
134 | // (latchStart + X) in [0, latchStart + guardLimit - 1 - guardStart) => |
135 | // (guardStart + X + 1) in [0, guardLimit) |
136 | // == forall X . X in [-guardStart, guardLimit - guardStart) && |
137 | // X in [-latchStart, guardLimit - 1 - guardStart) => |
138 | // X in [-guardStart - 1, guardLimit - guardStart - 1) |
139 | // == true |
140 | // |
141 | // So the widened condition is: |
142 | // guardStart u< guardLimit && |
143 | // latchStart + guardLimit - 1 - guardStart u>= latchLimit |
144 | // Similarly for ule condition the widened condition is: |
145 | // guardStart u< guardLimit && |
146 | // latchStart + guardLimit - 1 - guardStart u> latchLimit |
147 | // For slt condition the widened condition is: |
148 | // guardStart u< guardLimit && |
149 | // latchStart + guardLimit - 1 - guardStart s>= latchLimit |
150 | // For sle condition the widened condition is: |
151 | // guardStart u< guardLimit && |
152 | // latchStart + guardLimit - 1 - guardStart s> latchLimit |
153 | // |
154 | // When S = -1 (i.e. reverse iterating loop), the transformation is supported |
155 | // when: |
156 | // * The loop has a single latch with the condition of the form: |
157 | // B(X) = X <pred> latchLimit, where <pred> is u>, u>=, s>, or s>=. |
158 | // * The guard condition is of the form |
159 | // G(X) = X - 1 u< guardLimit |
160 | // |
161 | // For the ugt latch comparison case M is: |
162 | // forall X. X-1 u< guardLimit and X u> latchLimit => X-2 u< guardLimit |
163 | // |
164 | // The only way the antecedent can be true and the consequent can be false is if |
165 | // X == 1. |
166 | // If X == 1 then the second half of the antecedent is |
167 | // 1 u> latchLimit, and its negation is latchLimit u>= 1. |
168 | // |
169 | // So the widened condition is: |
170 | // guardStart u< guardLimit && latchLimit u>= 1. |
171 | // Similarly for sgt condition the widened condition is: |
172 | // guardStart u< guardLimit && latchLimit s>= 1. |
173 | // For uge condition the widened condition is: |
174 | // guardStart u< guardLimit && latchLimit u> 1. |
175 | // For sge condition the widened condition is: |
176 | // guardStart u< guardLimit && latchLimit s> 1. |
177 | //===----------------------------------------------------------------------===// |
178 | |
179 | #include "llvm/Transforms/Scalar/LoopPredication.h" |
180 | #include "llvm/ADT/Statistic.h" |
181 | #include "llvm/Analysis/AliasAnalysis.h" |
182 | #include "llvm/Analysis/BranchProbabilityInfo.h" |
183 | #include "llvm/Analysis/GuardUtils.h" |
184 | #include "llvm/Analysis/LoopInfo.h" |
185 | #include "llvm/Analysis/LoopPass.h" |
186 | #include "llvm/Analysis/MemorySSA.h" |
187 | #include "llvm/Analysis/MemorySSAUpdater.h" |
188 | #include "llvm/Analysis/ScalarEvolution.h" |
189 | #include "llvm/Analysis/ScalarEvolutionExpressions.h" |
190 | #include "llvm/IR/Function.h" |
191 | #include "llvm/IR/IntrinsicInst.h" |
192 | #include "llvm/IR/Module.h" |
193 | #include "llvm/IR/PatternMatch.h" |
194 | #include "llvm/IR/ProfDataUtils.h" |
195 | #include "llvm/InitializePasses.h" |
196 | #include "llvm/Pass.h" |
197 | #include "llvm/Support/CommandLine.h" |
198 | #include "llvm/Support/Debug.h" |
199 | #include "llvm/Transforms/Scalar.h" |
200 | #include "llvm/Transforms/Utils/GuardUtils.h" |
201 | #include "llvm/Transforms/Utils/Local.h" |
202 | #include "llvm/Transforms/Utils/LoopUtils.h" |
203 | #include "llvm/Transforms/Utils/ScalarEvolutionExpander.h" |
204 | #include <optional> |
205 | |
206 | #define DEBUG_TYPE "loop-predication" |
207 | |
208 | STATISTIC(TotalConsidered, "Number of guards considered" ); |
209 | STATISTIC(TotalWidened, "Number of checks widened" ); |
210 | |
211 | using namespace llvm; |
212 | |
213 | static cl::opt<bool> EnableIVTruncation("loop-predication-enable-iv-truncation" , |
214 | cl::Hidden, cl::init(Val: true)); |
215 | |
216 | static cl::opt<bool> EnableCountDownLoop("loop-predication-enable-count-down-loop" , |
217 | cl::Hidden, cl::init(Val: true)); |
218 | |
219 | static cl::opt<bool> |
220 | SkipProfitabilityChecks("loop-predication-skip-profitability-checks" , |
221 | cl::Hidden, cl::init(Val: false)); |
222 | |
223 | // This is the scale factor for the latch probability. We use this during |
224 | // profitability analysis to find other exiting blocks that have a much higher |
225 | // probability of exiting the loop instead of loop exiting via latch. |
226 | // This value should be greater than 1 for a sane profitability check. |
227 | static cl::opt<float> LatchExitProbabilityScale( |
228 | "loop-predication-latch-probability-scale" , cl::Hidden, cl::init(Val: 2.0), |
229 | cl::desc("scale factor for the latch probability. Value should be greater " |
230 | "than 1. Lower values are ignored" )); |
231 | |
232 | static cl::opt<bool> PredicateWidenableBranchGuards( |
233 | "loop-predication-predicate-widenable-branches-to-deopt" , cl::Hidden, |
234 | cl::desc("Whether or not we should predicate guards " |
235 | "expressed as widenable branches to deoptimize blocks" ), |
236 | cl::init(Val: true)); |
237 | |
238 | static cl::opt<bool> InsertAssumesOfPredicatedGuardsConditions( |
239 | "loop-predication-insert-assumes-of-predicated-guards-conditions" , |
240 | cl::Hidden, |
241 | cl::desc("Whether or not we should insert assumes of conditions of " |
242 | "predicated guards" ), |
243 | cl::init(Val: true)); |
244 | |
245 | namespace { |
246 | /// Represents an induction variable check: |
247 | /// icmp Pred, <induction variable>, <loop invariant limit> |
248 | struct LoopICmp { |
249 | ICmpInst::Predicate Pred; |
250 | const SCEVAddRecExpr *IV; |
251 | const SCEV *Limit; |
252 | LoopICmp(ICmpInst::Predicate Pred, const SCEVAddRecExpr *IV, |
253 | const SCEV *Limit) |
254 | : Pred(Pred), IV(IV), Limit(Limit) {} |
255 | LoopICmp() = default; |
256 | void dump() { |
257 | dbgs() << "LoopICmp Pred = " << Pred << ", IV = " << *IV |
258 | << ", Limit = " << *Limit << "\n" ; |
259 | } |
260 | }; |
261 | |
262 | class LoopPredication { |
263 | AliasAnalysis *AA; |
264 | DominatorTree *DT; |
265 | ScalarEvolution *SE; |
266 | LoopInfo *LI; |
267 | MemorySSAUpdater *MSSAU; |
268 | |
269 | Loop *L; |
270 | const DataLayout *DL; |
271 | BasicBlock *; |
272 | LoopICmp LatchCheck; |
273 | |
274 | bool isSupportedStep(const SCEV* Step); |
275 | std::optional<LoopICmp> parseLoopICmp(ICmpInst *ICI); |
276 | std::optional<LoopICmp> parseLoopLatchICmp(); |
277 | |
278 | /// Return an insertion point suitable for inserting a safe to speculate |
279 | /// instruction whose only user will be 'User' which has operands 'Ops'. A |
280 | /// trivial result would be the at the User itself, but we try to return a |
281 | /// loop invariant location if possible. |
282 | Instruction *findInsertPt(Instruction *User, ArrayRef<Value*> Ops); |
283 | /// Same as above, *except* that this uses the SCEV definition of invariant |
284 | /// which is that an expression *can be made* invariant via SCEVExpander. |
285 | /// Thus, this version is only suitable for finding an insert point to be |
286 | /// passed to SCEVExpander! |
287 | Instruction *findInsertPt(const SCEVExpander &Expander, Instruction *User, |
288 | ArrayRef<const SCEV *> Ops); |
289 | |
290 | /// Return true if the value is known to produce a single fixed value across |
291 | /// all iterations on which it executes. Note that this does not imply |
292 | /// speculation safety. That must be established separately. |
293 | bool isLoopInvariantValue(const SCEV* S); |
294 | |
295 | Value *expandCheck(SCEVExpander &Expander, Instruction *Guard, |
296 | ICmpInst::Predicate Pred, const SCEV *LHS, |
297 | const SCEV *RHS); |
298 | |
299 | std::optional<Value *> widenICmpRangeCheck(ICmpInst *ICI, |
300 | SCEVExpander &Expander, |
301 | Instruction *Guard); |
302 | std::optional<Value *> |
303 | widenICmpRangeCheckIncrementingLoop(LoopICmp LatchCheck, LoopICmp RangeCheck, |
304 | SCEVExpander &Expander, |
305 | Instruction *Guard); |
306 | std::optional<Value *> |
307 | widenICmpRangeCheckDecrementingLoop(LoopICmp LatchCheck, LoopICmp RangeCheck, |
308 | SCEVExpander &Expander, |
309 | Instruction *Guard); |
310 | void widenChecks(SmallVectorImpl<Value *> &Checks, |
311 | SmallVectorImpl<Value *> &WidenedChecks, |
312 | SCEVExpander &Expander, Instruction *Guard); |
313 | bool widenGuardConditions(IntrinsicInst *II, SCEVExpander &Expander); |
314 | bool widenWidenableBranchGuardConditions(BranchInst *Guard, SCEVExpander &Expander); |
315 | // If the loop always exits through another block in the loop, we should not |
316 | // predicate based on the latch check. For example, the latch check can be a |
317 | // very coarse grained check and there can be more fine grained exit checks |
318 | // within the loop. |
319 | bool isLoopProfitableToPredicate(); |
320 | |
321 | bool predicateLoopExits(Loop *L, SCEVExpander &Rewriter); |
322 | |
323 | public: |
324 | LoopPredication(AliasAnalysis *AA, DominatorTree *DT, ScalarEvolution *SE, |
325 | LoopInfo *LI, MemorySSAUpdater *MSSAU) |
326 | : AA(AA), DT(DT), SE(SE), LI(LI), MSSAU(MSSAU){}; |
327 | bool runOnLoop(Loop *L); |
328 | }; |
329 | |
330 | } // end namespace |
331 | |
332 | PreservedAnalyses LoopPredicationPass::run(Loop &L, LoopAnalysisManager &AM, |
333 | LoopStandardAnalysisResults &AR, |
334 | LPMUpdater &U) { |
335 | std::unique_ptr<MemorySSAUpdater> MSSAU; |
336 | if (AR.MSSA) |
337 | MSSAU = std::make_unique<MemorySSAUpdater>(args&: AR.MSSA); |
338 | LoopPredication LP(&AR.AA, &AR.DT, &AR.SE, &AR.LI, |
339 | MSSAU ? MSSAU.get() : nullptr); |
340 | if (!LP.runOnLoop(L: &L)) |
341 | return PreservedAnalyses::all(); |
342 | |
343 | auto PA = getLoopPassPreservedAnalyses(); |
344 | if (AR.MSSA) |
345 | PA.preserve<MemorySSAAnalysis>(); |
346 | return PA; |
347 | } |
348 | |
349 | std::optional<LoopICmp> LoopPredication::parseLoopICmp(ICmpInst *ICI) { |
350 | auto Pred = ICI->getPredicate(); |
351 | auto *LHS = ICI->getOperand(i_nocapture: 0); |
352 | auto *RHS = ICI->getOperand(i_nocapture: 1); |
353 | |
354 | const SCEV *LHSS = SE->getSCEV(V: LHS); |
355 | if (isa<SCEVCouldNotCompute>(Val: LHSS)) |
356 | return std::nullopt; |
357 | const SCEV *RHSS = SE->getSCEV(V: RHS); |
358 | if (isa<SCEVCouldNotCompute>(Val: RHSS)) |
359 | return std::nullopt; |
360 | |
361 | // Canonicalize RHS to be loop invariant bound, LHS - a loop computable IV |
362 | if (SE->isLoopInvariant(S: LHSS, L)) { |
363 | std::swap(a&: LHS, b&: RHS); |
364 | std::swap(a&: LHSS, b&: RHSS); |
365 | Pred = ICmpInst::getSwappedPredicate(pred: Pred); |
366 | } |
367 | |
368 | const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Val: LHSS); |
369 | if (!AR || AR->getLoop() != L) |
370 | return std::nullopt; |
371 | |
372 | return LoopICmp(Pred, AR, RHSS); |
373 | } |
374 | |
375 | Value *LoopPredication::expandCheck(SCEVExpander &Expander, |
376 | Instruction *Guard, |
377 | ICmpInst::Predicate Pred, const SCEV *LHS, |
378 | const SCEV *RHS) { |
379 | Type *Ty = LHS->getType(); |
380 | assert(Ty == RHS->getType() && "expandCheck operands have different types?" ); |
381 | |
382 | if (SE->isLoopInvariant(S: LHS, L) && SE->isLoopInvariant(S: RHS, L)) { |
383 | IRBuilder<> Builder(Guard); |
384 | if (SE->isLoopEntryGuardedByCond(L, Pred, LHS, RHS)) |
385 | return Builder.getTrue(); |
386 | if (SE->isLoopEntryGuardedByCond(L, Pred: ICmpInst::getInversePredicate(pred: Pred), |
387 | LHS, RHS)) |
388 | return Builder.getFalse(); |
389 | } |
390 | |
391 | Value *LHSV = |
392 | Expander.expandCodeFor(SH: LHS, Ty, I: findInsertPt(Expander, User: Guard, Ops: {LHS})); |
393 | Value *RHSV = |
394 | Expander.expandCodeFor(SH: RHS, Ty, I: findInsertPt(Expander, User: Guard, Ops: {RHS})); |
395 | IRBuilder<> Builder(findInsertPt(User: Guard, Ops: {LHSV, RHSV})); |
396 | return Builder.CreateICmp(P: Pred, LHS: LHSV, RHS: RHSV); |
397 | } |
398 | |
399 | // Returns true if its safe to truncate the IV to RangeCheckType. |
400 | // When the IV type is wider than the range operand type, we can still do loop |
401 | // predication, by generating SCEVs for the range and latch that are of the |
402 | // same type. We achieve this by generating a SCEV truncate expression for the |
403 | // latch IV. This is done iff truncation of the IV is a safe operation, |
404 | // without loss of information. |
405 | // Another way to achieve this is by generating a wider type SCEV for the |
406 | // range check operand, however, this needs a more involved check that |
407 | // operands do not overflow. This can lead to loss of information when the |
408 | // range operand is of the form: add i32 %offset, %iv. We need to prove that |
409 | // sext(x + y) is same as sext(x) + sext(y). |
410 | // This function returns true if we can safely represent the IV type in |
411 | // the RangeCheckType without loss of information. |
412 | static bool isSafeToTruncateWideIVType(const DataLayout &DL, |
413 | ScalarEvolution &SE, |
414 | const LoopICmp LatchCheck, |
415 | Type *RangeCheckType) { |
416 | if (!EnableIVTruncation) |
417 | return false; |
418 | assert(DL.getTypeSizeInBits(LatchCheck.IV->getType()).getFixedValue() > |
419 | DL.getTypeSizeInBits(RangeCheckType).getFixedValue() && |
420 | "Expected latch check IV type to be larger than range check operand " |
421 | "type!" ); |
422 | // The start and end values of the IV should be known. This is to guarantee |
423 | // that truncating the wide type will not lose information. |
424 | auto *Limit = dyn_cast<SCEVConstant>(Val: LatchCheck.Limit); |
425 | auto *Start = dyn_cast<SCEVConstant>(Val: LatchCheck.IV->getStart()); |
426 | if (!Limit || !Start) |
427 | return false; |
428 | // This check makes sure that the IV does not change sign during loop |
429 | // iterations. Consider latchType = i64, LatchStart = 5, Pred = ICMP_SGE, |
430 | // LatchEnd = 2, rangeCheckType = i32. If it's not a monotonic predicate, the |
431 | // IV wraps around, and the truncation of the IV would lose the range of |
432 | // iterations between 2^32 and 2^64. |
433 | if (!SE.getMonotonicPredicateType(LHS: LatchCheck.IV, Pred: LatchCheck.Pred)) |
434 | return false; |
435 | // The active bits should be less than the bits in the RangeCheckType. This |
436 | // guarantees that truncating the latch check to RangeCheckType is a safe |
437 | // operation. |
438 | auto RangeCheckTypeBitSize = |
439 | DL.getTypeSizeInBits(Ty: RangeCheckType).getFixedValue(); |
440 | return Start->getAPInt().getActiveBits() < RangeCheckTypeBitSize && |
441 | Limit->getAPInt().getActiveBits() < RangeCheckTypeBitSize; |
442 | } |
443 | |
444 | |
445 | // Return an LoopICmp describing a latch check equivlent to LatchCheck but with |
446 | // the requested type if safe to do so. May involve the use of a new IV. |
447 | static std::optional<LoopICmp> generateLoopLatchCheck(const DataLayout &DL, |
448 | ScalarEvolution &SE, |
449 | const LoopICmp LatchCheck, |
450 | Type *RangeCheckType) { |
451 | |
452 | auto *LatchType = LatchCheck.IV->getType(); |
453 | if (RangeCheckType == LatchType) |
454 | return LatchCheck; |
455 | // For now, bail out if latch type is narrower than range type. |
456 | if (DL.getTypeSizeInBits(Ty: LatchType).getFixedValue() < |
457 | DL.getTypeSizeInBits(Ty: RangeCheckType).getFixedValue()) |
458 | return std::nullopt; |
459 | if (!isSafeToTruncateWideIVType(DL, SE, LatchCheck, RangeCheckType)) |
460 | return std::nullopt; |
461 | // We can now safely identify the truncated version of the IV and limit for |
462 | // RangeCheckType. |
463 | LoopICmp NewLatchCheck; |
464 | NewLatchCheck.Pred = LatchCheck.Pred; |
465 | NewLatchCheck.IV = dyn_cast<SCEVAddRecExpr>( |
466 | Val: SE.getTruncateExpr(Op: LatchCheck.IV, Ty: RangeCheckType)); |
467 | if (!NewLatchCheck.IV) |
468 | return std::nullopt; |
469 | NewLatchCheck.Limit = SE.getTruncateExpr(Op: LatchCheck.Limit, Ty: RangeCheckType); |
470 | LLVM_DEBUG(dbgs() << "IV of type: " << *LatchType |
471 | << "can be represented as range check type:" |
472 | << *RangeCheckType << "\n" ); |
473 | LLVM_DEBUG(dbgs() << "LatchCheck.IV: " << *NewLatchCheck.IV << "\n" ); |
474 | LLVM_DEBUG(dbgs() << "LatchCheck.Limit: " << *NewLatchCheck.Limit << "\n" ); |
475 | return NewLatchCheck; |
476 | } |
477 | |
478 | bool LoopPredication::isSupportedStep(const SCEV* Step) { |
479 | return Step->isOne() || (Step->isAllOnesValue() && EnableCountDownLoop); |
480 | } |
481 | |
482 | Instruction *LoopPredication::findInsertPt(Instruction *Use, |
483 | ArrayRef<Value*> Ops) { |
484 | for (Value *Op : Ops) |
485 | if (!L->isLoopInvariant(V: Op)) |
486 | return Use; |
487 | return Preheader->getTerminator(); |
488 | } |
489 | |
490 | Instruction *LoopPredication::findInsertPt(const SCEVExpander &Expander, |
491 | Instruction *Use, |
492 | ArrayRef<const SCEV *> Ops) { |
493 | // Subtlety: SCEV considers things to be invariant if the value produced is |
494 | // the same across iterations. This is not the same as being able to |
495 | // evaluate outside the loop, which is what we actually need here. |
496 | for (const SCEV *Op : Ops) |
497 | if (!SE->isLoopInvariant(S: Op, L) || |
498 | !Expander.isSafeToExpandAt(S: Op, InsertionPoint: Preheader->getTerminator())) |
499 | return Use; |
500 | return Preheader->getTerminator(); |
501 | } |
502 | |
503 | bool LoopPredication::isLoopInvariantValue(const SCEV* S) { |
504 | // Handling expressions which produce invariant results, but *haven't* yet |
505 | // been removed from the loop serves two important purposes. |
506 | // 1) Most importantly, it resolves a pass ordering cycle which would |
507 | // otherwise need us to iteration licm, loop-predication, and either |
508 | // loop-unswitch or loop-peeling to make progress on examples with lots of |
509 | // predicable range checks in a row. (Since, in the general case, we can't |
510 | // hoist the length checks until the dominating checks have been discharged |
511 | // as we can't prove doing so is safe.) |
512 | // 2) As a nice side effect, this exposes the value of peeling or unswitching |
513 | // much more obviously in the IR. Otherwise, the cost modeling for other |
514 | // transforms would end up needing to duplicate all of this logic to model a |
515 | // check which becomes predictable based on a modeled peel or unswitch. |
516 | // |
517 | // The cost of doing so in the worst case is an extra fill from the stack in |
518 | // the loop to materialize the loop invariant test value instead of checking |
519 | // against the original IV which is presumable in a register inside the loop. |
520 | // Such cases are presumably rare, and hint at missing oppurtunities for |
521 | // other passes. |
522 | |
523 | if (SE->isLoopInvariant(S, L)) |
524 | // Note: This the SCEV variant, so the original Value* may be within the |
525 | // loop even though SCEV has proven it is loop invariant. |
526 | return true; |
527 | |
528 | // Handle a particular important case which SCEV doesn't yet know about which |
529 | // shows up in range checks on arrays with immutable lengths. |
530 | // TODO: This should be sunk inside SCEV. |
531 | if (const SCEVUnknown *U = dyn_cast<SCEVUnknown>(Val: S)) |
532 | if (const auto *LI = dyn_cast<LoadInst>(Val: U->getValue())) |
533 | if (LI->isUnordered() && L->hasLoopInvariantOperands(I: LI)) |
534 | if (!isModSet(MRI: AA->getModRefInfoMask(P: LI->getOperand(i_nocapture: 0))) || |
535 | LI->hasMetadata(KindID: LLVMContext::MD_invariant_load)) |
536 | return true; |
537 | return false; |
538 | } |
539 | |
540 | std::optional<Value *> LoopPredication::widenICmpRangeCheckIncrementingLoop( |
541 | LoopICmp LatchCheck, LoopICmp RangeCheck, SCEVExpander &Expander, |
542 | Instruction *Guard) { |
543 | auto *Ty = RangeCheck.IV->getType(); |
544 | // Generate the widened condition for the forward loop: |
545 | // guardStart u< guardLimit && |
546 | // latchLimit <pred> guardLimit - 1 - guardStart + latchStart |
547 | // where <pred> depends on the latch condition predicate. See the file |
548 | // header comment for the reasoning. |
549 | // guardLimit - guardStart + latchStart - 1 |
550 | const SCEV *GuardStart = RangeCheck.IV->getStart(); |
551 | const SCEV *GuardLimit = RangeCheck.Limit; |
552 | const SCEV *LatchStart = LatchCheck.IV->getStart(); |
553 | const SCEV *LatchLimit = LatchCheck.Limit; |
554 | // Subtlety: We need all the values to be *invariant* across all iterations, |
555 | // but we only need to check expansion safety for those which *aren't* |
556 | // already guaranteed to dominate the guard. |
557 | if (!isLoopInvariantValue(S: GuardStart) || |
558 | !isLoopInvariantValue(S: GuardLimit) || |
559 | !isLoopInvariantValue(S: LatchStart) || |
560 | !isLoopInvariantValue(S: LatchLimit)) { |
561 | LLVM_DEBUG(dbgs() << "Can't expand limit check!\n" ); |
562 | return std::nullopt; |
563 | } |
564 | if (!Expander.isSafeToExpandAt(S: LatchStart, InsertionPoint: Guard) || |
565 | !Expander.isSafeToExpandAt(S: LatchLimit, InsertionPoint: Guard)) { |
566 | LLVM_DEBUG(dbgs() << "Can't expand limit check!\n" ); |
567 | return std::nullopt; |
568 | } |
569 | |
570 | // guardLimit - guardStart + latchStart - 1 |
571 | const SCEV *RHS = |
572 | SE->getAddExpr(LHS: SE->getMinusSCEV(LHS: GuardLimit, RHS: GuardStart), |
573 | RHS: SE->getMinusSCEV(LHS: LatchStart, RHS: SE->getOne(Ty))); |
574 | auto LimitCheckPred = |
575 | ICmpInst::getFlippedStrictnessPredicate(pred: LatchCheck.Pred); |
576 | |
577 | LLVM_DEBUG(dbgs() << "LHS: " << *LatchLimit << "\n" ); |
578 | LLVM_DEBUG(dbgs() << "RHS: " << *RHS << "\n" ); |
579 | LLVM_DEBUG(dbgs() << "Pred: " << LimitCheckPred << "\n" ); |
580 | |
581 | auto *LimitCheck = |
582 | expandCheck(Expander, Guard, Pred: LimitCheckPred, LHS: LatchLimit, RHS); |
583 | auto *FirstIterationCheck = expandCheck(Expander, Guard, Pred: RangeCheck.Pred, |
584 | LHS: GuardStart, RHS: GuardLimit); |
585 | IRBuilder<> Builder(findInsertPt(Use: Guard, Ops: {FirstIterationCheck, LimitCheck})); |
586 | return Builder.CreateFreeze( |
587 | V: Builder.CreateAnd(LHS: FirstIterationCheck, RHS: LimitCheck)); |
588 | } |
589 | |
590 | std::optional<Value *> LoopPredication::widenICmpRangeCheckDecrementingLoop( |
591 | LoopICmp LatchCheck, LoopICmp RangeCheck, SCEVExpander &Expander, |
592 | Instruction *Guard) { |
593 | auto *Ty = RangeCheck.IV->getType(); |
594 | const SCEV *GuardStart = RangeCheck.IV->getStart(); |
595 | const SCEV *GuardLimit = RangeCheck.Limit; |
596 | const SCEV *LatchStart = LatchCheck.IV->getStart(); |
597 | const SCEV *LatchLimit = LatchCheck.Limit; |
598 | // Subtlety: We need all the values to be *invariant* across all iterations, |
599 | // but we only need to check expansion safety for those which *aren't* |
600 | // already guaranteed to dominate the guard. |
601 | if (!isLoopInvariantValue(S: GuardStart) || |
602 | !isLoopInvariantValue(S: GuardLimit) || |
603 | !isLoopInvariantValue(S: LatchStart) || |
604 | !isLoopInvariantValue(S: LatchLimit)) { |
605 | LLVM_DEBUG(dbgs() << "Can't expand limit check!\n" ); |
606 | return std::nullopt; |
607 | } |
608 | if (!Expander.isSafeToExpandAt(S: LatchStart, InsertionPoint: Guard) || |
609 | !Expander.isSafeToExpandAt(S: LatchLimit, InsertionPoint: Guard)) { |
610 | LLVM_DEBUG(dbgs() << "Can't expand limit check!\n" ); |
611 | return std::nullopt; |
612 | } |
613 | // The decrement of the latch check IV should be the same as the |
614 | // rangeCheckIV. |
615 | auto *PostDecLatchCheckIV = LatchCheck.IV->getPostIncExpr(SE&: *SE); |
616 | if (RangeCheck.IV != PostDecLatchCheckIV) { |
617 | LLVM_DEBUG(dbgs() << "Not the same. PostDecLatchCheckIV: " |
618 | << *PostDecLatchCheckIV |
619 | << " and RangeCheckIV: " << *RangeCheck.IV << "\n" ); |
620 | return std::nullopt; |
621 | } |
622 | |
623 | // Generate the widened condition for CountDownLoop: |
624 | // guardStart u< guardLimit && |
625 | // latchLimit <pred> 1. |
626 | // See the header comment for reasoning of the checks. |
627 | auto LimitCheckPred = |
628 | ICmpInst::getFlippedStrictnessPredicate(pred: LatchCheck.Pred); |
629 | auto *FirstIterationCheck = expandCheck(Expander, Guard, |
630 | Pred: ICmpInst::ICMP_ULT, |
631 | LHS: GuardStart, RHS: GuardLimit); |
632 | auto *LimitCheck = expandCheck(Expander, Guard, Pred: LimitCheckPred, LHS: LatchLimit, |
633 | RHS: SE->getOne(Ty)); |
634 | IRBuilder<> Builder(findInsertPt(Use: Guard, Ops: {FirstIterationCheck, LimitCheck})); |
635 | return Builder.CreateFreeze( |
636 | V: Builder.CreateAnd(LHS: FirstIterationCheck, RHS: LimitCheck)); |
637 | } |
638 | |
639 | static void normalizePredicate(ScalarEvolution *SE, Loop *L, |
640 | LoopICmp& RC) { |
641 | // LFTR canonicalizes checks to the ICMP_NE/EQ form; normalize back to the |
642 | // ULT/UGE form for ease of handling by our caller. |
643 | if (ICmpInst::isEquality(P: RC.Pred) && |
644 | RC.IV->getStepRecurrence(SE&: *SE)->isOne() && |
645 | SE->isKnownPredicate(Pred: ICmpInst::ICMP_ULE, LHS: RC.IV->getStart(), RHS: RC.Limit)) |
646 | RC.Pred = RC.Pred == ICmpInst::ICMP_NE ? |
647 | ICmpInst::ICMP_ULT : ICmpInst::ICMP_UGE; |
648 | } |
649 | |
650 | /// If ICI can be widened to a loop invariant condition emits the loop |
651 | /// invariant condition in the loop preheader and return it, otherwise |
652 | /// returns std::nullopt. |
653 | std::optional<Value *> |
654 | LoopPredication::widenICmpRangeCheck(ICmpInst *ICI, SCEVExpander &Expander, |
655 | Instruction *Guard) { |
656 | LLVM_DEBUG(dbgs() << "Analyzing ICmpInst condition:\n" ); |
657 | LLVM_DEBUG(ICI->dump()); |
658 | |
659 | // parseLoopStructure guarantees that the latch condition is: |
660 | // ++i <pred> latchLimit, where <pred> is u<, u<=, s<, or s<=. |
661 | // We are looking for the range checks of the form: |
662 | // i u< guardLimit |
663 | auto RangeCheck = parseLoopICmp(ICI); |
664 | if (!RangeCheck) { |
665 | LLVM_DEBUG(dbgs() << "Failed to parse the loop latch condition!\n" ); |
666 | return std::nullopt; |
667 | } |
668 | LLVM_DEBUG(dbgs() << "Guard check:\n" ); |
669 | LLVM_DEBUG(RangeCheck->dump()); |
670 | if (RangeCheck->Pred != ICmpInst::ICMP_ULT) { |
671 | LLVM_DEBUG(dbgs() << "Unsupported range check predicate(" |
672 | << RangeCheck->Pred << ")!\n" ); |
673 | return std::nullopt; |
674 | } |
675 | auto *RangeCheckIV = RangeCheck->IV; |
676 | if (!RangeCheckIV->isAffine()) { |
677 | LLVM_DEBUG(dbgs() << "Range check IV is not affine!\n" ); |
678 | return std::nullopt; |
679 | } |
680 | auto *Step = RangeCheckIV->getStepRecurrence(SE&: *SE); |
681 | // We cannot just compare with latch IV step because the latch and range IVs |
682 | // may have different types. |
683 | if (!isSupportedStep(Step)) { |
684 | LLVM_DEBUG(dbgs() << "Range check and latch have IVs different steps!\n" ); |
685 | return std::nullopt; |
686 | } |
687 | auto *Ty = RangeCheckIV->getType(); |
688 | auto CurrLatchCheckOpt = generateLoopLatchCheck(DL: *DL, SE&: *SE, LatchCheck, RangeCheckType: Ty); |
689 | if (!CurrLatchCheckOpt) { |
690 | LLVM_DEBUG(dbgs() << "Failed to generate a loop latch check " |
691 | "corresponding to range type: " |
692 | << *Ty << "\n" ); |
693 | return std::nullopt; |
694 | } |
695 | |
696 | LoopICmp CurrLatchCheck = *CurrLatchCheckOpt; |
697 | // At this point, the range and latch step should have the same type, but need |
698 | // not have the same value (we support both 1 and -1 steps). |
699 | assert(Step->getType() == |
700 | CurrLatchCheck.IV->getStepRecurrence(*SE)->getType() && |
701 | "Range and latch steps should be of same type!" ); |
702 | if (Step != CurrLatchCheck.IV->getStepRecurrence(SE&: *SE)) { |
703 | LLVM_DEBUG(dbgs() << "Range and latch have different step values!\n" ); |
704 | return std::nullopt; |
705 | } |
706 | |
707 | if (Step->isOne()) |
708 | return widenICmpRangeCheckIncrementingLoop(LatchCheck: CurrLatchCheck, RangeCheck: *RangeCheck, |
709 | Expander, Guard); |
710 | else { |
711 | assert(Step->isAllOnesValue() && "Step should be -1!" ); |
712 | return widenICmpRangeCheckDecrementingLoop(LatchCheck: CurrLatchCheck, RangeCheck: *RangeCheck, |
713 | Expander, Guard); |
714 | } |
715 | } |
716 | |
717 | void LoopPredication::widenChecks(SmallVectorImpl<Value *> &Checks, |
718 | SmallVectorImpl<Value *> &WidenedChecks, |
719 | SCEVExpander &Expander, Instruction *Guard) { |
720 | for (auto &Check : Checks) |
721 | if (ICmpInst *ICI = dyn_cast<ICmpInst>(Val: Check)) |
722 | if (auto NewRangeCheck = widenICmpRangeCheck(ICI, Expander, Guard)) { |
723 | WidenedChecks.push_back(Elt: Check); |
724 | Check = *NewRangeCheck; |
725 | } |
726 | } |
727 | |
728 | bool LoopPredication::widenGuardConditions(IntrinsicInst *Guard, |
729 | SCEVExpander &Expander) { |
730 | LLVM_DEBUG(dbgs() << "Processing guard:\n" ); |
731 | LLVM_DEBUG(Guard->dump()); |
732 | |
733 | TotalConsidered++; |
734 | SmallVector<Value *, 4> Checks; |
735 | SmallVector<Value *> WidenedChecks; |
736 | parseWidenableGuard(U: Guard, Checks); |
737 | widenChecks(Checks, WidenedChecks, Expander, Guard); |
738 | if (WidenedChecks.empty()) |
739 | return false; |
740 | |
741 | TotalWidened += WidenedChecks.size(); |
742 | |
743 | // Emit the new guard condition |
744 | IRBuilder<> Builder(findInsertPt(Use: Guard, Ops: Checks)); |
745 | Value *AllChecks = Builder.CreateAnd(Ops: Checks); |
746 | auto *OldCond = Guard->getOperand(i_nocapture: 0); |
747 | Guard->setOperand(i_nocapture: 0, Val_nocapture: AllChecks); |
748 | if (InsertAssumesOfPredicatedGuardsConditions) { |
749 | Builder.SetInsertPoint(&*++BasicBlock::iterator(Guard)); |
750 | Builder.CreateAssumption(Cond: OldCond); |
751 | } |
752 | RecursivelyDeleteTriviallyDeadInstructions(V: OldCond, TLI: nullptr /* TLI */, MSSAU); |
753 | |
754 | LLVM_DEBUG(dbgs() << "Widened checks = " << WidenedChecks.size() << "\n" ); |
755 | return true; |
756 | } |
757 | |
758 | bool LoopPredication::widenWidenableBranchGuardConditions( |
759 | BranchInst *BI, SCEVExpander &Expander) { |
760 | assert(isGuardAsWidenableBranch(BI) && "Must be!" ); |
761 | LLVM_DEBUG(dbgs() << "Processing guard:\n" ); |
762 | LLVM_DEBUG(BI->dump()); |
763 | |
764 | TotalConsidered++; |
765 | SmallVector<Value *, 4> Checks; |
766 | SmallVector<Value *> WidenedChecks; |
767 | parseWidenableGuard(U: BI, Checks); |
768 | // At the moment, our matching logic for wideable conditions implicitly |
769 | // assumes we preserve the form: (br (and Cond, WC())). FIXME |
770 | auto WC = extractWidenableCondition(U: BI); |
771 | Checks.push_back(Elt: WC); |
772 | widenChecks(Checks, WidenedChecks, Expander, Guard: BI); |
773 | if (WidenedChecks.empty()) |
774 | return false; |
775 | |
776 | TotalWidened += WidenedChecks.size(); |
777 | |
778 | // Emit the new guard condition |
779 | IRBuilder<> Builder(findInsertPt(Use: BI, Ops: Checks)); |
780 | Value *AllChecks = Builder.CreateAnd(Ops: Checks); |
781 | auto *OldCond = BI->getCondition(); |
782 | BI->setCondition(AllChecks); |
783 | if (InsertAssumesOfPredicatedGuardsConditions) { |
784 | BasicBlock *IfTrueBB = BI->getSuccessor(i: 0); |
785 | Builder.SetInsertPoint(TheBB: IfTrueBB, IP: IfTrueBB->getFirstInsertionPt()); |
786 | // If this block has other predecessors, we might not be able to use Cond. |
787 | // In this case, create a Phi where every other input is `true` and input |
788 | // from guard block is Cond. |
789 | Value *AssumeCond = Builder.CreateAnd(Ops: WidenedChecks); |
790 | if (!IfTrueBB->getUniquePredecessor()) { |
791 | auto *GuardBB = BI->getParent(); |
792 | auto *PN = Builder.CreatePHI(Ty: AssumeCond->getType(), NumReservedValues: pred_size(BB: IfTrueBB), |
793 | Name: "assume.cond" ); |
794 | for (auto *Pred : predecessors(BB: IfTrueBB)) |
795 | PN->addIncoming(V: Pred == GuardBB ? AssumeCond : Builder.getTrue(), BB: Pred); |
796 | AssumeCond = PN; |
797 | } |
798 | Builder.CreateAssumption(Cond: AssumeCond); |
799 | } |
800 | RecursivelyDeleteTriviallyDeadInstructions(V: OldCond, TLI: nullptr /* TLI */, MSSAU); |
801 | assert(isGuardAsWidenableBranch(BI) && |
802 | "Stopped being a guard after transform?" ); |
803 | |
804 | LLVM_DEBUG(dbgs() << "Widened checks = " << WidenedChecks.size() << "\n" ); |
805 | return true; |
806 | } |
807 | |
808 | std::optional<LoopICmp> LoopPredication::parseLoopLatchICmp() { |
809 | using namespace PatternMatch; |
810 | |
811 | BasicBlock *LoopLatch = L->getLoopLatch(); |
812 | if (!LoopLatch) { |
813 | LLVM_DEBUG(dbgs() << "The loop doesn't have a single latch!\n" ); |
814 | return std::nullopt; |
815 | } |
816 | |
817 | auto *BI = dyn_cast<BranchInst>(Val: LoopLatch->getTerminator()); |
818 | if (!BI || !BI->isConditional()) { |
819 | LLVM_DEBUG(dbgs() << "Failed to match the latch terminator!\n" ); |
820 | return std::nullopt; |
821 | } |
822 | BasicBlock *TrueDest = BI->getSuccessor(i: 0); |
823 | assert( |
824 | (TrueDest == L->getHeader() || BI->getSuccessor(1) == L->getHeader()) && |
825 | "One of the latch's destinations must be the header" ); |
826 | |
827 | auto *ICI = dyn_cast<ICmpInst>(Val: BI->getCondition()); |
828 | if (!ICI) { |
829 | LLVM_DEBUG(dbgs() << "Failed to match the latch condition!\n" ); |
830 | return std::nullopt; |
831 | } |
832 | auto Result = parseLoopICmp(ICI); |
833 | if (!Result) { |
834 | LLVM_DEBUG(dbgs() << "Failed to parse the loop latch condition!\n" ); |
835 | return std::nullopt; |
836 | } |
837 | |
838 | if (TrueDest != L->getHeader()) |
839 | Result->Pred = ICmpInst::getInversePredicate(pred: Result->Pred); |
840 | |
841 | // Check affine first, so if it's not we don't try to compute the step |
842 | // recurrence. |
843 | if (!Result->IV->isAffine()) { |
844 | LLVM_DEBUG(dbgs() << "The induction variable is not affine!\n" ); |
845 | return std::nullopt; |
846 | } |
847 | |
848 | auto *Step = Result->IV->getStepRecurrence(SE&: *SE); |
849 | if (!isSupportedStep(Step)) { |
850 | LLVM_DEBUG(dbgs() << "Unsupported loop stride(" << *Step << ")!\n" ); |
851 | return std::nullopt; |
852 | } |
853 | |
854 | auto IsUnsupportedPredicate = [](const SCEV *Step, ICmpInst::Predicate Pred) { |
855 | if (Step->isOne()) { |
856 | return Pred != ICmpInst::ICMP_ULT && Pred != ICmpInst::ICMP_SLT && |
857 | Pred != ICmpInst::ICMP_ULE && Pred != ICmpInst::ICMP_SLE; |
858 | } else { |
859 | assert(Step->isAllOnesValue() && "Step should be -1!" ); |
860 | return Pred != ICmpInst::ICMP_UGT && Pred != ICmpInst::ICMP_SGT && |
861 | Pred != ICmpInst::ICMP_UGE && Pred != ICmpInst::ICMP_SGE; |
862 | } |
863 | }; |
864 | |
865 | normalizePredicate(SE, L, RC&: *Result); |
866 | if (IsUnsupportedPredicate(Step, Result->Pred)) { |
867 | LLVM_DEBUG(dbgs() << "Unsupported loop latch predicate(" << Result->Pred |
868 | << ")!\n" ); |
869 | return std::nullopt; |
870 | } |
871 | |
872 | return Result; |
873 | } |
874 | |
875 | bool LoopPredication::isLoopProfitableToPredicate() { |
876 | if (SkipProfitabilityChecks) |
877 | return true; |
878 | |
879 | SmallVector<std::pair<BasicBlock *, BasicBlock *>, 8> ExitEdges; |
880 | L->getExitEdges(ExitEdges); |
881 | // If there is only one exiting edge in the loop, it is always profitable to |
882 | // predicate the loop. |
883 | if (ExitEdges.size() == 1) |
884 | return true; |
885 | |
886 | // Calculate the exiting probabilities of all exiting edges from the loop, |
887 | // starting with the LatchExitProbability. |
888 | // Heuristic for profitability: If any of the exiting blocks' probability of |
889 | // exiting the loop is larger than exiting through the latch block, it's not |
890 | // profitable to predicate the loop. |
891 | auto *LatchBlock = L->getLoopLatch(); |
892 | assert(LatchBlock && "Should have a single latch at this point!" ); |
893 | auto *LatchTerm = LatchBlock->getTerminator(); |
894 | assert(LatchTerm->getNumSuccessors() == 2 && |
895 | "expected to be an exiting block with 2 succs!" ); |
896 | unsigned LatchBrExitIdx = |
897 | LatchTerm->getSuccessor(Idx: 0) == L->getHeader() ? 1 : 0; |
898 | // We compute branch probabilities without BPI. We do not rely on BPI since |
899 | // Loop predication is usually run in an LPM and BPI is only preserved |
900 | // lossily within loop pass managers, while BPI has an inherent notion of |
901 | // being complete for an entire function. |
902 | |
903 | // If the latch exits into a deoptimize or an unreachable block, do not |
904 | // predicate on that latch check. |
905 | auto *LatchExitBlock = LatchTerm->getSuccessor(Idx: LatchBrExitIdx); |
906 | if (isa<UnreachableInst>(Val: LatchTerm) || |
907 | LatchExitBlock->getTerminatingDeoptimizeCall()) |
908 | return false; |
909 | |
910 | // Latch terminator has no valid profile data, so nothing to check |
911 | // profitability on. |
912 | if (!hasValidBranchWeightMD(I: *LatchTerm)) |
913 | return true; |
914 | |
915 | auto ComputeBranchProbability = |
916 | [&](const BasicBlock *ExitingBlock, |
917 | const BasicBlock *ExitBlock) -> BranchProbability { |
918 | auto *Term = ExitingBlock->getTerminator(); |
919 | unsigned NumSucc = Term->getNumSuccessors(); |
920 | if (MDNode *ProfileData = getValidBranchWeightMDNode(I: *Term)) { |
921 | SmallVector<uint32_t> Weights; |
922 | extractBranchWeights(ProfileData, Weights); |
923 | uint64_t Numerator = 0, Denominator = 0; |
924 | for (auto [i, Weight] : llvm::enumerate(First&: Weights)) { |
925 | if (Term->getSuccessor(Idx: i) == ExitBlock) |
926 | Numerator += Weight; |
927 | Denominator += Weight; |
928 | } |
929 | // If all weights are zero act as if there was no profile data |
930 | if (Denominator == 0) |
931 | return BranchProbability::getBranchProbability(Numerator: 1, Denominator: NumSucc); |
932 | return BranchProbability::getBranchProbability(Numerator, Denominator); |
933 | } else { |
934 | assert(LatchBlock != ExitingBlock && |
935 | "Latch term should always have profile data!" ); |
936 | // No profile data, so we choose the weight as 1/num_of_succ(Src) |
937 | return BranchProbability::getBranchProbability(Numerator: 1, Denominator: NumSucc); |
938 | } |
939 | }; |
940 | |
941 | BranchProbability LatchExitProbability = |
942 | ComputeBranchProbability(LatchBlock, LatchExitBlock); |
943 | |
944 | // Protect against degenerate inputs provided by the user. Providing a value |
945 | // less than one, can invert the definition of profitable loop predication. |
946 | float ScaleFactor = LatchExitProbabilityScale; |
947 | if (ScaleFactor < 1) { |
948 | LLVM_DEBUG( |
949 | dbgs() |
950 | << "Ignored user setting for loop-predication-latch-probability-scale: " |
951 | << LatchExitProbabilityScale << "\n" ); |
952 | LLVM_DEBUG(dbgs() << "The value is set to 1.0\n" ); |
953 | ScaleFactor = 1.0; |
954 | } |
955 | const auto LatchProbabilityThreshold = LatchExitProbability * ScaleFactor; |
956 | |
957 | for (const auto &ExitEdge : ExitEdges) { |
958 | BranchProbability ExitingBlockProbability = |
959 | ComputeBranchProbability(ExitEdge.first, ExitEdge.second); |
960 | // Some exiting edge has higher probability than the latch exiting edge. |
961 | // No longer profitable to predicate. |
962 | if (ExitingBlockProbability > LatchProbabilityThreshold) |
963 | return false; |
964 | } |
965 | |
966 | // We have concluded that the most probable way to exit from the |
967 | // loop is through the latch (or there's no profile information and all |
968 | // exits are equally likely). |
969 | return true; |
970 | } |
971 | |
972 | /// If we can (cheaply) find a widenable branch which controls entry into the |
973 | /// loop, return it. |
974 | static BranchInst *FindWidenableTerminatorAboveLoop(Loop *L, LoopInfo &LI) { |
975 | // Walk back through any unconditional executed blocks and see if we can find |
976 | // a widenable condition which seems to control execution of this loop. Note |
977 | // that we predict that maythrow calls are likely untaken and thus that it's |
978 | // profitable to widen a branch before a maythrow call with a condition |
979 | // afterwards even though that may cause the slow path to run in a case where |
980 | // it wouldn't have otherwise. |
981 | BasicBlock *BB = L->getLoopPreheader(); |
982 | if (!BB) |
983 | return nullptr; |
984 | do { |
985 | if (BasicBlock *Pred = BB->getSinglePredecessor()) |
986 | if (BB == Pred->getSingleSuccessor()) { |
987 | BB = Pred; |
988 | continue; |
989 | } |
990 | break; |
991 | } while (true); |
992 | |
993 | if (BasicBlock *Pred = BB->getSinglePredecessor()) { |
994 | if (auto *BI = dyn_cast<BranchInst>(Val: Pred->getTerminator())) |
995 | if (BI->getSuccessor(i: 0) == BB && isWidenableBranch(U: BI)) |
996 | return BI; |
997 | } |
998 | return nullptr; |
999 | } |
1000 | |
1001 | /// Return the minimum of all analyzeable exit counts. This is an upper bound |
1002 | /// on the actual exit count. If there are not at least two analyzeable exits, |
1003 | /// returns SCEVCouldNotCompute. |
1004 | static const SCEV *getMinAnalyzeableBackedgeTakenCount(ScalarEvolution &SE, |
1005 | DominatorTree &DT, |
1006 | Loop *L) { |
1007 | SmallVector<BasicBlock *, 16> ExitingBlocks; |
1008 | L->getExitingBlocks(ExitingBlocks); |
1009 | |
1010 | SmallVector<const SCEV *, 4> ExitCounts; |
1011 | for (BasicBlock *ExitingBB : ExitingBlocks) { |
1012 | const SCEV *ExitCount = SE.getExitCount(L, ExitingBlock: ExitingBB); |
1013 | if (isa<SCEVCouldNotCompute>(Val: ExitCount)) |
1014 | continue; |
1015 | assert(DT.dominates(ExitingBB, L->getLoopLatch()) && |
1016 | "We should only have known counts for exiting blocks that " |
1017 | "dominate latch!" ); |
1018 | ExitCounts.push_back(Elt: ExitCount); |
1019 | } |
1020 | if (ExitCounts.size() < 2) |
1021 | return SE.getCouldNotCompute(); |
1022 | return SE.getUMinFromMismatchedTypes(Ops&: ExitCounts); |
1023 | } |
1024 | |
1025 | /// This implements an analogous, but entirely distinct transform from the main |
1026 | /// loop predication transform. This one is phrased in terms of using a |
1027 | /// widenable branch *outside* the loop to allow us to simplify loop exits in a |
1028 | /// following loop. This is close in spirit to the IndVarSimplify transform |
1029 | /// of the same name, but is materially different widening loosens legality |
1030 | /// sharply. |
1031 | bool LoopPredication::predicateLoopExits(Loop *L, SCEVExpander &Rewriter) { |
1032 | // The transformation performed here aims to widen a widenable condition |
1033 | // above the loop such that all analyzeable exit leading to deopt are dead. |
1034 | // It assumes that the latch is the dominant exit for profitability and that |
1035 | // exits branching to deoptimizing blocks are rarely taken. It relies on the |
1036 | // semantics of widenable expressions for legality. (i.e. being able to fall |
1037 | // down the widenable path spuriously allows us to ignore exit order, |
1038 | // unanalyzeable exits, side effects, exceptional exits, and other challenges |
1039 | // which restrict the applicability of the non-WC based version of this |
1040 | // transform in IndVarSimplify.) |
1041 | // |
1042 | // NOTE ON POISON/UNDEF - We're hoisting an expression above guards which may |
1043 | // imply flags on the expression being hoisted and inserting new uses (flags |
1044 | // are only correct for current uses). The result is that we may be |
1045 | // inserting a branch on the value which can be either poison or undef. In |
1046 | // this case, the branch can legally go either way; we just need to avoid |
1047 | // introducing UB. This is achieved through the use of the freeze |
1048 | // instruction. |
1049 | |
1050 | SmallVector<BasicBlock *, 16> ExitingBlocks; |
1051 | L->getExitingBlocks(ExitingBlocks); |
1052 | |
1053 | if (ExitingBlocks.empty()) |
1054 | return false; // Nothing to do. |
1055 | |
1056 | auto *Latch = L->getLoopLatch(); |
1057 | if (!Latch) |
1058 | return false; |
1059 | |
1060 | auto *WidenableBR = FindWidenableTerminatorAboveLoop(L, LI&: *LI); |
1061 | if (!WidenableBR) |
1062 | return false; |
1063 | |
1064 | const SCEV *LatchEC = SE->getExitCount(L, ExitingBlock: Latch); |
1065 | if (isa<SCEVCouldNotCompute>(Val: LatchEC)) |
1066 | return false; // profitability - want hot exit in analyzeable set |
1067 | |
1068 | // At this point, we have found an analyzeable latch, and a widenable |
1069 | // condition above the loop. If we have a widenable exit within the loop |
1070 | // (for which we can't compute exit counts), drop the ability to further |
1071 | // widen so that we gain ability to analyze it's exit count and perform this |
1072 | // transform. TODO: It'd be nice to know for sure the exit became |
1073 | // analyzeable after dropping widenability. |
1074 | bool ChangedLoop = false; |
1075 | |
1076 | for (auto *ExitingBB : ExitingBlocks) { |
1077 | if (LI->getLoopFor(BB: ExitingBB) != L) |
1078 | continue; |
1079 | |
1080 | auto *BI = dyn_cast<BranchInst>(Val: ExitingBB->getTerminator()); |
1081 | if (!BI) |
1082 | continue; |
1083 | |
1084 | if (auto WC = extractWidenableCondition(U: BI)) |
1085 | if (L->contains(BB: BI->getSuccessor(i: 0))) { |
1086 | assert(WC->hasOneUse() && "Not appropriate widenable branch!" ); |
1087 | WC->user_back()->replaceUsesOfWith( |
1088 | From: WC, To: ConstantInt::getTrue(Context&: BI->getContext())); |
1089 | ChangedLoop = true; |
1090 | } |
1091 | } |
1092 | if (ChangedLoop) |
1093 | SE->forgetLoop(L); |
1094 | |
1095 | // The insertion point for the widening should be at the widenably call, not |
1096 | // at the WidenableBR. If we do this at the widenableBR, we can incorrectly |
1097 | // change a loop-invariant condition to a loop-varying one. |
1098 | auto *IP = cast<Instruction>(Val: WidenableBR->getCondition()); |
1099 | |
1100 | // The use of umin(all analyzeable exits) instead of latch is subtle, but |
1101 | // important for profitability. We may have a loop which hasn't been fully |
1102 | // canonicalized just yet. If the exit we chose to widen is provably never |
1103 | // taken, we want the widened form to *also* be provably never taken. We |
1104 | // can't guarantee this as a current unanalyzeable exit may later become |
1105 | // analyzeable, but we can at least avoid the obvious cases. |
1106 | const SCEV *MinEC = getMinAnalyzeableBackedgeTakenCount(SE&: *SE, DT&: *DT, L); |
1107 | if (isa<SCEVCouldNotCompute>(Val: MinEC) || MinEC->getType()->isPointerTy() || |
1108 | !SE->isLoopInvariant(S: MinEC, L) || |
1109 | !Rewriter.isSafeToExpandAt(S: MinEC, InsertionPoint: IP)) |
1110 | return ChangedLoop; |
1111 | |
1112 | Rewriter.setInsertPoint(IP); |
1113 | IRBuilder<> B(IP); |
1114 | |
1115 | bool InvalidateLoop = false; |
1116 | Value *MinECV = nullptr; // lazily generated if needed |
1117 | for (BasicBlock *ExitingBB : ExitingBlocks) { |
1118 | // If our exiting block exits multiple loops, we can only rewrite the |
1119 | // innermost one. Otherwise, we're changing how many times the innermost |
1120 | // loop runs before it exits. |
1121 | if (LI->getLoopFor(BB: ExitingBB) != L) |
1122 | continue; |
1123 | |
1124 | // Can't rewrite non-branch yet. |
1125 | auto *BI = dyn_cast<BranchInst>(Val: ExitingBB->getTerminator()); |
1126 | if (!BI) |
1127 | continue; |
1128 | |
1129 | // If already constant, nothing to do. |
1130 | if (isa<Constant>(Val: BI->getCondition())) |
1131 | continue; |
1132 | |
1133 | const SCEV *ExitCount = SE->getExitCount(L, ExitingBlock: ExitingBB); |
1134 | if (isa<SCEVCouldNotCompute>(Val: ExitCount) || |
1135 | ExitCount->getType()->isPointerTy() || |
1136 | !Rewriter.isSafeToExpandAt(S: ExitCount, InsertionPoint: WidenableBR)) |
1137 | continue; |
1138 | |
1139 | const bool ExitIfTrue = !L->contains(BB: *succ_begin(BB: ExitingBB)); |
1140 | BasicBlock *ExitBB = BI->getSuccessor(i: ExitIfTrue ? 0 : 1); |
1141 | if (!ExitBB->getPostdominatingDeoptimizeCall()) |
1142 | continue; |
1143 | |
1144 | /// Here we can be fairly sure that executing this exit will most likely |
1145 | /// lead to executing llvm.experimental.deoptimize. |
1146 | /// This is a profitability heuristic, not a legality constraint. |
1147 | |
1148 | // If we found a widenable exit condition, do two things: |
1149 | // 1) fold the widened exit test into the widenable condition |
1150 | // 2) fold the branch to untaken - avoids infinite looping |
1151 | |
1152 | Value *ECV = Rewriter.expandCodeFor(SH: ExitCount); |
1153 | if (!MinECV) |
1154 | MinECV = Rewriter.expandCodeFor(SH: MinEC); |
1155 | Value *RHS = MinECV; |
1156 | if (ECV->getType() != RHS->getType()) { |
1157 | Type *WiderTy = SE->getWiderType(Ty1: ECV->getType(), Ty2: RHS->getType()); |
1158 | ECV = B.CreateZExt(V: ECV, DestTy: WiderTy); |
1159 | RHS = B.CreateZExt(V: RHS, DestTy: WiderTy); |
1160 | } |
1161 | assert(!Latch || DT->dominates(ExitingBB, Latch)); |
1162 | Value *NewCond = B.CreateICmp(P: ICmpInst::ICMP_UGT, LHS: ECV, RHS); |
1163 | // Freeze poison or undef to an arbitrary bit pattern to ensure we can |
1164 | // branch without introducing UB. See NOTE ON POISON/UNDEF above for |
1165 | // context. |
1166 | NewCond = B.CreateFreeze(V: NewCond); |
1167 | |
1168 | widenWidenableBranch(WidenableBR, NewCond); |
1169 | |
1170 | Value *OldCond = BI->getCondition(); |
1171 | BI->setCondition(ConstantInt::get(Ty: OldCond->getType(), V: !ExitIfTrue)); |
1172 | InvalidateLoop = true; |
1173 | } |
1174 | |
1175 | if (InvalidateLoop) |
1176 | // We just mutated a bunch of loop exits changing there exit counts |
1177 | // widely. We need to force recomputation of the exit counts given these |
1178 | // changes. Note that all of the inserted exits are never taken, and |
1179 | // should be removed next time the CFG is modified. |
1180 | SE->forgetLoop(L); |
1181 | |
1182 | // Always return `true` since we have moved the WidenableBR's condition. |
1183 | return true; |
1184 | } |
1185 | |
1186 | bool LoopPredication::runOnLoop(Loop *Loop) { |
1187 | L = Loop; |
1188 | |
1189 | LLVM_DEBUG(dbgs() << "Analyzing " ); |
1190 | LLVM_DEBUG(L->dump()); |
1191 | |
1192 | Module *M = L->getHeader()->getModule(); |
1193 | |
1194 | // There is nothing to do if the module doesn't use guards |
1195 | auto *GuardDecl = |
1196 | M->getFunction(Name: Intrinsic::getName(id: Intrinsic::experimental_guard)); |
1197 | bool HasIntrinsicGuards = GuardDecl && !GuardDecl->use_empty(); |
1198 | auto *WCDecl = M->getFunction( |
1199 | Name: Intrinsic::getName(id: Intrinsic::experimental_widenable_condition)); |
1200 | bool HasWidenableConditions = |
1201 | PredicateWidenableBranchGuards && WCDecl && !WCDecl->use_empty(); |
1202 | if (!HasIntrinsicGuards && !HasWidenableConditions) |
1203 | return false; |
1204 | |
1205 | DL = &M->getDataLayout(); |
1206 | |
1207 | Preheader = L->getLoopPreheader(); |
1208 | if (!Preheader) |
1209 | return false; |
1210 | |
1211 | auto LatchCheckOpt = parseLoopLatchICmp(); |
1212 | if (!LatchCheckOpt) |
1213 | return false; |
1214 | LatchCheck = *LatchCheckOpt; |
1215 | |
1216 | LLVM_DEBUG(dbgs() << "Latch check:\n" ); |
1217 | LLVM_DEBUG(LatchCheck.dump()); |
1218 | |
1219 | if (!isLoopProfitableToPredicate()) { |
1220 | LLVM_DEBUG(dbgs() << "Loop not profitable to predicate!\n" ); |
1221 | return false; |
1222 | } |
1223 | // Collect all the guards into a vector and process later, so as not |
1224 | // to invalidate the instruction iterator. |
1225 | SmallVector<IntrinsicInst *, 4> Guards; |
1226 | SmallVector<BranchInst *, 4> GuardsAsWidenableBranches; |
1227 | for (const auto BB : L->blocks()) { |
1228 | for (auto &I : *BB) |
1229 | if (isGuard(U: &I)) |
1230 | Guards.push_back(Elt: cast<IntrinsicInst>(Val: &I)); |
1231 | if (PredicateWidenableBranchGuards && |
1232 | isGuardAsWidenableBranch(U: BB->getTerminator())) |
1233 | GuardsAsWidenableBranches.push_back( |
1234 | Elt: cast<BranchInst>(Val: BB->getTerminator())); |
1235 | } |
1236 | |
1237 | SCEVExpander Expander(*SE, *DL, "loop-predication" ); |
1238 | bool Changed = false; |
1239 | for (auto *Guard : Guards) |
1240 | Changed |= widenGuardConditions(Guard, Expander); |
1241 | for (auto *Guard : GuardsAsWidenableBranches) |
1242 | Changed |= widenWidenableBranchGuardConditions(BI: Guard, Expander); |
1243 | Changed |= predicateLoopExits(L, Rewriter&: Expander); |
1244 | |
1245 | if (MSSAU && VerifyMemorySSA) |
1246 | MSSAU->getMemorySSA()->verifyMemorySSA(); |
1247 | return Changed; |
1248 | } |
1249 | |