1//===- AggressiveInstCombine.cpp ------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements the aggressive expression pattern combiner classes.
10// Currently, it handles expression patterns for:
11// * Truncate instruction
12//
13//===----------------------------------------------------------------------===//
14
15#include "llvm/Transforms/AggressiveInstCombine/AggressiveInstCombine.h"
16#include "AggressiveInstCombineInternal.h"
17#include "llvm/ADT/Statistic.h"
18#include "llvm/Analysis/AliasAnalysis.h"
19#include "llvm/Analysis/AssumptionCache.h"
20#include "llvm/Analysis/BasicAliasAnalysis.h"
21#include "llvm/Analysis/ConstantFolding.h"
22#include "llvm/Analysis/DomTreeUpdater.h"
23#include "llvm/Analysis/GlobalsModRef.h"
24#include "llvm/Analysis/TargetLibraryInfo.h"
25#include "llvm/Analysis/TargetTransformInfo.h"
26#include "llvm/Analysis/ValueTracking.h"
27#include "llvm/IR/DataLayout.h"
28#include "llvm/IR/Dominators.h"
29#include "llvm/IR/Function.h"
30#include "llvm/IR/IRBuilder.h"
31#include "llvm/IR/PatternMatch.h"
32#include "llvm/Transforms/Utils/BasicBlockUtils.h"
33#include "llvm/Transforms/Utils/BuildLibCalls.h"
34#include "llvm/Transforms/Utils/Local.h"
35
36using namespace llvm;
37using namespace PatternMatch;
38
39#define DEBUG_TYPE "aggressive-instcombine"
40
41STATISTIC(NumAnyOrAllBitsSet, "Number of any/all-bits-set patterns folded");
42STATISTIC(NumGuardedRotates,
43 "Number of guarded rotates transformed into funnel shifts");
44STATISTIC(NumGuardedFunnelShifts,
45 "Number of guarded funnel shifts transformed into funnel shifts");
46STATISTIC(NumPopCountRecognized, "Number of popcount idioms recognized");
47
48static cl::opt<unsigned> MaxInstrsToScan(
49 "aggressive-instcombine-max-scan-instrs", cl::init(Val: 64), cl::Hidden,
50 cl::desc("Max number of instructions to scan for aggressive instcombine."));
51
52static cl::opt<unsigned> StrNCmpInlineThreshold(
53 "strncmp-inline-threshold", cl::init(Val: 3), cl::Hidden,
54 cl::desc("The maximum length of a constant string for a builtin string cmp "
55 "call eligible for inlining. The default value is 3."));
56
57static cl::opt<unsigned>
58 MemChrInlineThreshold("memchr-inline-threshold", cl::init(Val: 3), cl::Hidden,
59 cl::desc("The maximum length of a constant string to "
60 "inline a memchr call."));
61
62/// Match a pattern for a bitwise funnel/rotate operation that partially guards
63/// against undefined behavior by branching around the funnel-shift/rotation
64/// when the shift amount is 0.
65static bool foldGuardedFunnelShift(Instruction &I, const DominatorTree &DT) {
66 if (I.getOpcode() != Instruction::PHI || I.getNumOperands() != 2)
67 return false;
68
69 // As with the one-use checks below, this is not strictly necessary, but we
70 // are being cautious to avoid potential perf regressions on targets that
71 // do not actually have a funnel/rotate instruction (where the funnel shift
72 // would be expanded back into math/shift/logic ops).
73 if (!isPowerOf2_32(Value: I.getType()->getScalarSizeInBits()))
74 return false;
75
76 // Match V to funnel shift left/right and capture the source operands and
77 // shift amount.
78 auto matchFunnelShift = [](Value *V, Value *&ShVal0, Value *&ShVal1,
79 Value *&ShAmt) {
80 unsigned Width = V->getType()->getScalarSizeInBits();
81
82 // fshl(ShVal0, ShVal1, ShAmt)
83 // == (ShVal0 << ShAmt) | (ShVal1 >> (Width -ShAmt))
84 if (match(V, P: m_OneUse(SubPattern: m_c_Or(
85 L: m_Shl(L: m_Value(V&: ShVal0), R: m_Value(V&: ShAmt)),
86 R: m_LShr(L: m_Value(V&: ShVal1),
87 R: m_Sub(L: m_SpecificInt(V: Width), R: m_Deferred(V: ShAmt))))))) {
88 return Intrinsic::fshl;
89 }
90
91 // fshr(ShVal0, ShVal1, ShAmt)
92 // == (ShVal0 >> ShAmt) | (ShVal1 << (Width - ShAmt))
93 if (match(V,
94 P: m_OneUse(SubPattern: m_c_Or(L: m_Shl(L: m_Value(V&: ShVal0), R: m_Sub(L: m_SpecificInt(V: Width),
95 R: m_Value(V&: ShAmt))),
96 R: m_LShr(L: m_Value(V&: ShVal1), R: m_Deferred(V: ShAmt)))))) {
97 return Intrinsic::fshr;
98 }
99
100 return Intrinsic::not_intrinsic;
101 };
102
103 // One phi operand must be a funnel/rotate operation, and the other phi
104 // operand must be the source value of that funnel/rotate operation:
105 // phi [ rotate(RotSrc, ShAmt), FunnelBB ], [ RotSrc, GuardBB ]
106 // phi [ fshl(ShVal0, ShVal1, ShAmt), FunnelBB ], [ ShVal0, GuardBB ]
107 // phi [ fshr(ShVal0, ShVal1, ShAmt), FunnelBB ], [ ShVal1, GuardBB ]
108 PHINode &Phi = cast<PHINode>(Val&: I);
109 unsigned FunnelOp = 0, GuardOp = 1;
110 Value *P0 = Phi.getOperand(i_nocapture: 0), *P1 = Phi.getOperand(i_nocapture: 1);
111 Value *ShVal0, *ShVal1, *ShAmt;
112 Intrinsic::ID IID = matchFunnelShift(P0, ShVal0, ShVal1, ShAmt);
113 if (IID == Intrinsic::not_intrinsic ||
114 (IID == Intrinsic::fshl && ShVal0 != P1) ||
115 (IID == Intrinsic::fshr && ShVal1 != P1)) {
116 IID = matchFunnelShift(P1, ShVal0, ShVal1, ShAmt);
117 if (IID == Intrinsic::not_intrinsic ||
118 (IID == Intrinsic::fshl && ShVal0 != P0) ||
119 (IID == Intrinsic::fshr && ShVal1 != P0))
120 return false;
121 assert((IID == Intrinsic::fshl || IID == Intrinsic::fshr) &&
122 "Pattern must match funnel shift left or right");
123 std::swap(a&: FunnelOp, b&: GuardOp);
124 }
125
126 // The incoming block with our source operand must be the "guard" block.
127 // That must contain a cmp+branch to avoid the funnel/rotate when the shift
128 // amount is equal to 0. The other incoming block is the block with the
129 // funnel/rotate.
130 BasicBlock *GuardBB = Phi.getIncomingBlock(i: GuardOp);
131 BasicBlock *FunnelBB = Phi.getIncomingBlock(i: FunnelOp);
132 Instruction *TermI = GuardBB->getTerminator();
133
134 // Ensure that the shift values dominate each block.
135 if (!DT.dominates(Def: ShVal0, User: TermI) || !DT.dominates(Def: ShVal1, User: TermI))
136 return false;
137
138 BasicBlock *PhiBB = Phi.getParent();
139 if (!match(V: TermI, P: m_Br(C: m_SpecificICmp(MatchPred: CmpInst::ICMP_EQ, L: m_Specific(V: ShAmt),
140 R: m_ZeroInt()),
141 T: m_SpecificBB(BB: PhiBB), F: m_SpecificBB(BB: FunnelBB))))
142 return false;
143
144 IRBuilder<> Builder(PhiBB, PhiBB->getFirstInsertionPt());
145
146 if (ShVal0 == ShVal1)
147 ++NumGuardedRotates;
148 else
149 ++NumGuardedFunnelShifts;
150
151 // If this is not a rotate then the select was blocking poison from the
152 // 'shift-by-zero' non-TVal, but a funnel shift won't - so freeze it.
153 bool IsFshl = IID == Intrinsic::fshl;
154 if (ShVal0 != ShVal1) {
155 if (IsFshl && !llvm::isGuaranteedNotToBePoison(V: ShVal1))
156 ShVal1 = Builder.CreateFreeze(V: ShVal1);
157 else if (!IsFshl && !llvm::isGuaranteedNotToBePoison(V: ShVal0))
158 ShVal0 = Builder.CreateFreeze(V: ShVal0);
159 }
160
161 // We matched a variation of this IR pattern:
162 // GuardBB:
163 // %cmp = icmp eq i32 %ShAmt, 0
164 // br i1 %cmp, label %PhiBB, label %FunnelBB
165 // FunnelBB:
166 // %sub = sub i32 32, %ShAmt
167 // %shr = lshr i32 %ShVal1, %sub
168 // %shl = shl i32 %ShVal0, %ShAmt
169 // %fsh = or i32 %shr, %shl
170 // br label %PhiBB
171 // PhiBB:
172 // %cond = phi i32 [ %fsh, %FunnelBB ], [ %ShVal0, %GuardBB ]
173 // -->
174 // llvm.fshl.i32(i32 %ShVal0, i32 %ShVal1, i32 %ShAmt)
175 Phi.replaceAllUsesWith(
176 V: Builder.CreateIntrinsic(ID: IID, Types: Phi.getType(), Args: {ShVal0, ShVal1, ShAmt}));
177 return true;
178}
179
180/// This is used by foldAnyOrAllBitsSet() to capture a source value (Root) and
181/// the bit indexes (Mask) needed by a masked compare. If we're matching a chain
182/// of 'and' ops, then we also need to capture the fact that we saw an
183/// "and X, 1", so that's an extra return value for that case.
184namespace {
185struct MaskOps {
186 Value *Root = nullptr;
187 APInt Mask;
188 bool MatchAndChain;
189 bool FoundAnd1 = false;
190
191 MaskOps(unsigned BitWidth, bool MatchAnds)
192 : Mask(APInt::getZero(numBits: BitWidth)), MatchAndChain(MatchAnds) {}
193};
194} // namespace
195
196/// This is a recursive helper for foldAnyOrAllBitsSet() that walks through a
197/// chain of 'and' or 'or' instructions looking for shift ops of a common source
198/// value. Examples:
199/// or (or (or X, (X >> 3)), (X >> 5)), (X >> 8)
200/// returns { X, 0x129 }
201/// and (and (X >> 1), 1), (X >> 4)
202/// returns { X, 0x12 }
203static bool matchAndOrChain(Value *V, MaskOps &MOps) {
204 Value *Op0, *Op1;
205 if (MOps.MatchAndChain) {
206 // Recurse through a chain of 'and' operands. This requires an extra check
207 // vs. the 'or' matcher: we must find an "and X, 1" instruction somewhere
208 // in the chain to know that all of the high bits are cleared.
209 if (match(V, P: m_And(L: m_Value(V&: Op0), R: m_One()))) {
210 MOps.FoundAnd1 = true;
211 return matchAndOrChain(V: Op0, MOps);
212 }
213 if (match(V, P: m_And(L: m_Value(V&: Op0), R: m_Value(V&: Op1))))
214 return matchAndOrChain(V: Op0, MOps) && matchAndOrChain(V: Op1, MOps);
215 } else {
216 // Recurse through a chain of 'or' operands.
217 if (match(V, P: m_Or(L: m_Value(V&: Op0), R: m_Value(V&: Op1))))
218 return matchAndOrChain(V: Op0, MOps) && matchAndOrChain(V: Op1, MOps);
219 }
220
221 // We need a shift-right or a bare value representing a compare of bit 0 of
222 // the original source operand.
223 Value *Candidate;
224 const APInt *BitIndex = nullptr;
225 if (!match(V, P: m_LShr(L: m_Value(V&: Candidate), R: m_APInt(Res&: BitIndex))))
226 Candidate = V;
227
228 // Initialize result source operand.
229 if (!MOps.Root)
230 MOps.Root = Candidate;
231
232 // The shift constant is out-of-range? This code hasn't been simplified.
233 if (BitIndex && BitIndex->uge(RHS: MOps.Mask.getBitWidth()))
234 return false;
235
236 // Fill in the mask bit derived from the shift constant.
237 MOps.Mask.setBit(BitIndex ? BitIndex->getZExtValue() : 0);
238 return MOps.Root == Candidate;
239}
240
241/// Match patterns that correspond to "any-bits-set" and "all-bits-set".
242/// These will include a chain of 'or' or 'and'-shifted bits from a
243/// common source value:
244/// and (or (lshr X, C), ...), 1 --> (X & CMask) != 0
245/// and (and (lshr X, C), ...), 1 --> (X & CMask) == CMask
246/// Note: "any-bits-clear" and "all-bits-clear" are variations of these patterns
247/// that differ only with a final 'not' of the result. We expect that final
248/// 'not' to be folded with the compare that we create here (invert predicate).
249static bool foldAnyOrAllBitsSet(Instruction &I) {
250 // The 'any-bits-set' ('or' chain) pattern is simpler to match because the
251 // final "and X, 1" instruction must be the final op in the sequence.
252 bool MatchAllBitsSet;
253 if (match(V: &I, P: m_c_And(L: m_OneUse(SubPattern: m_And(L: m_Value(), R: m_Value())), R: m_Value())))
254 MatchAllBitsSet = true;
255 else if (match(V: &I, P: m_And(L: m_OneUse(SubPattern: m_Or(L: m_Value(), R: m_Value())), R: m_One())))
256 MatchAllBitsSet = false;
257 else
258 return false;
259
260 MaskOps MOps(I.getType()->getScalarSizeInBits(), MatchAllBitsSet);
261 if (MatchAllBitsSet) {
262 if (!matchAndOrChain(V: cast<BinaryOperator>(Val: &I), MOps) || !MOps.FoundAnd1)
263 return false;
264 } else {
265 if (!matchAndOrChain(V: cast<BinaryOperator>(Val: &I)->getOperand(i_nocapture: 0), MOps))
266 return false;
267 }
268
269 // The pattern was found. Create a masked compare that replaces all of the
270 // shift and logic ops.
271 IRBuilder<> Builder(&I);
272 Constant *Mask = ConstantInt::get(Ty: I.getType(), V: MOps.Mask);
273 Value *And = Builder.CreateAnd(LHS: MOps.Root, RHS: Mask);
274 Value *Cmp = MatchAllBitsSet ? Builder.CreateICmpEQ(LHS: And, RHS: Mask)
275 : Builder.CreateIsNotNull(Arg: And);
276 Value *Zext = Builder.CreateZExt(V: Cmp, DestTy: I.getType());
277 I.replaceAllUsesWith(V: Zext);
278 ++NumAnyOrAllBitsSet;
279 return true;
280}
281
282// Try to recognize below function as popcount intrinsic.
283// This is the "best" algorithm from
284// http://graphics.stanford.edu/~seander/bithacks.html#CountBitsSetParallel
285// Also used in TargetLowering::expandCTPOP().
286//
287// int popcount(unsigned int i) {
288// i = i - ((i >> 1) & 0x55555555);
289// i = (i & 0x33333333) + ((i >> 2) & 0x33333333);
290// i = ((i + (i >> 4)) & 0x0F0F0F0F);
291// return (i * 0x01010101) >> 24;
292// }
293static bool tryToRecognizePopCount(Instruction &I) {
294 if (I.getOpcode() != Instruction::LShr)
295 return false;
296
297 Type *Ty = I.getType();
298 if (!Ty->isIntOrIntVectorTy())
299 return false;
300
301 unsigned Len = Ty->getScalarSizeInBits();
302 // FIXME: fix Len == 8 and other irregular type lengths.
303 if (!(Len <= 128 && Len > 8 && Len % 8 == 0))
304 return false;
305
306 APInt Mask55 = APInt::getSplat(NewLen: Len, V: APInt(8, 0x55));
307 APInt Mask33 = APInt::getSplat(NewLen: Len, V: APInt(8, 0x33));
308 APInt Mask0F = APInt::getSplat(NewLen: Len, V: APInt(8, 0x0F));
309 APInt Mask01 = APInt::getSplat(NewLen: Len, V: APInt(8, 0x01));
310 APInt MaskShift = APInt(Len, Len - 8);
311
312 Value *Op0 = I.getOperand(i: 0);
313 Value *Op1 = I.getOperand(i: 1);
314 Value *MulOp0;
315 // Matching "(i * 0x01010101...) >> 24".
316 if ((match(V: Op0, P: m_Mul(L: m_Value(V&: MulOp0), R: m_SpecificInt(V: Mask01)))) &&
317 match(V: Op1, P: m_SpecificInt(V: MaskShift))) {
318 Value *ShiftOp0;
319 // Matching "((i + (i >> 4)) & 0x0F0F0F0F...)".
320 if (match(V: MulOp0, P: m_And(L: m_c_Add(L: m_LShr(L: m_Value(V&: ShiftOp0), R: m_SpecificInt(V: 4)),
321 R: m_Deferred(V: ShiftOp0)),
322 R: m_SpecificInt(V: Mask0F)))) {
323 Value *AndOp0;
324 // Matching "(i & 0x33333333...) + ((i >> 2) & 0x33333333...)".
325 if (match(V: ShiftOp0,
326 P: m_c_Add(L: m_And(L: m_Value(V&: AndOp0), R: m_SpecificInt(V: Mask33)),
327 R: m_And(L: m_LShr(L: m_Deferred(V: AndOp0), R: m_SpecificInt(V: 2)),
328 R: m_SpecificInt(V: Mask33))))) {
329 Value *Root, *SubOp1;
330 // Matching "i - ((i >> 1) & 0x55555555...)".
331 const APInt *AndMask;
332 if (match(V: AndOp0, P: m_Sub(L: m_Value(V&: Root), R: m_Value(V&: SubOp1))) &&
333 match(V: SubOp1, P: m_And(L: m_LShr(L: m_Specific(V: Root), R: m_SpecificInt(V: 1)),
334 R: m_APInt(Res&: AndMask)))) {
335 auto CheckAndMask = [&]() {
336 if (*AndMask == Mask55)
337 return true;
338
339 // Exact match failed, see if any bits are known to be 0 where we
340 // expect a 1 in the mask.
341 if (!AndMask->isSubsetOf(RHS: Mask55))
342 return false;
343
344 APInt NeededMask = Mask55 & ~*AndMask;
345 return MaskedValueIsZero(V: cast<Instruction>(Val: SubOp1)->getOperand(i: 0),
346 Mask: NeededMask,
347 SQ: SimplifyQuery(I.getDataLayout()));
348 };
349
350 if (CheckAndMask()) {
351 LLVM_DEBUG(dbgs() << "Recognized popcount intrinsic\n");
352 IRBuilder<> Builder(&I);
353 I.replaceAllUsesWith(
354 V: Builder.CreateIntrinsic(ID: Intrinsic::ctpop, Types: I.getType(), Args: {Root}));
355 ++NumPopCountRecognized;
356 return true;
357 }
358 }
359 }
360 }
361 }
362
363 return false;
364}
365
366/// Fold smin(smax(fptosi(x), C1), C2) to llvm.fptosi.sat(x), providing C1 and
367/// C2 saturate the value of the fp conversion. The transform is not reversable
368/// as the fptosi.sat is more defined than the input - all values produce a
369/// valid value for the fptosi.sat, where as some produce poison for original
370/// that were out of range of the integer conversion. The reversed pattern may
371/// use fmax and fmin instead. As we cannot directly reverse the transform, and
372/// it is not always profitable, we make it conditional on the cost being
373/// reported as lower by TTI.
374static bool tryToFPToSat(Instruction &I, TargetTransformInfo &TTI) {
375 // Look for min(max(fptosi, converting to fptosi_sat.
376 Value *In;
377 const APInt *MinC, *MaxC;
378 if (!match(V: &I, P: m_SMax(L: m_OneUse(SubPattern: m_SMin(L: m_OneUse(SubPattern: m_FPToSI(Op: m_Value(V&: In))),
379 R: m_APInt(Res&: MinC))),
380 R: m_APInt(Res&: MaxC))) &&
381 !match(V: &I, P: m_SMin(L: m_OneUse(SubPattern: m_SMax(L: m_OneUse(SubPattern: m_FPToSI(Op: m_Value(V&: In))),
382 R: m_APInt(Res&: MaxC))),
383 R: m_APInt(Res&: MinC))))
384 return false;
385
386 // Check that the constants clamp a saturate.
387 if (!(*MinC + 1).isPowerOf2() || -*MaxC != *MinC + 1)
388 return false;
389
390 Type *IntTy = I.getType();
391 Type *FpTy = In->getType();
392 Type *SatTy =
393 IntegerType::get(C&: IntTy->getContext(), NumBits: (*MinC + 1).exactLogBase2() + 1);
394 if (auto *VecTy = dyn_cast<VectorType>(Val: IntTy))
395 SatTy = VectorType::get(ElementType: SatTy, EC: VecTy->getElementCount());
396
397 // Get the cost of the intrinsic, and check that against the cost of
398 // fptosi+smin+smax
399 InstructionCost SatCost = TTI.getIntrinsicInstrCost(
400 ICA: IntrinsicCostAttributes(Intrinsic::fptosi_sat, SatTy, {In}, {FpTy}),
401 CostKind: TTI::TCK_RecipThroughput);
402 SatCost += TTI.getCastInstrCost(Opcode: Instruction::SExt, Dst: IntTy, Src: SatTy,
403 CCH: TTI::CastContextHint::None,
404 CostKind: TTI::TCK_RecipThroughput);
405
406 InstructionCost MinMaxCost = TTI.getCastInstrCost(
407 Opcode: Instruction::FPToSI, Dst: IntTy, Src: FpTy, CCH: TTI::CastContextHint::None,
408 CostKind: TTI::TCK_RecipThroughput);
409 MinMaxCost += TTI.getIntrinsicInstrCost(
410 ICA: IntrinsicCostAttributes(Intrinsic::smin, IntTy, {IntTy}),
411 CostKind: TTI::TCK_RecipThroughput);
412 MinMaxCost += TTI.getIntrinsicInstrCost(
413 ICA: IntrinsicCostAttributes(Intrinsic::smax, IntTy, {IntTy}),
414 CostKind: TTI::TCK_RecipThroughput);
415
416 if (SatCost >= MinMaxCost)
417 return false;
418
419 IRBuilder<> Builder(&I);
420 Value *Sat =
421 Builder.CreateIntrinsic(ID: Intrinsic::fptosi_sat, Types: {SatTy, FpTy}, Args: In);
422 I.replaceAllUsesWith(V: Builder.CreateSExt(V: Sat, DestTy: IntTy));
423 return true;
424}
425
426/// Try to replace a mathlib call to sqrt with the LLVM intrinsic. This avoids
427/// pessimistic codegen that has to account for setting errno and can enable
428/// vectorization.
429static bool foldSqrt(CallInst *Call, LibFunc Func, TargetTransformInfo &TTI,
430 TargetLibraryInfo &TLI, AssumptionCache &AC,
431 DominatorTree &DT) {
432 // If (1) this is a sqrt libcall, (2) we can assume that NAN is not created
433 // (because NNAN or the operand arg must not be less than -0.0) and (2) we
434 // would not end up lowering to a libcall anyway (which could change the value
435 // of errno), then:
436 // (1) errno won't be set.
437 // (2) it is safe to convert this to an intrinsic call.
438 Type *Ty = Call->getType();
439 Value *Arg = Call->getArgOperand(i: 0);
440 if (TTI.haveFastSqrt(Ty) &&
441 (Call->hasNoNaNs() ||
442 cannotBeOrderedLessThanZero(
443 V: Arg, SQ: SimplifyQuery(Call->getDataLayout(), &TLI, &DT, &AC, Call)))) {
444 IRBuilder<> Builder(Call);
445 Value *NewSqrt =
446 Builder.CreateIntrinsic(ID: Intrinsic::sqrt, Types: Ty, Args: Arg, FMFSource: Call, Name: "sqrt");
447 Call->replaceAllUsesWith(V: NewSqrt);
448
449 // Explicitly erase the old call because a call with side effects is not
450 // trivially dead.
451 Call->eraseFromParent();
452 return true;
453 }
454
455 return false;
456}
457
458// Check if this array of constants represents a cttz table.
459// Iterate over the elements from \p Table by trying to find/match all
460// the numbers from 0 to \p InputBits that should represent cttz results.
461static bool isCTTZTable(const ConstantDataArray &Table, uint64_t Mul,
462 uint64_t Shift, uint64_t InputBits) {
463 unsigned Length = Table.getNumElements();
464 if (Length < InputBits || Length > InputBits * 2)
465 return false;
466
467 APInt Mask = APInt::getBitsSetFrom(numBits: InputBits, loBit: Shift);
468 unsigned Matched = 0;
469
470 for (unsigned i = 0; i < Length; i++) {
471 uint64_t Element = Table.getElementAsInteger(i);
472 if (Element >= InputBits)
473 continue;
474
475 // Check if \p Element matches a concrete answer. It could fail for some
476 // elements that are never accessed, so we keep iterating over each element
477 // from the table. The number of matched elements should be equal to the
478 // number of potential right answers which is \p InputBits actually.
479 if ((((Mul << Element) & Mask.getZExtValue()) >> Shift) == i)
480 Matched++;
481 }
482
483 return Matched == InputBits;
484}
485
486// Try to recognize table-based ctz implementation.
487// E.g., an example in C (for more cases please see the llvm/tests):
488// int f(unsigned x) {
489// static const char table[32] =
490// {0, 1, 28, 2, 29, 14, 24, 3, 30,
491// 22, 20, 15, 25, 17, 4, 8, 31, 27,
492// 13, 23, 21, 19, 16, 7, 26, 12, 18, 6, 11, 5, 10, 9};
493// return table[((unsigned)((x & -x) * 0x077CB531U)) >> 27];
494// }
495// this can be lowered to `cttz` instruction.
496// There is also a special case when the element is 0.
497//
498// Here are some examples or LLVM IR for a 64-bit target:
499//
500// CASE 1:
501// %sub = sub i32 0, %x
502// %and = and i32 %sub, %x
503// %mul = mul i32 %and, 125613361
504// %shr = lshr i32 %mul, 27
505// %idxprom = zext i32 %shr to i64
506// %arrayidx = getelementptr inbounds [32 x i8], [32 x i8]* @ctz1.table, i64 0,
507// i64 %idxprom
508// %0 = load i8, i8* %arrayidx, align 1, !tbaa !8
509//
510// CASE 2:
511// %sub = sub i32 0, %x
512// %and = and i32 %sub, %x
513// %mul = mul i32 %and, 72416175
514// %shr = lshr i32 %mul, 26
515// %idxprom = zext i32 %shr to i64
516// %arrayidx = getelementptr inbounds [64 x i16], [64 x i16]* @ctz2.table,
517// i64 0, i64 %idxprom
518// %0 = load i16, i16* %arrayidx, align 2, !tbaa !8
519//
520// CASE 3:
521// %sub = sub i32 0, %x
522// %and = and i32 %sub, %x
523// %mul = mul i32 %and, 81224991
524// %shr = lshr i32 %mul, 27
525// %idxprom = zext i32 %shr to i64
526// %arrayidx = getelementptr inbounds [32 x i32], [32 x i32]* @ctz3.table,
527// i64 0, i64 %idxprom
528// %0 = load i32, i32* %arrayidx, align 4, !tbaa !8
529//
530// CASE 4:
531// %sub = sub i64 0, %x
532// %and = and i64 %sub, %x
533// %mul = mul i64 %and, 283881067100198605
534// %shr = lshr i64 %mul, 58
535// %arrayidx = getelementptr inbounds [64 x i8], [64 x i8]* @table, i64 0,
536// i64 %shr
537// %0 = load i8, i8* %arrayidx, align 1, !tbaa !8
538//
539// All this can be lowered to @llvm.cttz.i32/64 intrinsic.
540static bool tryToRecognizeTableBasedCttz(Instruction &I) {
541 LoadInst *LI = dyn_cast<LoadInst>(Val: &I);
542 if (!LI)
543 return false;
544
545 Type *AccessType = LI->getType();
546 if (!AccessType->isIntegerTy())
547 return false;
548
549 GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(Val: LI->getPointerOperand());
550 if (!GEP || !GEP->hasNoUnsignedSignedWrap() || GEP->getNumIndices() != 2)
551 return false;
552
553 if (!GEP->getSourceElementType()->isArrayTy())
554 return false;
555
556 uint64_t ArraySize = GEP->getSourceElementType()->getArrayNumElements();
557 if (ArraySize != 32 && ArraySize != 64)
558 return false;
559
560 GlobalVariable *GVTable = dyn_cast<GlobalVariable>(Val: GEP->getPointerOperand());
561 if (!GVTable || !GVTable->hasInitializer() || !GVTable->isConstant())
562 return false;
563
564 ConstantDataArray *ConstData =
565 dyn_cast<ConstantDataArray>(Val: GVTable->getInitializer());
566 if (!ConstData)
567 return false;
568
569 if (!match(V: GEP->idx_begin()->get(), P: m_ZeroInt()))
570 return false;
571
572 Value *Idx2 = std::next(x: GEP->idx_begin())->get();
573 Value *X1;
574 uint64_t MulConst, ShiftConst;
575 // FIXME: 64-bit targets have `i64` type for the GEP index, so this match will
576 // probably fail for other (e.g. 32-bit) targets.
577 if (!match(V: Idx2, P: m_ZExtOrSelf(
578 Op: m_LShr(L: m_Mul(L: m_c_And(L: m_Neg(V: m_Value(V&: X1)), R: m_Deferred(V: X1)),
579 R: m_ConstantInt(V&: MulConst)),
580 R: m_ConstantInt(V&: ShiftConst)))))
581 return false;
582
583 unsigned InputBits = X1->getType()->getScalarSizeInBits();
584 if (InputBits != 32 && InputBits != 64)
585 return false;
586
587 // Shift should extract top 5..7 bits.
588 if (InputBits - Log2_32(Value: InputBits) != ShiftConst &&
589 InputBits - Log2_32(Value: InputBits) - 1 != ShiftConst)
590 return false;
591
592 if (!isCTTZTable(Table: *ConstData, Mul: MulConst, Shift: ShiftConst, InputBits))
593 return false;
594
595 auto ZeroTableElem = ConstData->getElementAsInteger(i: 0);
596 bool DefinedForZero = ZeroTableElem == InputBits;
597
598 IRBuilder<> B(LI);
599 ConstantInt *BoolConst = B.getInt1(V: !DefinedForZero);
600 Type *XType = X1->getType();
601 auto Cttz = B.CreateIntrinsic(ID: Intrinsic::cttz, Types: {XType}, Args: {X1, BoolConst});
602 Value *ZExtOrTrunc = nullptr;
603
604 if (DefinedForZero) {
605 ZExtOrTrunc = B.CreateZExtOrTrunc(V: Cttz, DestTy: AccessType);
606 } else {
607 // If the value in elem 0 isn't the same as InputBits, we still want to
608 // produce the value from the table.
609 auto Cmp = B.CreateICmpEQ(LHS: X1, RHS: ConstantInt::get(Ty: XType, V: 0));
610 auto Select =
611 B.CreateSelect(C: Cmp, True: ConstantInt::get(Ty: XType, V: ZeroTableElem), False: Cttz);
612
613 // NOTE: If the table[0] is 0, but the cttz(0) is defined by the Target
614 // it should be handled as: `cttz(x) & (typeSize - 1)`.
615
616 ZExtOrTrunc = B.CreateZExtOrTrunc(V: Select, DestTy: AccessType);
617 }
618
619 LI->replaceAllUsesWith(V: ZExtOrTrunc);
620
621 return true;
622}
623
624/// This is used by foldLoadsRecursive() to capture a Root Load node which is
625/// of type or(load, load) and recursively build the wide load. Also capture the
626/// shift amount, zero extend type and loadSize.
627struct LoadOps {
628 LoadInst *Root = nullptr;
629 LoadInst *RootInsert = nullptr;
630 bool FoundRoot = false;
631 uint64_t LoadSize = 0;
632 const APInt *Shift = nullptr;
633 Type *ZextType;
634 AAMDNodes AATags;
635};
636
637// Identify and Merge consecutive loads recursively which is of the form
638// (ZExt(L1) << shift1) | (ZExt(L2) << shift2) -> ZExt(L3) << shift1
639// (ZExt(L1) << shift1) | ZExt(L2) -> ZExt(L3)
640static bool foldLoadsRecursive(Value *V, LoadOps &LOps, const DataLayout &DL,
641 AliasAnalysis &AA) {
642 const APInt *ShAmt2 = nullptr;
643 Value *X;
644 Instruction *L1, *L2;
645
646 // Go to the last node with loads.
647 if (match(V, P: m_OneUse(SubPattern: m_c_Or(
648 L: m_Value(V&: X),
649 R: m_OneUse(SubPattern: m_Shl(L: m_OneUse(SubPattern: m_ZExt(Op: m_OneUse(SubPattern: m_Instruction(I&: L2)))),
650 R: m_APInt(Res&: ShAmt2)))))) ||
651 match(V, P: m_OneUse(SubPattern: m_Or(L: m_Value(V&: X),
652 R: m_OneUse(SubPattern: m_ZExt(Op: m_OneUse(SubPattern: m_Instruction(I&: L2)))))))) {
653 if (!foldLoadsRecursive(V: X, LOps, DL, AA) && LOps.FoundRoot)
654 // Avoid Partial chain merge.
655 return false;
656 } else
657 return false;
658
659 // Check if the pattern has loads
660 LoadInst *LI1 = LOps.Root;
661 const APInt *ShAmt1 = LOps.Shift;
662 if (LOps.FoundRoot == false &&
663 (match(V: X, P: m_OneUse(SubPattern: m_ZExt(Op: m_Instruction(I&: L1)))) ||
664 match(V: X, P: m_OneUse(SubPattern: m_Shl(L: m_OneUse(SubPattern: m_ZExt(Op: m_OneUse(SubPattern: m_Instruction(I&: L1)))),
665 R: m_APInt(Res&: ShAmt1)))))) {
666 LI1 = dyn_cast<LoadInst>(Val: L1);
667 }
668 LoadInst *LI2 = dyn_cast<LoadInst>(Val: L2);
669
670 // Check if loads are same, atomic, volatile and having same address space.
671 if (LI1 == LI2 || !LI1 || !LI2 || !LI1->isSimple() || !LI2->isSimple() ||
672 LI1->getPointerAddressSpace() != LI2->getPointerAddressSpace())
673 return false;
674
675 // Check if Loads come from same BB.
676 if (LI1->getParent() != LI2->getParent())
677 return false;
678
679 // Find the data layout
680 bool IsBigEndian = DL.isBigEndian();
681
682 // Check if loads are consecutive and same size.
683 Value *Load1Ptr = LI1->getPointerOperand();
684 APInt Offset1(DL.getIndexTypeSizeInBits(Ty: Load1Ptr->getType()), 0);
685 Load1Ptr =
686 Load1Ptr->stripAndAccumulateConstantOffsets(DL, Offset&: Offset1,
687 /* AllowNonInbounds */ true);
688
689 Value *Load2Ptr = LI2->getPointerOperand();
690 APInt Offset2(DL.getIndexTypeSizeInBits(Ty: Load2Ptr->getType()), 0);
691 Load2Ptr =
692 Load2Ptr->stripAndAccumulateConstantOffsets(DL, Offset&: Offset2,
693 /* AllowNonInbounds */ true);
694
695 // Verify if both loads have same base pointers
696 uint64_t LoadSize1 = LI1->getType()->getPrimitiveSizeInBits();
697 uint64_t LoadSize2 = LI2->getType()->getPrimitiveSizeInBits();
698 if (Load1Ptr != Load2Ptr)
699 return false;
700
701 // Make sure that there are no padding bits.
702 if (!DL.typeSizeEqualsStoreSize(Ty: LI1->getType()) ||
703 !DL.typeSizeEqualsStoreSize(Ty: LI2->getType()))
704 return false;
705
706 // Alias Analysis to check for stores b/w the loads.
707 LoadInst *Start = LOps.FoundRoot ? LOps.RootInsert : LI1, *End = LI2;
708 MemoryLocation Loc;
709 if (!Start->comesBefore(Other: End)) {
710 std::swap(a&: Start, b&: End);
711 Loc = MemoryLocation::get(LI: End);
712 if (LOps.FoundRoot)
713 Loc = Loc.getWithNewSize(NewSize: LOps.LoadSize);
714 } else
715 Loc = MemoryLocation::get(LI: End);
716 unsigned NumScanned = 0;
717 for (Instruction &Inst :
718 make_range(x: Start->getIterator(), y: End->getIterator())) {
719 if (Inst.mayWriteToMemory() && isModSet(MRI: AA.getModRefInfo(I: &Inst, OptLoc: Loc)))
720 return false;
721
722 if (++NumScanned > MaxInstrsToScan)
723 return false;
724 }
725
726 // Make sure Load with lower Offset is at LI1
727 bool Reverse = false;
728 if (Offset2.slt(RHS: Offset1)) {
729 std::swap(a&: LI1, b&: LI2);
730 std::swap(a&: ShAmt1, b&: ShAmt2);
731 std::swap(a&: Offset1, b&: Offset2);
732 std::swap(a&: Load1Ptr, b&: Load2Ptr);
733 std::swap(a&: LoadSize1, b&: LoadSize2);
734 Reverse = true;
735 }
736
737 // Big endian swap the shifts
738 if (IsBigEndian)
739 std::swap(a&: ShAmt1, b&: ShAmt2);
740
741 // Find Shifts values.
742 uint64_t Shift1 = 0, Shift2 = 0;
743 if (ShAmt1)
744 Shift1 = ShAmt1->getZExtValue();
745 if (ShAmt2)
746 Shift2 = ShAmt2->getZExtValue();
747
748 // First load is always LI1. This is where we put the new load.
749 // Use the merged load size available from LI1 for forward loads.
750 if (LOps.FoundRoot) {
751 if (!Reverse)
752 LoadSize1 = LOps.LoadSize;
753 else
754 LoadSize2 = LOps.LoadSize;
755 }
756
757 // Verify if shift amount and load index aligns and verifies that loads
758 // are consecutive.
759 uint64_t ShiftDiff = IsBigEndian ? LoadSize2 : LoadSize1;
760 uint64_t PrevSize =
761 DL.getTypeStoreSize(Ty: IntegerType::get(C&: LI1->getContext(), NumBits: LoadSize1));
762 if ((Shift2 - Shift1) != ShiftDiff || (Offset2 - Offset1) != PrevSize)
763 return false;
764
765 // Update LOps
766 AAMDNodes AATags1 = LOps.AATags;
767 AAMDNodes AATags2 = LI2->getAAMetadata();
768 if (LOps.FoundRoot == false) {
769 LOps.FoundRoot = true;
770 AATags1 = LI1->getAAMetadata();
771 }
772 LOps.LoadSize = LoadSize1 + LoadSize2;
773 LOps.RootInsert = Start;
774
775 // Concatenate the AATags of the Merged Loads.
776 LOps.AATags = AATags1.concat(Other: AATags2);
777
778 LOps.Root = LI1;
779 LOps.Shift = ShAmt1;
780 LOps.ZextType = X->getType();
781 return true;
782}
783
784// For a given BB instruction, evaluate all loads in the chain that form a
785// pattern which suggests that the loads can be combined. The one and only use
786// of the loads is to form a wider load.
787static bool foldConsecutiveLoads(Instruction &I, const DataLayout &DL,
788 TargetTransformInfo &TTI, AliasAnalysis &AA,
789 const DominatorTree &DT) {
790 // Only consider load chains of scalar values.
791 if (isa<VectorType>(Val: I.getType()))
792 return false;
793
794 LoadOps LOps;
795 if (!foldLoadsRecursive(V: &I, LOps, DL, AA) || !LOps.FoundRoot)
796 return false;
797
798 IRBuilder<> Builder(&I);
799 LoadInst *NewLoad = nullptr, *LI1 = LOps.Root;
800
801 IntegerType *WiderType = IntegerType::get(C&: I.getContext(), NumBits: LOps.LoadSize);
802 // TTI based checks if we want to proceed with wider load
803 bool Allowed = TTI.isTypeLegal(Ty: WiderType);
804 if (!Allowed)
805 return false;
806
807 unsigned AS = LI1->getPointerAddressSpace();
808 unsigned Fast = 0;
809 Allowed = TTI.allowsMisalignedMemoryAccesses(Context&: I.getContext(), BitWidth: LOps.LoadSize,
810 AddressSpace: AS, Alignment: LI1->getAlign(), Fast: &Fast);
811 if (!Allowed || !Fast)
812 return false;
813
814 // Get the Index and Ptr for the new GEP.
815 Value *Load1Ptr = LI1->getPointerOperand();
816 Builder.SetInsertPoint(LOps.RootInsert);
817 if (!DT.dominates(Def: Load1Ptr, User: LOps.RootInsert)) {
818 APInt Offset1(DL.getIndexTypeSizeInBits(Ty: Load1Ptr->getType()), 0);
819 Load1Ptr = Load1Ptr->stripAndAccumulateConstantOffsets(
820 DL, Offset&: Offset1, /* AllowNonInbounds */ true);
821 Load1Ptr = Builder.CreatePtrAdd(Ptr: Load1Ptr, Offset: Builder.getInt(AI: Offset1));
822 }
823 // Generate wider load.
824 NewLoad = Builder.CreateAlignedLoad(Ty: WiderType, Ptr: Load1Ptr, Align: LI1->getAlign(),
825 isVolatile: LI1->isVolatile(), Name: "");
826 NewLoad->takeName(V: LI1);
827 // Set the New Load AATags Metadata.
828 if (LOps.AATags)
829 NewLoad->setAAMetadata(LOps.AATags);
830
831 Value *NewOp = NewLoad;
832 // Check if zero extend needed.
833 if (LOps.ZextType)
834 NewOp = Builder.CreateZExt(V: NewOp, DestTy: LOps.ZextType);
835
836 // Check if shift needed. We need to shift with the amount of load1
837 // shift if not zero.
838 if (LOps.Shift)
839 NewOp = Builder.CreateShl(LHS: NewOp, RHS: ConstantInt::get(Context&: I.getContext(), V: *LOps.Shift));
840 I.replaceAllUsesWith(V: NewOp);
841
842 return true;
843}
844
845/// Combine away instructions providing they are still equivalent when compared
846/// against 0. i.e do they have any bits set.
847static Value *optimizeShiftInOrChain(Value *V, IRBuilder<> &Builder) {
848 auto *I = dyn_cast<Instruction>(Val: V);
849 if (!I || I->getOpcode() != Instruction::Or || !I->hasOneUse())
850 return nullptr;
851
852 Value *A;
853
854 // Look deeper into the chain of or's, combining away shl (so long as they are
855 // nuw or nsw).
856 Value *Op0 = I->getOperand(i: 0);
857 if (match(V: Op0, P: m_CombineOr(L: m_NSWShl(L: m_Value(V&: A), R: m_Value()),
858 R: m_NUWShl(L: m_Value(V&: A), R: m_Value()))))
859 Op0 = A;
860 else if (auto *NOp = optimizeShiftInOrChain(V: Op0, Builder))
861 Op0 = NOp;
862
863 Value *Op1 = I->getOperand(i: 1);
864 if (match(V: Op1, P: m_CombineOr(L: m_NSWShl(L: m_Value(V&: A), R: m_Value()),
865 R: m_NUWShl(L: m_Value(V&: A), R: m_Value()))))
866 Op1 = A;
867 else if (auto *NOp = optimizeShiftInOrChain(V: Op1, Builder))
868 Op1 = NOp;
869
870 if (Op0 != I->getOperand(i: 0) || Op1 != I->getOperand(i: 1))
871 return Builder.CreateOr(LHS: Op0, RHS: Op1);
872 return nullptr;
873}
874
875static bool foldICmpOrChain(Instruction &I, const DataLayout &DL,
876 TargetTransformInfo &TTI, AliasAnalysis &AA,
877 const DominatorTree &DT) {
878 CmpPredicate Pred;
879 Value *Op0;
880 if (!match(V: &I, P: m_ICmp(Pred, L: m_Value(V&: Op0), R: m_Zero())) ||
881 !ICmpInst::isEquality(P: Pred))
882 return false;
883
884 // If the chain or or's matches a load, combine to that before attempting to
885 // remove shifts.
886 if (auto OpI = dyn_cast<Instruction>(Val: Op0))
887 if (OpI->getOpcode() == Instruction::Or)
888 if (foldConsecutiveLoads(I&: *OpI, DL, TTI, AA, DT))
889 return true;
890
891 IRBuilder<> Builder(&I);
892 // icmp eq/ne or(shl(a), b), 0 -> icmp eq/ne or(a, b), 0
893 if (auto *Res = optimizeShiftInOrChain(V: Op0, Builder)) {
894 I.replaceAllUsesWith(V: Builder.CreateICmp(P: Pred, LHS: Res, RHS: I.getOperand(i: 1)));
895 return true;
896 }
897
898 return false;
899}
900
901// Calculate GEP Stride and accumulated const ModOffset. Return Stride and
902// ModOffset
903static std::pair<APInt, APInt>
904getStrideAndModOffsetOfGEP(Value *PtrOp, const DataLayout &DL) {
905 unsigned BW = DL.getIndexTypeSizeInBits(Ty: PtrOp->getType());
906 std::optional<APInt> Stride;
907 APInt ModOffset(BW, 0);
908 // Return a minimum gep stride, greatest common divisor of consective gep
909 // index scales(c.f. Bézout's identity).
910 while (auto *GEP = dyn_cast<GEPOperator>(Val: PtrOp)) {
911 SmallMapVector<Value *, APInt, 4> VarOffsets;
912 if (!GEP->collectOffset(DL, BitWidth: BW, VariableOffsets&: VarOffsets, ConstantOffset&: ModOffset))
913 break;
914
915 for (auto [V, Scale] : VarOffsets) {
916 // Only keep a power of two factor for non-inbounds
917 if (!GEP->hasNoUnsignedSignedWrap())
918 Scale = APInt::getOneBitSet(numBits: Scale.getBitWidth(), BitNo: Scale.countr_zero());
919
920 if (!Stride)
921 Stride = Scale;
922 else
923 Stride = APIntOps::GreatestCommonDivisor(A: *Stride, B: Scale);
924 }
925
926 PtrOp = GEP->getPointerOperand();
927 }
928
929 // Check whether pointer arrives back at Global Variable via at least one GEP.
930 // Even if it doesn't, we can check by alignment.
931 if (!isa<GlobalVariable>(Val: PtrOp) || !Stride)
932 return {APInt(BW, 1), APInt(BW, 0)};
933
934 // In consideration of signed GEP indices, non-negligible offset become
935 // remainder of division by minimum GEP stride.
936 ModOffset = ModOffset.srem(RHS: *Stride);
937 if (ModOffset.isNegative())
938 ModOffset += *Stride;
939
940 return {*Stride, ModOffset};
941}
942
943/// If C is a constant patterned array and all valid loaded results for given
944/// alignment are same to a constant, return that constant.
945static bool foldPatternedLoads(Instruction &I, const DataLayout &DL) {
946 auto *LI = dyn_cast<LoadInst>(Val: &I);
947 if (!LI || LI->isVolatile())
948 return false;
949
950 // We can only fold the load if it is from a constant global with definitive
951 // initializer. Skip expensive logic if this is not the case.
952 auto *PtrOp = LI->getPointerOperand();
953 auto *GV = dyn_cast<GlobalVariable>(Val: getUnderlyingObject(V: PtrOp));
954 if (!GV || !GV->isConstant() || !GV->hasDefinitiveInitializer())
955 return false;
956
957 // Bail for large initializers in excess of 4K to avoid too many scans.
958 Constant *C = GV->getInitializer();
959 uint64_t GVSize = DL.getTypeAllocSize(Ty: C->getType());
960 if (!GVSize || 4096 < GVSize)
961 return false;
962
963 Type *LoadTy = LI->getType();
964 unsigned BW = DL.getIndexTypeSizeInBits(Ty: PtrOp->getType());
965 auto [Stride, ConstOffset] = getStrideAndModOffsetOfGEP(PtrOp, DL);
966
967 // Any possible offset could be multiple of GEP stride. And any valid
968 // offset is multiple of load alignment, so checking only multiples of bigger
969 // one is sufficient to say results' equality.
970 if (auto LA = LI->getAlign();
971 LA <= GV->getAlign().valueOrOne() && Stride.getZExtValue() < LA.value()) {
972 ConstOffset = APInt(BW, 0);
973 Stride = APInt(BW, LA.value());
974 }
975
976 Constant *Ca = ConstantFoldLoadFromConst(C, Ty: LoadTy, Offset: ConstOffset, DL);
977 if (!Ca)
978 return false;
979
980 unsigned E = GVSize - DL.getTypeStoreSize(Ty: LoadTy);
981 for (; ConstOffset.getZExtValue() <= E; ConstOffset += Stride)
982 if (Ca != ConstantFoldLoadFromConst(C, Ty: LoadTy, Offset: ConstOffset, DL))
983 return false;
984
985 I.replaceAllUsesWith(V: Ca);
986
987 return true;
988}
989
990namespace {
991class StrNCmpInliner {
992public:
993 StrNCmpInliner(CallInst *CI, LibFunc Func, DomTreeUpdater *DTU,
994 const DataLayout &DL)
995 : CI(CI), Func(Func), DTU(DTU), DL(DL) {}
996
997 bool optimizeStrNCmp();
998
999private:
1000 void inlineCompare(Value *LHS, StringRef RHS, uint64_t N, bool Swapped);
1001
1002 CallInst *CI;
1003 LibFunc Func;
1004 DomTreeUpdater *DTU;
1005 const DataLayout &DL;
1006};
1007
1008} // namespace
1009
1010/// First we normalize calls to strncmp/strcmp to the form of
1011/// compare(s1, s2, N), which means comparing first N bytes of s1 and s2
1012/// (without considering '\0').
1013///
1014/// Examples:
1015///
1016/// \code
1017/// strncmp(s, "a", 3) -> compare(s, "a", 2)
1018/// strncmp(s, "abc", 3) -> compare(s, "abc", 3)
1019/// strncmp(s, "a\0b", 3) -> compare(s, "a\0b", 2)
1020/// strcmp(s, "a") -> compare(s, "a", 2)
1021///
1022/// char s2[] = {'a'}
1023/// strncmp(s, s2, 3) -> compare(s, s2, 3)
1024///
1025/// char s2[] = {'a', 'b', 'c', 'd'}
1026/// strncmp(s, s2, 3) -> compare(s, s2, 3)
1027/// \endcode
1028///
1029/// We only handle cases where N and exactly one of s1 and s2 are constant.
1030/// Cases that s1 and s2 are both constant are already handled by the
1031/// instcombine pass.
1032///
1033/// We do not handle cases where N > StrNCmpInlineThreshold.
1034///
1035/// We also do not handles cases where N < 2, which are already
1036/// handled by the instcombine pass.
1037///
1038bool StrNCmpInliner::optimizeStrNCmp() {
1039 if (StrNCmpInlineThreshold < 2)
1040 return false;
1041
1042 if (!isOnlyUsedInZeroComparison(CxtI: CI))
1043 return false;
1044
1045 Value *Str1P = CI->getArgOperand(i: 0);
1046 Value *Str2P = CI->getArgOperand(i: 1);
1047 // Should be handled elsewhere.
1048 if (Str1P == Str2P)
1049 return false;
1050
1051 StringRef Str1, Str2;
1052 bool HasStr1 = getConstantStringInfo(V: Str1P, Str&: Str1, /*TrimAtNul=*/false);
1053 bool HasStr2 = getConstantStringInfo(V: Str2P, Str&: Str2, /*TrimAtNul=*/false);
1054 if (HasStr1 == HasStr2)
1055 return false;
1056
1057 // Note that '\0' and characters after it are not trimmed.
1058 StringRef Str = HasStr1 ? Str1 : Str2;
1059 Value *StrP = HasStr1 ? Str2P : Str1P;
1060
1061 size_t Idx = Str.find(C: '\0');
1062 uint64_t N = Idx == StringRef::npos ? UINT64_MAX : Idx + 1;
1063 if (Func == LibFunc_strncmp) {
1064 if (auto *ConstInt = dyn_cast<ConstantInt>(Val: CI->getArgOperand(i: 2)))
1065 N = std::min(a: N, b: ConstInt->getZExtValue());
1066 else
1067 return false;
1068 }
1069 // Now N means how many bytes we need to compare at most.
1070 if (N > Str.size() || N < 2 || N > StrNCmpInlineThreshold)
1071 return false;
1072
1073 // Cases where StrP has two or more dereferenceable bytes might be better
1074 // optimized elsewhere.
1075 bool CanBeNull = false, CanBeFreed = false;
1076 if (StrP->getPointerDereferenceableBytes(DL, CanBeNull, CanBeFreed) > 1)
1077 return false;
1078 inlineCompare(LHS: StrP, RHS: Str, N, Swapped: HasStr1);
1079 return true;
1080}
1081
1082/// Convert
1083///
1084/// \code
1085/// ret = compare(s1, s2, N)
1086/// \endcode
1087///
1088/// into
1089///
1090/// \code
1091/// ret = (int)s1[0] - (int)s2[0]
1092/// if (ret != 0)
1093/// goto NE
1094/// ...
1095/// ret = (int)s1[N-2] - (int)s2[N-2]
1096/// if (ret != 0)
1097/// goto NE
1098/// ret = (int)s1[N-1] - (int)s2[N-1]
1099/// NE:
1100/// \endcode
1101///
1102/// CFG before and after the transformation:
1103///
1104/// (before)
1105/// BBCI
1106///
1107/// (after)
1108/// BBCI -> BBSubs[0] (sub,icmp) --NE-> BBNE -> BBTail
1109/// | ^
1110/// E |
1111/// | |
1112/// BBSubs[1] (sub,icmp) --NE-----+
1113/// ... |
1114/// BBSubs[N-1] (sub) ---------+
1115///
1116void StrNCmpInliner::inlineCompare(Value *LHS, StringRef RHS, uint64_t N,
1117 bool Swapped) {
1118 auto &Ctx = CI->getContext();
1119 IRBuilder<> B(Ctx);
1120 // We want these instructions to be recognized as inlined instructions for the
1121 // compare call, but we don't have a source location for the definition of
1122 // that function, since we're generating that code now. Because the generated
1123 // code is a viable point for a memory access error, we make the pragmatic
1124 // choice here to directly use CI's location so that we have useful
1125 // attribution for the generated code.
1126 B.SetCurrentDebugLocation(CI->getDebugLoc());
1127
1128 BasicBlock *BBCI = CI->getParent();
1129 BasicBlock *BBTail =
1130 SplitBlock(Old: BBCI, SplitPt: CI, DTU, LI: nullptr, MSSAU: nullptr, BBName: BBCI->getName() + ".tail");
1131
1132 SmallVector<BasicBlock *> BBSubs;
1133 for (uint64_t I = 0; I < N; ++I)
1134 BBSubs.push_back(
1135 Elt: BasicBlock::Create(Context&: Ctx, Name: "sub_" + Twine(I), Parent: BBCI->getParent(), InsertBefore: BBTail));
1136 BasicBlock *BBNE = BasicBlock::Create(Context&: Ctx, Name: "ne", Parent: BBCI->getParent(), InsertBefore: BBTail);
1137
1138 cast<BranchInst>(Val: BBCI->getTerminator())->setSuccessor(idx: 0, NewSucc: BBSubs[0]);
1139
1140 B.SetInsertPoint(BBNE);
1141 PHINode *Phi = B.CreatePHI(Ty: CI->getType(), NumReservedValues: N);
1142 B.CreateBr(Dest: BBTail);
1143
1144 Value *Base = LHS;
1145 for (uint64_t i = 0; i < N; ++i) {
1146 B.SetInsertPoint(BBSubs[i]);
1147 Value *VL =
1148 B.CreateZExt(V: B.CreateLoad(Ty: B.getInt8Ty(),
1149 Ptr: B.CreateInBoundsPtrAdd(Ptr: Base, Offset: B.getInt64(C: i))),
1150 DestTy: CI->getType());
1151 Value *VR =
1152 ConstantInt::get(Ty: CI->getType(), V: static_cast<unsigned char>(RHS[i]));
1153 Value *Sub = Swapped ? B.CreateSub(LHS: VR, RHS: VL) : B.CreateSub(LHS: VL, RHS: VR);
1154 if (i < N - 1)
1155 B.CreateCondBr(Cond: B.CreateICmpNE(LHS: Sub, RHS: ConstantInt::get(Ty: CI->getType(), V: 0)),
1156 True: BBNE, False: BBSubs[i + 1]);
1157 else
1158 B.CreateBr(Dest: BBNE);
1159
1160 Phi->addIncoming(V: Sub, BB: BBSubs[i]);
1161 }
1162
1163 CI->replaceAllUsesWith(V: Phi);
1164 CI->eraseFromParent();
1165
1166 if (DTU) {
1167 SmallVector<DominatorTree::UpdateType, 8> Updates;
1168 Updates.push_back(Elt: {DominatorTree::Insert, BBCI, BBSubs[0]});
1169 for (uint64_t i = 0; i < N; ++i) {
1170 if (i < N - 1)
1171 Updates.push_back(Elt: {DominatorTree::Insert, BBSubs[i], BBSubs[i + 1]});
1172 Updates.push_back(Elt: {DominatorTree::Insert, BBSubs[i], BBNE});
1173 }
1174 Updates.push_back(Elt: {DominatorTree::Insert, BBNE, BBTail});
1175 Updates.push_back(Elt: {DominatorTree::Delete, BBCI, BBTail});
1176 DTU->applyUpdates(Updates);
1177 }
1178}
1179
1180/// Convert memchr with a small constant string into a switch
1181static bool foldMemChr(CallInst *Call, DomTreeUpdater *DTU,
1182 const DataLayout &DL) {
1183 if (isa<Constant>(Val: Call->getArgOperand(i: 1)))
1184 return false;
1185
1186 StringRef Str;
1187 Value *Base = Call->getArgOperand(i: 0);
1188 if (!getConstantStringInfo(V: Base, Str, /*TrimAtNul=*/false))
1189 return false;
1190
1191 uint64_t N = Str.size();
1192 if (auto *ConstInt = dyn_cast<ConstantInt>(Val: Call->getArgOperand(i: 2))) {
1193 uint64_t Val = ConstInt->getZExtValue();
1194 // Ignore the case that n is larger than the size of string.
1195 if (Val > N)
1196 return false;
1197 N = Val;
1198 } else
1199 return false;
1200
1201 if (N > MemChrInlineThreshold)
1202 return false;
1203
1204 BasicBlock *BB = Call->getParent();
1205 BasicBlock *BBNext = SplitBlock(Old: BB, SplitPt: Call, DTU);
1206 IRBuilder<> IRB(BB);
1207 IRB.SetCurrentDebugLocation(Call->getDebugLoc());
1208 IntegerType *ByteTy = IRB.getInt8Ty();
1209 BB->getTerminator()->eraseFromParent();
1210 SwitchInst *SI = IRB.CreateSwitch(
1211 V: IRB.CreateTrunc(V: Call->getArgOperand(i: 1), DestTy: ByteTy), Dest: BBNext, NumCases: N);
1212 Type *IndexTy = DL.getIndexType(PtrTy: Call->getType());
1213 SmallVector<DominatorTree::UpdateType, 8> Updates;
1214
1215 BasicBlock *BBSuccess = BasicBlock::Create(
1216 Context&: Call->getContext(), Name: "memchr.success", Parent: BB->getParent(), InsertBefore: BBNext);
1217 IRB.SetInsertPoint(BBSuccess);
1218 PHINode *IndexPHI = IRB.CreatePHI(Ty: IndexTy, NumReservedValues: N, Name: "memchr.idx");
1219 Value *FirstOccursLocation = IRB.CreateInBoundsPtrAdd(Ptr: Base, Offset: IndexPHI);
1220 IRB.CreateBr(Dest: BBNext);
1221 if (DTU)
1222 Updates.push_back(Elt: {DominatorTree::Insert, BBSuccess, BBNext});
1223
1224 SmallPtrSet<ConstantInt *, 4> Cases;
1225 for (uint64_t I = 0; I < N; ++I) {
1226 ConstantInt *CaseVal = ConstantInt::get(Ty: ByteTy, V: Str[I]);
1227 if (!Cases.insert(Ptr: CaseVal).second)
1228 continue;
1229
1230 BasicBlock *BBCase = BasicBlock::Create(Context&: Call->getContext(), Name: "memchr.case",
1231 Parent: BB->getParent(), InsertBefore: BBSuccess);
1232 SI->addCase(OnVal: CaseVal, Dest: BBCase);
1233 IRB.SetInsertPoint(BBCase);
1234 IndexPHI->addIncoming(V: ConstantInt::get(Ty: IndexTy, V: I), BB: BBCase);
1235 IRB.CreateBr(Dest: BBSuccess);
1236 if (DTU) {
1237 Updates.push_back(Elt: {DominatorTree::Insert, BB, BBCase});
1238 Updates.push_back(Elt: {DominatorTree::Insert, BBCase, BBSuccess});
1239 }
1240 }
1241
1242 PHINode *PHI =
1243 PHINode::Create(Ty: Call->getType(), NumReservedValues: 2, NameStr: Call->getName(), InsertBefore: BBNext->begin());
1244 PHI->addIncoming(V: Constant::getNullValue(Ty: Call->getType()), BB);
1245 PHI->addIncoming(V: FirstOccursLocation, BB: BBSuccess);
1246
1247 Call->replaceAllUsesWith(V: PHI);
1248 Call->eraseFromParent();
1249
1250 if (DTU)
1251 DTU->applyUpdates(Updates);
1252
1253 return true;
1254}
1255
1256static bool foldLibCalls(Instruction &I, TargetTransformInfo &TTI,
1257 TargetLibraryInfo &TLI, AssumptionCache &AC,
1258 DominatorTree &DT, const DataLayout &DL,
1259 bool &MadeCFGChange) {
1260
1261 auto *CI = dyn_cast<CallInst>(Val: &I);
1262 if (!CI || CI->isNoBuiltin())
1263 return false;
1264
1265 Function *CalledFunc = CI->getCalledFunction();
1266 if (!CalledFunc)
1267 return false;
1268
1269 LibFunc LF;
1270 if (!TLI.getLibFunc(FDecl: *CalledFunc, F&: LF) ||
1271 !isLibFuncEmittable(M: CI->getModule(), TLI: &TLI, TheLibFunc: LF))
1272 return false;
1273
1274 DomTreeUpdater DTU(&DT, DomTreeUpdater::UpdateStrategy::Lazy);
1275
1276 switch (LF) {
1277 case LibFunc_sqrt:
1278 case LibFunc_sqrtf:
1279 case LibFunc_sqrtl:
1280 return foldSqrt(Call: CI, Func: LF, TTI, TLI, AC, DT);
1281 case LibFunc_strcmp:
1282 case LibFunc_strncmp:
1283 if (StrNCmpInliner(CI, LF, &DTU, DL).optimizeStrNCmp()) {
1284 MadeCFGChange = true;
1285 return true;
1286 }
1287 break;
1288 case LibFunc_memchr:
1289 if (foldMemChr(Call: CI, DTU: &DTU, DL)) {
1290 MadeCFGChange = true;
1291 return true;
1292 }
1293 break;
1294 default:;
1295 }
1296 return false;
1297}
1298
1299/// This is the entry point for folds that could be implemented in regular
1300/// InstCombine, but they are separated because they are not expected to
1301/// occur frequently and/or have more than a constant-length pattern match.
1302static bool foldUnusualPatterns(Function &F, DominatorTree &DT,
1303 TargetTransformInfo &TTI,
1304 TargetLibraryInfo &TLI, AliasAnalysis &AA,
1305 AssumptionCache &AC, bool &MadeCFGChange) {
1306 bool MadeChange = false;
1307 for (BasicBlock &BB : F) {
1308 // Ignore unreachable basic blocks.
1309 if (!DT.isReachableFromEntry(A: &BB))
1310 continue;
1311
1312 const DataLayout &DL = F.getDataLayout();
1313
1314 // Walk the block backwards for efficiency. We're matching a chain of
1315 // use->defs, so we're more likely to succeed by starting from the bottom.
1316 // Also, we want to avoid matching partial patterns.
1317 // TODO: It would be more efficient if we removed dead instructions
1318 // iteratively in this loop rather than waiting until the end.
1319 for (Instruction &I : make_early_inc_range(Range: llvm::reverse(C&: BB))) {
1320 MadeChange |= foldAnyOrAllBitsSet(I);
1321 MadeChange |= foldGuardedFunnelShift(I, DT);
1322 MadeChange |= tryToRecognizePopCount(I);
1323 MadeChange |= tryToFPToSat(I, TTI);
1324 MadeChange |= tryToRecognizeTableBasedCttz(I);
1325 MadeChange |= foldConsecutiveLoads(I, DL, TTI, AA, DT);
1326 MadeChange |= foldPatternedLoads(I, DL);
1327 MadeChange |= foldICmpOrChain(I, DL, TTI, AA, DT);
1328 // NOTE: This function introduces erasing of the instruction `I`, so it
1329 // needs to be called at the end of this sequence, otherwise we may make
1330 // bugs.
1331 MadeChange |= foldLibCalls(I, TTI, TLI, AC, DT, DL, MadeCFGChange);
1332 }
1333 }
1334
1335 // We're done with transforms, so remove dead instructions.
1336 if (MadeChange)
1337 for (BasicBlock &BB : F)
1338 SimplifyInstructionsInBlock(BB: &BB);
1339
1340 return MadeChange;
1341}
1342
1343/// This is the entry point for all transforms. Pass manager differences are
1344/// handled in the callers of this function.
1345static bool runImpl(Function &F, AssumptionCache &AC, TargetTransformInfo &TTI,
1346 TargetLibraryInfo &TLI, DominatorTree &DT,
1347 AliasAnalysis &AA, bool &MadeCFGChange) {
1348 bool MadeChange = false;
1349 const DataLayout &DL = F.getDataLayout();
1350 TruncInstCombine TIC(AC, TLI, DL, DT);
1351 MadeChange |= TIC.run(F);
1352 MadeChange |= foldUnusualPatterns(F, DT, TTI, TLI, AA, AC, MadeCFGChange);
1353 return MadeChange;
1354}
1355
1356PreservedAnalyses AggressiveInstCombinePass::run(Function &F,
1357 FunctionAnalysisManager &AM) {
1358 auto &AC = AM.getResult<AssumptionAnalysis>(IR&: F);
1359 auto &TLI = AM.getResult<TargetLibraryAnalysis>(IR&: F);
1360 auto &DT = AM.getResult<DominatorTreeAnalysis>(IR&: F);
1361 auto &TTI = AM.getResult<TargetIRAnalysis>(IR&: F);
1362 auto &AA = AM.getResult<AAManager>(IR&: F);
1363 bool MadeCFGChange = false;
1364 if (!runImpl(F, AC, TTI, TLI, DT, AA, MadeCFGChange)) {
1365 // No changes, all analyses are preserved.
1366 return PreservedAnalyses::all();
1367 }
1368 // Mark all the analyses that instcombine updates as preserved.
1369 PreservedAnalyses PA;
1370 if (MadeCFGChange)
1371 PA.preserve<DominatorTreeAnalysis>();
1372 else
1373 PA.preserveSet<CFGAnalyses>();
1374 return PA;
1375}
1376