1//===- AggressiveInstCombine.cpp ------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements the aggressive expression pattern combiner classes.
10// Currently, it handles expression patterns for:
11// * Truncate instruction
12//
13//===----------------------------------------------------------------------===//
14
15#include "llvm/Transforms/AggressiveInstCombine/AggressiveInstCombine.h"
16#include "AggressiveInstCombineInternal.h"
17#include "llvm/ADT/Statistic.h"
18#include "llvm/Analysis/AliasAnalysis.h"
19#include "llvm/Analysis/AssumptionCache.h"
20#include "llvm/Analysis/BasicAliasAnalysis.h"
21#include "llvm/Analysis/ConstantFolding.h"
22#include "llvm/Analysis/DomTreeUpdater.h"
23#include "llvm/Analysis/GlobalsModRef.h"
24#include "llvm/Analysis/TargetLibraryInfo.h"
25#include "llvm/Analysis/TargetTransformInfo.h"
26#include "llvm/Analysis/ValueTracking.h"
27#include "llvm/IR/DataLayout.h"
28#include "llvm/IR/Dominators.h"
29#include "llvm/IR/Function.h"
30#include "llvm/IR/IRBuilder.h"
31#include "llvm/IR/Instruction.h"
32#include "llvm/IR/MDBuilder.h"
33#include "llvm/IR/PatternMatch.h"
34#include "llvm/IR/ProfDataUtils.h"
35#include "llvm/Support/Casting.h"
36#include "llvm/Support/CommandLine.h"
37#include "llvm/Transforms/Utils/BasicBlockUtils.h"
38#include "llvm/Transforms/Utils/BuildLibCalls.h"
39#include "llvm/Transforms/Utils/Local.h"
40
41using namespace llvm;
42using namespace PatternMatch;
43
44#define DEBUG_TYPE "aggressive-instcombine"
45
46namespace llvm {
47extern cl::opt<bool> ProfcheckDisableMetadataFixes;
48}
49
50STATISTIC(NumAnyOrAllBitsSet, "Number of any/all-bits-set patterns folded");
51STATISTIC(NumGuardedRotates,
52 "Number of guarded rotates transformed into funnel shifts");
53STATISTIC(NumGuardedFunnelShifts,
54 "Number of guarded funnel shifts transformed into funnel shifts");
55STATISTIC(NumPopCountRecognized, "Number of popcount idioms recognized");
56
57static cl::opt<unsigned> MaxInstrsToScan(
58 "aggressive-instcombine-max-scan-instrs", cl::init(Val: 64), cl::Hidden,
59 cl::desc("Max number of instructions to scan for aggressive instcombine."));
60
61static cl::opt<unsigned> StrNCmpInlineThreshold(
62 "strncmp-inline-threshold", cl::init(Val: 3), cl::Hidden,
63 cl::desc("The maximum length of a constant string for a builtin string cmp "
64 "call eligible for inlining. The default value is 3."));
65
66static cl::opt<unsigned>
67 MemChrInlineThreshold("memchr-inline-threshold", cl::init(Val: 3), cl::Hidden,
68 cl::desc("The maximum length of a constant string to "
69 "inline a memchr call."));
70
71/// Match a pattern for a bitwise funnel/rotate operation that partially guards
72/// against undefined behavior by branching around the funnel-shift/rotation
73/// when the shift amount is 0.
74static bool foldGuardedFunnelShift(Instruction &I, const DominatorTree &DT) {
75 if (I.getOpcode() != Instruction::PHI || I.getNumOperands() != 2)
76 return false;
77
78 // As with the one-use checks below, this is not strictly necessary, but we
79 // are being cautious to avoid potential perf regressions on targets that
80 // do not actually have a funnel/rotate instruction (where the funnel shift
81 // would be expanded back into math/shift/logic ops).
82 if (!isPowerOf2_32(Value: I.getType()->getScalarSizeInBits()))
83 return false;
84
85 // Match V to funnel shift left/right and capture the source operands and
86 // shift amount.
87 auto matchFunnelShift = [](Value *V, Value *&ShVal0, Value *&ShVal1,
88 Value *&ShAmt) {
89 unsigned Width = V->getType()->getScalarSizeInBits();
90
91 // fshl(ShVal0, ShVal1, ShAmt)
92 // == (ShVal0 << ShAmt) | (ShVal1 >> (Width -ShAmt))
93 if (match(V, P: m_OneUse(SubPattern: m_c_Or(
94 L: m_Shl(L: m_Value(V&: ShVal0), R: m_Value(V&: ShAmt)),
95 R: m_LShr(L: m_Value(V&: ShVal1), R: m_Sub(L: m_SpecificInt(V: Width),
96 R: m_Deferred(V: ShAmt))))))) {
97 return Intrinsic::fshl;
98 }
99
100 // fshr(ShVal0, ShVal1, ShAmt)
101 // == (ShVal0 >> ShAmt) | (ShVal1 << (Width - ShAmt))
102 if (match(V,
103 P: m_OneUse(SubPattern: m_c_Or(L: m_Shl(L: m_Value(V&: ShVal0), R: m_Sub(L: m_SpecificInt(V: Width),
104 R: m_Value(V&: ShAmt))),
105 R: m_LShr(L: m_Value(V&: ShVal1), R: m_Deferred(V: ShAmt)))))) {
106 return Intrinsic::fshr;
107 }
108
109 return Intrinsic::not_intrinsic;
110 };
111
112 // One phi operand must be a funnel/rotate operation, and the other phi
113 // operand must be the source value of that funnel/rotate operation:
114 // phi [ rotate(RotSrc, ShAmt), FunnelBB ], [ RotSrc, GuardBB ]
115 // phi [ fshl(ShVal0, ShVal1, ShAmt), FunnelBB ], [ ShVal0, GuardBB ]
116 // phi [ fshr(ShVal0, ShVal1, ShAmt), FunnelBB ], [ ShVal1, GuardBB ]
117 PHINode &Phi = cast<PHINode>(Val&: I);
118 unsigned FunnelOp = 0, GuardOp = 1;
119 Value *P0 = Phi.getOperand(i_nocapture: 0), *P1 = Phi.getOperand(i_nocapture: 1);
120 Value *ShVal0, *ShVal1, *ShAmt;
121 Intrinsic::ID IID = matchFunnelShift(P0, ShVal0, ShVal1, ShAmt);
122 if (IID == Intrinsic::not_intrinsic ||
123 (IID == Intrinsic::fshl && ShVal0 != P1) ||
124 (IID == Intrinsic::fshr && ShVal1 != P1)) {
125 IID = matchFunnelShift(P1, ShVal0, ShVal1, ShAmt);
126 if (IID == Intrinsic::not_intrinsic ||
127 (IID == Intrinsic::fshl && ShVal0 != P0) ||
128 (IID == Intrinsic::fshr && ShVal1 != P0))
129 return false;
130 assert((IID == Intrinsic::fshl || IID == Intrinsic::fshr) &&
131 "Pattern must match funnel shift left or right");
132 std::swap(a&: FunnelOp, b&: GuardOp);
133 }
134
135 // The incoming block with our source operand must be the "guard" block.
136 // That must contain a cmp+branch to avoid the funnel/rotate when the shift
137 // amount is equal to 0. The other incoming block is the block with the
138 // funnel/rotate.
139 BasicBlock *GuardBB = Phi.getIncomingBlock(i: GuardOp);
140 BasicBlock *FunnelBB = Phi.getIncomingBlock(i: FunnelOp);
141 Instruction *TermI = GuardBB->getTerminator();
142
143 // Ensure that the shift values dominate each block.
144 if (!DT.dominates(Def: ShVal0, User: TermI) || !DT.dominates(Def: ShVal1, User: TermI))
145 return false;
146
147 BasicBlock *PhiBB = Phi.getParent();
148 if (!match(V: TermI, P: m_Br(C: m_SpecificICmp(MatchPred: CmpInst::ICMP_EQ, L: m_Specific(V: ShAmt),
149 R: m_ZeroInt()),
150 T: m_SpecificBB(BB: PhiBB), F: m_SpecificBB(BB: FunnelBB))))
151 return false;
152
153 IRBuilder<> Builder(PhiBB, PhiBB->getFirstInsertionPt());
154
155 if (ShVal0 == ShVal1)
156 ++NumGuardedRotates;
157 else
158 ++NumGuardedFunnelShifts;
159
160 // If this is not a rotate then the select was blocking poison from the
161 // 'shift-by-zero' non-TVal, but a funnel shift won't - so freeze it.
162 bool IsFshl = IID == Intrinsic::fshl;
163 if (ShVal0 != ShVal1) {
164 if (IsFshl && !llvm::isGuaranteedNotToBePoison(V: ShVal1))
165 ShVal1 = Builder.CreateFreeze(V: ShVal1);
166 else if (!IsFshl && !llvm::isGuaranteedNotToBePoison(V: ShVal0))
167 ShVal0 = Builder.CreateFreeze(V: ShVal0);
168 }
169
170 // We matched a variation of this IR pattern:
171 // GuardBB:
172 // %cmp = icmp eq i32 %ShAmt, 0
173 // br i1 %cmp, label %PhiBB, label %FunnelBB
174 // FunnelBB:
175 // %sub = sub i32 32, %ShAmt
176 // %shr = lshr i32 %ShVal1, %sub
177 // %shl = shl i32 %ShVal0, %ShAmt
178 // %fsh = or i32 %shr, %shl
179 // br label %PhiBB
180 // PhiBB:
181 // %cond = phi i32 [ %fsh, %FunnelBB ], [ %ShVal0, %GuardBB ]
182 // -->
183 // llvm.fshl.i32(i32 %ShVal0, i32 %ShVal1, i32 %ShAmt)
184 Phi.replaceAllUsesWith(
185 V: Builder.CreateIntrinsic(ID: IID, Types: Phi.getType(), Args: {ShVal0, ShVal1, ShAmt}));
186 return true;
187}
188
189/// This is used by foldAnyOrAllBitsSet() to capture a source value (Root) and
190/// the bit indexes (Mask) needed by a masked compare. If we're matching a chain
191/// of 'and' ops, then we also need to capture the fact that we saw an
192/// "and X, 1", so that's an extra return value for that case.
193namespace {
194struct MaskOps {
195 Value *Root = nullptr;
196 APInt Mask;
197 bool MatchAndChain;
198 bool FoundAnd1 = false;
199
200 MaskOps(unsigned BitWidth, bool MatchAnds)
201 : Mask(APInt::getZero(numBits: BitWidth)), MatchAndChain(MatchAnds) {}
202};
203} // namespace
204
205/// This is a recursive helper for foldAnyOrAllBitsSet() that walks through a
206/// chain of 'and' or 'or' instructions looking for shift ops of a common source
207/// value. Examples:
208/// or (or (or X, (X >> 3)), (X >> 5)), (X >> 8)
209/// returns { X, 0x129 }
210/// and (and (X >> 1), 1), (X >> 4)
211/// returns { X, 0x12 }
212static bool matchAndOrChain(Value *V, MaskOps &MOps) {
213 Value *Op0, *Op1;
214 if (MOps.MatchAndChain) {
215 // Recurse through a chain of 'and' operands. This requires an extra check
216 // vs. the 'or' matcher: we must find an "and X, 1" instruction somewhere
217 // in the chain to know that all of the high bits are cleared.
218 if (match(V, P: m_And(L: m_Value(V&: Op0), R: m_One()))) {
219 MOps.FoundAnd1 = true;
220 return matchAndOrChain(V: Op0, MOps);
221 }
222 if (match(V, P: m_And(L: m_Value(V&: Op0), R: m_Value(V&: Op1))))
223 return matchAndOrChain(V: Op0, MOps) && matchAndOrChain(V: Op1, MOps);
224 } else {
225 // Recurse through a chain of 'or' operands.
226 if (match(V, P: m_Or(L: m_Value(V&: Op0), R: m_Value(V&: Op1))))
227 return matchAndOrChain(V: Op0, MOps) && matchAndOrChain(V: Op1, MOps);
228 }
229
230 // We need a shift-right or a bare value representing a compare of bit 0 of
231 // the original source operand.
232 Value *Candidate;
233 const APInt *BitIndex = nullptr;
234 if (!match(V, P: m_LShr(L: m_Value(V&: Candidate), R: m_APInt(Res&: BitIndex))))
235 Candidate = V;
236
237 // Initialize result source operand.
238 if (!MOps.Root)
239 MOps.Root = Candidate;
240
241 // The shift constant is out-of-range? This code hasn't been simplified.
242 if (BitIndex && BitIndex->uge(RHS: MOps.Mask.getBitWidth()))
243 return false;
244
245 // Fill in the mask bit derived from the shift constant.
246 MOps.Mask.setBit(BitIndex ? BitIndex->getZExtValue() : 0);
247 return MOps.Root == Candidate;
248}
249
250/// Match patterns that correspond to "any-bits-set" and "all-bits-set".
251/// These will include a chain of 'or' or 'and'-shifted bits from a
252/// common source value:
253/// and (or (lshr X, C), ...), 1 --> (X & CMask) != 0
254/// and (and (lshr X, C), ...), 1 --> (X & CMask) == CMask
255/// Note: "any-bits-clear" and "all-bits-clear" are variations of these patterns
256/// that differ only with a final 'not' of the result. We expect that final
257/// 'not' to be folded with the compare that we create here (invert predicate).
258static bool foldAnyOrAllBitsSet(Instruction &I) {
259 // The 'any-bits-set' ('or' chain) pattern is simpler to match because the
260 // final "and X, 1" instruction must be the final op in the sequence.
261 bool MatchAllBitsSet;
262 bool MatchTrunc;
263 Value *X;
264 if (I.getType()->isIntOrIntVectorTy(BitWidth: 1)) {
265 if (match(V: &I, P: m_Trunc(Op: m_OneUse(SubPattern: m_And(L: m_Value(), R: m_Value())))))
266 MatchAllBitsSet = true;
267 else if (match(V: &I, P: m_Trunc(Op: m_OneUse(SubPattern: m_Or(L: m_Value(), R: m_Value())))))
268 MatchAllBitsSet = false;
269 else
270 return false;
271 MatchTrunc = true;
272 X = I.getOperand(i: 0);
273 } else {
274 if (match(V: &I, P: m_c_And(L: m_OneUse(SubPattern: m_And(L: m_Value(), R: m_Value())), R: m_Value()))) {
275 X = &I;
276 MatchAllBitsSet = true;
277 } else if (match(V: &I,
278 P: m_And(L: m_OneUse(SubPattern: m_Or(L: m_Value(), R: m_Value())), R: m_One()))) {
279 X = I.getOperand(i: 0);
280 MatchAllBitsSet = false;
281 } else
282 return false;
283 MatchTrunc = false;
284 }
285 Type *Ty = X->getType();
286
287 MaskOps MOps(Ty->getScalarSizeInBits(), MatchAllBitsSet);
288 if (!matchAndOrChain(V: X, MOps) ||
289 (MatchAllBitsSet && !MatchTrunc && !MOps.FoundAnd1))
290 return false;
291
292 // The pattern was found. Create a masked compare that replaces all of the
293 // shift and logic ops.
294 IRBuilder<> Builder(&I);
295 Constant *Mask = ConstantInt::get(Ty, V: MOps.Mask);
296 Value *And = Builder.CreateAnd(LHS: MOps.Root, RHS: Mask);
297 Value *Cmp = MatchAllBitsSet ? Builder.CreateICmpEQ(LHS: And, RHS: Mask)
298 : Builder.CreateIsNotNull(Arg: And);
299 Value *Zext = MatchTrunc ? Cmp : Builder.CreateZExt(V: Cmp, DestTy: Ty);
300 I.replaceAllUsesWith(V: Zext);
301 ++NumAnyOrAllBitsSet;
302 return true;
303}
304
305// Try to recognize below function as popcount intrinsic.
306// This is the "best" algorithm from
307// http://graphics.stanford.edu/~seander/bithacks.html#CountBitsSetParallel
308// Also used in TargetLowering::expandCTPOP().
309//
310// int popcount(unsigned int i) {
311// i = i - ((i >> 1) & 0x55555555);
312// i = (i & 0x33333333) + ((i >> 2) & 0x33333333);
313// i = ((i + (i >> 4)) & 0x0F0F0F0F);
314// return (i * 0x01010101) >> 24;
315// }
316static bool tryToRecognizePopCount(Instruction &I) {
317 if (I.getOpcode() != Instruction::LShr)
318 return false;
319
320 Type *Ty = I.getType();
321 if (!Ty->isIntOrIntVectorTy())
322 return false;
323
324 unsigned Len = Ty->getScalarSizeInBits();
325 // FIXME: fix Len == 8 and other irregular type lengths.
326 if (!(Len <= 128 && Len > 8 && Len % 8 == 0))
327 return false;
328
329 APInt Mask55 = APInt::getSplat(NewLen: Len, V: APInt(8, 0x55));
330 APInt Mask33 = APInt::getSplat(NewLen: Len, V: APInt(8, 0x33));
331 APInt Mask0F = APInt::getSplat(NewLen: Len, V: APInt(8, 0x0F));
332 APInt Mask01 = APInt::getSplat(NewLen: Len, V: APInt(8, 0x01));
333 APInt MaskShift = APInt(Len, Len - 8);
334
335 Value *Op0 = I.getOperand(i: 0);
336 Value *Op1 = I.getOperand(i: 1);
337 Value *MulOp0;
338 // Matching "(i * 0x01010101...) >> 24".
339 if ((match(V: Op0, P: m_Mul(L: m_Value(V&: MulOp0), R: m_SpecificInt(V: Mask01)))) &&
340 match(V: Op1, P: m_SpecificInt(V: MaskShift))) {
341 Value *ShiftOp0;
342 // Matching "((i + (i >> 4)) & 0x0F0F0F0F...)".
343 if (match(V: MulOp0, P: m_And(L: m_c_Add(L: m_LShr(L: m_Value(V&: ShiftOp0), R: m_SpecificInt(V: 4)),
344 R: m_Deferred(V: ShiftOp0)),
345 R: m_SpecificInt(V: Mask0F)))) {
346 Value *AndOp0;
347 // Matching "(i & 0x33333333...) + ((i >> 2) & 0x33333333...)".
348 if (match(V: ShiftOp0,
349 P: m_c_Add(L: m_And(L: m_Value(V&: AndOp0), R: m_SpecificInt(V: Mask33)),
350 R: m_And(L: m_LShr(L: m_Deferred(V: AndOp0), R: m_SpecificInt(V: 2)),
351 R: m_SpecificInt(V: Mask33))))) {
352 Value *Root, *SubOp1;
353 // Matching "i - ((i >> 1) & 0x55555555...)".
354 const APInt *AndMask;
355 if (match(V: AndOp0, P: m_Sub(L: m_Value(V&: Root), R: m_Value(V&: SubOp1))) &&
356 match(V: SubOp1, P: m_And(L: m_LShr(L: m_Specific(V: Root), R: m_SpecificInt(V: 1)),
357 R: m_APInt(Res&: AndMask)))) {
358 auto CheckAndMask = [&]() {
359 if (*AndMask == Mask55)
360 return true;
361
362 // Exact match failed, see if any bits are known to be 0 where we
363 // expect a 1 in the mask.
364 if (!AndMask->isSubsetOf(RHS: Mask55))
365 return false;
366
367 APInt NeededMask = Mask55 & ~*AndMask;
368 return MaskedValueIsZero(V: cast<Instruction>(Val: SubOp1)->getOperand(i: 0),
369 Mask: NeededMask,
370 SQ: SimplifyQuery(I.getDataLayout()));
371 };
372
373 if (CheckAndMask()) {
374 LLVM_DEBUG(dbgs() << "Recognized popcount intrinsic\n");
375 IRBuilder<> Builder(&I);
376 I.replaceAllUsesWith(
377 V: Builder.CreateIntrinsic(ID: Intrinsic::ctpop, Types: I.getType(), Args: {Root}));
378 ++NumPopCountRecognized;
379 return true;
380 }
381 }
382 }
383 }
384 }
385
386 return false;
387}
388
389/// Fold smin(smax(fptosi(x), C1), C2) to llvm.fptosi.sat(x), providing C1 and
390/// C2 saturate the value of the fp conversion. The transform is not reversable
391/// as the fptosi.sat is more defined than the input - all values produce a
392/// valid value for the fptosi.sat, where as some produce poison for original
393/// that were out of range of the integer conversion. The reversed pattern may
394/// use fmax and fmin instead. As we cannot directly reverse the transform, and
395/// it is not always profitable, we make it conditional on the cost being
396/// reported as lower by TTI.
397static bool tryToFPToSat(Instruction &I, TargetTransformInfo &TTI) {
398 // Look for min(max(fptosi, converting to fptosi_sat.
399 Value *In;
400 const APInt *MinC, *MaxC;
401 if (!match(V: &I, P: m_SMax(L: m_OneUse(SubPattern: m_SMin(L: m_OneUse(SubPattern: m_FPToSI(Op: m_Value(V&: In))),
402 R: m_APInt(Res&: MinC))),
403 R: m_APInt(Res&: MaxC))) &&
404 !match(V: &I, P: m_SMin(L: m_OneUse(SubPattern: m_SMax(L: m_OneUse(SubPattern: m_FPToSI(Op: m_Value(V&: In))),
405 R: m_APInt(Res&: MaxC))),
406 R: m_APInt(Res&: MinC))))
407 return false;
408
409 // Check that the constants clamp a saturate.
410 if (!(*MinC + 1).isPowerOf2() || -*MaxC != *MinC + 1)
411 return false;
412
413 Type *IntTy = I.getType();
414 Type *FpTy = In->getType();
415 Type *SatTy =
416 IntegerType::get(C&: IntTy->getContext(), NumBits: (*MinC + 1).exactLogBase2() + 1);
417 if (auto *VecTy = dyn_cast<VectorType>(Val: IntTy))
418 SatTy = VectorType::get(ElementType: SatTy, EC: VecTy->getElementCount());
419
420 // Get the cost of the intrinsic, and check that against the cost of
421 // fptosi+smin+smax
422 InstructionCost SatCost = TTI.getIntrinsicInstrCost(
423 ICA: IntrinsicCostAttributes(Intrinsic::fptosi_sat, SatTy, {In}, {FpTy}),
424 CostKind: TTI::TCK_RecipThroughput);
425 SatCost += TTI.getCastInstrCost(Opcode: Instruction::SExt, Dst: IntTy, Src: SatTy,
426 CCH: TTI::CastContextHint::None,
427 CostKind: TTI::TCK_RecipThroughput);
428
429 InstructionCost MinMaxCost = TTI.getCastInstrCost(
430 Opcode: Instruction::FPToSI, Dst: IntTy, Src: FpTy, CCH: TTI::CastContextHint::None,
431 CostKind: TTI::TCK_RecipThroughput);
432 MinMaxCost += TTI.getIntrinsicInstrCost(
433 ICA: IntrinsicCostAttributes(Intrinsic::smin, IntTy, {IntTy}),
434 CostKind: TTI::TCK_RecipThroughput);
435 MinMaxCost += TTI.getIntrinsicInstrCost(
436 ICA: IntrinsicCostAttributes(Intrinsic::smax, IntTy, {IntTy}),
437 CostKind: TTI::TCK_RecipThroughput);
438
439 if (SatCost >= MinMaxCost)
440 return false;
441
442 IRBuilder<> Builder(&I);
443 Value *Sat =
444 Builder.CreateIntrinsic(ID: Intrinsic::fptosi_sat, Types: {SatTy, FpTy}, Args: In);
445 I.replaceAllUsesWith(V: Builder.CreateSExt(V: Sat, DestTy: IntTy));
446 return true;
447}
448
449/// Try to replace a mathlib call to sqrt with the LLVM intrinsic. This avoids
450/// pessimistic codegen that has to account for setting errno and can enable
451/// vectorization.
452static bool foldSqrt(CallInst *Call, LibFunc Func, TargetTransformInfo &TTI,
453 TargetLibraryInfo &TLI, AssumptionCache &AC,
454 DominatorTree &DT) {
455 // If (1) this is a sqrt libcall, (2) we can assume that NAN is not created
456 // (because NNAN or the operand arg must not be less than -0.0) and (2) we
457 // would not end up lowering to a libcall anyway (which could change the value
458 // of errno), then:
459 // (1) errno won't be set.
460 // (2) it is safe to convert this to an intrinsic call.
461 Type *Ty = Call->getType();
462 Value *Arg = Call->getArgOperand(i: 0);
463 if (TTI.haveFastSqrt(Ty) &&
464 (Call->hasNoNaNs() ||
465 cannotBeOrderedLessThanZero(
466 V: Arg, SQ: SimplifyQuery(Call->getDataLayout(), &TLI, &DT, &AC, Call)))) {
467 IRBuilder<> Builder(Call);
468 Value *NewSqrt =
469 Builder.CreateIntrinsic(ID: Intrinsic::sqrt, Types: Ty, Args: Arg, FMFSource: Call, Name: "sqrt");
470 Call->replaceAllUsesWith(V: NewSqrt);
471
472 // Explicitly erase the old call because a call with side effects is not
473 // trivially dead.
474 Call->eraseFromParent();
475 return true;
476 }
477
478 return false;
479}
480
481// Check if this array of constants represents a cttz table.
482// Iterate over the elements from \p Table by trying to find/match all
483// the numbers from 0 to \p InputBits that should represent cttz results.
484static bool isCTTZTable(Constant *Table, const APInt &Mul, const APInt &Shift,
485 const APInt &AndMask, Type *AccessTy,
486 unsigned InputBits, const APInt &GEPIdxFactor,
487 const DataLayout &DL) {
488 for (unsigned Idx = 0; Idx < InputBits; Idx++) {
489 APInt Index =
490 (APInt::getOneBitSet(numBits: InputBits, BitNo: Idx) * Mul).lshr(ShiftAmt: Shift) & AndMask;
491 ConstantInt *C = dyn_cast_or_null<ConstantInt>(
492 Val: ConstantFoldLoadFromConst(C: Table, Ty: AccessTy, Offset: Index * GEPIdxFactor, DL));
493 if (!C || C->getValue() != Idx)
494 return false;
495 }
496
497 return true;
498}
499
500// Try to recognize table-based ctz implementation.
501// E.g., an example in C (for more cases please see the llvm/tests):
502// int f(unsigned x) {
503// static const char table[32] =
504// {0, 1, 28, 2, 29, 14, 24, 3, 30,
505// 22, 20, 15, 25, 17, 4, 8, 31, 27,
506// 13, 23, 21, 19, 16, 7, 26, 12, 18, 6, 11, 5, 10, 9};
507// return table[((unsigned)((x & -x) * 0x077CB531U)) >> 27];
508// }
509// this can be lowered to `cttz` instruction.
510// There is also a special case when the element is 0.
511//
512// The (x & -x) sets the lowest non-zero bit to 1. The multiply is a de-bruijn
513// sequence that contains each pattern of bits in it. The shift extracts
514// the top bits after the multiply, and that index into the table should
515// represent the number of trailing zeros in the original number.
516//
517// Here are some examples or LLVM IR for a 64-bit target:
518//
519// CASE 1:
520// %sub = sub i32 0, %x
521// %and = and i32 %sub, %x
522// %mul = mul i32 %and, 125613361
523// %shr = lshr i32 %mul, 27
524// %idxprom = zext i32 %shr to i64
525// %arrayidx = getelementptr inbounds [32 x i8], [32 x i8]* @ctz1.table, i64 0,
526// i64 %idxprom
527// %0 = load i8, i8* %arrayidx, align 1, !tbaa !8
528//
529// CASE 2:
530// %sub = sub i32 0, %x
531// %and = and i32 %sub, %x
532// %mul = mul i32 %and, 72416175
533// %shr = lshr i32 %mul, 26
534// %idxprom = zext i32 %shr to i64
535// %arrayidx = getelementptr inbounds [64 x i16], [64 x i16]* @ctz2.table,
536// i64 0, i64 %idxprom
537// %0 = load i16, i16* %arrayidx, align 2, !tbaa !8
538//
539// CASE 3:
540// %sub = sub i32 0, %x
541// %and = and i32 %sub, %x
542// %mul = mul i32 %and, 81224991
543// %shr = lshr i32 %mul, 27
544// %idxprom = zext i32 %shr to i64
545// %arrayidx = getelementptr inbounds [32 x i32], [32 x i32]* @ctz3.table,
546// i64 0, i64 %idxprom
547// %0 = load i32, i32* %arrayidx, align 4, !tbaa !8
548//
549// CASE 4:
550// %sub = sub i64 0, %x
551// %and = and i64 %sub, %x
552// %mul = mul i64 %and, 283881067100198605
553// %shr = lshr i64 %mul, 58
554// %arrayidx = getelementptr inbounds [64 x i8], [64 x i8]* @table, i64 0,
555// i64 %shr
556// %0 = load i8, i8* %arrayidx, align 1, !tbaa !8
557//
558// All these can be lowered to @llvm.cttz.i32/64 intrinsics.
559static bool tryToRecognizeTableBasedCttz(Instruction &I, const DataLayout &DL) {
560 LoadInst *LI = dyn_cast<LoadInst>(Val: &I);
561 if (!LI)
562 return false;
563
564 Type *AccessType = LI->getType();
565 if (!AccessType->isIntegerTy())
566 return false;
567
568 GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(Val: LI->getPointerOperand());
569 if (!GEP || !GEP->hasNoUnsignedSignedWrap())
570 return false;
571
572 GlobalVariable *GVTable = dyn_cast<GlobalVariable>(Val: GEP->getPointerOperand());
573 if (!GVTable || !GVTable->hasInitializer() || !GVTable->isConstant())
574 return false;
575
576 unsigned BW = DL.getIndexTypeSizeInBits(Ty: GEP->getType());
577 APInt ModOffset(BW, 0);
578 SmallMapVector<Value *, APInt, 4> VarOffsets;
579 if (!GEP->collectOffset(DL, BitWidth: BW, VariableOffsets&: VarOffsets, ConstantOffset&: ModOffset) ||
580 VarOffsets.size() != 1 || ModOffset != 0)
581 return false;
582 auto [GepIdx, GEPScale] = VarOffsets.front();
583
584 Value *X1;
585 const APInt *MulConst, *ShiftConst, *AndCst = nullptr;
586 // Check that the gep variable index is ((x & -x) * MulConst) >> ShiftConst.
587 // This might be extended to the pointer index type, and if the gep index type
588 // has been replaced with an i8 then a new And (and different ShiftConst) will
589 // be present.
590 auto MatchInner = m_LShr(
591 L: m_Mul(L: m_c_And(L: m_Neg(V: m_Value(V&: X1)), R: m_Deferred(V: X1)), R: m_APInt(Res&: MulConst)),
592 R: m_APInt(Res&: ShiftConst));
593 if (!match(V: GepIdx, P: m_CastOrSelf(Op: MatchInner)) &&
594 !match(V: GepIdx, P: m_CastOrSelf(Op: m_And(L: MatchInner, R: m_APInt(Res&: AndCst)))))
595 return false;
596
597 unsigned InputBits = X1->getType()->getScalarSizeInBits();
598 if (InputBits != 16 && InputBits != 32 && InputBits != 64 && InputBits != 128)
599 return false;
600
601 if (!GEPScale.isIntN(N: InputBits) ||
602 !isCTTZTable(Table: GVTable->getInitializer(), Mul: *MulConst, Shift: *ShiftConst,
603 AndMask: AndCst ? *AndCst : APInt::getAllOnes(numBits: InputBits), AccessTy: AccessType,
604 InputBits, GEPIdxFactor: GEPScale.zextOrTrunc(width: InputBits), DL))
605 return false;
606
607 ConstantInt *ZeroTableElem = cast<ConstantInt>(
608 Val: ConstantFoldLoadFromConst(C: GVTable->getInitializer(), Ty: AccessType, DL));
609 bool DefinedForZero = ZeroTableElem->getZExtValue() == InputBits;
610
611 IRBuilder<> B(LI);
612 ConstantInt *BoolConst = B.getInt1(V: !DefinedForZero);
613 Type *XType = X1->getType();
614 auto Cttz = B.CreateIntrinsic(ID: Intrinsic::cttz, Types: {XType}, Args: {X1, BoolConst});
615 Value *ZExtOrTrunc = nullptr;
616
617 if (DefinedForZero) {
618 ZExtOrTrunc = B.CreateZExtOrTrunc(V: Cttz, DestTy: AccessType);
619 } else {
620 // If the value in elem 0 isn't the same as InputBits, we still want to
621 // produce the value from the table.
622 auto Cmp = B.CreateICmpEQ(LHS: X1, RHS: ConstantInt::get(Ty: XType, V: 0));
623 auto Select = B.CreateSelect(C: Cmp, True: B.CreateZExt(V: ZeroTableElem, DestTy: XType), False: Cttz);
624
625 // The true branch of select handles the cttz(0) case, which is rare.
626 if (!ProfcheckDisableMetadataFixes) {
627 if (Instruction *SelectI = dyn_cast<Instruction>(Val: Select))
628 SelectI->setMetadata(
629 KindID: LLVMContext::MD_prof,
630 Node: MDBuilder(SelectI->getContext()).createUnlikelyBranchWeights());
631 }
632
633 // NOTE: If the table[0] is 0, but the cttz(0) is defined by the Target
634 // it should be handled as: `cttz(x) & (typeSize - 1)`.
635
636 ZExtOrTrunc = B.CreateZExtOrTrunc(V: Select, DestTy: AccessType);
637 }
638
639 LI->replaceAllUsesWith(V: ZExtOrTrunc);
640
641 return true;
642}
643
644// Check if this array of constants represents a log2 table.
645// Iterate over the elements from \p Table by trying to find/match all
646// the numbers from 0 to \p InputBits that should represent log2 results.
647static bool isLog2Table(Constant *Table, const APInt &Mul, const APInt &Shift,
648 Type *AccessTy, unsigned InputBits,
649 const APInt &GEPIdxFactor, const DataLayout &DL) {
650 for (unsigned Idx = 0; Idx < InputBits; Idx++) {
651 APInt Index = (APInt::getLowBitsSet(numBits: InputBits, loBitsSet: Idx + 1) * Mul).lshr(ShiftAmt: Shift);
652 ConstantInt *C = dyn_cast_or_null<ConstantInt>(
653 Val: ConstantFoldLoadFromConst(C: Table, Ty: AccessTy, Offset: Index * GEPIdxFactor, DL));
654 if (!C || C->getValue() != Idx)
655 return false;
656 }
657
658 // Verify that an input of zero will select table index 0.
659 APInt ZeroIndex = Mul.lshr(ShiftAmt: Shift);
660 if (!ZeroIndex.isZero())
661 return false;
662
663 return true;
664}
665
666// Try to recognize table-based log2 implementation.
667// E.g., an example in C (for more cases please the llvm/tests):
668// int f(unsigned v) {
669// static const char table[32] =
670// {0, 9, 1, 10, 13, 21, 2, 29, 11, 14, 16, 18, 22, 25, 3, 30,
671// 8, 12, 20, 28, 15, 17, 24, 7, 19, 27, 23, 6, 26, 5, 4, 31};
672//
673// v |= v >> 1; // first round down to one less than a power of 2
674// v |= v >> 2;
675// v |= v >> 4;
676// v |= v >> 8;
677// v |= v >> 16;
678//
679// return table[(unsigned)(v * 0x07C4ACDDU) >> 27];
680// }
681// this can be lowered to `ctlz` instruction.
682// There is also a special case when the element is 0.
683//
684// The >> and |= sequence sets all bits below the most significant set bit. The
685// multiply is a de-bruijn sequence that contains each pattern of bits in it.
686// The shift extracts the top bits after the multiply, and that index into the
687// table should represent the floor log base 2 of the original number.
688//
689// Here are some examples of LLVM IR for a 64-bit target.
690//
691// CASE 1:
692// %shr = lshr i32 %v, 1
693// %or = or i32 %shr, %v
694// %shr1 = lshr i32 %or, 2
695// %or2 = or i32 %shr1, %or
696// %shr3 = lshr i32 %or2, 4
697// %or4 = or i32 %shr3, %or2
698// %shr5 = lshr i32 %or4, 8
699// %or6 = or i32 %shr5, %or4
700// %shr7 = lshr i32 %or6, 16
701// %or8 = or i32 %shr7, %or6
702// %mul = mul i32 %or8, 130329821
703// %shr9 = lshr i32 %mul, 27
704// %idxprom = zext nneg i32 %shr9 to i64
705// %arrayidx = getelementptr inbounds i8, ptr @table, i64 %idxprom
706// %0 = load i8, ptr %arrayidx, align 1
707//
708// CASE 2:
709// %shr = lshr i64 %v, 1
710// %or = or i64 %shr, %v
711// %shr1 = lshr i64 %or, 2
712// %or2 = or i64 %shr1, %or
713// %shr3 = lshr i64 %or2, 4
714// %or4 = or i64 %shr3, %or2
715// %shr5 = lshr i64 %or4, 8
716// %or6 = or i64 %shr5, %or4
717// %shr7 = lshr i64 %or6, 16
718// %or8 = or i64 %shr7, %or6
719// %shr9 = lshr i64 %or8, 32
720// %or10 = or i64 %shr9, %or8
721// %mul = mul i64 %or10, 285870213051386505
722// %shr11 = lshr i64 %mul, 58
723// %arrayidx = getelementptr inbounds i8, ptr @table, i64 %shr11
724// %0 = load i8, ptr %arrayidx, align 1
725//
726// All these can be lowered to @llvm.ctlz.i32/64 intrinsics and a subtract.
727static bool tryToRecognizeTableBasedLog2(Instruction &I, const DataLayout &DL,
728 TargetTransformInfo &TTI) {
729 LoadInst *LI = dyn_cast<LoadInst>(Val: &I);
730 if (!LI)
731 return false;
732
733 Type *AccessType = LI->getType();
734 if (!AccessType->isIntegerTy())
735 return false;
736
737 GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(Val: LI->getPointerOperand());
738 if (!GEP || !GEP->hasNoUnsignedSignedWrap())
739 return false;
740
741 GlobalVariable *GVTable = dyn_cast<GlobalVariable>(Val: GEP->getPointerOperand());
742 if (!GVTable || !GVTable->hasInitializer() || !GVTable->isConstant())
743 return false;
744
745 unsigned BW = DL.getIndexTypeSizeInBits(Ty: GEP->getType());
746 APInt ModOffset(BW, 0);
747 SmallMapVector<Value *, APInt, 4> VarOffsets;
748 if (!GEP->collectOffset(DL, BitWidth: BW, VariableOffsets&: VarOffsets, ConstantOffset&: ModOffset) ||
749 VarOffsets.size() != 1 || ModOffset != 0)
750 return false;
751 auto [GepIdx, GEPScale] = VarOffsets.front();
752
753 Value *X;
754 const APInt *MulConst, *ShiftConst;
755 // Check that the gep variable index is (x * MulConst) >> ShiftConst.
756 auto MatchInner =
757 m_LShr(L: m_Mul(L: m_Value(V&: X), R: m_APInt(Res&: MulConst)), R: m_APInt(Res&: ShiftConst));
758 if (!match(V: GepIdx, P: m_CastOrSelf(Op: MatchInner)))
759 return false;
760
761 unsigned InputBits = X->getType()->getScalarSizeInBits();
762 if (InputBits != 16 && InputBits != 32 && InputBits != 64 && InputBits != 128)
763 return false;
764
765 // Verify shift amount.
766 // TODO: Allow other shift amounts when we have proper test coverage.
767 if (*ShiftConst != InputBits - Log2_32(Value: InputBits))
768 return false;
769
770 // Match the sequence of OR operations with right shifts by powers of 2.
771 for (unsigned ShiftAmt = InputBits / 2; ShiftAmt != 0; ShiftAmt /= 2) {
772 Value *Y;
773 if (!match(V: X, P: m_c_Or(L: m_LShr(L: m_Value(V&: Y), R: m_SpecificInt(V: ShiftAmt)),
774 R: m_Deferred(V: Y))))
775 return false;
776 X = Y;
777 }
778
779 if (!GEPScale.isIntN(N: InputBits) ||
780 !isLog2Table(Table: GVTable->getInitializer(), Mul: *MulConst, Shift: *ShiftConst,
781 AccessTy: AccessType, InputBits, GEPIdxFactor: GEPScale.zextOrTrunc(width: InputBits), DL))
782 return false;
783
784 ConstantInt *ZeroTableElem = cast<ConstantInt>(
785 Val: ConstantFoldLoadFromConst(C: GVTable->getInitializer(), Ty: AccessType, DL));
786
787 // Use InputBits - 1 - ctlz(X) to compute log2(X).
788 IRBuilder<> B(LI);
789 ConstantInt *BoolConst = B.getTrue();
790 Type *XType = X->getType();
791
792 // Check the the backend has an efficient ctlz instruction.
793 // FIXME: Teach the backend to emit the original code when ctlz isn't
794 // supported like we do for cttz.
795 IntrinsicCostAttributes Attrs(
796 Intrinsic::ctlz, XType,
797 {PoisonValue::get(T: XType), /*is_zero_poison=*/BoolConst});
798 InstructionCost Cost =
799 TTI.getIntrinsicInstrCost(ICA: Attrs, CostKind: TargetTransformInfo::TCK_SizeAndLatency);
800 if (Cost > TargetTransformInfo::TCC_Basic)
801 return false;
802
803 Value *Ctlz = B.CreateIntrinsic(ID: Intrinsic::ctlz, Types: {XType}, Args: {X, BoolConst});
804
805 Constant *InputBitsM1 = ConstantInt::get(Ty: XType, V: InputBits - 1);
806 Value *Sub = B.CreateSub(LHS: InputBitsM1, RHS: Ctlz);
807
808 // The table won't produce a sensible result for 0.
809 Value *Cmp = B.CreateICmpEQ(LHS: X, RHS: ConstantInt::get(Ty: XType, V: 0));
810 Value *Select = B.CreateSelect(C: Cmp, True: B.CreateZExt(V: ZeroTableElem, DestTy: XType), False: Sub);
811
812 // The true branch of select handles the log2(0) case, which is rare.
813 if (!ProfcheckDisableMetadataFixes) {
814 if (Instruction *SelectI = dyn_cast<Instruction>(Val: Select))
815 SelectI->setMetadata(
816 KindID: LLVMContext::MD_prof,
817 Node: MDBuilder(SelectI->getContext()).createUnlikelyBranchWeights());
818 }
819
820 Value *ZExtOrTrunc = B.CreateZExtOrTrunc(V: Select, DestTy: AccessType);
821
822 LI->replaceAllUsesWith(V: ZExtOrTrunc);
823
824 return true;
825}
826
827/// This is used by foldLoadsRecursive() to capture a Root Load node which is
828/// of type or(load, load) and recursively build the wide load. Also capture the
829/// shift amount, zero extend type and loadSize.
830struct LoadOps {
831 LoadInst *Root = nullptr;
832 LoadInst *RootInsert = nullptr;
833 bool FoundRoot = false;
834 uint64_t LoadSize = 0;
835 uint64_t Shift = 0;
836 Type *ZextType;
837 AAMDNodes AATags;
838};
839
840// Identify and Merge consecutive loads recursively which is of the form
841// (ZExt(L1) << shift1) | (ZExt(L2) << shift2) -> ZExt(L3) << shift1
842// (ZExt(L1) << shift1) | ZExt(L2) -> ZExt(L3)
843static bool foldLoadsRecursive(Value *V, LoadOps &LOps, const DataLayout &DL,
844 AliasAnalysis &AA, bool IsRoot = false) {
845 uint64_t ShAmt2;
846 Value *X;
847 Instruction *L1, *L2;
848
849 // For the root instruction, allow multiple uses since the final result
850 // may legitimately be used in multiple places. For intermediate values,
851 // require single use to avoid creating duplicate loads.
852 if (!IsRoot && !V->hasOneUse())
853 return false;
854
855 if (!match(V, P: m_c_Or(L: m_Value(V&: X),
856 R: m_OneUse(SubPattern: m_ShlOrSelf(L: m_OneUse(SubPattern: m_ZExt(Op: m_Instruction(I&: L2))),
857 R&: ShAmt2)))))
858 return false;
859
860 if (!foldLoadsRecursive(V: X, LOps, DL, AA, /*IsRoot=*/false) && LOps.FoundRoot)
861 // Avoid Partial chain merge.
862 return false;
863
864 // Check if the pattern has loads
865 LoadInst *LI1 = LOps.Root;
866 uint64_t ShAmt1 = LOps.Shift;
867 if (LOps.FoundRoot == false &&
868 match(V: X, P: m_OneUse(
869 SubPattern: m_ShlOrSelf(L: m_OneUse(SubPattern: m_ZExt(Op: m_Instruction(I&: L1))), R&: ShAmt1)))) {
870 LI1 = dyn_cast<LoadInst>(Val: L1);
871 }
872 LoadInst *LI2 = dyn_cast<LoadInst>(Val: L2);
873
874 // Check if loads are same, atomic, volatile and having same address space.
875 if (LI1 == LI2 || !LI1 || !LI2 || !LI1->isSimple() || !LI2->isSimple() ||
876 LI1->getPointerAddressSpace() != LI2->getPointerAddressSpace())
877 return false;
878
879 // Check if Loads come from same BB.
880 if (LI1->getParent() != LI2->getParent())
881 return false;
882
883 // Find the data layout
884 bool IsBigEndian = DL.isBigEndian();
885
886 // Check if loads are consecutive and same size.
887 Value *Load1Ptr = LI1->getPointerOperand();
888 APInt Offset1(DL.getIndexTypeSizeInBits(Ty: Load1Ptr->getType()), 0);
889 Load1Ptr =
890 Load1Ptr->stripAndAccumulateConstantOffsets(DL, Offset&: Offset1,
891 /* AllowNonInbounds */ true);
892
893 Value *Load2Ptr = LI2->getPointerOperand();
894 APInt Offset2(DL.getIndexTypeSizeInBits(Ty: Load2Ptr->getType()), 0);
895 Load2Ptr =
896 Load2Ptr->stripAndAccumulateConstantOffsets(DL, Offset&: Offset2,
897 /* AllowNonInbounds */ true);
898
899 // Verify if both loads have same base pointers
900 uint64_t LoadSize1 = LI1->getType()->getPrimitiveSizeInBits();
901 uint64_t LoadSize2 = LI2->getType()->getPrimitiveSizeInBits();
902 if (Load1Ptr != Load2Ptr)
903 return false;
904
905 // Make sure that there are no padding bits.
906 if (!DL.typeSizeEqualsStoreSize(Ty: LI1->getType()) ||
907 !DL.typeSizeEqualsStoreSize(Ty: LI2->getType()))
908 return false;
909
910 // Alias Analysis to check for stores b/w the loads.
911 LoadInst *Start = LOps.FoundRoot ? LOps.RootInsert : LI1, *End = LI2;
912 MemoryLocation Loc;
913 if (!Start->comesBefore(Other: End)) {
914 std::swap(a&: Start, b&: End);
915 // If LOps.RootInsert comes after LI2, since we use LI2 as the new insert
916 // point, we should make sure whether the memory region accessed by LOps
917 // isn't modified.
918 if (LOps.FoundRoot)
919 Loc = MemoryLocation(
920 LOps.Root->getPointerOperand(),
921 LocationSize::precise(Value: DL.getTypeStoreSize(
922 Ty: IntegerType::get(C&: LI1->getContext(), NumBits: LOps.LoadSize))),
923 LOps.AATags);
924 else
925 Loc = MemoryLocation::get(LI: End);
926 } else
927 Loc = MemoryLocation::get(LI: End);
928 unsigned NumScanned = 0;
929 for (Instruction &Inst :
930 make_range(x: Start->getIterator(), y: End->getIterator())) {
931 if (Inst.mayWriteToMemory() && isModSet(MRI: AA.getModRefInfo(I: &Inst, OptLoc: Loc)))
932 return false;
933
934 if (++NumScanned > MaxInstrsToScan)
935 return false;
936 }
937
938 // Make sure Load with lower Offset is at LI1
939 bool Reverse = false;
940 if (Offset2.slt(RHS: Offset1)) {
941 std::swap(a&: LI1, b&: LI2);
942 std::swap(a&: ShAmt1, b&: ShAmt2);
943 std::swap(a&: Offset1, b&: Offset2);
944 std::swap(a&: Load1Ptr, b&: Load2Ptr);
945 std::swap(a&: LoadSize1, b&: LoadSize2);
946 Reverse = true;
947 }
948
949 // Big endian swap the shifts
950 if (IsBigEndian)
951 std::swap(a&: ShAmt1, b&: ShAmt2);
952
953 // First load is always LI1. This is where we put the new load.
954 // Use the merged load size available from LI1 for forward loads.
955 if (LOps.FoundRoot) {
956 if (!Reverse)
957 LoadSize1 = LOps.LoadSize;
958 else
959 LoadSize2 = LOps.LoadSize;
960 }
961
962 // Verify if shift amount and load index aligns and verifies that loads
963 // are consecutive.
964 uint64_t ShiftDiff = IsBigEndian ? LoadSize2 : LoadSize1;
965 uint64_t PrevSize =
966 DL.getTypeStoreSize(Ty: IntegerType::get(C&: LI1->getContext(), NumBits: LoadSize1));
967 if ((ShAmt2 - ShAmt1) != ShiftDiff || (Offset2 - Offset1) != PrevSize)
968 return false;
969
970 // Update LOps
971 AAMDNodes AATags1 = LOps.AATags;
972 AAMDNodes AATags2 = LI2->getAAMetadata();
973 if (LOps.FoundRoot == false) {
974 LOps.FoundRoot = true;
975 AATags1 = LI1->getAAMetadata();
976 }
977 LOps.LoadSize = LoadSize1 + LoadSize2;
978 LOps.RootInsert = Start;
979
980 // Concatenate the AATags of the Merged Loads.
981 LOps.AATags = AATags1.concat(Other: AATags2);
982
983 LOps.Root = LI1;
984 LOps.Shift = ShAmt1;
985 LOps.ZextType = X->getType();
986 return true;
987}
988
989// For a given BB instruction, evaluate all loads in the chain that form a
990// pattern which suggests that the loads can be combined. The one and only use
991// of the loads is to form a wider load.
992static bool foldConsecutiveLoads(Instruction &I, const DataLayout &DL,
993 TargetTransformInfo &TTI, AliasAnalysis &AA,
994 const DominatorTree &DT) {
995 // Only consider load chains of scalar values.
996 if (isa<VectorType>(Val: I.getType()))
997 return false;
998
999 LoadOps LOps;
1000 if (!foldLoadsRecursive(V: &I, LOps, DL, AA, /*IsRoot=*/true) || !LOps.FoundRoot)
1001 return false;
1002
1003 IRBuilder<> Builder(&I);
1004 LoadInst *NewLoad = nullptr, *LI1 = LOps.Root;
1005
1006 IntegerType *WiderType = IntegerType::get(C&: I.getContext(), NumBits: LOps.LoadSize);
1007 // TTI based checks if we want to proceed with wider load
1008 bool Allowed = TTI.isTypeLegal(Ty: WiderType);
1009 if (!Allowed)
1010 return false;
1011
1012 unsigned AS = LI1->getPointerAddressSpace();
1013 unsigned Fast = 0;
1014 Allowed = TTI.allowsMisalignedMemoryAccesses(Context&: I.getContext(), BitWidth: LOps.LoadSize,
1015 AddressSpace: AS, Alignment: LI1->getAlign(), Fast: &Fast);
1016 if (!Allowed || !Fast)
1017 return false;
1018
1019 // Get the Index and Ptr for the new GEP.
1020 Value *Load1Ptr = LI1->getPointerOperand();
1021 Builder.SetInsertPoint(LOps.RootInsert);
1022 if (!DT.dominates(Def: Load1Ptr, User: LOps.RootInsert)) {
1023 APInt Offset1(DL.getIndexTypeSizeInBits(Ty: Load1Ptr->getType()), 0);
1024 Load1Ptr = Load1Ptr->stripAndAccumulateConstantOffsets(
1025 DL, Offset&: Offset1, /* AllowNonInbounds */ true);
1026 Load1Ptr = Builder.CreatePtrAdd(Ptr: Load1Ptr, Offset: Builder.getInt(AI: Offset1));
1027 }
1028 // Generate wider load.
1029 NewLoad = Builder.CreateAlignedLoad(Ty: WiderType, Ptr: Load1Ptr, Align: LI1->getAlign(),
1030 isVolatile: LI1->isVolatile(), Name: "");
1031 NewLoad->takeName(V: LI1);
1032 // Set the New Load AATags Metadata.
1033 if (LOps.AATags)
1034 NewLoad->setAAMetadata(LOps.AATags);
1035
1036 Value *NewOp = NewLoad;
1037 // Check if zero extend needed.
1038 if (LOps.ZextType)
1039 NewOp = Builder.CreateZExt(V: NewOp, DestTy: LOps.ZextType);
1040
1041 // Check if shift needed. We need to shift with the amount of load1
1042 // shift if not zero.
1043 if (LOps.Shift)
1044 NewOp = Builder.CreateShl(LHS: NewOp, RHS: LOps.Shift);
1045 I.replaceAllUsesWith(V: NewOp);
1046
1047 return true;
1048}
1049
1050/// ValWidth bits starting at ValOffset of Val stored at PtrBase+PtrOffset.
1051struct PartStore {
1052 Value *PtrBase;
1053 APInt PtrOffset;
1054 Value *Val;
1055 uint64_t ValOffset;
1056 uint64_t ValWidth;
1057 StoreInst *Store;
1058
1059 bool isCompatibleWith(const PartStore &Other) const {
1060 return PtrBase == Other.PtrBase && Val == Other.Val;
1061 }
1062
1063 bool operator<(const PartStore &Other) const {
1064 return PtrOffset.slt(RHS: Other.PtrOffset);
1065 }
1066};
1067
1068static std::optional<PartStore> matchPartStore(Instruction &I,
1069 const DataLayout &DL) {
1070 auto *Store = dyn_cast<StoreInst>(Val: &I);
1071 if (!Store || !Store->isSimple())
1072 return std::nullopt;
1073
1074 Value *StoredVal = Store->getValueOperand();
1075 Type *StoredTy = StoredVal->getType();
1076 if (!StoredTy->isIntegerTy() || !DL.typeSizeEqualsStoreSize(Ty: StoredTy))
1077 return std::nullopt;
1078
1079 uint64_t ValWidth = StoredTy->getPrimitiveSizeInBits();
1080 uint64_t ValOffset;
1081 Value *Val;
1082 if (!match(V: StoredVal, P: m_Trunc(Op: m_LShrOrSelf(L: m_Value(V&: Val), R&: ValOffset))))
1083 return std::nullopt;
1084
1085 Value *Ptr = Store->getPointerOperand();
1086 APInt PtrOffset(DL.getIndexTypeSizeInBits(Ty: Ptr->getType()), 0);
1087 Value *PtrBase = Ptr->stripAndAccumulateConstantOffsets(
1088 DL, Offset&: PtrOffset, /*AllowNonInbounds=*/true);
1089 return {{.PtrBase: PtrBase, .PtrOffset: PtrOffset, .Val: Val, .ValOffset: ValOffset, .ValWidth: ValWidth, .Store: Store}};
1090}
1091
1092static bool mergeConsecutivePartStores(ArrayRef<PartStore> Parts,
1093 unsigned Width, const DataLayout &DL,
1094 TargetTransformInfo &TTI) {
1095 if (Parts.size() < 2)
1096 return false;
1097
1098 // Check whether combining the stores is profitable.
1099 // FIXME: We could generate smaller stores if we can't produce a large one.
1100 const PartStore &First = Parts.front();
1101 LLVMContext &Ctx = First.Store->getContext();
1102 Type *NewTy = Type::getIntNTy(C&: Ctx, N: Width);
1103 unsigned Fast = 0;
1104 if (!TTI.isTypeLegal(Ty: NewTy) ||
1105 !TTI.allowsMisalignedMemoryAccesses(Context&: Ctx, BitWidth: Width,
1106 AddressSpace: First.Store->getPointerAddressSpace(),
1107 Alignment: First.Store->getAlign(), Fast: &Fast) ||
1108 !Fast)
1109 return false;
1110
1111 // Generate the combined store.
1112 IRBuilder<> Builder(First.Store);
1113 Value *Val = First.Val;
1114 if (First.ValOffset != 0)
1115 Val = Builder.CreateLShr(LHS: Val, RHS: First.ValOffset);
1116 Val = Builder.CreateZExtOrTrunc(V: Val, DestTy: NewTy);
1117 StoreInst *Store = Builder.CreateAlignedStore(
1118 Val, Ptr: First.Store->getPointerOperand(), Align: First.Store->getAlign());
1119
1120 // Merge various metadata onto the new store.
1121 AAMDNodes AATags = First.Store->getAAMetadata();
1122 SmallVector<Instruction *> Stores = {First.Store};
1123 Stores.reserve(N: Parts.size());
1124 SmallVector<DebugLoc> DbgLocs = {First.Store->getDebugLoc()};
1125 DbgLocs.reserve(N: Parts.size());
1126 for (const PartStore &Part : drop_begin(RangeOrContainer&: Parts)) {
1127 AATags = AATags.concat(Other: Part.Store->getAAMetadata());
1128 Stores.push_back(Elt: Part.Store);
1129 DbgLocs.push_back(Elt: Part.Store->getDebugLoc());
1130 }
1131 Store->setAAMetadata(AATags);
1132 Store->mergeDIAssignID(SourceInstructions: Stores);
1133 Store->setDebugLoc(DebugLoc::getMergedLocations(Locs: DbgLocs));
1134
1135 // Remove the old stores.
1136 for (const PartStore &Part : Parts)
1137 Part.Store->eraseFromParent();
1138
1139 return true;
1140}
1141
1142static bool mergePartStores(SmallVectorImpl<PartStore> &Parts,
1143 const DataLayout &DL, TargetTransformInfo &TTI) {
1144 if (Parts.size() < 2)
1145 return false;
1146
1147 // We now have multiple parts of the same value stored to the same pointer.
1148 // Sort the parts by pointer offset, and make sure they are consistent with
1149 // the value offsets. Also check that the value is fully covered without
1150 // overlaps.
1151 bool Changed = false;
1152 llvm::sort(C&: Parts);
1153 int64_t LastEndOffsetFromFirst = 0;
1154 const PartStore *First = &Parts[0];
1155 for (const PartStore &Part : Parts) {
1156 APInt PtrOffsetFromFirst = Part.PtrOffset - First->PtrOffset;
1157 int64_t ValOffsetFromFirst = Part.ValOffset - First->ValOffset;
1158 if (PtrOffsetFromFirst * 8 != ValOffsetFromFirst ||
1159 LastEndOffsetFromFirst != ValOffsetFromFirst) {
1160 Changed |= mergeConsecutivePartStores(Parts: ArrayRef(First, &Part),
1161 Width: LastEndOffsetFromFirst, DL, TTI);
1162 First = &Part;
1163 LastEndOffsetFromFirst = Part.ValWidth;
1164 continue;
1165 }
1166
1167 LastEndOffsetFromFirst = ValOffsetFromFirst + Part.ValWidth;
1168 }
1169
1170 Changed |= mergeConsecutivePartStores(Parts: ArrayRef(First, Parts.end()),
1171 Width: LastEndOffsetFromFirst, DL, TTI);
1172 return Changed;
1173}
1174
1175static bool foldConsecutiveStores(BasicBlock &BB, const DataLayout &DL,
1176 TargetTransformInfo &TTI, AliasAnalysis &AA) {
1177 // FIXME: Add big endian support.
1178 if (DL.isBigEndian())
1179 return false;
1180
1181 BatchAAResults BatchAA(AA);
1182 SmallVector<PartStore, 8> Parts;
1183 bool MadeChange = false;
1184 for (Instruction &I : make_early_inc_range(Range&: BB)) {
1185 if (std::optional<PartStore> Part = matchPartStore(I, DL)) {
1186 if (Parts.empty() || Part->isCompatibleWith(Other: Parts[0])) {
1187 Parts.push_back(Elt: std::move(*Part));
1188 continue;
1189 }
1190
1191 MadeChange |= mergePartStores(Parts, DL, TTI);
1192 Parts.clear();
1193 Parts.push_back(Elt: std::move(*Part));
1194 continue;
1195 }
1196
1197 if (Parts.empty())
1198 continue;
1199
1200 if (I.mayThrow() ||
1201 (I.mayReadOrWriteMemory() &&
1202 isModOrRefSet(MRI: BatchAA.getModRefInfo(
1203 I: &I, OptLoc: MemoryLocation::getBeforeOrAfter(Ptr: Parts[0].PtrBase))))) {
1204 MadeChange |= mergePartStores(Parts, DL, TTI);
1205 Parts.clear();
1206 continue;
1207 }
1208 }
1209
1210 MadeChange |= mergePartStores(Parts, DL, TTI);
1211 return MadeChange;
1212}
1213
1214/// Combine away instructions providing they are still equivalent when compared
1215/// against 0. i.e do they have any bits set.
1216static Value *optimizeShiftInOrChain(Value *V, IRBuilder<> &Builder) {
1217 auto *I = dyn_cast<Instruction>(Val: V);
1218 if (!I || I->getOpcode() != Instruction::Or || !I->hasOneUse())
1219 return nullptr;
1220
1221 Value *A;
1222
1223 // Look deeper into the chain of or's, combining away shl (so long as they are
1224 // nuw or nsw).
1225 Value *Op0 = I->getOperand(i: 0);
1226 if (match(V: Op0, P: m_CombineOr(L: m_NSWShl(L: m_Value(V&: A), R: m_Value()),
1227 R: m_NUWShl(L: m_Value(V&: A), R: m_Value()))))
1228 Op0 = A;
1229 else if (auto *NOp = optimizeShiftInOrChain(V: Op0, Builder))
1230 Op0 = NOp;
1231
1232 Value *Op1 = I->getOperand(i: 1);
1233 if (match(V: Op1, P: m_CombineOr(L: m_NSWShl(L: m_Value(V&: A), R: m_Value()),
1234 R: m_NUWShl(L: m_Value(V&: A), R: m_Value()))))
1235 Op1 = A;
1236 else if (auto *NOp = optimizeShiftInOrChain(V: Op1, Builder))
1237 Op1 = NOp;
1238
1239 if (Op0 != I->getOperand(i: 0) || Op1 != I->getOperand(i: 1))
1240 return Builder.CreateOr(LHS: Op0, RHS: Op1);
1241 return nullptr;
1242}
1243
1244static bool foldICmpOrChain(Instruction &I, const DataLayout &DL,
1245 TargetTransformInfo &TTI, AliasAnalysis &AA,
1246 const DominatorTree &DT) {
1247 CmpPredicate Pred;
1248 Value *Op0;
1249 if (!match(V: &I, P: m_ICmp(Pred, L: m_Value(V&: Op0), R: m_Zero())) ||
1250 !ICmpInst::isEquality(P: Pred))
1251 return false;
1252
1253 // If the chain or or's matches a load, combine to that before attempting to
1254 // remove shifts.
1255 if (auto OpI = dyn_cast<Instruction>(Val: Op0))
1256 if (OpI->getOpcode() == Instruction::Or)
1257 if (foldConsecutiveLoads(I&: *OpI, DL, TTI, AA, DT))
1258 return true;
1259
1260 IRBuilder<> Builder(&I);
1261 // icmp eq/ne or(shl(a), b), 0 -> icmp eq/ne or(a, b), 0
1262 if (auto *Res = optimizeShiftInOrChain(V: Op0, Builder)) {
1263 I.replaceAllUsesWith(V: Builder.CreateICmp(P: Pred, LHS: Res, RHS: I.getOperand(i: 1)));
1264 return true;
1265 }
1266
1267 return false;
1268}
1269
1270// Calculate GEP Stride and accumulated const ModOffset. Return Stride and
1271// ModOffset
1272static std::pair<APInt, APInt>
1273getStrideAndModOffsetOfGEP(Value *PtrOp, const DataLayout &DL) {
1274 unsigned BW = DL.getIndexTypeSizeInBits(Ty: PtrOp->getType());
1275 std::optional<APInt> Stride;
1276 APInt ModOffset(BW, 0);
1277 // Return a minimum gep stride, greatest common divisor of consective gep
1278 // index scales(c.f. Bézout's identity).
1279 while (auto *GEP = dyn_cast<GEPOperator>(Val: PtrOp)) {
1280 SmallMapVector<Value *, APInt, 4> VarOffsets;
1281 if (!GEP->collectOffset(DL, BitWidth: BW, VariableOffsets&: VarOffsets, ConstantOffset&: ModOffset))
1282 break;
1283
1284 for (auto [V, Scale] : VarOffsets) {
1285 // Only keep a power of two factor for non-inbounds
1286 if (!GEP->hasNoUnsignedSignedWrap())
1287 Scale = APInt::getOneBitSet(numBits: Scale.getBitWidth(), BitNo: Scale.countr_zero());
1288
1289 if (!Stride)
1290 Stride = Scale;
1291 else
1292 Stride = APIntOps::GreatestCommonDivisor(A: *Stride, B: Scale);
1293 }
1294
1295 PtrOp = GEP->getPointerOperand();
1296 }
1297
1298 // Check whether pointer arrives back at Global Variable via at least one GEP.
1299 // Even if it doesn't, we can check by alignment.
1300 if (!isa<GlobalVariable>(Val: PtrOp) || !Stride)
1301 return {APInt(BW, 1), APInt(BW, 0)};
1302
1303 // In consideration of signed GEP indices, non-negligible offset become
1304 // remainder of division by minimum GEP stride.
1305 ModOffset = ModOffset.srem(RHS: *Stride);
1306 if (ModOffset.isNegative())
1307 ModOffset += *Stride;
1308
1309 return {*Stride, ModOffset};
1310}
1311
1312/// If C is a constant patterned array and all valid loaded results for given
1313/// alignment are same to a constant, return that constant.
1314static bool foldPatternedLoads(Instruction &I, const DataLayout &DL) {
1315 auto *LI = dyn_cast<LoadInst>(Val: &I);
1316 if (!LI || LI->isVolatile())
1317 return false;
1318
1319 // We can only fold the load if it is from a constant global with definitive
1320 // initializer. Skip expensive logic if this is not the case.
1321 auto *PtrOp = LI->getPointerOperand();
1322 auto *GV = dyn_cast<GlobalVariable>(Val: getUnderlyingObject(V: PtrOp));
1323 if (!GV || !GV->isConstant() || !GV->hasDefinitiveInitializer())
1324 return false;
1325
1326 // Bail for large initializers in excess of 4K to avoid too many scans.
1327 Constant *C = GV->getInitializer();
1328 uint64_t GVSize = DL.getTypeAllocSize(Ty: C->getType());
1329 if (!GVSize || 4096 < GVSize)
1330 return false;
1331
1332 Type *LoadTy = LI->getType();
1333 unsigned BW = DL.getIndexTypeSizeInBits(Ty: PtrOp->getType());
1334 auto [Stride, ConstOffset] = getStrideAndModOffsetOfGEP(PtrOp, DL);
1335
1336 // Any possible offset could be multiple of GEP stride. And any valid
1337 // offset is multiple of load alignment, so checking only multiples of bigger
1338 // one is sufficient to say results' equality.
1339 if (auto LA = LI->getAlign();
1340 LA <= GV->getAlign().valueOrOne() && Stride.getZExtValue() < LA.value()) {
1341 ConstOffset = APInt(BW, 0);
1342 Stride = APInt(BW, LA.value());
1343 }
1344
1345 Constant *Ca = ConstantFoldLoadFromConst(C, Ty: LoadTy, Offset: ConstOffset, DL);
1346 if (!Ca)
1347 return false;
1348
1349 unsigned E = GVSize - DL.getTypeStoreSize(Ty: LoadTy);
1350 for (; ConstOffset.getZExtValue() <= E; ConstOffset += Stride)
1351 if (Ca != ConstantFoldLoadFromConst(C, Ty: LoadTy, Offset: ConstOffset, DL))
1352 return false;
1353
1354 I.replaceAllUsesWith(V: Ca);
1355
1356 return true;
1357}
1358
1359namespace {
1360class StrNCmpInliner {
1361public:
1362 StrNCmpInliner(CallInst *CI, LibFunc Func, DomTreeUpdater *DTU,
1363 const DataLayout &DL)
1364 : CI(CI), Func(Func), DTU(DTU), DL(DL) {}
1365
1366 bool optimizeStrNCmp();
1367
1368private:
1369 void inlineCompare(Value *LHS, StringRef RHS, uint64_t N, bool Swapped);
1370
1371 CallInst *CI;
1372 LibFunc Func;
1373 DomTreeUpdater *DTU;
1374 const DataLayout &DL;
1375};
1376
1377} // namespace
1378
1379/// First we normalize calls to strncmp/strcmp to the form of
1380/// compare(s1, s2, N), which means comparing first N bytes of s1 and s2
1381/// (without considering '\0').
1382///
1383/// Examples:
1384///
1385/// \code
1386/// strncmp(s, "a", 3) -> compare(s, "a", 2)
1387/// strncmp(s, "abc", 3) -> compare(s, "abc", 3)
1388/// strncmp(s, "a\0b", 3) -> compare(s, "a\0b", 2)
1389/// strcmp(s, "a") -> compare(s, "a", 2)
1390///
1391/// char s2[] = {'a'}
1392/// strncmp(s, s2, 3) -> compare(s, s2, 3)
1393///
1394/// char s2[] = {'a', 'b', 'c', 'd'}
1395/// strncmp(s, s2, 3) -> compare(s, s2, 3)
1396/// \endcode
1397///
1398/// We only handle cases where N and exactly one of s1 and s2 are constant.
1399/// Cases that s1 and s2 are both constant are already handled by the
1400/// instcombine pass.
1401///
1402/// We do not handle cases where N > StrNCmpInlineThreshold.
1403///
1404/// We also do not handles cases where N < 2, which are already
1405/// handled by the instcombine pass.
1406///
1407bool StrNCmpInliner::optimizeStrNCmp() {
1408 if (StrNCmpInlineThreshold < 2)
1409 return false;
1410
1411 if (!isOnlyUsedInZeroComparison(CxtI: CI))
1412 return false;
1413
1414 Value *Str1P = CI->getArgOperand(i: 0);
1415 Value *Str2P = CI->getArgOperand(i: 1);
1416 // Should be handled elsewhere.
1417 if (Str1P == Str2P)
1418 return false;
1419
1420 StringRef Str1, Str2;
1421 bool HasStr1 = getConstantStringInfo(V: Str1P, Str&: Str1, /*TrimAtNul=*/false);
1422 bool HasStr2 = getConstantStringInfo(V: Str2P, Str&: Str2, /*TrimAtNul=*/false);
1423 if (HasStr1 == HasStr2)
1424 return false;
1425
1426 // Note that '\0' and characters after it are not trimmed.
1427 StringRef Str = HasStr1 ? Str1 : Str2;
1428 Value *StrP = HasStr1 ? Str2P : Str1P;
1429
1430 size_t Idx = Str.find(C: '\0');
1431 uint64_t N = Idx == StringRef::npos ? UINT64_MAX : Idx + 1;
1432 if (Func == LibFunc_strncmp) {
1433 if (auto *ConstInt = dyn_cast<ConstantInt>(Val: CI->getArgOperand(i: 2)))
1434 N = std::min(a: N, b: ConstInt->getZExtValue());
1435 else
1436 return false;
1437 }
1438 // Now N means how many bytes we need to compare at most.
1439 if (N > Str.size() || N < 2 || N > StrNCmpInlineThreshold)
1440 return false;
1441
1442 // Cases where StrP has two or more dereferenceable bytes might be better
1443 // optimized elsewhere.
1444 bool CanBeNull = false, CanBeFreed = false;
1445 if (StrP->getPointerDereferenceableBytes(DL, CanBeNull, CanBeFreed) > 1)
1446 return false;
1447 inlineCompare(LHS: StrP, RHS: Str, N, Swapped: HasStr1);
1448 return true;
1449}
1450
1451/// Convert
1452///
1453/// \code
1454/// ret = compare(s1, s2, N)
1455/// \endcode
1456///
1457/// into
1458///
1459/// \code
1460/// ret = (int)s1[0] - (int)s2[0]
1461/// if (ret != 0)
1462/// goto NE
1463/// ...
1464/// ret = (int)s1[N-2] - (int)s2[N-2]
1465/// if (ret != 0)
1466/// goto NE
1467/// ret = (int)s1[N-1] - (int)s2[N-1]
1468/// NE:
1469/// \endcode
1470///
1471/// CFG before and after the transformation:
1472///
1473/// (before)
1474/// BBCI
1475///
1476/// (after)
1477/// BBCI -> BBSubs[0] (sub,icmp) --NE-> BBNE -> BBTail
1478/// | ^
1479/// E |
1480/// | |
1481/// BBSubs[1] (sub,icmp) --NE-----+
1482/// ... |
1483/// BBSubs[N-1] (sub) ---------+
1484///
1485void StrNCmpInliner::inlineCompare(Value *LHS, StringRef RHS, uint64_t N,
1486 bool Swapped) {
1487 auto &Ctx = CI->getContext();
1488 IRBuilder<> B(Ctx);
1489 // We want these instructions to be recognized as inlined instructions for the
1490 // compare call, but we don't have a source location for the definition of
1491 // that function, since we're generating that code now. Because the generated
1492 // code is a viable point for a memory access error, we make the pragmatic
1493 // choice here to directly use CI's location so that we have useful
1494 // attribution for the generated code.
1495 B.SetCurrentDebugLocation(CI->getDebugLoc());
1496
1497 BasicBlock *BBCI = CI->getParent();
1498 BasicBlock *BBTail =
1499 SplitBlock(Old: BBCI, SplitPt: CI, DTU, LI: nullptr, MSSAU: nullptr, BBName: BBCI->getName() + ".tail");
1500
1501 SmallVector<BasicBlock *> BBSubs;
1502 for (uint64_t I = 0; I < N; ++I)
1503 BBSubs.push_back(
1504 Elt: BasicBlock::Create(Context&: Ctx, Name: "sub_" + Twine(I), Parent: BBCI->getParent(), InsertBefore: BBTail));
1505 BasicBlock *BBNE = BasicBlock::Create(Context&: Ctx, Name: "ne", Parent: BBCI->getParent(), InsertBefore: BBTail);
1506
1507 cast<UncondBrInst>(Val: BBCI->getTerminator())->setSuccessor(BBSubs[0]);
1508
1509 B.SetInsertPoint(BBNE);
1510 PHINode *Phi = B.CreatePHI(Ty: CI->getType(), NumReservedValues: N);
1511 B.CreateBr(Dest: BBTail);
1512
1513 Value *Base = LHS;
1514 for (uint64_t i = 0; i < N; ++i) {
1515 B.SetInsertPoint(BBSubs[i]);
1516 Value *VL =
1517 B.CreateZExt(V: B.CreateLoad(Ty: B.getInt8Ty(),
1518 Ptr: B.CreateInBoundsPtrAdd(Ptr: Base, Offset: B.getInt64(C: i))),
1519 DestTy: CI->getType());
1520 Value *VR =
1521 ConstantInt::get(Ty: CI->getType(), V: static_cast<unsigned char>(RHS[i]));
1522 Value *Sub = Swapped ? B.CreateSub(LHS: VR, RHS: VL) : B.CreateSub(LHS: VL, RHS: VR);
1523 if (i < N - 1) {
1524 CondBrInst *CondBrInst = B.CreateCondBr(
1525 Cond: B.CreateICmpNE(LHS: Sub, RHS: ConstantInt::get(Ty: CI->getType(), V: 0)), True: BBNE,
1526 False: BBSubs[i + 1]);
1527
1528 Function *F = CI->getFunction();
1529 assert(F && "Instruction does not belong to a function!");
1530 std::optional<Function::ProfileCount> EC = F->getEntryCount();
1531 if (EC && EC->getCount() > 0)
1532 setExplicitlyUnknownBranchWeights(I&: *CondBrInst, DEBUG_TYPE);
1533 } else {
1534 B.CreateBr(Dest: BBNE);
1535 }
1536
1537 Phi->addIncoming(V: Sub, BB: BBSubs[i]);
1538 }
1539
1540 CI->replaceAllUsesWith(V: Phi);
1541 CI->eraseFromParent();
1542
1543 if (DTU) {
1544 SmallVector<DominatorTree::UpdateType, 8> Updates;
1545 Updates.push_back(Elt: {DominatorTree::Insert, BBCI, BBSubs[0]});
1546 for (uint64_t i = 0; i < N; ++i) {
1547 if (i < N - 1)
1548 Updates.push_back(Elt: {DominatorTree::Insert, BBSubs[i], BBSubs[i + 1]});
1549 Updates.push_back(Elt: {DominatorTree::Insert, BBSubs[i], BBNE});
1550 }
1551 Updates.push_back(Elt: {DominatorTree::Insert, BBNE, BBTail});
1552 Updates.push_back(Elt: {DominatorTree::Delete, BBCI, BBTail});
1553 DTU->applyUpdates(Updates);
1554 }
1555}
1556
1557/// Convert memchr with a small constant string into a switch
1558static bool foldMemChr(CallInst *Call, DomTreeUpdater *DTU,
1559 const DataLayout &DL) {
1560 if (isa<Constant>(Val: Call->getArgOperand(i: 1)))
1561 return false;
1562
1563 StringRef Str;
1564 Value *Base = Call->getArgOperand(i: 0);
1565 if (!getConstantStringInfo(V: Base, Str, /*TrimAtNul=*/false))
1566 return false;
1567
1568 uint64_t N = Str.size();
1569 if (auto *ConstInt = dyn_cast<ConstantInt>(Val: Call->getArgOperand(i: 2))) {
1570 uint64_t Val = ConstInt->getZExtValue();
1571 // Ignore the case that n is larger than the size of string.
1572 if (Val > N)
1573 return false;
1574 N = Val;
1575 } else
1576 return false;
1577
1578 if (N > MemChrInlineThreshold)
1579 return false;
1580
1581 BasicBlock *BB = Call->getParent();
1582 BasicBlock *BBNext = SplitBlock(Old: BB, SplitPt: Call, DTU);
1583 IRBuilder<> IRB(BB);
1584 IRB.SetCurrentDebugLocation(Call->getDebugLoc());
1585 IntegerType *ByteTy = IRB.getInt8Ty();
1586 BB->getTerminator()->eraseFromParent();
1587 SwitchInst *SI = IRB.CreateSwitch(
1588 V: IRB.CreateTrunc(V: Call->getArgOperand(i: 1), DestTy: ByteTy), Dest: BBNext, NumCases: N);
1589 // We can't know the precise weights here, as they would depend on the value
1590 // distribution of Call->getArgOperand(1). So we just mark it as "unknown".
1591 setExplicitlyUnknownBranchWeightsIfProfiled(I&: *SI, DEBUG_TYPE);
1592 Type *IndexTy = DL.getIndexType(PtrTy: Call->getType());
1593 SmallVector<DominatorTree::UpdateType, 8> Updates;
1594
1595 BasicBlock *BBSuccess = BasicBlock::Create(
1596 Context&: Call->getContext(), Name: "memchr.success", Parent: BB->getParent(), InsertBefore: BBNext);
1597 IRB.SetInsertPoint(BBSuccess);
1598 PHINode *IndexPHI = IRB.CreatePHI(Ty: IndexTy, NumReservedValues: N, Name: "memchr.idx");
1599 Value *FirstOccursLocation = IRB.CreateInBoundsPtrAdd(Ptr: Base, Offset: IndexPHI);
1600 IRB.CreateBr(Dest: BBNext);
1601 if (DTU)
1602 Updates.push_back(Elt: {DominatorTree::Insert, BBSuccess, BBNext});
1603
1604 SmallPtrSet<ConstantInt *, 4> Cases;
1605 for (uint64_t I = 0; I < N; ++I) {
1606 ConstantInt *CaseVal =
1607 ConstantInt::get(Ty: ByteTy, V: static_cast<unsigned char>(Str[I]));
1608 if (!Cases.insert(Ptr: CaseVal).second)
1609 continue;
1610
1611 BasicBlock *BBCase = BasicBlock::Create(Context&: Call->getContext(), Name: "memchr.case",
1612 Parent: BB->getParent(), InsertBefore: BBSuccess);
1613 SI->addCase(OnVal: CaseVal, Dest: BBCase);
1614 IRB.SetInsertPoint(BBCase);
1615 IndexPHI->addIncoming(V: ConstantInt::get(Ty: IndexTy, V: I), BB: BBCase);
1616 IRB.CreateBr(Dest: BBSuccess);
1617 if (DTU) {
1618 Updates.push_back(Elt: {DominatorTree::Insert, BB, BBCase});
1619 Updates.push_back(Elt: {DominatorTree::Insert, BBCase, BBSuccess});
1620 }
1621 }
1622
1623 PHINode *PHI =
1624 PHINode::Create(Ty: Call->getType(), NumReservedValues: 2, NameStr: Call->getName(), InsertBefore: BBNext->begin());
1625 PHI->addIncoming(V: Constant::getNullValue(Ty: Call->getType()), BB);
1626 PHI->addIncoming(V: FirstOccursLocation, BB: BBSuccess);
1627
1628 Call->replaceAllUsesWith(V: PHI);
1629 Call->eraseFromParent();
1630
1631 if (DTU)
1632 DTU->applyUpdates(Updates);
1633
1634 return true;
1635}
1636
1637static bool foldLibCalls(Instruction &I, TargetTransformInfo &TTI,
1638 TargetLibraryInfo &TLI, AssumptionCache &AC,
1639 DominatorTree &DT, const DataLayout &DL,
1640 bool &MadeCFGChange) {
1641
1642 auto *CI = dyn_cast<CallInst>(Val: &I);
1643 if (!CI || CI->isNoBuiltin())
1644 return false;
1645
1646 Function *CalledFunc = CI->getCalledFunction();
1647 if (!CalledFunc)
1648 return false;
1649
1650 LibFunc LF;
1651 if (!TLI.getLibFunc(FDecl: *CalledFunc, F&: LF) ||
1652 !isLibFuncEmittable(M: CI->getModule(), TLI: &TLI, TheLibFunc: LF))
1653 return false;
1654
1655 DomTreeUpdater DTU(&DT, DomTreeUpdater::UpdateStrategy::Lazy);
1656
1657 switch (LF) {
1658 case LibFunc_sqrt:
1659 case LibFunc_sqrtf:
1660 case LibFunc_sqrtl:
1661 return foldSqrt(Call: CI, Func: LF, TTI, TLI, AC, DT);
1662 case LibFunc_strcmp:
1663 case LibFunc_strncmp:
1664 if (StrNCmpInliner(CI, LF, &DTU, DL).optimizeStrNCmp()) {
1665 MadeCFGChange = true;
1666 return true;
1667 }
1668 break;
1669 case LibFunc_memchr:
1670 if (foldMemChr(Call: CI, DTU: &DTU, DL)) {
1671 MadeCFGChange = true;
1672 return true;
1673 }
1674 break;
1675 default:;
1676 }
1677 return false;
1678}
1679
1680/// Match high part of long multiplication.
1681///
1682/// Considering a multiply made up of high and low parts, we can split the
1683/// multiply into:
1684/// x * y == (xh*T + xl) * (yh*T + yl)
1685/// where xh == x>>32 and xl == x & 0xffffffff. T = 2^32.
1686/// This expands to
1687/// xh*yh*T*T + xh*yl*T + xl*yh*T + xl*yl
1688/// which can be drawn as
1689/// [ xh*yh ]
1690/// [ xh*yl ]
1691/// [ xl*yh ]
1692/// [ xl*yl ]
1693/// We are looking for the "high" half, which is xh*yh + xh*yl>>32 + xl*yh>>32 +
1694/// some carrys. The carry makes this difficult and there are multiple ways of
1695/// representing it. The ones we attempt to support here are:
1696/// Carry: xh*yh + carry + lowsum
1697/// carry = lowsum < xh*yl ? 0x1000000 : 0
1698/// lowsum = xh*yl + xl*yh + (xl*yl>>32)
1699/// Ladder: xh*yh + c2>>32 + c3>>32
1700/// c2 = xh*yl + (xl*yl>>32); c3 = c2&0xffffffff + xl*yh
1701/// or c2 = (xl*yh&0xffffffff) + xh*yl + (xl*yl>>32); c3 = xl*yh
1702/// Carry4: xh*yh + carry + crosssum>>32 + (xl*yl + crosssum&0xffffffff) >> 32
1703/// crosssum = xh*yl + xl*yh
1704/// carry = crosssum < xh*yl ? 0x1000000 : 0
1705/// Ladder4: xh*yh + (xl*yh)>>32 + (xh*yl)>>32 + low>>32;
1706/// low = (xl*yl)>>32 + (xl*yh)&0xffffffff + (xh*yl)&0xffffffff
1707///
1708/// They all start by matching xh*yh + 2 or 3 other operands. The bottom of the
1709/// tree is xh*yh, xh*yl, xl*yh and xl*yl.
1710static bool foldMulHigh(Instruction &I) {
1711 Type *Ty = I.getType();
1712 if (!Ty->isIntOrIntVectorTy())
1713 return false;
1714
1715 unsigned BitWidth = Ty->getScalarSizeInBits();
1716 APInt LowMask = APInt::getLowBitsSet(numBits: BitWidth, loBitsSet: BitWidth / 2);
1717 if (BitWidth % 2 != 0)
1718 return false;
1719
1720 auto CreateMulHigh = [&](Value *X, Value *Y) {
1721 IRBuilder<> Builder(&I);
1722 Type *NTy = Ty->getWithNewBitWidth(NewBitWidth: BitWidth * 2);
1723 Value *XExt = Builder.CreateZExt(V: X, DestTy: NTy);
1724 Value *YExt = Builder.CreateZExt(V: Y, DestTy: NTy);
1725 Value *Mul = Builder.CreateMul(LHS: XExt, RHS: YExt, Name: "", /*HasNUW=*/true);
1726 Value *High = Builder.CreateLShr(LHS: Mul, RHS: BitWidth);
1727 Value *Res = Builder.CreateTrunc(V: High, DestTy: Ty, Name: "", /*HasNUW=*/IsNUW: true);
1728 Res->takeName(V: &I);
1729 I.replaceAllUsesWith(V: Res);
1730 LLVM_DEBUG(dbgs() << "Created long multiply from parts of " << *X << " and "
1731 << *Y << "\n");
1732 return true;
1733 };
1734
1735 // Common check routines for X_lo*Y_lo and X_hi*Y_lo
1736 auto CheckLoLo = [&](Value *XlYl, Value *X, Value *Y) {
1737 return match(V: XlYl, P: m_c_Mul(L: m_And(L: m_Specific(V: X), R: m_SpecificInt(V: LowMask)),
1738 R: m_And(L: m_Specific(V: Y), R: m_SpecificInt(V: LowMask))));
1739 };
1740 auto CheckHiLo = [&](Value *XhYl, Value *X, Value *Y) {
1741 return match(V: XhYl,
1742 P: m_c_Mul(L: m_LShr(L: m_Specific(V: X), R: m_SpecificInt(V: BitWidth / 2)),
1743 R: m_And(L: m_Specific(V: Y), R: m_SpecificInt(V: LowMask))));
1744 };
1745
1746 auto FoldMulHighCarry = [&](Value *X, Value *Y, Instruction *Carry,
1747 Instruction *B) {
1748 // Looking for LowSum >> 32 and carry (select)
1749 if (Carry->getOpcode() != Instruction::Select)
1750 std::swap(a&: Carry, b&: B);
1751
1752 // Carry = LowSum < XhYl ? 0x100000000 : 0
1753 Value *LowSum, *XhYl;
1754 if (!match(V: Carry,
1755 P: m_OneUse(SubPattern: m_Select(
1756 C: m_OneUse(SubPattern: m_SpecificICmp(MatchPred: ICmpInst::ICMP_ULT, L: m_Value(V&: LowSum),
1757 R: m_Value(V&: XhYl))),
1758 L: m_SpecificInt(V: APInt::getOneBitSet(numBits: BitWidth, BitNo: BitWidth / 2)),
1759 R: m_Zero()))))
1760 return false;
1761
1762 // XhYl can be Xh*Yl or Xl*Yh
1763 if (!CheckHiLo(XhYl, X, Y)) {
1764 if (CheckHiLo(XhYl, Y, X))
1765 std::swap(a&: X, b&: Y);
1766 else
1767 return false;
1768 }
1769 if (XhYl->hasNUsesOrMore(N: 3))
1770 return false;
1771
1772 // B = LowSum >> 32
1773 if (!match(V: B, P: m_OneUse(SubPattern: m_LShr(L: m_Specific(V: LowSum),
1774 R: m_SpecificInt(V: BitWidth / 2)))) ||
1775 LowSum->hasNUsesOrMore(N: 3))
1776 return false;
1777
1778 // LowSum = XhYl + XlYh + XlYl>>32
1779 Value *XlYh, *XlYl;
1780 auto XlYlHi = m_LShr(L: m_Value(V&: XlYl), R: m_SpecificInt(V: BitWidth / 2));
1781 if (!match(V: LowSum,
1782 P: m_c_Add(L: m_Specific(V: XhYl),
1783 R: m_OneUse(SubPattern: m_c_Add(L: m_OneUse(SubPattern: m_Value(V&: XlYh)), R: XlYlHi)))) &&
1784 !match(V: LowSum, P: m_c_Add(L: m_OneUse(SubPattern: m_Value(V&: XlYh)),
1785 R: m_OneUse(SubPattern: m_c_Add(L: m_Specific(V: XhYl), R: XlYlHi)))) &&
1786 !match(V: LowSum,
1787 P: m_c_Add(L: XlYlHi, R: m_OneUse(SubPattern: m_c_Add(L: m_Specific(V: XhYl),
1788 R: m_OneUse(SubPattern: m_Value(V&: XlYh)))))))
1789 return false;
1790
1791 // Check XlYl and XlYh
1792 if (!CheckLoLo(XlYl, X, Y))
1793 return false;
1794 if (!CheckHiLo(XlYh, Y, X))
1795 return false;
1796
1797 return CreateMulHigh(X, Y);
1798 };
1799
1800 auto FoldMulHighLadder = [&](Value *X, Value *Y, Instruction *A,
1801 Instruction *B) {
1802 // xh*yh + c2>>32 + c3>>32
1803 // c2 = xh*yl + (xl*yl>>32); c3 = c2&0xffffffff + xl*yh
1804 // or c2 = (xl*yh&0xffffffff) + xh*yl + (xl*yl>>32); c3 = xh*yl
1805 Value *XlYh, *XhYl, *XlYl, *C2, *C3;
1806 // Strip off the two expected shifts.
1807 if (!match(V: A, P: m_LShr(L: m_Value(V&: C2), R: m_SpecificInt(V: BitWidth / 2))) ||
1808 !match(V: B, P: m_LShr(L: m_Value(V&: C3), R: m_SpecificInt(V: BitWidth / 2))))
1809 return false;
1810
1811 if (match(V: C3, P: m_c_Add(L: m_Add(L: m_Value(), R: m_Value()), R: m_Value())))
1812 std::swap(a&: C2, b&: C3);
1813 // Try to match c2 = (xl*yh&0xffffffff) + xh*yl + (xl*yl>>32)
1814 if (match(V: C2,
1815 P: m_c_Add(L: m_c_Add(L: m_And(L: m_Specific(V: C3), R: m_SpecificInt(V: LowMask)),
1816 R: m_Value(V&: XlYh)),
1817 R: m_LShr(L: m_Value(V&: XlYl), R: m_SpecificInt(V: BitWidth / 2)))) ||
1818 match(V: C2, P: m_c_Add(L: m_c_Add(L: m_And(L: m_Specific(V: C3), R: m_SpecificInt(V: LowMask)),
1819 R: m_LShr(L: m_Value(V&: XlYl),
1820 R: m_SpecificInt(V: BitWidth / 2))),
1821 R: m_Value(V&: XlYh))) ||
1822 match(V: C2, P: m_c_Add(L: m_c_Add(L: m_LShr(L: m_Value(V&: XlYl),
1823 R: m_SpecificInt(V: BitWidth / 2)),
1824 R: m_Value(V&: XlYh)),
1825 R: m_And(L: m_Specific(V: C3), R: m_SpecificInt(V: LowMask))))) {
1826 XhYl = C3;
1827 } else {
1828 // Match c3 = c2&0xffffffff + xl*yh
1829 if (!match(V: C3, P: m_c_Add(L: m_And(L: m_Specific(V: C2), R: m_SpecificInt(V: LowMask)),
1830 R: m_Value(V&: XlYh))))
1831 std::swap(a&: C2, b&: C3);
1832 if (!match(V: C3, P: m_c_Add(L: m_OneUse(
1833 SubPattern: m_And(L: m_Specific(V: C2), R: m_SpecificInt(V: LowMask))),
1834 R: m_Value(V&: XlYh))) ||
1835 !C3->hasOneUse() || C2->hasNUsesOrMore(N: 3))
1836 return false;
1837
1838 // Match c2 = xh*yl + (xl*yl >> 32)
1839 if (!match(V: C2, P: m_c_Add(L: m_LShr(L: m_Value(V&: XlYl), R: m_SpecificInt(V: BitWidth / 2)),
1840 R: m_Value(V&: XhYl))))
1841 return false;
1842 }
1843
1844 // Match XhYl and XlYh - they can appear either way around.
1845 if (!CheckHiLo(XlYh, Y, X))
1846 std::swap(a&: XlYh, b&: XhYl);
1847 if (!CheckHiLo(XlYh, Y, X))
1848 return false;
1849 if (!CheckHiLo(XhYl, X, Y))
1850 return false;
1851 if (!CheckLoLo(XlYl, X, Y))
1852 return false;
1853
1854 return CreateMulHigh(X, Y);
1855 };
1856
1857 auto FoldMulHighLadder4 = [&](Value *X, Value *Y, Instruction *A,
1858 Instruction *B, Instruction *C) {
1859 /// Ladder4: xh*yh + (xl*yh)>>32 + (xh+yl)>>32 + low>>32;
1860 /// low = (xl*yl)>>32 + (xl*yh)&0xffffffff + (xh*yl)&0xffffffff
1861
1862 // Find A = Low >> 32 and B/C = XhYl>>32, XlYh>>32.
1863 auto ShiftAdd =
1864 m_LShr(L: m_Add(L: m_Value(), R: m_Value()), R: m_SpecificInt(V: BitWidth / 2));
1865 if (!match(V: A, P: ShiftAdd))
1866 std::swap(a&: A, b&: B);
1867 if (!match(V: A, P: ShiftAdd))
1868 std::swap(a&: A, b&: C);
1869 Value *Low;
1870 if (!match(V: A, P: m_LShr(L: m_OneUse(SubPattern: m_Value(V&: Low)), R: m_SpecificInt(V: BitWidth / 2))))
1871 return false;
1872
1873 // Match B == XhYl>>32 and C == XlYh>>32
1874 Value *XhYl, *XlYh;
1875 if (!match(V: B, P: m_LShr(L: m_Value(V&: XhYl), R: m_SpecificInt(V: BitWidth / 2))) ||
1876 !match(V: C, P: m_LShr(L: m_Value(V&: XlYh), R: m_SpecificInt(V: BitWidth / 2))))
1877 return false;
1878 if (!CheckHiLo(XhYl, X, Y))
1879 std::swap(a&: XhYl, b&: XlYh);
1880 if (!CheckHiLo(XhYl, X, Y) || XhYl->hasNUsesOrMore(N: 3))
1881 return false;
1882 if (!CheckHiLo(XlYh, Y, X) || XlYh->hasNUsesOrMore(N: 3))
1883 return false;
1884
1885 // Match Low as XlYl>>32 + XhYl&0xffffffff + XlYh&0xffffffff
1886 Value *XlYl;
1887 if (!match(
1888 V: Low,
1889 P: m_c_Add(
1890 L: m_OneUse(SubPattern: m_c_Add(
1891 L: m_OneUse(SubPattern: m_And(L: m_Specific(V: XhYl), R: m_SpecificInt(V: LowMask))),
1892 R: m_OneUse(SubPattern: m_And(L: m_Specific(V: XlYh), R: m_SpecificInt(V: LowMask))))),
1893 R: m_OneUse(
1894 SubPattern: m_LShr(L: m_Value(V&: XlYl), R: m_SpecificInt(V: BitWidth / 2))))) &&
1895 !match(
1896 V: Low,
1897 P: m_c_Add(
1898 L: m_OneUse(SubPattern: m_c_Add(
1899 L: m_OneUse(SubPattern: m_And(L: m_Specific(V: XhYl), R: m_SpecificInt(V: LowMask))),
1900 R: m_OneUse(
1901 SubPattern: m_LShr(L: m_Value(V&: XlYl), R: m_SpecificInt(V: BitWidth / 2))))),
1902 R: m_OneUse(SubPattern: m_And(L: m_Specific(V: XlYh), R: m_SpecificInt(V: LowMask))))) &&
1903 !match(
1904 V: Low,
1905 P: m_c_Add(
1906 L: m_OneUse(SubPattern: m_c_Add(
1907 L: m_OneUse(SubPattern: m_And(L: m_Specific(V: XlYh), R: m_SpecificInt(V: LowMask))),
1908 R: m_OneUse(
1909 SubPattern: m_LShr(L: m_Value(V&: XlYl), R: m_SpecificInt(V: BitWidth / 2))))),
1910 R: m_OneUse(SubPattern: m_And(L: m_Specific(V: XhYl), R: m_SpecificInt(V: LowMask))))))
1911 return false;
1912 if (!CheckLoLo(XlYl, X, Y))
1913 return false;
1914
1915 return CreateMulHigh(X, Y);
1916 };
1917
1918 auto FoldMulHighCarry4 = [&](Value *X, Value *Y, Instruction *Carry,
1919 Instruction *B, Instruction *C) {
1920 // xh*yh + carry + crosssum>>32 + (xl*yl + crosssum&0xffffffff) >> 32
1921 // crosssum = xh*yl+xl*yh
1922 // carry = crosssum < xh*yl ? 0x1000000 : 0
1923 if (Carry->getOpcode() != Instruction::Select)
1924 std::swap(a&: Carry, b&: B);
1925 if (Carry->getOpcode() != Instruction::Select)
1926 std::swap(a&: Carry, b&: C);
1927
1928 // Carry = CrossSum < XhYl ? 0x100000000 : 0
1929 Value *CrossSum, *XhYl;
1930 if (!match(V: Carry,
1931 P: m_OneUse(SubPattern: m_Select(
1932 C: m_OneUse(SubPattern: m_SpecificICmp(MatchPred: ICmpInst::ICMP_ULT,
1933 L: m_Value(V&: CrossSum), R: m_Value(V&: XhYl))),
1934 L: m_SpecificInt(V: APInt::getOneBitSet(numBits: BitWidth, BitNo: BitWidth / 2)),
1935 R: m_Zero()))))
1936 return false;
1937
1938 if (!match(V: B, P: m_LShr(L: m_Specific(V: CrossSum), R: m_SpecificInt(V: BitWidth / 2))))
1939 std::swap(a&: B, b&: C);
1940 if (!match(V: B, P: m_LShr(L: m_Specific(V: CrossSum), R: m_SpecificInt(V: BitWidth / 2))))
1941 return false;
1942
1943 Value *XlYl, *LowAccum;
1944 if (!match(V: C, P: m_LShr(L: m_Value(V&: LowAccum), R: m_SpecificInt(V: BitWidth / 2))) ||
1945 !match(V: LowAccum, P: m_c_Add(L: m_OneUse(SubPattern: m_LShr(L: m_Value(V&: XlYl),
1946 R: m_SpecificInt(V: BitWidth / 2))),
1947 R: m_OneUse(SubPattern: m_And(L: m_Specific(V: CrossSum),
1948 R: m_SpecificInt(V: LowMask))))) ||
1949 LowAccum->hasNUsesOrMore(N: 3))
1950 return false;
1951 if (!CheckLoLo(XlYl, X, Y))
1952 return false;
1953
1954 if (!CheckHiLo(XhYl, X, Y))
1955 std::swap(a&: X, b&: Y);
1956 if (!CheckHiLo(XhYl, X, Y))
1957 return false;
1958 Value *XlYh;
1959 if (!match(V: CrossSum, P: m_c_Add(L: m_Specific(V: XhYl), R: m_OneUse(SubPattern: m_Value(V&: XlYh)))) ||
1960 !CheckHiLo(XlYh, Y, X) || CrossSum->hasNUsesOrMore(N: 4) ||
1961 XhYl->hasNUsesOrMore(N: 3))
1962 return false;
1963
1964 return CreateMulHigh(X, Y);
1965 };
1966
1967 // X and Y are the two inputs, A, B and C are other parts of the pattern
1968 // (crosssum>>32, carry, etc).
1969 Value *X, *Y;
1970 Instruction *A, *B, *C;
1971 auto HiHi = m_OneUse(SubPattern: m_Mul(L: m_LShr(L: m_Value(V&: X), R: m_SpecificInt(V: BitWidth / 2)),
1972 R: m_LShr(L: m_Value(V&: Y), R: m_SpecificInt(V: BitWidth / 2))));
1973 if ((match(V: &I, P: m_c_Add(L: HiHi, R: m_OneUse(SubPattern: m_Add(L: m_Instruction(I&: A),
1974 R: m_Instruction(I&: B))))) ||
1975 match(V: &I, P: m_c_Add(L: m_Instruction(I&: A),
1976 R: m_OneUse(SubPattern: m_c_Add(L: HiHi, R: m_Instruction(I&: B)))))) &&
1977 A->hasOneUse() && B->hasOneUse())
1978 if (FoldMulHighCarry(X, Y, A, B) || FoldMulHighLadder(X, Y, A, B))
1979 return true;
1980
1981 if ((match(V: &I, P: m_c_Add(L: HiHi, R: m_OneUse(SubPattern: m_c_Add(
1982 L: m_Instruction(I&: A),
1983 R: m_OneUse(SubPattern: m_Add(L: m_Instruction(I&: B),
1984 R: m_Instruction(I&: C))))))) ||
1985 match(V: &I, P: m_c_Add(L: m_Instruction(I&: A),
1986 R: m_OneUse(SubPattern: m_c_Add(
1987 L: HiHi, R: m_OneUse(SubPattern: m_Add(L: m_Instruction(I&: B),
1988 R: m_Instruction(I&: C))))))) ||
1989 match(V: &I, P: m_c_Add(L: m_Instruction(I&: A),
1990 R: m_OneUse(SubPattern: m_c_Add(
1991 L: m_Instruction(I&: B),
1992 R: m_OneUse(SubPattern: m_c_Add(L: HiHi, R: m_Instruction(I&: C))))))) ||
1993 match(V: &I,
1994 P: m_c_Add(L: m_OneUse(SubPattern: m_c_Add(L: HiHi, R: m_Instruction(I&: A))),
1995 R: m_OneUse(SubPattern: m_Add(L: m_Instruction(I&: B), R: m_Instruction(I&: C)))))) &&
1996 A->hasOneUse() && B->hasOneUse() && C->hasOneUse())
1997 return FoldMulHighCarry4(X, Y, A, B, C) ||
1998 FoldMulHighLadder4(X, Y, A, B, C);
1999
2000 return false;
2001}
2002
2003/// This is the entry point for folds that could be implemented in regular
2004/// InstCombine, but they are separated because they are not expected to
2005/// occur frequently and/or have more than a constant-length pattern match.
2006static bool foldUnusualPatterns(Function &F, DominatorTree &DT,
2007 TargetTransformInfo &TTI,
2008 TargetLibraryInfo &TLI, AliasAnalysis &AA,
2009 AssumptionCache &AC, bool &MadeCFGChange) {
2010 bool MadeChange = false;
2011 for (BasicBlock &BB : F) {
2012 // Ignore unreachable basic blocks.
2013 if (!DT.isReachableFromEntry(A: &BB))
2014 continue;
2015
2016 const DataLayout &DL = F.getDataLayout();
2017
2018 // Walk the block backwards for efficiency. We're matching a chain of
2019 // use->defs, so we're more likely to succeed by starting from the bottom.
2020 // Also, we want to avoid matching partial patterns.
2021 // TODO: It would be more efficient if we removed dead instructions
2022 // iteratively in this loop rather than waiting until the end.
2023 for (Instruction &I : make_early_inc_range(Range: llvm::reverse(C&: BB))) {
2024 MadeChange |= foldAnyOrAllBitsSet(I);
2025 MadeChange |= foldGuardedFunnelShift(I, DT);
2026 MadeChange |= tryToRecognizePopCount(I);
2027 MadeChange |= tryToFPToSat(I, TTI);
2028 MadeChange |= tryToRecognizeTableBasedCttz(I, DL);
2029 MadeChange |= tryToRecognizeTableBasedLog2(I, DL, TTI);
2030 MadeChange |= foldConsecutiveLoads(I, DL, TTI, AA, DT);
2031 MadeChange |= foldPatternedLoads(I, DL);
2032 MadeChange |= foldICmpOrChain(I, DL, TTI, AA, DT);
2033 MadeChange |= foldMulHigh(I);
2034 // NOTE: This function introduces erasing of the instruction `I`, so it
2035 // needs to be called at the end of this sequence, otherwise we may make
2036 // bugs.
2037 MadeChange |= foldLibCalls(I, TTI, TLI, AC, DT, DL, MadeCFGChange);
2038 }
2039
2040 // Do this separately to avoid redundantly scanning stores multiple times.
2041 MadeChange |= foldConsecutiveStores(BB, DL, TTI, AA);
2042 }
2043
2044 // We're done with transforms, so remove dead instructions.
2045 if (MadeChange)
2046 for (BasicBlock &BB : F)
2047 SimplifyInstructionsInBlock(BB: &BB);
2048
2049 return MadeChange;
2050}
2051
2052/// This is the entry point for all transforms. Pass manager differences are
2053/// handled in the callers of this function.
2054static bool runImpl(Function &F, AssumptionCache &AC, TargetTransformInfo &TTI,
2055 TargetLibraryInfo &TLI, DominatorTree &DT,
2056 AliasAnalysis &AA, bool &MadeCFGChange) {
2057 bool MadeChange = false;
2058 const DataLayout &DL = F.getDataLayout();
2059 TruncInstCombine TIC(AC, TLI, DL, DT);
2060 MadeChange |= TIC.run(F);
2061 MadeChange |= foldUnusualPatterns(F, DT, TTI, TLI, AA, AC, MadeCFGChange);
2062 return MadeChange;
2063}
2064
2065PreservedAnalyses AggressiveInstCombinePass::run(Function &F,
2066 FunctionAnalysisManager &AM) {
2067 auto &AC = AM.getResult<AssumptionAnalysis>(IR&: F);
2068 auto &TLI = AM.getResult<TargetLibraryAnalysis>(IR&: F);
2069 auto &DT = AM.getResult<DominatorTreeAnalysis>(IR&: F);
2070 auto &TTI = AM.getResult<TargetIRAnalysis>(IR&: F);
2071 auto &AA = AM.getResult<AAManager>(IR&: F);
2072 bool MadeCFGChange = false;
2073 if (!runImpl(F, AC, TTI, TLI, DT, AA, MadeCFGChange)) {
2074 // No changes, all analyses are preserved.
2075 return PreservedAnalyses::all();
2076 }
2077 // Mark all the analyses that instcombine updates as preserved.
2078 PreservedAnalyses PA;
2079 if (MadeCFGChange)
2080 PA.preserve<DominatorTreeAnalysis>();
2081 else
2082 PA.preserveSet<CFGAnalyses>();
2083 return PA;
2084}
2085