1 | //===- AggressiveInstCombine.cpp ------------------------------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file implements the aggressive expression pattern combiner classes. |
10 | // Currently, it handles expression patterns for: |
11 | // * Truncate instruction |
12 | // |
13 | //===----------------------------------------------------------------------===// |
14 | |
15 | #include "llvm/Transforms/AggressiveInstCombine/AggressiveInstCombine.h" |
16 | #include "AggressiveInstCombineInternal.h" |
17 | #include "llvm/ADT/Statistic.h" |
18 | #include "llvm/Analysis/AliasAnalysis.h" |
19 | #include "llvm/Analysis/AssumptionCache.h" |
20 | #include "llvm/Analysis/BasicAliasAnalysis.h" |
21 | #include "llvm/Analysis/ConstantFolding.h" |
22 | #include "llvm/Analysis/DomTreeUpdater.h" |
23 | #include "llvm/Analysis/GlobalsModRef.h" |
24 | #include "llvm/Analysis/TargetLibraryInfo.h" |
25 | #include "llvm/Analysis/TargetTransformInfo.h" |
26 | #include "llvm/Analysis/ValueTracking.h" |
27 | #include "llvm/IR/DataLayout.h" |
28 | #include "llvm/IR/Dominators.h" |
29 | #include "llvm/IR/Function.h" |
30 | #include "llvm/IR/IRBuilder.h" |
31 | #include "llvm/IR/PatternMatch.h" |
32 | #include "llvm/Transforms/Utils/BasicBlockUtils.h" |
33 | #include "llvm/Transforms/Utils/BuildLibCalls.h" |
34 | #include "llvm/Transforms/Utils/Local.h" |
35 | |
36 | using namespace llvm; |
37 | using namespace PatternMatch; |
38 | |
39 | #define DEBUG_TYPE "aggressive-instcombine" |
40 | |
41 | STATISTIC(NumAnyOrAllBitsSet, "Number of any/all-bits-set patterns folded" ); |
42 | STATISTIC(NumGuardedRotates, |
43 | "Number of guarded rotates transformed into funnel shifts" ); |
44 | STATISTIC(NumGuardedFunnelShifts, |
45 | "Number of guarded funnel shifts transformed into funnel shifts" ); |
46 | STATISTIC(NumPopCountRecognized, "Number of popcount idioms recognized" ); |
47 | |
48 | static cl::opt<unsigned> MaxInstrsToScan( |
49 | "aggressive-instcombine-max-scan-instrs" , cl::init(Val: 64), cl::Hidden, |
50 | cl::desc("Max number of instructions to scan for aggressive instcombine." )); |
51 | |
52 | static cl::opt<unsigned> StrNCmpInlineThreshold( |
53 | "strncmp-inline-threshold" , cl::init(Val: 3), cl::Hidden, |
54 | cl::desc("The maximum length of a constant string for a builtin string cmp " |
55 | "call eligible for inlining. The default value is 3." )); |
56 | |
57 | static cl::opt<unsigned> |
58 | MemChrInlineThreshold("memchr-inline-threshold" , cl::init(Val: 3), cl::Hidden, |
59 | cl::desc("The maximum length of a constant string to " |
60 | "inline a memchr call." )); |
61 | |
62 | /// Match a pattern for a bitwise funnel/rotate operation that partially guards |
63 | /// against undefined behavior by branching around the funnel-shift/rotation |
64 | /// when the shift amount is 0. |
65 | static bool foldGuardedFunnelShift(Instruction &I, const DominatorTree &DT) { |
66 | if (I.getOpcode() != Instruction::PHI || I.getNumOperands() != 2) |
67 | return false; |
68 | |
69 | // As with the one-use checks below, this is not strictly necessary, but we |
70 | // are being cautious to avoid potential perf regressions on targets that |
71 | // do not actually have a funnel/rotate instruction (where the funnel shift |
72 | // would be expanded back into math/shift/logic ops). |
73 | if (!isPowerOf2_32(Value: I.getType()->getScalarSizeInBits())) |
74 | return false; |
75 | |
76 | // Match V to funnel shift left/right and capture the source operands and |
77 | // shift amount. |
78 | auto matchFunnelShift = [](Value *V, Value *&ShVal0, Value *&ShVal1, |
79 | Value *&ShAmt) { |
80 | unsigned Width = V->getType()->getScalarSizeInBits(); |
81 | |
82 | // fshl(ShVal0, ShVal1, ShAmt) |
83 | // == (ShVal0 << ShAmt) | (ShVal1 >> (Width -ShAmt)) |
84 | if (match(V, P: m_OneUse(SubPattern: m_c_Or( |
85 | L: m_Shl(L: m_Value(V&: ShVal0), R: m_Value(V&: ShAmt)), |
86 | R: m_LShr(L: m_Value(V&: ShVal1), |
87 | R: m_Sub(L: m_SpecificInt(V: Width), R: m_Deferred(V: ShAmt))))))) { |
88 | return Intrinsic::fshl; |
89 | } |
90 | |
91 | // fshr(ShVal0, ShVal1, ShAmt) |
92 | // == (ShVal0 >> ShAmt) | (ShVal1 << (Width - ShAmt)) |
93 | if (match(V, |
94 | P: m_OneUse(SubPattern: m_c_Or(L: m_Shl(L: m_Value(V&: ShVal0), R: m_Sub(L: m_SpecificInt(V: Width), |
95 | R: m_Value(V&: ShAmt))), |
96 | R: m_LShr(L: m_Value(V&: ShVal1), R: m_Deferred(V: ShAmt)))))) { |
97 | return Intrinsic::fshr; |
98 | } |
99 | |
100 | return Intrinsic::not_intrinsic; |
101 | }; |
102 | |
103 | // One phi operand must be a funnel/rotate operation, and the other phi |
104 | // operand must be the source value of that funnel/rotate operation: |
105 | // phi [ rotate(RotSrc, ShAmt), FunnelBB ], [ RotSrc, GuardBB ] |
106 | // phi [ fshl(ShVal0, ShVal1, ShAmt), FunnelBB ], [ ShVal0, GuardBB ] |
107 | // phi [ fshr(ShVal0, ShVal1, ShAmt), FunnelBB ], [ ShVal1, GuardBB ] |
108 | PHINode &Phi = cast<PHINode>(Val&: I); |
109 | unsigned FunnelOp = 0, GuardOp = 1; |
110 | Value *P0 = Phi.getOperand(i_nocapture: 0), *P1 = Phi.getOperand(i_nocapture: 1); |
111 | Value *ShVal0, *ShVal1, *ShAmt; |
112 | Intrinsic::ID IID = matchFunnelShift(P0, ShVal0, ShVal1, ShAmt); |
113 | if (IID == Intrinsic::not_intrinsic || |
114 | (IID == Intrinsic::fshl && ShVal0 != P1) || |
115 | (IID == Intrinsic::fshr && ShVal1 != P1)) { |
116 | IID = matchFunnelShift(P1, ShVal0, ShVal1, ShAmt); |
117 | if (IID == Intrinsic::not_intrinsic || |
118 | (IID == Intrinsic::fshl && ShVal0 != P0) || |
119 | (IID == Intrinsic::fshr && ShVal1 != P0)) |
120 | return false; |
121 | assert((IID == Intrinsic::fshl || IID == Intrinsic::fshr) && |
122 | "Pattern must match funnel shift left or right" ); |
123 | std::swap(a&: FunnelOp, b&: GuardOp); |
124 | } |
125 | |
126 | // The incoming block with our source operand must be the "guard" block. |
127 | // That must contain a cmp+branch to avoid the funnel/rotate when the shift |
128 | // amount is equal to 0. The other incoming block is the block with the |
129 | // funnel/rotate. |
130 | BasicBlock *GuardBB = Phi.getIncomingBlock(i: GuardOp); |
131 | BasicBlock *FunnelBB = Phi.getIncomingBlock(i: FunnelOp); |
132 | Instruction *TermI = GuardBB->getTerminator(); |
133 | |
134 | // Ensure that the shift values dominate each block. |
135 | if (!DT.dominates(Def: ShVal0, User: TermI) || !DT.dominates(Def: ShVal1, User: TermI)) |
136 | return false; |
137 | |
138 | ICmpInst::Predicate Pred; |
139 | BasicBlock *PhiBB = Phi.getParent(); |
140 | if (!match(V: TermI, P: m_Br(C: m_ICmp(Pred, L: m_Specific(V: ShAmt), R: m_ZeroInt()), |
141 | T: m_SpecificBB(BB: PhiBB), F: m_SpecificBB(BB: FunnelBB)))) |
142 | return false; |
143 | |
144 | if (Pred != CmpInst::ICMP_EQ) |
145 | return false; |
146 | |
147 | IRBuilder<> Builder(PhiBB, PhiBB->getFirstInsertionPt()); |
148 | |
149 | if (ShVal0 == ShVal1) |
150 | ++NumGuardedRotates; |
151 | else |
152 | ++NumGuardedFunnelShifts; |
153 | |
154 | // If this is not a rotate then the select was blocking poison from the |
155 | // 'shift-by-zero' non-TVal, but a funnel shift won't - so freeze it. |
156 | bool IsFshl = IID == Intrinsic::fshl; |
157 | if (ShVal0 != ShVal1) { |
158 | if (IsFshl && !llvm::isGuaranteedNotToBePoison(V: ShVal1)) |
159 | ShVal1 = Builder.CreateFreeze(V: ShVal1); |
160 | else if (!IsFshl && !llvm::isGuaranteedNotToBePoison(V: ShVal0)) |
161 | ShVal0 = Builder.CreateFreeze(V: ShVal0); |
162 | } |
163 | |
164 | // We matched a variation of this IR pattern: |
165 | // GuardBB: |
166 | // %cmp = icmp eq i32 %ShAmt, 0 |
167 | // br i1 %cmp, label %PhiBB, label %FunnelBB |
168 | // FunnelBB: |
169 | // %sub = sub i32 32, %ShAmt |
170 | // %shr = lshr i32 %ShVal1, %sub |
171 | // %shl = shl i32 %ShVal0, %ShAmt |
172 | // %fsh = or i32 %shr, %shl |
173 | // br label %PhiBB |
174 | // PhiBB: |
175 | // %cond = phi i32 [ %fsh, %FunnelBB ], [ %ShVal0, %GuardBB ] |
176 | // --> |
177 | // llvm.fshl.i32(i32 %ShVal0, i32 %ShVal1, i32 %ShAmt) |
178 | Function *F = Intrinsic::getDeclaration(M: Phi.getModule(), id: IID, Tys: Phi.getType()); |
179 | Phi.replaceAllUsesWith(V: Builder.CreateCall(Callee: F, Args: {ShVal0, ShVal1, ShAmt})); |
180 | return true; |
181 | } |
182 | |
183 | /// This is used by foldAnyOrAllBitsSet() to capture a source value (Root) and |
184 | /// the bit indexes (Mask) needed by a masked compare. If we're matching a chain |
185 | /// of 'and' ops, then we also need to capture the fact that we saw an |
186 | /// "and X, 1", so that's an extra return value for that case. |
187 | struct MaskOps { |
188 | Value *Root = nullptr; |
189 | APInt Mask; |
190 | bool MatchAndChain; |
191 | bool FoundAnd1 = false; |
192 | |
193 | MaskOps(unsigned BitWidth, bool MatchAnds) |
194 | : Mask(APInt::getZero(numBits: BitWidth)), MatchAndChain(MatchAnds) {} |
195 | }; |
196 | |
197 | /// This is a recursive helper for foldAnyOrAllBitsSet() that walks through a |
198 | /// chain of 'and' or 'or' instructions looking for shift ops of a common source |
199 | /// value. Examples: |
200 | /// or (or (or X, (X >> 3)), (X >> 5)), (X >> 8) |
201 | /// returns { X, 0x129 } |
202 | /// and (and (X >> 1), 1), (X >> 4) |
203 | /// returns { X, 0x12 } |
204 | static bool matchAndOrChain(Value *V, MaskOps &MOps) { |
205 | Value *Op0, *Op1; |
206 | if (MOps.MatchAndChain) { |
207 | // Recurse through a chain of 'and' operands. This requires an extra check |
208 | // vs. the 'or' matcher: we must find an "and X, 1" instruction somewhere |
209 | // in the chain to know that all of the high bits are cleared. |
210 | if (match(V, P: m_And(L: m_Value(V&: Op0), R: m_One()))) { |
211 | MOps.FoundAnd1 = true; |
212 | return matchAndOrChain(V: Op0, MOps); |
213 | } |
214 | if (match(V, P: m_And(L: m_Value(V&: Op0), R: m_Value(V&: Op1)))) |
215 | return matchAndOrChain(V: Op0, MOps) && matchAndOrChain(V: Op1, MOps); |
216 | } else { |
217 | // Recurse through a chain of 'or' operands. |
218 | if (match(V, P: m_Or(L: m_Value(V&: Op0), R: m_Value(V&: Op1)))) |
219 | return matchAndOrChain(V: Op0, MOps) && matchAndOrChain(V: Op1, MOps); |
220 | } |
221 | |
222 | // We need a shift-right or a bare value representing a compare of bit 0 of |
223 | // the original source operand. |
224 | Value *Candidate; |
225 | const APInt *BitIndex = nullptr; |
226 | if (!match(V, P: m_LShr(L: m_Value(V&: Candidate), R: m_APInt(Res&: BitIndex)))) |
227 | Candidate = V; |
228 | |
229 | // Initialize result source operand. |
230 | if (!MOps.Root) |
231 | MOps.Root = Candidate; |
232 | |
233 | // The shift constant is out-of-range? This code hasn't been simplified. |
234 | if (BitIndex && BitIndex->uge(RHS: MOps.Mask.getBitWidth())) |
235 | return false; |
236 | |
237 | // Fill in the mask bit derived from the shift constant. |
238 | MOps.Mask.setBit(BitIndex ? BitIndex->getZExtValue() : 0); |
239 | return MOps.Root == Candidate; |
240 | } |
241 | |
242 | /// Match patterns that correspond to "any-bits-set" and "all-bits-set". |
243 | /// These will include a chain of 'or' or 'and'-shifted bits from a |
244 | /// common source value: |
245 | /// and (or (lshr X, C), ...), 1 --> (X & CMask) != 0 |
246 | /// and (and (lshr X, C), ...), 1 --> (X & CMask) == CMask |
247 | /// Note: "any-bits-clear" and "all-bits-clear" are variations of these patterns |
248 | /// that differ only with a final 'not' of the result. We expect that final |
249 | /// 'not' to be folded with the compare that we create here (invert predicate). |
250 | static bool foldAnyOrAllBitsSet(Instruction &I) { |
251 | // The 'any-bits-set' ('or' chain) pattern is simpler to match because the |
252 | // final "and X, 1" instruction must be the final op in the sequence. |
253 | bool MatchAllBitsSet; |
254 | if (match(V: &I, P: m_c_And(L: m_OneUse(SubPattern: m_And(L: m_Value(), R: m_Value())), R: m_Value()))) |
255 | MatchAllBitsSet = true; |
256 | else if (match(V: &I, P: m_And(L: m_OneUse(SubPattern: m_Or(L: m_Value(), R: m_Value())), R: m_One()))) |
257 | MatchAllBitsSet = false; |
258 | else |
259 | return false; |
260 | |
261 | MaskOps MOps(I.getType()->getScalarSizeInBits(), MatchAllBitsSet); |
262 | if (MatchAllBitsSet) { |
263 | if (!matchAndOrChain(V: cast<BinaryOperator>(Val: &I), MOps) || !MOps.FoundAnd1) |
264 | return false; |
265 | } else { |
266 | if (!matchAndOrChain(V: cast<BinaryOperator>(Val: &I)->getOperand(i_nocapture: 0), MOps)) |
267 | return false; |
268 | } |
269 | |
270 | // The pattern was found. Create a masked compare that replaces all of the |
271 | // shift and logic ops. |
272 | IRBuilder<> Builder(&I); |
273 | Constant *Mask = ConstantInt::get(Ty: I.getType(), V: MOps.Mask); |
274 | Value *And = Builder.CreateAnd(LHS: MOps.Root, RHS: Mask); |
275 | Value *Cmp = MatchAllBitsSet ? Builder.CreateICmpEQ(LHS: And, RHS: Mask) |
276 | : Builder.CreateIsNotNull(Arg: And); |
277 | Value *Zext = Builder.CreateZExt(V: Cmp, DestTy: I.getType()); |
278 | I.replaceAllUsesWith(V: Zext); |
279 | ++NumAnyOrAllBitsSet; |
280 | return true; |
281 | } |
282 | |
283 | // Try to recognize below function as popcount intrinsic. |
284 | // This is the "best" algorithm from |
285 | // http://graphics.stanford.edu/~seander/bithacks.html#CountBitsSetParallel |
286 | // Also used in TargetLowering::expandCTPOP(). |
287 | // |
288 | // int popcount(unsigned int i) { |
289 | // i = i - ((i >> 1) & 0x55555555); |
290 | // i = (i & 0x33333333) + ((i >> 2) & 0x33333333); |
291 | // i = ((i + (i >> 4)) & 0x0F0F0F0F); |
292 | // return (i * 0x01010101) >> 24; |
293 | // } |
294 | static bool tryToRecognizePopCount(Instruction &I) { |
295 | if (I.getOpcode() != Instruction::LShr) |
296 | return false; |
297 | |
298 | Type *Ty = I.getType(); |
299 | if (!Ty->isIntOrIntVectorTy()) |
300 | return false; |
301 | |
302 | unsigned Len = Ty->getScalarSizeInBits(); |
303 | // FIXME: fix Len == 8 and other irregular type lengths. |
304 | if (!(Len <= 128 && Len > 8 && Len % 8 == 0)) |
305 | return false; |
306 | |
307 | APInt Mask55 = APInt::getSplat(NewLen: Len, V: APInt(8, 0x55)); |
308 | APInt Mask33 = APInt::getSplat(NewLen: Len, V: APInt(8, 0x33)); |
309 | APInt Mask0F = APInt::getSplat(NewLen: Len, V: APInt(8, 0x0F)); |
310 | APInt Mask01 = APInt::getSplat(NewLen: Len, V: APInt(8, 0x01)); |
311 | APInt MaskShift = APInt(Len, Len - 8); |
312 | |
313 | Value *Op0 = I.getOperand(i: 0); |
314 | Value *Op1 = I.getOperand(i: 1); |
315 | Value *MulOp0; |
316 | // Matching "(i * 0x01010101...) >> 24". |
317 | if ((match(V: Op0, P: m_Mul(L: m_Value(V&: MulOp0), R: m_SpecificInt(V: Mask01)))) && |
318 | match(V: Op1, P: m_SpecificInt(V: MaskShift))) { |
319 | Value *ShiftOp0; |
320 | // Matching "((i + (i >> 4)) & 0x0F0F0F0F...)". |
321 | if (match(V: MulOp0, P: m_And(L: m_c_Add(L: m_LShr(L: m_Value(V&: ShiftOp0), R: m_SpecificInt(V: 4)), |
322 | R: m_Deferred(V: ShiftOp0)), |
323 | R: m_SpecificInt(V: Mask0F)))) { |
324 | Value *AndOp0; |
325 | // Matching "(i & 0x33333333...) + ((i >> 2) & 0x33333333...)". |
326 | if (match(V: ShiftOp0, |
327 | P: m_c_Add(L: m_And(L: m_Value(V&: AndOp0), R: m_SpecificInt(V: Mask33)), |
328 | R: m_And(L: m_LShr(L: m_Deferred(V: AndOp0), R: m_SpecificInt(V: 2)), |
329 | R: m_SpecificInt(V: Mask33))))) { |
330 | Value *Root, *SubOp1; |
331 | // Matching "i - ((i >> 1) & 0x55555555...)". |
332 | if (match(V: AndOp0, P: m_Sub(L: m_Value(V&: Root), R: m_Value(V&: SubOp1))) && |
333 | match(V: SubOp1, P: m_And(L: m_LShr(L: m_Specific(V: Root), R: m_SpecificInt(V: 1)), |
334 | R: m_SpecificInt(V: Mask55)))) { |
335 | LLVM_DEBUG(dbgs() << "Recognized popcount intrinsic\n" ); |
336 | IRBuilder<> Builder(&I); |
337 | Function *Func = Intrinsic::getDeclaration( |
338 | M: I.getModule(), id: Intrinsic::ctpop, Tys: I.getType()); |
339 | I.replaceAllUsesWith(V: Builder.CreateCall(Callee: Func, Args: {Root})); |
340 | ++NumPopCountRecognized; |
341 | return true; |
342 | } |
343 | } |
344 | } |
345 | } |
346 | |
347 | return false; |
348 | } |
349 | |
350 | /// Fold smin(smax(fptosi(x), C1), C2) to llvm.fptosi.sat(x), providing C1 and |
351 | /// C2 saturate the value of the fp conversion. The transform is not reversable |
352 | /// as the fptosi.sat is more defined than the input - all values produce a |
353 | /// valid value for the fptosi.sat, where as some produce poison for original |
354 | /// that were out of range of the integer conversion. The reversed pattern may |
355 | /// use fmax and fmin instead. As we cannot directly reverse the transform, and |
356 | /// it is not always profitable, we make it conditional on the cost being |
357 | /// reported as lower by TTI. |
358 | static bool tryToFPToSat(Instruction &I, TargetTransformInfo &TTI) { |
359 | // Look for min(max(fptosi, converting to fptosi_sat. |
360 | Value *In; |
361 | const APInt *MinC, *MaxC; |
362 | if (!match(V: &I, P: m_SMax(L: m_OneUse(SubPattern: m_SMin(L: m_OneUse(SubPattern: m_FPToSI(Op: m_Value(V&: In))), |
363 | R: m_APInt(Res&: MinC))), |
364 | R: m_APInt(Res&: MaxC))) && |
365 | !match(V: &I, P: m_SMin(L: m_OneUse(SubPattern: m_SMax(L: m_OneUse(SubPattern: m_FPToSI(Op: m_Value(V&: In))), |
366 | R: m_APInt(Res&: MaxC))), |
367 | R: m_APInt(Res&: MinC)))) |
368 | return false; |
369 | |
370 | // Check that the constants clamp a saturate. |
371 | if (!(*MinC + 1).isPowerOf2() || -*MaxC != *MinC + 1) |
372 | return false; |
373 | |
374 | Type *IntTy = I.getType(); |
375 | Type *FpTy = In->getType(); |
376 | Type *SatTy = |
377 | IntegerType::get(C&: IntTy->getContext(), NumBits: (*MinC + 1).exactLogBase2() + 1); |
378 | if (auto *VecTy = dyn_cast<VectorType>(Val: IntTy)) |
379 | SatTy = VectorType::get(ElementType: SatTy, EC: VecTy->getElementCount()); |
380 | |
381 | // Get the cost of the intrinsic, and check that against the cost of |
382 | // fptosi+smin+smax |
383 | InstructionCost SatCost = TTI.getIntrinsicInstrCost( |
384 | ICA: IntrinsicCostAttributes(Intrinsic::fptosi_sat, SatTy, {In}, {FpTy}), |
385 | CostKind: TTI::TCK_RecipThroughput); |
386 | SatCost += TTI.getCastInstrCost(Opcode: Instruction::SExt, Dst: IntTy, Src: SatTy, |
387 | CCH: TTI::CastContextHint::None, |
388 | CostKind: TTI::TCK_RecipThroughput); |
389 | |
390 | InstructionCost MinMaxCost = TTI.getCastInstrCost( |
391 | Opcode: Instruction::FPToSI, Dst: IntTy, Src: FpTy, CCH: TTI::CastContextHint::None, |
392 | CostKind: TTI::TCK_RecipThroughput); |
393 | MinMaxCost += TTI.getIntrinsicInstrCost( |
394 | ICA: IntrinsicCostAttributes(Intrinsic::smin, IntTy, {IntTy}), |
395 | CostKind: TTI::TCK_RecipThroughput); |
396 | MinMaxCost += TTI.getIntrinsicInstrCost( |
397 | ICA: IntrinsicCostAttributes(Intrinsic::smax, IntTy, {IntTy}), |
398 | CostKind: TTI::TCK_RecipThroughput); |
399 | |
400 | if (SatCost >= MinMaxCost) |
401 | return false; |
402 | |
403 | IRBuilder<> Builder(&I); |
404 | Function *Fn = Intrinsic::getDeclaration(M: I.getModule(), id: Intrinsic::fptosi_sat, |
405 | Tys: {SatTy, FpTy}); |
406 | Value *Sat = Builder.CreateCall(Callee: Fn, Args: In); |
407 | I.replaceAllUsesWith(V: Builder.CreateSExt(V: Sat, DestTy: IntTy)); |
408 | return true; |
409 | } |
410 | |
411 | /// Try to replace a mathlib call to sqrt with the LLVM intrinsic. This avoids |
412 | /// pessimistic codegen that has to account for setting errno and can enable |
413 | /// vectorization. |
414 | static bool foldSqrt(CallInst *Call, LibFunc Func, TargetTransformInfo &TTI, |
415 | TargetLibraryInfo &TLI, AssumptionCache &AC, |
416 | DominatorTree &DT) { |
417 | |
418 | Module *M = Call->getModule(); |
419 | |
420 | // If (1) this is a sqrt libcall, (2) we can assume that NAN is not created |
421 | // (because NNAN or the operand arg must not be less than -0.0) and (2) we |
422 | // would not end up lowering to a libcall anyway (which could change the value |
423 | // of errno), then: |
424 | // (1) errno won't be set. |
425 | // (2) it is safe to convert this to an intrinsic call. |
426 | Type *Ty = Call->getType(); |
427 | Value *Arg = Call->getArgOperand(i: 0); |
428 | if (TTI.haveFastSqrt(Ty) && |
429 | (Call->hasNoNaNs() || |
430 | cannotBeOrderedLessThanZero( |
431 | V: Arg, Depth: 0, |
432 | SQ: SimplifyQuery(Call->getDataLayout(), &TLI, &DT, &AC, Call)))) { |
433 | IRBuilder<> Builder(Call); |
434 | IRBuilderBase::FastMathFlagGuard Guard(Builder); |
435 | Builder.setFastMathFlags(Call->getFastMathFlags()); |
436 | |
437 | Function *Sqrt = Intrinsic::getDeclaration(M, id: Intrinsic::sqrt, Tys: Ty); |
438 | Value *NewSqrt = Builder.CreateCall(Callee: Sqrt, Args: Arg, Name: "sqrt" ); |
439 | Call->replaceAllUsesWith(V: NewSqrt); |
440 | |
441 | // Explicitly erase the old call because a call with side effects is not |
442 | // trivially dead. |
443 | Call->eraseFromParent(); |
444 | return true; |
445 | } |
446 | |
447 | return false; |
448 | } |
449 | |
450 | // Check if this array of constants represents a cttz table. |
451 | // Iterate over the elements from \p Table by trying to find/match all |
452 | // the numbers from 0 to \p InputBits that should represent cttz results. |
453 | static bool isCTTZTable(const ConstantDataArray &Table, uint64_t Mul, |
454 | uint64_t Shift, uint64_t InputBits) { |
455 | unsigned Length = Table.getNumElements(); |
456 | if (Length < InputBits || Length > InputBits * 2) |
457 | return false; |
458 | |
459 | APInt Mask = APInt::getBitsSetFrom(numBits: InputBits, loBit: Shift); |
460 | unsigned Matched = 0; |
461 | |
462 | for (unsigned i = 0; i < Length; i++) { |
463 | uint64_t Element = Table.getElementAsInteger(i); |
464 | if (Element >= InputBits) |
465 | continue; |
466 | |
467 | // Check if \p Element matches a concrete answer. It could fail for some |
468 | // elements that are never accessed, so we keep iterating over each element |
469 | // from the table. The number of matched elements should be equal to the |
470 | // number of potential right answers which is \p InputBits actually. |
471 | if ((((Mul << Element) & Mask.getZExtValue()) >> Shift) == i) |
472 | Matched++; |
473 | } |
474 | |
475 | return Matched == InputBits; |
476 | } |
477 | |
478 | // Try to recognize table-based ctz implementation. |
479 | // E.g., an example in C (for more cases please see the llvm/tests): |
480 | // int f(unsigned x) { |
481 | // static const char table[32] = |
482 | // {0, 1, 28, 2, 29, 14, 24, 3, 30, |
483 | // 22, 20, 15, 25, 17, 4, 8, 31, 27, |
484 | // 13, 23, 21, 19, 16, 7, 26, 12, 18, 6, 11, 5, 10, 9}; |
485 | // return table[((unsigned)((x & -x) * 0x077CB531U)) >> 27]; |
486 | // } |
487 | // this can be lowered to `cttz` instruction. |
488 | // There is also a special case when the element is 0. |
489 | // |
490 | // Here are some examples or LLVM IR for a 64-bit target: |
491 | // |
492 | // CASE 1: |
493 | // %sub = sub i32 0, %x |
494 | // %and = and i32 %sub, %x |
495 | // %mul = mul i32 %and, 125613361 |
496 | // %shr = lshr i32 %mul, 27 |
497 | // %idxprom = zext i32 %shr to i64 |
498 | // %arrayidx = getelementptr inbounds [32 x i8], [32 x i8]* @ctz1.table, i64 0, |
499 | // i64 %idxprom |
500 | // %0 = load i8, i8* %arrayidx, align 1, !tbaa !8 |
501 | // |
502 | // CASE 2: |
503 | // %sub = sub i32 0, %x |
504 | // %and = and i32 %sub, %x |
505 | // %mul = mul i32 %and, 72416175 |
506 | // %shr = lshr i32 %mul, 26 |
507 | // %idxprom = zext i32 %shr to i64 |
508 | // %arrayidx = getelementptr inbounds [64 x i16], [64 x i16]* @ctz2.table, |
509 | // i64 0, i64 %idxprom |
510 | // %0 = load i16, i16* %arrayidx, align 2, !tbaa !8 |
511 | // |
512 | // CASE 3: |
513 | // %sub = sub i32 0, %x |
514 | // %and = and i32 %sub, %x |
515 | // %mul = mul i32 %and, 81224991 |
516 | // %shr = lshr i32 %mul, 27 |
517 | // %idxprom = zext i32 %shr to i64 |
518 | // %arrayidx = getelementptr inbounds [32 x i32], [32 x i32]* @ctz3.table, |
519 | // i64 0, i64 %idxprom |
520 | // %0 = load i32, i32* %arrayidx, align 4, !tbaa !8 |
521 | // |
522 | // CASE 4: |
523 | // %sub = sub i64 0, %x |
524 | // %and = and i64 %sub, %x |
525 | // %mul = mul i64 %and, 283881067100198605 |
526 | // %shr = lshr i64 %mul, 58 |
527 | // %arrayidx = getelementptr inbounds [64 x i8], [64 x i8]* @table, i64 0, |
528 | // i64 %shr |
529 | // %0 = load i8, i8* %arrayidx, align 1, !tbaa !8 |
530 | // |
531 | // All this can be lowered to @llvm.cttz.i32/64 intrinsic. |
532 | static bool tryToRecognizeTableBasedCttz(Instruction &I) { |
533 | LoadInst *LI = dyn_cast<LoadInst>(Val: &I); |
534 | if (!LI) |
535 | return false; |
536 | |
537 | Type *AccessType = LI->getType(); |
538 | if (!AccessType->isIntegerTy()) |
539 | return false; |
540 | |
541 | GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(Val: LI->getPointerOperand()); |
542 | if (!GEP || !GEP->isInBounds() || GEP->getNumIndices() != 2) |
543 | return false; |
544 | |
545 | if (!GEP->getSourceElementType()->isArrayTy()) |
546 | return false; |
547 | |
548 | uint64_t ArraySize = GEP->getSourceElementType()->getArrayNumElements(); |
549 | if (ArraySize != 32 && ArraySize != 64) |
550 | return false; |
551 | |
552 | GlobalVariable *GVTable = dyn_cast<GlobalVariable>(Val: GEP->getPointerOperand()); |
553 | if (!GVTable || !GVTable->hasInitializer() || !GVTable->isConstant()) |
554 | return false; |
555 | |
556 | ConstantDataArray *ConstData = |
557 | dyn_cast<ConstantDataArray>(Val: GVTable->getInitializer()); |
558 | if (!ConstData) |
559 | return false; |
560 | |
561 | if (!match(V: GEP->idx_begin()->get(), P: m_ZeroInt())) |
562 | return false; |
563 | |
564 | Value *Idx2 = std::next(x: GEP->idx_begin())->get(); |
565 | Value *X1; |
566 | uint64_t MulConst, ShiftConst; |
567 | // FIXME: 64-bit targets have `i64` type for the GEP index, so this match will |
568 | // probably fail for other (e.g. 32-bit) targets. |
569 | if (!match(V: Idx2, P: m_ZExtOrSelf( |
570 | Op: m_LShr(L: m_Mul(L: m_c_And(L: m_Neg(V: m_Value(V&: X1)), R: m_Deferred(V: X1)), |
571 | R: m_ConstantInt(V&: MulConst)), |
572 | R: m_ConstantInt(V&: ShiftConst))))) |
573 | return false; |
574 | |
575 | unsigned InputBits = X1->getType()->getScalarSizeInBits(); |
576 | if (InputBits != 32 && InputBits != 64) |
577 | return false; |
578 | |
579 | // Shift should extract top 5..7 bits. |
580 | if (InputBits - Log2_32(Value: InputBits) != ShiftConst && |
581 | InputBits - Log2_32(Value: InputBits) - 1 != ShiftConst) |
582 | return false; |
583 | |
584 | if (!isCTTZTable(Table: *ConstData, Mul: MulConst, Shift: ShiftConst, InputBits)) |
585 | return false; |
586 | |
587 | auto ZeroTableElem = ConstData->getElementAsInteger(i: 0); |
588 | bool DefinedForZero = ZeroTableElem == InputBits; |
589 | |
590 | IRBuilder<> B(LI); |
591 | ConstantInt *BoolConst = B.getInt1(V: !DefinedForZero); |
592 | Type *XType = X1->getType(); |
593 | auto Cttz = B.CreateIntrinsic(ID: Intrinsic::cttz, Types: {XType}, Args: {X1, BoolConst}); |
594 | Value *ZExtOrTrunc = nullptr; |
595 | |
596 | if (DefinedForZero) { |
597 | ZExtOrTrunc = B.CreateZExtOrTrunc(V: Cttz, DestTy: AccessType); |
598 | } else { |
599 | // If the value in elem 0 isn't the same as InputBits, we still want to |
600 | // produce the value from the table. |
601 | auto Cmp = B.CreateICmpEQ(LHS: X1, RHS: ConstantInt::get(Ty: XType, V: 0)); |
602 | auto Select = |
603 | B.CreateSelect(C: Cmp, True: ConstantInt::get(Ty: XType, V: ZeroTableElem), False: Cttz); |
604 | |
605 | // NOTE: If the table[0] is 0, but the cttz(0) is defined by the Target |
606 | // it should be handled as: `cttz(x) & (typeSize - 1)`. |
607 | |
608 | ZExtOrTrunc = B.CreateZExtOrTrunc(V: Select, DestTy: AccessType); |
609 | } |
610 | |
611 | LI->replaceAllUsesWith(V: ZExtOrTrunc); |
612 | |
613 | return true; |
614 | } |
615 | |
616 | /// This is used by foldLoadsRecursive() to capture a Root Load node which is |
617 | /// of type or(load, load) and recursively build the wide load. Also capture the |
618 | /// shift amount, zero extend type and loadSize. |
619 | struct LoadOps { |
620 | LoadInst *Root = nullptr; |
621 | LoadInst *RootInsert = nullptr; |
622 | bool FoundRoot = false; |
623 | uint64_t LoadSize = 0; |
624 | const APInt *Shift = nullptr; |
625 | Type *ZextType; |
626 | AAMDNodes AATags; |
627 | }; |
628 | |
629 | // Identify and Merge consecutive loads recursively which is of the form |
630 | // (ZExt(L1) << shift1) | (ZExt(L2) << shift2) -> ZExt(L3) << shift1 |
631 | // (ZExt(L1) << shift1) | ZExt(L2) -> ZExt(L3) |
632 | static bool foldLoadsRecursive(Value *V, LoadOps &LOps, const DataLayout &DL, |
633 | AliasAnalysis &AA) { |
634 | const APInt *ShAmt2 = nullptr; |
635 | Value *X; |
636 | Instruction *L1, *L2; |
637 | |
638 | // Go to the last node with loads. |
639 | if (match(V, P: m_OneUse(SubPattern: m_c_Or( |
640 | L: m_Value(V&: X), |
641 | R: m_OneUse(SubPattern: m_Shl(L: m_OneUse(SubPattern: m_ZExt(Op: m_OneUse(SubPattern: m_Instruction(I&: L2)))), |
642 | R: m_APInt(Res&: ShAmt2)))))) || |
643 | match(V, P: m_OneUse(SubPattern: m_Or(L: m_Value(V&: X), |
644 | R: m_OneUse(SubPattern: m_ZExt(Op: m_OneUse(SubPattern: m_Instruction(I&: L2)))))))) { |
645 | if (!foldLoadsRecursive(V: X, LOps, DL, AA) && LOps.FoundRoot) |
646 | // Avoid Partial chain merge. |
647 | return false; |
648 | } else |
649 | return false; |
650 | |
651 | // Check if the pattern has loads |
652 | LoadInst *LI1 = LOps.Root; |
653 | const APInt *ShAmt1 = LOps.Shift; |
654 | if (LOps.FoundRoot == false && |
655 | (match(V: X, P: m_OneUse(SubPattern: m_ZExt(Op: m_Instruction(I&: L1)))) || |
656 | match(V: X, P: m_OneUse(SubPattern: m_Shl(L: m_OneUse(SubPattern: m_ZExt(Op: m_OneUse(SubPattern: m_Instruction(I&: L1)))), |
657 | R: m_APInt(Res&: ShAmt1)))))) { |
658 | LI1 = dyn_cast<LoadInst>(Val: L1); |
659 | } |
660 | LoadInst *LI2 = dyn_cast<LoadInst>(Val: L2); |
661 | |
662 | // Check if loads are same, atomic, volatile and having same address space. |
663 | if (LI1 == LI2 || !LI1 || !LI2 || !LI1->isSimple() || !LI2->isSimple() || |
664 | LI1->getPointerAddressSpace() != LI2->getPointerAddressSpace()) |
665 | return false; |
666 | |
667 | // Check if Loads come from same BB. |
668 | if (LI1->getParent() != LI2->getParent()) |
669 | return false; |
670 | |
671 | // Find the data layout |
672 | bool IsBigEndian = DL.isBigEndian(); |
673 | |
674 | // Check if loads are consecutive and same size. |
675 | Value *Load1Ptr = LI1->getPointerOperand(); |
676 | APInt Offset1(DL.getIndexTypeSizeInBits(Ty: Load1Ptr->getType()), 0); |
677 | Load1Ptr = |
678 | Load1Ptr->stripAndAccumulateConstantOffsets(DL, Offset&: Offset1, |
679 | /* AllowNonInbounds */ true); |
680 | |
681 | Value *Load2Ptr = LI2->getPointerOperand(); |
682 | APInt Offset2(DL.getIndexTypeSizeInBits(Ty: Load2Ptr->getType()), 0); |
683 | Load2Ptr = |
684 | Load2Ptr->stripAndAccumulateConstantOffsets(DL, Offset&: Offset2, |
685 | /* AllowNonInbounds */ true); |
686 | |
687 | // Verify if both loads have same base pointers and load sizes are same. |
688 | uint64_t LoadSize1 = LI1->getType()->getPrimitiveSizeInBits(); |
689 | uint64_t LoadSize2 = LI2->getType()->getPrimitiveSizeInBits(); |
690 | if (Load1Ptr != Load2Ptr || LoadSize1 != LoadSize2) |
691 | return false; |
692 | |
693 | // Support Loadsizes greater or equal to 8bits and only power of 2. |
694 | if (LoadSize1 < 8 || !isPowerOf2_64(Value: LoadSize1)) |
695 | return false; |
696 | |
697 | // Alias Analysis to check for stores b/w the loads. |
698 | LoadInst *Start = LOps.FoundRoot ? LOps.RootInsert : LI1, *End = LI2; |
699 | MemoryLocation Loc; |
700 | if (!Start->comesBefore(Other: End)) { |
701 | std::swap(a&: Start, b&: End); |
702 | Loc = MemoryLocation::get(LI: End); |
703 | if (LOps.FoundRoot) |
704 | Loc = Loc.getWithNewSize(NewSize: LOps.LoadSize); |
705 | } else |
706 | Loc = MemoryLocation::get(LI: End); |
707 | unsigned NumScanned = 0; |
708 | for (Instruction &Inst : |
709 | make_range(x: Start->getIterator(), y: End->getIterator())) { |
710 | if (Inst.mayWriteToMemory() && isModSet(MRI: AA.getModRefInfo(I: &Inst, OptLoc: Loc))) |
711 | return false; |
712 | |
713 | // Ignore debug info so that's not counted against MaxInstrsToScan. |
714 | // Otherwise debug info could affect codegen. |
715 | if (!isa<DbgInfoIntrinsic>(Val: Inst) && ++NumScanned > MaxInstrsToScan) |
716 | return false; |
717 | } |
718 | |
719 | // Make sure Load with lower Offset is at LI1 |
720 | bool Reverse = false; |
721 | if (Offset2.slt(RHS: Offset1)) { |
722 | std::swap(a&: LI1, b&: LI2); |
723 | std::swap(a&: ShAmt1, b&: ShAmt2); |
724 | std::swap(a&: Offset1, b&: Offset2); |
725 | std::swap(a&: Load1Ptr, b&: Load2Ptr); |
726 | std::swap(a&: LoadSize1, b&: LoadSize2); |
727 | Reverse = true; |
728 | } |
729 | |
730 | // Big endian swap the shifts |
731 | if (IsBigEndian) |
732 | std::swap(a&: ShAmt1, b&: ShAmt2); |
733 | |
734 | // Find Shifts values. |
735 | uint64_t Shift1 = 0, Shift2 = 0; |
736 | if (ShAmt1) |
737 | Shift1 = ShAmt1->getZExtValue(); |
738 | if (ShAmt2) |
739 | Shift2 = ShAmt2->getZExtValue(); |
740 | |
741 | // First load is always LI1. This is where we put the new load. |
742 | // Use the merged load size available from LI1 for forward loads. |
743 | if (LOps.FoundRoot) { |
744 | if (!Reverse) |
745 | LoadSize1 = LOps.LoadSize; |
746 | else |
747 | LoadSize2 = LOps.LoadSize; |
748 | } |
749 | |
750 | // Verify if shift amount and load index aligns and verifies that loads |
751 | // are consecutive. |
752 | uint64_t ShiftDiff = IsBigEndian ? LoadSize2 : LoadSize1; |
753 | uint64_t PrevSize = |
754 | DL.getTypeStoreSize(Ty: IntegerType::get(C&: LI1->getContext(), NumBits: LoadSize1)); |
755 | if ((Shift2 - Shift1) != ShiftDiff || (Offset2 - Offset1) != PrevSize) |
756 | return false; |
757 | |
758 | // Update LOps |
759 | AAMDNodes AATags1 = LOps.AATags; |
760 | AAMDNodes AATags2 = LI2->getAAMetadata(); |
761 | if (LOps.FoundRoot == false) { |
762 | LOps.FoundRoot = true; |
763 | AATags1 = LI1->getAAMetadata(); |
764 | } |
765 | LOps.LoadSize = LoadSize1 + LoadSize2; |
766 | LOps.RootInsert = Start; |
767 | |
768 | // Concatenate the AATags of the Merged Loads. |
769 | LOps.AATags = AATags1.concat(Other: AATags2); |
770 | |
771 | LOps.Root = LI1; |
772 | LOps.Shift = ShAmt1; |
773 | LOps.ZextType = X->getType(); |
774 | return true; |
775 | } |
776 | |
777 | // For a given BB instruction, evaluate all loads in the chain that form a |
778 | // pattern which suggests that the loads can be combined. The one and only use |
779 | // of the loads is to form a wider load. |
780 | static bool foldConsecutiveLoads(Instruction &I, const DataLayout &DL, |
781 | TargetTransformInfo &TTI, AliasAnalysis &AA, |
782 | const DominatorTree &DT) { |
783 | // Only consider load chains of scalar values. |
784 | if (isa<VectorType>(Val: I.getType())) |
785 | return false; |
786 | |
787 | LoadOps LOps; |
788 | if (!foldLoadsRecursive(V: &I, LOps, DL, AA) || !LOps.FoundRoot) |
789 | return false; |
790 | |
791 | IRBuilder<> Builder(&I); |
792 | LoadInst *NewLoad = nullptr, *LI1 = LOps.Root; |
793 | |
794 | IntegerType *WiderType = IntegerType::get(C&: I.getContext(), NumBits: LOps.LoadSize); |
795 | // TTI based checks if we want to proceed with wider load |
796 | bool Allowed = TTI.isTypeLegal(Ty: WiderType); |
797 | if (!Allowed) |
798 | return false; |
799 | |
800 | unsigned AS = LI1->getPointerAddressSpace(); |
801 | unsigned Fast = 0; |
802 | Allowed = TTI.allowsMisalignedMemoryAccesses(Context&: I.getContext(), BitWidth: LOps.LoadSize, |
803 | AddressSpace: AS, Alignment: LI1->getAlign(), Fast: &Fast); |
804 | if (!Allowed || !Fast) |
805 | return false; |
806 | |
807 | // Get the Index and Ptr for the new GEP. |
808 | Value *Load1Ptr = LI1->getPointerOperand(); |
809 | Builder.SetInsertPoint(LOps.RootInsert); |
810 | if (!DT.dominates(Def: Load1Ptr, User: LOps.RootInsert)) { |
811 | APInt Offset1(DL.getIndexTypeSizeInBits(Ty: Load1Ptr->getType()), 0); |
812 | Load1Ptr = Load1Ptr->stripAndAccumulateConstantOffsets( |
813 | DL, Offset&: Offset1, /* AllowNonInbounds */ true); |
814 | Load1Ptr = Builder.CreatePtrAdd(Ptr: Load1Ptr, |
815 | Offset: Builder.getInt32(C: Offset1.getZExtValue())); |
816 | } |
817 | // Generate wider load. |
818 | NewLoad = Builder.CreateAlignedLoad(Ty: WiderType, Ptr: Load1Ptr, Align: LI1->getAlign(), |
819 | isVolatile: LI1->isVolatile(), Name: "" ); |
820 | NewLoad->takeName(V: LI1); |
821 | // Set the New Load AATags Metadata. |
822 | if (LOps.AATags) |
823 | NewLoad->setAAMetadata(LOps.AATags); |
824 | |
825 | Value *NewOp = NewLoad; |
826 | // Check if zero extend needed. |
827 | if (LOps.ZextType) |
828 | NewOp = Builder.CreateZExt(V: NewOp, DestTy: LOps.ZextType); |
829 | |
830 | // Check if shift needed. We need to shift with the amount of load1 |
831 | // shift if not zero. |
832 | if (LOps.Shift) |
833 | NewOp = Builder.CreateShl(LHS: NewOp, RHS: ConstantInt::get(Context&: I.getContext(), V: *LOps.Shift)); |
834 | I.replaceAllUsesWith(V: NewOp); |
835 | |
836 | return true; |
837 | } |
838 | |
839 | // Calculate GEP Stride and accumulated const ModOffset. Return Stride and |
840 | // ModOffset |
841 | static std::pair<APInt, APInt> |
842 | getStrideAndModOffsetOfGEP(Value *PtrOp, const DataLayout &DL) { |
843 | unsigned BW = DL.getIndexTypeSizeInBits(Ty: PtrOp->getType()); |
844 | std::optional<APInt> Stride; |
845 | APInt ModOffset(BW, 0); |
846 | // Return a minimum gep stride, greatest common divisor of consective gep |
847 | // index scales(c.f. Bézout's identity). |
848 | while (auto *GEP = dyn_cast<GEPOperator>(Val: PtrOp)) { |
849 | MapVector<Value *, APInt> VarOffsets; |
850 | if (!GEP->collectOffset(DL, BitWidth: BW, VariableOffsets&: VarOffsets, ConstantOffset&: ModOffset)) |
851 | break; |
852 | |
853 | for (auto [V, Scale] : VarOffsets) { |
854 | // Only keep a power of two factor for non-inbounds |
855 | if (!GEP->isInBounds()) |
856 | Scale = APInt::getOneBitSet(numBits: Scale.getBitWidth(), BitNo: Scale.countr_zero()); |
857 | |
858 | if (!Stride) |
859 | Stride = Scale; |
860 | else |
861 | Stride = APIntOps::GreatestCommonDivisor(A: *Stride, B: Scale); |
862 | } |
863 | |
864 | PtrOp = GEP->getPointerOperand(); |
865 | } |
866 | |
867 | // Check whether pointer arrives back at Global Variable via at least one GEP. |
868 | // Even if it doesn't, we can check by alignment. |
869 | if (!isa<GlobalVariable>(Val: PtrOp) || !Stride) |
870 | return {APInt(BW, 1), APInt(BW, 0)}; |
871 | |
872 | // In consideration of signed GEP indices, non-negligible offset become |
873 | // remainder of division by minimum GEP stride. |
874 | ModOffset = ModOffset.srem(RHS: *Stride); |
875 | if (ModOffset.isNegative()) |
876 | ModOffset += *Stride; |
877 | |
878 | return {*Stride, ModOffset}; |
879 | } |
880 | |
881 | /// If C is a constant patterned array and all valid loaded results for given |
882 | /// alignment are same to a constant, return that constant. |
883 | static bool foldPatternedLoads(Instruction &I, const DataLayout &DL) { |
884 | auto *LI = dyn_cast<LoadInst>(Val: &I); |
885 | if (!LI || LI->isVolatile()) |
886 | return false; |
887 | |
888 | // We can only fold the load if it is from a constant global with definitive |
889 | // initializer. Skip expensive logic if this is not the case. |
890 | auto *PtrOp = LI->getPointerOperand(); |
891 | auto *GV = dyn_cast<GlobalVariable>(Val: getUnderlyingObject(V: PtrOp)); |
892 | if (!GV || !GV->isConstant() || !GV->hasDefinitiveInitializer()) |
893 | return false; |
894 | |
895 | // Bail for large initializers in excess of 4K to avoid too many scans. |
896 | Constant *C = GV->getInitializer(); |
897 | uint64_t GVSize = DL.getTypeAllocSize(Ty: C->getType()); |
898 | if (!GVSize || 4096 < GVSize) |
899 | return false; |
900 | |
901 | Type *LoadTy = LI->getType(); |
902 | unsigned BW = DL.getIndexTypeSizeInBits(Ty: PtrOp->getType()); |
903 | auto [Stride, ConstOffset] = getStrideAndModOffsetOfGEP(PtrOp, DL); |
904 | |
905 | // Any possible offset could be multiple of GEP stride. And any valid |
906 | // offset is multiple of load alignment, so checking only multiples of bigger |
907 | // one is sufficient to say results' equality. |
908 | if (auto LA = LI->getAlign(); |
909 | LA <= GV->getAlign().valueOrOne() && Stride.getZExtValue() < LA.value()) { |
910 | ConstOffset = APInt(BW, 0); |
911 | Stride = APInt(BW, LA.value()); |
912 | } |
913 | |
914 | Constant *Ca = ConstantFoldLoadFromConst(C, Ty: LoadTy, Offset: ConstOffset, DL); |
915 | if (!Ca) |
916 | return false; |
917 | |
918 | unsigned E = GVSize - DL.getTypeStoreSize(Ty: LoadTy); |
919 | for (; ConstOffset.getZExtValue() <= E; ConstOffset += Stride) |
920 | if (Ca != ConstantFoldLoadFromConst(C, Ty: LoadTy, Offset: ConstOffset, DL)) |
921 | return false; |
922 | |
923 | I.replaceAllUsesWith(V: Ca); |
924 | |
925 | return true; |
926 | } |
927 | |
928 | namespace { |
929 | class StrNCmpInliner { |
930 | public: |
931 | StrNCmpInliner(CallInst *CI, LibFunc Func, DomTreeUpdater *DTU, |
932 | const DataLayout &DL) |
933 | : CI(CI), Func(Func), DTU(DTU), DL(DL) {} |
934 | |
935 | bool optimizeStrNCmp(); |
936 | |
937 | private: |
938 | void inlineCompare(Value *LHS, StringRef RHS, uint64_t N, bool Swapped); |
939 | |
940 | CallInst *CI; |
941 | LibFunc Func; |
942 | DomTreeUpdater *DTU; |
943 | const DataLayout &DL; |
944 | }; |
945 | |
946 | } // namespace |
947 | |
948 | /// First we normalize calls to strncmp/strcmp to the form of |
949 | /// compare(s1, s2, N), which means comparing first N bytes of s1 and s2 |
950 | /// (without considering '\0'). |
951 | /// |
952 | /// Examples: |
953 | /// |
954 | /// \code |
955 | /// strncmp(s, "a", 3) -> compare(s, "a", 2) |
956 | /// strncmp(s, "abc", 3) -> compare(s, "abc", 3) |
957 | /// strncmp(s, "a\0b", 3) -> compare(s, "a\0b", 2) |
958 | /// strcmp(s, "a") -> compare(s, "a", 2) |
959 | /// |
960 | /// char s2[] = {'a'} |
961 | /// strncmp(s, s2, 3) -> compare(s, s2, 3) |
962 | /// |
963 | /// char s2[] = {'a', 'b', 'c', 'd'} |
964 | /// strncmp(s, s2, 3) -> compare(s, s2, 3) |
965 | /// \endcode |
966 | /// |
967 | /// We only handle cases where N and exactly one of s1 and s2 are constant. |
968 | /// Cases that s1 and s2 are both constant are already handled by the |
969 | /// instcombine pass. |
970 | /// |
971 | /// We do not handle cases where N > StrNCmpInlineThreshold. |
972 | /// |
973 | /// We also do not handles cases where N < 2, which are already |
974 | /// handled by the instcombine pass. |
975 | /// |
976 | bool StrNCmpInliner::optimizeStrNCmp() { |
977 | if (StrNCmpInlineThreshold < 2) |
978 | return false; |
979 | |
980 | if (!isOnlyUsedInZeroComparison(CxtI: CI)) |
981 | return false; |
982 | |
983 | Value *Str1P = CI->getArgOperand(i: 0); |
984 | Value *Str2P = CI->getArgOperand(i: 1); |
985 | // Should be handled elsewhere. |
986 | if (Str1P == Str2P) |
987 | return false; |
988 | |
989 | StringRef Str1, Str2; |
990 | bool HasStr1 = getConstantStringInfo(V: Str1P, Str&: Str1, /*TrimAtNul=*/false); |
991 | bool HasStr2 = getConstantStringInfo(V: Str2P, Str&: Str2, /*TrimAtNul=*/false); |
992 | if (HasStr1 == HasStr2) |
993 | return false; |
994 | |
995 | // Note that '\0' and characters after it are not trimmed. |
996 | StringRef Str = HasStr1 ? Str1 : Str2; |
997 | Value *StrP = HasStr1 ? Str2P : Str1P; |
998 | |
999 | size_t Idx = Str.find(C: '\0'); |
1000 | uint64_t N = Idx == StringRef::npos ? UINT64_MAX : Idx + 1; |
1001 | if (Func == LibFunc_strncmp) { |
1002 | if (auto *ConstInt = dyn_cast<ConstantInt>(Val: CI->getArgOperand(i: 2))) |
1003 | N = std::min(a: N, b: ConstInt->getZExtValue()); |
1004 | else |
1005 | return false; |
1006 | } |
1007 | // Now N means how many bytes we need to compare at most. |
1008 | if (N > Str.size() || N < 2 || N > StrNCmpInlineThreshold) |
1009 | return false; |
1010 | |
1011 | // Cases where StrP has two or more dereferenceable bytes might be better |
1012 | // optimized elsewhere. |
1013 | bool CanBeNull = false, CanBeFreed = false; |
1014 | if (StrP->getPointerDereferenceableBytes(DL, CanBeNull, CanBeFreed) > 1) |
1015 | return false; |
1016 | inlineCompare(LHS: StrP, RHS: Str, N, Swapped: HasStr1); |
1017 | return true; |
1018 | } |
1019 | |
1020 | /// Convert |
1021 | /// |
1022 | /// \code |
1023 | /// ret = compare(s1, s2, N) |
1024 | /// \endcode |
1025 | /// |
1026 | /// into |
1027 | /// |
1028 | /// \code |
1029 | /// ret = (int)s1[0] - (int)s2[0] |
1030 | /// if (ret != 0) |
1031 | /// goto NE |
1032 | /// ... |
1033 | /// ret = (int)s1[N-2] - (int)s2[N-2] |
1034 | /// if (ret != 0) |
1035 | /// goto NE |
1036 | /// ret = (int)s1[N-1] - (int)s2[N-1] |
1037 | /// NE: |
1038 | /// \endcode |
1039 | /// |
1040 | /// CFG before and after the transformation: |
1041 | /// |
1042 | /// (before) |
1043 | /// BBCI |
1044 | /// |
1045 | /// (after) |
1046 | /// BBCI -> BBSubs[0] (sub,icmp) --NE-> BBNE -> BBTail |
1047 | /// | ^ |
1048 | /// E | |
1049 | /// | | |
1050 | /// BBSubs[1] (sub,icmp) --NE-----+ |
1051 | /// ... | |
1052 | /// BBSubs[N-1] (sub) ---------+ |
1053 | /// |
1054 | void StrNCmpInliner::inlineCompare(Value *LHS, StringRef RHS, uint64_t N, |
1055 | bool Swapped) { |
1056 | auto &Ctx = CI->getContext(); |
1057 | IRBuilder<> B(Ctx); |
1058 | |
1059 | BasicBlock *BBCI = CI->getParent(); |
1060 | BasicBlock *BBTail = |
1061 | SplitBlock(Old: BBCI, SplitPt: CI, DTU, LI: nullptr, MSSAU: nullptr, BBName: BBCI->getName() + ".tail" ); |
1062 | |
1063 | SmallVector<BasicBlock *> BBSubs; |
1064 | for (uint64_t I = 0; I < N; ++I) |
1065 | BBSubs.push_back( |
1066 | Elt: BasicBlock::Create(Context&: Ctx, Name: "sub_" + Twine(I), Parent: BBCI->getParent(), InsertBefore: BBTail)); |
1067 | BasicBlock *BBNE = BasicBlock::Create(Context&: Ctx, Name: "ne" , Parent: BBCI->getParent(), InsertBefore: BBTail); |
1068 | |
1069 | cast<BranchInst>(Val: BBCI->getTerminator())->setSuccessor(idx: 0, NewSucc: BBSubs[0]); |
1070 | |
1071 | B.SetInsertPoint(BBNE); |
1072 | PHINode *Phi = B.CreatePHI(Ty: CI->getType(), NumReservedValues: N); |
1073 | B.CreateBr(Dest: BBTail); |
1074 | |
1075 | Value *Base = LHS; |
1076 | for (uint64_t i = 0; i < N; ++i) { |
1077 | B.SetInsertPoint(BBSubs[i]); |
1078 | Value *VL = |
1079 | B.CreateZExt(V: B.CreateLoad(Ty: B.getInt8Ty(), |
1080 | Ptr: B.CreateInBoundsPtrAdd(Ptr: Base, Offset: B.getInt64(C: i))), |
1081 | DestTy: CI->getType()); |
1082 | Value *VR = |
1083 | ConstantInt::get(Ty: CI->getType(), V: static_cast<unsigned char>(RHS[i])); |
1084 | Value *Sub = Swapped ? B.CreateSub(LHS: VR, RHS: VL) : B.CreateSub(LHS: VL, RHS: VR); |
1085 | if (i < N - 1) |
1086 | B.CreateCondBr(Cond: B.CreateICmpNE(LHS: Sub, RHS: ConstantInt::get(Ty: CI->getType(), V: 0)), |
1087 | True: BBNE, False: BBSubs[i + 1]); |
1088 | else |
1089 | B.CreateBr(Dest: BBNE); |
1090 | |
1091 | Phi->addIncoming(V: Sub, BB: BBSubs[i]); |
1092 | } |
1093 | |
1094 | CI->replaceAllUsesWith(V: Phi); |
1095 | CI->eraseFromParent(); |
1096 | |
1097 | if (DTU) { |
1098 | SmallVector<DominatorTree::UpdateType, 8> Updates; |
1099 | Updates.push_back(Elt: {DominatorTree::Insert, BBCI, BBSubs[0]}); |
1100 | for (uint64_t i = 0; i < N; ++i) { |
1101 | if (i < N - 1) |
1102 | Updates.push_back(Elt: {DominatorTree::Insert, BBSubs[i], BBSubs[i + 1]}); |
1103 | Updates.push_back(Elt: {DominatorTree::Insert, BBSubs[i], BBNE}); |
1104 | } |
1105 | Updates.push_back(Elt: {DominatorTree::Insert, BBNE, BBTail}); |
1106 | Updates.push_back(Elt: {DominatorTree::Delete, BBCI, BBTail}); |
1107 | DTU->applyUpdates(Updates); |
1108 | } |
1109 | } |
1110 | |
1111 | /// Convert memchr with a small constant string into a switch |
1112 | static bool foldMemChr(CallInst *Call, DomTreeUpdater *DTU, |
1113 | const DataLayout &DL) { |
1114 | if (isa<Constant>(Val: Call->getArgOperand(i: 1))) |
1115 | return false; |
1116 | |
1117 | StringRef Str; |
1118 | Value *Base = Call->getArgOperand(i: 0); |
1119 | if (!getConstantStringInfo(V: Base, Str, /*TrimAtNul=*/false)) |
1120 | return false; |
1121 | |
1122 | uint64_t N = Str.size(); |
1123 | if (auto *ConstInt = dyn_cast<ConstantInt>(Val: Call->getArgOperand(i: 2))) { |
1124 | uint64_t Val = ConstInt->getZExtValue(); |
1125 | // Ignore the case that n is larger than the size of string. |
1126 | if (Val > N) |
1127 | return false; |
1128 | N = Val; |
1129 | } else |
1130 | return false; |
1131 | |
1132 | if (N > MemChrInlineThreshold) |
1133 | return false; |
1134 | |
1135 | BasicBlock *BB = Call->getParent(); |
1136 | BasicBlock *BBNext = SplitBlock(Old: BB, SplitPt: Call, DTU); |
1137 | IRBuilder<> IRB(BB); |
1138 | IntegerType *ByteTy = IRB.getInt8Ty(); |
1139 | BB->getTerminator()->eraseFromParent(); |
1140 | SwitchInst *SI = IRB.CreateSwitch( |
1141 | V: IRB.CreateTrunc(V: Call->getArgOperand(i: 1), DestTy: ByteTy), Dest: BBNext, NumCases: N); |
1142 | Type *IndexTy = DL.getIndexType(PtrTy: Call->getType()); |
1143 | SmallVector<DominatorTree::UpdateType, 8> Updates; |
1144 | |
1145 | BasicBlock *BBSuccess = BasicBlock::Create( |
1146 | Context&: Call->getContext(), Name: "memchr.success" , Parent: BB->getParent(), InsertBefore: BBNext); |
1147 | IRB.SetInsertPoint(BBSuccess); |
1148 | PHINode *IndexPHI = IRB.CreatePHI(Ty: IndexTy, NumReservedValues: N, Name: "memchr.idx" ); |
1149 | Value *FirstOccursLocation = IRB.CreateInBoundsPtrAdd(Ptr: Base, Offset: IndexPHI); |
1150 | IRB.CreateBr(Dest: BBNext); |
1151 | if (DTU) |
1152 | Updates.push_back(Elt: {DominatorTree::Insert, BBSuccess, BBNext}); |
1153 | |
1154 | SmallPtrSet<ConstantInt *, 4> Cases; |
1155 | for (uint64_t I = 0; I < N; ++I) { |
1156 | ConstantInt *CaseVal = ConstantInt::get(Ty: ByteTy, V: Str[I]); |
1157 | if (!Cases.insert(Ptr: CaseVal).second) |
1158 | continue; |
1159 | |
1160 | BasicBlock *BBCase = BasicBlock::Create(Context&: Call->getContext(), Name: "memchr.case" , |
1161 | Parent: BB->getParent(), InsertBefore: BBSuccess); |
1162 | SI->addCase(OnVal: CaseVal, Dest: BBCase); |
1163 | IRB.SetInsertPoint(BBCase); |
1164 | IndexPHI->addIncoming(V: ConstantInt::get(Ty: IndexTy, V: I), BB: BBCase); |
1165 | IRB.CreateBr(Dest: BBSuccess); |
1166 | if (DTU) { |
1167 | Updates.push_back(Elt: {DominatorTree::Insert, BB, BBCase}); |
1168 | Updates.push_back(Elt: {DominatorTree::Insert, BBCase, BBSuccess}); |
1169 | } |
1170 | } |
1171 | |
1172 | PHINode *PHI = |
1173 | PHINode::Create(Ty: Call->getType(), NumReservedValues: 2, NameStr: Call->getName(), InsertBefore: BBNext->begin()); |
1174 | PHI->addIncoming(V: Constant::getNullValue(Ty: Call->getType()), BB); |
1175 | PHI->addIncoming(V: FirstOccursLocation, BB: BBSuccess); |
1176 | |
1177 | Call->replaceAllUsesWith(V: PHI); |
1178 | Call->eraseFromParent(); |
1179 | |
1180 | if (DTU) |
1181 | DTU->applyUpdates(Updates); |
1182 | |
1183 | return true; |
1184 | } |
1185 | |
1186 | static bool foldLibCalls(Instruction &I, TargetTransformInfo &TTI, |
1187 | TargetLibraryInfo &TLI, AssumptionCache &AC, |
1188 | DominatorTree &DT, const DataLayout &DL, |
1189 | bool &MadeCFGChange) { |
1190 | |
1191 | auto *CI = dyn_cast<CallInst>(Val: &I); |
1192 | if (!CI || CI->isNoBuiltin()) |
1193 | return false; |
1194 | |
1195 | Function *CalledFunc = CI->getCalledFunction(); |
1196 | if (!CalledFunc) |
1197 | return false; |
1198 | |
1199 | LibFunc LF; |
1200 | if (!TLI.getLibFunc(FDecl: *CalledFunc, F&: LF) || |
1201 | !isLibFuncEmittable(M: CI->getModule(), TLI: &TLI, TheLibFunc: LF)) |
1202 | return false; |
1203 | |
1204 | DomTreeUpdater DTU(&DT, DomTreeUpdater::UpdateStrategy::Lazy); |
1205 | |
1206 | switch (LF) { |
1207 | case LibFunc_sqrt: |
1208 | case LibFunc_sqrtf: |
1209 | case LibFunc_sqrtl: |
1210 | return foldSqrt(Call: CI, Func: LF, TTI, TLI, AC, DT); |
1211 | case LibFunc_strcmp: |
1212 | case LibFunc_strncmp: |
1213 | if (StrNCmpInliner(CI, LF, &DTU, DL).optimizeStrNCmp()) { |
1214 | MadeCFGChange = true; |
1215 | return true; |
1216 | } |
1217 | break; |
1218 | case LibFunc_memchr: |
1219 | if (foldMemChr(Call: CI, DTU: &DTU, DL)) { |
1220 | MadeCFGChange = true; |
1221 | return true; |
1222 | } |
1223 | break; |
1224 | default:; |
1225 | } |
1226 | return false; |
1227 | } |
1228 | |
1229 | /// This is the entry point for folds that could be implemented in regular |
1230 | /// InstCombine, but they are separated because they are not expected to |
1231 | /// occur frequently and/or have more than a constant-length pattern match. |
1232 | static bool foldUnusualPatterns(Function &F, DominatorTree &DT, |
1233 | TargetTransformInfo &TTI, |
1234 | TargetLibraryInfo &TLI, AliasAnalysis &AA, |
1235 | AssumptionCache &AC, bool &MadeCFGChange) { |
1236 | bool MadeChange = false; |
1237 | for (BasicBlock &BB : F) { |
1238 | // Ignore unreachable basic blocks. |
1239 | if (!DT.isReachableFromEntry(A: &BB)) |
1240 | continue; |
1241 | |
1242 | const DataLayout &DL = F.getDataLayout(); |
1243 | |
1244 | // Walk the block backwards for efficiency. We're matching a chain of |
1245 | // use->defs, so we're more likely to succeed by starting from the bottom. |
1246 | // Also, we want to avoid matching partial patterns. |
1247 | // TODO: It would be more efficient if we removed dead instructions |
1248 | // iteratively in this loop rather than waiting until the end. |
1249 | for (Instruction &I : make_early_inc_range(Range: llvm::reverse(C&: BB))) { |
1250 | MadeChange |= foldAnyOrAllBitsSet(I); |
1251 | MadeChange |= foldGuardedFunnelShift(I, DT); |
1252 | MadeChange |= tryToRecognizePopCount(I); |
1253 | MadeChange |= tryToFPToSat(I, TTI); |
1254 | MadeChange |= tryToRecognizeTableBasedCttz(I); |
1255 | MadeChange |= foldConsecutiveLoads(I, DL, TTI, AA, DT); |
1256 | MadeChange |= foldPatternedLoads(I, DL); |
1257 | // NOTE: This function introduces erasing of the instruction `I`, so it |
1258 | // needs to be called at the end of this sequence, otherwise we may make |
1259 | // bugs. |
1260 | MadeChange |= foldLibCalls(I, TTI, TLI, AC, DT, DL, MadeCFGChange); |
1261 | } |
1262 | } |
1263 | |
1264 | // We're done with transforms, so remove dead instructions. |
1265 | if (MadeChange) |
1266 | for (BasicBlock &BB : F) |
1267 | SimplifyInstructionsInBlock(BB: &BB); |
1268 | |
1269 | return MadeChange; |
1270 | } |
1271 | |
1272 | /// This is the entry point for all transforms. Pass manager differences are |
1273 | /// handled in the callers of this function. |
1274 | static bool runImpl(Function &F, AssumptionCache &AC, TargetTransformInfo &TTI, |
1275 | TargetLibraryInfo &TLI, DominatorTree &DT, |
1276 | AliasAnalysis &AA, bool &MadeCFGChange) { |
1277 | bool MadeChange = false; |
1278 | const DataLayout &DL = F.getDataLayout(); |
1279 | TruncInstCombine TIC(AC, TLI, DL, DT); |
1280 | MadeChange |= TIC.run(F); |
1281 | MadeChange |= foldUnusualPatterns(F, DT, TTI, TLI, AA, AC, MadeCFGChange); |
1282 | return MadeChange; |
1283 | } |
1284 | |
1285 | PreservedAnalyses AggressiveInstCombinePass::run(Function &F, |
1286 | FunctionAnalysisManager &AM) { |
1287 | auto &AC = AM.getResult<AssumptionAnalysis>(IR&: F); |
1288 | auto &TLI = AM.getResult<TargetLibraryAnalysis>(IR&: F); |
1289 | auto &DT = AM.getResult<DominatorTreeAnalysis>(IR&: F); |
1290 | auto &TTI = AM.getResult<TargetIRAnalysis>(IR&: F); |
1291 | auto &AA = AM.getResult<AAManager>(IR&: F); |
1292 | bool MadeCFGChange = false; |
1293 | if (!runImpl(F, AC, TTI, TLI, DT, AA, MadeCFGChange)) { |
1294 | // No changes, all analyses are preserved. |
1295 | return PreservedAnalyses::all(); |
1296 | } |
1297 | // Mark all the analyses that instcombine updates as preserved. |
1298 | PreservedAnalyses PA; |
1299 | if (MadeCFGChange) |
1300 | PA.preserve<DominatorTreeAnalysis>(); |
1301 | else |
1302 | PA.preserveSet<CFGAnalyses>(); |
1303 | return PA; |
1304 | } |
1305 | |