| 1 | //===- InductiveRangeCheckElimination.cpp - -------------------------------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | // |
| 9 | // The InductiveRangeCheckElimination pass splits a loop's iteration space into |
| 10 | // three disjoint ranges. It does that in a way such that the loop running in |
| 11 | // the middle loop provably does not need range checks. As an example, it will |
| 12 | // convert |
| 13 | // |
| 14 | // len = < known positive > |
| 15 | // for (i = 0; i < n; i++) { |
| 16 | // if (0 <= i && i < len) { |
| 17 | // do_something(); |
| 18 | // } else { |
| 19 | // throw_out_of_bounds(); |
| 20 | // } |
| 21 | // } |
| 22 | // |
| 23 | // to |
| 24 | // |
| 25 | // len = < known positive > |
| 26 | // limit = smin(n, len) |
| 27 | // // no first segment |
| 28 | // for (i = 0; i < limit; i++) { |
| 29 | // if (0 <= i && i < len) { // this check is fully redundant |
| 30 | // do_something(); |
| 31 | // } else { |
| 32 | // throw_out_of_bounds(); |
| 33 | // } |
| 34 | // } |
| 35 | // for (i = limit; i < n; i++) { |
| 36 | // if (0 <= i && i < len) { |
| 37 | // do_something(); |
| 38 | // } else { |
| 39 | // throw_out_of_bounds(); |
| 40 | // } |
| 41 | // } |
| 42 | // |
| 43 | //===----------------------------------------------------------------------===// |
| 44 | |
| 45 | #include "llvm/Transforms/Scalar/InductiveRangeCheckElimination.h" |
| 46 | #include "llvm/ADT/APInt.h" |
| 47 | #include "llvm/ADT/ArrayRef.h" |
| 48 | #include "llvm/ADT/PriorityWorklist.h" |
| 49 | #include "llvm/ADT/SmallPtrSet.h" |
| 50 | #include "llvm/ADT/SmallVector.h" |
| 51 | #include "llvm/ADT/StringRef.h" |
| 52 | #include "llvm/ADT/Twine.h" |
| 53 | #include "llvm/Analysis/BlockFrequencyInfo.h" |
| 54 | #include "llvm/Analysis/BranchProbabilityInfo.h" |
| 55 | #include "llvm/Analysis/LoopAnalysisManager.h" |
| 56 | #include "llvm/Analysis/LoopInfo.h" |
| 57 | #include "llvm/Analysis/ScalarEvolution.h" |
| 58 | #include "llvm/Analysis/ScalarEvolutionExpressions.h" |
| 59 | #include "llvm/IR/BasicBlock.h" |
| 60 | #include "llvm/IR/CFG.h" |
| 61 | #include "llvm/IR/Constants.h" |
| 62 | #include "llvm/IR/DerivedTypes.h" |
| 63 | #include "llvm/IR/Dominators.h" |
| 64 | #include "llvm/IR/Function.h" |
| 65 | #include "llvm/IR/IRBuilder.h" |
| 66 | #include "llvm/IR/InstrTypes.h" |
| 67 | #include "llvm/IR/Instructions.h" |
| 68 | #include "llvm/IR/Metadata.h" |
| 69 | #include "llvm/IR/Module.h" |
| 70 | #include "llvm/IR/PatternMatch.h" |
| 71 | #include "llvm/IR/Type.h" |
| 72 | #include "llvm/IR/Use.h" |
| 73 | #include "llvm/IR/User.h" |
| 74 | #include "llvm/IR/Value.h" |
| 75 | #include "llvm/Support/BranchProbability.h" |
| 76 | #include "llvm/Support/Casting.h" |
| 77 | #include "llvm/Support/CommandLine.h" |
| 78 | #include "llvm/Support/Compiler.h" |
| 79 | #include "llvm/Support/Debug.h" |
| 80 | #include "llvm/Support/ErrorHandling.h" |
| 81 | #include "llvm/Support/raw_ostream.h" |
| 82 | #include "llvm/Transforms/Utils/BasicBlockUtils.h" |
| 83 | #include "llvm/Transforms/Utils/Cloning.h" |
| 84 | #include "llvm/Transforms/Utils/LoopConstrainer.h" |
| 85 | #include "llvm/Transforms/Utils/LoopSimplify.h" |
| 86 | #include "llvm/Transforms/Utils/LoopUtils.h" |
| 87 | #include "llvm/Transforms/Utils/ValueMapper.h" |
| 88 | #include <algorithm> |
| 89 | #include <cassert> |
| 90 | #include <optional> |
| 91 | #include <utility> |
| 92 | |
| 93 | using namespace llvm; |
| 94 | using namespace llvm::PatternMatch; |
| 95 | |
| 96 | static cl::opt<unsigned> LoopSizeCutoff("irce-loop-size-cutoff" , cl::Hidden, |
| 97 | cl::init(Val: 64)); |
| 98 | |
| 99 | static cl::opt<bool> PrintChangedLoops("irce-print-changed-loops" , cl::Hidden, |
| 100 | cl::init(Val: false)); |
| 101 | |
| 102 | static cl::opt<bool> PrintRangeChecks("irce-print-range-checks" , cl::Hidden, |
| 103 | cl::init(Val: false)); |
| 104 | |
| 105 | static cl::opt<bool> SkipProfitabilityChecks("irce-skip-profitability-checks" , |
| 106 | cl::Hidden, cl::init(Val: false)); |
| 107 | |
| 108 | static cl::opt<unsigned> MinEliminatedChecks("irce-min-eliminated-checks" , |
| 109 | cl::Hidden, cl::init(Val: 10)); |
| 110 | |
| 111 | static cl::opt<bool> AllowUnsignedLatchCondition("irce-allow-unsigned-latch" , |
| 112 | cl::Hidden, cl::init(Val: true)); |
| 113 | |
| 114 | static cl::opt<bool> AllowNarrowLatchCondition( |
| 115 | "irce-allow-narrow-latch" , cl::Hidden, cl::init(Val: true), |
| 116 | cl::desc("If set to true, IRCE may eliminate wide range checks in loops " |
| 117 | "with narrow latch condition." )); |
| 118 | |
| 119 | static cl::opt<unsigned> MaxTypeSizeForOverflowCheck( |
| 120 | "irce-max-type-size-for-overflow-check" , cl::Hidden, cl::init(Val: 32), |
| 121 | cl::desc( |
| 122 | "Maximum size of range check type for which can be produced runtime " |
| 123 | "overflow check of its limit's computation" )); |
| 124 | |
| 125 | static cl::opt<bool> |
| 126 | PrintScaledBoundaryRangeChecks("irce-print-scaled-boundary-range-checks" , |
| 127 | cl::Hidden, cl::init(Val: false)); |
| 128 | |
| 129 | #define DEBUG_TYPE "irce" |
| 130 | |
| 131 | namespace { |
| 132 | |
| 133 | /// An inductive range check is conditional branch in a loop with a condition |
| 134 | /// that is provably true for some contiguous range of values taken by the |
| 135 | /// containing loop's induction variable. |
| 136 | /// |
| 137 | class InductiveRangeCheck { |
| 138 | |
| 139 | const SCEV *Begin = nullptr; |
| 140 | const SCEV *Step = nullptr; |
| 141 | const SCEV *End = nullptr; |
| 142 | Use *CheckUse = nullptr; |
| 143 | |
| 144 | static bool parseRangeCheckICmp(Loop *L, ICmpInst *ICI, ScalarEvolution &SE, |
| 145 | const SCEVAddRecExpr *&Index, |
| 146 | const SCEV *&End); |
| 147 | |
| 148 | static void |
| 149 | extractRangeChecksFromCond(Loop *L, ScalarEvolution &SE, Use &ConditionUse, |
| 150 | SmallVectorImpl<InductiveRangeCheck> &Checks, |
| 151 | SmallPtrSetImpl<Value *> &Visited); |
| 152 | |
| 153 | static bool parseIvAgaisntLimit(Loop *L, Value *LHS, Value *RHS, |
| 154 | ICmpInst::Predicate Pred, ScalarEvolution &SE, |
| 155 | const SCEVAddRecExpr *&Index, |
| 156 | const SCEV *&End); |
| 157 | |
| 158 | static bool reassociateSubLHS(Loop *L, Value *VariantLHS, Value *InvariantRHS, |
| 159 | ICmpInst::Predicate Pred, ScalarEvolution &SE, |
| 160 | const SCEVAddRecExpr *&Index, const SCEV *&End); |
| 161 | |
| 162 | public: |
| 163 | const SCEV *getBegin() const { return Begin; } |
| 164 | const SCEV *getStep() const { return Step; } |
| 165 | const SCEV *getEnd() const { return End; } |
| 166 | |
| 167 | void print(raw_ostream &OS) const { |
| 168 | OS << "InductiveRangeCheck:\n" ; |
| 169 | OS << " Begin: " ; |
| 170 | Begin->print(OS); |
| 171 | OS << " Step: " ; |
| 172 | Step->print(OS); |
| 173 | OS << " End: " ; |
| 174 | End->print(OS); |
| 175 | OS << "\n CheckUse: " ; |
| 176 | getCheckUse()->getUser()->print(O&: OS); |
| 177 | OS << " Operand: " << getCheckUse()->getOperandNo() << "\n" ; |
| 178 | } |
| 179 | |
| 180 | LLVM_DUMP_METHOD |
| 181 | void dump() { |
| 182 | print(OS&: dbgs()); |
| 183 | } |
| 184 | |
| 185 | Use *getCheckUse() const { return CheckUse; } |
| 186 | |
| 187 | /// Represents an signed integer range [Range.getBegin(), Range.getEnd()). If |
| 188 | /// R.getEnd() le R.getBegin(), then R denotes the empty range. |
| 189 | |
| 190 | class Range { |
| 191 | const SCEV *Begin; |
| 192 | const SCEV *End; |
| 193 | |
| 194 | public: |
| 195 | Range(const SCEV *Begin, const SCEV *End) : Begin(Begin), End(End) { |
| 196 | assert(Begin->getType() == End->getType() && "ill-typed range!" ); |
| 197 | } |
| 198 | |
| 199 | Type *getType() const { return Begin->getType(); } |
| 200 | const SCEV *getBegin() const { return Begin; } |
| 201 | const SCEV *getEnd() const { return End; } |
| 202 | bool isEmpty(ScalarEvolution &SE, bool IsSigned) const { |
| 203 | if (Begin == End) |
| 204 | return true; |
| 205 | if (IsSigned) |
| 206 | return SE.isKnownPredicate(Pred: ICmpInst::ICMP_SGE, LHS: Begin, RHS: End); |
| 207 | else |
| 208 | return SE.isKnownPredicate(Pred: ICmpInst::ICMP_UGE, LHS: Begin, RHS: End); |
| 209 | } |
| 210 | }; |
| 211 | |
| 212 | /// This is the value the condition of the branch needs to evaluate to for the |
| 213 | /// branch to take the hot successor (see (1) above). |
| 214 | bool getPassingDirection() { return true; } |
| 215 | |
| 216 | /// Computes a range for the induction variable (IndVar) in which the range |
| 217 | /// check is redundant and can be constant-folded away. The induction |
| 218 | /// variable is not required to be the canonical {0,+,1} induction variable. |
| 219 | std::optional<Range> computeSafeIterationSpace(ScalarEvolution &SE, |
| 220 | const SCEVAddRecExpr *IndVar, |
| 221 | bool IsLatchSigned) const; |
| 222 | |
| 223 | /// Parse out a set of inductive range checks from \p BI and append them to \p |
| 224 | /// Checks. |
| 225 | /// |
| 226 | /// NB! There may be conditions feeding into \p BI that aren't inductive range |
| 227 | /// checks, and hence don't end up in \p Checks. |
| 228 | static void extractRangeChecksFromBranch( |
| 229 | BranchInst *BI, Loop *L, ScalarEvolution &SE, BranchProbabilityInfo *BPI, |
| 230 | std::optional<uint64_t> EstimatedTripCount, |
| 231 | SmallVectorImpl<InductiveRangeCheck> &Checks, bool &Changed); |
| 232 | }; |
| 233 | |
| 234 | class InductiveRangeCheckElimination { |
| 235 | ScalarEvolution &SE; |
| 236 | BranchProbabilityInfo *BPI; |
| 237 | DominatorTree &DT; |
| 238 | LoopInfo &LI; |
| 239 | |
| 240 | using GetBFIFunc = |
| 241 | std::optional<llvm::function_ref<llvm::BlockFrequencyInfo &()>>; |
| 242 | GetBFIFunc GetBFI; |
| 243 | |
| 244 | // Returns the estimated number of iterations based on block frequency info if |
| 245 | // available, or on branch probability info. Nullopt is returned if the number |
| 246 | // of iterations cannot be estimated. |
| 247 | std::optional<uint64_t> estimatedTripCount(const Loop &L); |
| 248 | |
| 249 | public: |
| 250 | InductiveRangeCheckElimination(ScalarEvolution &SE, |
| 251 | BranchProbabilityInfo *BPI, DominatorTree &DT, |
| 252 | LoopInfo &LI, GetBFIFunc GetBFI = std::nullopt) |
| 253 | : SE(SE), BPI(BPI), DT(DT), LI(LI), GetBFI(GetBFI) {} |
| 254 | |
| 255 | bool run(Loop *L, function_ref<void(Loop *, bool)> LPMAddNewLoop); |
| 256 | }; |
| 257 | |
| 258 | } // end anonymous namespace |
| 259 | |
| 260 | /// Parse a single ICmp instruction, `ICI`, into a range check. If `ICI` cannot |
| 261 | /// be interpreted as a range check, return false. Otherwise set `Index` to the |
| 262 | /// SCEV being range checked, and set `End` to the upper or lower limit `Index` |
| 263 | /// is being range checked. |
| 264 | bool InductiveRangeCheck::parseRangeCheckICmp(Loop *L, ICmpInst *ICI, |
| 265 | ScalarEvolution &SE, |
| 266 | const SCEVAddRecExpr *&Index, |
| 267 | const SCEV *&End) { |
| 268 | auto IsLoopInvariant = [&SE, L](Value *V) { |
| 269 | return SE.isLoopInvariant(S: SE.getSCEV(V), L); |
| 270 | }; |
| 271 | |
| 272 | ICmpInst::Predicate Pred = ICI->getPredicate(); |
| 273 | Value *LHS = ICI->getOperand(i_nocapture: 0); |
| 274 | Value *RHS = ICI->getOperand(i_nocapture: 1); |
| 275 | |
| 276 | if (!LHS->getType()->isIntegerTy()) |
| 277 | return false; |
| 278 | |
| 279 | // Canonicalize to the `Index Pred Invariant` comparison |
| 280 | if (IsLoopInvariant(LHS)) { |
| 281 | std::swap(a&: LHS, b&: RHS); |
| 282 | Pred = CmpInst::getSwappedPredicate(pred: Pred); |
| 283 | } else if (!IsLoopInvariant(RHS)) |
| 284 | // Both LHS and RHS are loop variant |
| 285 | return false; |
| 286 | |
| 287 | if (parseIvAgaisntLimit(L, LHS, RHS, Pred, SE, Index, End)) |
| 288 | return true; |
| 289 | |
| 290 | if (reassociateSubLHS(L, VariantLHS: LHS, InvariantRHS: RHS, Pred, SE, Index, End)) |
| 291 | return true; |
| 292 | |
| 293 | // TODO: support ReassociateAddLHS |
| 294 | return false; |
| 295 | } |
| 296 | |
| 297 | // Try to parse range check in the form of "IV vs Limit" |
| 298 | bool InductiveRangeCheck::parseIvAgaisntLimit(Loop *L, Value *LHS, Value *RHS, |
| 299 | ICmpInst::Predicate Pred, |
| 300 | ScalarEvolution &SE, |
| 301 | const SCEVAddRecExpr *&Index, |
| 302 | const SCEV *&End) { |
| 303 | |
| 304 | auto SIntMaxSCEV = [&](Type *T) { |
| 305 | unsigned BitWidth = cast<IntegerType>(Val: T)->getBitWidth(); |
| 306 | return SE.getConstant(Val: APInt::getSignedMaxValue(numBits: BitWidth)); |
| 307 | }; |
| 308 | |
| 309 | const auto *AddRec = dyn_cast<SCEVAddRecExpr>(Val: SE.getSCEV(V: LHS)); |
| 310 | if (!AddRec) |
| 311 | return false; |
| 312 | |
| 313 | // We strengthen "0 <= I" to "0 <= I < INT_SMAX" and "I < L" to "0 <= I < L". |
| 314 | // We can potentially do much better here. |
| 315 | // If we want to adjust upper bound for the unsigned range check as we do it |
| 316 | // for signed one, we will need to pick Unsigned max |
| 317 | switch (Pred) { |
| 318 | default: |
| 319 | return false; |
| 320 | |
| 321 | case ICmpInst::ICMP_SGE: |
| 322 | if (match(V: RHS, P: m_ConstantInt<0>())) { |
| 323 | Index = AddRec; |
| 324 | End = SIntMaxSCEV(Index->getType()); |
| 325 | return true; |
| 326 | } |
| 327 | return false; |
| 328 | |
| 329 | case ICmpInst::ICMP_SGT: |
| 330 | if (match(V: RHS, P: m_ConstantInt<-1>())) { |
| 331 | Index = AddRec; |
| 332 | End = SIntMaxSCEV(Index->getType()); |
| 333 | return true; |
| 334 | } |
| 335 | return false; |
| 336 | |
| 337 | case ICmpInst::ICMP_SLT: |
| 338 | case ICmpInst::ICMP_ULT: |
| 339 | Index = AddRec; |
| 340 | End = SE.getSCEV(V: RHS); |
| 341 | return true; |
| 342 | |
| 343 | case ICmpInst::ICMP_SLE: |
| 344 | case ICmpInst::ICMP_ULE: |
| 345 | const SCEV *One = SE.getOne(Ty: RHS->getType()); |
| 346 | const SCEV *RHSS = SE.getSCEV(V: RHS); |
| 347 | bool Signed = Pred == ICmpInst::ICMP_SLE; |
| 348 | if (SE.willNotOverflow(BinOp: Instruction::BinaryOps::Add, Signed, LHS: RHSS, RHS: One)) { |
| 349 | Index = AddRec; |
| 350 | End = SE.getAddExpr(LHS: RHSS, RHS: One); |
| 351 | return true; |
| 352 | } |
| 353 | return false; |
| 354 | } |
| 355 | |
| 356 | llvm_unreachable("default clause returns!" ); |
| 357 | } |
| 358 | |
| 359 | // Try to parse range check in the form of "IV - Offset vs Limit" or "Offset - |
| 360 | // IV vs Limit" |
| 361 | bool InductiveRangeCheck::reassociateSubLHS( |
| 362 | Loop *L, Value *VariantLHS, Value *InvariantRHS, ICmpInst::Predicate Pred, |
| 363 | ScalarEvolution &SE, const SCEVAddRecExpr *&Index, const SCEV *&End) { |
| 364 | Value *LHS, *RHS; |
| 365 | if (!match(V: VariantLHS, P: m_Sub(L: m_Value(V&: LHS), R: m_Value(V&: RHS)))) |
| 366 | return false; |
| 367 | |
| 368 | const SCEV *IV = SE.getSCEV(V: LHS); |
| 369 | const SCEV *Offset = SE.getSCEV(V: RHS); |
| 370 | const SCEV *Limit = SE.getSCEV(V: InvariantRHS); |
| 371 | |
| 372 | bool OffsetSubtracted = false; |
| 373 | if (SE.isLoopInvariant(S: IV, L)) |
| 374 | // "Offset - IV vs Limit" |
| 375 | std::swap(a&: IV, b&: Offset); |
| 376 | else if (SE.isLoopInvariant(S: Offset, L)) |
| 377 | // "IV - Offset vs Limit" |
| 378 | OffsetSubtracted = true; |
| 379 | else |
| 380 | return false; |
| 381 | |
| 382 | const auto *AddRec = dyn_cast<SCEVAddRecExpr>(Val: IV); |
| 383 | if (!AddRec) |
| 384 | return false; |
| 385 | |
| 386 | // In order to turn "IV - Offset < Limit" into "IV < Limit + Offset", we need |
| 387 | // to be able to freely move values from left side of inequality to right side |
| 388 | // (just as in normal linear arithmetics). Overflows make things much more |
| 389 | // complicated, so we want to avoid this. |
| 390 | // |
| 391 | // Let's prove that the initial subtraction doesn't overflow with all IV's |
| 392 | // values from the safe range constructed for that check. |
| 393 | // |
| 394 | // [Case 1] IV - Offset < Limit |
| 395 | // It doesn't overflow if: |
| 396 | // SINT_MIN <= IV - Offset <= SINT_MAX |
| 397 | // In terms of scaled SINT we need to prove: |
| 398 | // SINT_MIN + Offset <= IV <= SINT_MAX + Offset |
| 399 | // Safe range will be constructed: |
| 400 | // 0 <= IV < Limit + Offset |
| 401 | // It means that 'IV - Offset' doesn't underflow, because: |
| 402 | // SINT_MIN + Offset < 0 <= IV |
| 403 | // and doesn't overflow: |
| 404 | // IV < Limit + Offset <= SINT_MAX + Offset |
| 405 | // |
| 406 | // [Case 2] Offset - IV > Limit |
| 407 | // It doesn't overflow if: |
| 408 | // SINT_MIN <= Offset - IV <= SINT_MAX |
| 409 | // In terms of scaled SINT we need to prove: |
| 410 | // -SINT_MIN >= IV - Offset >= -SINT_MAX |
| 411 | // Offset - SINT_MIN >= IV >= Offset - SINT_MAX |
| 412 | // Safe range will be constructed: |
| 413 | // 0 <= IV < Offset - Limit |
| 414 | // It means that 'Offset - IV' doesn't underflow, because |
| 415 | // Offset - SINT_MAX < 0 <= IV |
| 416 | // and doesn't overflow: |
| 417 | // IV < Offset - Limit <= Offset - SINT_MIN |
| 418 | // |
| 419 | // For the computed upper boundary of the IV's range (Offset +/- Limit) we |
| 420 | // don't know exactly whether it overflows or not. So if we can't prove this |
| 421 | // fact at compile time, we scale boundary computations to a wider type with |
| 422 | // the intention to add runtime overflow check. |
| 423 | |
| 424 | auto getExprScaledIfOverflow = [&](Instruction::BinaryOps BinOp, |
| 425 | const SCEV *LHS, |
| 426 | const SCEV *RHS) -> const SCEV * { |
| 427 | const SCEV *(ScalarEvolution::*Operation)(const SCEV *, const SCEV *, |
| 428 | SCEV::NoWrapFlags, unsigned); |
| 429 | switch (BinOp) { |
| 430 | default: |
| 431 | llvm_unreachable("Unsupported binary op" ); |
| 432 | case Instruction::Add: |
| 433 | Operation = &ScalarEvolution::getAddExpr; |
| 434 | break; |
| 435 | case Instruction::Sub: |
| 436 | Operation = &ScalarEvolution::getMinusSCEV; |
| 437 | break; |
| 438 | } |
| 439 | |
| 440 | if (SE.willNotOverflow(BinOp, Signed: ICmpInst::isSigned(predicate: Pred), LHS, RHS, |
| 441 | CtxI: cast<Instruction>(Val: VariantLHS))) |
| 442 | return (SE.*Operation)(LHS, RHS, SCEV::FlagAnyWrap, 0); |
| 443 | |
| 444 | // We couldn't prove that the expression does not overflow. |
| 445 | // Than scale it to a wider type to check overflow at runtime. |
| 446 | auto *Ty = cast<IntegerType>(Val: LHS->getType()); |
| 447 | if (Ty->getBitWidth() > MaxTypeSizeForOverflowCheck) |
| 448 | return nullptr; |
| 449 | |
| 450 | auto WideTy = IntegerType::get(C&: Ty->getContext(), NumBits: Ty->getBitWidth() * 2); |
| 451 | return (SE.*Operation)(SE.getSignExtendExpr(Op: LHS, Ty: WideTy), |
| 452 | SE.getSignExtendExpr(Op: RHS, Ty: WideTy), SCEV::FlagAnyWrap, |
| 453 | 0); |
| 454 | }; |
| 455 | |
| 456 | if (OffsetSubtracted) |
| 457 | // "IV - Offset < Limit" -> "IV" < Offset + Limit |
| 458 | Limit = getExprScaledIfOverflow(Instruction::BinaryOps::Add, Offset, Limit); |
| 459 | else { |
| 460 | // "Offset - IV > Limit" -> "IV" < Offset - Limit |
| 461 | Limit = getExprScaledIfOverflow(Instruction::BinaryOps::Sub, Offset, Limit); |
| 462 | Pred = ICmpInst::getSwappedPredicate(pred: Pred); |
| 463 | } |
| 464 | |
| 465 | if (Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_SLE) { |
| 466 | // "Expr <= Limit" -> "Expr < Limit + 1" |
| 467 | if (Pred == ICmpInst::ICMP_SLE && Limit) |
| 468 | Limit = getExprScaledIfOverflow(Instruction::BinaryOps::Add, Limit, |
| 469 | SE.getOne(Ty: Limit->getType())); |
| 470 | if (Limit) { |
| 471 | Index = AddRec; |
| 472 | End = Limit; |
| 473 | return true; |
| 474 | } |
| 475 | } |
| 476 | return false; |
| 477 | } |
| 478 | |
| 479 | void InductiveRangeCheck::( |
| 480 | Loop *L, ScalarEvolution &SE, Use &ConditionUse, |
| 481 | SmallVectorImpl<InductiveRangeCheck> &Checks, |
| 482 | SmallPtrSetImpl<Value *> &Visited) { |
| 483 | Value *Condition = ConditionUse.get(); |
| 484 | if (!Visited.insert(Ptr: Condition).second) |
| 485 | return; |
| 486 | |
| 487 | // TODO: Do the same for OR, XOR, NOT etc? |
| 488 | if (match(V: Condition, P: m_LogicalAnd(L: m_Value(), R: m_Value()))) { |
| 489 | extractRangeChecksFromCond(L, SE, ConditionUse&: cast<User>(Val: Condition)->getOperandUse(i: 0), |
| 490 | Checks, Visited); |
| 491 | extractRangeChecksFromCond(L, SE, ConditionUse&: cast<User>(Val: Condition)->getOperandUse(i: 1), |
| 492 | Checks, Visited); |
| 493 | return; |
| 494 | } |
| 495 | |
| 496 | ICmpInst *ICI = dyn_cast<ICmpInst>(Val: Condition); |
| 497 | if (!ICI) |
| 498 | return; |
| 499 | |
| 500 | const SCEV *End = nullptr; |
| 501 | const SCEVAddRecExpr *IndexAddRec = nullptr; |
| 502 | if (!parseRangeCheckICmp(L, ICI, SE, Index&: IndexAddRec, End)) |
| 503 | return; |
| 504 | |
| 505 | assert(IndexAddRec && "IndexAddRec was not computed" ); |
| 506 | assert(End && "End was not computed" ); |
| 507 | |
| 508 | if ((IndexAddRec->getLoop() != L) || !IndexAddRec->isAffine()) |
| 509 | return; |
| 510 | |
| 511 | InductiveRangeCheck IRC; |
| 512 | IRC.End = End; |
| 513 | IRC.Begin = IndexAddRec->getStart(); |
| 514 | IRC.Step = IndexAddRec->getStepRecurrence(SE); |
| 515 | IRC.CheckUse = &ConditionUse; |
| 516 | Checks.push_back(Elt: IRC); |
| 517 | } |
| 518 | |
| 519 | void InductiveRangeCheck::( |
| 520 | BranchInst *BI, Loop *L, ScalarEvolution &SE, BranchProbabilityInfo *BPI, |
| 521 | std::optional<uint64_t> EstimatedTripCount, |
| 522 | SmallVectorImpl<InductiveRangeCheck> &Checks, bool &Changed) { |
| 523 | if (BI->isUnconditional() || BI->getParent() == L->getLoopLatch()) |
| 524 | return; |
| 525 | |
| 526 | unsigned IndexLoopSucc = L->contains(BB: BI->getSuccessor(i: 0)) ? 0 : 1; |
| 527 | assert(L->contains(BI->getSuccessor(IndexLoopSucc)) && |
| 528 | "No edges coming to loop?" ); |
| 529 | |
| 530 | if (!SkipProfitabilityChecks && BPI) { |
| 531 | auto SuccessProbability = |
| 532 | BPI->getEdgeProbability(Src: BI->getParent(), IndexInSuccessors: IndexLoopSucc); |
| 533 | if (EstimatedTripCount) { |
| 534 | auto EstimatedEliminatedChecks = |
| 535 | SuccessProbability.scale(Num: *EstimatedTripCount); |
| 536 | if (EstimatedEliminatedChecks < MinEliminatedChecks) { |
| 537 | LLVM_DEBUG(dbgs() << "irce: could not prove profitability for branch " |
| 538 | << *BI << ": " |
| 539 | << "estimated eliminated checks too low " |
| 540 | << EstimatedEliminatedChecks << "\n" ;); |
| 541 | return; |
| 542 | } |
| 543 | } else { |
| 544 | BranchProbability LikelyTaken(15, 16); |
| 545 | if (SuccessProbability < LikelyTaken) { |
| 546 | LLVM_DEBUG(dbgs() << "irce: could not prove profitability for branch " |
| 547 | << *BI << ": " |
| 548 | << "could not estimate trip count " |
| 549 | << "and branch success probability too low " |
| 550 | << SuccessProbability << "\n" ;); |
| 551 | return; |
| 552 | } |
| 553 | } |
| 554 | } |
| 555 | |
| 556 | // IRCE expects branch's true edge comes to loop. Invert branch for opposite |
| 557 | // case. |
| 558 | if (IndexLoopSucc != 0) { |
| 559 | IRBuilder<> Builder(BI); |
| 560 | InvertBranch(PBI: BI, Builder); |
| 561 | if (BPI) |
| 562 | BPI->swapSuccEdgesProbabilities(Src: BI->getParent()); |
| 563 | Changed = true; |
| 564 | } |
| 565 | |
| 566 | SmallPtrSet<Value *, 8> Visited; |
| 567 | InductiveRangeCheck::extractRangeChecksFromCond(L, SE, ConditionUse&: BI->getOperandUse(i: 0), |
| 568 | Checks, Visited); |
| 569 | } |
| 570 | |
| 571 | /// If the type of \p S matches with \p Ty, return \p S. Otherwise, return |
| 572 | /// signed or unsigned extension of \p S to type \p Ty. |
| 573 | static const SCEV *NoopOrExtend(const SCEV *S, Type *Ty, ScalarEvolution &SE, |
| 574 | bool Signed) { |
| 575 | return Signed ? SE.getNoopOrSignExtend(V: S, Ty) : SE.getNoopOrZeroExtend(V: S, Ty); |
| 576 | } |
| 577 | |
| 578 | // Compute a safe set of limits for the main loop to run in -- effectively the |
| 579 | // intersection of `Range' and the iteration space of the original loop. |
| 580 | // Return std::nullopt if unable to compute the set of subranges. |
| 581 | static std::optional<LoopConstrainer::SubRanges> |
| 582 | calculateSubRanges(ScalarEvolution &SE, const Loop &L, |
| 583 | InductiveRangeCheck::Range &Range, |
| 584 | const LoopStructure &MainLoopStructure) { |
| 585 | auto *RTy = cast<IntegerType>(Val: Range.getType()); |
| 586 | // We only support wide range checks and narrow latches. |
| 587 | if (!AllowNarrowLatchCondition && RTy != MainLoopStructure.ExitCountTy) |
| 588 | return std::nullopt; |
| 589 | if (RTy->getBitWidth() < MainLoopStructure.ExitCountTy->getBitWidth()) |
| 590 | return std::nullopt; |
| 591 | |
| 592 | LoopConstrainer::SubRanges Result; |
| 593 | |
| 594 | bool IsSignedPredicate = MainLoopStructure.IsSignedPredicate; |
| 595 | // I think we can be more aggressive here and make this nuw / nsw if the |
| 596 | // addition that feeds into the icmp for the latch's terminating branch is nuw |
| 597 | // / nsw. In any case, a wrapping 2's complement addition is safe. |
| 598 | const SCEV *Start = NoopOrExtend(S: SE.getSCEV(V: MainLoopStructure.IndVarStart), |
| 599 | Ty: RTy, SE, Signed: IsSignedPredicate); |
| 600 | const SCEV *End = NoopOrExtend(S: SE.getSCEV(V: MainLoopStructure.LoopExitAt), Ty: RTy, |
| 601 | SE, Signed: IsSignedPredicate); |
| 602 | |
| 603 | bool Increasing = MainLoopStructure.IndVarIncreasing; |
| 604 | |
| 605 | // We compute `Smallest` and `Greatest` such that [Smallest, Greatest), or |
| 606 | // [Smallest, GreatestSeen] is the range of values the induction variable |
| 607 | // takes. |
| 608 | |
| 609 | const SCEV *Smallest = nullptr, *Greatest = nullptr, *GreatestSeen = nullptr; |
| 610 | |
| 611 | const SCEV *One = SE.getOne(Ty: RTy); |
| 612 | if (Increasing) { |
| 613 | Smallest = Start; |
| 614 | Greatest = End; |
| 615 | // No overflow, because the range [Smallest, GreatestSeen] is not empty. |
| 616 | GreatestSeen = SE.getMinusSCEV(LHS: End, RHS: One); |
| 617 | } else { |
| 618 | // These two computations may sign-overflow. Here is why that is okay: |
| 619 | // |
| 620 | // We know that the induction variable does not sign-overflow on any |
| 621 | // iteration except the last one, and it starts at `Start` and ends at |
| 622 | // `End`, decrementing by one every time. |
| 623 | // |
| 624 | // * if `Smallest` sign-overflows we know `End` is `INT_SMAX`. Since the |
| 625 | // induction variable is decreasing we know that the smallest value |
| 626 | // the loop body is actually executed with is `INT_SMIN` == `Smallest`. |
| 627 | // |
| 628 | // * if `Greatest` sign-overflows, we know it can only be `INT_SMIN`. In |
| 629 | // that case, `Clamp` will always return `Smallest` and |
| 630 | // [`Result.LowLimit`, `Result.HighLimit`) = [`Smallest`, `Smallest`) |
| 631 | // will be an empty range. Returning an empty range is always safe. |
| 632 | |
| 633 | Smallest = SE.getAddExpr(LHS: End, RHS: One); |
| 634 | Greatest = SE.getAddExpr(LHS: Start, RHS: One); |
| 635 | GreatestSeen = Start; |
| 636 | } |
| 637 | |
| 638 | auto Clamp = [&SE, Smallest, Greatest, IsSignedPredicate](const SCEV *S) { |
| 639 | return IsSignedPredicate |
| 640 | ? SE.getSMaxExpr(LHS: Smallest, RHS: SE.getSMinExpr(LHS: Greatest, RHS: S)) |
| 641 | : SE.getUMaxExpr(LHS: Smallest, RHS: SE.getUMinExpr(LHS: Greatest, RHS: S)); |
| 642 | }; |
| 643 | |
| 644 | // In some cases we can prove that we don't need a pre or post loop. |
| 645 | ICmpInst::Predicate PredLE = |
| 646 | IsSignedPredicate ? ICmpInst::ICMP_SLE : ICmpInst::ICMP_ULE; |
| 647 | ICmpInst::Predicate PredLT = |
| 648 | IsSignedPredicate ? ICmpInst::ICMP_SLT : ICmpInst::ICMP_ULT; |
| 649 | |
| 650 | bool ProvablyNoPreloop = |
| 651 | SE.isKnownPredicate(Pred: PredLE, LHS: Range.getBegin(), RHS: Smallest); |
| 652 | if (!ProvablyNoPreloop) |
| 653 | Result.LowLimit = Clamp(Range.getBegin()); |
| 654 | |
| 655 | bool ProvablyNoPostLoop = |
| 656 | SE.isKnownPredicate(Pred: PredLT, LHS: GreatestSeen, RHS: Range.getEnd()); |
| 657 | if (!ProvablyNoPostLoop) |
| 658 | Result.HighLimit = Clamp(Range.getEnd()); |
| 659 | |
| 660 | return Result; |
| 661 | } |
| 662 | |
| 663 | /// Computes and returns a range of values for the induction variable (IndVar) |
| 664 | /// in which the range check can be safely elided. If it cannot compute such a |
| 665 | /// range, returns std::nullopt. |
| 666 | std::optional<InductiveRangeCheck::Range> |
| 667 | InductiveRangeCheck::computeSafeIterationSpace(ScalarEvolution &SE, |
| 668 | const SCEVAddRecExpr *IndVar, |
| 669 | bool IsLatchSigned) const { |
| 670 | // We can deal when types of latch check and range checks don't match in case |
| 671 | // if latch check is more narrow. |
| 672 | auto *IVType = dyn_cast<IntegerType>(Val: IndVar->getType()); |
| 673 | auto *RCType = dyn_cast<IntegerType>(Val: getBegin()->getType()); |
| 674 | auto *EndType = dyn_cast<IntegerType>(Val: getEnd()->getType()); |
| 675 | // Do not work with pointer types. |
| 676 | if (!IVType || !RCType) |
| 677 | return std::nullopt; |
| 678 | if (IVType->getBitWidth() > RCType->getBitWidth()) |
| 679 | return std::nullopt; |
| 680 | |
| 681 | // IndVar is of the form "A + B * I" (where "I" is the canonical induction |
| 682 | // variable, that may or may not exist as a real llvm::Value in the loop) and |
| 683 | // this inductive range check is a range check on the "C + D * I" ("C" is |
| 684 | // getBegin() and "D" is getStep()). We rewrite the value being range |
| 685 | // checked to "M + N * IndVar" where "N" = "D * B^(-1)" and "M" = "C - NA". |
| 686 | // |
| 687 | // The actual inequalities we solve are of the form |
| 688 | // |
| 689 | // 0 <= M + 1 * IndVar < L given L >= 0 (i.e. N == 1) |
| 690 | // |
| 691 | // Here L stands for upper limit of the safe iteration space. |
| 692 | // The inequality is satisfied by (0 - M) <= IndVar < (L - M). To avoid |
| 693 | // overflows when calculating (0 - M) and (L - M) we, depending on type of |
| 694 | // IV's iteration space, limit the calculations by borders of the iteration |
| 695 | // space. For example, if IndVar is unsigned, (0 - M) overflows for any M > 0. |
| 696 | // If we figured out that "anything greater than (-M) is safe", we strengthen |
| 697 | // this to "everything greater than 0 is safe", assuming that values between |
| 698 | // -M and 0 just do not exist in unsigned iteration space, and we don't want |
| 699 | // to deal with overflown values. |
| 700 | |
| 701 | if (!IndVar->isAffine()) |
| 702 | return std::nullopt; |
| 703 | |
| 704 | const SCEV *A = NoopOrExtend(S: IndVar->getStart(), Ty: RCType, SE, Signed: IsLatchSigned); |
| 705 | const SCEVConstant *B = dyn_cast<SCEVConstant>( |
| 706 | Val: NoopOrExtend(S: IndVar->getStepRecurrence(SE), Ty: RCType, SE, Signed: IsLatchSigned)); |
| 707 | if (!B) |
| 708 | return std::nullopt; |
| 709 | assert(!B->isZero() && "Recurrence with zero step?" ); |
| 710 | |
| 711 | const SCEV *C = getBegin(); |
| 712 | const SCEVConstant *D = dyn_cast<SCEVConstant>(Val: getStep()); |
| 713 | if (D != B) |
| 714 | return std::nullopt; |
| 715 | |
| 716 | assert(!D->getValue()->isZero() && "Recurrence with zero step?" ); |
| 717 | unsigned BitWidth = RCType->getBitWidth(); |
| 718 | const SCEV *SIntMax = SE.getConstant(Val: APInt::getSignedMaxValue(numBits: BitWidth)); |
| 719 | const SCEV *SIntMin = SE.getConstant(Val: APInt::getSignedMinValue(numBits: BitWidth)); |
| 720 | |
| 721 | // Subtract Y from X so that it does not go through border of the IV |
| 722 | // iteration space. Mathematically, it is equivalent to: |
| 723 | // |
| 724 | // ClampedSubtract(X, Y) = min(max(X - Y, INT_MIN), INT_MAX). [1] |
| 725 | // |
| 726 | // In [1], 'X - Y' is a mathematical subtraction (result is not bounded to |
| 727 | // any width of bit grid). But after we take min/max, the result is |
| 728 | // guaranteed to be within [INT_MIN, INT_MAX]. |
| 729 | // |
| 730 | // In [1], INT_MAX and INT_MIN are respectively signed and unsigned max/min |
| 731 | // values, depending on type of latch condition that defines IV iteration |
| 732 | // space. |
| 733 | auto ClampedSubtract = [&](const SCEV *X, const SCEV *Y) { |
| 734 | // FIXME: The current implementation assumes that X is in [0, SINT_MAX]. |
| 735 | // This is required to ensure that SINT_MAX - X does not overflow signed and |
| 736 | // that X - Y does not overflow unsigned if Y is negative. Can we lift this |
| 737 | // restriction and make it work for negative X either? |
| 738 | if (IsLatchSigned) { |
| 739 | // X is a number from signed range, Y is interpreted as signed. |
| 740 | // Even if Y is SINT_MAX, (X - Y) does not reach SINT_MIN. So the only |
| 741 | // thing we should care about is that we didn't cross SINT_MAX. |
| 742 | // So, if Y is positive, we subtract Y safely. |
| 743 | // Rule 1: Y > 0 ---> Y. |
| 744 | // If 0 <= -Y <= (SINT_MAX - X), we subtract Y safely. |
| 745 | // Rule 2: Y >=s (X - SINT_MAX) ---> Y. |
| 746 | // If 0 <= (SINT_MAX - X) < -Y, we can only subtract (X - SINT_MAX). |
| 747 | // Rule 3: Y <s (X - SINT_MAX) ---> (X - SINT_MAX). |
| 748 | // It gives us smax(Y, X - SINT_MAX) to subtract in all cases. |
| 749 | const SCEV *XMinusSIntMax = SE.getMinusSCEV(LHS: X, RHS: SIntMax); |
| 750 | return SE.getMinusSCEV(LHS: X, RHS: SE.getSMaxExpr(LHS: Y, RHS: XMinusSIntMax), |
| 751 | Flags: SCEV::FlagNSW); |
| 752 | } else |
| 753 | // X is a number from unsigned range, Y is interpreted as signed. |
| 754 | // Even if Y is SINT_MIN, (X - Y) does not reach UINT_MAX. So the only |
| 755 | // thing we should care about is that we didn't cross zero. |
| 756 | // So, if Y is negative, we subtract Y safely. |
| 757 | // Rule 1: Y <s 0 ---> Y. |
| 758 | // If 0 <= Y <= X, we subtract Y safely. |
| 759 | // Rule 2: Y <=s X ---> Y. |
| 760 | // If 0 <= X < Y, we should stop at 0 and can only subtract X. |
| 761 | // Rule 3: Y >s X ---> X. |
| 762 | // It gives us smin(X, Y) to subtract in all cases. |
| 763 | return SE.getMinusSCEV(LHS: X, RHS: SE.getSMinExpr(LHS: X, RHS: Y), Flags: SCEV::FlagNUW); |
| 764 | }; |
| 765 | const SCEV *M = SE.getMinusSCEV(LHS: C, RHS: A); |
| 766 | const SCEV *Zero = SE.getZero(Ty: M->getType()); |
| 767 | |
| 768 | // This function returns SCEV equal to 1 if X is non-negative 0 otherwise. |
| 769 | auto SCEVCheckNonNegative = [&](const SCEV *X) { |
| 770 | const Loop *L = IndVar->getLoop(); |
| 771 | const SCEV *Zero = SE.getZero(Ty: X->getType()); |
| 772 | const SCEV *One = SE.getOne(Ty: X->getType()); |
| 773 | // Can we trivially prove that X is a non-negative or negative value? |
| 774 | if (isKnownNonNegativeInLoop(S: X, L, SE)) |
| 775 | return One; |
| 776 | else if (isKnownNegativeInLoop(S: X, L, SE)) |
| 777 | return Zero; |
| 778 | // If not, we will have to figure it out during the execution. |
| 779 | // Function smax(smin(X, 0), -1) + 1 equals to 1 if X >= 0 and 0 if X < 0. |
| 780 | const SCEV *NegOne = SE.getNegativeSCEV(V: One); |
| 781 | return SE.getAddExpr(LHS: SE.getSMaxExpr(LHS: SE.getSMinExpr(LHS: X, RHS: Zero), RHS: NegOne), RHS: One); |
| 782 | }; |
| 783 | |
| 784 | // This function returns SCEV equal to 1 if X will not overflow in terms of |
| 785 | // range check type, 0 otherwise. |
| 786 | auto SCEVCheckWillNotOverflow = [&](const SCEV *X) { |
| 787 | // X doesn't overflow if SINT_MAX >= X. |
| 788 | // Then if (SINT_MAX - X) >= 0, X doesn't overflow |
| 789 | const SCEV *SIntMaxExt = SE.getSignExtendExpr(Op: SIntMax, Ty: X->getType()); |
| 790 | const SCEV *OverflowCheck = |
| 791 | SCEVCheckNonNegative(SE.getMinusSCEV(LHS: SIntMaxExt, RHS: X)); |
| 792 | |
| 793 | // X doesn't underflow if X >= SINT_MIN. |
| 794 | // Then if (X - SINT_MIN) >= 0, X doesn't underflow |
| 795 | const SCEV *SIntMinExt = SE.getSignExtendExpr(Op: SIntMin, Ty: X->getType()); |
| 796 | const SCEV *UnderflowCheck = |
| 797 | SCEVCheckNonNegative(SE.getMinusSCEV(LHS: X, RHS: SIntMinExt)); |
| 798 | |
| 799 | return SE.getMulExpr(LHS: OverflowCheck, RHS: UnderflowCheck); |
| 800 | }; |
| 801 | |
| 802 | // FIXME: Current implementation of ClampedSubtract implicitly assumes that |
| 803 | // X is non-negative (in sense of a signed value). We need to re-implement |
| 804 | // this function in a way that it will correctly handle negative X as well. |
| 805 | // We use it twice: for X = 0 everything is fine, but for X = getEnd() we can |
| 806 | // end up with a negative X and produce wrong results. So currently we ensure |
| 807 | // that if getEnd() is negative then both ends of the safe range are zero. |
| 808 | // Note that this may pessimize elimination of unsigned range checks against |
| 809 | // negative values. |
| 810 | const SCEV *REnd = getEnd(); |
| 811 | const SCEV *EndWillNotOverflow = SE.getOne(Ty: RCType); |
| 812 | |
| 813 | auto PrintRangeCheck = [&](raw_ostream &OS) { |
| 814 | auto L = IndVar->getLoop(); |
| 815 | OS << "irce: in function " ; |
| 816 | OS << L->getHeader()->getParent()->getName(); |
| 817 | OS << ", in " ; |
| 818 | L->print(OS); |
| 819 | OS << "there is range check with scaled boundary:\n" ; |
| 820 | print(OS); |
| 821 | }; |
| 822 | |
| 823 | if (EndType->getBitWidth() > RCType->getBitWidth()) { |
| 824 | assert(EndType->getBitWidth() == RCType->getBitWidth() * 2); |
| 825 | if (PrintScaledBoundaryRangeChecks) |
| 826 | PrintRangeCheck(errs()); |
| 827 | // End is computed with extended type but will be truncated to a narrow one |
| 828 | // type of range check. Therefore we need a check that the result will not |
| 829 | // overflow in terms of narrow type. |
| 830 | EndWillNotOverflow = |
| 831 | SE.getTruncateExpr(Op: SCEVCheckWillNotOverflow(REnd), Ty: RCType); |
| 832 | REnd = SE.getTruncateExpr(Op: REnd, Ty: RCType); |
| 833 | } |
| 834 | |
| 835 | const SCEV *RuntimeChecks = |
| 836 | SE.getMulExpr(LHS: SCEVCheckNonNegative(REnd), RHS: EndWillNotOverflow); |
| 837 | const SCEV *Begin = SE.getMulExpr(LHS: ClampedSubtract(Zero, M), RHS: RuntimeChecks); |
| 838 | const SCEV *End = SE.getMulExpr(LHS: ClampedSubtract(REnd, M), RHS: RuntimeChecks); |
| 839 | |
| 840 | return InductiveRangeCheck::Range(Begin, End); |
| 841 | } |
| 842 | |
| 843 | static std::optional<InductiveRangeCheck::Range> |
| 844 | IntersectSignedRange(ScalarEvolution &SE, |
| 845 | const std::optional<InductiveRangeCheck::Range> &R1, |
| 846 | const InductiveRangeCheck::Range &R2) { |
| 847 | if (R2.isEmpty(SE, /* IsSigned */ true)) |
| 848 | return std::nullopt; |
| 849 | if (!R1) |
| 850 | return R2; |
| 851 | auto &R1Value = *R1; |
| 852 | // We never return empty ranges from this function, and R1 is supposed to be |
| 853 | // a result of intersection. Thus, R1 is never empty. |
| 854 | assert(!R1Value.isEmpty(SE, /* IsSigned */ true) && |
| 855 | "We should never have empty R1!" ); |
| 856 | |
| 857 | // TODO: we could widen the smaller range and have this work; but for now we |
| 858 | // bail out to keep things simple. |
| 859 | if (R1Value.getType() != R2.getType()) |
| 860 | return std::nullopt; |
| 861 | |
| 862 | const SCEV *NewBegin = SE.getSMaxExpr(LHS: R1Value.getBegin(), RHS: R2.getBegin()); |
| 863 | const SCEV *NewEnd = SE.getSMinExpr(LHS: R1Value.getEnd(), RHS: R2.getEnd()); |
| 864 | |
| 865 | // If the resulting range is empty, just return std::nullopt. |
| 866 | auto Ret = InductiveRangeCheck::Range(NewBegin, NewEnd); |
| 867 | if (Ret.isEmpty(SE, /* IsSigned */ true)) |
| 868 | return std::nullopt; |
| 869 | return Ret; |
| 870 | } |
| 871 | |
| 872 | static std::optional<InductiveRangeCheck::Range> |
| 873 | IntersectUnsignedRange(ScalarEvolution &SE, |
| 874 | const std::optional<InductiveRangeCheck::Range> &R1, |
| 875 | const InductiveRangeCheck::Range &R2) { |
| 876 | if (R2.isEmpty(SE, /* IsSigned */ false)) |
| 877 | return std::nullopt; |
| 878 | if (!R1) |
| 879 | return R2; |
| 880 | auto &R1Value = *R1; |
| 881 | // We never return empty ranges from this function, and R1 is supposed to be |
| 882 | // a result of intersection. Thus, R1 is never empty. |
| 883 | assert(!R1Value.isEmpty(SE, /* IsSigned */ false) && |
| 884 | "We should never have empty R1!" ); |
| 885 | |
| 886 | // TODO: we could widen the smaller range and have this work; but for now we |
| 887 | // bail out to keep things simple. |
| 888 | if (R1Value.getType() != R2.getType()) |
| 889 | return std::nullopt; |
| 890 | |
| 891 | const SCEV *NewBegin = SE.getUMaxExpr(LHS: R1Value.getBegin(), RHS: R2.getBegin()); |
| 892 | const SCEV *NewEnd = SE.getUMinExpr(LHS: R1Value.getEnd(), RHS: R2.getEnd()); |
| 893 | |
| 894 | // If the resulting range is empty, just return std::nullopt. |
| 895 | auto Ret = InductiveRangeCheck::Range(NewBegin, NewEnd); |
| 896 | if (Ret.isEmpty(SE, /* IsSigned */ false)) |
| 897 | return std::nullopt; |
| 898 | return Ret; |
| 899 | } |
| 900 | |
| 901 | PreservedAnalyses IRCEPass::run(Function &F, FunctionAnalysisManager &AM) { |
| 902 | auto &DT = AM.getResult<DominatorTreeAnalysis>(IR&: F); |
| 903 | LoopInfo &LI = AM.getResult<LoopAnalysis>(IR&: F); |
| 904 | // There are no loops in the function. Return before computing other expensive |
| 905 | // analyses. |
| 906 | if (LI.empty()) |
| 907 | return PreservedAnalyses::all(); |
| 908 | auto &SE = AM.getResult<ScalarEvolutionAnalysis>(IR&: F); |
| 909 | auto &BPI = AM.getResult<BranchProbabilityAnalysis>(IR&: F); |
| 910 | |
| 911 | // Get BFI analysis result on demand. Please note that modification of |
| 912 | // CFG invalidates this analysis and we should handle it. |
| 913 | auto getBFI = [&F, &AM ]()->BlockFrequencyInfo & { |
| 914 | return AM.getResult<BlockFrequencyAnalysis>(IR&: F); |
| 915 | }; |
| 916 | InductiveRangeCheckElimination IRCE(SE, &BPI, DT, LI, { getBFI }); |
| 917 | |
| 918 | bool Changed = false; |
| 919 | { |
| 920 | bool CFGChanged = false; |
| 921 | for (const auto &L : LI) { |
| 922 | CFGChanged |= simplifyLoop(L, DT: &DT, LI: &LI, SE: &SE, AC: nullptr, MSSAU: nullptr, |
| 923 | /*PreserveLCSSA=*/false); |
| 924 | Changed |= formLCSSARecursively(L&: *L, DT, LI: &LI, SE: &SE); |
| 925 | } |
| 926 | Changed |= CFGChanged; |
| 927 | |
| 928 | if (CFGChanged && !SkipProfitabilityChecks) { |
| 929 | PreservedAnalyses PA = PreservedAnalyses::all(); |
| 930 | PA.abandon<BlockFrequencyAnalysis>(); |
| 931 | AM.invalidate(IR&: F, PA); |
| 932 | } |
| 933 | } |
| 934 | |
| 935 | SmallPriorityWorklist<Loop *, 4> Worklist; |
| 936 | appendLoopsToWorklist(LI, Worklist); |
| 937 | auto LPMAddNewLoop = [&Worklist](Loop *NL, bool IsSubloop) { |
| 938 | if (!IsSubloop) |
| 939 | appendLoopsToWorklist(*NL, Worklist); |
| 940 | }; |
| 941 | |
| 942 | while (!Worklist.empty()) { |
| 943 | Loop *L = Worklist.pop_back_val(); |
| 944 | if (IRCE.run(L, LPMAddNewLoop)) { |
| 945 | Changed = true; |
| 946 | if (!SkipProfitabilityChecks) { |
| 947 | PreservedAnalyses PA = PreservedAnalyses::all(); |
| 948 | PA.abandon<BlockFrequencyAnalysis>(); |
| 949 | AM.invalidate(IR&: F, PA); |
| 950 | } |
| 951 | } |
| 952 | } |
| 953 | |
| 954 | if (!Changed) |
| 955 | return PreservedAnalyses::all(); |
| 956 | return getLoopPassPreservedAnalyses(); |
| 957 | } |
| 958 | |
| 959 | std::optional<uint64_t> |
| 960 | InductiveRangeCheckElimination::estimatedTripCount(const Loop &L) { |
| 961 | if (GetBFI) { |
| 962 | BlockFrequencyInfo &BFI = (*GetBFI)(); |
| 963 | uint64_t hFreq = BFI.getBlockFreq(BB: L.getHeader()).getFrequency(); |
| 964 | uint64_t phFreq = BFI.getBlockFreq(BB: L.getLoopPreheader()).getFrequency(); |
| 965 | if (phFreq == 0 || hFreq == 0) |
| 966 | return std::nullopt; |
| 967 | return {hFreq / phFreq}; |
| 968 | } |
| 969 | |
| 970 | if (!BPI) |
| 971 | return std::nullopt; |
| 972 | |
| 973 | auto *Latch = L.getLoopLatch(); |
| 974 | if (!Latch) |
| 975 | return std::nullopt; |
| 976 | auto *LatchBr = dyn_cast<BranchInst>(Val: Latch->getTerminator()); |
| 977 | if (!LatchBr) |
| 978 | return std::nullopt; |
| 979 | |
| 980 | auto LatchBrExitIdx = LatchBr->getSuccessor(i: 0) == L.getHeader() ? 1 : 0; |
| 981 | BranchProbability ExitProbability = |
| 982 | BPI->getEdgeProbability(Src: Latch, IndexInSuccessors: LatchBrExitIdx); |
| 983 | if (ExitProbability.isUnknown() || ExitProbability.isZero()) |
| 984 | return std::nullopt; |
| 985 | |
| 986 | return {ExitProbability.scaleByInverse(Num: 1)}; |
| 987 | } |
| 988 | |
| 989 | bool InductiveRangeCheckElimination::run( |
| 990 | Loop *L, function_ref<void(Loop *, bool)> LPMAddNewLoop) { |
| 991 | if (L->getBlocks().size() >= LoopSizeCutoff) { |
| 992 | LLVM_DEBUG(dbgs() << "irce: giving up constraining loop, too large\n" ); |
| 993 | return false; |
| 994 | } |
| 995 | |
| 996 | BasicBlock * = L->getLoopPreheader(); |
| 997 | if (!Preheader) { |
| 998 | LLVM_DEBUG(dbgs() << "irce: loop has no preheader, leaving\n" ); |
| 999 | return false; |
| 1000 | } |
| 1001 | |
| 1002 | auto EstimatedTripCount = estimatedTripCount(L: *L); |
| 1003 | if (!SkipProfitabilityChecks && EstimatedTripCount && |
| 1004 | *EstimatedTripCount < MinEliminatedChecks) { |
| 1005 | LLVM_DEBUG(dbgs() << "irce: could not prove profitability: " |
| 1006 | << "the estimated number of iterations is " |
| 1007 | << *EstimatedTripCount << "\n" ); |
| 1008 | return false; |
| 1009 | } |
| 1010 | |
| 1011 | LLVMContext &Context = Preheader->getContext(); |
| 1012 | SmallVector<InductiveRangeCheck, 16> RangeChecks; |
| 1013 | bool Changed = false; |
| 1014 | |
| 1015 | for (auto *BBI : L->getBlocks()) |
| 1016 | if (BranchInst *TBI = dyn_cast<BranchInst>(Val: BBI->getTerminator())) |
| 1017 | InductiveRangeCheck::extractRangeChecksFromBranch( |
| 1018 | BI: TBI, L, SE, BPI, EstimatedTripCount, Checks&: RangeChecks, Changed); |
| 1019 | |
| 1020 | if (RangeChecks.empty()) |
| 1021 | return Changed; |
| 1022 | |
| 1023 | auto PrintRecognizedRangeChecks = [&](raw_ostream &OS) { |
| 1024 | OS << "irce: looking at loop " ; L->print(OS); |
| 1025 | OS << "irce: loop has " << RangeChecks.size() |
| 1026 | << " inductive range checks: \n" ; |
| 1027 | for (InductiveRangeCheck &IRC : RangeChecks) |
| 1028 | IRC.print(OS); |
| 1029 | }; |
| 1030 | |
| 1031 | LLVM_DEBUG(PrintRecognizedRangeChecks(dbgs())); |
| 1032 | |
| 1033 | if (PrintRangeChecks) |
| 1034 | PrintRecognizedRangeChecks(errs()); |
| 1035 | |
| 1036 | const char *FailureReason = nullptr; |
| 1037 | std::optional<LoopStructure> MaybeLoopStructure = |
| 1038 | LoopStructure::parseLoopStructure(SE, *L, AllowUnsignedLatchCondition, |
| 1039 | FailureReason); |
| 1040 | if (!MaybeLoopStructure) { |
| 1041 | LLVM_DEBUG(dbgs() << "irce: could not parse loop structure: " |
| 1042 | << FailureReason << "\n" ;); |
| 1043 | return Changed; |
| 1044 | } |
| 1045 | LoopStructure LS = *MaybeLoopStructure; |
| 1046 | const SCEVAddRecExpr *IndVar = |
| 1047 | cast<SCEVAddRecExpr>(Val: SE.getMinusSCEV(LHS: SE.getSCEV(V: LS.IndVarBase), RHS: SE.getSCEV(V: LS.IndVarStep))); |
| 1048 | |
| 1049 | std::optional<InductiveRangeCheck::Range> SafeIterRange; |
| 1050 | |
| 1051 | SmallVector<InductiveRangeCheck, 4> RangeChecksToEliminate; |
| 1052 | // Basing on the type of latch predicate, we interpret the IV iteration range |
| 1053 | // as signed or unsigned range. We use different min/max functions (signed or |
| 1054 | // unsigned) when intersecting this range with safe iteration ranges implied |
| 1055 | // by range checks. |
| 1056 | auto IntersectRange = |
| 1057 | LS.IsSignedPredicate ? IntersectSignedRange : IntersectUnsignedRange; |
| 1058 | |
| 1059 | for (InductiveRangeCheck &IRC : RangeChecks) { |
| 1060 | auto Result = IRC.computeSafeIterationSpace(SE, IndVar, |
| 1061 | IsLatchSigned: LS.IsSignedPredicate); |
| 1062 | if (Result) { |
| 1063 | auto MaybeSafeIterRange = IntersectRange(SE, SafeIterRange, *Result); |
| 1064 | if (MaybeSafeIterRange) { |
| 1065 | assert(!MaybeSafeIterRange->isEmpty(SE, LS.IsSignedPredicate) && |
| 1066 | "We should never return empty ranges!" ); |
| 1067 | RangeChecksToEliminate.push_back(Elt: IRC); |
| 1068 | SafeIterRange = *MaybeSafeIterRange; |
| 1069 | } |
| 1070 | } |
| 1071 | } |
| 1072 | |
| 1073 | if (!SafeIterRange) |
| 1074 | return Changed; |
| 1075 | |
| 1076 | std::optional<LoopConstrainer::SubRanges> MaybeSR = |
| 1077 | calculateSubRanges(SE, L: *L, Range&: *SafeIterRange, MainLoopStructure: LS); |
| 1078 | if (!MaybeSR) { |
| 1079 | LLVM_DEBUG(dbgs() << "irce: could not compute subranges\n" ); |
| 1080 | return false; |
| 1081 | } |
| 1082 | |
| 1083 | LoopConstrainer LC(*L, LI, LPMAddNewLoop, LS, SE, DT, |
| 1084 | SafeIterRange->getBegin()->getType(), *MaybeSR); |
| 1085 | |
| 1086 | if (LC.run()) { |
| 1087 | Changed = true; |
| 1088 | |
| 1089 | auto PrintConstrainedLoopInfo = [L]() { |
| 1090 | dbgs() << "irce: in function " ; |
| 1091 | dbgs() << L->getHeader()->getParent()->getName() << ": " ; |
| 1092 | dbgs() << "constrained " ; |
| 1093 | L->print(OS&: dbgs()); |
| 1094 | }; |
| 1095 | |
| 1096 | LLVM_DEBUG(PrintConstrainedLoopInfo()); |
| 1097 | |
| 1098 | if (PrintChangedLoops) |
| 1099 | PrintConstrainedLoopInfo(); |
| 1100 | |
| 1101 | // Optimize away the now-redundant range checks. |
| 1102 | |
| 1103 | for (InductiveRangeCheck &IRC : RangeChecksToEliminate) { |
| 1104 | ConstantInt *FoldedRangeCheck = IRC.getPassingDirection() |
| 1105 | ? ConstantInt::getTrue(Context) |
| 1106 | : ConstantInt::getFalse(Context); |
| 1107 | IRC.getCheckUse()->set(FoldedRangeCheck); |
| 1108 | } |
| 1109 | } |
| 1110 | |
| 1111 | return Changed; |
| 1112 | } |
| 1113 | |