1//===- AggressiveInstCombine.cpp ------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements the aggressive expression pattern combiner classes.
10// Currently, it handles expression patterns for:
11// * Truncate instruction
12//
13//===----------------------------------------------------------------------===//
14
15#include "llvm/Transforms/AggressiveInstCombine/AggressiveInstCombine.h"
16#include "AggressiveInstCombineInternal.h"
17#include "llvm/ADT/Statistic.h"
18#include "llvm/Analysis/AliasAnalysis.h"
19#include "llvm/Analysis/AssumptionCache.h"
20#include "llvm/Analysis/BasicAliasAnalysis.h"
21#include "llvm/Analysis/ConstantFolding.h"
22#include "llvm/Analysis/DomTreeUpdater.h"
23#include "llvm/Analysis/GlobalsModRef.h"
24#include "llvm/Analysis/TargetLibraryInfo.h"
25#include "llvm/Analysis/TargetTransformInfo.h"
26#include "llvm/Analysis/ValueTracking.h"
27#include "llvm/IR/DataLayout.h"
28#include "llvm/IR/Dominators.h"
29#include "llvm/IR/Function.h"
30#include "llvm/IR/IRBuilder.h"
31#include "llvm/IR/Instruction.h"
32#include "llvm/IR/MDBuilder.h"
33#include "llvm/IR/PatternMatch.h"
34#include "llvm/IR/ProfDataUtils.h"
35#include "llvm/Support/Casting.h"
36#include "llvm/Support/CommandLine.h"
37#include "llvm/Transforms/Utils/BasicBlockUtils.h"
38#include "llvm/Transforms/Utils/BuildLibCalls.h"
39#include "llvm/Transforms/Utils/Local.h"
40
41using namespace llvm;
42using namespace PatternMatch;
43
44#define DEBUG_TYPE "aggressive-instcombine"
45
46namespace llvm {
47extern cl::opt<bool> ProfcheckDisableMetadataFixes;
48}
49
50STATISTIC(NumAnyOrAllBitsSet, "Number of any/all-bits-set patterns folded");
51STATISTIC(NumGuardedRotates,
52 "Number of guarded rotates transformed into funnel shifts");
53STATISTIC(NumGuardedFunnelShifts,
54 "Number of guarded funnel shifts transformed into funnel shifts");
55STATISTIC(NumPopCountRecognized, "Number of popcount idioms recognized");
56STATISTIC(NumSelectCTTZFolded,
57 "Number of select-based split cttz patterns folded");
58STATISTIC(NumSelectCTLZFolded,
59 "Number of select-based split ctlz patterns folded");
60
61static cl::opt<unsigned> MaxInstrsToScan(
62 "aggressive-instcombine-max-scan-instrs", cl::init(Val: 64), cl::Hidden,
63 cl::desc("Max number of instructions to scan for aggressive instcombine."));
64
65static cl::opt<unsigned> StrNCmpInlineThreshold(
66 "strncmp-inline-threshold", cl::init(Val: 3), cl::Hidden,
67 cl::desc("The maximum length of a constant string for a builtin string cmp "
68 "call eligible for inlining. The default value is 3."));
69
70static cl::opt<unsigned>
71 MemChrInlineThreshold("memchr-inline-threshold", cl::init(Val: 3), cl::Hidden,
72 cl::desc("The maximum length of a constant string to "
73 "inline a memchr call."));
74
75/// Try to fold a select-based split cttz pattern into a single full-width cttz.
76///
77/// %lo = trunc iN %val to i(N/2)
78/// %cmp = icmp eq i(N/2) %lo, 0
79/// %shr = lshr iN %val, N/2
80/// %hi = trunc iN %shr to i(N/2)
81/// %cttz_hi = call i(N/2) @llvm.cttz.i(N/2)(i(N/2) %hi, ...)
82/// %hi_plus = add/or_disjoint i(N/2) %cttz_hi, N/2
83/// %cttz_lo = call i(N/2) @llvm.cttz.i(N/2)(i(N/2) %lo, ...)
84/// %result = select i1 %cmp, i(N/2) %hi_plus, i(N/2) %cttz_lo
85/// -->
86/// %cttz_wide = call iN @llvm.cttz.iN(iN %val, i1 false)
87/// %result = trunc iN %cttz_wide to i(N/2)
88/// Alive proof (for i64/i32): https://alive2.llvm.org/ce/z/-s14-s
89// TrueVal/FalseVal are pre-normalized by the caller to the EQ/NE cases.
90static bool foldSelectSplitCTTZ(Instruction &I, Value *LoTrunc, Value *HiResult,
91 Value *LoResult, Type *HalfTy) {
92 unsigned HalfWidth = HalfTy->getIntegerBitWidth();
93 unsigned FullWidth = HalfWidth * 2;
94
95 // LoTrunc: trunc iN SrcVal to i(N/2)
96 Value *SrcVal;
97 if (!match(V: LoTrunc, P: m_Trunc(Op: m_Value(V&: SrcVal))))
98 return false;
99 if (!SrcVal->getType()->isIntegerTy(BitWidth: FullWidth))
100 return false;
101
102 // LoResult: cttz(trunc(SrcVal), _), must use same truncated value
103 if (!match(V: LoResult, P: m_OneUse(SubPattern: m_Cttz(Op0: m_Specific(V: LoTrunc), Op1: m_Value()))))
104 return false;
105
106 // HiResult: add/or_disjoint(cttz(trunc(lshr(SrcVal, N/2)), _), N/2)
107 Value *CttzHiCall;
108 if (!match(V: HiResult, P: m_OneUse(SubPattern: m_AddLike(L: m_Value(V&: CttzHiCall),
109 R: m_SpecificInt(V: HalfWidth)))))
110 return false;
111
112 Value *HiCttzArg;
113 if (!match(V: CttzHiCall, P: m_OneUse(SubPattern: m_Cttz(Op0: m_Value(V&: HiCttzArg), Op1: m_Value()))))
114 return false;
115
116 if (!match(V: HiCttzArg,
117 P: m_Trunc(Op: m_LShr(L: m_Specific(V: SrcVal), R: m_SpecificInt(V: HalfWidth)))))
118 return false;
119
120 // Match successful.
121 IRBuilder<> Builder(&I);
122 Value *CttzWide = Builder.CreateIntrinsic(
123 ID: Intrinsic::cttz, OverloadTypes: {SrcVal->getType()}, Args: {SrcVal, Builder.getFalse()});
124 Value *Trunc = Builder.CreateTrunc(V: CttzWide, DestTy: HalfTy);
125
126 I.replaceAllUsesWith(V: Trunc);
127 ++NumSelectCTTZFolded;
128 return true;
129}
130
131/// Same as foldSelectSplitCTTZ but for leading zeros (ctlz).
132///
133/// %shr = lshr iN %val, N/2
134/// %hi = trunc iN %shr to i(N/2)
135/// %cmp = icmp eq i(N/2) %hi, 0 (or icmp eq iN %shr, 0)
136/// %lo = trunc iN %val to i(N/2)
137/// %ctlz_lo = call i(N/2) @llvm.ctlz.i(N/2)(i(N/2) %lo, ...)
138/// %lo_plus = add/or_disjoint i(N/2) %ctlz_lo, N/2
139/// %ctlz_hi = call i(N/2) @llvm.ctlz.i(N/2)(i(N/2) %hi, ...)
140/// %result = select i1 %cmp, i(N/2) %lo_plus, i(N/2) %ctlz_hi
141/// -->
142/// %ctlz_wide = call iN @llvm.ctlz.iN(iN %val, i1 false)
143/// %result = trunc iN %ctlz_wide to i(N/2)
144///
145/// Alive proof (for i64/i32): https://alive2.llvm.org/ce/z/WfQepH
146// TrueVal/FalseVal are pre-normalized by the caller to the EQ/NE cases.
147static bool foldSelectSplitCTLZ(Instruction &I, Value *HiPart, Value *LoResult,
148 Value *HiResult, Type *HalfTy) {
149 unsigned HalfWidth = HalfTy->getIntegerBitWidth();
150 unsigned FullWidth = HalfWidth * 2;
151
152 // Extract SrcVal from HiPart: either trunc(lshr(SrcVal, N/2)) or
153 // lshr(SrcVal, N/2)
154 Value *SrcVal;
155 if (match(V: HiPart, P: m_Trunc(Op: m_Value(V&: SrcVal))))
156 HiPart = SrcVal;
157
158 if (!match(V: HiPart, P: m_LShr(L: m_Value(V&: SrcVal), R: m_SpecificInt(V: HalfWidth))))
159 return false;
160 if (!SrcVal->getType()->isIntegerTy(BitWidth: FullWidth))
161 return false;
162
163 // HiResult: ctlz(trunc(lshr(SrcVal, N/2)), _)
164 Value *HiCtlzArg;
165 if (!match(V: HiResult, P: m_OneUse(SubPattern: m_Ctlz(Op0: m_Value(V&: HiCtlzArg), Op1: m_Value()))))
166 return false;
167
168 if (!match(V: HiCtlzArg,
169 P: m_Trunc(Op: m_LShr(L: m_Specific(V: SrcVal), R: m_SpecificInt(V: HalfWidth)))))
170 return false;
171
172 // LoResult: add/or_disjoint(ctlz(trunc(SrcVal), _), N/2)
173 Value *CtlzLoCall;
174 if (!match(V: LoResult, P: m_OneUse(SubPattern: m_AddLike(L: m_Value(V&: CtlzLoCall),
175 R: m_SpecificInt(V: HalfWidth)))))
176 return false;
177
178 Value *LoCtlzArg;
179 if (!match(V: CtlzLoCall, P: m_OneUse(SubPattern: m_Ctlz(Op0: m_Value(V&: LoCtlzArg), Op1: m_Value()))))
180 return false;
181
182 if (!match(V: LoCtlzArg, P: m_Trunc(Op: m_Specific(V: SrcVal))))
183 return false;
184
185 // Match successful.
186 IRBuilder<> Builder(&I);
187 Value *CtlzWide = Builder.CreateIntrinsic(
188 ID: Intrinsic::ctlz, OverloadTypes: {SrcVal->getType()}, Args: {SrcVal, Builder.getFalse()});
189 Value *Trunc = Builder.CreateTrunc(V: CtlzWide, DestTy: HalfTy);
190
191 I.replaceAllUsesWith(V: Trunc);
192 ++NumSelectCTLZFolded;
193 return true;
194}
195
196/// Common entry point for folding select-based split cttz/ctlz patterns.
197/// Performs the initial select and type matching shared by both transforms,
198/// then delegates to foldSelectSplitCTTZ and foldSelectSplitCTLZ.
199static bool foldSelectSplitCTLZCTTZ(Instruction &I) {
200 Value *Cond, *TrueVal, *FalseVal;
201 if (!match(V: &I, P: m_Select(C: m_Value(V&: Cond), L: m_Value(V&: TrueVal), R: m_Value(V&: FalseVal))))
202 return false;
203
204 Type *Ty = I.getType();
205 if (!Ty->isIntegerTy())
206 return false;
207
208 // Bail out on very small types (i1, i2): the full-width cttz/ctlz can return
209 // values not representable in the half type (e.g., cttz.i4 can return 4,
210 // which doesn't fit in i2).
211 if (Ty->getIntegerBitWidth() <= 2)
212 return false;
213
214 CmpPredicate Pred;
215 Value *CmpOp;
216 if (!match(V: Cond, P: m_ICmp(Pred, L: m_Value(V&: CmpOp), R: m_ZeroInt())) ||
217 !ICmpInst::isEquality(P: Pred))
218 return false;
219
220 // Canonicalize select operands.
221 if (Pred == CmpInst::ICMP_NE)
222 std::swap(a&: TrueVal, b&: FalseVal);
223
224 return foldSelectSplitCTTZ(I, LoTrunc: CmpOp, HiResult: TrueVal, LoResult: FalseVal, HalfTy: Ty) ||
225 foldSelectSplitCTLZ(I, HiPart: CmpOp, LoResult: TrueVal, HiResult: FalseVal, HalfTy: Ty);
226}
227
228/// Match a pattern for a bitwise funnel/rotate operation that partially guards
229/// against undefined behavior by branching around the funnel-shift/rotation
230/// when the shift amount is 0.
231static bool foldGuardedFunnelShift(Instruction &I, const DominatorTree &DT) {
232 if (I.getOpcode() != Instruction::PHI || I.getNumOperands() != 2)
233 return false;
234
235 // As with the one-use checks below, this is not strictly necessary, but we
236 // are being cautious to avoid potential perf regressions on targets that
237 // do not actually have a funnel/rotate instruction (where the funnel shift
238 // would be expanded back into math/shift/logic ops).
239 if (!isPowerOf2_32(Value: I.getType()->getScalarSizeInBits()))
240 return false;
241
242 // Match V to funnel shift left/right and capture the source operands and
243 // shift amount.
244 auto matchFunnelShift = [](Value *V, Value *&ShVal0, Value *&ShVal1,
245 Value *&ShAmt) {
246 unsigned Width = V->getType()->getScalarSizeInBits();
247
248 // fshl(ShVal0, ShVal1, ShAmt)
249 // == (ShVal0 << ShAmt) | (ShVal1 >> (Width -ShAmt))
250 if (match(V, P: m_OneUse(SubPattern: m_c_Or(
251 L: m_Shl(L: m_Value(V&: ShVal0), R: m_Value(V&: ShAmt)),
252 R: m_LShr(L: m_Value(V&: ShVal1), R: m_Sub(L: m_SpecificInt(V: Width),
253 R: m_Deferred(V: ShAmt))))))) {
254 return Intrinsic::fshl;
255 }
256
257 // fshr(ShVal0, ShVal1, ShAmt)
258 // == (ShVal0 >> ShAmt) | (ShVal1 << (Width - ShAmt))
259 if (match(V,
260 P: m_OneUse(SubPattern: m_c_Or(L: m_Shl(L: m_Value(V&: ShVal0), R: m_Sub(L: m_SpecificInt(V: Width),
261 R: m_Value(V&: ShAmt))),
262 R: m_LShr(L: m_Value(V&: ShVal1), R: m_Deferred(V: ShAmt)))))) {
263 return Intrinsic::fshr;
264 }
265
266 return Intrinsic::not_intrinsic;
267 };
268
269 // One phi operand must be a funnel/rotate operation, and the other phi
270 // operand must be the source value of that funnel/rotate operation:
271 // phi [ rotate(RotSrc, ShAmt), FunnelBB ], [ RotSrc, GuardBB ]
272 // phi [ fshl(ShVal0, ShVal1, ShAmt), FunnelBB ], [ ShVal0, GuardBB ]
273 // phi [ fshr(ShVal0, ShVal1, ShAmt), FunnelBB ], [ ShVal1, GuardBB ]
274 PHINode &Phi = cast<PHINode>(Val&: I);
275 unsigned FunnelOp = 0, GuardOp = 1;
276 Value *P0 = Phi.getOperand(i_nocapture: 0), *P1 = Phi.getOperand(i_nocapture: 1);
277 Value *ShVal0, *ShVal1, *ShAmt;
278 Intrinsic::ID IID = matchFunnelShift(P0, ShVal0, ShVal1, ShAmt);
279 if (IID == Intrinsic::not_intrinsic ||
280 (IID == Intrinsic::fshl && ShVal0 != P1) ||
281 (IID == Intrinsic::fshr && ShVal1 != P1)) {
282 IID = matchFunnelShift(P1, ShVal0, ShVal1, ShAmt);
283 if (IID == Intrinsic::not_intrinsic ||
284 (IID == Intrinsic::fshl && ShVal0 != P0) ||
285 (IID == Intrinsic::fshr && ShVal1 != P0))
286 return false;
287 assert((IID == Intrinsic::fshl || IID == Intrinsic::fshr) &&
288 "Pattern must match funnel shift left or right");
289 std::swap(a&: FunnelOp, b&: GuardOp);
290 }
291
292 // The incoming block with our source operand must be the "guard" block.
293 // That must contain a cmp+branch to avoid the funnel/rotate when the shift
294 // amount is equal to 0. The other incoming block is the block with the
295 // funnel/rotate.
296 BasicBlock *GuardBB = Phi.getIncomingBlock(i: GuardOp);
297 BasicBlock *FunnelBB = Phi.getIncomingBlock(i: FunnelOp);
298 Instruction *TermI = GuardBB->getTerminator();
299
300 // Ensure that the shift values dominate each block.
301 if (!DT.dominates(Def: ShVal0, User: TermI) || !DT.dominates(Def: ShVal1, User: TermI))
302 return false;
303
304 BasicBlock *PhiBB = Phi.getParent();
305 if (!match(V: TermI, P: m_Br(C: m_SpecificICmp(MatchPred: CmpInst::ICMP_EQ, L: m_Specific(V: ShAmt),
306 R: m_ZeroInt()),
307 T: m_SpecificBB(BB: PhiBB), F: m_SpecificBB(BB: FunnelBB))))
308 return false;
309
310 IRBuilder<> Builder(PhiBB, PhiBB->getFirstInsertionPt());
311
312 if (ShVal0 == ShVal1)
313 ++NumGuardedRotates;
314 else
315 ++NumGuardedFunnelShifts;
316
317 // If this is not a rotate then the select was blocking poison from the
318 // 'shift-by-zero' non-TVal, but a funnel shift won't - so freeze it.
319 bool IsFshl = IID == Intrinsic::fshl;
320 if (ShVal0 != ShVal1) {
321 if (IsFshl && !llvm::isGuaranteedNotToBePoison(V: ShVal1))
322 ShVal1 = Builder.CreateFreeze(V: ShVal1);
323 else if (!IsFshl && !llvm::isGuaranteedNotToBePoison(V: ShVal0))
324 ShVal0 = Builder.CreateFreeze(V: ShVal0);
325 }
326
327 // We matched a variation of this IR pattern:
328 // GuardBB:
329 // %cmp = icmp eq i32 %ShAmt, 0
330 // br i1 %cmp, label %PhiBB, label %FunnelBB
331 // FunnelBB:
332 // %sub = sub i32 32, %ShAmt
333 // %shr = lshr i32 %ShVal1, %sub
334 // %shl = shl i32 %ShVal0, %ShAmt
335 // %fsh = or i32 %shr, %shl
336 // br label %PhiBB
337 // PhiBB:
338 // %cond = phi i32 [ %fsh, %FunnelBB ], [ %ShVal0, %GuardBB ]
339 // -->
340 // llvm.fshl.i32(i32 %ShVal0, i32 %ShVal1, i32 %ShAmt)
341 Phi.replaceAllUsesWith(
342 V: Builder.CreateIntrinsic(ID: IID, OverloadTypes: Phi.getType(), Args: {ShVal0, ShVal1, ShAmt}));
343 return true;
344}
345
346/// This is used by foldAnyOrAllBitsSet() to capture a source value (Root) and
347/// the bit indexes (Mask) needed by a masked compare. If we're matching a chain
348/// of 'and' ops, then we also need to capture the fact that we saw an
349/// "and X, 1", so that's an extra return value for that case.
350namespace {
351struct MaskOps {
352 Value *Root = nullptr;
353 APInt Mask;
354 bool MatchAndChain;
355 bool FoundAnd1 = false;
356
357 MaskOps(unsigned BitWidth, bool MatchAnds)
358 : Mask(APInt::getZero(numBits: BitWidth)), MatchAndChain(MatchAnds) {}
359};
360} // namespace
361
362/// This is a recursive helper for foldAnyOrAllBitsSet() that walks through a
363/// chain of 'and' or 'or' instructions looking for shift ops of a common source
364/// value. Examples:
365/// or (or (or X, (X >> 3)), (X >> 5)), (X >> 8)
366/// returns { X, 0x129 }
367/// and (and (X >> 1), 1), (X >> 4)
368/// returns { X, 0x12 }
369static bool matchAndOrChain(Value *V, MaskOps &MOps) {
370 Value *Op0, *Op1;
371 if (MOps.MatchAndChain) {
372 // Recurse through a chain of 'and' operands. This requires an extra check
373 // vs. the 'or' matcher: we must find an "and X, 1" instruction somewhere
374 // in the chain to know that all of the high bits are cleared.
375 if (match(V, P: m_And(L: m_Value(V&: Op0), R: m_One()))) {
376 MOps.FoundAnd1 = true;
377 return matchAndOrChain(V: Op0, MOps);
378 }
379 if (match(V, P: m_And(L: m_Value(V&: Op0), R: m_Value(V&: Op1))))
380 return matchAndOrChain(V: Op0, MOps) && matchAndOrChain(V: Op1, MOps);
381 } else {
382 // Recurse through a chain of 'or' operands.
383 if (match(V, P: m_Or(L: m_Value(V&: Op0), R: m_Value(V&: Op1))))
384 return matchAndOrChain(V: Op0, MOps) && matchAndOrChain(V: Op1, MOps);
385 }
386
387 // We need a shift-right or a bare value representing a compare of bit 0 of
388 // the original source operand.
389 Value *Candidate;
390 const APInt *BitIndex = nullptr;
391 if (!match(V, P: m_LShr(L: m_Value(V&: Candidate), R: m_APInt(Res&: BitIndex))))
392 Candidate = V;
393
394 // Initialize result source operand.
395 if (!MOps.Root)
396 MOps.Root = Candidate;
397
398 // The shift constant is out-of-range? This code hasn't been simplified.
399 if (BitIndex && BitIndex->uge(RHS: MOps.Mask.getBitWidth()))
400 return false;
401
402 // Fill in the mask bit derived from the shift constant.
403 MOps.Mask.setBit(BitIndex ? BitIndex->getZExtValue() : 0);
404 return MOps.Root == Candidate;
405}
406
407/// Match patterns that correspond to "any-bits-set" and "all-bits-set".
408/// These will include a chain of 'or' or 'and'-shifted bits from a
409/// common source value:
410/// and (or (lshr X, C), ...), 1 --> (X & CMask) != 0
411/// and (and (lshr X, C), ...), 1 --> (X & CMask) == CMask
412/// Note: "any-bits-clear" and "all-bits-clear" are variations of these patterns
413/// that differ only with a final 'not' of the result. We expect that final
414/// 'not' to be folded with the compare that we create here (invert predicate).
415static bool foldAnyOrAllBitsSet(Instruction &I) {
416 // The 'any-bits-set' ('or' chain) pattern is simpler to match because the
417 // final "and X, 1" instruction must be the final op in the sequence.
418 bool MatchAllBitsSet;
419 bool MatchTrunc;
420 Value *X;
421 if (I.getType()->isIntOrIntVectorTy(BitWidth: 1)) {
422 if (match(V: &I, P: m_Trunc(Op: m_OneUse(SubPattern: m_And(L: m_Value(), R: m_Value())))))
423 MatchAllBitsSet = true;
424 else if (match(V: &I, P: m_Trunc(Op: m_OneUse(SubPattern: m_Or(L: m_Value(), R: m_Value())))))
425 MatchAllBitsSet = false;
426 else
427 return false;
428 MatchTrunc = true;
429 X = I.getOperand(i: 0);
430 } else {
431 if (match(V: &I, P: m_c_And(L: m_OneUse(SubPattern: m_And(L: m_Value(), R: m_Value())), R: m_Value()))) {
432 X = &I;
433 MatchAllBitsSet = true;
434 } else if (match(V: &I,
435 P: m_And(L: m_OneUse(SubPattern: m_Or(L: m_Value(), R: m_Value())), R: m_One()))) {
436 X = I.getOperand(i: 0);
437 MatchAllBitsSet = false;
438 } else
439 return false;
440 MatchTrunc = false;
441 }
442 Type *Ty = X->getType();
443
444 MaskOps MOps(Ty->getScalarSizeInBits(), MatchAllBitsSet);
445 if (!matchAndOrChain(V: X, MOps) ||
446 (MatchAllBitsSet && !MatchTrunc && !MOps.FoundAnd1))
447 return false;
448
449 // The pattern was found. Create a masked compare that replaces all of the
450 // shift and logic ops.
451 IRBuilder<> Builder(&I);
452 Constant *Mask = ConstantInt::get(Ty, V: MOps.Mask);
453 Value *And = Builder.CreateAnd(LHS: MOps.Root, RHS: Mask);
454 Value *Cmp = MatchAllBitsSet ? Builder.CreateICmpEQ(LHS: And, RHS: Mask)
455 : Builder.CreateIsNotNull(Arg: And);
456 Value *Zext = MatchTrunc ? Cmp : Builder.CreateZExt(V: Cmp, DestTy: Ty);
457 I.replaceAllUsesWith(V: Zext);
458 ++NumAnyOrAllBitsSet;
459 return true;
460}
461
462/// Helper function to replace an instruction with a popcount intrinsic.
463/// This creates the ctpop intrinsic with an optional truncation appended at the
464/// end, and replaces all uses of the instruction.
465static void replaceWithPopCount(Instruction &I, Value *Root) {
466 LLVM_DEBUG(dbgs() << "Recognized popcount intrinsic\n");
467 Type *RootTy = Root->getType();
468 Type *OrigTy = I.getType();
469
470 IRBuilder<> Builder(&I);
471 Value *NewVal = Builder.CreateIntrinsic(ID: Intrinsic::ctpop, OverloadTypes: RootTy, Args: {Root});
472 if (OrigTy != RootTy) {
473 assert(RootTy->getScalarSizeInBits() > OrigTy->getScalarSizeInBits() &&
474 "Only truncation is supported for now");
475 NewVal = Builder.CreateTrunc(V: NewVal, DestTy: OrigTy);
476 }
477 I.replaceAllUsesWith(V: NewVal);
478 ++NumPopCountRecognized;
479}
480
481// Matches the common innermost steps of the Hacker's Delight popcount idiom:
482// V = ((x + (x >> 4)) & 0x0F...)
483// x = (y & 0x33...) + ((y >> 2) & 0x33...) [or y - 3*((y>>2)&0x33...)]
484// y = Root - ((Root >> 1) & 0x55...)
485// This computes the popcount for each byte.
486// Returns Root on success, nullptr on failure.
487static Value *matchPopCountBytes(Value *V, unsigned Len, const DataLayout &DL) {
488 APInt Mask55 = APInt::getSplat(NewLen: Len, V: APInt(8, 0x55));
489 APInt Mask33 = APInt::getSplat(NewLen: Len, V: APInt(8, 0x33));
490 APInt Mask0F = APInt::getSplat(NewLen: Len, V: APInt(8, 0x0F));
491
492 Value *Add2;
493 // Matching "((x + (x >> 4)) & 0x0F...)".
494 if (!match(V, P: m_And(L: m_c_Add(L: m_LShr(L: m_Value(V&: Add2), R: m_SpecificInt(V: 4)),
495 R: m_Deferred(V: Add2)),
496 R: m_SpecificInt(V: Mask0F))))
497 return nullptr;
498
499 Value *Sub1;
500 APInt NegThree(Len, -3, /*isSigned=*/true);
501 // Match
502 // x = (x & 0x33333333) + ((x >> 2) & 0x33333333)"
503 // Or
504 // x = x - 3*((x >> 2) & 0x33333333)
505 if (!match(V: Add2, P: m_c_Add(L: m_And(L: m_LShr(L: m_Value(V&: Sub1), R: m_SpecificInt(V: 2)),
506 R: m_SpecificInt(V: Mask33)),
507 R: m_And(L: m_Deferred(V: Sub1), R: m_SpecificInt(V: Mask33)))) &&
508 !match(V: Add2, P: m_Add(L: m_Mul(L: m_And(L: m_LShr(L: m_Value(V&: Sub1), R: m_SpecificInt(V: 2)),
509 R: m_SpecificInt(V: Mask33)),
510 R: m_SpecificInt(V: NegThree)),
511 R: m_Deferred(V: Sub1))))
512 return nullptr;
513
514 Value *Root, *LShr;
515 const APInt *AndMask;
516 // Matching "x - ((x >> 1) & 0x55...)".
517 if (!match(V: Sub1,
518 P: m_Sub(L: m_Value(V&: Root), R: m_And(L: m_Value(V&: LShr, P: m_LShr(L: m_Deferred(V: Root),
519 R: m_SpecificInt(V: 1))),
520 R: m_APInt(Res&: AndMask)))))
521 return nullptr;
522
523 if (*AndMask != Mask55) {
524 // Accept a narrowed mask if missing bits are known zero in Root>>1.
525 if (!AndMask->isSubsetOf(RHS: Mask55))
526 return nullptr;
527 APInt NeededMask = Mask55 & ~*AndMask;
528 if (!MaskedValueIsZero(V: LShr, Mask: NeededMask, SQ: SimplifyQuery(DL)))
529 return nullptr;
530 }
531
532 return Root;
533}
534
535// Try to recognize below function as popcount intrinsic.
536// This is the "best" algorithm from
537// http://graphics.stanford.edu/~seander/bithacks.html#CountBitsSetParallel
538// Also used in TargetLowering::expandCTPOP().
539//
540// int popcount(unsigned int i) {
541// i = i - ((i >> 1) & 0x55555555);
542// i = (i & 0x33333333) + ((i >> 2) & 0x33333333);
543// i = ((i + (i >> 4)) & 0x0F0F0F0F);
544// return (i * 0x01010101) >> 24;
545// }
546static bool tryToRecognizePopCount(Instruction &I) {
547 if (I.getOpcode() != Instruction::LShr)
548 return false;
549
550 Type *Ty = I.getType();
551 if (!Ty->isIntOrIntVectorTy())
552 return false;
553
554 unsigned Len = Ty->getScalarSizeInBits();
555 // Len==8 is handled by tryToRecognizePopCount2n3.
556 // FIXME: other irregular type lengths.
557 if (Len > 128 || Len <= 8 || Len % 8 != 0)
558 return false;
559
560 APInt Mask01 = APInt::getSplat(NewLen: Len, V: APInt(8, 0x01));
561
562 Value *Op0 = I.getOperand(i: 0);
563 Value *Op1 = I.getOperand(i: 1);
564 Value *MulOp0;
565 // Matching "(i * 0x01010101...) >> 24".
566 if (!match(V: Op0, P: m_Mul(L: m_Value(V&: MulOp0), R: m_SpecificInt(V: Mask01))) ||
567 !match(V: Op1, P: m_SpecificInt(V: Len - 8)))
568 return false;
569
570 Value *Root = matchPopCountBytes(V: MulOp0, Len, DL: I.getDataLayout());
571 if (!Root)
572 return false;
573
574 replaceWithPopCount(I, Root);
575 return true;
576}
577
578// Try to recognize below function as popcount intrinsic.
579// Ref. Hacker Delights
580// int popcount32(unsigned int i) {
581// uWord = (uWord & 0x55555555) + ((uWord>>1) & 0x55555555);
582// uWord = (uWord & 0x33333333) + ((uWord>>2) & 0x33333333);
583// uWord = (uWord & 0x0F0F0F0F) + ((uWord>>4) & 0x0F0F0F0F);
584// uWord = (uWord & 0x00FF00FF) + ((uWord>>8) & 0x00FF00FF);
585// return (uWord & 0x0000FFFF) + (uWord>>16);
586// }
587// int popcount64(unsigned long i) {
588// uWord = (uWord & 0x5555555555555555) + ((uWord>>1) & 0x5555555555555555);
589// uWord = (uWord & 0x3333333333333333) + ((uWord>>2) & 0x3333333333333333);
590// uWord = (uWord & 0x0F0F0F0F0F0F0F0F) + ((uWord>>4) & 0x0F0F0F0F0F0F0F0F);
591// uWord = (uWord & 0x00FF00FF00FF00FF) + ((uWord>>8) & 0x00FF00FF00FF00FF);
592// uWord = (uWord & 0x0000FFFF0000FFFF) + ((uWord>>16) & 0x0000FFFF0000FFFF);
593// return (uWord & 0x00000000FFFFFFFF) + (uWord>>32) & 0x00000000FFFFFFFF;
594// }
595//
596// InstCombine may narrow AND masks when it can prove the removed bits are
597// known zero (e.g. 0x0F0F0F0F -> 0x07070707). We accept such narrowed masks
598// by checking they are subsets of the expected masks and verifying the missing
599// bits are known zero via MaskedValueIsZero.
600static bool tryToRecognizePopCount1(Instruction &I) {
601 if (I.getOpcode() != Instruction::Add)
602 return false;
603
604 Type *Ty = I.getType();
605 if (!Ty->isIntOrIntVectorTy())
606 return false;
607
608 unsigned Len = Ty->getScalarSizeInBits();
609 if (Len > 64 || Len <= 8 || Len % 8 != 0)
610 return false;
611
612 // Len should be a power of 2 for the loop to work correctly
613 if (!isPowerOf2_32(Value: Len))
614 return false;
615
616 APInt Mask55 = APInt::getSplat(NewLen: Len, V: APInt(8, 0x55));
617 APInt Mask33 = APInt::getSplat(NewLen: Len, V: APInt(8, 0x33));
618
619 SimplifyQuery SQ(I.getDataLayout());
620
621 // Check if CapturedMask is a valid (possibly narrowed) version of
622 // ExpectedMask for the given Operand. Returns true if the masks match
623 // exactly, or if CapturedMask is a subset and the missing bits are
624 // known zero in the Operand.
625 auto isValidNarrowedMask = [&](const APInt &CapturedMask,
626 const APInt &ExpectedMask,
627 Value *Operand) -> bool {
628 if (CapturedMask == ExpectedMask)
629 return true;
630 if (!CapturedMask.isSubsetOf(RHS: ExpectedMask))
631 return false;
632 APInt NeededMask = ExpectedMask & ~CapturedMask;
633 return MaskedValueIsZero(V: Operand, Mask: NeededMask, SQ);
634 };
635
636 // For "(x & M) + ((x >> S) & M)" patterns, both AND masks may be narrowed.
637 // Require subsets of BaseMask and prove any implied missing bits are zero.
638 auto narrowAddPairMasksOk = [&](const APInt &BaseMask, unsigned ShiftAmt,
639 Value *Val, const APInt &AndMask1,
640 const APInt &AndMask2) -> bool {
641 if (!AndMask1.isSubsetOf(RHS: BaseMask) || !AndMask2.isSubsetOf(RHS: BaseMask))
642 return false;
643 APInt NeededShifted = (BaseMask & ~AndMask1).shl(shiftAmt: ShiftAmt);
644 APInt NeededUnshifted = BaseMask & ~AndMask2;
645 APInt AllNeeded = NeededShifted | NeededUnshifted;
646 return AllNeeded.isZero() || MaskedValueIsZero(V: Val, Mask: AllNeeded, SQ);
647 };
648
649 Value *ShiftOp;
650 Value *Start = &I;
651 for (unsigned I = Len; I >= 8; I = I / 2) {
652 APInt Mask = APInt::getSplat(NewLen: Len, V: APInt::getLowBitsSet(numBits: I, loBitsSet: I / 2));
653 const APInt *AndMask1 = nullptr, *AndMask2 = nullptr;
654
655 // Matching "(uWord & Mask) + ((uWord>>I/2) & Mask)".
656 // Both masks might have been narrowed by InstCombine.
657 if (match(V: Start,
658 P: m_c_Add(L: m_And(L: m_LShr(L: m_Value(V&: ShiftOp), R: m_SpecificInt(V: I / 2)),
659 R: m_APInt(Res&: AndMask1)),
660 R: m_And(L: m_Deferred(V: ShiftOp), R: m_APInt(Res&: AndMask2))))) {
661 if (!narrowAddPairMasksOk(Mask, I / 2, ShiftOp, *AndMask1, *AndMask2))
662 return false;
663 }
664 // Matching "(uWord & Mask) + (uWord>>I/2)".
665 // The mask might have been narrowed by InstCombine.
666 else if (match(V: Start,
667 P: m_c_Add(L: m_LShr(L: m_Value(V&: ShiftOp), R: m_SpecificInt(V: I / 2)),
668 R: m_And(L: m_Deferred(V: ShiftOp), R: m_APInt(Res&: AndMask1))))) {
669 if (!isValidNarrowedMask(*AndMask1, Mask, ShiftOp))
670 return false;
671 } else
672 return false;
673 Start = ShiftOp;
674 }
675
676 // Matching "uWord = (uWord & Mask33) + ((uWord>>2) & Mask33)".
677 const APInt *AndMask1 = nullptr, *AndMask2 = nullptr;
678 if (!match(V: Start, P: m_c_Add(L: m_And(L: m_LShr(L: m_Value(V&: ShiftOp), R: m_SpecificInt(V: 2)),
679 R: m_APInt(Res&: AndMask1)),
680 R: m_And(L: m_Deferred(V: ShiftOp), R: m_APInt(Res&: AndMask2)))))
681 return false;
682 if (!narrowAddPairMasksOk(Mask33, 2, ShiftOp, *AndMask1, *AndMask2))
683 return false;
684
685 Start = ShiftOp;
686 Value *Root;
687 // Matching "uWord = (uWord & Mask55) + ((uWord>>1) & Mask55)".
688 AndMask1 = nullptr;
689 AndMask2 = nullptr;
690 if (!match(V: Start, P: m_c_Add(L: m_And(L: m_LShr(L: m_Value(V&: Root), R: m_SpecificInt(V: 1)),
691 R: m_APInt(Res&: AndMask1)),
692 R: m_And(L: m_Deferred(V: Root), R: m_APInt(Res&: AndMask2)))))
693 return false;
694 if (!narrowAddPairMasksOk(Mask55, 1, Root, *AndMask1, *AndMask2))
695 return false;
696
697 replaceWithPopCount(I, Root);
698 return true;
699}
700
701// Try to recognize below function as popcount intrinsic.
702// Ref. Hackers Delight
703// int popcnt(unsigned x) {
704// x = x - ((x >> 1) & 0x55555555);
705// x = (x & 0x33333333) + ((x >> 2) & 0x33333333);
706// x = (x + (x >> 4)) & 0x0F0F0F0F;
707// x = x + (x >> 8);
708// x = x + (x >> 16);
709// return x & 0x0000003F;
710// }
711
712// int popcnt(unsigned x) {
713// x = x - ((x >> 1) & 0x55555555);
714// x = x - 3*((x >> 2) & 0x33333333);
715// x = (x + (x >> 4)) & 0x0F0F0F0F;
716// x = x + (x >> 8);
717// x = x + (x >> 16);
718// return x & 0x0000003F;
719// }
720static bool tryToRecognizePopCount2n3(Instruction &I) {
721 if (I.getOpcode() != Instruction::And)
722 return false;
723
724 Type *Ty = I.getType();
725 if (!Ty->isIntOrIntVectorTy())
726 return false;
727
728 unsigned Len = Ty->getScalarSizeInBits();
729 Value *Add1;
730 if (Len == 8) {
731 // Special case for Len == 8, we only need to match the And at the end of
732 // matchPopCountBytes.
733 Add1 = &I;
734 } else {
735 const APInt *MaskRes;
736 if (!match(V: &I, P: m_And(L: m_Value(V&: Add1), R: m_APInt(Res&: MaskRes))))
737 return false;
738
739 // Since `(trunc (and x, C))` might be canonicalized into `(and (trunc x),
740 // C)` we might loose the opportunity to recognize `(trunc (popcount y))`.
741 // The following block tries to capture such truncation, update `Len`, and
742 // append the truncation at the end of the emitting popcount, if there is
743 // any.
744 Value *TruncSrc;
745 if (match(V: Add1, P: m_OneUse(SubPattern: m_Trunc(Op: m_Value(V&: TruncSrc))))) {
746 Add1 = TruncSrc;
747 Len = Add1->getType()->getScalarSizeInBits();
748 }
749
750 if (Len > 64 || Len <= 8 || Len % 8 != 0)
751 return false;
752
753 // Len should be a power of 2 for the loop to work correctly
754 if (!isPowerOf2_32(Value: Len))
755 return false;
756
757 // Number of bits needed to represent Len.
758 unsigned NumLenBits = Log2_32(Value: Len) + 1;
759 // The "mask" here really only needs to fulfill two conditions:
760 // (1) All ones for the lower NumLenBits-bits
761 // (2) Zeros from bit 8 and onward.
762 // Condition (1) is straightforward. The reason behind condition
763 // (2) is that we don't care any 8-bit chunks but the first one
764 // in the original divide-and-conquer algorithm.
765 if (MaskRes->countTrailingOnes() < NumLenBits ||
766 MaskRes->getActiveBits() > 8)
767 return false;
768
769 for (unsigned I = Len; I >= 16; I = I / 2) {
770 Value *Add2;
771 // Matching "x = x + (x >> I/2)" for I-bit.
772 if (!match(V: Add1, P: m_c_Add(L: m_LShr(L: m_Value(V&: Add2), R: m_SpecificInt(V: I / 2)),
773 R: m_Deferred(V: Add2))))
774 return false;
775 Add1 = Add2;
776 }
777 }
778
779 Value *Root = matchPopCountBytes(V: Add1, Len, DL: I.getDataLayout());
780 if (!Root)
781 return false;
782
783 replaceWithPopCount(I, Root);
784 return true;
785}
786
787/// Fold smin(smax(fptosi(x), C1), C2) to llvm.fptosi.sat(x), providing C1 and
788/// C2 saturate the value of the fp conversion. The transform is not reversable
789/// as the fptosi.sat is more defined than the input - all values produce a
790/// valid value for the fptosi.sat, where as some produce poison for original
791/// that were out of range of the integer conversion. The reversed pattern may
792/// use fmax and fmin instead. As we cannot directly reverse the transform, and
793/// it is not always profitable, we make it conditional on the cost being
794/// reported as lower by TTI.
795static bool tryToFPToSat(Instruction &I, TargetTransformInfo &TTI) {
796 // Look for min(max(fptosi, converting to fptosi_sat.
797 Value *In;
798 const APInt *MinC, *MaxC;
799 if (!match(V: &I, P: m_SMax(Op0: m_OneUse(SubPattern: m_SMin(Op0: m_OneUse(SubPattern: m_FPToSI(Op: m_Value(V&: In))),
800 Op1: m_APInt(Res&: MinC))),
801 Op1: m_APInt(Res&: MaxC))) &&
802 !match(V: &I, P: m_SMin(Op0: m_OneUse(SubPattern: m_SMax(Op0: m_OneUse(SubPattern: m_FPToSI(Op: m_Value(V&: In))),
803 Op1: m_APInt(Res&: MaxC))),
804 Op1: m_APInt(Res&: MinC))))
805 return false;
806
807 // Check that the constants clamp a saturate.
808 if (!(*MinC + 1).isPowerOf2() || -*MaxC != *MinC + 1)
809 return false;
810
811 Type *IntTy = I.getType();
812 Type *FpTy = In->getType();
813 Type *SatTy =
814 IntegerType::get(C&: IntTy->getContext(), NumBits: (*MinC + 1).exactLogBase2() + 1);
815 if (auto *VecTy = dyn_cast<VectorType>(Val: IntTy))
816 SatTy = VectorType::get(ElementType: SatTy, EC: VecTy->getElementCount());
817
818 // Get the cost of the intrinsic, and check that against the cost of
819 // fptosi+smin+smax
820 InstructionCost SatCost = TTI.getIntrinsicInstrCost(
821 ICA: IntrinsicCostAttributes(Intrinsic::fptosi_sat, SatTy, {In}, {FpTy}),
822 CostKind: TTI::TCK_RecipThroughput);
823 SatCost += TTI.getCastInstrCost(Opcode: Instruction::SExt, Dst: IntTy, Src: SatTy,
824 CCH: TTI::CastContextHint::None,
825 CostKind: TTI::TCK_RecipThroughput);
826
827 InstructionCost MinMaxCost = TTI.getCastInstrCost(
828 Opcode: Instruction::FPToSI, Dst: IntTy, Src: FpTy, CCH: TTI::CastContextHint::None,
829 CostKind: TTI::TCK_RecipThroughput);
830 MinMaxCost += TTI.getIntrinsicInstrCost(
831 ICA: IntrinsicCostAttributes(Intrinsic::smin, IntTy, {IntTy}),
832 CostKind: TTI::TCK_RecipThroughput);
833 MinMaxCost += TTI.getIntrinsicInstrCost(
834 ICA: IntrinsicCostAttributes(Intrinsic::smax, IntTy, {IntTy}),
835 CostKind: TTI::TCK_RecipThroughput);
836
837 if (SatCost >= MinMaxCost)
838 return false;
839
840 IRBuilder<> Builder(&I);
841 Value *Sat =
842 Builder.CreateIntrinsic(ID: Intrinsic::fptosi_sat, OverloadTypes: {SatTy, FpTy}, Args: In);
843 I.replaceAllUsesWith(V: Builder.CreateSExt(V: Sat, DestTy: IntTy));
844 return true;
845}
846
847/// Try to replace a mathlib call to sqrt with the LLVM intrinsic. This avoids
848/// pessimistic codegen that has to account for setting errno and can enable
849/// vectorization.
850static bool foldSqrt(CallInst *Call, LibFunc Func, TargetTransformInfo &TTI,
851 TargetLibraryInfo &TLI, AssumptionCache &AC,
852 DominatorTree &DT) {
853 // If (1) this is a sqrt libcall, (2) we can assume that NAN is not created
854 // (because NNAN or the operand arg must not be less than -0.0) and (2) we
855 // would not end up lowering to a libcall anyway (which could change the value
856 // of errno), then:
857 // (1) errno won't be set.
858 // (2) it is safe to convert this to an intrinsic call.
859 Type *Ty = Call->getType();
860 Value *Arg = Call->getArgOperand(i: 0);
861 if (TTI.haveFastSqrt(Ty) &&
862 (Call->hasNoNaNs() ||
863 cannotBeOrderedLessThanZero(
864 V: Arg, SQ: SimplifyQuery(Call->getDataLayout(), &TLI, &DT, &AC, Call)))) {
865 IRBuilder<> Builder(Call);
866 Value *NewSqrt =
867 Builder.CreateIntrinsic(ID: Intrinsic::sqrt, OverloadTypes: Ty, Args: Arg, FMFSource: Call, Name: "sqrt");
868 Call->replaceAllUsesWith(V: NewSqrt);
869
870 // Explicitly erase the old call because a call with side effects is not
871 // trivially dead.
872 Call->eraseFromParent();
873 return true;
874 }
875
876 return false;
877}
878
879// Check if this array of constants represents a cttz table.
880// Iterate over the elements from \p Table by trying to find/match all
881// the numbers from 0 to \p InputBits that should represent cttz results.
882static bool isCTTZTable(Constant *Table, const APInt &Mul, const APInt &Shift,
883 const APInt &AndMask, Type *AccessTy,
884 unsigned InputBits, const APInt &GEPIdxFactor,
885 const DataLayout &DL) {
886 for (unsigned Idx = 0; Idx < InputBits; Idx++) {
887 APInt Index =
888 (APInt::getOneBitSet(numBits: InputBits, BitNo: Idx) * Mul).lshr(ShiftAmt: Shift) & AndMask;
889 ConstantInt *C = dyn_cast_or_null<ConstantInt>(
890 Val: ConstantFoldLoadFromConst(C: Table, Ty: AccessTy, Offset: Index * GEPIdxFactor, DL));
891 if (!C || C->getValue() != Idx)
892 return false;
893 }
894
895 return true;
896}
897
898// Try to recognize table-based ctz implementation.
899// E.g., an example in C (for more cases please see the llvm/tests):
900// int f(unsigned x) {
901// static const char table[32] =
902// {0, 1, 28, 2, 29, 14, 24, 3, 30,
903// 22, 20, 15, 25, 17, 4, 8, 31, 27,
904// 13, 23, 21, 19, 16, 7, 26, 12, 18, 6, 11, 5, 10, 9};
905// return table[((unsigned)((x & -x) * 0x077CB531U)) >> 27];
906// }
907// this can be lowered to `cttz` instruction.
908// There is also a special case when the element is 0.
909//
910// The (x & -x) sets the lowest non-zero bit to 1. The multiply is a de-bruijn
911// sequence that contains each pattern of bits in it. The shift extracts
912// the top bits after the multiply, and that index into the table should
913// represent the number of trailing zeros in the original number.
914//
915// Here are some examples or LLVM IR for a 64-bit target:
916//
917// CASE 1:
918// %sub = sub i32 0, %x
919// %and = and i32 %sub, %x
920// %mul = mul i32 %and, 125613361
921// %shr = lshr i32 %mul, 27
922// %idxprom = zext i32 %shr to i64
923// %arrayidx = getelementptr inbounds [32 x i8], [32 x i8]* @ctz1.table, i64 0,
924// i64 %idxprom
925// %0 = load i8, i8* %arrayidx, align 1, !tbaa !8
926//
927// CASE 2:
928// %sub = sub i32 0, %x
929// %and = and i32 %sub, %x
930// %mul = mul i32 %and, 72416175
931// %shr = lshr i32 %mul, 26
932// %idxprom = zext i32 %shr to i64
933// %arrayidx = getelementptr inbounds [64 x i16], [64 x i16]* @ctz2.table,
934// i64 0, i64 %idxprom
935// %0 = load i16, i16* %arrayidx, align 2, !tbaa !8
936//
937// CASE 3:
938// %sub = sub i32 0, %x
939// %and = and i32 %sub, %x
940// %mul = mul i32 %and, 81224991
941// %shr = lshr i32 %mul, 27
942// %idxprom = zext i32 %shr to i64
943// %arrayidx = getelementptr inbounds [32 x i32], [32 x i32]* @ctz3.table,
944// i64 0, i64 %idxprom
945// %0 = load i32, i32* %arrayidx, align 4, !tbaa !8
946//
947// CASE 4:
948// %sub = sub i64 0, %x
949// %and = and i64 %sub, %x
950// %mul = mul i64 %and, 283881067100198605
951// %shr = lshr i64 %mul, 58
952// %arrayidx = getelementptr inbounds [64 x i8], [64 x i8]* @table, i64 0,
953// i64 %shr
954// %0 = load i8, i8* %arrayidx, align 1, !tbaa !8
955//
956// All these can be lowered to @llvm.cttz.i32/64 intrinsics.
957static bool tryToRecognizeTableBasedCttz(Instruction &I, const DataLayout &DL) {
958 LoadInst *LI = dyn_cast<LoadInst>(Val: &I);
959 if (!LI)
960 return false;
961
962 Type *AccessType = LI->getType();
963 if (!AccessType->isIntegerTy())
964 return false;
965
966 GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(Val: LI->getPointerOperand());
967 if (!GEP || !GEP->hasNoUnsignedSignedWrap())
968 return false;
969
970 GlobalVariable *GVTable = dyn_cast<GlobalVariable>(Val: GEP->getPointerOperand());
971 if (!GVTable || !GVTable->hasInitializer() || !GVTable->isConstant())
972 return false;
973
974 unsigned BW = DL.getIndexTypeSizeInBits(Ty: GEP->getType());
975 APInt ModOffset(BW, 0);
976 SmallMapVector<Value *, APInt, 4> VarOffsets;
977 if (!GEP->collectOffset(DL, BitWidth: BW, VariableOffsets&: VarOffsets, ConstantOffset&: ModOffset) ||
978 VarOffsets.size() != 1 || ModOffset != 0)
979 return false;
980 auto [GepIdx, GEPScale] = VarOffsets.front();
981
982 Value *X1;
983 const APInt *MulConst, *ShiftConst, *AndCst = nullptr;
984 // Check that the gep variable index is ((x & -x) * MulConst) >> ShiftConst.
985 // This might be extended to the pointer index type, and if the gep index type
986 // has been replaced with an i8 then a new And (and different ShiftConst) will
987 // be present.
988 auto MatchInner = m_LShr(
989 L: m_Mul(L: m_c_And(L: m_Neg(V: m_Value(V&: X1)), R: m_Deferred(V: X1)), R: m_APInt(Res&: MulConst)),
990 R: m_APInt(Res&: ShiftConst));
991 if (!match(V: GepIdx, P: m_CastOrSelf(Op: MatchInner)) &&
992 !match(V: GepIdx, P: m_CastOrSelf(Op: m_And(L: MatchInner, R: m_APInt(Res&: AndCst)))))
993 return false;
994
995 unsigned InputBits = X1->getType()->getScalarSizeInBits();
996 if (InputBits != 16 && InputBits != 32 && InputBits != 64 && InputBits != 128)
997 return false;
998
999 if (!GEPScale.isIntN(N: InputBits) ||
1000 !isCTTZTable(Table: GVTable->getInitializer(), Mul: *MulConst, Shift: *ShiftConst,
1001 AndMask: AndCst ? *AndCst : APInt::getAllOnes(numBits: InputBits), AccessTy: AccessType,
1002 InputBits, GEPIdxFactor: GEPScale.zextOrTrunc(width: InputBits), DL))
1003 return false;
1004
1005 ConstantInt *ZeroTableElem = cast<ConstantInt>(
1006 Val: ConstantFoldLoadFromConst(C: GVTable->getInitializer(), Ty: AccessType, DL));
1007 bool DefinedForZero = ZeroTableElem->getZExtValue() == InputBits;
1008
1009 IRBuilder<> B(LI);
1010 ConstantInt *BoolConst = B.getInt1(V: !DefinedForZero);
1011 Type *XType = X1->getType();
1012 auto Cttz = B.CreateIntrinsic(ID: Intrinsic::cttz, OverloadTypes: {XType}, Args: {X1, BoolConst});
1013 Value *ZExtOrTrunc = nullptr;
1014
1015 if (DefinedForZero) {
1016 ZExtOrTrunc = B.CreateZExtOrTrunc(V: Cttz, DestTy: AccessType);
1017 } else {
1018 // If the value in elem 0 isn't the same as InputBits, we still want to
1019 // produce the value from the table.
1020 auto Cmp = B.CreateICmpEQ(LHS: X1, RHS: ConstantInt::get(Ty: XType, V: 0));
1021 auto Select = B.CreateSelect(C: Cmp, True: B.CreateZExt(V: ZeroTableElem, DestTy: XType), False: Cttz);
1022
1023 // The true branch of select handles the cttz(0) case, which is rare.
1024 if (!ProfcheckDisableMetadataFixes) {
1025 if (Instruction *SelectI = dyn_cast<Instruction>(Val: Select))
1026 SelectI->setMetadata(
1027 KindID: LLVMContext::MD_prof,
1028 Node: MDBuilder(SelectI->getContext()).createUnlikelyBranchWeights());
1029 }
1030
1031 // NOTE: If the table[0] is 0, but the cttz(0) is defined by the Target
1032 // it should be handled as: `cttz(x) & (typeSize - 1)`.
1033
1034 ZExtOrTrunc = B.CreateZExtOrTrunc(V: Select, DestTy: AccessType);
1035 }
1036
1037 LI->replaceAllUsesWith(V: ZExtOrTrunc);
1038
1039 return true;
1040}
1041
1042// Check if this array of constants represents a log2 table.
1043// Iterate over the elements from \p Table by trying to find/match all
1044// the numbers from 0 to \p InputBits that should represent log2 results.
1045static bool isLog2Table(Constant *Table, const APInt &Mul, const APInt &Shift,
1046 Type *AccessTy, unsigned InputBits,
1047 const APInt &GEPIdxFactor, const DataLayout &DL) {
1048 for (unsigned Idx = 0; Idx < InputBits; Idx++) {
1049 APInt Index = (APInt::getLowBitsSet(numBits: InputBits, loBitsSet: Idx + 1) * Mul).lshr(ShiftAmt: Shift);
1050 ConstantInt *C = dyn_cast_or_null<ConstantInt>(
1051 Val: ConstantFoldLoadFromConst(C: Table, Ty: AccessTy, Offset: Index * GEPIdxFactor, DL));
1052 if (!C || C->getValue() != Idx)
1053 return false;
1054 }
1055
1056 // Verify that an input of zero will select table index 0.
1057 APInt ZeroIndex = Mul.lshr(ShiftAmt: Shift);
1058 if (!ZeroIndex.isZero())
1059 return false;
1060
1061 return true;
1062}
1063
1064// Try to recognize table-based log2 implementation.
1065// E.g., an example in C (for more cases please the llvm/tests):
1066// int f(unsigned v) {
1067// static const char table[32] =
1068// {0, 9, 1, 10, 13, 21, 2, 29, 11, 14, 16, 18, 22, 25, 3, 30,
1069// 8, 12, 20, 28, 15, 17, 24, 7, 19, 27, 23, 6, 26, 5, 4, 31};
1070//
1071// v |= v >> 1; // first round down to one less than a power of 2
1072// v |= v >> 2;
1073// v |= v >> 4;
1074// v |= v >> 8;
1075// v |= v >> 16;
1076//
1077// return table[(unsigned)(v * 0x07C4ACDDU) >> 27];
1078// }
1079// this can be lowered to `ctlz` instruction.
1080// There is also a special case when the element is 0.
1081//
1082// The >> and |= sequence sets all bits below the most significant set bit. The
1083// multiply is a de-bruijn sequence that contains each pattern of bits in it.
1084// The shift extracts the top bits after the multiply, and that index into the
1085// table should represent the floor log base 2 of the original number.
1086//
1087// Here are some examples of LLVM IR for a 64-bit target.
1088//
1089// CASE 1:
1090// %shr = lshr i32 %v, 1
1091// %or = or i32 %shr, %v
1092// %shr1 = lshr i32 %or, 2
1093// %or2 = or i32 %shr1, %or
1094// %shr3 = lshr i32 %or2, 4
1095// %or4 = or i32 %shr3, %or2
1096// %shr5 = lshr i32 %or4, 8
1097// %or6 = or i32 %shr5, %or4
1098// %shr7 = lshr i32 %or6, 16
1099// %or8 = or i32 %shr7, %or6
1100// %mul = mul i32 %or8, 130329821
1101// %shr9 = lshr i32 %mul, 27
1102// %idxprom = zext nneg i32 %shr9 to i64
1103// %arrayidx = getelementptr inbounds i8, ptr @table, i64 %idxprom
1104// %0 = load i8, ptr %arrayidx, align 1
1105//
1106// CASE 2:
1107// %shr = lshr i64 %v, 1
1108// %or = or i64 %shr, %v
1109// %shr1 = lshr i64 %or, 2
1110// %or2 = or i64 %shr1, %or
1111// %shr3 = lshr i64 %or2, 4
1112// %or4 = or i64 %shr3, %or2
1113// %shr5 = lshr i64 %or4, 8
1114// %or6 = or i64 %shr5, %or4
1115// %shr7 = lshr i64 %or6, 16
1116// %or8 = or i64 %shr7, %or6
1117// %shr9 = lshr i64 %or8, 32
1118// %or10 = or i64 %shr9, %or8
1119// %mul = mul i64 %or10, 285870213051386505
1120// %shr11 = lshr i64 %mul, 58
1121// %arrayidx = getelementptr inbounds i8, ptr @table, i64 %shr11
1122// %0 = load i8, ptr %arrayidx, align 1
1123//
1124// All these can be lowered to @llvm.ctlz.i32/64 intrinsics and a subtract.
1125static bool tryToRecognizeTableBasedLog2(Instruction &I, const DataLayout &DL,
1126 TargetTransformInfo &TTI) {
1127 LoadInst *LI = dyn_cast<LoadInst>(Val: &I);
1128 if (!LI)
1129 return false;
1130
1131 Type *AccessType = LI->getType();
1132 if (!AccessType->isIntegerTy())
1133 return false;
1134
1135 GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(Val: LI->getPointerOperand());
1136 if (!GEP || !GEP->hasNoUnsignedSignedWrap())
1137 return false;
1138
1139 GlobalVariable *GVTable = dyn_cast<GlobalVariable>(Val: GEP->getPointerOperand());
1140 if (!GVTable || !GVTable->hasInitializer() || !GVTable->isConstant())
1141 return false;
1142
1143 unsigned BW = DL.getIndexTypeSizeInBits(Ty: GEP->getType());
1144 APInt ModOffset(BW, 0);
1145 SmallMapVector<Value *, APInt, 4> VarOffsets;
1146 if (!GEP->collectOffset(DL, BitWidth: BW, VariableOffsets&: VarOffsets, ConstantOffset&: ModOffset) ||
1147 VarOffsets.size() != 1 || ModOffset != 0)
1148 return false;
1149 auto [GepIdx, GEPScale] = VarOffsets.front();
1150
1151 Value *X;
1152 const APInt *MulConst, *ShiftConst;
1153 // Check that the gep variable index is (x * MulConst) >> ShiftConst.
1154 auto MatchInner =
1155 m_LShr(L: m_Mul(L: m_Value(V&: X), R: m_APInt(Res&: MulConst)), R: m_APInt(Res&: ShiftConst));
1156 if (!match(V: GepIdx, P: m_CastOrSelf(Op: MatchInner)))
1157 return false;
1158
1159 unsigned InputBits = X->getType()->getScalarSizeInBits();
1160 if (InputBits != 16 && InputBits != 32 && InputBits != 64 && InputBits != 128)
1161 return false;
1162
1163 // Verify shift amount.
1164 // TODO: Allow other shift amounts when we have proper test coverage.
1165 if (*ShiftConst != InputBits - Log2_32(Value: InputBits))
1166 return false;
1167
1168 // Match the sequence of OR operations with right shifts by powers of 2.
1169 for (unsigned ShiftAmt = InputBits / 2; ShiftAmt != 0; ShiftAmt /= 2) {
1170 Value *Y;
1171 if (!match(V: X, P: m_c_Or(L: m_LShr(L: m_Value(V&: Y), R: m_SpecificInt(V: ShiftAmt)),
1172 R: m_Deferred(V: Y))))
1173 return false;
1174 X = Y;
1175 }
1176
1177 if (!GEPScale.isIntN(N: InputBits) ||
1178 !isLog2Table(Table: GVTable->getInitializer(), Mul: *MulConst, Shift: *ShiftConst,
1179 AccessTy: AccessType, InputBits, GEPIdxFactor: GEPScale.zextOrTrunc(width: InputBits), DL))
1180 return false;
1181
1182 ConstantInt *ZeroTableElem = cast<ConstantInt>(
1183 Val: ConstantFoldLoadFromConst(C: GVTable->getInitializer(), Ty: AccessType, DL));
1184
1185 // Use InputBits - 1 - ctlz(X) to compute log2(X).
1186 IRBuilder<> B(LI);
1187 ConstantInt *BoolConst = B.getTrue();
1188 Type *XType = X->getType();
1189
1190 // Check the the backend has an efficient ctlz instruction.
1191 // FIXME: Teach the backend to emit the original code when ctlz isn't
1192 // supported like we do for cttz.
1193 IntrinsicCostAttributes Attrs(
1194 Intrinsic::ctlz, XType,
1195 {PoisonValue::get(T: XType), /*is_zero_poison=*/BoolConst});
1196 InstructionCost Cost =
1197 TTI.getIntrinsicInstrCost(ICA: Attrs, CostKind: TargetTransformInfo::TCK_SizeAndLatency);
1198 if (Cost > TargetTransformInfo::TCC_Basic)
1199 return false;
1200
1201 Value *Ctlz = B.CreateIntrinsic(ID: Intrinsic::ctlz, OverloadTypes: {XType}, Args: {X, BoolConst});
1202
1203 Constant *InputBitsM1 = ConstantInt::get(Ty: XType, V: InputBits - 1);
1204 Value *Sub = B.CreateSub(LHS: InputBitsM1, RHS: Ctlz);
1205
1206 // The table won't produce a sensible result for 0.
1207 Value *Cmp = B.CreateICmpEQ(LHS: X, RHS: ConstantInt::get(Ty: XType, V: 0));
1208 Value *Select = B.CreateSelect(C: Cmp, True: B.CreateZExt(V: ZeroTableElem, DestTy: XType), False: Sub);
1209
1210 // The true branch of select handles the log2(0) case, which is rare.
1211 if (!ProfcheckDisableMetadataFixes) {
1212 if (Instruction *SelectI = dyn_cast<Instruction>(Val: Select))
1213 SelectI->setMetadata(
1214 KindID: LLVMContext::MD_prof,
1215 Node: MDBuilder(SelectI->getContext()).createUnlikelyBranchWeights());
1216 }
1217
1218 Value *ZExtOrTrunc = B.CreateZExtOrTrunc(V: Select, DestTy: AccessType);
1219
1220 LI->replaceAllUsesWith(V: ZExtOrTrunc);
1221
1222 return true;
1223}
1224
1225/// This is used by foldLoadsRecursive() to capture a Root Load node which is
1226/// of type or(load, load) and recursively build the wide load. Also capture the
1227/// shift amount, zero extend type and loadSize.
1228struct LoadOps {
1229 LoadInst *Root = nullptr;
1230 LoadInst *RootInsert = nullptr;
1231 bool FoundRoot = false;
1232 uint64_t LoadSize = 0;
1233 uint64_t Shift = 0;
1234 Type *ZextType;
1235 AAMDNodes AATags;
1236};
1237
1238// Identify and Merge consecutive loads recursively which is of the form
1239// (ZExt(L1) << shift1) | (ZExt(L2) << shift2) -> ZExt(L3) << shift1
1240// (ZExt(L1) << shift1) | ZExt(L2) -> ZExt(L3)
1241static bool foldLoadsRecursive(Value *V, LoadOps &LOps, const DataLayout &DL,
1242 AliasAnalysis &AA, bool IsRoot = false) {
1243 uint64_t ShAmt2;
1244 Value *X;
1245 Instruction *L1, *L2;
1246
1247 // For the root instruction, allow multiple uses since the final result
1248 // may legitimately be used in multiple places. For intermediate values,
1249 // require single use to avoid creating duplicate loads.
1250 if (!IsRoot && !V->hasOneUse())
1251 return false;
1252
1253 if (!match(V, P: m_c_Or(L: m_Value(V&: X),
1254 R: m_OneUse(SubPattern: m_ShlOrSelf(L: m_OneUse(SubPattern: m_ZExt(Op: m_Instruction(I&: L2))),
1255 R&: ShAmt2)))))
1256 return false;
1257
1258 if (!foldLoadsRecursive(V: X, LOps, DL, AA, /*IsRoot=*/false) && LOps.FoundRoot)
1259 // Avoid Partial chain merge.
1260 return false;
1261
1262 // Check if the pattern has loads
1263 LoadInst *LI1 = LOps.Root;
1264 uint64_t ShAmt1 = LOps.Shift;
1265 if (LOps.FoundRoot == false &&
1266 match(V: X, P: m_OneUse(
1267 SubPattern: m_ShlOrSelf(L: m_OneUse(SubPattern: m_ZExt(Op: m_Instruction(I&: L1))), R&: ShAmt1)))) {
1268 LI1 = dyn_cast<LoadInst>(Val: L1);
1269 }
1270 LoadInst *LI2 = dyn_cast<LoadInst>(Val: L2);
1271
1272 // Check if loads are same, atomic, volatile and having same address space.
1273 if (LI1 == LI2 || !LI1 || !LI2 || !LI1->isSimple() || !LI2->isSimple() ||
1274 LI1->getPointerAddressSpace() != LI2->getPointerAddressSpace())
1275 return false;
1276
1277 // Check if Loads come from same BB.
1278 if (LI1->getParent() != LI2->getParent())
1279 return false;
1280
1281 // Find the data layout
1282 bool IsBigEndian = DL.isBigEndian();
1283
1284 // Check if loads are consecutive and same size.
1285 Value *Load1Ptr = LI1->getPointerOperand();
1286 APInt Offset1(DL.getIndexTypeSizeInBits(Ty: Load1Ptr->getType()), 0);
1287 Load1Ptr =
1288 Load1Ptr->stripAndAccumulateConstantOffsets(DL, Offset&: Offset1,
1289 /* AllowNonInbounds */ true);
1290
1291 Value *Load2Ptr = LI2->getPointerOperand();
1292 APInt Offset2(DL.getIndexTypeSizeInBits(Ty: Load2Ptr->getType()), 0);
1293 Load2Ptr =
1294 Load2Ptr->stripAndAccumulateConstantOffsets(DL, Offset&: Offset2,
1295 /* AllowNonInbounds */ true);
1296
1297 // Verify if both loads have same base pointers
1298 uint64_t LoadSize1 = LI1->getType()->getPrimitiveSizeInBits();
1299 uint64_t LoadSize2 = LI2->getType()->getPrimitiveSizeInBits();
1300 if (Load1Ptr != Load2Ptr)
1301 return false;
1302
1303 // Make sure that there are no padding bits.
1304 if (!DL.typeSizeEqualsStoreSize(Ty: LI1->getType()) ||
1305 !DL.typeSizeEqualsStoreSize(Ty: LI2->getType()))
1306 return false;
1307
1308 // Alias Analysis to check for stores b/w the loads.
1309 LoadInst *Start = LOps.FoundRoot ? LOps.RootInsert : LI1, *End = LI2;
1310 MemoryLocation Loc;
1311 if (!Start->comesBefore(Other: End)) {
1312 std::swap(a&: Start, b&: End);
1313 // If LOps.RootInsert comes after LI2, since we use LI2 as the new insert
1314 // point, we should make sure whether the memory region accessed by LOps
1315 // isn't modified.
1316 if (LOps.FoundRoot)
1317 Loc = MemoryLocation(
1318 LOps.Root->getPointerOperand(),
1319 LocationSize::precise(Value: DL.getTypeStoreSize(
1320 Ty: IntegerType::get(C&: LI1->getContext(), NumBits: LOps.LoadSize))),
1321 LOps.AATags);
1322 else
1323 Loc = MemoryLocation::get(LI: End);
1324 } else
1325 Loc = MemoryLocation::get(LI: End);
1326 unsigned NumScanned = 0;
1327 for (Instruction &Inst :
1328 make_range(x: Start->getIterator(), y: End->getIterator())) {
1329 if (Inst.mayWriteToMemory() && isModSet(MRI: AA.getModRefInfo(I: &Inst, OptLoc: Loc)))
1330 return false;
1331
1332 if (++NumScanned > MaxInstrsToScan)
1333 return false;
1334 }
1335
1336 // Make sure Load with lower Offset is at LI1
1337 bool Reverse = false;
1338 if (Offset2.slt(RHS: Offset1)) {
1339 std::swap(a&: LI1, b&: LI2);
1340 std::swap(a&: ShAmt1, b&: ShAmt2);
1341 std::swap(a&: Offset1, b&: Offset2);
1342 std::swap(a&: Load1Ptr, b&: Load2Ptr);
1343 std::swap(a&: LoadSize1, b&: LoadSize2);
1344 Reverse = true;
1345 }
1346
1347 // Big endian swap the shifts
1348 if (IsBigEndian)
1349 std::swap(a&: ShAmt1, b&: ShAmt2);
1350
1351 // First load is always LI1. This is where we put the new load.
1352 // Use the merged load size available from LI1 for forward loads.
1353 if (LOps.FoundRoot) {
1354 if (!Reverse)
1355 LoadSize1 = LOps.LoadSize;
1356 else
1357 LoadSize2 = LOps.LoadSize;
1358 }
1359
1360 // Verify if shift amount and load index aligns and verifies that loads
1361 // are consecutive.
1362 uint64_t ShiftDiff = IsBigEndian ? LoadSize2 : LoadSize1;
1363 uint64_t PrevSize =
1364 DL.getTypeStoreSize(Ty: IntegerType::get(C&: LI1->getContext(), NumBits: LoadSize1));
1365 if ((ShAmt2 - ShAmt1) != ShiftDiff || (Offset2 - Offset1) != PrevSize)
1366 return false;
1367
1368 // Update LOps
1369 AAMDNodes AATags1 = LOps.AATags;
1370 AAMDNodes AATags2 = LI2->getAAMetadata();
1371 if (LOps.FoundRoot == false) {
1372 LOps.FoundRoot = true;
1373 AATags1 = LI1->getAAMetadata();
1374 }
1375 LOps.LoadSize = LoadSize1 + LoadSize2;
1376 LOps.RootInsert = Start;
1377
1378 // Concatenate the AATags of the Merged Loads.
1379 LOps.AATags = AATags1.concat(Other: AATags2);
1380
1381 LOps.Root = LI1;
1382 LOps.Shift = ShAmt1;
1383 LOps.ZextType = X->getType();
1384 return true;
1385}
1386
1387// For a given BB instruction, evaluate all loads in the chain that form a
1388// pattern which suggests that the loads can be combined. The one and only use
1389// of the loads is to form a wider load.
1390static bool foldConsecutiveLoads(Instruction &I, const DataLayout &DL,
1391 TargetTransformInfo &TTI, AliasAnalysis &AA,
1392 const DominatorTree &DT) {
1393 // Only consider load chains of scalar values.
1394 if (isa<VectorType>(Val: I.getType()))
1395 return false;
1396
1397 LoadOps LOps;
1398 if (!foldLoadsRecursive(V: &I, LOps, DL, AA, /*IsRoot=*/true) || !LOps.FoundRoot)
1399 return false;
1400
1401 IRBuilder<> Builder(&I);
1402 LoadInst *NewLoad = nullptr, *LI1 = LOps.Root;
1403
1404 IntegerType *WiderType = IntegerType::get(C&: I.getContext(), NumBits: LOps.LoadSize);
1405 // TTI based checks if we want to proceed with wider load
1406 bool Allowed = TTI.isTypeLegal(Ty: WiderType);
1407 if (!Allowed)
1408 return false;
1409
1410 unsigned AS = LI1->getPointerAddressSpace();
1411 unsigned Fast = 0;
1412 Allowed = TTI.allowsMisalignedMemoryAccesses(Context&: I.getContext(), BitWidth: LOps.LoadSize,
1413 AddressSpace: AS, Alignment: LI1->getAlign(), Fast: &Fast);
1414 if (!Allowed || !Fast)
1415 return false;
1416
1417 // Get the Index and Ptr for the new GEP.
1418 Value *Load1Ptr = LI1->getPointerOperand();
1419 Builder.SetInsertPoint(LOps.RootInsert);
1420 if (!DT.dominates(Def: Load1Ptr, User: LOps.RootInsert)) {
1421 APInt Offset1(DL.getIndexTypeSizeInBits(Ty: Load1Ptr->getType()), 0);
1422 Load1Ptr = Load1Ptr->stripAndAccumulateConstantOffsets(
1423 DL, Offset&: Offset1, /* AllowNonInbounds */ true);
1424 Load1Ptr = Builder.CreatePtrAdd(Ptr: Load1Ptr, Offset: Builder.getInt(AI: Offset1));
1425 }
1426 // Generate wider load.
1427 NewLoad = Builder.CreateAlignedLoad(Ty: WiderType, Ptr: Load1Ptr, Align: LI1->getAlign(),
1428 isVolatile: LI1->isVolatile(), Name: "");
1429 NewLoad->takeName(V: LI1);
1430 // Set the New Load AATags Metadata.
1431 if (LOps.AATags)
1432 NewLoad->setAAMetadata(LOps.AATags);
1433
1434 Value *NewOp = NewLoad;
1435 // Check if zero extend needed.
1436 if (LOps.ZextType)
1437 NewOp = Builder.CreateZExt(V: NewOp, DestTy: LOps.ZextType);
1438
1439 // Check if shift needed. We need to shift with the amount of load1
1440 // shift if not zero.
1441 if (LOps.Shift)
1442 NewOp = Builder.CreateShl(LHS: NewOp, RHS: LOps.Shift);
1443 I.replaceAllUsesWith(V: NewOp);
1444
1445 return true;
1446}
1447
1448/// ValWidth bits starting at ValOffset of Val stored at PtrBase+PtrOffset.
1449struct PartStore {
1450 Value *PtrBase;
1451 APInt PtrOffset;
1452 Value *Val;
1453 uint64_t ValOffset;
1454 uint64_t ValWidth;
1455 StoreInst *Store;
1456
1457 bool isCompatibleWith(const PartStore &Other) const {
1458 return PtrBase == Other.PtrBase && Val == Other.Val;
1459 }
1460
1461 bool operator<(const PartStore &Other) const {
1462 return PtrOffset.slt(RHS: Other.PtrOffset);
1463 }
1464};
1465
1466static std::optional<PartStore> matchPartStore(Instruction &I,
1467 const DataLayout &DL) {
1468 auto *Store = dyn_cast<StoreInst>(Val: &I);
1469 if (!Store || !Store->isSimple())
1470 return std::nullopt;
1471
1472 Value *StoredVal = Store->getValueOperand();
1473 Type *StoredTy = StoredVal->getType();
1474 if (!StoredTy->isIntegerTy() || !DL.typeSizeEqualsStoreSize(Ty: StoredTy))
1475 return std::nullopt;
1476
1477 uint64_t ValWidth = StoredTy->getPrimitiveSizeInBits();
1478 uint64_t ValOffset;
1479 Value *Val;
1480 if (!match(V: StoredVal, P: m_Trunc(Op: m_LShrOrSelf(L: m_Value(V&: Val), R&: ValOffset))))
1481 return std::nullopt;
1482
1483 Value *Ptr = Store->getPointerOperand();
1484 APInt PtrOffset(DL.getIndexTypeSizeInBits(Ty: Ptr->getType()), 0);
1485 Value *PtrBase = Ptr->stripAndAccumulateConstantOffsets(
1486 DL, Offset&: PtrOffset, /*AllowNonInbounds=*/true);
1487 return {{.PtrBase: PtrBase, .PtrOffset: PtrOffset, .Val: Val, .ValOffset: ValOffset, .ValWidth: ValWidth, .Store: Store}};
1488}
1489
1490static bool mergeConsecutivePartStores(ArrayRef<PartStore> Parts,
1491 unsigned Width, const DataLayout &DL,
1492 TargetTransformInfo &TTI) {
1493 if (Parts.size() < 2)
1494 return false;
1495
1496 // Check whether combining the stores is profitable.
1497 // FIXME: We could generate smaller stores if we can't produce a large one.
1498 const PartStore &First = Parts.front();
1499 LLVMContext &Ctx = First.Store->getContext();
1500 Type *NewTy = Type::getIntNTy(C&: Ctx, N: Width);
1501 unsigned Fast = 0;
1502 if (!TTI.isTypeLegal(Ty: NewTy) ||
1503 !TTI.allowsMisalignedMemoryAccesses(Context&: Ctx, BitWidth: Width,
1504 AddressSpace: First.Store->getPointerAddressSpace(),
1505 Alignment: First.Store->getAlign(), Fast: &Fast) ||
1506 !Fast)
1507 return false;
1508
1509 // Generate the combined store.
1510 IRBuilder<> Builder(First.Store);
1511 Value *Val = First.Val;
1512 if (First.ValOffset != 0)
1513 Val = Builder.CreateLShr(LHS: Val, RHS: First.ValOffset);
1514 Val = Builder.CreateZExtOrTrunc(V: Val, DestTy: NewTy);
1515 StoreInst *Store = Builder.CreateAlignedStore(
1516 Val, Ptr: First.Store->getPointerOperand(), Align: First.Store->getAlign());
1517
1518 // Merge various metadata onto the new store.
1519 AAMDNodes AATags = First.Store->getAAMetadata();
1520 SmallVector<Instruction *> Stores = {First.Store};
1521 Stores.reserve(N: Parts.size());
1522 SmallVector<DebugLoc> DbgLocs = {First.Store->getDebugLoc()};
1523 DbgLocs.reserve(N: Parts.size());
1524 for (const PartStore &Part : drop_begin(RangeOrContainer&: Parts)) {
1525 AATags = AATags.concat(Other: Part.Store->getAAMetadata());
1526 Stores.push_back(Elt: Part.Store);
1527 DbgLocs.push_back(Elt: Part.Store->getDebugLoc());
1528 }
1529 Store->setAAMetadata(AATags);
1530 Store->mergeDIAssignID(SourceInstructions: Stores);
1531 Store->setDebugLoc(DebugLoc::getMergedLocations(Locs: DbgLocs));
1532
1533 // Remove the old stores.
1534 for (const PartStore &Part : Parts)
1535 Part.Store->eraseFromParent();
1536
1537 return true;
1538}
1539
1540static bool mergePartStores(SmallVectorImpl<PartStore> &Parts,
1541 const DataLayout &DL, TargetTransformInfo &TTI) {
1542 if (Parts.size() < 2)
1543 return false;
1544
1545 // We now have multiple parts of the same value stored to the same pointer.
1546 // Sort the parts by pointer offset, and make sure they are consistent with
1547 // the value offsets. Also check that the value is fully covered without
1548 // overlaps.
1549 bool Changed = false;
1550 llvm::sort(C&: Parts);
1551 int64_t LastEndOffsetFromFirst = 0;
1552 const PartStore *First = &Parts[0];
1553 for (const PartStore &Part : Parts) {
1554 APInt PtrOffsetFromFirst = Part.PtrOffset - First->PtrOffset;
1555 int64_t ValOffsetFromFirst = Part.ValOffset - First->ValOffset;
1556 if (PtrOffsetFromFirst * 8 != ValOffsetFromFirst ||
1557 LastEndOffsetFromFirst != ValOffsetFromFirst) {
1558 Changed |= mergeConsecutivePartStores(Parts: ArrayRef(First, &Part),
1559 Width: LastEndOffsetFromFirst, DL, TTI);
1560 First = &Part;
1561 LastEndOffsetFromFirst = Part.ValWidth;
1562 continue;
1563 }
1564
1565 LastEndOffsetFromFirst = ValOffsetFromFirst + Part.ValWidth;
1566 }
1567
1568 Changed |= mergeConsecutivePartStores(Parts: ArrayRef(First, Parts.end()),
1569 Width: LastEndOffsetFromFirst, DL, TTI);
1570 return Changed;
1571}
1572
1573static bool foldConsecutiveStores(BasicBlock &BB, const DataLayout &DL,
1574 TargetTransformInfo &TTI, AliasAnalysis &AA) {
1575 // FIXME: Add big endian support.
1576 if (DL.isBigEndian())
1577 return false;
1578
1579 BatchAAResults BatchAA(AA);
1580 SmallVector<PartStore, 8> Parts;
1581 bool MadeChange = false;
1582 for (Instruction &I : make_early_inc_range(Range&: BB)) {
1583 if (std::optional<PartStore> Part = matchPartStore(I, DL)) {
1584 if (Parts.empty() || Part->isCompatibleWith(Other: Parts[0])) {
1585 Parts.push_back(Elt: std::move(*Part));
1586 continue;
1587 }
1588
1589 MadeChange |= mergePartStores(Parts, DL, TTI);
1590 Parts.clear();
1591 Parts.push_back(Elt: std::move(*Part));
1592 continue;
1593 }
1594
1595 if (Parts.empty())
1596 continue;
1597
1598 if (I.mayThrow() ||
1599 (I.mayReadOrWriteMemory() &&
1600 isModOrRefSet(MRI: BatchAA.getModRefInfo(
1601 I: &I, OptLoc: MemoryLocation::getBeforeOrAfter(Ptr: Parts[0].PtrBase))))) {
1602 MadeChange |= mergePartStores(Parts, DL, TTI);
1603 Parts.clear();
1604 continue;
1605 }
1606 }
1607
1608 MadeChange |= mergePartStores(Parts, DL, TTI);
1609 return MadeChange;
1610}
1611
1612/// Combine away instructions providing they are still equivalent when compared
1613/// against 0. i.e do they have any bits set.
1614static Value *optimizeShiftInOrChain(Value *V, IRBuilder<> &Builder) {
1615 auto *I = dyn_cast<Instruction>(Val: V);
1616 if (!I || I->getOpcode() != Instruction::Or || !I->hasOneUse())
1617 return nullptr;
1618
1619 Value *A;
1620
1621 // Look deeper into the chain of or's, combining away shl (so long as they are
1622 // nuw or nsw).
1623 Value *Op0 = I->getOperand(i: 0);
1624 if (match(V: Op0, P: m_CombineOr(Ps: m_NSWShl(L: m_Value(V&: A), R: m_Value()),
1625 Ps: m_NUWShl(L: m_Value(V&: A), R: m_Value()))))
1626 Op0 = A;
1627 else if (auto *NOp = optimizeShiftInOrChain(V: Op0, Builder))
1628 Op0 = NOp;
1629
1630 Value *Op1 = I->getOperand(i: 1);
1631 if (match(V: Op1, P: m_CombineOr(Ps: m_NSWShl(L: m_Value(V&: A), R: m_Value()),
1632 Ps: m_NUWShl(L: m_Value(V&: A), R: m_Value()))))
1633 Op1 = A;
1634 else if (auto *NOp = optimizeShiftInOrChain(V: Op1, Builder))
1635 Op1 = NOp;
1636
1637 if (Op0 != I->getOperand(i: 0) || Op1 != I->getOperand(i: 1))
1638 return Builder.CreateOr(LHS: Op0, RHS: Op1);
1639 return nullptr;
1640}
1641
1642static bool foldICmpOrChain(Instruction &I, const DataLayout &DL,
1643 TargetTransformInfo &TTI, AliasAnalysis &AA,
1644 const DominatorTree &DT) {
1645 CmpPredicate Pred;
1646 Value *Op0;
1647 if (!match(V: &I, P: m_ICmp(Pred, L: m_Value(V&: Op0), R: m_Zero())) ||
1648 !ICmpInst::isEquality(P: Pred))
1649 return false;
1650
1651 // If the chain or or's matches a load, combine to that before attempting to
1652 // remove shifts.
1653 if (auto OpI = dyn_cast<Instruction>(Val: Op0))
1654 if (OpI->getOpcode() == Instruction::Or)
1655 if (foldConsecutiveLoads(I&: *OpI, DL, TTI, AA, DT))
1656 return true;
1657
1658 IRBuilder<> Builder(&I);
1659 // icmp eq/ne or(shl(a), b), 0 -> icmp eq/ne or(a, b), 0
1660 if (auto *Res = optimizeShiftInOrChain(V: Op0, Builder)) {
1661 I.replaceAllUsesWith(V: Builder.CreateICmp(P: Pred, LHS: Res, RHS: I.getOperand(i: 1)));
1662 return true;
1663 }
1664
1665 return false;
1666}
1667
1668// Calculate GEP Stride and accumulated const ModOffset. Return Stride and
1669// ModOffset
1670static std::pair<APInt, APInt>
1671getStrideAndModOffsetOfGEP(Value *PtrOp, const DataLayout &DL) {
1672 unsigned BW = DL.getIndexTypeSizeInBits(Ty: PtrOp->getType());
1673 std::optional<APInt> Stride;
1674 APInt ModOffset(BW, 0);
1675 // Return a minimum gep stride, greatest common divisor of consective gep
1676 // index scales(c.f. Bézout's identity).
1677 while (auto *GEP = dyn_cast<GEPOperator>(Val: PtrOp)) {
1678 SmallMapVector<Value *, APInt, 4> VarOffsets;
1679 if (!GEP->collectOffset(DL, BitWidth: BW, VariableOffsets&: VarOffsets, ConstantOffset&: ModOffset))
1680 break;
1681
1682 for (auto [V, Scale] : VarOffsets) {
1683 // Only keep a power of two factor for non-inbounds
1684 if (!GEP->hasNoUnsignedSignedWrap())
1685 Scale = APInt::getOneBitSet(numBits: Scale.getBitWidth(), BitNo: Scale.countr_zero());
1686
1687 if (!Stride)
1688 Stride = Scale;
1689 else
1690 Stride = APIntOps::GreatestCommonDivisor(A: *Stride, B: Scale);
1691 }
1692
1693 PtrOp = GEP->getPointerOperand();
1694 }
1695
1696 // Check whether pointer arrives back at Global Variable via at least one GEP.
1697 // Even if it doesn't, we can check by alignment.
1698 if (!isa<GlobalVariable>(Val: PtrOp) || !Stride)
1699 return {APInt(BW, 1), APInt(BW, 0)};
1700
1701 // In consideration of signed GEP indices, non-negligible offset become
1702 // remainder of division by minimum GEP stride.
1703 ModOffset = ModOffset.srem(RHS: *Stride);
1704 if (ModOffset.isNegative())
1705 ModOffset += *Stride;
1706
1707 return {*Stride, ModOffset};
1708}
1709
1710/// If C is a constant patterned array and all valid loaded results for given
1711/// alignment are same to a constant, return that constant.
1712static bool foldPatternedLoads(Instruction &I, const DataLayout &DL) {
1713 auto *LI = dyn_cast<LoadInst>(Val: &I);
1714 if (!LI || LI->isVolatile())
1715 return false;
1716
1717 // We can only fold the load if it is from a constant global with definitive
1718 // initializer. Skip expensive logic if this is not the case.
1719 auto *PtrOp = LI->getPointerOperand();
1720 auto *GV = dyn_cast<GlobalVariable>(Val: getUnderlyingObject(V: PtrOp));
1721 if (!GV || !GV->isConstant() || !GV->hasDefinitiveInitializer())
1722 return false;
1723
1724 // Bail for large initializers in excess of 4K to avoid too many scans.
1725 Constant *C = GV->getInitializer();
1726 uint64_t GVSize = DL.getTypeAllocSize(Ty: C->getType());
1727 if (!GVSize || 4096 < GVSize)
1728 return false;
1729
1730 Type *LoadTy = LI->getType();
1731 unsigned BW = DL.getIndexTypeSizeInBits(Ty: PtrOp->getType());
1732 auto [Stride, ConstOffset] = getStrideAndModOffsetOfGEP(PtrOp, DL);
1733
1734 // Any possible offset could be multiple of GEP stride. And any valid
1735 // offset is multiple of load alignment, so checking only multiples of bigger
1736 // one is sufficient to say results' equality.
1737 if (auto LA = LI->getAlign();
1738 LA <= GV->getAlign().valueOrOne() && Stride.getZExtValue() < LA.value()) {
1739 ConstOffset = APInt(BW, 0);
1740 Stride = APInt(BW, LA.value());
1741 }
1742
1743 Constant *Ca = ConstantFoldLoadFromConst(C, Ty: LoadTy, Offset: ConstOffset, DL);
1744 if (!Ca)
1745 return false;
1746
1747 unsigned E = GVSize - DL.getTypeStoreSize(Ty: LoadTy);
1748 for (; ConstOffset.getZExtValue() <= E; ConstOffset += Stride)
1749 if (Ca != ConstantFoldLoadFromConst(C, Ty: LoadTy, Offset: ConstOffset, DL))
1750 return false;
1751
1752 I.replaceAllUsesWith(V: Ca);
1753
1754 return true;
1755}
1756
1757namespace {
1758class StrNCmpInliner {
1759public:
1760 StrNCmpInliner(CallInst *CI, LibFunc Func, DomTreeUpdater *DTU,
1761 const DataLayout &DL)
1762 : CI(CI), Func(Func), DTU(DTU), DL(DL) {}
1763
1764 bool optimizeStrNCmp();
1765
1766private:
1767 void inlineCompare(Value *LHS, StringRef RHS, uint64_t N, bool Swapped);
1768
1769 CallInst *CI;
1770 LibFunc Func;
1771 DomTreeUpdater *DTU;
1772 const DataLayout &DL;
1773};
1774
1775} // namespace
1776
1777/// First we normalize calls to strncmp/strcmp to the form of
1778/// compare(s1, s2, N), which means comparing first N bytes of s1 and s2
1779/// (without considering '\0').
1780///
1781/// Examples:
1782///
1783/// \code
1784/// strncmp(s, "a", 3) -> compare(s, "a", 2)
1785/// strncmp(s, "abc", 3) -> compare(s, "abc", 3)
1786/// strncmp(s, "a\0b", 3) -> compare(s, "a\0b", 2)
1787/// strcmp(s, "a") -> compare(s, "a", 2)
1788///
1789/// char s2[] = {'a'}
1790/// strncmp(s, s2, 3) -> compare(s, s2, 3)
1791///
1792/// char s2[] = {'a', 'b', 'c', 'd'}
1793/// strncmp(s, s2, 3) -> compare(s, s2, 3)
1794/// \endcode
1795///
1796/// We only handle cases where N and exactly one of s1 and s2 are constant.
1797/// Cases that s1 and s2 are both constant are already handled by the
1798/// instcombine pass.
1799///
1800/// We do not handle cases where N > StrNCmpInlineThreshold.
1801///
1802/// We also do not handles cases where N < 2, which are already
1803/// handled by the instcombine pass.
1804///
1805bool StrNCmpInliner::optimizeStrNCmp() {
1806 if (StrNCmpInlineThreshold < 2)
1807 return false;
1808
1809 if (!isOnlyUsedInZeroComparison(CxtI: CI))
1810 return false;
1811
1812 Value *Str1P = CI->getArgOperand(i: 0);
1813 Value *Str2P = CI->getArgOperand(i: 1);
1814 // Should be handled elsewhere.
1815 if (Str1P == Str2P)
1816 return false;
1817
1818 StringRef Str1, Str2;
1819 bool HasStr1 = getConstantStringInfo(V: Str1P, Str&: Str1, /*TrimAtNul=*/false);
1820 bool HasStr2 = getConstantStringInfo(V: Str2P, Str&: Str2, /*TrimAtNul=*/false);
1821 if (HasStr1 == HasStr2)
1822 return false;
1823
1824 // Note that '\0' and characters after it are not trimmed.
1825 StringRef Str = HasStr1 ? Str1 : Str2;
1826 Value *StrP = HasStr1 ? Str2P : Str1P;
1827
1828 size_t Idx = Str.find(C: '\0');
1829 uint64_t N = Idx == StringRef::npos ? UINT64_MAX : Idx + 1;
1830 if (Func == LibFunc_strncmp) {
1831 if (auto *ConstInt = dyn_cast<ConstantInt>(Val: CI->getArgOperand(i: 2)))
1832 N = std::min(a: N, b: ConstInt->getZExtValue());
1833 else
1834 return false;
1835 }
1836 // Now N means how many bytes we need to compare at most.
1837 if (N > Str.size() || N < 2 || N > StrNCmpInlineThreshold)
1838 return false;
1839
1840 // Cases where StrP has two or more dereferenceable bytes might be better
1841 // optimized elsewhere.
1842 bool CanBeNull = false;
1843 if (StrP->getPointerDereferenceableBytes(DL, CanBeNull,
1844 /*CanBeFreed=*/nullptr) > 1)
1845 return false;
1846 inlineCompare(LHS: StrP, RHS: Str, N, Swapped: HasStr1);
1847 return true;
1848}
1849
1850/// Convert
1851///
1852/// \code
1853/// ret = compare(s1, s2, N)
1854/// \endcode
1855///
1856/// into
1857///
1858/// \code
1859/// ret = (int)s1[0] - (int)s2[0]
1860/// if (ret != 0)
1861/// goto NE
1862/// ...
1863/// ret = (int)s1[N-2] - (int)s2[N-2]
1864/// if (ret != 0)
1865/// goto NE
1866/// ret = (int)s1[N-1] - (int)s2[N-1]
1867/// NE:
1868/// \endcode
1869///
1870/// CFG before and after the transformation:
1871///
1872/// (before)
1873/// BBCI
1874///
1875/// (after)
1876/// BBCI -> BBSubs[0] (sub,icmp) --NE-> BBNE -> BBTail
1877/// | ^
1878/// E |
1879/// | |
1880/// BBSubs[1] (sub,icmp) --NE-----+
1881/// ... |
1882/// BBSubs[N-1] (sub) ---------+
1883///
1884void StrNCmpInliner::inlineCompare(Value *LHS, StringRef RHS, uint64_t N,
1885 bool Swapped) {
1886 auto &Ctx = CI->getContext();
1887 IRBuilder<> B(Ctx);
1888 // We want these instructions to be recognized as inlined instructions for the
1889 // compare call, but we don't have a source location for the definition of
1890 // that function, since we're generating that code now. Because the generated
1891 // code is a viable point for a memory access error, we make the pragmatic
1892 // choice here to directly use CI's location so that we have useful
1893 // attribution for the generated code.
1894 B.SetCurrentDebugLocation(CI->getDebugLoc());
1895
1896 BasicBlock *BBCI = CI->getParent();
1897 BasicBlock *BBTail =
1898 SplitBlock(Old: BBCI, SplitPt: CI, DTU, LI: nullptr, MSSAU: nullptr, BBName: BBCI->getName() + ".tail");
1899
1900 SmallVector<BasicBlock *> BBSubs;
1901 for (uint64_t I = 0; I < N; ++I)
1902 BBSubs.push_back(
1903 Elt: BasicBlock::Create(Context&: Ctx, Name: "sub_" + Twine(I), Parent: BBCI->getParent(), InsertBefore: BBTail));
1904 BasicBlock *BBNE = BasicBlock::Create(Context&: Ctx, Name: "ne", Parent: BBCI->getParent(), InsertBefore: BBTail);
1905
1906 cast<UncondBrInst>(Val: BBCI->getTerminator())->setSuccessor(BBSubs[0]);
1907
1908 B.SetInsertPoint(BBNE);
1909 PHINode *Phi = B.CreatePHI(Ty: CI->getType(), NumReservedValues: N);
1910 B.CreateBr(Dest: BBTail);
1911
1912 Value *Base = LHS;
1913 for (uint64_t i = 0; i < N; ++i) {
1914 B.SetInsertPoint(BBSubs[i]);
1915 Value *VL =
1916 B.CreateZExt(V: B.CreateLoad(Ty: B.getInt8Ty(),
1917 Ptr: B.CreateInBoundsPtrAdd(Ptr: Base, Offset: B.getInt64(C: i))),
1918 DestTy: CI->getType());
1919 Value *VR =
1920 ConstantInt::get(Ty: CI->getType(), V: static_cast<unsigned char>(RHS[i]));
1921 Value *Sub = Swapped ? B.CreateSub(LHS: VR, RHS: VL) : B.CreateSub(LHS: VL, RHS: VR);
1922 if (i < N - 1) {
1923 CondBrInst *CondBrInst = B.CreateCondBr(
1924 Cond: B.CreateICmpNE(LHS: Sub, RHS: ConstantInt::get(Ty: CI->getType(), V: 0)), True: BBNE,
1925 False: BBSubs[i + 1]);
1926
1927 Function *F = CI->getFunction();
1928 assert(F && "Instruction does not belong to a function!");
1929 std::optional<uint64_t> EC = F->getEntryCount();
1930 if (EC && *EC > 0)
1931 setExplicitlyUnknownBranchWeights(I&: *CondBrInst, DEBUG_TYPE);
1932 } else {
1933 B.CreateBr(Dest: BBNE);
1934 }
1935
1936 Phi->addIncoming(V: Sub, BB: BBSubs[i]);
1937 }
1938
1939 CI->replaceAllUsesWith(V: Phi);
1940 CI->eraseFromParent();
1941
1942 if (DTU) {
1943 SmallVector<DominatorTree::UpdateType, 8> Updates;
1944 Updates.push_back(Elt: {DominatorTree::Insert, BBCI, BBSubs[0]});
1945 for (uint64_t i = 0; i < N; ++i) {
1946 if (i < N - 1)
1947 Updates.push_back(Elt: {DominatorTree::Insert, BBSubs[i], BBSubs[i + 1]});
1948 Updates.push_back(Elt: {DominatorTree::Insert, BBSubs[i], BBNE});
1949 }
1950 Updates.push_back(Elt: {DominatorTree::Insert, BBNE, BBTail});
1951 Updates.push_back(Elt: {DominatorTree::Delete, BBCI, BBTail});
1952 DTU->applyUpdates(Updates);
1953 }
1954}
1955
1956/// Convert memchr with a small constant string into a switch
1957static bool foldMemChr(CallInst *Call, DomTreeUpdater *DTU,
1958 const DataLayout &DL) {
1959 if (isa<Constant>(Val: Call->getArgOperand(i: 1)))
1960 return false;
1961
1962 StringRef Str;
1963 Value *Base = Call->getArgOperand(i: 0);
1964 if (!getConstantStringInfo(V: Base, Str, /*TrimAtNul=*/false))
1965 return false;
1966
1967 uint64_t N = Str.size();
1968 if (auto *ConstInt = dyn_cast<ConstantInt>(Val: Call->getArgOperand(i: 2))) {
1969 uint64_t Val = ConstInt->getZExtValue();
1970 // Ignore the case that n is larger than the size of string.
1971 if (Val > N)
1972 return false;
1973 N = Val;
1974 } else
1975 return false;
1976
1977 if (N > MemChrInlineThreshold)
1978 return false;
1979
1980 BasicBlock *BB = Call->getParent();
1981 BasicBlock *BBNext = SplitBlock(Old: BB, SplitPt: Call, DTU);
1982 IRBuilder<> IRB(BB);
1983 IRB.SetCurrentDebugLocation(Call->getDebugLoc());
1984 IntegerType *ByteTy = IRB.getInt8Ty();
1985 BB->getTerminator()->eraseFromParent();
1986 SwitchInst *SI = IRB.CreateSwitch(
1987 V: IRB.CreateTrunc(V: Call->getArgOperand(i: 1), DestTy: ByteTy), Dest: BBNext, NumCases: N);
1988 // We can't know the precise weights here, as they would depend on the value
1989 // distribution of Call->getArgOperand(1). So we just mark it as "unknown".
1990 setExplicitlyUnknownBranchWeightsIfProfiled(I&: *SI, DEBUG_TYPE);
1991 Type *IndexTy = DL.getIndexType(PtrTy: Call->getType());
1992 SmallVector<DominatorTree::UpdateType, 8> Updates;
1993
1994 BasicBlock *BBSuccess = BasicBlock::Create(
1995 Context&: Call->getContext(), Name: "memchr.success", Parent: BB->getParent(), InsertBefore: BBNext);
1996 IRB.SetInsertPoint(BBSuccess);
1997 PHINode *IndexPHI = IRB.CreatePHI(Ty: IndexTy, NumReservedValues: N, Name: "memchr.idx");
1998 Value *FirstOccursLocation = IRB.CreateInBoundsPtrAdd(Ptr: Base, Offset: IndexPHI);
1999 IRB.CreateBr(Dest: BBNext);
2000 if (DTU)
2001 Updates.push_back(Elt: {DominatorTree::Insert, BBSuccess, BBNext});
2002
2003 SmallPtrSet<ConstantInt *, 4> Cases;
2004 for (uint64_t I = 0; I < N; ++I) {
2005 ConstantInt *CaseVal =
2006 ConstantInt::get(Ty: ByteTy, V: static_cast<unsigned char>(Str[I]));
2007 if (!Cases.insert(Ptr: CaseVal).second)
2008 continue;
2009
2010 BasicBlock *BBCase = BasicBlock::Create(Context&: Call->getContext(), Name: "memchr.case",
2011 Parent: BB->getParent(), InsertBefore: BBSuccess);
2012 SI->addCase(OnVal: CaseVal, Dest: BBCase);
2013 IRB.SetInsertPoint(BBCase);
2014 IndexPHI->addIncoming(V: ConstantInt::get(Ty: IndexTy, V: I), BB: BBCase);
2015 IRB.CreateBr(Dest: BBSuccess);
2016 if (DTU) {
2017 Updates.push_back(Elt: {DominatorTree::Insert, BB, BBCase});
2018 Updates.push_back(Elt: {DominatorTree::Insert, BBCase, BBSuccess});
2019 }
2020 }
2021
2022 PHINode *PHI =
2023 PHINode::Create(Ty: Call->getType(), NumReservedValues: 2, NameStr: Call->getName(), InsertBefore: BBNext->begin());
2024 PHI->addIncoming(V: Constant::getNullValue(Ty: Call->getType()), BB);
2025 PHI->addIncoming(V: FirstOccursLocation, BB: BBSuccess);
2026
2027 Call->replaceAllUsesWith(V: PHI);
2028 Call->eraseFromParent();
2029
2030 if (DTU)
2031 DTU->applyUpdates(Updates);
2032
2033 return true;
2034}
2035
2036static bool foldLibCalls(Instruction &I, TargetTransformInfo &TTI,
2037 TargetLibraryInfo &TLI, AssumptionCache &AC,
2038 DominatorTree &DT, const DataLayout &DL,
2039 bool &MadeCFGChange) {
2040
2041 auto *CI = dyn_cast<CallInst>(Val: &I);
2042 if (!CI || CI->isNoBuiltin())
2043 return false;
2044
2045 Function *CalledFunc = CI->getCalledFunction();
2046 if (!CalledFunc)
2047 return false;
2048
2049 LibFunc LF;
2050 if (!TLI.getLibFunc(FDecl: *CalledFunc, F&: LF) ||
2051 !isLibFuncEmittable(M: CI->getModule(), TLI: &TLI, TheLibFunc: LF))
2052 return false;
2053
2054 DomTreeUpdater DTU(&DT, DomTreeUpdater::UpdateStrategy::Lazy);
2055
2056 switch (LF) {
2057 case LibFunc_sqrt:
2058 case LibFunc_sqrtf:
2059 case LibFunc_sqrtl:
2060 return foldSqrt(Call: CI, Func: LF, TTI, TLI, AC, DT);
2061 case LibFunc_strcmp:
2062 case LibFunc_strncmp:
2063 if (StrNCmpInliner(CI, LF, &DTU, DL).optimizeStrNCmp()) {
2064 MadeCFGChange = true;
2065 return true;
2066 }
2067 break;
2068 case LibFunc_memchr:
2069 if (foldMemChr(Call: CI, DTU: &DTU, DL)) {
2070 MadeCFGChange = true;
2071 return true;
2072 }
2073 break;
2074 default:;
2075 }
2076 return false;
2077}
2078
2079/// Match high part of long multiplication.
2080///
2081/// Considering a multiply made up of high and low parts, we can split the
2082/// multiply into:
2083/// x * y == (xh*T + xl) * (yh*T + yl)
2084/// where xh == x>>32 and xl == x & 0xffffffff. T = 2^32.
2085/// This expands to
2086/// xh*yh*T*T + xh*yl*T + xl*yh*T + xl*yl
2087/// which can be drawn as
2088/// [ xh*yh ]
2089/// [ xh*yl ]
2090/// [ xl*yh ]
2091/// [ xl*yl ]
2092/// We are looking for the "high" half, which is xh*yh + xh*yl>>32 + xl*yh>>32 +
2093/// some carrys. The carry makes this difficult and there are multiple ways of
2094/// representing it. The ones we attempt to support here are:
2095/// Carry: xh*yh + carry + lowsum
2096/// carry = lowsum < xh*yl ? 0x1000000 : 0
2097/// lowsum = xh*yl + xl*yh + (xl*yl>>32)
2098/// Ladder: xh*yh + c2>>32 + c3>>32
2099/// c2 = xh*yl + (xl*yl>>32); c3 = c2&0xffffffff + xl*yh
2100/// or c2 = (xl*yh&0xffffffff) + xh*yl + (xl*yl>>32); c3 = xl*yh
2101/// Carry4: xh*yh + carry + crosssum>>32 + (xl*yl + crosssum&0xffffffff) >> 32
2102/// crosssum = xh*yl + xl*yh
2103/// carry = crosssum < xh*yl ? 0x1000000 : 0
2104/// Ladder4: xh*yh + (xl*yh)>>32 + (xh*yl)>>32 + low>>32;
2105/// low = (xl*yl)>>32 + (xl*yh)&0xffffffff + (xh*yl)&0xffffffff
2106///
2107/// They all start by matching xh*yh + 2 or 3 other operands. The bottom of the
2108/// tree is xh*yh, xh*yl, xl*yh and xl*yl.
2109static bool foldMulHigh(Instruction &I) {
2110 Type *Ty = I.getType();
2111 if (!Ty->isIntOrIntVectorTy())
2112 return false;
2113
2114 unsigned BitWidth = Ty->getScalarSizeInBits();
2115 APInt LowMask = APInt::getLowBitsSet(numBits: BitWidth, loBitsSet: BitWidth / 2);
2116 if (BitWidth % 2 != 0)
2117 return false;
2118
2119 auto CreateMulHigh = [&](Value *X, Value *Y) {
2120 IRBuilder<> Builder(&I);
2121 Type *NTy = Ty->getWithNewBitWidth(NewBitWidth: BitWidth * 2);
2122 Value *XExt = Builder.CreateZExt(V: X, DestTy: NTy);
2123 Value *YExt = Builder.CreateZExt(V: Y, DestTy: NTy);
2124 Value *Mul = Builder.CreateMul(LHS: XExt, RHS: YExt, Name: "", /*HasNUW=*/true);
2125 Value *High = Builder.CreateLShr(LHS: Mul, RHS: BitWidth);
2126 Value *Res = Builder.CreateTrunc(V: High, DestTy: Ty, Name: "", /*HasNUW=*/IsNUW: true);
2127 Res->takeName(V: &I);
2128 I.replaceAllUsesWith(V: Res);
2129 LLVM_DEBUG(dbgs() << "Created long multiply from parts of " << *X << " and "
2130 << *Y << "\n");
2131 return true;
2132 };
2133
2134 // Common check routines for X_lo*Y_lo and X_hi*Y_lo
2135 auto CheckLoLo = [&](Value *XlYl, Value *X, Value *Y) {
2136 return match(V: XlYl, P: m_c_Mul(L: m_And(L: m_Specific(V: X), R: m_SpecificInt(V: LowMask)),
2137 R: m_And(L: m_Specific(V: Y), R: m_SpecificInt(V: LowMask))));
2138 };
2139 auto CheckHiLo = [&](Value *XhYl, Value *X, Value *Y) {
2140 return match(V: XhYl,
2141 P: m_c_Mul(L: m_LShr(L: m_Specific(V: X), R: m_SpecificInt(V: BitWidth / 2)),
2142 R: m_And(L: m_Specific(V: Y), R: m_SpecificInt(V: LowMask))));
2143 };
2144
2145 auto FoldMulHighCarry = [&](Value *X, Value *Y, Instruction *Carry,
2146 Instruction *B) {
2147 // Looking for LowSum >> 32 and carry (select)
2148 if (Carry->getOpcode() != Instruction::Select)
2149 std::swap(a&: Carry, b&: B);
2150
2151 // Carry = LowSum < XhYl ? 0x100000000 : 0
2152 Value *LowSum, *XhYl;
2153 if (!match(V: Carry,
2154 P: m_OneUse(SubPattern: m_Select(
2155 C: m_OneUse(SubPattern: m_SpecificICmp(MatchPred: ICmpInst::ICMP_ULT, L: m_Value(V&: LowSum),
2156 R: m_Value(V&: XhYl))),
2157 L: m_SpecificInt(V: APInt::getOneBitSet(numBits: BitWidth, BitNo: BitWidth / 2)),
2158 R: m_Zero()))))
2159 return false;
2160
2161 // XhYl can be Xh*Yl or Xl*Yh
2162 if (!CheckHiLo(XhYl, X, Y)) {
2163 if (CheckHiLo(XhYl, Y, X))
2164 std::swap(a&: X, b&: Y);
2165 else
2166 return false;
2167 }
2168 if (XhYl->hasNUsesOrMore(N: 3))
2169 return false;
2170
2171 // B = LowSum >> 32
2172 if (!match(V: B, P: m_OneUse(SubPattern: m_LShr(L: m_Specific(V: LowSum),
2173 R: m_SpecificInt(V: BitWidth / 2)))) ||
2174 LowSum->hasNUsesOrMore(N: 3))
2175 return false;
2176
2177 // LowSum = XhYl + XlYh + XlYl>>32
2178 Value *XlYh, *XlYl;
2179 auto XlYlHi = m_LShr(L: m_Value(V&: XlYl), R: m_SpecificInt(V: BitWidth / 2));
2180 if (!match(V: LowSum,
2181 P: m_c_Add(L: m_Specific(V: XhYl),
2182 R: m_OneUse(SubPattern: m_c_Add(L: m_OneUse(SubPattern: m_Value(V&: XlYh)), R: XlYlHi)))) &&
2183 !match(V: LowSum, P: m_c_Add(L: m_OneUse(SubPattern: m_Value(V&: XlYh)),
2184 R: m_OneUse(SubPattern: m_c_Add(L: m_Specific(V: XhYl), R: XlYlHi)))) &&
2185 !match(V: LowSum,
2186 P: m_c_Add(L: XlYlHi, R: m_OneUse(SubPattern: m_c_Add(L: m_Specific(V: XhYl),
2187 R: m_OneUse(SubPattern: m_Value(V&: XlYh)))))))
2188 return false;
2189
2190 // Check XlYl and XlYh
2191 if (!CheckLoLo(XlYl, X, Y))
2192 return false;
2193 if (!CheckHiLo(XlYh, Y, X))
2194 return false;
2195
2196 return CreateMulHigh(X, Y);
2197 };
2198
2199 auto FoldMulHighLadder = [&](Value *X, Value *Y, Instruction *A,
2200 Instruction *B) {
2201 // xh*yh + c2>>32 + c3>>32
2202 // c2 = xh*yl + (xl*yl>>32); c3 = c2&0xffffffff + xl*yh
2203 // or c2 = (xl*yh&0xffffffff) + xh*yl + (xl*yl>>32); c3 = xh*yl
2204 Value *XlYh, *XhYl, *XlYl, *C2, *C3;
2205 // Strip off the two expected shifts.
2206 if (!match(V: A, P: m_LShr(L: m_Value(V&: C2), R: m_SpecificInt(V: BitWidth / 2))) ||
2207 !match(V: B, P: m_LShr(L: m_Value(V&: C3), R: m_SpecificInt(V: BitWidth / 2))))
2208 return false;
2209
2210 if (match(V: C3, P: m_c_Add(L: m_Add(L: m_Value(), R: m_Value()), R: m_Value())))
2211 std::swap(a&: C2, b&: C3);
2212 // Try to match c2 = (xl*yh&0xffffffff) + xh*yl + (xl*yl>>32)
2213 if (match(V: C2,
2214 P: m_c_Add(L: m_c_Add(L: m_And(L: m_Specific(V: C3), R: m_SpecificInt(V: LowMask)),
2215 R: m_Value(V&: XlYh)),
2216 R: m_LShr(L: m_Value(V&: XlYl), R: m_SpecificInt(V: BitWidth / 2)))) ||
2217 match(V: C2, P: m_c_Add(L: m_c_Add(L: m_And(L: m_Specific(V: C3), R: m_SpecificInt(V: LowMask)),
2218 R: m_LShr(L: m_Value(V&: XlYl),
2219 R: m_SpecificInt(V: BitWidth / 2))),
2220 R: m_Value(V&: XlYh))) ||
2221 match(V: C2, P: m_c_Add(L: m_c_Add(L: m_LShr(L: m_Value(V&: XlYl),
2222 R: m_SpecificInt(V: BitWidth / 2)),
2223 R: m_Value(V&: XlYh)),
2224 R: m_And(L: m_Specific(V: C3), R: m_SpecificInt(V: LowMask))))) {
2225 XhYl = C3;
2226 } else {
2227 // Match c3 = c2&0xffffffff + xl*yh
2228 if (!match(V: C3, P: m_c_Add(L: m_And(L: m_Specific(V: C2), R: m_SpecificInt(V: LowMask)),
2229 R: m_Value(V&: XlYh))))
2230 std::swap(a&: C2, b&: C3);
2231 if (!match(V: C3, P: m_c_Add(L: m_OneUse(
2232 SubPattern: m_And(L: m_Specific(V: C2), R: m_SpecificInt(V: LowMask))),
2233 R: m_Value(V&: XlYh))) ||
2234 !C3->hasOneUse() || C2->hasNUsesOrMore(N: 3))
2235 return false;
2236
2237 // Match c2 = xh*yl + (xl*yl >> 32)
2238 if (!match(V: C2, P: m_c_Add(L: m_LShr(L: m_Value(V&: XlYl), R: m_SpecificInt(V: BitWidth / 2)),
2239 R: m_Value(V&: XhYl))))
2240 return false;
2241 }
2242
2243 // Match XhYl and XlYh - they can appear either way around.
2244 if (!CheckHiLo(XlYh, Y, X))
2245 std::swap(a&: XlYh, b&: XhYl);
2246 if (!CheckHiLo(XlYh, Y, X))
2247 return false;
2248 if (!CheckHiLo(XhYl, X, Y))
2249 return false;
2250 if (!CheckLoLo(XlYl, X, Y))
2251 return false;
2252
2253 return CreateMulHigh(X, Y);
2254 };
2255
2256 auto FoldMulHighLadder4 = [&](Value *X, Value *Y, Instruction *A,
2257 Instruction *B, Instruction *C) {
2258 /// Ladder4: xh*yh + (xl*yh)>>32 + (xh+yl)>>32 + low>>32;
2259 /// low = (xl*yl)>>32 + (xl*yh)&0xffffffff + (xh*yl)&0xffffffff
2260
2261 // Find A = Low >> 32 and B/C = XhYl>>32, XlYh>>32.
2262 auto ShiftAdd =
2263 m_LShr(L: m_Add(L: m_Value(), R: m_Value()), R: m_SpecificInt(V: BitWidth / 2));
2264 if (!match(V: A, P: ShiftAdd))
2265 std::swap(a&: A, b&: B);
2266 if (!match(V: A, P: ShiftAdd))
2267 std::swap(a&: A, b&: C);
2268 Value *Low;
2269 if (!match(V: A, P: m_LShr(L: m_OneUse(SubPattern: m_Value(V&: Low)), R: m_SpecificInt(V: BitWidth / 2))))
2270 return false;
2271
2272 // Match B == XhYl>>32 and C == XlYh>>32
2273 Value *XhYl, *XlYh;
2274 if (!match(V: B, P: m_LShr(L: m_Value(V&: XhYl), R: m_SpecificInt(V: BitWidth / 2))) ||
2275 !match(V: C, P: m_LShr(L: m_Value(V&: XlYh), R: m_SpecificInt(V: BitWidth / 2))))
2276 return false;
2277 if (!CheckHiLo(XhYl, X, Y))
2278 std::swap(a&: XhYl, b&: XlYh);
2279 if (!CheckHiLo(XhYl, X, Y) || XhYl->hasNUsesOrMore(N: 3))
2280 return false;
2281 if (!CheckHiLo(XlYh, Y, X) || XlYh->hasNUsesOrMore(N: 3))
2282 return false;
2283
2284 // Match Low as XlYl>>32 + XhYl&0xffffffff + XlYh&0xffffffff
2285 Value *XlYl;
2286 if (!match(
2287 V: Low,
2288 P: m_c_Add(
2289 L: m_OneUse(SubPattern: m_c_Add(
2290 L: m_OneUse(SubPattern: m_And(L: m_Specific(V: XhYl), R: m_SpecificInt(V: LowMask))),
2291 R: m_OneUse(SubPattern: m_And(L: m_Specific(V: XlYh), R: m_SpecificInt(V: LowMask))))),
2292 R: m_OneUse(
2293 SubPattern: m_LShr(L: m_Value(V&: XlYl), R: m_SpecificInt(V: BitWidth / 2))))) &&
2294 !match(
2295 V: Low,
2296 P: m_c_Add(
2297 L: m_OneUse(SubPattern: m_c_Add(
2298 L: m_OneUse(SubPattern: m_And(L: m_Specific(V: XhYl), R: m_SpecificInt(V: LowMask))),
2299 R: m_OneUse(
2300 SubPattern: m_LShr(L: m_Value(V&: XlYl), R: m_SpecificInt(V: BitWidth / 2))))),
2301 R: m_OneUse(SubPattern: m_And(L: m_Specific(V: XlYh), R: m_SpecificInt(V: LowMask))))) &&
2302 !match(
2303 V: Low,
2304 P: m_c_Add(
2305 L: m_OneUse(SubPattern: m_c_Add(
2306 L: m_OneUse(SubPattern: m_And(L: m_Specific(V: XlYh), R: m_SpecificInt(V: LowMask))),
2307 R: m_OneUse(
2308 SubPattern: m_LShr(L: m_Value(V&: XlYl), R: m_SpecificInt(V: BitWidth / 2))))),
2309 R: m_OneUse(SubPattern: m_And(L: m_Specific(V: XhYl), R: m_SpecificInt(V: LowMask))))))
2310 return false;
2311 if (!CheckLoLo(XlYl, X, Y))
2312 return false;
2313
2314 return CreateMulHigh(X, Y);
2315 };
2316
2317 auto FoldMulHighCarry4 = [&](Value *X, Value *Y, Instruction *Carry,
2318 Instruction *B, Instruction *C) {
2319 // xh*yh + carry + crosssum>>32 + (xl*yl + crosssum&0xffffffff) >> 32
2320 // crosssum = xh*yl+xl*yh
2321 // carry = crosssum < xh*yl ? 0x1000000 : 0
2322 if (Carry->getOpcode() != Instruction::Select)
2323 std::swap(a&: Carry, b&: B);
2324 if (Carry->getOpcode() != Instruction::Select)
2325 std::swap(a&: Carry, b&: C);
2326
2327 // Carry = CrossSum < XhYl ? 0x100000000 : 0
2328 Value *CrossSum, *XhYl;
2329 if (!match(V: Carry,
2330 P: m_OneUse(SubPattern: m_Select(
2331 C: m_OneUse(SubPattern: m_SpecificICmp(MatchPred: ICmpInst::ICMP_ULT,
2332 L: m_Value(V&: CrossSum), R: m_Value(V&: XhYl))),
2333 L: m_SpecificInt(V: APInt::getOneBitSet(numBits: BitWidth, BitNo: BitWidth / 2)),
2334 R: m_Zero()))))
2335 return false;
2336
2337 if (!match(V: B, P: m_LShr(L: m_Specific(V: CrossSum), R: m_SpecificInt(V: BitWidth / 2))))
2338 std::swap(a&: B, b&: C);
2339 if (!match(V: B, P: m_LShr(L: m_Specific(V: CrossSum), R: m_SpecificInt(V: BitWidth / 2))))
2340 return false;
2341
2342 Value *XlYl, *LowAccum;
2343 if (!match(V: C, P: m_LShr(L: m_Value(V&: LowAccum), R: m_SpecificInt(V: BitWidth / 2))) ||
2344 !match(V: LowAccum, P: m_c_Add(L: m_OneUse(SubPattern: m_LShr(L: m_Value(V&: XlYl),
2345 R: m_SpecificInt(V: BitWidth / 2))),
2346 R: m_OneUse(SubPattern: m_And(L: m_Specific(V: CrossSum),
2347 R: m_SpecificInt(V: LowMask))))) ||
2348 LowAccum->hasNUsesOrMore(N: 3))
2349 return false;
2350 if (!CheckLoLo(XlYl, X, Y))
2351 return false;
2352
2353 if (!CheckHiLo(XhYl, X, Y))
2354 std::swap(a&: X, b&: Y);
2355 if (!CheckHiLo(XhYl, X, Y))
2356 return false;
2357 Value *XlYh;
2358 if (!match(V: CrossSum, P: m_c_Add(L: m_Specific(V: XhYl), R: m_OneUse(SubPattern: m_Value(V&: XlYh)))) ||
2359 !CheckHiLo(XlYh, Y, X) || CrossSum->hasNUsesOrMore(N: 4) ||
2360 XhYl->hasNUsesOrMore(N: 3))
2361 return false;
2362
2363 return CreateMulHigh(X, Y);
2364 };
2365
2366 // X and Y are the two inputs, A, B and C are other parts of the pattern
2367 // (crosssum>>32, carry, etc).
2368 Value *X, *Y;
2369 Instruction *A, *B, *C;
2370 auto HiHi = m_OneUse(SubPattern: m_Mul(L: m_LShr(L: m_Value(V&: X), R: m_SpecificInt(V: BitWidth / 2)),
2371 R: m_LShr(L: m_Value(V&: Y), R: m_SpecificInt(V: BitWidth / 2))));
2372 if ((match(V: &I, P: m_c_Add(L: HiHi, R: m_OneUse(SubPattern: m_Add(L: m_Instruction(I&: A),
2373 R: m_Instruction(I&: B))))) ||
2374 match(V: &I, P: m_c_Add(L: m_Instruction(I&: A),
2375 R: m_OneUse(SubPattern: m_c_Add(L: HiHi, R: m_Instruction(I&: B)))))) &&
2376 A->hasOneUse() && B->hasOneUse())
2377 if (FoldMulHighCarry(X, Y, A, B) || FoldMulHighLadder(X, Y, A, B))
2378 return true;
2379
2380 if ((match(V: &I, P: m_c_Add(L: HiHi, R: m_OneUse(SubPattern: m_c_Add(
2381 L: m_Instruction(I&: A),
2382 R: m_OneUse(SubPattern: m_Add(L: m_Instruction(I&: B),
2383 R: m_Instruction(I&: C))))))) ||
2384 match(V: &I, P: m_c_Add(L: m_Instruction(I&: A),
2385 R: m_OneUse(SubPattern: m_c_Add(
2386 L: HiHi, R: m_OneUse(SubPattern: m_Add(L: m_Instruction(I&: B),
2387 R: m_Instruction(I&: C))))))) ||
2388 match(V: &I, P: m_c_Add(L: m_Instruction(I&: A),
2389 R: m_OneUse(SubPattern: m_c_Add(
2390 L: m_Instruction(I&: B),
2391 R: m_OneUse(SubPattern: m_c_Add(L: HiHi, R: m_Instruction(I&: C))))))) ||
2392 match(V: &I,
2393 P: m_c_Add(L: m_OneUse(SubPattern: m_c_Add(L: HiHi, R: m_Instruction(I&: A))),
2394 R: m_OneUse(SubPattern: m_Add(L: m_Instruction(I&: B), R: m_Instruction(I&: C)))))) &&
2395 A->hasOneUse() && B->hasOneUse() && C->hasOneUse())
2396 return FoldMulHighCarry4(X, Y, A, B, C) ||
2397 FoldMulHighLadder4(X, Y, A, B, C);
2398
2399 return false;
2400}
2401
2402/// This is the entry point for folds that could be implemented in regular
2403/// InstCombine, but they are separated because they are not expected to
2404/// occur frequently and/or have more than a constant-length pattern match.
2405static bool foldUnusualPatterns(Function &F, DominatorTree &DT,
2406 TargetTransformInfo &TTI,
2407 TargetLibraryInfo &TLI, AliasAnalysis &AA,
2408 AssumptionCache &AC, bool &MadeCFGChange) {
2409 bool MadeChange = false;
2410 for (BasicBlock &BB : F) {
2411 // Ignore unreachable basic blocks.
2412 if (!DT.isReachableFromEntry(A: &BB))
2413 continue;
2414
2415 const DataLayout &DL = F.getDataLayout();
2416
2417 // Walk the block backwards for efficiency. We're matching a chain of
2418 // use->defs, so we're more likely to succeed by starting from the bottom.
2419 // Also, we want to avoid matching partial patterns.
2420 // TODO: It would be more efficient if we removed dead instructions
2421 // iteratively in this loop rather than waiting until the end.
2422 for (Instruction &I : make_early_inc_range(Range: llvm::reverse(C&: BB))) {
2423 MadeChange |= foldAnyOrAllBitsSet(I);
2424 MadeChange |= foldGuardedFunnelShift(I, DT);
2425 MadeChange |= foldSelectSplitCTLZCTTZ(I);
2426 MadeChange |= tryToRecognizePopCount(I);
2427 MadeChange |= tryToRecognizePopCount1(I);
2428 MadeChange |= tryToRecognizePopCount2n3(I);
2429 MadeChange |= tryToFPToSat(I, TTI);
2430 MadeChange |= tryToRecognizeTableBasedCttz(I, DL);
2431 MadeChange |= tryToRecognizeTableBasedLog2(I, DL, TTI);
2432 MadeChange |= foldConsecutiveLoads(I, DL, TTI, AA, DT);
2433 MadeChange |= foldPatternedLoads(I, DL);
2434 MadeChange |= foldICmpOrChain(I, DL, TTI, AA, DT);
2435 MadeChange |= foldMulHigh(I);
2436 // NOTE: This function introduces erasing of the instruction `I`, so it
2437 // needs to be called at the end of this sequence, otherwise we may make
2438 // bugs.
2439 MadeChange |= foldLibCalls(I, TTI, TLI, AC, DT, DL, MadeCFGChange);
2440 }
2441
2442 // Do this separately to avoid redundantly scanning stores multiple times.
2443 MadeChange |= foldConsecutiveStores(BB, DL, TTI, AA);
2444 }
2445
2446 // We're done with transforms, so remove dead instructions.
2447 if (MadeChange)
2448 for (BasicBlock &BB : F)
2449 SimplifyInstructionsInBlock(BB: &BB);
2450
2451 return MadeChange;
2452}
2453
2454/// This is the entry point for all transforms. Pass manager differences are
2455/// handled in the callers of this function.
2456static bool runImpl(Function &F, AssumptionCache &AC, TargetTransformInfo &TTI,
2457 TargetLibraryInfo &TLI, DominatorTree &DT,
2458 AliasAnalysis &AA, bool &MadeCFGChange) {
2459 bool MadeChange = false;
2460 const DataLayout &DL = F.getDataLayout();
2461 TruncInstCombine TIC(AC, TLI, DL, DT);
2462 MadeChange |= TIC.run(F);
2463 MadeChange |= foldUnusualPatterns(F, DT, TTI, TLI, AA, AC, MadeCFGChange);
2464 return MadeChange;
2465}
2466
2467PreservedAnalyses AggressiveInstCombinePass::run(Function &F,
2468 FunctionAnalysisManager &AM) {
2469 auto &AC = AM.getResult<AssumptionAnalysis>(IR&: F);
2470 auto &TLI = AM.getResult<TargetLibraryAnalysis>(IR&: F);
2471 auto &DT = AM.getResult<DominatorTreeAnalysis>(IR&: F);
2472 auto &TTI = AM.getResult<TargetIRAnalysis>(IR&: F);
2473 auto &AA = AM.getResult<AAManager>(IR&: F);
2474 bool MadeCFGChange = false;
2475 if (!runImpl(F, AC, TTI, TLI, DT, AA, MadeCFGChange)) {
2476 // No changes, all analyses are preserved.
2477 return PreservedAnalyses::all();
2478 }
2479 // Mark all the analyses that instcombine updates as preserved.
2480 PreservedAnalyses PA;
2481 if (MadeCFGChange)
2482 PA.preserve<DominatorTreeAnalysis>();
2483 else
2484 PA.preserveSet<CFGAnalyses>();
2485 return PA;
2486}
2487