1 | //===- LoopFlatten.cpp - Loop flattening pass------------------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This pass flattens pairs nested loops into a single loop. |
10 | // |
11 | // The intention is to optimise loop nests like this, which together access an |
12 | // array linearly: |
13 | // |
14 | // for (int i = 0; i < N; ++i) |
15 | // for (int j = 0; j < M; ++j) |
16 | // f(A[i*M+j]); |
17 | // |
18 | // into one loop: |
19 | // |
20 | // for (int i = 0; i < (N*M); ++i) |
21 | // f(A[i]); |
22 | // |
23 | // It can also flatten loops where the induction variables are not used in the |
24 | // loop. This is only worth doing if the induction variables are only used in an |
25 | // expression like i*M+j. If they had any other uses, we would have to insert a |
26 | // div/mod to reconstruct the original values, so this wouldn't be profitable. |
27 | // |
28 | // We also need to prove that N*M will not overflow. The preferred solution is |
29 | // to widen the IV, which avoids overflow checks, so that is tried first. If |
30 | // the IV cannot be widened, then we try to determine that this new tripcount |
31 | // expression won't overflow. |
32 | // |
33 | // Q: Does LoopFlatten use SCEV? |
34 | // Short answer: Yes and no. |
35 | // |
36 | // Long answer: |
37 | // For this transformation to be valid, we require all uses of the induction |
38 | // variables to be linear expressions of the form i*M+j. The different Loop |
39 | // APIs are used to get some loop components like the induction variable, |
40 | // compare statement, etc. In addition, we do some pattern matching to find the |
41 | // linear expressions and other loop components like the loop increment. The |
42 | // latter are examples of expressions that do use the induction variable, but |
43 | // are safe to ignore when we check all uses to be of the form i*M+j. We keep |
44 | // track of all of this in bookkeeping struct FlattenInfo. |
45 | // We assume the loops to be canonical, i.e. starting at 0 and increment with |
46 | // 1. This makes RHS of the compare the loop tripcount (with the right |
47 | // predicate). We use SCEV to then sanity check that this tripcount matches |
48 | // with the tripcount as computed by SCEV. |
49 | // |
50 | //===----------------------------------------------------------------------===// |
51 | |
52 | #include "llvm/Transforms/Scalar/LoopFlatten.h" |
53 | |
54 | #include "llvm/ADT/Statistic.h" |
55 | #include "llvm/Analysis/AssumptionCache.h" |
56 | #include "llvm/Analysis/LoopInfo.h" |
57 | #include "llvm/Analysis/LoopNestAnalysis.h" |
58 | #include "llvm/Analysis/MemorySSAUpdater.h" |
59 | #include "llvm/Analysis/OptimizationRemarkEmitter.h" |
60 | #include "llvm/Analysis/ScalarEvolution.h" |
61 | #include "llvm/Analysis/TargetTransformInfo.h" |
62 | #include "llvm/Analysis/ValueTracking.h" |
63 | #include "llvm/IR/Dominators.h" |
64 | #include "llvm/IR/Function.h" |
65 | #include "llvm/IR/IRBuilder.h" |
66 | #include "llvm/IR/Module.h" |
67 | #include "llvm/IR/PatternMatch.h" |
68 | #include "llvm/Support/Debug.h" |
69 | #include "llvm/Support/raw_ostream.h" |
70 | #include "llvm/Transforms/Scalar/LoopPassManager.h" |
71 | #include "llvm/Transforms/Utils/Local.h" |
72 | #include "llvm/Transforms/Utils/LoopUtils.h" |
73 | #include "llvm/Transforms/Utils/LoopVersioning.h" |
74 | #include "llvm/Transforms/Utils/ScalarEvolutionExpander.h" |
75 | #include "llvm/Transforms/Utils/SimplifyIndVar.h" |
76 | #include <optional> |
77 | |
78 | using namespace llvm; |
79 | using namespace llvm::PatternMatch; |
80 | |
81 | #define DEBUG_TYPE "loop-flatten" |
82 | |
83 | STATISTIC(NumFlattened, "Number of loops flattened" ); |
84 | |
85 | static cl::opt<unsigned> RepeatedInstructionThreshold( |
86 | "loop-flatten-cost-threshold" , cl::Hidden, cl::init(Val: 2), |
87 | cl::desc("Limit on the cost of instructions that can be repeated due to " |
88 | "loop flattening" )); |
89 | |
90 | static cl::opt<bool> |
91 | AssumeNoOverflow("loop-flatten-assume-no-overflow" , cl::Hidden, |
92 | cl::init(Val: false), |
93 | cl::desc("Assume that the product of the two iteration " |
94 | "trip counts will never overflow" )); |
95 | |
96 | static cl::opt<bool> |
97 | WidenIV("loop-flatten-widen-iv" , cl::Hidden, cl::init(Val: true), |
98 | cl::desc("Widen the loop induction variables, if possible, so " |
99 | "overflow checks won't reject flattening" )); |
100 | |
101 | static cl::opt<bool> |
102 | VersionLoops("loop-flatten-version-loops" , cl::Hidden, cl::init(Val: true), |
103 | cl::desc("Version loops if flattened loop could overflow" )); |
104 | |
105 | namespace { |
106 | // We require all uses of both induction variables to match this pattern: |
107 | // |
108 | // (OuterPHI * InnerTripCount) + InnerPHI |
109 | // |
110 | // I.e., it needs to be a linear expression of the induction variables and the |
111 | // inner loop trip count. We keep track of all different expressions on which |
112 | // checks will be performed in this bookkeeping struct. |
113 | // |
114 | struct FlattenInfo { |
115 | Loop *OuterLoop = nullptr; // The loop pair to be flattened. |
116 | Loop *InnerLoop = nullptr; |
117 | |
118 | PHINode *InnerInductionPHI = nullptr; // These PHINodes correspond to loop |
119 | PHINode *OuterInductionPHI = nullptr; // induction variables, which are |
120 | // expected to start at zero and |
121 | // increment by one on each loop. |
122 | |
123 | Value *InnerTripCount = nullptr; // The product of these two tripcounts |
124 | Value *OuterTripCount = nullptr; // will be the new flattened loop |
125 | // tripcount. Also used to recognise a |
126 | // linear expression that will be replaced. |
127 | |
128 | SmallPtrSet<Value *, 4> LinearIVUses; // Contains the linear expressions |
129 | // of the form i*M+j that will be |
130 | // replaced. |
131 | |
132 | BinaryOperator *InnerIncrement = nullptr; // Uses of induction variables in |
133 | BinaryOperator *OuterIncrement = nullptr; // loop control statements that |
134 | BranchInst *InnerBranch = nullptr; // are safe to ignore. |
135 | |
136 | BranchInst *OuterBranch = nullptr; // The instruction that needs to be |
137 | // updated with new tripcount. |
138 | |
139 | SmallPtrSet<PHINode *, 4> InnerPHIsToTransform; |
140 | |
141 | bool Widened = false; // Whether this holds the flatten info before or after |
142 | // widening. |
143 | |
144 | PHINode *NarrowInnerInductionPHI = nullptr; // Holds the old/narrow induction |
145 | PHINode *NarrowOuterInductionPHI = nullptr; // phis, i.e. the Phis before IV |
146 | // has been applied. Used to skip |
147 | // checks on phi nodes. |
148 | |
149 | Value *NewTripCount = nullptr; // The tripcount of the flattened loop. |
150 | |
151 | FlattenInfo(Loop *OL, Loop *IL) : OuterLoop(OL), InnerLoop(IL){}; |
152 | |
153 | bool isNarrowInductionPhi(PHINode *Phi) { |
154 | // This can't be the narrow phi if we haven't widened the IV first. |
155 | if (!Widened) |
156 | return false; |
157 | return NarrowInnerInductionPHI == Phi || NarrowOuterInductionPHI == Phi; |
158 | } |
159 | bool isInnerLoopIncrement(User *U) { |
160 | return InnerIncrement == U; |
161 | } |
162 | bool isOuterLoopIncrement(User *U) { |
163 | return OuterIncrement == U; |
164 | } |
165 | bool isInnerLoopTest(User *U) { |
166 | return InnerBranch->getCondition() == U; |
167 | } |
168 | |
169 | bool checkOuterInductionPhiUsers(SmallPtrSet<Value *, 4> &ValidOuterPHIUses) { |
170 | for (User *U : OuterInductionPHI->users()) { |
171 | if (isOuterLoopIncrement(U)) |
172 | continue; |
173 | |
174 | auto IsValidOuterPHIUses = [&] (User *U) -> bool { |
175 | LLVM_DEBUG(dbgs() << "Found use of outer induction variable: " ; U->dump()); |
176 | if (!ValidOuterPHIUses.count(Ptr: U)) { |
177 | LLVM_DEBUG(dbgs() << "Did not match expected pattern, bailing\n" ); |
178 | return false; |
179 | } |
180 | LLVM_DEBUG(dbgs() << "Use is optimisable\n" ); |
181 | return true; |
182 | }; |
183 | |
184 | if (auto *V = dyn_cast<TruncInst>(Val: U)) { |
185 | for (auto *K : V->users()) { |
186 | if (!IsValidOuterPHIUses(K)) |
187 | return false; |
188 | } |
189 | continue; |
190 | } |
191 | |
192 | if (!IsValidOuterPHIUses(U)) |
193 | return false; |
194 | } |
195 | return true; |
196 | } |
197 | |
198 | bool matchLinearIVUser(User *U, Value *InnerTripCount, |
199 | SmallPtrSet<Value *, 4> &ValidOuterPHIUses) { |
200 | LLVM_DEBUG(dbgs() << "Checking linear i*M+j expression for: " ; U->dump()); |
201 | Value *MatchedMul = nullptr; |
202 | Value *MatchedItCount = nullptr; |
203 | |
204 | bool IsAdd = match(V: U, P: m_c_Add(L: m_Specific(V: InnerInductionPHI), |
205 | R: m_Value(V&: MatchedMul))) && |
206 | match(V: MatchedMul, P: m_c_Mul(L: m_Specific(V: OuterInductionPHI), |
207 | R: m_Value(V&: MatchedItCount))); |
208 | |
209 | // Matches the same pattern as above, except it also looks for truncs |
210 | // on the phi, which can be the result of widening the induction variables. |
211 | bool IsAddTrunc = |
212 | match(V: U, P: m_c_Add(L: m_Trunc(Op: m_Specific(V: InnerInductionPHI)), |
213 | R: m_Value(V&: MatchedMul))) && |
214 | match(V: MatchedMul, P: m_c_Mul(L: m_Trunc(Op: m_Specific(V: OuterInductionPHI)), |
215 | R: m_Value(V&: MatchedItCount))); |
216 | |
217 | // Matches the pattern ptr+i*M+j, with the two additions being done via GEP. |
218 | bool IsGEP = match(V: U, P: m_GEP(Ops: m_GEP(Ops: m_Value(), Ops: m_Value(V&: MatchedMul)), |
219 | Ops: m_Specific(V: InnerInductionPHI))) && |
220 | match(V: MatchedMul, P: m_c_Mul(L: m_Specific(V: OuterInductionPHI), |
221 | R: m_Value(V&: MatchedItCount))); |
222 | |
223 | if (!MatchedItCount) |
224 | return false; |
225 | |
226 | LLVM_DEBUG(dbgs() << "Matched multiplication: " ; MatchedMul->dump()); |
227 | LLVM_DEBUG(dbgs() << "Matched iteration count: " ; MatchedItCount->dump()); |
228 | |
229 | // The mul should not have any other uses. Widening may leave trivially dead |
230 | // uses, which can be ignored. |
231 | if (count_if(Range: MatchedMul->users(), P: [](User *U) { |
232 | return !isInstructionTriviallyDead(I: cast<Instruction>(Val: U)); |
233 | }) > 1) { |
234 | LLVM_DEBUG(dbgs() << "Multiply has more than one use\n" ); |
235 | return false; |
236 | } |
237 | |
238 | // Look through extends if the IV has been widened. Don't look through |
239 | // extends if we already looked through a trunc. |
240 | if (Widened && (IsAdd || IsGEP) && |
241 | (isa<SExtInst>(Val: MatchedItCount) || isa<ZExtInst>(Val: MatchedItCount))) { |
242 | assert(MatchedItCount->getType() == InnerInductionPHI->getType() && |
243 | "Unexpected type mismatch in types after widening" ); |
244 | MatchedItCount = isa<SExtInst>(Val: MatchedItCount) |
245 | ? dyn_cast<SExtInst>(Val: MatchedItCount)->getOperand(i_nocapture: 0) |
246 | : dyn_cast<ZExtInst>(Val: MatchedItCount)->getOperand(i_nocapture: 0); |
247 | } |
248 | |
249 | LLVM_DEBUG(dbgs() << "Looking for inner trip count: " ; |
250 | InnerTripCount->dump()); |
251 | |
252 | if ((IsAdd || IsAddTrunc || IsGEP) && MatchedItCount == InnerTripCount) { |
253 | LLVM_DEBUG(dbgs() << "Found. This sse is optimisable\n" ); |
254 | ValidOuterPHIUses.insert(Ptr: MatchedMul); |
255 | LinearIVUses.insert(Ptr: U); |
256 | return true; |
257 | } |
258 | |
259 | LLVM_DEBUG(dbgs() << "Did not match expected pattern, bailing\n" ); |
260 | return false; |
261 | } |
262 | |
263 | bool checkInnerInductionPhiUsers(SmallPtrSet<Value *, 4> &ValidOuterPHIUses) { |
264 | Value *SExtInnerTripCount = InnerTripCount; |
265 | if (Widened && |
266 | (isa<SExtInst>(Val: InnerTripCount) || isa<ZExtInst>(Val: InnerTripCount))) |
267 | SExtInnerTripCount = cast<Instruction>(Val: InnerTripCount)->getOperand(i: 0); |
268 | |
269 | for (User *U : InnerInductionPHI->users()) { |
270 | LLVM_DEBUG(dbgs() << "Checking User: " ; U->dump()); |
271 | if (isInnerLoopIncrement(U)) { |
272 | LLVM_DEBUG(dbgs() << "Use is inner loop increment, continuing\n" ); |
273 | continue; |
274 | } |
275 | |
276 | // After widening the IVs, a trunc instruction might have been introduced, |
277 | // so look through truncs. |
278 | if (isa<TruncInst>(Val: U)) { |
279 | if (!U->hasOneUse()) |
280 | return false; |
281 | U = *U->user_begin(); |
282 | } |
283 | |
284 | // If the use is in the compare (which is also the condition of the inner |
285 | // branch) then the compare has been altered by another transformation e.g |
286 | // icmp ult %inc, tripcount -> icmp ult %j, tripcount-1, where tripcount is |
287 | // a constant. Ignore this use as the compare gets removed later anyway. |
288 | if (isInnerLoopTest(U)) { |
289 | LLVM_DEBUG(dbgs() << "Use is the inner loop test, continuing\n" ); |
290 | continue; |
291 | } |
292 | |
293 | if (!matchLinearIVUser(U, InnerTripCount: SExtInnerTripCount, ValidOuterPHIUses)) { |
294 | LLVM_DEBUG(dbgs() << "Not a linear IV user\n" ); |
295 | return false; |
296 | } |
297 | LLVM_DEBUG(dbgs() << "Linear IV users found!\n" ); |
298 | } |
299 | return true; |
300 | } |
301 | }; |
302 | } // namespace |
303 | |
304 | static bool |
305 | setLoopComponents(Value *&TC, Value *&TripCount, BinaryOperator *&Increment, |
306 | SmallPtrSetImpl<Instruction *> &IterationInstructions) { |
307 | TripCount = TC; |
308 | IterationInstructions.insert(Ptr: Increment); |
309 | LLVM_DEBUG(dbgs() << "Found Increment: " ; Increment->dump()); |
310 | LLVM_DEBUG(dbgs() << "Found trip count: " ; TripCount->dump()); |
311 | LLVM_DEBUG(dbgs() << "Successfully found all loop components\n" ); |
312 | return true; |
313 | } |
314 | |
315 | // Given the RHS of the loop latch compare instruction, verify with SCEV |
316 | // that this is indeed the loop tripcount. |
317 | // TODO: This used to be a straightforward check but has grown to be quite |
318 | // complicated now. It is therefore worth revisiting what the additional |
319 | // benefits are of this (compared to relying on canonical loops and pattern |
320 | // matching). |
321 | static bool verifyTripCount(Value *RHS, Loop *L, |
322 | SmallPtrSetImpl<Instruction *> &IterationInstructions, |
323 | PHINode *&InductionPHI, Value *&TripCount, BinaryOperator *&Increment, |
324 | BranchInst *&BackBranch, ScalarEvolution *SE, bool IsWidened) { |
325 | const SCEV *BackedgeTakenCount = SE->getBackedgeTakenCount(L); |
326 | if (isa<SCEVCouldNotCompute>(Val: BackedgeTakenCount)) { |
327 | LLVM_DEBUG(dbgs() << "Backedge-taken count is not predictable\n" ); |
328 | return false; |
329 | } |
330 | |
331 | // Evaluating in the trip count's type can not overflow here as the overflow |
332 | // checks are performed in checkOverflow, but are first tried to avoid by |
333 | // widening the IV. |
334 | const SCEV *SCEVTripCount = |
335 | SE->getTripCountFromExitCount(ExitCount: BackedgeTakenCount, |
336 | EvalTy: BackedgeTakenCount->getType(), L); |
337 | |
338 | const SCEV *SCEVRHS = SE->getSCEV(V: RHS); |
339 | if (SCEVRHS == SCEVTripCount) |
340 | return setLoopComponents(TC&: RHS, TripCount, Increment, IterationInstructions); |
341 | ConstantInt *ConstantRHS = dyn_cast<ConstantInt>(Val: RHS); |
342 | if (ConstantRHS) { |
343 | const SCEV *BackedgeTCExt = nullptr; |
344 | if (IsWidened) { |
345 | const SCEV *SCEVTripCountExt; |
346 | // Find the extended backedge taken count and extended trip count using |
347 | // SCEV. One of these should now match the RHS of the compare. |
348 | BackedgeTCExt = SE->getZeroExtendExpr(Op: BackedgeTakenCount, Ty: RHS->getType()); |
349 | SCEVTripCountExt = SE->getTripCountFromExitCount(ExitCount: BackedgeTCExt, |
350 | EvalTy: RHS->getType(), L); |
351 | if (SCEVRHS != BackedgeTCExt && SCEVRHS != SCEVTripCountExt) { |
352 | LLVM_DEBUG(dbgs() << "Could not find valid trip count\n" ); |
353 | return false; |
354 | } |
355 | } |
356 | // If the RHS of the compare is equal to the backedge taken count we need |
357 | // to add one to get the trip count. |
358 | if (SCEVRHS == BackedgeTCExt || SCEVRHS == BackedgeTakenCount) { |
359 | Value *NewRHS = ConstantInt::get(Context&: ConstantRHS->getContext(), |
360 | V: ConstantRHS->getValue() + 1); |
361 | return setLoopComponents(TC&: NewRHS, TripCount, Increment, |
362 | IterationInstructions); |
363 | } |
364 | return setLoopComponents(TC&: RHS, TripCount, Increment, IterationInstructions); |
365 | } |
366 | // If the RHS isn't a constant then check that the reason it doesn't match |
367 | // the SCEV trip count is because the RHS is a ZExt or SExt instruction |
368 | // (and take the trip count to be the RHS). |
369 | if (!IsWidened) { |
370 | LLVM_DEBUG(dbgs() << "Could not find valid trip count\n" ); |
371 | return false; |
372 | } |
373 | auto *TripCountInst = dyn_cast<Instruction>(Val: RHS); |
374 | if (!TripCountInst) { |
375 | LLVM_DEBUG(dbgs() << "Could not find valid trip count\n" ); |
376 | return false; |
377 | } |
378 | if ((!isa<ZExtInst>(Val: TripCountInst) && !isa<SExtInst>(Val: TripCountInst)) || |
379 | SE->getSCEV(V: TripCountInst->getOperand(i: 0)) != SCEVTripCount) { |
380 | LLVM_DEBUG(dbgs() << "Could not find valid extended trip count\n" ); |
381 | return false; |
382 | } |
383 | return setLoopComponents(TC&: RHS, TripCount, Increment, IterationInstructions); |
384 | } |
385 | |
386 | // Finds the induction variable, increment and trip count for a simple loop that |
387 | // we can flatten. |
388 | static bool findLoopComponents( |
389 | Loop *L, SmallPtrSetImpl<Instruction *> &IterationInstructions, |
390 | PHINode *&InductionPHI, Value *&TripCount, BinaryOperator *&Increment, |
391 | BranchInst *&BackBranch, ScalarEvolution *SE, bool IsWidened) { |
392 | LLVM_DEBUG(dbgs() << "Finding components of loop: " << L->getName() << "\n" ); |
393 | |
394 | if (!L->isLoopSimplifyForm()) { |
395 | LLVM_DEBUG(dbgs() << "Loop is not in normal form\n" ); |
396 | return false; |
397 | } |
398 | |
399 | // Currently, to simplify the implementation, the Loop induction variable must |
400 | // start at zero and increment with a step size of one. |
401 | if (!L->isCanonical(SE&: *SE)) { |
402 | LLVM_DEBUG(dbgs() << "Loop is not canonical\n" ); |
403 | return false; |
404 | } |
405 | |
406 | // There must be exactly one exiting block, and it must be the same at the |
407 | // latch. |
408 | BasicBlock *Latch = L->getLoopLatch(); |
409 | if (L->getExitingBlock() != Latch) { |
410 | LLVM_DEBUG(dbgs() << "Exiting and latch block are different\n" ); |
411 | return false; |
412 | } |
413 | |
414 | // Find the induction PHI. If there is no induction PHI, we can't do the |
415 | // transformation. TODO: could other variables trigger this? Do we have to |
416 | // search for the best one? |
417 | InductionPHI = L->getInductionVariable(SE&: *SE); |
418 | if (!InductionPHI) { |
419 | LLVM_DEBUG(dbgs() << "Could not find induction PHI\n" ); |
420 | return false; |
421 | } |
422 | LLVM_DEBUG(dbgs() << "Found induction PHI: " ; InductionPHI->dump()); |
423 | |
424 | bool ContinueOnTrue = L->contains(BB: Latch->getTerminator()->getSuccessor(Idx: 0)); |
425 | auto IsValidPredicate = [&](ICmpInst::Predicate Pred) { |
426 | if (ContinueOnTrue) |
427 | return Pred == CmpInst::ICMP_NE || Pred == CmpInst::ICMP_ULT; |
428 | else |
429 | return Pred == CmpInst::ICMP_EQ; |
430 | }; |
431 | |
432 | // Find Compare and make sure it is valid. getLatchCmpInst checks that the |
433 | // back branch of the latch is conditional. |
434 | ICmpInst *Compare = L->getLatchCmpInst(); |
435 | if (!Compare || !IsValidPredicate(Compare->getUnsignedPredicate()) || |
436 | Compare->hasNUsesOrMore(N: 2)) { |
437 | LLVM_DEBUG(dbgs() << "Could not find valid comparison\n" ); |
438 | return false; |
439 | } |
440 | BackBranch = cast<BranchInst>(Val: Latch->getTerminator()); |
441 | IterationInstructions.insert(Ptr: BackBranch); |
442 | LLVM_DEBUG(dbgs() << "Found back branch: " ; BackBranch->dump()); |
443 | IterationInstructions.insert(Ptr: Compare); |
444 | LLVM_DEBUG(dbgs() << "Found comparison: " ; Compare->dump()); |
445 | |
446 | // Find increment and trip count. |
447 | // There are exactly 2 incoming values to the induction phi; one from the |
448 | // pre-header and one from the latch. The incoming latch value is the |
449 | // increment variable. |
450 | Increment = |
451 | cast<BinaryOperator>(Val: InductionPHI->getIncomingValueForBlock(BB: Latch)); |
452 | if ((Compare->getOperand(i_nocapture: 0) != Increment || !Increment->hasNUses(N: 2)) && |
453 | !Increment->hasNUses(N: 1)) { |
454 | LLVM_DEBUG(dbgs() << "Could not find valid increment\n" ); |
455 | return false; |
456 | } |
457 | // The trip count is the RHS of the compare. If this doesn't match the trip |
458 | // count computed by SCEV then this is because the trip count variable |
459 | // has been widened so the types don't match, or because it is a constant and |
460 | // another transformation has changed the compare (e.g. icmp ult %inc, |
461 | // tripcount -> icmp ult %j, tripcount-1), or both. |
462 | Value *RHS = Compare->getOperand(i_nocapture: 1); |
463 | |
464 | return verifyTripCount(RHS, L, IterationInstructions, InductionPHI, TripCount, |
465 | Increment, BackBranch, SE, IsWidened); |
466 | } |
467 | |
468 | static bool checkPHIs(FlattenInfo &FI, const TargetTransformInfo *TTI) { |
469 | // All PHIs in the inner and outer headers must either be: |
470 | // - The induction PHI, which we are going to rewrite as one induction in |
471 | // the new loop. This is already checked by findLoopComponents. |
472 | // - An outer header PHI with all incoming values from outside the loop. |
473 | // LoopSimplify guarantees we have a pre-header, so we don't need to |
474 | // worry about that here. |
475 | // - Pairs of PHIs in the inner and outer headers, which implement a |
476 | // loop-carried dependency that will still be valid in the new loop. To |
477 | // be valid, this variable must be modified only in the inner loop. |
478 | |
479 | // The set of PHI nodes in the outer loop header that we know will still be |
480 | // valid after the transformation. These will not need to be modified (with |
481 | // the exception of the induction variable), but we do need to check that |
482 | // there are no unsafe PHI nodes. |
483 | SmallPtrSet<PHINode *, 4> SafeOuterPHIs; |
484 | SafeOuterPHIs.insert(Ptr: FI.OuterInductionPHI); |
485 | |
486 | // Check that all PHI nodes in the inner loop header match one of the valid |
487 | // patterns. |
488 | for (PHINode &InnerPHI : FI.InnerLoop->getHeader()->phis()) { |
489 | // The induction PHIs break these rules, and that's OK because we treat |
490 | // them specially when doing the transformation. |
491 | if (&InnerPHI == FI.InnerInductionPHI) |
492 | continue; |
493 | if (FI.isNarrowInductionPhi(Phi: &InnerPHI)) |
494 | continue; |
495 | |
496 | // Each inner loop PHI node must have two incoming values/blocks - one |
497 | // from the pre-header, and one from the latch. |
498 | assert(InnerPHI.getNumIncomingValues() == 2); |
499 | Value * = |
500 | InnerPHI.getIncomingValueForBlock(BB: FI.InnerLoop->getLoopPreheader()); |
501 | Value *LatchValue = |
502 | InnerPHI.getIncomingValueForBlock(BB: FI.InnerLoop->getLoopLatch()); |
503 | |
504 | // The incoming value from the outer loop must be the PHI node in the |
505 | // outer loop header, with no modifications made in the top of the outer |
506 | // loop. |
507 | PHINode *OuterPHI = dyn_cast<PHINode>(Val: PreHeaderValue); |
508 | if (!OuterPHI || OuterPHI->getParent() != FI.OuterLoop->getHeader()) { |
509 | LLVM_DEBUG(dbgs() << "value modified in top of outer loop\n" ); |
510 | return false; |
511 | } |
512 | |
513 | // The other incoming value must come from the inner loop, without any |
514 | // modifications in the tail end of the outer loop. We are in LCSSA form, |
515 | // so this will actually be a PHI in the inner loop's exit block, which |
516 | // only uses values from inside the inner loop. |
517 | PHINode *LCSSAPHI = dyn_cast<PHINode>( |
518 | Val: OuterPHI->getIncomingValueForBlock(BB: FI.OuterLoop->getLoopLatch())); |
519 | if (!LCSSAPHI) { |
520 | LLVM_DEBUG(dbgs() << "could not find LCSSA PHI\n" ); |
521 | return false; |
522 | } |
523 | |
524 | // The value used by the LCSSA PHI must be the same one that the inner |
525 | // loop's PHI uses. |
526 | if (LCSSAPHI->hasConstantValue() != LatchValue) { |
527 | LLVM_DEBUG( |
528 | dbgs() << "LCSSA PHI incoming value does not match latch value\n" ); |
529 | return false; |
530 | } |
531 | |
532 | LLVM_DEBUG(dbgs() << "PHI pair is safe:\n" ); |
533 | LLVM_DEBUG(dbgs() << " Inner: " ; InnerPHI.dump()); |
534 | LLVM_DEBUG(dbgs() << " Outer: " ; OuterPHI->dump()); |
535 | SafeOuterPHIs.insert(Ptr: OuterPHI); |
536 | FI.InnerPHIsToTransform.insert(Ptr: &InnerPHI); |
537 | } |
538 | |
539 | for (PHINode &OuterPHI : FI.OuterLoop->getHeader()->phis()) { |
540 | if (FI.isNarrowInductionPhi(Phi: &OuterPHI)) |
541 | continue; |
542 | if (!SafeOuterPHIs.count(Ptr: &OuterPHI)) { |
543 | LLVM_DEBUG(dbgs() << "found unsafe PHI in outer loop: " ; OuterPHI.dump()); |
544 | return false; |
545 | } |
546 | } |
547 | |
548 | LLVM_DEBUG(dbgs() << "checkPHIs: OK\n" ); |
549 | return true; |
550 | } |
551 | |
552 | static bool |
553 | checkOuterLoopInsts(FlattenInfo &FI, |
554 | SmallPtrSetImpl<Instruction *> &IterationInstructions, |
555 | const TargetTransformInfo *TTI) { |
556 | // Check for instructions in the outer but not inner loop. If any of these |
557 | // have side-effects then this transformation is not legal, and if there is |
558 | // a significant amount of code here which can't be optimised out that it's |
559 | // not profitable (as these instructions would get executed for each |
560 | // iteration of the inner loop). |
561 | InstructionCost RepeatedInstrCost = 0; |
562 | for (auto *B : FI.OuterLoop->getBlocks()) { |
563 | if (FI.InnerLoop->contains(BB: B)) |
564 | continue; |
565 | |
566 | for (auto &I : *B) { |
567 | if (!isa<PHINode>(Val: &I) && !I.isTerminator() && |
568 | !isSafeToSpeculativelyExecute(I: &I)) { |
569 | LLVM_DEBUG(dbgs() << "Cannot flatten because instruction may have " |
570 | "side effects: " ; |
571 | I.dump()); |
572 | return false; |
573 | } |
574 | // The execution count of the outer loop's iteration instructions |
575 | // (increment, compare and branch) will be increased, but the |
576 | // equivalent instructions will be removed from the inner loop, so |
577 | // they make a net difference of zero. |
578 | if (IterationInstructions.count(Ptr: &I)) |
579 | continue; |
580 | // The unconditional branch to the inner loop's header will turn into |
581 | // a fall-through, so adds no cost. |
582 | BranchInst *Br = dyn_cast<BranchInst>(Val: &I); |
583 | if (Br && Br->isUnconditional() && |
584 | Br->getSuccessor(i: 0) == FI.InnerLoop->getHeader()) |
585 | continue; |
586 | // Multiplies of the outer iteration variable and inner iteration |
587 | // count will be optimised out. |
588 | if (match(V: &I, P: m_c_Mul(L: m_Specific(V: FI.OuterInductionPHI), |
589 | R: m_Specific(V: FI.InnerTripCount)))) |
590 | continue; |
591 | InstructionCost Cost = |
592 | TTI->getInstructionCost(U: &I, CostKind: TargetTransformInfo::TCK_SizeAndLatency); |
593 | LLVM_DEBUG(dbgs() << "Cost " << Cost << ": " ; I.dump()); |
594 | RepeatedInstrCost += Cost; |
595 | } |
596 | } |
597 | |
598 | LLVM_DEBUG(dbgs() << "Cost of instructions that will be repeated: " |
599 | << RepeatedInstrCost << "\n" ); |
600 | // Bail out if flattening the loops would cause instructions in the outer |
601 | // loop but not in the inner loop to be executed extra times. |
602 | if (RepeatedInstrCost > RepeatedInstructionThreshold) { |
603 | LLVM_DEBUG(dbgs() << "checkOuterLoopInsts: not profitable, bailing.\n" ); |
604 | return false; |
605 | } |
606 | |
607 | LLVM_DEBUG(dbgs() << "checkOuterLoopInsts: OK\n" ); |
608 | return true; |
609 | } |
610 | |
611 | |
612 | |
613 | // We require all uses of both induction variables to match this pattern: |
614 | // |
615 | // (OuterPHI * InnerTripCount) + InnerPHI |
616 | // |
617 | // Any uses of the induction variables not matching that pattern would |
618 | // require a div/mod to reconstruct in the flattened loop, so the |
619 | // transformation wouldn't be profitable. |
620 | static bool checkIVUsers(FlattenInfo &FI) { |
621 | // Check that all uses of the inner loop's induction variable match the |
622 | // expected pattern, recording the uses of the outer IV. |
623 | SmallPtrSet<Value *, 4> ValidOuterPHIUses; |
624 | if (!FI.checkInnerInductionPhiUsers(ValidOuterPHIUses)) |
625 | return false; |
626 | |
627 | // Check that there are no uses of the outer IV other than the ones found |
628 | // as part of the pattern above. |
629 | if (!FI.checkOuterInductionPhiUsers(ValidOuterPHIUses)) |
630 | return false; |
631 | |
632 | LLVM_DEBUG(dbgs() << "checkIVUsers: OK\n" ; |
633 | dbgs() << "Found " << FI.LinearIVUses.size() |
634 | << " value(s) that can be replaced:\n" ; |
635 | for (Value *V : FI.LinearIVUses) { |
636 | dbgs() << " " ; |
637 | V->dump(); |
638 | }); |
639 | return true; |
640 | } |
641 | |
642 | // Return an OverflowResult dependant on if overflow of the multiplication of |
643 | // InnerTripCount and OuterTripCount can be assumed not to happen. |
644 | static OverflowResult checkOverflow(FlattenInfo &FI, DominatorTree *DT, |
645 | AssumptionCache *AC) { |
646 | Function *F = FI.OuterLoop->getHeader()->getParent(); |
647 | const DataLayout &DL = F->getDataLayout(); |
648 | |
649 | // For debugging/testing. |
650 | if (AssumeNoOverflow) |
651 | return OverflowResult::NeverOverflows; |
652 | |
653 | // Check if the multiply could not overflow due to known ranges of the |
654 | // input values. |
655 | OverflowResult OR = computeOverflowForUnsignedMul( |
656 | LHS: FI.InnerTripCount, RHS: FI.OuterTripCount, |
657 | SQ: SimplifyQuery(DL, DT, AC, |
658 | FI.OuterLoop->getLoopPreheader()->getTerminator())); |
659 | if (OR != OverflowResult::MayOverflow) |
660 | return OR; |
661 | |
662 | auto CheckGEP = [&](GetElementPtrInst *GEP, Value *GEPOperand) { |
663 | for (Value *GEPUser : GEP->users()) { |
664 | auto *GEPUserInst = cast<Instruction>(Val: GEPUser); |
665 | if (!isa<LoadInst>(Val: GEPUserInst) && |
666 | !(isa<StoreInst>(Val: GEPUserInst) && GEP == GEPUserInst->getOperand(i: 1))) |
667 | continue; |
668 | if (!isGuaranteedToExecuteForEveryIteration(I: GEPUserInst, L: FI.InnerLoop)) |
669 | continue; |
670 | // The IV is used as the operand of a GEP which dominates the loop |
671 | // latch, and the IV is at least as wide as the address space of the |
672 | // GEP. In this case, the GEP would wrap around the address space |
673 | // before the IV increment wraps, which would be UB. |
674 | if (GEP->isInBounds() && |
675 | GEPOperand->getType()->getIntegerBitWidth() >= |
676 | DL.getPointerTypeSizeInBits(GEP->getType())) { |
677 | LLVM_DEBUG( |
678 | dbgs() << "use of linear IV would be UB if overflow occurred: " ; |
679 | GEP->dump()); |
680 | return true; |
681 | } |
682 | } |
683 | return false; |
684 | }; |
685 | |
686 | // Check if any IV user is, or is used by, a GEP that would cause UB if the |
687 | // multiply overflows. |
688 | for (Value *V : FI.LinearIVUses) { |
689 | if (auto *GEP = dyn_cast<GetElementPtrInst>(Val: V)) |
690 | if (GEP->getNumIndices() == 1 && CheckGEP(GEP, GEP->getOperand(i_nocapture: 1))) |
691 | return OverflowResult::NeverOverflows; |
692 | for (Value *U : V->users()) |
693 | if (auto *GEP = dyn_cast<GetElementPtrInst>(Val: U)) |
694 | if (CheckGEP(GEP, V)) |
695 | return OverflowResult::NeverOverflows; |
696 | } |
697 | |
698 | return OverflowResult::MayOverflow; |
699 | } |
700 | |
701 | static bool CanFlattenLoopPair(FlattenInfo &FI, DominatorTree *DT, LoopInfo *LI, |
702 | ScalarEvolution *SE, AssumptionCache *AC, |
703 | const TargetTransformInfo *TTI) { |
704 | SmallPtrSet<Instruction *, 8> IterationInstructions; |
705 | if (!findLoopComponents(L: FI.InnerLoop, IterationInstructions, |
706 | InductionPHI&: FI.InnerInductionPHI, TripCount&: FI.InnerTripCount, |
707 | Increment&: FI.InnerIncrement, BackBranch&: FI.InnerBranch, SE, IsWidened: FI.Widened)) |
708 | return false; |
709 | if (!findLoopComponents(L: FI.OuterLoop, IterationInstructions, |
710 | InductionPHI&: FI.OuterInductionPHI, TripCount&: FI.OuterTripCount, |
711 | Increment&: FI.OuterIncrement, BackBranch&: FI.OuterBranch, SE, IsWidened: FI.Widened)) |
712 | return false; |
713 | |
714 | // Both of the loop trip count values must be invariant in the outer loop |
715 | // (non-instructions are all inherently invariant). |
716 | if (!FI.OuterLoop->isLoopInvariant(V: FI.InnerTripCount)) { |
717 | LLVM_DEBUG(dbgs() << "inner loop trip count not invariant\n" ); |
718 | return false; |
719 | } |
720 | if (!FI.OuterLoop->isLoopInvariant(V: FI.OuterTripCount)) { |
721 | LLVM_DEBUG(dbgs() << "outer loop trip count not invariant\n" ); |
722 | return false; |
723 | } |
724 | |
725 | if (!checkPHIs(FI, TTI)) |
726 | return false; |
727 | |
728 | // FIXME: it should be possible to handle different types correctly. |
729 | if (FI.InnerInductionPHI->getType() != FI.OuterInductionPHI->getType()) |
730 | return false; |
731 | |
732 | if (!checkOuterLoopInsts(FI, IterationInstructions, TTI)) |
733 | return false; |
734 | |
735 | // Find the values in the loop that can be replaced with the linearized |
736 | // induction variable, and check that there are no other uses of the inner |
737 | // or outer induction variable. If there were, we could still do this |
738 | // transformation, but we'd have to insert a div/mod to calculate the |
739 | // original IVs, so it wouldn't be profitable. |
740 | if (!checkIVUsers(FI)) |
741 | return false; |
742 | |
743 | LLVM_DEBUG(dbgs() << "CanFlattenLoopPair: OK\n" ); |
744 | return true; |
745 | } |
746 | |
747 | static bool DoFlattenLoopPair(FlattenInfo &FI, DominatorTree *DT, LoopInfo *LI, |
748 | ScalarEvolution *SE, AssumptionCache *AC, |
749 | const TargetTransformInfo *TTI, LPMUpdater *U, |
750 | MemorySSAUpdater *MSSAU) { |
751 | Function *F = FI.OuterLoop->getHeader()->getParent(); |
752 | LLVM_DEBUG(dbgs() << "Checks all passed, doing the transformation\n" ); |
753 | { |
754 | using namespace ore; |
755 | OptimizationRemark (DEBUG_TYPE, "Flattened" , FI.InnerLoop->getStartLoc(), |
756 | FI.InnerLoop->getHeader()); |
757 | OptimizationRemarkEmitter ORE(F); |
758 | Remark << "Flattened into outer loop" ; |
759 | ORE.emit(OptDiag&: Remark); |
760 | } |
761 | |
762 | if (!FI.NewTripCount) { |
763 | FI.NewTripCount = BinaryOperator::CreateMul( |
764 | V1: FI.InnerTripCount, V2: FI.OuterTripCount, Name: "flatten.tripcount" , |
765 | It: FI.OuterLoop->getLoopPreheader()->getTerminator()->getIterator()); |
766 | LLVM_DEBUG(dbgs() << "Created new trip count in preheader: " ; |
767 | FI.NewTripCount->dump()); |
768 | } |
769 | |
770 | // Fix up PHI nodes that take values from the inner loop back-edge, which |
771 | // we are about to remove. |
772 | FI.InnerInductionPHI->removeIncomingValue(BB: FI.InnerLoop->getLoopLatch()); |
773 | |
774 | // The old Phi will be optimised away later, but for now we can't leave |
775 | // leave it in an invalid state, so are updating them too. |
776 | for (PHINode *PHI : FI.InnerPHIsToTransform) |
777 | PHI->removeIncomingValue(BB: FI.InnerLoop->getLoopLatch()); |
778 | |
779 | // Modify the trip count of the outer loop to be the product of the two |
780 | // trip counts. |
781 | cast<User>(Val: FI.OuterBranch->getCondition())->setOperand(i: 1, Val: FI.NewTripCount); |
782 | |
783 | // Replace the inner loop backedge with an unconditional branch to the exit. |
784 | BasicBlock *InnerExitBlock = FI.InnerLoop->getExitBlock(); |
785 | BasicBlock *InnerExitingBlock = FI.InnerLoop->getExitingBlock(); |
786 | Instruction *Term = InnerExitingBlock->getTerminator(); |
787 | Instruction *BI = BranchInst::Create(IfTrue: InnerExitBlock, InsertBefore: InnerExitingBlock); |
788 | BI->setDebugLoc(Term->getDebugLoc()); |
789 | Term->eraseFromParent(); |
790 | |
791 | // Update the DomTree and MemorySSA. |
792 | DT->deleteEdge(From: InnerExitingBlock, To: FI.InnerLoop->getHeader()); |
793 | if (MSSAU) |
794 | MSSAU->removeEdge(From: InnerExitingBlock, To: FI.InnerLoop->getHeader()); |
795 | |
796 | // Replace all uses of the polynomial calculated from the two induction |
797 | // variables with the one new one. |
798 | IRBuilder<> Builder(FI.OuterInductionPHI->getParent()->getTerminator()); |
799 | for (Value *V : FI.LinearIVUses) { |
800 | Value *OuterValue = FI.OuterInductionPHI; |
801 | if (FI.Widened) |
802 | OuterValue = Builder.CreateTrunc(V: FI.OuterInductionPHI, DestTy: V->getType(), |
803 | Name: "flatten.trunciv" ); |
804 | |
805 | if (auto *GEP = dyn_cast<GetElementPtrInst>(Val: V)) { |
806 | // Replace the GEP with one that uses OuterValue as the offset. |
807 | auto *InnerGEP = cast<GetElementPtrInst>(Val: GEP->getOperand(i_nocapture: 0)); |
808 | Value *Base = InnerGEP->getOperand(i_nocapture: 0); |
809 | // When the base of the GEP doesn't dominate the outer induction phi then |
810 | // we need to insert the new GEP where the old GEP was. |
811 | if (!DT->dominates(Def: Base, User: &*Builder.GetInsertPoint())) |
812 | Builder.SetInsertPoint(cast<Instruction>(Val: V)); |
813 | OuterValue = |
814 | Builder.CreateGEP(Ty: GEP->getSourceElementType(), Ptr: Base, IdxList: OuterValue, |
815 | Name: "flatten." + V->getName(), |
816 | NW: GEP->isInBounds() && InnerGEP->isInBounds()); |
817 | } |
818 | |
819 | LLVM_DEBUG(dbgs() << "Replacing: " ; V->dump(); dbgs() << "with: " ; |
820 | OuterValue->dump()); |
821 | V->replaceAllUsesWith(V: OuterValue); |
822 | } |
823 | |
824 | // Tell LoopInfo, SCEV and the pass manager that the inner loop has been |
825 | // deleted, and invalidate any outer loop information. |
826 | SE->forgetLoop(L: FI.OuterLoop); |
827 | SE->forgetBlockAndLoopDispositions(); |
828 | if (U) |
829 | U->markLoopAsDeleted(L&: *FI.InnerLoop, Name: FI.InnerLoop->getName()); |
830 | LI->erase(L: FI.InnerLoop); |
831 | |
832 | // Increment statistic value. |
833 | NumFlattened++; |
834 | |
835 | return true; |
836 | } |
837 | |
838 | static bool CanWidenIV(FlattenInfo &FI, DominatorTree *DT, LoopInfo *LI, |
839 | ScalarEvolution *SE, AssumptionCache *AC, |
840 | const TargetTransformInfo *TTI) { |
841 | if (!WidenIV) { |
842 | LLVM_DEBUG(dbgs() << "Widening the IVs is disabled\n" ); |
843 | return false; |
844 | } |
845 | |
846 | LLVM_DEBUG(dbgs() << "Try widening the IVs\n" ); |
847 | Module *M = FI.InnerLoop->getHeader()->getParent()->getParent(); |
848 | auto &DL = M->getDataLayout(); |
849 | auto *InnerType = FI.InnerInductionPHI->getType(); |
850 | auto *OuterType = FI.OuterInductionPHI->getType(); |
851 | unsigned MaxLegalSize = DL.getLargestLegalIntTypeSizeInBits(); |
852 | auto *MaxLegalType = DL.getLargestLegalIntType(C&: M->getContext()); |
853 | |
854 | // If both induction types are less than the maximum legal integer width, |
855 | // promote both to the widest type available so we know calculating |
856 | // (OuterTripCount * InnerTripCount) as the new trip count is safe. |
857 | if (InnerType != OuterType || |
858 | InnerType->getScalarSizeInBits() >= MaxLegalSize || |
859 | MaxLegalType->getScalarSizeInBits() < |
860 | InnerType->getScalarSizeInBits() * 2) { |
861 | LLVM_DEBUG(dbgs() << "Can't widen the IV\n" ); |
862 | return false; |
863 | } |
864 | |
865 | SCEVExpander Rewriter(*SE, DL, "loopflatten" ); |
866 | SmallVector<WeakTrackingVH, 4> DeadInsts; |
867 | unsigned ElimExt = 0; |
868 | unsigned Widened = 0; |
869 | |
870 | auto CreateWideIV = [&](WideIVInfo WideIV, bool &Deleted) -> bool { |
871 | PHINode *WidePhi = |
872 | createWideIV(WI: WideIV, LI, SE, Rewriter, DT, DeadInsts, NumElimExt&: ElimExt, NumWidened&: Widened, |
873 | HasGuards: true /* HasGuards */, UsePostIncrementRanges: true /* UsePostIncrementRanges */); |
874 | if (!WidePhi) |
875 | return false; |
876 | LLVM_DEBUG(dbgs() << "Created wide phi: " ; WidePhi->dump()); |
877 | LLVM_DEBUG(dbgs() << "Deleting old phi: " ; WideIV.NarrowIV->dump()); |
878 | Deleted = RecursivelyDeleteDeadPHINode(PN: WideIV.NarrowIV); |
879 | return true; |
880 | }; |
881 | |
882 | bool Deleted; |
883 | if (!CreateWideIV({.NarrowIV: FI.InnerInductionPHI, .WidestNativeType: MaxLegalType, .IsSigned: false}, Deleted)) |
884 | return false; |
885 | // Add the narrow phi to list, so that it will be adjusted later when the |
886 | // the transformation is performed. |
887 | if (!Deleted) |
888 | FI.InnerPHIsToTransform.insert(Ptr: FI.InnerInductionPHI); |
889 | |
890 | if (!CreateWideIV({.NarrowIV: FI.OuterInductionPHI, .WidestNativeType: MaxLegalType, .IsSigned: false}, Deleted)) |
891 | return false; |
892 | |
893 | assert(Widened && "Widened IV expected" ); |
894 | FI.Widened = true; |
895 | |
896 | // Save the old/narrow induction phis, which we need to ignore in CheckPHIs. |
897 | FI.NarrowInnerInductionPHI = FI.InnerInductionPHI; |
898 | FI.NarrowOuterInductionPHI = FI.OuterInductionPHI; |
899 | |
900 | // After widening, rediscover all the loop components. |
901 | return CanFlattenLoopPair(FI, DT, LI, SE, AC, TTI); |
902 | } |
903 | |
904 | static bool FlattenLoopPair(FlattenInfo &FI, DominatorTree *DT, LoopInfo *LI, |
905 | ScalarEvolution *SE, AssumptionCache *AC, |
906 | const TargetTransformInfo *TTI, LPMUpdater *U, |
907 | MemorySSAUpdater *MSSAU, |
908 | const LoopAccessInfo &LAI) { |
909 | LLVM_DEBUG( |
910 | dbgs() << "Loop flattening running on outer loop " |
911 | << FI.OuterLoop->getHeader()->getName() << " and inner loop " |
912 | << FI.InnerLoop->getHeader()->getName() << " in " |
913 | << FI.OuterLoop->getHeader()->getParent()->getName() << "\n" ); |
914 | |
915 | if (!CanFlattenLoopPair(FI, DT, LI, SE, AC, TTI)) |
916 | return false; |
917 | |
918 | // Check if we can widen the induction variables to avoid overflow checks. |
919 | bool CanFlatten = CanWidenIV(FI, DT, LI, SE, AC, TTI); |
920 | |
921 | // It can happen that after widening of the IV, flattening may not be |
922 | // possible/happening, e.g. when it is deemed unprofitable. So bail here if |
923 | // that is the case. |
924 | // TODO: IV widening without performing the actual flattening transformation |
925 | // is not ideal. While this codegen change should not matter much, it is an |
926 | // unnecessary change which is better to avoid. It's unlikely this happens |
927 | // often, because if it's unprofitibale after widening, it should be |
928 | // unprofitabe before widening as checked in the first round of checks. But |
929 | // 'RepeatedInstructionThreshold' is set to only 2, which can probably be |
930 | // relaxed. Because this is making a code change (the IV widening, but not |
931 | // the flattening), we return true here. |
932 | if (FI.Widened && !CanFlatten) |
933 | return true; |
934 | |
935 | // If we have widened and can perform the transformation, do that here. |
936 | if (CanFlatten) |
937 | return DoFlattenLoopPair(FI, DT, LI, SE, AC, TTI, U, MSSAU); |
938 | |
939 | // Otherwise, if we haven't widened the IV, check if the new iteration |
940 | // variable might overflow. In this case, we need to version the loop, and |
941 | // select the original version at runtime if the iteration space is too |
942 | // large. |
943 | OverflowResult OR = checkOverflow(FI, DT, AC); |
944 | if (OR == OverflowResult::AlwaysOverflowsHigh || |
945 | OR == OverflowResult::AlwaysOverflowsLow) { |
946 | LLVM_DEBUG(dbgs() << "Multiply would always overflow, so not profitable\n" ); |
947 | return false; |
948 | } else if (OR == OverflowResult::MayOverflow) { |
949 | Module *M = FI.OuterLoop->getHeader()->getParent()->getParent(); |
950 | const DataLayout &DL = M->getDataLayout(); |
951 | if (!VersionLoops) { |
952 | LLVM_DEBUG(dbgs() << "Multiply might overflow, not flattening\n" ); |
953 | return false; |
954 | } else if (!DL.isLegalInteger( |
955 | Width: FI.OuterTripCount->getType()->getScalarSizeInBits())) { |
956 | // If the trip count type isn't legal then it won't be possible to check |
957 | // for overflow using only a single multiply instruction, so don't |
958 | // flatten. |
959 | LLVM_DEBUG( |
960 | dbgs() << "Can't check overflow efficiently, not flattening\n" ); |
961 | return false; |
962 | } |
963 | LLVM_DEBUG(dbgs() << "Multiply might overflow, versioning loop\n" ); |
964 | |
965 | // Version the loop. The overflow check isn't a runtime pointer check, so we |
966 | // pass an empty list of runtime pointer checks, causing LoopVersioning to |
967 | // emit 'false' as the branch condition, and add our own check afterwards. |
968 | BasicBlock *CheckBlock = FI.OuterLoop->getLoopPreheader(); |
969 | ArrayRef<RuntimePointerCheck> Checks(nullptr, nullptr); |
970 | LoopVersioning LVer(LAI, Checks, FI.OuterLoop, LI, DT, SE); |
971 | LVer.versionLoop(); |
972 | |
973 | // Check for overflow by calculating the new tripcount using |
974 | // umul_with_overflow and then checking if it overflowed. |
975 | BranchInst *Br = cast<BranchInst>(Val: CheckBlock->getTerminator()); |
976 | assert(Br->isConditional() && |
977 | "Expected LoopVersioning to generate a conditional branch" ); |
978 | assert(match(Br->getCondition(), m_Zero()) && |
979 | "Expected branch condition to be false" ); |
980 | IRBuilder<> Builder(Br); |
981 | Function *F = Intrinsic::getDeclaration(M, id: Intrinsic::umul_with_overflow, |
982 | Tys: FI.OuterTripCount->getType()); |
983 | Value *Call = Builder.CreateCall(Callee: F, Args: {FI.OuterTripCount, FI.InnerTripCount}, |
984 | Name: "flatten.mul" ); |
985 | FI.NewTripCount = Builder.CreateExtractValue(Agg: Call, Idxs: 0, Name: "flatten.tripcount" ); |
986 | Value *Overflow = Builder.CreateExtractValue(Agg: Call, Idxs: 1, Name: "flatten.overflow" ); |
987 | Br->setCondition(Overflow); |
988 | } else { |
989 | LLVM_DEBUG(dbgs() << "Multiply cannot overflow, modifying loop in-place\n" ); |
990 | } |
991 | |
992 | return DoFlattenLoopPair(FI, DT, LI, SE, AC, TTI, U, MSSAU); |
993 | } |
994 | |
995 | PreservedAnalyses LoopFlattenPass::run(LoopNest &LN, LoopAnalysisManager &LAM, |
996 | LoopStandardAnalysisResults &AR, |
997 | LPMUpdater &U) { |
998 | |
999 | bool Changed = false; |
1000 | |
1001 | std::optional<MemorySSAUpdater> MSSAU; |
1002 | if (AR.MSSA) { |
1003 | MSSAU = MemorySSAUpdater(AR.MSSA); |
1004 | if (VerifyMemorySSA) |
1005 | AR.MSSA->verifyMemorySSA(); |
1006 | } |
1007 | |
1008 | // The loop flattening pass requires loops to be |
1009 | // in simplified form, and also needs LCSSA. Running |
1010 | // this pass will simplify all loops that contain inner loops, |
1011 | // regardless of whether anything ends up being flattened. |
1012 | LoopAccessInfoManager LAIM(AR.SE, AR.AA, AR.DT, AR.LI, &AR.TTI, nullptr); |
1013 | for (Loop *InnerLoop : LN.getLoops()) { |
1014 | auto *OuterLoop = InnerLoop->getParentLoop(); |
1015 | if (!OuterLoop) |
1016 | continue; |
1017 | FlattenInfo FI(OuterLoop, InnerLoop); |
1018 | Changed |= |
1019 | FlattenLoopPair(FI, DT: &AR.DT, LI: &AR.LI, SE: &AR.SE, AC: &AR.AC, TTI: &AR.TTI, U: &U, |
1020 | MSSAU: MSSAU ? &*MSSAU : nullptr, LAI: LAIM.getInfo(L&: *OuterLoop)); |
1021 | } |
1022 | |
1023 | if (!Changed) |
1024 | return PreservedAnalyses::all(); |
1025 | |
1026 | if (AR.MSSA && VerifyMemorySSA) |
1027 | AR.MSSA->verifyMemorySSA(); |
1028 | |
1029 | auto PA = getLoopPassPreservedAnalyses(); |
1030 | if (AR.MSSA) |
1031 | PA.preserve<MemorySSAAnalysis>(); |
1032 | return PA; |
1033 | } |
1034 | |