1 | //===- NaryReassociate.cpp - Reassociate n-ary expressions ----------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This pass reassociates n-ary add expressions and eliminates the redundancy |
10 | // exposed by the reassociation. |
11 | // |
12 | // A motivating example: |
13 | // |
14 | // void foo(int a, int b) { |
15 | // bar(a + b); |
16 | // bar((a + 2) + b); |
17 | // } |
18 | // |
19 | // An ideal compiler should reassociate (a + 2) + b to (a + b) + 2 and simplify |
20 | // the above code to |
21 | // |
22 | // int t = a + b; |
23 | // bar(t); |
24 | // bar(t + 2); |
25 | // |
26 | // However, the Reassociate pass is unable to do that because it processes each |
27 | // instruction individually and believes (a + 2) + b is the best form according |
28 | // to its rank system. |
29 | // |
30 | // To address this limitation, NaryReassociate reassociates an expression in a |
31 | // form that reuses existing instructions. As a result, NaryReassociate can |
32 | // reassociate (a + 2) + b in the example to (a + b) + 2 because it detects that |
33 | // (a + b) is computed before. |
34 | // |
35 | // NaryReassociate works as follows. For every instruction in the form of (a + |
36 | // b) + c, it checks whether a + c or b + c is already computed by a dominating |
37 | // instruction. If so, it then reassociates (a + b) + c into (a + c) + b or (b + |
38 | // c) + a and removes the redundancy accordingly. To efficiently look up whether |
39 | // an expression is computed before, we store each instruction seen and its SCEV |
40 | // into an SCEV-to-instruction map. |
41 | // |
42 | // Although the algorithm pattern-matches only ternary additions, it |
43 | // automatically handles many >3-ary expressions by walking through the function |
44 | // in the depth-first order. For example, given |
45 | // |
46 | // (a + c) + d |
47 | // ((a + b) + c) + d |
48 | // |
49 | // NaryReassociate first rewrites (a + b) + c to (a + c) + b, and then rewrites |
50 | // ((a + c) + b) + d into ((a + c) + d) + b. |
51 | // |
52 | // Finally, the above dominator-based algorithm may need to be run multiple |
53 | // iterations before emitting optimal code. One source of this need is that we |
54 | // only split an operand when it is used only once. The above algorithm can |
55 | // eliminate an instruction and decrease the usage count of its operands. As a |
56 | // result, an instruction that previously had multiple uses may become a |
57 | // single-use instruction and thus eligible for split consideration. For |
58 | // example, |
59 | // |
60 | // ac = a + c |
61 | // ab = a + b |
62 | // abc = ab + c |
63 | // ab2 = ab + b |
64 | // ab2c = ab2 + c |
65 | // |
66 | // In the first iteration, we cannot reassociate abc to ac+b because ab is used |
67 | // twice. However, we can reassociate ab2c to abc+b in the first iteration. As a |
68 | // result, ab2 becomes dead and ab will be used only once in the second |
69 | // iteration. |
70 | // |
71 | // Limitations and TODO items: |
72 | // |
73 | // 1) We only considers n-ary adds and muls for now. This should be extended |
74 | // and generalized. |
75 | // |
76 | //===----------------------------------------------------------------------===// |
77 | |
78 | #include "llvm/Transforms/Scalar/NaryReassociate.h" |
79 | #include "llvm/ADT/DepthFirstIterator.h" |
80 | #include "llvm/ADT/SmallVector.h" |
81 | #include "llvm/Analysis/AssumptionCache.h" |
82 | #include "llvm/Analysis/ScalarEvolution.h" |
83 | #include "llvm/Analysis/ScalarEvolutionExpressions.h" |
84 | #include "llvm/Analysis/TargetLibraryInfo.h" |
85 | #include "llvm/Analysis/TargetTransformInfo.h" |
86 | #include "llvm/Analysis/ValueTracking.h" |
87 | #include "llvm/IR/BasicBlock.h" |
88 | #include "llvm/IR/Constants.h" |
89 | #include "llvm/IR/DataLayout.h" |
90 | #include "llvm/IR/DerivedTypes.h" |
91 | #include "llvm/IR/Dominators.h" |
92 | #include "llvm/IR/Function.h" |
93 | #include "llvm/IR/GetElementPtrTypeIterator.h" |
94 | #include "llvm/IR/IRBuilder.h" |
95 | #include "llvm/IR/InstrTypes.h" |
96 | #include "llvm/IR/Instruction.h" |
97 | #include "llvm/IR/Instructions.h" |
98 | #include "llvm/IR/Module.h" |
99 | #include "llvm/IR/Operator.h" |
100 | #include "llvm/IR/PatternMatch.h" |
101 | #include "llvm/IR/Type.h" |
102 | #include "llvm/IR/Value.h" |
103 | #include "llvm/IR/ValueHandle.h" |
104 | #include "llvm/InitializePasses.h" |
105 | #include "llvm/Pass.h" |
106 | #include "llvm/Support/Casting.h" |
107 | #include "llvm/Support/ErrorHandling.h" |
108 | #include "llvm/Transforms/Scalar.h" |
109 | #include "llvm/Transforms/Utils/Local.h" |
110 | #include "llvm/Transforms/Utils/ScalarEvolutionExpander.h" |
111 | #include <cassert> |
112 | #include <cstdint> |
113 | |
114 | using namespace llvm; |
115 | using namespace PatternMatch; |
116 | |
117 | #define DEBUG_TYPE "nary-reassociate" |
118 | |
119 | namespace { |
120 | |
121 | class NaryReassociateLegacyPass : public FunctionPass { |
122 | public: |
123 | static char ID; |
124 | |
125 | NaryReassociateLegacyPass() : FunctionPass(ID) { |
126 | initializeNaryReassociateLegacyPassPass(*PassRegistry::getPassRegistry()); |
127 | } |
128 | |
129 | bool doInitialization(Module &M) override { |
130 | return false; |
131 | } |
132 | |
133 | bool runOnFunction(Function &F) override; |
134 | |
135 | void getAnalysisUsage(AnalysisUsage &AU) const override { |
136 | AU.addPreserved<DominatorTreeWrapperPass>(); |
137 | AU.addPreserved<ScalarEvolutionWrapperPass>(); |
138 | AU.addPreserved<TargetLibraryInfoWrapperPass>(); |
139 | AU.addRequired<AssumptionCacheTracker>(); |
140 | AU.addRequired<DominatorTreeWrapperPass>(); |
141 | AU.addRequired<ScalarEvolutionWrapperPass>(); |
142 | AU.addRequired<TargetLibraryInfoWrapperPass>(); |
143 | AU.addRequired<TargetTransformInfoWrapperPass>(); |
144 | AU.setPreservesCFG(); |
145 | } |
146 | |
147 | private: |
148 | NaryReassociatePass Impl; |
149 | }; |
150 | |
151 | } // end anonymous namespace |
152 | |
153 | char NaryReassociateLegacyPass::ID = 0; |
154 | |
155 | INITIALIZE_PASS_BEGIN(NaryReassociateLegacyPass, "nary-reassociate" , |
156 | "Nary reassociation" , false, false) |
157 | INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker) |
158 | INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) |
159 | INITIALIZE_PASS_DEPENDENCY(ScalarEvolutionWrapperPass) |
160 | INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass) |
161 | INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass) |
162 | INITIALIZE_PASS_END(NaryReassociateLegacyPass, "nary-reassociate" , |
163 | "Nary reassociation" , false, false) |
164 | |
165 | FunctionPass *llvm::createNaryReassociatePass() { |
166 | return new NaryReassociateLegacyPass(); |
167 | } |
168 | |
169 | bool NaryReassociateLegacyPass::runOnFunction(Function &F) { |
170 | if (skipFunction(F)) |
171 | return false; |
172 | |
173 | auto *AC = &getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F); |
174 | auto *DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree(); |
175 | auto *SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE(); |
176 | auto *TLI = &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F); |
177 | auto *TTI = &getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F); |
178 | |
179 | return Impl.runImpl(F, AC_: AC, DT_: DT, SE_: SE, TLI_: TLI, TTI_: TTI); |
180 | } |
181 | |
182 | PreservedAnalyses NaryReassociatePass::run(Function &F, |
183 | FunctionAnalysisManager &AM) { |
184 | auto *AC = &AM.getResult<AssumptionAnalysis>(IR&: F); |
185 | auto *DT = &AM.getResult<DominatorTreeAnalysis>(IR&: F); |
186 | auto *SE = &AM.getResult<ScalarEvolutionAnalysis>(IR&: F); |
187 | auto *TLI = &AM.getResult<TargetLibraryAnalysis>(IR&: F); |
188 | auto *TTI = &AM.getResult<TargetIRAnalysis>(IR&: F); |
189 | |
190 | if (!runImpl(F, AC_: AC, DT_: DT, SE_: SE, TLI_: TLI, TTI_: TTI)) |
191 | return PreservedAnalyses::all(); |
192 | |
193 | PreservedAnalyses PA; |
194 | PA.preserveSet<CFGAnalyses>(); |
195 | PA.preserve<ScalarEvolutionAnalysis>(); |
196 | return PA; |
197 | } |
198 | |
199 | bool NaryReassociatePass::runImpl(Function &F, AssumptionCache *AC_, |
200 | DominatorTree *DT_, ScalarEvolution *SE_, |
201 | TargetLibraryInfo *TLI_, |
202 | TargetTransformInfo *TTI_) { |
203 | AC = AC_; |
204 | DT = DT_; |
205 | SE = SE_; |
206 | TLI = TLI_; |
207 | TTI = TTI_; |
208 | DL = &F.getDataLayout(); |
209 | |
210 | bool Changed = false, ChangedInThisIteration; |
211 | do { |
212 | ChangedInThisIteration = doOneIteration(F); |
213 | Changed |= ChangedInThisIteration; |
214 | } while (ChangedInThisIteration); |
215 | return Changed; |
216 | } |
217 | |
218 | bool NaryReassociatePass::doOneIteration(Function &F) { |
219 | bool Changed = false; |
220 | SeenExprs.clear(); |
221 | // Process the basic blocks in a depth first traversal of the dominator |
222 | // tree. This order ensures that all bases of a candidate are in Candidates |
223 | // when we process it. |
224 | SmallVector<WeakTrackingVH, 16> DeadInsts; |
225 | for (const auto Node : depth_first(G: DT)) { |
226 | BasicBlock *BB = Node->getBlock(); |
227 | for (Instruction &OrigI : *BB) { |
228 | const SCEV *OrigSCEV = nullptr; |
229 | if (Instruction *NewI = tryReassociate(I: &OrigI, OrigSCEV)) { |
230 | Changed = true; |
231 | OrigI.replaceAllUsesWith(V: NewI); |
232 | |
233 | // Add 'OrigI' to the list of dead instructions. |
234 | DeadInsts.push_back(Elt: WeakTrackingVH(&OrigI)); |
235 | // Add the rewritten instruction to SeenExprs; the original |
236 | // instruction is deleted. |
237 | const SCEV *NewSCEV = SE->getSCEV(V: NewI); |
238 | SeenExprs[NewSCEV].push_back(Elt: WeakTrackingVH(NewI)); |
239 | |
240 | // Ideally, NewSCEV should equal OldSCEV because tryReassociate(I) |
241 | // is equivalent to I. However, ScalarEvolution::getSCEV may |
242 | // weaken nsw causing NewSCEV not to equal OldSCEV. For example, |
243 | // suppose we reassociate |
244 | // I = &a[sext(i +nsw j)] // assuming sizeof(a[0]) = 4 |
245 | // to |
246 | // NewI = &a[sext(i)] + sext(j). |
247 | // |
248 | // ScalarEvolution computes |
249 | // getSCEV(I) = a + 4 * sext(i + j) |
250 | // getSCEV(newI) = a + 4 * sext(i) + 4 * sext(j) |
251 | // which are different SCEVs. |
252 | // |
253 | // To alleviate this issue of ScalarEvolution not always capturing |
254 | // equivalence, we add I to SeenExprs[OldSCEV] as well so that we can |
255 | // map both SCEV before and after tryReassociate(I) to I. |
256 | // |
257 | // This improvement is exercised in @reassociate_gep_nsw in |
258 | // nary-gep.ll. |
259 | if (NewSCEV != OrigSCEV) |
260 | SeenExprs[OrigSCEV].push_back(Elt: WeakTrackingVH(NewI)); |
261 | } else if (OrigSCEV) |
262 | SeenExprs[OrigSCEV].push_back(Elt: WeakTrackingVH(&OrigI)); |
263 | } |
264 | } |
265 | // Delete all dead instructions from 'DeadInsts'. |
266 | // Please note ScalarEvolution is updated along the way. |
267 | RecursivelyDeleteTriviallyDeadInstructionsPermissive( |
268 | DeadInsts, TLI, MSSAU: nullptr, AboutToDeleteCallback: [this](Value *V) { SE->forgetValue(V); }); |
269 | |
270 | return Changed; |
271 | } |
272 | |
273 | template <typename PredT> |
274 | Instruction * |
275 | NaryReassociatePass::matchAndReassociateMinOrMax(Instruction *I, |
276 | const SCEV *&OrigSCEV) { |
277 | Value *LHS = nullptr; |
278 | Value *RHS = nullptr; |
279 | |
280 | auto MinMaxMatcher = |
281 | MaxMin_match<ICmpInst, bind_ty<Value>, bind_ty<Value>, PredT>( |
282 | m_Value(V&: LHS), m_Value(V&: RHS)); |
283 | if (match(I, MinMaxMatcher)) { |
284 | OrigSCEV = SE->getSCEV(V: I); |
285 | if (auto *NewMinMax = dyn_cast_or_null<Instruction>( |
286 | tryReassociateMinOrMax(I, MinMaxMatcher, LHS, RHS))) |
287 | return NewMinMax; |
288 | if (auto *NewMinMax = dyn_cast_or_null<Instruction>( |
289 | tryReassociateMinOrMax(I, MinMaxMatcher, RHS, LHS))) |
290 | return NewMinMax; |
291 | } |
292 | return nullptr; |
293 | } |
294 | |
295 | Instruction *NaryReassociatePass::tryReassociate(Instruction * I, |
296 | const SCEV *&OrigSCEV) { |
297 | |
298 | if (!SE->isSCEVable(Ty: I->getType())) |
299 | return nullptr; |
300 | |
301 | switch (I->getOpcode()) { |
302 | case Instruction::Add: |
303 | case Instruction::Mul: |
304 | OrigSCEV = SE->getSCEV(V: I); |
305 | return tryReassociateBinaryOp(I: cast<BinaryOperator>(Val: I)); |
306 | case Instruction::GetElementPtr: |
307 | OrigSCEV = SE->getSCEV(V: I); |
308 | return tryReassociateGEP(GEP: cast<GetElementPtrInst>(Val: I)); |
309 | default: |
310 | break; |
311 | } |
312 | |
313 | // Try to match signed/unsigned Min/Max. |
314 | Instruction *ResI = nullptr; |
315 | // TODO: Currently min/max reassociation is restricted to integer types only |
316 | // due to use of SCEVExpander which my introduce incompatible forms of min/max |
317 | // for pointer types. |
318 | if (I->getType()->isIntegerTy()) |
319 | if ((ResI = matchAndReassociateMinOrMax<umin_pred_ty>(I, OrigSCEV)) || |
320 | (ResI = matchAndReassociateMinOrMax<smin_pred_ty>(I, OrigSCEV)) || |
321 | (ResI = matchAndReassociateMinOrMax<umax_pred_ty>(I, OrigSCEV)) || |
322 | (ResI = matchAndReassociateMinOrMax<smax_pred_ty>(I, OrigSCEV))) |
323 | return ResI; |
324 | |
325 | return nullptr; |
326 | } |
327 | |
328 | static bool isGEPFoldable(GetElementPtrInst *GEP, |
329 | const TargetTransformInfo *TTI) { |
330 | SmallVector<const Value *, 4> Indices(GEP->indices()); |
331 | return TTI->getGEPCost(PointeeType: GEP->getSourceElementType(), Ptr: GEP->getPointerOperand(), |
332 | Operands: Indices) == TargetTransformInfo::TCC_Free; |
333 | } |
334 | |
335 | Instruction *NaryReassociatePass::tryReassociateGEP(GetElementPtrInst *GEP) { |
336 | // Not worth reassociating GEP if it is foldable. |
337 | if (isGEPFoldable(GEP, TTI)) |
338 | return nullptr; |
339 | |
340 | gep_type_iterator GTI = gep_type_begin(GEP: *GEP); |
341 | for (unsigned I = 1, E = GEP->getNumOperands(); I != E; ++I, ++GTI) { |
342 | if (GTI.isSequential()) { |
343 | if (auto *NewGEP = tryReassociateGEPAtIndex(GEP, I: I - 1, |
344 | IndexedType: GTI.getIndexedType())) { |
345 | return NewGEP; |
346 | } |
347 | } |
348 | } |
349 | return nullptr; |
350 | } |
351 | |
352 | bool NaryReassociatePass::requiresSignExtension(Value *Index, |
353 | GetElementPtrInst *GEP) { |
354 | unsigned IndexSizeInBits = |
355 | DL->getIndexSizeInBits(AS: GEP->getType()->getPointerAddressSpace()); |
356 | return cast<IntegerType>(Val: Index->getType())->getBitWidth() < IndexSizeInBits; |
357 | } |
358 | |
359 | GetElementPtrInst * |
360 | NaryReassociatePass::tryReassociateGEPAtIndex(GetElementPtrInst *GEP, |
361 | unsigned I, Type *IndexedType) { |
362 | SimplifyQuery SQ(*DL, DT, AC, GEP); |
363 | Value *IndexToSplit = GEP->getOperand(i_nocapture: I + 1); |
364 | if (SExtInst *SExt = dyn_cast<SExtInst>(Val: IndexToSplit)) { |
365 | IndexToSplit = SExt->getOperand(i_nocapture: 0); |
366 | } else if (ZExtInst *ZExt = dyn_cast<ZExtInst>(Val: IndexToSplit)) { |
367 | // zext can be treated as sext if the source is non-negative. |
368 | if (isKnownNonNegative(V: ZExt->getOperand(i_nocapture: 0), SQ)) |
369 | IndexToSplit = ZExt->getOperand(i_nocapture: 0); |
370 | } |
371 | |
372 | if (AddOperator *AO = dyn_cast<AddOperator>(Val: IndexToSplit)) { |
373 | // If the I-th index needs sext and the underlying add is not equipped with |
374 | // nsw, we cannot split the add because |
375 | // sext(LHS + RHS) != sext(LHS) + sext(RHS). |
376 | if (requiresSignExtension(Index: IndexToSplit, GEP) && |
377 | computeOverflowForSignedAdd(Add: AO, SQ) != OverflowResult::NeverOverflows) |
378 | return nullptr; |
379 | |
380 | Value *LHS = AO->getOperand(i_nocapture: 0), *RHS = AO->getOperand(i_nocapture: 1); |
381 | // IndexToSplit = LHS + RHS. |
382 | if (auto *NewGEP = tryReassociateGEPAtIndex(GEP, I, LHS, RHS, IndexedType)) |
383 | return NewGEP; |
384 | // Symmetrically, try IndexToSplit = RHS + LHS. |
385 | if (LHS != RHS) { |
386 | if (auto *NewGEP = |
387 | tryReassociateGEPAtIndex(GEP, I, LHS: RHS, RHS: LHS, IndexedType)) |
388 | return NewGEP; |
389 | } |
390 | } |
391 | return nullptr; |
392 | } |
393 | |
394 | GetElementPtrInst * |
395 | NaryReassociatePass::tryReassociateGEPAtIndex(GetElementPtrInst *GEP, |
396 | unsigned I, Value *LHS, |
397 | Value *RHS, Type *IndexedType) { |
398 | // Look for GEP's closest dominator that has the same SCEV as GEP except that |
399 | // the I-th index is replaced with LHS. |
400 | SmallVector<const SCEV *, 4> IndexExprs; |
401 | for (Use &Index : GEP->indices()) |
402 | IndexExprs.push_back(Elt: SE->getSCEV(V: Index)); |
403 | // Replace the I-th index with LHS. |
404 | IndexExprs[I] = SE->getSCEV(V: LHS); |
405 | if (isKnownNonNegative(V: LHS, SQ: SimplifyQuery(*DL, DT, AC, GEP)) && |
406 | DL->getTypeSizeInBits(Ty: LHS->getType()).getFixedValue() < |
407 | DL->getTypeSizeInBits(Ty: GEP->getOperand(i_nocapture: I)->getType()) |
408 | .getFixedValue()) { |
409 | // Zero-extend LHS if it is non-negative. InstCombine canonicalizes sext to |
410 | // zext if the source operand is proved non-negative. We should do that |
411 | // consistently so that CandidateExpr more likely appears before. See |
412 | // @reassociate_gep_assume for an example of this canonicalization. |
413 | IndexExprs[I] = |
414 | SE->getZeroExtendExpr(Op: IndexExprs[I], Ty: GEP->getOperand(i_nocapture: I)->getType()); |
415 | } |
416 | const SCEV *CandidateExpr = SE->getGEPExpr(GEP: cast<GEPOperator>(Val: GEP), |
417 | IndexExprs); |
418 | |
419 | Value *Candidate = findClosestMatchingDominator(CandidateExpr, Dominatee: GEP); |
420 | if (Candidate == nullptr) |
421 | return nullptr; |
422 | |
423 | IRBuilder<> Builder(GEP); |
424 | // Candidate does not necessarily have the same pointer type as GEP. Use |
425 | // bitcast or pointer cast to make sure they have the same type, so that the |
426 | // later RAUW doesn't complain. |
427 | Candidate = Builder.CreateBitOrPointerCast(V: Candidate, DestTy: GEP->getType()); |
428 | assert(Candidate->getType() == GEP->getType()); |
429 | |
430 | // NewGEP = (char *)Candidate + RHS * sizeof(IndexedType) |
431 | uint64_t IndexedSize = DL->getTypeAllocSize(Ty: IndexedType); |
432 | Type *ElementType = GEP->getResultElementType(); |
433 | uint64_t ElementSize = DL->getTypeAllocSize(Ty: ElementType); |
434 | // Another less rare case: because I is not necessarily the last index of the |
435 | // GEP, the size of the type at the I-th index (IndexedSize) is not |
436 | // necessarily divisible by ElementSize. For example, |
437 | // |
438 | // #pragma pack(1) |
439 | // struct S { |
440 | // int a[3]; |
441 | // int64 b[8]; |
442 | // }; |
443 | // #pragma pack() |
444 | // |
445 | // sizeof(S) = 100 is indivisible by sizeof(int64) = 8. |
446 | // |
447 | // TODO: bail out on this case for now. We could emit uglygep. |
448 | if (IndexedSize % ElementSize != 0) |
449 | return nullptr; |
450 | |
451 | // NewGEP = &Candidate[RHS * (sizeof(IndexedType) / sizeof(Candidate[0]))); |
452 | Type *PtrIdxTy = DL->getIndexType(PtrTy: GEP->getType()); |
453 | if (RHS->getType() != PtrIdxTy) |
454 | RHS = Builder.CreateSExtOrTrunc(V: RHS, DestTy: PtrIdxTy); |
455 | if (IndexedSize != ElementSize) { |
456 | RHS = Builder.CreateMul( |
457 | LHS: RHS, RHS: ConstantInt::get(Ty: PtrIdxTy, V: IndexedSize / ElementSize)); |
458 | } |
459 | GetElementPtrInst *NewGEP = cast<GetElementPtrInst>( |
460 | Val: Builder.CreateGEP(Ty: GEP->getResultElementType(), Ptr: Candidate, IdxList: RHS)); |
461 | NewGEP->setIsInBounds(GEP->isInBounds()); |
462 | NewGEP->takeName(V: GEP); |
463 | return NewGEP; |
464 | } |
465 | |
466 | Instruction *NaryReassociatePass::tryReassociateBinaryOp(BinaryOperator *I) { |
467 | Value *LHS = I->getOperand(i_nocapture: 0), *RHS = I->getOperand(i_nocapture: 1); |
468 | // There is no need to reassociate 0. |
469 | if (SE->getSCEV(V: I)->isZero()) |
470 | return nullptr; |
471 | if (auto *NewI = tryReassociateBinaryOp(LHS, RHS, I)) |
472 | return NewI; |
473 | if (auto *NewI = tryReassociateBinaryOp(LHS: RHS, RHS: LHS, I)) |
474 | return NewI; |
475 | return nullptr; |
476 | } |
477 | |
478 | Instruction *NaryReassociatePass::tryReassociateBinaryOp(Value *LHS, Value *RHS, |
479 | BinaryOperator *I) { |
480 | Value *A = nullptr, *B = nullptr; |
481 | // To be conservative, we reassociate I only when it is the only user of (A op |
482 | // B). |
483 | if (LHS->hasOneUse() && matchTernaryOp(I, V: LHS, Op1&: A, Op2&: B)) { |
484 | // I = (A op B) op RHS |
485 | // = (A op RHS) op B or (B op RHS) op A |
486 | const SCEV *AExpr = SE->getSCEV(V: A), *BExpr = SE->getSCEV(V: B); |
487 | const SCEV *RHSExpr = SE->getSCEV(V: RHS); |
488 | if (BExpr != RHSExpr) { |
489 | if (auto *NewI = |
490 | tryReassociatedBinaryOp(LHS: getBinarySCEV(I, LHS: AExpr, RHS: RHSExpr), RHS: B, I)) |
491 | return NewI; |
492 | } |
493 | if (AExpr != RHSExpr) { |
494 | if (auto *NewI = |
495 | tryReassociatedBinaryOp(LHS: getBinarySCEV(I, LHS: BExpr, RHS: RHSExpr), RHS: A, I)) |
496 | return NewI; |
497 | } |
498 | } |
499 | return nullptr; |
500 | } |
501 | |
502 | Instruction *NaryReassociatePass::tryReassociatedBinaryOp(const SCEV *LHSExpr, |
503 | Value *RHS, |
504 | BinaryOperator *I) { |
505 | // Look for the closest dominator LHS of I that computes LHSExpr, and replace |
506 | // I with LHS op RHS. |
507 | auto *LHS = findClosestMatchingDominator(CandidateExpr: LHSExpr, Dominatee: I); |
508 | if (LHS == nullptr) |
509 | return nullptr; |
510 | |
511 | Instruction *NewI = nullptr; |
512 | switch (I->getOpcode()) { |
513 | case Instruction::Add: |
514 | NewI = BinaryOperator::CreateAdd(V1: LHS, V2: RHS, Name: "" , It: I->getIterator()); |
515 | break; |
516 | case Instruction::Mul: |
517 | NewI = BinaryOperator::CreateMul(V1: LHS, V2: RHS, Name: "" , It: I->getIterator()); |
518 | break; |
519 | default: |
520 | llvm_unreachable("Unexpected instruction." ); |
521 | } |
522 | NewI->setDebugLoc(I->getDebugLoc()); |
523 | NewI->takeName(V: I); |
524 | return NewI; |
525 | } |
526 | |
527 | bool NaryReassociatePass::matchTernaryOp(BinaryOperator *I, Value *V, |
528 | Value *&Op1, Value *&Op2) { |
529 | switch (I->getOpcode()) { |
530 | case Instruction::Add: |
531 | return match(V, P: m_Add(L: m_Value(V&: Op1), R: m_Value(V&: Op2))); |
532 | case Instruction::Mul: |
533 | return match(V, P: m_Mul(L: m_Value(V&: Op1), R: m_Value(V&: Op2))); |
534 | default: |
535 | llvm_unreachable("Unexpected instruction." ); |
536 | } |
537 | return false; |
538 | } |
539 | |
540 | const SCEV *NaryReassociatePass::getBinarySCEV(BinaryOperator *I, |
541 | const SCEV *LHS, |
542 | const SCEV *RHS) { |
543 | switch (I->getOpcode()) { |
544 | case Instruction::Add: |
545 | return SE->getAddExpr(LHS, RHS); |
546 | case Instruction::Mul: |
547 | return SE->getMulExpr(LHS, RHS); |
548 | default: |
549 | llvm_unreachable("Unexpected instruction." ); |
550 | } |
551 | return nullptr; |
552 | } |
553 | |
554 | Instruction * |
555 | NaryReassociatePass::findClosestMatchingDominator(const SCEV *CandidateExpr, |
556 | Instruction *Dominatee) { |
557 | auto Pos = SeenExprs.find(Val: CandidateExpr); |
558 | if (Pos == SeenExprs.end()) |
559 | return nullptr; |
560 | |
561 | auto &Candidates = Pos->second; |
562 | // Because we process the basic blocks in pre-order of the dominator tree, a |
563 | // candidate that doesn't dominate the current instruction won't dominate any |
564 | // future instruction either. Therefore, we pop it out of the stack. This |
565 | // optimization makes the algorithm O(n). |
566 | while (!Candidates.empty()) { |
567 | // Candidates stores WeakTrackingVHs, so a candidate can be nullptr if it's |
568 | // removed during rewriting. |
569 | if (Value *Candidate = Candidates.pop_back_val()) { |
570 | Instruction *CandidateInstruction = cast<Instruction>(Val: Candidate); |
571 | if (!DT->dominates(Def: CandidateInstruction, User: Dominatee)) |
572 | continue; |
573 | |
574 | // Make sure that the instruction is safe to reuse without introducing |
575 | // poison. |
576 | SmallVector<Instruction *> DropPoisonGeneratingInsts; |
577 | if (!SE->canReuseInstruction(S: CandidateExpr, I: CandidateInstruction, |
578 | DropPoisonGeneratingInsts)) |
579 | continue; |
580 | |
581 | for (Instruction *I : DropPoisonGeneratingInsts) |
582 | I->dropPoisonGeneratingAnnotations(); |
583 | |
584 | return CandidateInstruction; |
585 | } |
586 | } |
587 | return nullptr; |
588 | } |
589 | |
590 | template <typename MaxMinT> static SCEVTypes convertToSCEVype(MaxMinT &MM) { |
591 | if (std::is_same_v<smax_pred_ty, typename MaxMinT::PredType>) |
592 | return scSMaxExpr; |
593 | else if (std::is_same_v<umax_pred_ty, typename MaxMinT::PredType>) |
594 | return scUMaxExpr; |
595 | else if (std::is_same_v<smin_pred_ty, typename MaxMinT::PredType>) |
596 | return scSMinExpr; |
597 | else if (std::is_same_v<umin_pred_ty, typename MaxMinT::PredType>) |
598 | return scUMinExpr; |
599 | |
600 | llvm_unreachable("Can't convert MinMax pattern to SCEV type" ); |
601 | return scUnknown; |
602 | } |
603 | |
604 | // Parameters: |
605 | // I - instruction matched by MaxMinMatch matcher |
606 | // MaxMinMatch - min/max idiom matcher |
607 | // LHS - first operand of I |
608 | // RHS - second operand of I |
609 | template <typename MaxMinT> |
610 | Value *NaryReassociatePass::tryReassociateMinOrMax(Instruction *I, |
611 | MaxMinT MaxMinMatch, |
612 | Value *LHS, Value *RHS) { |
613 | Value *A = nullptr, *B = nullptr; |
614 | MaxMinT m_MaxMin(m_Value(V&: A), m_Value(V&: B)); |
615 | |
616 | if (LHS->hasNUsesOrMore(N: 3) || |
617 | // The optimization is profitable only if LHS can be removed in the end. |
618 | // In other words LHS should be used (directly or indirectly) by I only. |
619 | llvm::any_of(LHS->users(), |
620 | [&](auto *U) { |
621 | return U != I && |
622 | !(U->hasOneUser() && *U->users().begin() == I); |
623 | }) || |
624 | !match(LHS, m_MaxMin)) |
625 | return nullptr; |
626 | |
627 | auto tryCombination = [&](Value *A, const SCEV *AExpr, Value *B, |
628 | const SCEV *BExpr, Value *C, |
629 | const SCEV *CExpr) -> Value * { |
630 | SmallVector<const SCEV *, 2> Ops1{BExpr, AExpr}; |
631 | const SCEVTypes SCEVType = convertToSCEVype(m_MaxMin); |
632 | const SCEV *R1Expr = SE->getMinMaxExpr(Kind: SCEVType, Operands&: Ops1); |
633 | |
634 | Instruction *R1MinMax = findClosestMatchingDominator(CandidateExpr: R1Expr, Dominatee: I); |
635 | |
636 | if (!R1MinMax) |
637 | return nullptr; |
638 | |
639 | LLVM_DEBUG(dbgs() << "NARY: Found common sub-expr: " << *R1MinMax << "\n" ); |
640 | |
641 | SmallVector<const SCEV *, 2> Ops2{SE->getUnknown(V: C), |
642 | SE->getUnknown(V: R1MinMax)}; |
643 | const SCEV *R2Expr = SE->getMinMaxExpr(Kind: SCEVType, Operands&: Ops2); |
644 | |
645 | SCEVExpander Expander(*SE, *DL, "nary-reassociate" ); |
646 | Value *NewMinMax = Expander.expandCodeFor(SH: R2Expr, Ty: I->getType(), I); |
647 | NewMinMax->setName(Twine(I->getName()).concat(Suffix: ".nary" )); |
648 | |
649 | LLVM_DEBUG(dbgs() << "NARY: Deleting: " << *I << "\n" |
650 | << "NARY: Inserting: " << *NewMinMax << "\n" ); |
651 | return NewMinMax; |
652 | }; |
653 | |
654 | const SCEV *AExpr = SE->getSCEV(V: A); |
655 | const SCEV *BExpr = SE->getSCEV(V: B); |
656 | const SCEV *RHSExpr = SE->getSCEV(V: RHS); |
657 | |
658 | if (BExpr != RHSExpr) { |
659 | // Try (A op RHS) op B |
660 | if (auto *NewMinMax = tryCombination(A, AExpr, RHS, RHSExpr, B, BExpr)) |
661 | return NewMinMax; |
662 | } |
663 | |
664 | if (AExpr != RHSExpr) { |
665 | // Try (RHS op B) op A |
666 | if (auto *NewMinMax = tryCombination(RHS, RHSExpr, B, BExpr, A, AExpr)) |
667 | return NewMinMax; |
668 | } |
669 | |
670 | return nullptr; |
671 | } |
672 | |