| 1 | //===- LoopFlatten.cpp - Loop flattening pass------------------------------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | // |
| 9 | // This pass flattens pairs nested loops into a single loop. |
| 10 | // |
| 11 | // The intention is to optimise loop nests like this, which together access an |
| 12 | // array linearly: |
| 13 | // |
| 14 | // for (int i = 0; i < N; ++i) |
| 15 | // for (int j = 0; j < M; ++j) |
| 16 | // f(A[i*M+j]); |
| 17 | // |
| 18 | // into one loop: |
| 19 | // |
| 20 | // for (int i = 0; i < (N*M); ++i) |
| 21 | // f(A[i]); |
| 22 | // |
| 23 | // It can also flatten loops where the induction variables are not used in the |
| 24 | // loop. This is only worth doing if the induction variables are only used in an |
| 25 | // expression like i*M+j. If they had any other uses, we would have to insert a |
| 26 | // div/mod to reconstruct the original values, so this wouldn't be profitable. |
| 27 | // |
| 28 | // We also need to prove that N*M will not overflow. The preferred solution is |
| 29 | // to widen the IV, which avoids overflow checks, so that is tried first. If |
| 30 | // the IV cannot be widened, then we try to determine that this new tripcount |
| 31 | // expression won't overflow. |
| 32 | // |
| 33 | // Q: Does LoopFlatten use SCEV? |
| 34 | // Short answer: Yes and no. |
| 35 | // |
| 36 | // Long answer: |
| 37 | // For this transformation to be valid, we require all uses of the induction |
| 38 | // variables to be linear expressions of the form i*M+j. The different Loop |
| 39 | // APIs are used to get some loop components like the induction variable, |
| 40 | // compare statement, etc. In addition, we do some pattern matching to find the |
| 41 | // linear expressions and other loop components like the loop increment. The |
| 42 | // latter are examples of expressions that do use the induction variable, but |
| 43 | // are safe to ignore when we check all uses to be of the form i*M+j. We keep |
| 44 | // track of all of this in bookkeeping struct FlattenInfo. |
| 45 | // We assume the loops to be canonical, i.e. starting at 0 and increment with |
| 46 | // 1. This makes RHS of the compare the loop tripcount (with the right |
| 47 | // predicate). We use SCEV to then sanity check that this tripcount matches |
| 48 | // with the tripcount as computed by SCEV. |
| 49 | // |
| 50 | //===----------------------------------------------------------------------===// |
| 51 | |
| 52 | #include "llvm/Transforms/Scalar/LoopFlatten.h" |
| 53 | |
| 54 | #include "llvm/ADT/Statistic.h" |
| 55 | #include "llvm/Analysis/AssumptionCache.h" |
| 56 | #include "llvm/Analysis/LoopInfo.h" |
| 57 | #include "llvm/Analysis/LoopNestAnalysis.h" |
| 58 | #include "llvm/Analysis/MemorySSAUpdater.h" |
| 59 | #include "llvm/Analysis/OptimizationRemarkEmitter.h" |
| 60 | #include "llvm/Analysis/ScalarEvolution.h" |
| 61 | #include "llvm/Analysis/TargetTransformInfo.h" |
| 62 | #include "llvm/Analysis/ValueTracking.h" |
| 63 | #include "llvm/IR/Dominators.h" |
| 64 | #include "llvm/IR/Function.h" |
| 65 | #include "llvm/IR/IRBuilder.h" |
| 66 | #include "llvm/IR/Module.h" |
| 67 | #include "llvm/IR/PatternMatch.h" |
| 68 | #include "llvm/Support/Debug.h" |
| 69 | #include "llvm/Support/raw_ostream.h" |
| 70 | #include "llvm/Transforms/Scalar/LoopPassManager.h" |
| 71 | #include "llvm/Transforms/Utils/Local.h" |
| 72 | #include "llvm/Transforms/Utils/LoopUtils.h" |
| 73 | #include "llvm/Transforms/Utils/LoopVersioning.h" |
| 74 | #include "llvm/Transforms/Utils/ScalarEvolutionExpander.h" |
| 75 | #include "llvm/Transforms/Utils/SimplifyIndVar.h" |
| 76 | #include <optional> |
| 77 | |
| 78 | using namespace llvm; |
| 79 | using namespace llvm::PatternMatch; |
| 80 | |
| 81 | #define DEBUG_TYPE "loop-flatten" |
| 82 | |
| 83 | STATISTIC(NumFlattened, "Number of loops flattened" ); |
| 84 | |
| 85 | static cl::opt<unsigned> RepeatedInstructionThreshold( |
| 86 | "loop-flatten-cost-threshold" , cl::Hidden, cl::init(Val: 2), |
| 87 | cl::desc("Limit on the cost of instructions that can be repeated due to " |
| 88 | "loop flattening" )); |
| 89 | |
| 90 | static cl::opt<bool> |
| 91 | AssumeNoOverflow("loop-flatten-assume-no-overflow" , cl::Hidden, |
| 92 | cl::init(Val: false), |
| 93 | cl::desc("Assume that the product of the two iteration " |
| 94 | "trip counts will never overflow" )); |
| 95 | |
| 96 | static cl::opt<bool> |
| 97 | WidenIV("loop-flatten-widen-iv" , cl::Hidden, cl::init(Val: true), |
| 98 | cl::desc("Widen the loop induction variables, if possible, so " |
| 99 | "overflow checks won't reject flattening" )); |
| 100 | |
| 101 | static cl::opt<bool> |
| 102 | VersionLoops("loop-flatten-version-loops" , cl::Hidden, cl::init(Val: true), |
| 103 | cl::desc("Version loops if flattened loop could overflow" )); |
| 104 | |
| 105 | namespace { |
| 106 | // We require all uses of both induction variables to match this pattern: |
| 107 | // |
| 108 | // (OuterPHI * InnerTripCount) + InnerPHI |
| 109 | // |
| 110 | // I.e., it needs to be a linear expression of the induction variables and the |
| 111 | // inner loop trip count. We keep track of all different expressions on which |
| 112 | // checks will be performed in this bookkeeping struct. |
| 113 | // |
| 114 | struct FlattenInfo { |
| 115 | Loop *OuterLoop = nullptr; // The loop pair to be flattened. |
| 116 | Loop *InnerLoop = nullptr; |
| 117 | |
| 118 | PHINode *InnerInductionPHI = nullptr; // These PHINodes correspond to loop |
| 119 | PHINode *OuterInductionPHI = nullptr; // induction variables, which are |
| 120 | // expected to start at zero and |
| 121 | // increment by one on each loop. |
| 122 | |
| 123 | Value *InnerTripCount = nullptr; // The product of these two tripcounts |
| 124 | Value *OuterTripCount = nullptr; // will be the new flattened loop |
| 125 | // tripcount. Also used to recognise a |
| 126 | // linear expression that will be replaced. |
| 127 | |
| 128 | SmallPtrSet<Value *, 4> LinearIVUses; // Contains the linear expressions |
| 129 | // of the form i*M+j that will be |
| 130 | // replaced. |
| 131 | |
| 132 | BinaryOperator *InnerIncrement = nullptr; // Uses of induction variables in |
| 133 | BinaryOperator *OuterIncrement = nullptr; // loop control statements that |
| 134 | BranchInst *InnerBranch = nullptr; // are safe to ignore. |
| 135 | |
| 136 | BranchInst *OuterBranch = nullptr; // The instruction that needs to be |
| 137 | // updated with new tripcount. |
| 138 | |
| 139 | SmallPtrSet<PHINode *, 4> InnerPHIsToTransform; |
| 140 | |
| 141 | bool Widened = false; // Whether this holds the flatten info before or after |
| 142 | // widening. |
| 143 | |
| 144 | PHINode *NarrowInnerInductionPHI = nullptr; // Holds the old/narrow induction |
| 145 | PHINode *NarrowOuterInductionPHI = nullptr; // phis, i.e. the Phis before IV |
| 146 | // has been applied. Used to skip |
| 147 | // checks on phi nodes. |
| 148 | |
| 149 | Value *NewTripCount = nullptr; // The tripcount of the flattened loop. |
| 150 | |
| 151 | FlattenInfo(Loop *OL, Loop *IL) : OuterLoop(OL), InnerLoop(IL){}; |
| 152 | |
| 153 | bool isNarrowInductionPhi(PHINode *Phi) { |
| 154 | // This can't be the narrow phi if we haven't widened the IV first. |
| 155 | if (!Widened) |
| 156 | return false; |
| 157 | return NarrowInnerInductionPHI == Phi || NarrowOuterInductionPHI == Phi; |
| 158 | } |
| 159 | bool isInnerLoopIncrement(User *U) { |
| 160 | return InnerIncrement == U; |
| 161 | } |
| 162 | bool isOuterLoopIncrement(User *U) { |
| 163 | return OuterIncrement == U; |
| 164 | } |
| 165 | bool isInnerLoopTest(User *U) { |
| 166 | return InnerBranch->getCondition() == U; |
| 167 | } |
| 168 | |
| 169 | bool checkOuterInductionPhiUsers(SmallPtrSet<Value *, 4> &ValidOuterPHIUses) { |
| 170 | for (User *U : OuterInductionPHI->users()) { |
| 171 | if (isOuterLoopIncrement(U)) |
| 172 | continue; |
| 173 | |
| 174 | auto IsValidOuterPHIUses = [&] (User *U) -> bool { |
| 175 | LLVM_DEBUG(dbgs() << "Found use of outer induction variable: " ; U->dump()); |
| 176 | if (!ValidOuterPHIUses.count(Ptr: U)) { |
| 177 | LLVM_DEBUG(dbgs() << "Did not match expected pattern, bailing\n" ); |
| 178 | return false; |
| 179 | } |
| 180 | LLVM_DEBUG(dbgs() << "Use is optimisable\n" ); |
| 181 | return true; |
| 182 | }; |
| 183 | |
| 184 | if (auto *V = dyn_cast<TruncInst>(Val: U)) { |
| 185 | for (auto *K : V->users()) { |
| 186 | if (!IsValidOuterPHIUses(K)) |
| 187 | return false; |
| 188 | } |
| 189 | continue; |
| 190 | } |
| 191 | |
| 192 | if (!IsValidOuterPHIUses(U)) |
| 193 | return false; |
| 194 | } |
| 195 | return true; |
| 196 | } |
| 197 | |
| 198 | bool matchLinearIVUser(User *U, Value *InnerTripCount, |
| 199 | SmallPtrSet<Value *, 4> &ValidOuterPHIUses) { |
| 200 | LLVM_DEBUG(dbgs() << "Checking linear i*M+j expression for: " ; U->dump()); |
| 201 | Value *MatchedMul = nullptr; |
| 202 | Value *MatchedItCount = nullptr; |
| 203 | |
| 204 | bool IsAdd = match(V: U, P: m_c_Add(L: m_Specific(V: InnerInductionPHI), |
| 205 | R: m_Value(V&: MatchedMul))) && |
| 206 | match(V: MatchedMul, P: m_c_Mul(L: m_Specific(V: OuterInductionPHI), |
| 207 | R: m_Value(V&: MatchedItCount))); |
| 208 | |
| 209 | // Matches the same pattern as above, except it also looks for truncs |
| 210 | // on the phi, which can be the result of widening the induction variables. |
| 211 | bool IsAddTrunc = |
| 212 | match(V: U, P: m_c_Add(L: m_Trunc(Op: m_Specific(V: InnerInductionPHI)), |
| 213 | R: m_Value(V&: MatchedMul))) && |
| 214 | match(V: MatchedMul, P: m_c_Mul(L: m_Trunc(Op: m_Specific(V: OuterInductionPHI)), |
| 215 | R: m_Value(V&: MatchedItCount))); |
| 216 | |
| 217 | // Matches the pattern ptr+i*M+j, with the two additions being done via GEP. |
| 218 | bool IsGEP = match(V: U, P: m_GEP(Ops: m_GEP(Ops: m_Value(), Ops: m_Value(V&: MatchedMul)), |
| 219 | Ops: m_Specific(V: InnerInductionPHI))) && |
| 220 | match(V: MatchedMul, P: m_c_Mul(L: m_Specific(V: OuterInductionPHI), |
| 221 | R: m_Value(V&: MatchedItCount))); |
| 222 | |
| 223 | if (!MatchedItCount) |
| 224 | return false; |
| 225 | |
| 226 | LLVM_DEBUG(dbgs() << "Matched multiplication: " ; MatchedMul->dump()); |
| 227 | LLVM_DEBUG(dbgs() << "Matched iteration count: " ; MatchedItCount->dump()); |
| 228 | |
| 229 | // The mul should not have any other uses. Widening may leave trivially dead |
| 230 | // uses, which can be ignored. |
| 231 | if (count_if(Range: MatchedMul->users(), P: [](User *U) { |
| 232 | return !isInstructionTriviallyDead(I: cast<Instruction>(Val: U)); |
| 233 | }) > 1) { |
| 234 | LLVM_DEBUG(dbgs() << "Multiply has more than one use\n" ); |
| 235 | return false; |
| 236 | } |
| 237 | |
| 238 | // Look through extends if the IV has been widened. Don't look through |
| 239 | // extends if we already looked through a trunc. |
| 240 | if (Widened && (IsAdd || IsGEP) && |
| 241 | (isa<SExtInst>(Val: MatchedItCount) || isa<ZExtInst>(Val: MatchedItCount))) { |
| 242 | assert(MatchedItCount->getType() == InnerInductionPHI->getType() && |
| 243 | "Unexpected type mismatch in types after widening" ); |
| 244 | MatchedItCount = isa<SExtInst>(Val: MatchedItCount) |
| 245 | ? dyn_cast<SExtInst>(Val: MatchedItCount)->getOperand(i_nocapture: 0) |
| 246 | : dyn_cast<ZExtInst>(Val: MatchedItCount)->getOperand(i_nocapture: 0); |
| 247 | } |
| 248 | |
| 249 | LLVM_DEBUG(dbgs() << "Looking for inner trip count: " ; |
| 250 | InnerTripCount->dump()); |
| 251 | |
| 252 | if ((IsAdd || IsAddTrunc || IsGEP) && MatchedItCount == InnerTripCount) { |
| 253 | LLVM_DEBUG(dbgs() << "Found. This sse is optimisable\n" ); |
| 254 | ValidOuterPHIUses.insert(Ptr: MatchedMul); |
| 255 | LinearIVUses.insert(Ptr: U); |
| 256 | return true; |
| 257 | } |
| 258 | |
| 259 | LLVM_DEBUG(dbgs() << "Did not match expected pattern, bailing\n" ); |
| 260 | return false; |
| 261 | } |
| 262 | |
| 263 | bool checkInnerInductionPhiUsers(SmallPtrSet<Value *, 4> &ValidOuterPHIUses) { |
| 264 | Value *SExtInnerTripCount = InnerTripCount; |
| 265 | if (Widened && |
| 266 | (isa<SExtInst>(Val: InnerTripCount) || isa<ZExtInst>(Val: InnerTripCount))) |
| 267 | SExtInnerTripCount = cast<Instruction>(Val: InnerTripCount)->getOperand(i: 0); |
| 268 | |
| 269 | for (User *U : InnerInductionPHI->users()) { |
| 270 | LLVM_DEBUG(dbgs() << "Checking User: " ; U->dump()); |
| 271 | if (isInnerLoopIncrement(U)) { |
| 272 | LLVM_DEBUG(dbgs() << "Use is inner loop increment, continuing\n" ); |
| 273 | continue; |
| 274 | } |
| 275 | |
| 276 | // After widening the IVs, a trunc instruction might have been introduced, |
| 277 | // so look through truncs. |
| 278 | if (isa<TruncInst>(Val: U)) { |
| 279 | if (!U->hasOneUse()) |
| 280 | return false; |
| 281 | U = *U->user_begin(); |
| 282 | } |
| 283 | |
| 284 | // If the use is in the compare (which is also the condition of the inner |
| 285 | // branch) then the compare has been altered by another transformation e.g |
| 286 | // icmp ult %inc, tripcount -> icmp ult %j, tripcount-1, where tripcount is |
| 287 | // a constant. Ignore this use as the compare gets removed later anyway. |
| 288 | if (isInnerLoopTest(U)) { |
| 289 | LLVM_DEBUG(dbgs() << "Use is the inner loop test, continuing\n" ); |
| 290 | continue; |
| 291 | } |
| 292 | |
| 293 | if (!matchLinearIVUser(U, InnerTripCount: SExtInnerTripCount, ValidOuterPHIUses)) { |
| 294 | LLVM_DEBUG(dbgs() << "Not a linear IV user\n" ); |
| 295 | return false; |
| 296 | } |
| 297 | LLVM_DEBUG(dbgs() << "Linear IV users found!\n" ); |
| 298 | } |
| 299 | return true; |
| 300 | } |
| 301 | }; |
| 302 | } // namespace |
| 303 | |
| 304 | static bool |
| 305 | setLoopComponents(Value *&TC, Value *&TripCount, BinaryOperator *&Increment, |
| 306 | SmallPtrSetImpl<Instruction *> &IterationInstructions) { |
| 307 | TripCount = TC; |
| 308 | IterationInstructions.insert(Ptr: Increment); |
| 309 | LLVM_DEBUG(dbgs() << "Found Increment: " ; Increment->dump()); |
| 310 | LLVM_DEBUG(dbgs() << "Found trip count: " ; TripCount->dump()); |
| 311 | LLVM_DEBUG(dbgs() << "Successfully found all loop components\n" ); |
| 312 | return true; |
| 313 | } |
| 314 | |
| 315 | // Given the RHS of the loop latch compare instruction, verify with SCEV |
| 316 | // that this is indeed the loop tripcount. |
| 317 | // TODO: This used to be a straightforward check but has grown to be quite |
| 318 | // complicated now. It is therefore worth revisiting what the additional |
| 319 | // benefits are of this (compared to relying on canonical loops and pattern |
| 320 | // matching). |
| 321 | static bool verifyTripCount(Value *RHS, Loop *L, |
| 322 | SmallPtrSetImpl<Instruction *> &IterationInstructions, |
| 323 | PHINode *&InductionPHI, Value *&TripCount, BinaryOperator *&Increment, |
| 324 | BranchInst *&BackBranch, ScalarEvolution *SE, bool IsWidened) { |
| 325 | const SCEV *BackedgeTakenCount = SE->getBackedgeTakenCount(L); |
| 326 | if (isa<SCEVCouldNotCompute>(Val: BackedgeTakenCount)) { |
| 327 | LLVM_DEBUG(dbgs() << "Backedge-taken count is not predictable\n" ); |
| 328 | return false; |
| 329 | } |
| 330 | |
| 331 | // Evaluating in the trip count's type can not overflow here as the overflow |
| 332 | // checks are performed in checkOverflow, but are first tried to avoid by |
| 333 | // widening the IV. |
| 334 | const SCEV *SCEVTripCount = |
| 335 | SE->getTripCountFromExitCount(ExitCount: BackedgeTakenCount, |
| 336 | EvalTy: BackedgeTakenCount->getType(), L); |
| 337 | |
| 338 | const SCEV *SCEVRHS = SE->getSCEV(V: RHS); |
| 339 | if (SCEVRHS == SCEVTripCount) |
| 340 | return setLoopComponents(TC&: RHS, TripCount, Increment, IterationInstructions); |
| 341 | ConstantInt *ConstantRHS = dyn_cast<ConstantInt>(Val: RHS); |
| 342 | if (ConstantRHS) { |
| 343 | const SCEV *BackedgeTCExt = nullptr; |
| 344 | if (IsWidened) { |
| 345 | const SCEV *SCEVTripCountExt; |
| 346 | // Find the extended backedge taken count and extended trip count using |
| 347 | // SCEV. One of these should now match the RHS of the compare. |
| 348 | BackedgeTCExt = SE->getZeroExtendExpr(Op: BackedgeTakenCount, Ty: RHS->getType()); |
| 349 | SCEVTripCountExt = SE->getTripCountFromExitCount(ExitCount: BackedgeTCExt, |
| 350 | EvalTy: RHS->getType(), L); |
| 351 | if (SCEVRHS != BackedgeTCExt && SCEVRHS != SCEVTripCountExt) { |
| 352 | LLVM_DEBUG(dbgs() << "Could not find valid trip count\n" ); |
| 353 | return false; |
| 354 | } |
| 355 | } |
| 356 | // If the RHS of the compare is equal to the backedge taken count we need |
| 357 | // to add one to get the trip count. |
| 358 | if (SCEVRHS == BackedgeTCExt || SCEVRHS == BackedgeTakenCount) { |
| 359 | Value *NewRHS = ConstantInt::get(Context&: ConstantRHS->getContext(), |
| 360 | V: ConstantRHS->getValue() + 1); |
| 361 | return setLoopComponents(TC&: NewRHS, TripCount, Increment, |
| 362 | IterationInstructions); |
| 363 | } |
| 364 | return setLoopComponents(TC&: RHS, TripCount, Increment, IterationInstructions); |
| 365 | } |
| 366 | // If the RHS isn't a constant then check that the reason it doesn't match |
| 367 | // the SCEV trip count is because the RHS is a ZExt or SExt instruction |
| 368 | // (and take the trip count to be the RHS). |
| 369 | if (!IsWidened) { |
| 370 | LLVM_DEBUG(dbgs() << "Could not find valid trip count\n" ); |
| 371 | return false; |
| 372 | } |
| 373 | auto *TripCountInst = dyn_cast<Instruction>(Val: RHS); |
| 374 | if (!TripCountInst) { |
| 375 | LLVM_DEBUG(dbgs() << "Could not find valid trip count\n" ); |
| 376 | return false; |
| 377 | } |
| 378 | if ((!isa<ZExtInst>(Val: TripCountInst) && !isa<SExtInst>(Val: TripCountInst)) || |
| 379 | SE->getSCEV(V: TripCountInst->getOperand(i: 0)) != SCEVTripCount) { |
| 380 | LLVM_DEBUG(dbgs() << "Could not find valid extended trip count\n" ); |
| 381 | return false; |
| 382 | } |
| 383 | return setLoopComponents(TC&: RHS, TripCount, Increment, IterationInstructions); |
| 384 | } |
| 385 | |
| 386 | // Finds the induction variable, increment and trip count for a simple loop that |
| 387 | // we can flatten. |
| 388 | static bool findLoopComponents( |
| 389 | Loop *L, SmallPtrSetImpl<Instruction *> &IterationInstructions, |
| 390 | PHINode *&InductionPHI, Value *&TripCount, BinaryOperator *&Increment, |
| 391 | BranchInst *&BackBranch, ScalarEvolution *SE, bool IsWidened) { |
| 392 | LLVM_DEBUG(dbgs() << "Finding components of loop: " << L->getName() << "\n" ); |
| 393 | |
| 394 | if (!L->isLoopSimplifyForm()) { |
| 395 | LLVM_DEBUG(dbgs() << "Loop is not in normal form\n" ); |
| 396 | return false; |
| 397 | } |
| 398 | |
| 399 | // Currently, to simplify the implementation, the Loop induction variable must |
| 400 | // start at zero and increment with a step size of one. |
| 401 | if (!L->isCanonical(SE&: *SE)) { |
| 402 | LLVM_DEBUG(dbgs() << "Loop is not canonical\n" ); |
| 403 | return false; |
| 404 | } |
| 405 | |
| 406 | // There must be exactly one exiting block, and it must be the same at the |
| 407 | // latch. |
| 408 | BasicBlock *Latch = L->getLoopLatch(); |
| 409 | if (L->getExitingBlock() != Latch) { |
| 410 | LLVM_DEBUG(dbgs() << "Exiting and latch block are different\n" ); |
| 411 | return false; |
| 412 | } |
| 413 | |
| 414 | // Find the induction PHI. If there is no induction PHI, we can't do the |
| 415 | // transformation. TODO: could other variables trigger this? Do we have to |
| 416 | // search for the best one? |
| 417 | InductionPHI = L->getInductionVariable(SE&: *SE); |
| 418 | if (!InductionPHI) { |
| 419 | LLVM_DEBUG(dbgs() << "Could not find induction PHI\n" ); |
| 420 | return false; |
| 421 | } |
| 422 | LLVM_DEBUG(dbgs() << "Found induction PHI: " ; InductionPHI->dump()); |
| 423 | |
| 424 | bool ContinueOnTrue = L->contains(BB: Latch->getTerminator()->getSuccessor(Idx: 0)); |
| 425 | auto IsValidPredicate = [&](ICmpInst::Predicate Pred) { |
| 426 | if (ContinueOnTrue) |
| 427 | return Pred == CmpInst::ICMP_NE || Pred == CmpInst::ICMP_ULT; |
| 428 | else |
| 429 | return Pred == CmpInst::ICMP_EQ; |
| 430 | }; |
| 431 | |
| 432 | // Find Compare and make sure it is valid. getLatchCmpInst checks that the |
| 433 | // back branch of the latch is conditional. |
| 434 | ICmpInst *Compare = L->getLatchCmpInst(); |
| 435 | if (!Compare || !IsValidPredicate(Compare->getUnsignedPredicate()) || |
| 436 | Compare->hasNUsesOrMore(N: 2)) { |
| 437 | LLVM_DEBUG(dbgs() << "Could not find valid comparison\n" ); |
| 438 | return false; |
| 439 | } |
| 440 | BackBranch = cast<BranchInst>(Val: Latch->getTerminator()); |
| 441 | IterationInstructions.insert(Ptr: BackBranch); |
| 442 | LLVM_DEBUG(dbgs() << "Found back branch: " ; BackBranch->dump()); |
| 443 | IterationInstructions.insert(Ptr: Compare); |
| 444 | LLVM_DEBUG(dbgs() << "Found comparison: " ; Compare->dump()); |
| 445 | |
| 446 | // Find increment and trip count. |
| 447 | // There are exactly 2 incoming values to the induction phi; one from the |
| 448 | // pre-header and one from the latch. The incoming latch value is the |
| 449 | // increment variable. |
| 450 | Increment = |
| 451 | cast<BinaryOperator>(Val: InductionPHI->getIncomingValueForBlock(BB: Latch)); |
| 452 | if ((Compare->getOperand(i_nocapture: 0) != Increment || !Increment->hasNUses(N: 2)) && |
| 453 | !Increment->hasNUses(N: 1)) { |
| 454 | LLVM_DEBUG(dbgs() << "Could not find valid increment\n" ); |
| 455 | return false; |
| 456 | } |
| 457 | // The trip count is the RHS of the compare. If this doesn't match the trip |
| 458 | // count computed by SCEV then this is because the trip count variable |
| 459 | // has been widened so the types don't match, or because it is a constant and |
| 460 | // another transformation has changed the compare (e.g. icmp ult %inc, |
| 461 | // tripcount -> icmp ult %j, tripcount-1), or both. |
| 462 | Value *RHS = Compare->getOperand(i_nocapture: 1); |
| 463 | |
| 464 | return verifyTripCount(RHS, L, IterationInstructions, InductionPHI, TripCount, |
| 465 | Increment, BackBranch, SE, IsWidened); |
| 466 | } |
| 467 | |
| 468 | static bool checkPHIs(FlattenInfo &FI, const TargetTransformInfo *TTI) { |
| 469 | // All PHIs in the inner and outer headers must either be: |
| 470 | // - The induction PHI, which we are going to rewrite as one induction in |
| 471 | // the new loop. This is already checked by findLoopComponents. |
| 472 | // - An outer header PHI with all incoming values from outside the loop. |
| 473 | // LoopSimplify guarantees we have a pre-header, so we don't need to |
| 474 | // worry about that here. |
| 475 | // - Pairs of PHIs in the inner and outer headers, which implement a |
| 476 | // loop-carried dependency that will still be valid in the new loop. To |
| 477 | // be valid, this variable must be modified only in the inner loop. |
| 478 | |
| 479 | // The set of PHI nodes in the outer loop header that we know will still be |
| 480 | // valid after the transformation. These will not need to be modified (with |
| 481 | // the exception of the induction variable), but we do need to check that |
| 482 | // there are no unsafe PHI nodes. |
| 483 | SmallPtrSet<PHINode *, 4> SafeOuterPHIs; |
| 484 | SafeOuterPHIs.insert(Ptr: FI.OuterInductionPHI); |
| 485 | |
| 486 | // Check that all PHI nodes in the inner loop header match one of the valid |
| 487 | // patterns. |
| 488 | for (PHINode &InnerPHI : FI.InnerLoop->getHeader()->phis()) { |
| 489 | // The induction PHIs break these rules, and that's OK because we treat |
| 490 | // them specially when doing the transformation. |
| 491 | if (&InnerPHI == FI.InnerInductionPHI) |
| 492 | continue; |
| 493 | if (FI.isNarrowInductionPhi(Phi: &InnerPHI)) |
| 494 | continue; |
| 495 | |
| 496 | // Each inner loop PHI node must have two incoming values/blocks - one |
| 497 | // from the pre-header, and one from the latch. |
| 498 | assert(InnerPHI.getNumIncomingValues() == 2); |
| 499 | Value * = |
| 500 | InnerPHI.getIncomingValueForBlock(BB: FI.InnerLoop->getLoopPreheader()); |
| 501 | Value *LatchValue = |
| 502 | InnerPHI.getIncomingValueForBlock(BB: FI.InnerLoop->getLoopLatch()); |
| 503 | |
| 504 | // The incoming value from the outer loop must be the PHI node in the |
| 505 | // outer loop header, with no modifications made in the top of the outer |
| 506 | // loop. |
| 507 | PHINode *OuterPHI = dyn_cast<PHINode>(Val: PreHeaderValue); |
| 508 | if (!OuterPHI || OuterPHI->getParent() != FI.OuterLoop->getHeader()) { |
| 509 | LLVM_DEBUG(dbgs() << "value modified in top of outer loop\n" ); |
| 510 | return false; |
| 511 | } |
| 512 | |
| 513 | // The other incoming value must come from the inner loop, without any |
| 514 | // modifications in the tail end of the outer loop. We are in LCSSA form, |
| 515 | // so this will actually be a PHI in the inner loop's exit block, which |
| 516 | // only uses values from inside the inner loop. |
| 517 | PHINode *LCSSAPHI = dyn_cast<PHINode>( |
| 518 | Val: OuterPHI->getIncomingValueForBlock(BB: FI.OuterLoop->getLoopLatch())); |
| 519 | if (!LCSSAPHI) { |
| 520 | LLVM_DEBUG(dbgs() << "could not find LCSSA PHI\n" ); |
| 521 | return false; |
| 522 | } |
| 523 | |
| 524 | // The value used by the LCSSA PHI must be the same one that the inner |
| 525 | // loop's PHI uses. |
| 526 | if (LCSSAPHI->hasConstantValue() != LatchValue) { |
| 527 | LLVM_DEBUG( |
| 528 | dbgs() << "LCSSA PHI incoming value does not match latch value\n" ); |
| 529 | return false; |
| 530 | } |
| 531 | |
| 532 | LLVM_DEBUG(dbgs() << "PHI pair is safe:\n" ); |
| 533 | LLVM_DEBUG(dbgs() << " Inner: " ; InnerPHI.dump()); |
| 534 | LLVM_DEBUG(dbgs() << " Outer: " ; OuterPHI->dump()); |
| 535 | SafeOuterPHIs.insert(Ptr: OuterPHI); |
| 536 | FI.InnerPHIsToTransform.insert(Ptr: &InnerPHI); |
| 537 | } |
| 538 | |
| 539 | for (PHINode &OuterPHI : FI.OuterLoop->getHeader()->phis()) { |
| 540 | if (FI.isNarrowInductionPhi(Phi: &OuterPHI)) |
| 541 | continue; |
| 542 | if (!SafeOuterPHIs.count(Ptr: &OuterPHI)) { |
| 543 | LLVM_DEBUG(dbgs() << "found unsafe PHI in outer loop: " ; OuterPHI.dump()); |
| 544 | return false; |
| 545 | } |
| 546 | } |
| 547 | |
| 548 | LLVM_DEBUG(dbgs() << "checkPHIs: OK\n" ); |
| 549 | return true; |
| 550 | } |
| 551 | |
| 552 | static bool |
| 553 | checkOuterLoopInsts(FlattenInfo &FI, |
| 554 | SmallPtrSetImpl<Instruction *> &IterationInstructions, |
| 555 | const TargetTransformInfo *TTI) { |
| 556 | // Check for instructions in the outer but not inner loop. If any of these |
| 557 | // have side-effects then this transformation is not legal, and if there is |
| 558 | // a significant amount of code here which can't be optimised out that it's |
| 559 | // not profitable (as these instructions would get executed for each |
| 560 | // iteration of the inner loop). |
| 561 | InstructionCost RepeatedInstrCost = 0; |
| 562 | for (auto *B : FI.OuterLoop->getBlocks()) { |
| 563 | if (FI.InnerLoop->contains(BB: B)) |
| 564 | continue; |
| 565 | |
| 566 | for (auto &I : *B) { |
| 567 | if (!isa<PHINode>(Val: &I) && !I.isTerminator() && |
| 568 | !isSafeToSpeculativelyExecute(I: &I)) { |
| 569 | LLVM_DEBUG(dbgs() << "Cannot flatten because instruction may have " |
| 570 | "side effects: " ; |
| 571 | I.dump()); |
| 572 | return false; |
| 573 | } |
| 574 | // The execution count of the outer loop's iteration instructions |
| 575 | // (increment, compare and branch) will be increased, but the |
| 576 | // equivalent instructions will be removed from the inner loop, so |
| 577 | // they make a net difference of zero. |
| 578 | if (IterationInstructions.count(Ptr: &I)) |
| 579 | continue; |
| 580 | // The unconditional branch to the inner loop's header will turn into |
| 581 | // a fall-through, so adds no cost. |
| 582 | BranchInst *Br = dyn_cast<BranchInst>(Val: &I); |
| 583 | if (Br && Br->isUnconditional() && |
| 584 | Br->getSuccessor(i: 0) == FI.InnerLoop->getHeader()) |
| 585 | continue; |
| 586 | // Multiplies of the outer iteration variable and inner iteration |
| 587 | // count will be optimised out. |
| 588 | if (match(V: &I, P: m_c_Mul(L: m_Specific(V: FI.OuterInductionPHI), |
| 589 | R: m_Specific(V: FI.InnerTripCount)))) |
| 590 | continue; |
| 591 | InstructionCost Cost = |
| 592 | TTI->getInstructionCost(U: &I, CostKind: TargetTransformInfo::TCK_SizeAndLatency); |
| 593 | LLVM_DEBUG(dbgs() << "Cost " << Cost << ": " ; I.dump()); |
| 594 | RepeatedInstrCost += Cost; |
| 595 | } |
| 596 | } |
| 597 | |
| 598 | LLVM_DEBUG(dbgs() << "Cost of instructions that will be repeated: " |
| 599 | << RepeatedInstrCost << "\n" ); |
| 600 | // Bail out if flattening the loops would cause instructions in the outer |
| 601 | // loop but not in the inner loop to be executed extra times. |
| 602 | if (RepeatedInstrCost > RepeatedInstructionThreshold) { |
| 603 | LLVM_DEBUG(dbgs() << "checkOuterLoopInsts: not profitable, bailing.\n" ); |
| 604 | return false; |
| 605 | } |
| 606 | |
| 607 | LLVM_DEBUG(dbgs() << "checkOuterLoopInsts: OK\n" ); |
| 608 | return true; |
| 609 | } |
| 610 | |
| 611 | |
| 612 | |
| 613 | // We require all uses of both induction variables to match this pattern: |
| 614 | // |
| 615 | // (OuterPHI * InnerTripCount) + InnerPHI |
| 616 | // |
| 617 | // Any uses of the induction variables not matching that pattern would |
| 618 | // require a div/mod to reconstruct in the flattened loop, so the |
| 619 | // transformation wouldn't be profitable. |
| 620 | static bool checkIVUsers(FlattenInfo &FI) { |
| 621 | // Check that all uses of the inner loop's induction variable match the |
| 622 | // expected pattern, recording the uses of the outer IV. |
| 623 | SmallPtrSet<Value *, 4> ValidOuterPHIUses; |
| 624 | if (!FI.checkInnerInductionPhiUsers(ValidOuterPHIUses)) |
| 625 | return false; |
| 626 | |
| 627 | // Check that there are no uses of the outer IV other than the ones found |
| 628 | // as part of the pattern above. |
| 629 | if (!FI.checkOuterInductionPhiUsers(ValidOuterPHIUses)) |
| 630 | return false; |
| 631 | |
| 632 | LLVM_DEBUG(dbgs() << "checkIVUsers: OK\n" ; |
| 633 | dbgs() << "Found " << FI.LinearIVUses.size() |
| 634 | << " value(s) that can be replaced:\n" ; |
| 635 | for (Value *V : FI.LinearIVUses) { |
| 636 | dbgs() << " " ; |
| 637 | V->dump(); |
| 638 | }); |
| 639 | return true; |
| 640 | } |
| 641 | |
| 642 | // Return an OverflowResult dependant on if overflow of the multiplication of |
| 643 | // InnerTripCount and OuterTripCount can be assumed not to happen. |
| 644 | static OverflowResult checkOverflow(FlattenInfo &FI, DominatorTree *DT, |
| 645 | AssumptionCache *AC) { |
| 646 | Function *F = FI.OuterLoop->getHeader()->getParent(); |
| 647 | const DataLayout &DL = F->getDataLayout(); |
| 648 | |
| 649 | // For debugging/testing. |
| 650 | if (AssumeNoOverflow) |
| 651 | return OverflowResult::NeverOverflows; |
| 652 | |
| 653 | // Check if the multiply could not overflow due to known ranges of the |
| 654 | // input values. |
| 655 | OverflowResult OR = computeOverflowForUnsignedMul( |
| 656 | LHS: FI.InnerTripCount, RHS: FI.OuterTripCount, |
| 657 | SQ: SimplifyQuery(DL, DT, AC, |
| 658 | FI.OuterLoop->getLoopPreheader()->getTerminator())); |
| 659 | if (OR != OverflowResult::MayOverflow) |
| 660 | return OR; |
| 661 | |
| 662 | auto CheckGEP = [&](GetElementPtrInst *GEP, Value *GEPOperand) { |
| 663 | for (Value *GEPUser : GEP->users()) { |
| 664 | auto *GEPUserInst = cast<Instruction>(Val: GEPUser); |
| 665 | if (!isa<LoadInst>(Val: GEPUserInst) && |
| 666 | !(isa<StoreInst>(Val: GEPUserInst) && GEP == GEPUserInst->getOperand(i: 1))) |
| 667 | continue; |
| 668 | if (!isGuaranteedToExecuteForEveryIteration(I: GEPUserInst, L: FI.InnerLoop)) |
| 669 | continue; |
| 670 | // The IV is used as the operand of a GEP which dominates the loop |
| 671 | // latch, and the IV is at least as wide as the address space of the |
| 672 | // GEP. In this case, the GEP would wrap around the address space |
| 673 | // before the IV increment wraps, which would be UB. |
| 674 | if (GEP->isInBounds() && |
| 675 | GEPOperand->getType()->getIntegerBitWidth() >= |
| 676 | DL.getPointerTypeSizeInBits(GEP->getType())) { |
| 677 | LLVM_DEBUG( |
| 678 | dbgs() << "use of linear IV would be UB if overflow occurred: " ; |
| 679 | GEP->dump()); |
| 680 | return true; |
| 681 | } |
| 682 | } |
| 683 | return false; |
| 684 | }; |
| 685 | |
| 686 | // Check if any IV user is, or is used by, a GEP that would cause UB if the |
| 687 | // multiply overflows. |
| 688 | for (Value *V : FI.LinearIVUses) { |
| 689 | if (auto *GEP = dyn_cast<GetElementPtrInst>(Val: V)) |
| 690 | if (GEP->getNumIndices() == 1 && CheckGEP(GEP, GEP->getOperand(i_nocapture: 1))) |
| 691 | return OverflowResult::NeverOverflows; |
| 692 | for (Value *U : V->users()) |
| 693 | if (auto *GEP = dyn_cast<GetElementPtrInst>(Val: U)) |
| 694 | if (CheckGEP(GEP, V)) |
| 695 | return OverflowResult::NeverOverflows; |
| 696 | } |
| 697 | |
| 698 | return OverflowResult::MayOverflow; |
| 699 | } |
| 700 | |
| 701 | static bool CanFlattenLoopPair(FlattenInfo &FI, DominatorTree *DT, LoopInfo *LI, |
| 702 | ScalarEvolution *SE, AssumptionCache *AC, |
| 703 | const TargetTransformInfo *TTI) { |
| 704 | SmallPtrSet<Instruction *, 8> IterationInstructions; |
| 705 | if (!findLoopComponents(L: FI.InnerLoop, IterationInstructions, |
| 706 | InductionPHI&: FI.InnerInductionPHI, TripCount&: FI.InnerTripCount, |
| 707 | Increment&: FI.InnerIncrement, BackBranch&: FI.InnerBranch, SE, IsWidened: FI.Widened)) |
| 708 | return false; |
| 709 | if (!findLoopComponents(L: FI.OuterLoop, IterationInstructions, |
| 710 | InductionPHI&: FI.OuterInductionPHI, TripCount&: FI.OuterTripCount, |
| 711 | Increment&: FI.OuterIncrement, BackBranch&: FI.OuterBranch, SE, IsWidened: FI.Widened)) |
| 712 | return false; |
| 713 | |
| 714 | // Both of the loop trip count values must be invariant in the outer loop |
| 715 | // (non-instructions are all inherently invariant). |
| 716 | if (!FI.OuterLoop->isLoopInvariant(V: FI.InnerTripCount)) { |
| 717 | LLVM_DEBUG(dbgs() << "inner loop trip count not invariant\n" ); |
| 718 | return false; |
| 719 | } |
| 720 | if (!FI.OuterLoop->isLoopInvariant(V: FI.OuterTripCount)) { |
| 721 | LLVM_DEBUG(dbgs() << "outer loop trip count not invariant\n" ); |
| 722 | return false; |
| 723 | } |
| 724 | |
| 725 | if (!checkPHIs(FI, TTI)) |
| 726 | return false; |
| 727 | |
| 728 | // FIXME: it should be possible to handle different types correctly. |
| 729 | if (FI.InnerInductionPHI->getType() != FI.OuterInductionPHI->getType()) |
| 730 | return false; |
| 731 | |
| 732 | if (!checkOuterLoopInsts(FI, IterationInstructions, TTI)) |
| 733 | return false; |
| 734 | |
| 735 | // Find the values in the loop that can be replaced with the linearized |
| 736 | // induction variable, and check that there are no other uses of the inner |
| 737 | // or outer induction variable. If there were, we could still do this |
| 738 | // transformation, but we'd have to insert a div/mod to calculate the |
| 739 | // original IVs, so it wouldn't be profitable. |
| 740 | if (!checkIVUsers(FI)) |
| 741 | return false; |
| 742 | |
| 743 | LLVM_DEBUG(dbgs() << "CanFlattenLoopPair: OK\n" ); |
| 744 | return true; |
| 745 | } |
| 746 | |
| 747 | static bool DoFlattenLoopPair(FlattenInfo &FI, DominatorTree *DT, LoopInfo *LI, |
| 748 | ScalarEvolution *SE, AssumptionCache *AC, |
| 749 | const TargetTransformInfo *TTI, LPMUpdater *U, |
| 750 | MemorySSAUpdater *MSSAU) { |
| 751 | Function *F = FI.OuterLoop->getHeader()->getParent(); |
| 752 | LLVM_DEBUG(dbgs() << "Checks all passed, doing the transformation\n" ); |
| 753 | { |
| 754 | using namespace ore; |
| 755 | OptimizationRemark (DEBUG_TYPE, "Flattened" , FI.InnerLoop->getStartLoc(), |
| 756 | FI.InnerLoop->getHeader()); |
| 757 | OptimizationRemarkEmitter ORE(F); |
| 758 | Remark << "Flattened into outer loop" ; |
| 759 | ORE.emit(OptDiag&: Remark); |
| 760 | } |
| 761 | |
| 762 | if (!FI.NewTripCount) { |
| 763 | FI.NewTripCount = BinaryOperator::CreateMul( |
| 764 | V1: FI.InnerTripCount, V2: FI.OuterTripCount, Name: "flatten.tripcount" , |
| 765 | InsertBefore: FI.OuterLoop->getLoopPreheader()->getTerminator()->getIterator()); |
| 766 | LLVM_DEBUG(dbgs() << "Created new trip count in preheader: " ; |
| 767 | FI.NewTripCount->dump()); |
| 768 | } |
| 769 | |
| 770 | // Fix up PHI nodes that take values from the inner loop back-edge, which |
| 771 | // we are about to remove. |
| 772 | FI.InnerInductionPHI->removeIncomingValue(BB: FI.InnerLoop->getLoopLatch()); |
| 773 | |
| 774 | // The old Phi will be optimised away later, but for now we can't leave |
| 775 | // leave it in an invalid state, so are updating them too. |
| 776 | for (PHINode *PHI : FI.InnerPHIsToTransform) |
| 777 | PHI->removeIncomingValue(BB: FI.InnerLoop->getLoopLatch()); |
| 778 | |
| 779 | // Modify the trip count of the outer loop to be the product of the two |
| 780 | // trip counts. |
| 781 | cast<User>(Val: FI.OuterBranch->getCondition())->setOperand(i: 1, Val: FI.NewTripCount); |
| 782 | |
| 783 | // Replace the inner loop backedge with an unconditional branch to the exit. |
| 784 | BasicBlock *InnerExitBlock = FI.InnerLoop->getExitBlock(); |
| 785 | BasicBlock *InnerExitingBlock = FI.InnerLoop->getExitingBlock(); |
| 786 | Instruction *Term = InnerExitingBlock->getTerminator(); |
| 787 | Instruction *BI = BranchInst::Create(IfTrue: InnerExitBlock, InsertBefore: InnerExitingBlock); |
| 788 | BI->setDebugLoc(Term->getDebugLoc()); |
| 789 | Term->eraseFromParent(); |
| 790 | |
| 791 | // Update the DomTree and MemorySSA. |
| 792 | DT->deleteEdge(From: InnerExitingBlock, To: FI.InnerLoop->getHeader()); |
| 793 | if (MSSAU) |
| 794 | MSSAU->removeEdge(From: InnerExitingBlock, To: FI.InnerLoop->getHeader()); |
| 795 | |
| 796 | // Replace all uses of the polynomial calculated from the two induction |
| 797 | // variables with the one new one. |
| 798 | IRBuilder<> Builder(FI.OuterInductionPHI->getParent()->getTerminator()); |
| 799 | for (Value *V : FI.LinearIVUses) { |
| 800 | Value *OuterValue = FI.OuterInductionPHI; |
| 801 | if (FI.Widened) |
| 802 | OuterValue = Builder.CreateTrunc(V: FI.OuterInductionPHI, DestTy: V->getType(), |
| 803 | Name: "flatten.trunciv" ); |
| 804 | |
| 805 | if (auto *GEP = dyn_cast<GetElementPtrInst>(Val: V)) { |
| 806 | // Replace the GEP with one that uses OuterValue as the offset. |
| 807 | auto *InnerGEP = cast<GetElementPtrInst>(Val: GEP->getOperand(i_nocapture: 0)); |
| 808 | Value *Base = InnerGEP->getOperand(i_nocapture: 0); |
| 809 | // When the base of the GEP doesn't dominate the outer induction phi then |
| 810 | // we need to insert the new GEP where the old GEP was. |
| 811 | if (!DT->dominates(Def: Base, User: &*Builder.GetInsertPoint())) |
| 812 | Builder.SetInsertPoint(cast<Instruction>(Val: V)); |
| 813 | OuterValue = |
| 814 | Builder.CreateGEP(Ty: GEP->getSourceElementType(), Ptr: Base, IdxList: OuterValue, |
| 815 | Name: "flatten." + V->getName(), |
| 816 | NW: GEP->isInBounds() && InnerGEP->isInBounds()); |
| 817 | } |
| 818 | |
| 819 | LLVM_DEBUG(dbgs() << "Replacing: " ; V->dump(); dbgs() << "with: " ; |
| 820 | OuterValue->dump()); |
| 821 | V->replaceAllUsesWith(V: OuterValue); |
| 822 | } |
| 823 | |
| 824 | // Tell LoopInfo, SCEV and the pass manager that the inner loop has been |
| 825 | // deleted, and invalidate any outer loop information. |
| 826 | SE->forgetLoop(L: FI.OuterLoop); |
| 827 | SE->forgetBlockAndLoopDispositions(); |
| 828 | if (U) |
| 829 | U->markLoopAsDeleted(L&: *FI.InnerLoop, Name: FI.InnerLoop->getName()); |
| 830 | LI->erase(L: FI.InnerLoop); |
| 831 | |
| 832 | // Increment statistic value. |
| 833 | NumFlattened++; |
| 834 | |
| 835 | return true; |
| 836 | } |
| 837 | |
| 838 | static bool CanWidenIV(FlattenInfo &FI, DominatorTree *DT, LoopInfo *LI, |
| 839 | ScalarEvolution *SE, AssumptionCache *AC, |
| 840 | const TargetTransformInfo *TTI) { |
| 841 | if (!WidenIV) { |
| 842 | LLVM_DEBUG(dbgs() << "Widening the IVs is disabled\n" ); |
| 843 | return false; |
| 844 | } |
| 845 | |
| 846 | LLVM_DEBUG(dbgs() << "Try widening the IVs\n" ); |
| 847 | Module *M = FI.InnerLoop->getHeader()->getParent()->getParent(); |
| 848 | auto &DL = M->getDataLayout(); |
| 849 | auto *InnerType = FI.InnerInductionPHI->getType(); |
| 850 | auto *OuterType = FI.OuterInductionPHI->getType(); |
| 851 | unsigned MaxLegalSize = DL.getLargestLegalIntTypeSizeInBits(); |
| 852 | auto *MaxLegalType = DL.getLargestLegalIntType(C&: M->getContext()); |
| 853 | |
| 854 | // If both induction types are less than the maximum legal integer width, |
| 855 | // promote both to the widest type available so we know calculating |
| 856 | // (OuterTripCount * InnerTripCount) as the new trip count is safe. |
| 857 | if (InnerType != OuterType || |
| 858 | InnerType->getScalarSizeInBits() >= MaxLegalSize || |
| 859 | MaxLegalType->getScalarSizeInBits() < |
| 860 | InnerType->getScalarSizeInBits() * 2) { |
| 861 | LLVM_DEBUG(dbgs() << "Can't widen the IV\n" ); |
| 862 | return false; |
| 863 | } |
| 864 | |
| 865 | SCEVExpander Rewriter(*SE, DL, "loopflatten" ); |
| 866 | SmallVector<WeakTrackingVH, 4> DeadInsts; |
| 867 | unsigned ElimExt = 0; |
| 868 | unsigned Widened = 0; |
| 869 | |
| 870 | auto CreateWideIV = [&](WideIVInfo WideIV, bool &Deleted) -> bool { |
| 871 | PHINode *WidePhi = |
| 872 | createWideIV(WI: WideIV, LI, SE, Rewriter, DT, DeadInsts, NumElimExt&: ElimExt, NumWidened&: Widened, |
| 873 | HasGuards: true /* HasGuards */, UsePostIncrementRanges: true /* UsePostIncrementRanges */); |
| 874 | if (!WidePhi) |
| 875 | return false; |
| 876 | LLVM_DEBUG(dbgs() << "Created wide phi: " ; WidePhi->dump()); |
| 877 | LLVM_DEBUG(dbgs() << "Deleting old phi: " ; WideIV.NarrowIV->dump()); |
| 878 | Deleted = RecursivelyDeleteDeadPHINode(PN: WideIV.NarrowIV); |
| 879 | return true; |
| 880 | }; |
| 881 | |
| 882 | bool Deleted; |
| 883 | if (!CreateWideIV({.NarrowIV: FI.InnerInductionPHI, .WidestNativeType: MaxLegalType, .IsSigned: false}, Deleted)) |
| 884 | return false; |
| 885 | // Add the narrow phi to list, so that it will be adjusted later when the |
| 886 | // the transformation is performed. |
| 887 | if (!Deleted) |
| 888 | FI.InnerPHIsToTransform.insert(Ptr: FI.InnerInductionPHI); |
| 889 | |
| 890 | if (!CreateWideIV({.NarrowIV: FI.OuterInductionPHI, .WidestNativeType: MaxLegalType, .IsSigned: false}, Deleted)) |
| 891 | return false; |
| 892 | |
| 893 | assert(Widened && "Widened IV expected" ); |
| 894 | FI.Widened = true; |
| 895 | |
| 896 | // Save the old/narrow induction phis, which we need to ignore in CheckPHIs. |
| 897 | FI.NarrowInnerInductionPHI = FI.InnerInductionPHI; |
| 898 | FI.NarrowOuterInductionPHI = FI.OuterInductionPHI; |
| 899 | |
| 900 | // After widening, rediscover all the loop components. |
| 901 | return CanFlattenLoopPair(FI, DT, LI, SE, AC, TTI); |
| 902 | } |
| 903 | |
| 904 | static bool FlattenLoopPair(FlattenInfo &FI, DominatorTree *DT, LoopInfo *LI, |
| 905 | ScalarEvolution *SE, AssumptionCache *AC, |
| 906 | const TargetTransformInfo *TTI, LPMUpdater *U, |
| 907 | MemorySSAUpdater *MSSAU, |
| 908 | const LoopAccessInfo &LAI) { |
| 909 | LLVM_DEBUG( |
| 910 | dbgs() << "Loop flattening running on outer loop " |
| 911 | << FI.OuterLoop->getHeader()->getName() << " and inner loop " |
| 912 | << FI.InnerLoop->getHeader()->getName() << " in " |
| 913 | << FI.OuterLoop->getHeader()->getParent()->getName() << "\n" ); |
| 914 | |
| 915 | if (!CanFlattenLoopPair(FI, DT, LI, SE, AC, TTI)) |
| 916 | return false; |
| 917 | |
| 918 | // Check if we can widen the induction variables to avoid overflow checks. |
| 919 | bool CanFlatten = CanWidenIV(FI, DT, LI, SE, AC, TTI); |
| 920 | |
| 921 | // It can happen that after widening of the IV, flattening may not be |
| 922 | // possible/happening, e.g. when it is deemed unprofitable. So bail here if |
| 923 | // that is the case. |
| 924 | // TODO: IV widening without performing the actual flattening transformation |
| 925 | // is not ideal. While this codegen change should not matter much, it is an |
| 926 | // unnecessary change which is better to avoid. It's unlikely this happens |
| 927 | // often, because if it's unprofitibale after widening, it should be |
| 928 | // unprofitabe before widening as checked in the first round of checks. But |
| 929 | // 'RepeatedInstructionThreshold' is set to only 2, which can probably be |
| 930 | // relaxed. Because this is making a code change (the IV widening, but not |
| 931 | // the flattening), we return true here. |
| 932 | if (FI.Widened && !CanFlatten) |
| 933 | return true; |
| 934 | |
| 935 | // If we have widened and can perform the transformation, do that here. |
| 936 | if (CanFlatten) |
| 937 | return DoFlattenLoopPair(FI, DT, LI, SE, AC, TTI, U, MSSAU); |
| 938 | |
| 939 | // Otherwise, if we haven't widened the IV, check if the new iteration |
| 940 | // variable might overflow. In this case, we need to version the loop, and |
| 941 | // select the original version at runtime if the iteration space is too |
| 942 | // large. |
| 943 | OverflowResult OR = checkOverflow(FI, DT, AC); |
| 944 | if (OR == OverflowResult::AlwaysOverflowsHigh || |
| 945 | OR == OverflowResult::AlwaysOverflowsLow) { |
| 946 | LLVM_DEBUG(dbgs() << "Multiply would always overflow, so not profitable\n" ); |
| 947 | return false; |
| 948 | } else if (OR == OverflowResult::MayOverflow) { |
| 949 | Module *M = FI.OuterLoop->getHeader()->getParent()->getParent(); |
| 950 | const DataLayout &DL = M->getDataLayout(); |
| 951 | if (!VersionLoops) { |
| 952 | LLVM_DEBUG(dbgs() << "Multiply might overflow, not flattening\n" ); |
| 953 | return false; |
| 954 | } else if (!DL.isLegalInteger( |
| 955 | Width: FI.OuterTripCount->getType()->getScalarSizeInBits())) { |
| 956 | // If the trip count type isn't legal then it won't be possible to check |
| 957 | // for overflow using only a single multiply instruction, so don't |
| 958 | // flatten. |
| 959 | LLVM_DEBUG( |
| 960 | dbgs() << "Can't check overflow efficiently, not flattening\n" ); |
| 961 | return false; |
| 962 | } |
| 963 | LLVM_DEBUG(dbgs() << "Multiply might overflow, versioning loop\n" ); |
| 964 | |
| 965 | // Version the loop. The overflow check isn't a runtime pointer check, so we |
| 966 | // pass an empty list of runtime pointer checks, causing LoopVersioning to |
| 967 | // emit 'false' as the branch condition, and add our own check afterwards. |
| 968 | BasicBlock *CheckBlock = FI.OuterLoop->getLoopPreheader(); |
| 969 | ArrayRef<RuntimePointerCheck> Checks(nullptr, nullptr); |
| 970 | LoopVersioning LVer(LAI, Checks, FI.OuterLoop, LI, DT, SE); |
| 971 | LVer.versionLoop(); |
| 972 | |
| 973 | // Check for overflow by calculating the new tripcount using |
| 974 | // umul_with_overflow and then checking if it overflowed. |
| 975 | BranchInst *Br = cast<BranchInst>(Val: CheckBlock->getTerminator()); |
| 976 | assert(Br->isConditional() && |
| 977 | "Expected LoopVersioning to generate a conditional branch" ); |
| 978 | assert(match(Br->getCondition(), m_Zero()) && |
| 979 | "Expected branch condition to be false" ); |
| 980 | IRBuilder<> Builder(Br); |
| 981 | Value *Call = Builder.CreateIntrinsic( |
| 982 | ID: Intrinsic::umul_with_overflow, Types: FI.OuterTripCount->getType(), |
| 983 | Args: {FI.OuterTripCount, FI.InnerTripCount}, |
| 984 | /*FMFSource=*/nullptr, Name: "flatten.mul" ); |
| 985 | FI.NewTripCount = Builder.CreateExtractValue(Agg: Call, Idxs: 0, Name: "flatten.tripcount" ); |
| 986 | Value *Overflow = Builder.CreateExtractValue(Agg: Call, Idxs: 1, Name: "flatten.overflow" ); |
| 987 | Br->setCondition(Overflow); |
| 988 | } else { |
| 989 | LLVM_DEBUG(dbgs() << "Multiply cannot overflow, modifying loop in-place\n" ); |
| 990 | } |
| 991 | |
| 992 | return DoFlattenLoopPair(FI, DT, LI, SE, AC, TTI, U, MSSAU); |
| 993 | } |
| 994 | |
| 995 | PreservedAnalyses LoopFlattenPass::run(LoopNest &LN, LoopAnalysisManager &LAM, |
| 996 | LoopStandardAnalysisResults &AR, |
| 997 | LPMUpdater &U) { |
| 998 | |
| 999 | bool Changed = false; |
| 1000 | |
| 1001 | std::optional<MemorySSAUpdater> MSSAU; |
| 1002 | if (AR.MSSA) { |
| 1003 | MSSAU = MemorySSAUpdater(AR.MSSA); |
| 1004 | if (VerifyMemorySSA) |
| 1005 | AR.MSSA->verifyMemorySSA(); |
| 1006 | } |
| 1007 | |
| 1008 | // The loop flattening pass requires loops to be |
| 1009 | // in simplified form, and also needs LCSSA. Running |
| 1010 | // this pass will simplify all loops that contain inner loops, |
| 1011 | // regardless of whether anything ends up being flattened. |
| 1012 | LoopAccessInfoManager LAIM(AR.SE, AR.AA, AR.DT, AR.LI, &AR.TTI, nullptr); |
| 1013 | for (Loop *InnerLoop : LN.getLoops()) { |
| 1014 | auto *OuterLoop = InnerLoop->getParentLoop(); |
| 1015 | if (!OuterLoop) |
| 1016 | continue; |
| 1017 | FlattenInfo FI(OuterLoop, InnerLoop); |
| 1018 | Changed |= |
| 1019 | FlattenLoopPair(FI, DT: &AR.DT, LI: &AR.LI, SE: &AR.SE, AC: &AR.AC, TTI: &AR.TTI, U: &U, |
| 1020 | MSSAU: MSSAU ? &*MSSAU : nullptr, LAI: LAIM.getInfo(L&: *OuterLoop)); |
| 1021 | } |
| 1022 | |
| 1023 | if (!Changed) |
| 1024 | return PreservedAnalyses::all(); |
| 1025 | |
| 1026 | if (AR.MSSA && VerifyMemorySSA) |
| 1027 | AR.MSSA->verifyMemorySSA(); |
| 1028 | |
| 1029 | auto PA = getLoopPassPreservedAnalyses(); |
| 1030 | if (AR.MSSA) |
| 1031 | PA.preserve<MemorySSAAnalysis>(); |
| 1032 | return PA; |
| 1033 | } |
| 1034 | |