1//===-------- LoopIdiomVectorize.cpp - Loop idiom vectorization -----------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This pass implements a pass that recognizes certain loop idioms and
10// transforms them into more optimized versions of the same loop. In cases
11// where this happens, it can be a significant performance win.
12//
13// We currently support two loops:
14//
15// 1. A loop that finds the first mismatched byte in an array and returns the
16// index, i.e. something like:
17//
18// while (++i != n) {
19// if (a[i] != b[i])
20// break;
21// }
22//
23// In this example we can actually vectorize the loop despite the early exit,
24// although the loop vectorizer does not support it. It requires some extra
25// checks to deal with the possibility of faulting loads when crossing page
26// boundaries. However, even with these checks it is still profitable to do the
27// transformation.
28//
29// TODO List:
30//
31// * Add support for the inverse case where we scan for a matching element.
32// * Permit 64-bit induction variable types.
33// * Recognize loops that increment the IV *after* comparing bytes.
34// * Allow 32-bit sign-extends of the IV used by the GEP.
35//
36// 2. A loop that finds the first matching character in an array among a set of
37// possible matches, e.g.:
38//
39// for (; first != last; ++first)
40// for (s_it = s_first; s_it != s_last; ++s_it)
41// if (*first == *s_it)
42// return first;
43// return last;
44//
45// This corresponds to std::find_first_of (for arrays of bytes) from the C++
46// standard library. This function can be implemented efficiently for targets
47// that support @llvm.experimental.vector.match. For example, on AArch64 targets
48// that implement SVE2, this lower to a MATCH instruction, which enables us to
49// perform up to 16x16=256 comparisons in one go. This can lead to very
50// significant speedups.
51//
52// TODO:
53//
54// * Add support for `find_first_not_of' loops (i.e. with not-equal comparison).
55// * Make VF a configurable parameter (right now we assume 128-bit vectors).
56// * Potentially adjust the cost model to let the transformation kick-in even if
57// @llvm.experimental.vector.match doesn't have direct support in hardware.
58//
59//===----------------------------------------------------------------------===//
60//
61// NOTE: This Pass matches really specific loop patterns because it's only
62// supposed to be a temporary solution until our LoopVectorizer is powerful
63// enough to vectorize them automatically.
64//
65//===----------------------------------------------------------------------===//
66
67#include "llvm/Transforms/Vectorize/LoopIdiomVectorize.h"
68#include "llvm/Analysis/DomTreeUpdater.h"
69#include "llvm/Analysis/LoopPass.h"
70#include "llvm/Analysis/OptimizationRemarkEmitter.h"
71#include "llvm/Analysis/TargetTransformInfo.h"
72#include "llvm/IR/Dominators.h"
73#include "llvm/IR/IRBuilder.h"
74#include "llvm/IR/Intrinsics.h"
75#include "llvm/IR/MDBuilder.h"
76#include "llvm/IR/PatternMatch.h"
77#include "llvm/Transforms/Utils/BasicBlockUtils.h"
78#include "llvm/Transforms/Vectorize/LoopVectorizationLegality.h"
79
80using namespace llvm;
81using namespace PatternMatch;
82
83#define DEBUG_TYPE "loop-idiom-vectorize"
84
85static cl::opt<bool> DisableAll("disable-loop-idiom-vectorize-all", cl::Hidden,
86 cl::init(Val: false),
87 cl::desc("Disable Loop Idiom Vectorize Pass."));
88
89static cl::opt<LoopIdiomVectorizeStyle>
90 LITVecStyle("loop-idiom-vectorize-style", cl::Hidden,
91 cl::desc("The vectorization style for loop idiom transform."),
92 cl::values(clEnumValN(LoopIdiomVectorizeStyle::Masked, "masked",
93 "Use masked vector intrinsics"),
94 clEnumValN(LoopIdiomVectorizeStyle::Predicated,
95 "predicated", "Use VP intrinsics")),
96 cl::init(Val: LoopIdiomVectorizeStyle::Masked));
97
98static cl::opt<bool>
99 DisableByteCmp("disable-loop-idiom-vectorize-bytecmp", cl::Hidden,
100 cl::init(Val: false),
101 cl::desc("Proceed with Loop Idiom Vectorize Pass, but do "
102 "not convert byte-compare loop(s)."));
103
104static cl::opt<unsigned>
105 ByteCmpVF("loop-idiom-vectorize-bytecmp-vf", cl::Hidden,
106 cl::desc("The vectorization factor for byte-compare patterns."),
107 cl::init(Val: 16));
108
109static cl::opt<bool>
110 DisableFindFirstByte("disable-loop-idiom-vectorize-find-first-byte",
111 cl::Hidden, cl::init(Val: false),
112 cl::desc("Do not convert find-first-byte loop(s)."));
113
114static cl::opt<bool>
115 VerifyLoops("loop-idiom-vectorize-verify", cl::Hidden, cl::init(Val: false),
116 cl::desc("Verify loops generated Loop Idiom Vectorize Pass."));
117
118namespace {
119class LoopIdiomVectorize {
120 LoopIdiomVectorizeStyle VectorizeStyle;
121 unsigned ByteCompareVF;
122 Loop *CurLoop = nullptr;
123 DominatorTree *DT;
124 LoopInfo *LI;
125 const TargetTransformInfo *TTI;
126 const DataLayout *DL;
127
128 /// Interface to emit optimization remarks.
129 OptimizationRemarkEmitter &ORE;
130
131 // Blocks that will be used for inserting vectorized code.
132 BasicBlock *EndBlock = nullptr;
133 BasicBlock *VectorLoopPreheaderBlock = nullptr;
134 BasicBlock *VectorLoopStartBlock = nullptr;
135 BasicBlock *VectorLoopMismatchBlock = nullptr;
136 BasicBlock *VectorLoopIncBlock = nullptr;
137
138public:
139 LoopIdiomVectorize(LoopIdiomVectorizeStyle S, unsigned VF, DominatorTree *DT,
140 LoopInfo *LI, const TargetTransformInfo *TTI,
141 const DataLayout *DL, OptimizationRemarkEmitter &ORE)
142 : VectorizeStyle(S), ByteCompareVF(VF), DT(DT), LI(LI), TTI(TTI), DL(DL),
143 ORE(ORE) {}
144
145 bool run(Loop *L);
146
147private:
148 /// \name Countable Loop Idiom Handling
149 /// @{
150
151 bool runOnCountableLoop();
152 bool runOnLoopBlock(BasicBlock *BB, const SCEV *BECount,
153 SmallVectorImpl<BasicBlock *> &ExitBlocks);
154
155 bool recognizeByteCompare();
156
157 Value *expandFindMismatch(IRBuilder<> &Builder, DomTreeUpdater &DTU,
158 GetElementPtrInst *GEPA, GetElementPtrInst *GEPB,
159 Instruction *Index, Value *Start, Value *MaxLen);
160
161 Value *createMaskedFindMismatch(IRBuilder<> &Builder, DomTreeUpdater &DTU,
162 GetElementPtrInst *GEPA,
163 GetElementPtrInst *GEPB, Value *ExtStart,
164 Value *ExtEnd);
165 Value *createPredicatedFindMismatch(IRBuilder<> &Builder, DomTreeUpdater &DTU,
166 GetElementPtrInst *GEPA,
167 GetElementPtrInst *GEPB, Value *ExtStart,
168 Value *ExtEnd);
169
170 void transformByteCompare(GetElementPtrInst *GEPA, GetElementPtrInst *GEPB,
171 PHINode *IndPhi, Value *MaxLen, Instruction *Index,
172 Value *Start, bool IncIdx, BasicBlock *FoundBB,
173 BasicBlock *EndBB);
174
175 bool recognizeFindFirstByte();
176
177 Value *expandFindFirstByte(IRBuilder<> &Builder, DomTreeUpdater &DTU,
178 unsigned VF, Type *CharTy, Value *IndPhi,
179 BasicBlock *ExitSucc, BasicBlock *ExitFail,
180 Value *SearchStart, Value *SearchEnd,
181 Value *NeedleStart, Value *NeedleEnd);
182
183 void transformFindFirstByte(PHINode *IndPhi, unsigned VF, Type *CharTy,
184 BasicBlock *ExitSucc, BasicBlock *ExitFail,
185 Value *SearchStart, Value *SearchEnd,
186 Value *NeedleStart, Value *NeedleEnd);
187 /// @}
188};
189} // anonymous namespace
190
191PreservedAnalyses LoopIdiomVectorizePass::run(Loop &L, LoopAnalysisManager &AM,
192 LoopStandardAnalysisResults &AR,
193 LPMUpdater &) {
194 if (DisableAll)
195 return PreservedAnalyses::all();
196
197 const auto *DL = &L.getHeader()->getDataLayout();
198
199 LoopIdiomVectorizeStyle VecStyle = VectorizeStyle;
200 if (LITVecStyle.getNumOccurrences())
201 VecStyle = LITVecStyle;
202
203 unsigned BCVF = ByteCompareVF;
204 if (ByteCmpVF.getNumOccurrences())
205 BCVF = ByteCmpVF;
206
207 Function &F = *L.getHeader()->getParent();
208 auto &FAMP = AM.getResult<FunctionAnalysisManagerLoopProxy>(IR&: L, ExtraArgs&: AR);
209 auto *ORE = FAMP.getCachedResult<OptimizationRemarkEmitterAnalysis>(IR&: F);
210
211 std::optional<OptimizationRemarkEmitter> ORELocal;
212 if (!ORE) {
213 ORELocal.emplace(args: &F);
214 ORE = &*ORELocal;
215 }
216
217 LoopIdiomVectorize LIV(VecStyle, BCVF, &AR.DT, &AR.LI, &AR.TTI, DL, *ORE);
218 if (!LIV.run(L: &L))
219 return PreservedAnalyses::all();
220
221 return PreservedAnalyses::none();
222}
223
224//===----------------------------------------------------------------------===//
225//
226// Implementation of LoopIdiomVectorize
227//
228//===----------------------------------------------------------------------===//
229
230bool LoopIdiomVectorize::run(Loop *L) {
231 CurLoop = L;
232
233 Function &F = *L->getHeader()->getParent();
234 if (DisableAll || F.hasOptSize())
235 return false;
236
237 // Bail if vectorization is disabled on loop.
238 LoopVectorizeHints Hints(L, /*InterleaveOnlyWhenForced=*/true, ORE);
239 if (!Hints.allowVectorization(F: &F, L, /*VectorizeOnlyWhenForced=*/false)) {
240 LLVM_DEBUG(dbgs() << DEBUG_TYPE << " is disabled on " << L->getName()
241 << " due to vectorization hints\n");
242 return false;
243 }
244
245 if (F.hasFnAttribute(Kind: Attribute::NoImplicitFloat)) {
246 LLVM_DEBUG(dbgs() << DEBUG_TYPE << " is disabled on " << F.getName()
247 << " due to its NoImplicitFloat attribute");
248 return false;
249 }
250
251 // If the loop could not be converted to canonical form, it must have an
252 // indirectbr in it, just give up.
253 if (!L->getLoopPreheader())
254 return false;
255
256 LLVM_DEBUG(dbgs() << DEBUG_TYPE " Scanning: F[" << F.getName() << "] Loop %"
257 << CurLoop->getHeader()->getName() << "\n");
258
259 if (recognizeByteCompare())
260 return true;
261
262 if (recognizeFindFirstByte())
263 return true;
264
265 return false;
266}
267
268static void fixSuccessorPhis(Loop *L, Value *ScalarRes, Value *VectorRes,
269 BasicBlock *SuccBB, BasicBlock *IncBB) {
270 for (PHINode &PN : SuccBB->phis()) {
271 // Look through the incoming values to find ScalarRes, meaning this is a
272 // PHI collecting the results of the transformation.
273 bool ResPhi = false;
274 for (Value *Op : PN.incoming_values())
275 if (Op == ScalarRes) {
276 ResPhi = true;
277 break;
278 }
279
280 // Any PHI that depended upon the result of the transformation needs a new
281 // incoming value from IncBB.
282 if (ResPhi)
283 PN.addIncoming(V: VectorRes, BB: IncBB);
284 else {
285 // There should be no other outside uses of other values in the
286 // original loop. Any incoming values should either:
287 // 1. Be for blocks outside the loop, which aren't interesting. Or ..
288 // 2. These are from blocks in the loop with values defined outside
289 // the loop. We should a similar incoming value from CmpBB.
290 for (BasicBlock *BB : PN.blocks())
291 if (L->contains(BB)) {
292 PN.addIncoming(V: PN.getIncomingValueForBlock(BB), BB: IncBB);
293 break;
294 }
295 }
296 }
297}
298
299bool LoopIdiomVectorize::recognizeByteCompare() {
300 // Currently the transformation only works on scalable vector types, although
301 // there is no fundamental reason why it cannot be made to work for fixed
302 // width too.
303
304 // We also need to know the minimum page size for the target in order to
305 // generate runtime memory checks to ensure the vector version won't fault.
306 if (!TTI->supportsScalableVectors() || !TTI->getMinPageSize().has_value() ||
307 DisableByteCmp)
308 return false;
309
310 BasicBlock *Header = CurLoop->getHeader();
311
312 // In LoopIdiomVectorize::run we have already checked that the loop
313 // has a preheader so we can assume it's in a canonical form.
314 if (CurLoop->getNumBackEdges() != 1 || CurLoop->getNumBlocks() != 2)
315 return false;
316
317 PHINode *PN = dyn_cast<PHINode>(Val: &Header->front());
318 if (!PN || PN->getNumIncomingValues() != 2)
319 return false;
320
321 auto LoopBlocks = CurLoop->getBlocks();
322 // The first block in the loop should contain only 4 instructions, e.g.
323 //
324 // while.cond:
325 // %res.phi = phi i32 [ %start, %ph ], [ %inc, %while.body ]
326 // %inc = add i32 %res.phi, 1
327 // %cmp.not = icmp eq i32 %inc, %n
328 // br i1 %cmp.not, label %while.end, label %while.body
329 //
330 if (LoopBlocks[0]->sizeWithoutDebug() > 4)
331 return false;
332
333 // The second block should contain 7 instructions, e.g.
334 //
335 // while.body:
336 // %idx = zext i32 %inc to i64
337 // %idx.a = getelementptr inbounds i8, ptr %a, i64 %idx
338 // %load.a = load i8, ptr %idx.a
339 // %idx.b = getelementptr inbounds i8, ptr %b, i64 %idx
340 // %load.b = load i8, ptr %idx.b
341 // %cmp.not.ld = icmp eq i8 %load.a, %load.b
342 // br i1 %cmp.not.ld, label %while.cond, label %while.end
343 //
344 if (LoopBlocks[1]->sizeWithoutDebug() > 7)
345 return false;
346
347 // The incoming value to the PHI node from the loop should be an add of 1.
348 Value *StartIdx = nullptr;
349 Instruction *Index = nullptr;
350 if (!CurLoop->contains(BB: PN->getIncomingBlock(i: 0))) {
351 StartIdx = PN->getIncomingValue(i: 0);
352 Index = dyn_cast<Instruction>(Val: PN->getIncomingValue(i: 1));
353 } else {
354 StartIdx = PN->getIncomingValue(i: 1);
355 Index = dyn_cast<Instruction>(Val: PN->getIncomingValue(i: 0));
356 }
357
358 // Limit to 32-bit types for now
359 if (!Index || !Index->getType()->isIntegerTy(Bitwidth: 32) ||
360 !match(V: Index, P: m_c_Add(L: m_Specific(V: PN), R: m_One())))
361 return false;
362
363 // If we match the pattern, PN and Index will be replaced with the result of
364 // the cttz.elts intrinsic. If any other instructions are used outside of
365 // the loop, we cannot replace it.
366 for (BasicBlock *BB : LoopBlocks)
367 for (Instruction &I : *BB)
368 if (&I != PN && &I != Index)
369 for (User *U : I.users())
370 if (!CurLoop->contains(Inst: cast<Instruction>(Val: U)))
371 return false;
372
373 // Match the branch instruction for the header
374 Value *MaxLen;
375 BasicBlock *EndBB, *WhileBB;
376 if (!match(V: Header->getTerminator(),
377 P: m_Br(C: m_SpecificICmp(MatchPred: ICmpInst::ICMP_EQ, L: m_Specific(V: Index),
378 R: m_Value(V&: MaxLen)),
379 T: m_BasicBlock(V&: EndBB), F: m_BasicBlock(V&: WhileBB))) ||
380 !CurLoop->contains(BB: WhileBB))
381 return false;
382
383 // WhileBB should contain the pattern of load & compare instructions. Match
384 // the pattern and find the GEP instructions used by the loads.
385 BasicBlock *FoundBB;
386 BasicBlock *TrueBB;
387 Value *LoadA, *LoadB;
388 if (!match(V: WhileBB->getTerminator(),
389 P: m_Br(C: m_SpecificICmp(MatchPred: ICmpInst::ICMP_EQ, L: m_Value(V&: LoadA),
390 R: m_Value(V&: LoadB)),
391 T: m_BasicBlock(V&: TrueBB), F: m_BasicBlock(V&: FoundBB))) ||
392 !CurLoop->contains(BB: TrueBB))
393 return false;
394
395 Value *A, *B;
396 if (!match(V: LoadA, P: m_Load(Op: m_Value(V&: A))) || !match(V: LoadB, P: m_Load(Op: m_Value(V&: B))))
397 return false;
398
399 LoadInst *LoadAI = cast<LoadInst>(Val: LoadA);
400 LoadInst *LoadBI = cast<LoadInst>(Val: LoadB);
401 if (!LoadAI->isSimple() || !LoadBI->isSimple())
402 return false;
403
404 GetElementPtrInst *GEPA = dyn_cast<GetElementPtrInst>(Val: A);
405 GetElementPtrInst *GEPB = dyn_cast<GetElementPtrInst>(Val: B);
406
407 if (!GEPA || !GEPB)
408 return false;
409
410 Value *PtrA = GEPA->getPointerOperand();
411 Value *PtrB = GEPB->getPointerOperand();
412
413 // Check we are loading i8 values from two loop invariant pointers
414 if (!CurLoop->isLoopInvariant(V: PtrA) || !CurLoop->isLoopInvariant(V: PtrB) ||
415 !GEPA->getResultElementType()->isIntegerTy(Bitwidth: 8) ||
416 !GEPB->getResultElementType()->isIntegerTy(Bitwidth: 8) ||
417 !LoadAI->getType()->isIntegerTy(Bitwidth: 8) ||
418 !LoadBI->getType()->isIntegerTy(Bitwidth: 8) || PtrA == PtrB)
419 return false;
420
421 // Check that the index to the GEPs is the index we found earlier
422 if (GEPA->getNumIndices() > 1 || GEPB->getNumIndices() > 1)
423 return false;
424
425 Value *IdxA = GEPA->getOperand(i_nocapture: GEPA->getNumIndices());
426 Value *IdxB = GEPB->getOperand(i_nocapture: GEPB->getNumIndices());
427 if (IdxA != IdxB || !match(V: IdxA, P: m_ZExt(Op: m_Specific(V: Index))))
428 return false;
429
430 // We only ever expect the pre-incremented index value to be used inside the
431 // loop.
432 if (!PN->hasOneUse())
433 return false;
434
435 // Ensure that when the Found and End blocks are identical the PHIs have the
436 // supported format. We don't currently allow cases like this:
437 // while.cond:
438 // ...
439 // br i1 %cmp.not, label %while.end, label %while.body
440 //
441 // while.body:
442 // ...
443 // br i1 %cmp.not2, label %while.cond, label %while.end
444 //
445 // while.end:
446 // %final_ptr = phi ptr [ %c, %while.body ], [ %d, %while.cond ]
447 //
448 // Where the incoming values for %final_ptr are unique and from each of the
449 // loop blocks, but not actually defined in the loop. This requires extra
450 // work setting up the byte.compare block, i.e. by introducing a select to
451 // choose the correct value.
452 // TODO: We could add support for this in future.
453 if (FoundBB == EndBB) {
454 for (PHINode &EndPN : EndBB->phis()) {
455 Value *WhileCondVal = EndPN.getIncomingValueForBlock(BB: Header);
456 Value *WhileBodyVal = EndPN.getIncomingValueForBlock(BB: WhileBB);
457
458 // The value of the index when leaving the while.cond block is always the
459 // same as the end value (MaxLen) so we permit either. The value when
460 // leaving the while.body block should only be the index. Otherwise for
461 // any other values we only allow ones that are same for both blocks.
462 if (WhileCondVal != WhileBodyVal &&
463 ((WhileCondVal != Index && WhileCondVal != MaxLen) ||
464 (WhileBodyVal != Index)))
465 return false;
466 }
467 }
468
469 LLVM_DEBUG(dbgs() << "FOUND IDIOM IN LOOP: \n"
470 << *(EndBB->getParent()) << "\n\n");
471
472 // The index is incremented before the GEP/Load pair so we need to
473 // add 1 to the start value.
474 transformByteCompare(GEPA, GEPB, IndPhi: PN, MaxLen, Index, Start: StartIdx, /*IncIdx=*/true,
475 FoundBB, EndBB);
476 return true;
477}
478
479Value *LoopIdiomVectorize::createMaskedFindMismatch(
480 IRBuilder<> &Builder, DomTreeUpdater &DTU, GetElementPtrInst *GEPA,
481 GetElementPtrInst *GEPB, Value *ExtStart, Value *ExtEnd) {
482 Type *I64Type = Builder.getInt64Ty();
483 Type *ResType = Builder.getInt32Ty();
484 Type *LoadType = Builder.getInt8Ty();
485 Value *PtrA = GEPA->getPointerOperand();
486 Value *PtrB = GEPB->getPointerOperand();
487
488 ScalableVectorType *PredVTy =
489 ScalableVectorType::get(ElementType: Builder.getInt1Ty(), MinNumElts: ByteCompareVF);
490
491 Value *InitialPred = Builder.CreateIntrinsic(
492 ID: Intrinsic::get_active_lane_mask, Types: {PredVTy, I64Type}, Args: {ExtStart, ExtEnd});
493
494 Value *VecLen = Builder.CreateVScale(Ty: I64Type);
495 VecLen =
496 Builder.CreateMul(LHS: VecLen, RHS: ConstantInt::get(Ty: I64Type, V: ByteCompareVF), Name: "",
497 /*HasNUW=*/true, /*HasNSW=*/true);
498
499 Value *PFalse = Builder.CreateVectorSplat(EC: PredVTy->getElementCount(),
500 V: Builder.getInt1(V: false));
501
502 Builder.CreateBr(Dest: VectorLoopStartBlock);
503
504 DTU.applyUpdates(Updates: {{DominatorTree::Insert, VectorLoopPreheaderBlock,
505 VectorLoopStartBlock}});
506
507 // Set up the first vector loop block by creating the PHIs, doing the vector
508 // loads and comparing the vectors.
509 Builder.SetInsertPoint(VectorLoopStartBlock);
510 PHINode *LoopPred = Builder.CreatePHI(Ty: PredVTy, NumReservedValues: 2, Name: "mismatch_vec_loop_pred");
511 LoopPred->addIncoming(V: InitialPred, BB: VectorLoopPreheaderBlock);
512 PHINode *VectorIndexPhi = Builder.CreatePHI(Ty: I64Type, NumReservedValues: 2, Name: "mismatch_vec_index");
513 VectorIndexPhi->addIncoming(V: ExtStart, BB: VectorLoopPreheaderBlock);
514 Type *VectorLoadType =
515 ScalableVectorType::get(ElementType: Builder.getInt8Ty(), MinNumElts: ByteCompareVF);
516 Value *Passthru = ConstantInt::getNullValue(Ty: VectorLoadType);
517
518 Value *VectorLhsGep =
519 Builder.CreateGEP(Ty: LoadType, Ptr: PtrA, IdxList: VectorIndexPhi, Name: "", NW: GEPA->isInBounds());
520 Value *VectorLhsLoad = Builder.CreateMaskedLoad(Ty: VectorLoadType, Ptr: VectorLhsGep,
521 Alignment: Align(1), Mask: LoopPred, PassThru: Passthru);
522
523 Value *VectorRhsGep =
524 Builder.CreateGEP(Ty: LoadType, Ptr: PtrB, IdxList: VectorIndexPhi, Name: "", NW: GEPB->isInBounds());
525 Value *VectorRhsLoad = Builder.CreateMaskedLoad(Ty: VectorLoadType, Ptr: VectorRhsGep,
526 Alignment: Align(1), Mask: LoopPred, PassThru: Passthru);
527
528 Value *VectorMatchCmp = Builder.CreateICmpNE(LHS: VectorLhsLoad, RHS: VectorRhsLoad);
529 VectorMatchCmp = Builder.CreateSelect(C: LoopPred, True: VectorMatchCmp, False: PFalse);
530 Value *VectorMatchHasActiveLanes = Builder.CreateOrReduce(Src: VectorMatchCmp);
531 Builder.CreateCondBr(Cond: VectorMatchHasActiveLanes, True: VectorLoopMismatchBlock,
532 False: VectorLoopIncBlock);
533
534 DTU.applyUpdates(
535 Updates: {{DominatorTree::Insert, VectorLoopStartBlock, VectorLoopMismatchBlock},
536 {DominatorTree::Insert, VectorLoopStartBlock, VectorLoopIncBlock}});
537
538 // Increment the index counter and calculate the predicate for the next
539 // iteration of the loop. We branch back to the start of the loop if there
540 // is at least one active lane.
541 Builder.SetInsertPoint(VectorLoopIncBlock);
542 Value *NewVectorIndexPhi =
543 Builder.CreateAdd(LHS: VectorIndexPhi, RHS: VecLen, Name: "",
544 /*HasNUW=*/true, /*HasNSW=*/true);
545 VectorIndexPhi->addIncoming(V: NewVectorIndexPhi, BB: VectorLoopIncBlock);
546 Value *NewPred =
547 Builder.CreateIntrinsic(ID: Intrinsic::get_active_lane_mask,
548 Types: {PredVTy, I64Type}, Args: {NewVectorIndexPhi, ExtEnd});
549 LoopPred->addIncoming(V: NewPred, BB: VectorLoopIncBlock);
550
551 Value *PredHasActiveLanes =
552 Builder.CreateExtractElement(Vec: NewPred, Idx: uint64_t(0));
553 Builder.CreateCondBr(Cond: PredHasActiveLanes, True: VectorLoopStartBlock, False: EndBlock);
554
555 DTU.applyUpdates(
556 Updates: {{DominatorTree::Insert, VectorLoopIncBlock, VectorLoopStartBlock},
557 {DominatorTree::Insert, VectorLoopIncBlock, EndBlock}});
558
559 // If we found a mismatch then we need to calculate which lane in the vector
560 // had a mismatch and add that on to the current loop index.
561 Builder.SetInsertPoint(VectorLoopMismatchBlock);
562 PHINode *FoundPred = Builder.CreatePHI(Ty: PredVTy, NumReservedValues: 1, Name: "mismatch_vec_found_pred");
563 FoundPred->addIncoming(V: VectorMatchCmp, BB: VectorLoopStartBlock);
564 PHINode *LastLoopPred =
565 Builder.CreatePHI(Ty: PredVTy, NumReservedValues: 1, Name: "mismatch_vec_last_loop_pred");
566 LastLoopPred->addIncoming(V: LoopPred, BB: VectorLoopStartBlock);
567 PHINode *VectorFoundIndex =
568 Builder.CreatePHI(Ty: I64Type, NumReservedValues: 1, Name: "mismatch_vec_found_index");
569 VectorFoundIndex->addIncoming(V: VectorIndexPhi, BB: VectorLoopStartBlock);
570
571 Value *PredMatchCmp = Builder.CreateAnd(LHS: LastLoopPred, RHS: FoundPred);
572 Value *Ctz = Builder.CreateCountTrailingZeroElems(ResTy: ResType, Mask: PredMatchCmp);
573 Ctz = Builder.CreateZExt(V: Ctz, DestTy: I64Type);
574 Value *VectorLoopRes64 = Builder.CreateAdd(LHS: VectorFoundIndex, RHS: Ctz, Name: "",
575 /*HasNUW=*/true, /*HasNSW=*/true);
576 return Builder.CreateTrunc(V: VectorLoopRes64, DestTy: ResType);
577}
578
579Value *LoopIdiomVectorize::createPredicatedFindMismatch(
580 IRBuilder<> &Builder, DomTreeUpdater &DTU, GetElementPtrInst *GEPA,
581 GetElementPtrInst *GEPB, Value *ExtStart, Value *ExtEnd) {
582 Type *I64Type = Builder.getInt64Ty();
583 Type *I32Type = Builder.getInt32Ty();
584 Type *ResType = I32Type;
585 Type *LoadType = Builder.getInt8Ty();
586 Value *PtrA = GEPA->getPointerOperand();
587 Value *PtrB = GEPB->getPointerOperand();
588
589 auto *JumpToVectorLoop = UncondBrInst::Create(IfTrue: VectorLoopStartBlock);
590 Builder.Insert(I: JumpToVectorLoop);
591
592 DTU.applyUpdates(Updates: {{DominatorTree::Insert, VectorLoopPreheaderBlock,
593 VectorLoopStartBlock}});
594
595 // Set up the first Vector loop block by creating the PHIs, doing the vector
596 // loads and comparing the vectors.
597 Builder.SetInsertPoint(VectorLoopStartBlock);
598 auto *VectorIndexPhi = Builder.CreatePHI(Ty: I64Type, NumReservedValues: 2, Name: "mismatch_vector_index");
599 VectorIndexPhi->addIncoming(V: ExtStart, BB: VectorLoopPreheaderBlock);
600
601 // Calculate AVL by subtracting the vector loop index from the trip count
602 Value *AVL = Builder.CreateSub(LHS: ExtEnd, RHS: VectorIndexPhi, Name: "avl", /*HasNUW=*/true,
603 /*HasNSW=*/true);
604
605 auto *VectorLoadType = ScalableVectorType::get(ElementType: LoadType, MinNumElts: ByteCompareVF);
606 auto *VF = ConstantInt::get(Ty: I32Type, V: ByteCompareVF);
607
608 Value *VL = Builder.CreateIntrinsic(ID: Intrinsic::experimental_get_vector_length,
609 Types: {I64Type}, Args: {AVL, VF, Builder.getTrue()});
610 Value *GepOffset = VectorIndexPhi;
611
612 Value *VectorLhsGep =
613 Builder.CreateGEP(Ty: LoadType, Ptr: PtrA, IdxList: GepOffset, Name: "", NW: GEPA->isInBounds());
614 VectorType *TrueMaskTy =
615 VectorType::get(ElementType: Builder.getInt1Ty(), EC: VectorLoadType->getElementCount());
616 Value *AllTrueMask = Constant::getAllOnesValue(Ty: TrueMaskTy);
617 Value *VectorLhsLoad = Builder.CreateIntrinsic(
618 ID: Intrinsic::vp_load, Types: {VectorLoadType, VectorLhsGep->getType()},
619 Args: {VectorLhsGep, AllTrueMask, VL}, FMFSource: nullptr, Name: "lhs.load");
620
621 Value *VectorRhsGep =
622 Builder.CreateGEP(Ty: LoadType, Ptr: PtrB, IdxList: GepOffset, Name: "", NW: GEPB->isInBounds());
623 Value *VectorRhsLoad = Builder.CreateIntrinsic(
624 ID: Intrinsic::vp_load, Types: {VectorLoadType, VectorLhsGep->getType()},
625 Args: {VectorRhsGep, AllTrueMask, VL}, FMFSource: nullptr, Name: "rhs.load");
626
627 Value *VectorMatchCmp =
628 Builder.CreateICmpNE(LHS: VectorLhsLoad, RHS: VectorRhsLoad, Name: "mismatch.cmp");
629 Value *CTZ = Builder.CreateIntrinsic(
630 ID: Intrinsic::vp_cttz_elts, Types: {ResType, VectorMatchCmp->getType()},
631 Args: {VectorMatchCmp, /*ZeroIsPoison=*/Builder.getInt1(V: false), AllTrueMask,
632 VL});
633 Value *MismatchFound = Builder.CreateICmpNE(LHS: CTZ, RHS: VL);
634 auto *VectorEarlyExit = CondBrInst::Create(
635 Cond: MismatchFound, IfTrue: VectorLoopMismatchBlock, IfFalse: VectorLoopIncBlock);
636 Builder.Insert(I: VectorEarlyExit);
637
638 DTU.applyUpdates(
639 Updates: {{DominatorTree::Insert, VectorLoopStartBlock, VectorLoopMismatchBlock},
640 {DominatorTree::Insert, VectorLoopStartBlock, VectorLoopIncBlock}});
641
642 // Increment the index counter and calculate the predicate for the next
643 // iteration of the loop. We branch back to the start of the loop if there
644 // is at least one active lane.
645 Builder.SetInsertPoint(VectorLoopIncBlock);
646 Value *VL64 = Builder.CreateZExt(V: VL, DestTy: I64Type);
647 Value *NewVectorIndexPhi =
648 Builder.CreateAdd(LHS: VectorIndexPhi, RHS: VL64, Name: "",
649 /*HasNUW=*/true, /*HasNSW=*/true);
650 VectorIndexPhi->addIncoming(V: NewVectorIndexPhi, BB: VectorLoopIncBlock);
651 Value *ExitCond = Builder.CreateICmpNE(LHS: NewVectorIndexPhi, RHS: ExtEnd);
652 auto *VectorLoopBranchBack =
653 CondBrInst::Create(Cond: ExitCond, IfTrue: VectorLoopStartBlock, IfFalse: EndBlock);
654 Builder.Insert(I: VectorLoopBranchBack);
655
656 DTU.applyUpdates(
657 Updates: {{DominatorTree::Insert, VectorLoopIncBlock, VectorLoopStartBlock},
658 {DominatorTree::Insert, VectorLoopIncBlock, EndBlock}});
659
660 // If we found a mismatch then we need to calculate which lane in the vector
661 // had a mismatch and add that on to the current loop index.
662 Builder.SetInsertPoint(VectorLoopMismatchBlock);
663
664 // Add LCSSA phis for CTZ and VectorIndexPhi.
665 auto *CTZLCSSAPhi = Builder.CreatePHI(Ty: CTZ->getType(), NumReservedValues: 1, Name: "ctz");
666 CTZLCSSAPhi->addIncoming(V: CTZ, BB: VectorLoopStartBlock);
667 auto *VectorIndexLCSSAPhi =
668 Builder.CreatePHI(Ty: VectorIndexPhi->getType(), NumReservedValues: 1, Name: "mismatch_vector_index");
669 VectorIndexLCSSAPhi->addIncoming(V: VectorIndexPhi, BB: VectorLoopStartBlock);
670
671 Value *CTZI64 = Builder.CreateZExt(V: CTZLCSSAPhi, DestTy: I64Type);
672 Value *VectorLoopRes64 = Builder.CreateAdd(LHS: VectorIndexLCSSAPhi, RHS: CTZI64, Name: "",
673 /*HasNUW=*/true, /*HasNSW=*/true);
674 return Builder.CreateTrunc(V: VectorLoopRes64, DestTy: ResType);
675}
676
677Value *LoopIdiomVectorize::expandFindMismatch(
678 IRBuilder<> &Builder, DomTreeUpdater &DTU, GetElementPtrInst *GEPA,
679 GetElementPtrInst *GEPB, Instruction *Index, Value *Start, Value *MaxLen) {
680 Value *PtrA = GEPA->getPointerOperand();
681 Value *PtrB = GEPB->getPointerOperand();
682
683 // Get the arguments and types for the intrinsic.
684 BasicBlock *Preheader = CurLoop->getLoopPreheader();
685 Instruction *PHBranch = Preheader->getTerminator();
686 LLVMContext &Ctx = PHBranch->getContext();
687 Type *LoadType = Type::getInt8Ty(C&: Ctx);
688 Type *ResType = Builder.getInt32Ty();
689
690 // Split block in the original loop preheader.
691 EndBlock = SplitBlock(Old: Preheader, SplitPt: PHBranch, DT, LI, MSSAU: nullptr, BBName: "mismatch_end");
692
693 // Create the blocks that we're going to need:
694 // 1. A block for checking the zero-extended length exceeds 0
695 // 2. A block to check that the start and end addresses of a given array
696 // lie on the same page.
697 // 3. The vector loop preheader.
698 // 4. The first vector loop block.
699 // 5. The vector loop increment block.
700 // 6. A block we can jump to from the vector loop when a mismatch is found.
701 // 7. The first block of the scalar loop itself, containing PHIs , loads
702 // and cmp.
703 // 8. A scalar loop increment block to increment the PHIs and go back
704 // around the loop.
705
706 BasicBlock *MinItCheckBlock = BasicBlock::Create(
707 Context&: Ctx, Name: "mismatch_min_it_check", Parent: EndBlock->getParent(), InsertBefore: EndBlock);
708
709 // Update the terminator added by SplitBlock to branch to the first block
710 Preheader->getTerminator()->setSuccessor(Idx: 0, BB: MinItCheckBlock);
711
712 BasicBlock *MemCheckBlock = BasicBlock::Create(
713 Context&: Ctx, Name: "mismatch_mem_check", Parent: EndBlock->getParent(), InsertBefore: EndBlock);
714
715 VectorLoopPreheaderBlock = BasicBlock::Create(
716 Context&: Ctx, Name: "mismatch_vec_loop_preheader", Parent: EndBlock->getParent(), InsertBefore: EndBlock);
717
718 VectorLoopStartBlock = BasicBlock::Create(Context&: Ctx, Name: "mismatch_vec_loop",
719 Parent: EndBlock->getParent(), InsertBefore: EndBlock);
720
721 VectorLoopIncBlock = BasicBlock::Create(Context&: Ctx, Name: "mismatch_vec_loop_inc",
722 Parent: EndBlock->getParent(), InsertBefore: EndBlock);
723
724 VectorLoopMismatchBlock = BasicBlock::Create(Context&: Ctx, Name: "mismatch_vec_loop_found",
725 Parent: EndBlock->getParent(), InsertBefore: EndBlock);
726
727 BasicBlock *LoopPreHeaderBlock = BasicBlock::Create(
728 Context&: Ctx, Name: "mismatch_loop_pre", Parent: EndBlock->getParent(), InsertBefore: EndBlock);
729
730 BasicBlock *LoopStartBlock =
731 BasicBlock::Create(Context&: Ctx, Name: "mismatch_loop", Parent: EndBlock->getParent(), InsertBefore: EndBlock);
732
733 BasicBlock *LoopIncBlock = BasicBlock::Create(
734 Context&: Ctx, Name: "mismatch_loop_inc", Parent: EndBlock->getParent(), InsertBefore: EndBlock);
735
736 DTU.applyUpdates(Updates: {{DominatorTree::Insert, Preheader, MinItCheckBlock},
737 {DominatorTree::Delete, Preheader, EndBlock}});
738
739 // Update LoopInfo with the new vector & scalar loops.
740 auto VectorLoop = LI->AllocateLoop();
741 auto ScalarLoop = LI->AllocateLoop();
742
743 if (CurLoop->getParentLoop()) {
744 CurLoop->getParentLoop()->addBasicBlockToLoop(NewBB: MinItCheckBlock, LI&: *LI);
745 CurLoop->getParentLoop()->addBasicBlockToLoop(NewBB: MemCheckBlock, LI&: *LI);
746 CurLoop->getParentLoop()->addBasicBlockToLoop(NewBB: VectorLoopPreheaderBlock,
747 LI&: *LI);
748 CurLoop->getParentLoop()->addChildLoop(NewChild: VectorLoop);
749 CurLoop->getParentLoop()->addBasicBlockToLoop(NewBB: VectorLoopMismatchBlock, LI&: *LI);
750 CurLoop->getParentLoop()->addBasicBlockToLoop(NewBB: LoopPreHeaderBlock, LI&: *LI);
751 CurLoop->getParentLoop()->addChildLoop(NewChild: ScalarLoop);
752 } else {
753 LI->addTopLevelLoop(New: VectorLoop);
754 LI->addTopLevelLoop(New: ScalarLoop);
755 }
756
757 // Add the new basic blocks to their associated loops.
758 VectorLoop->addBasicBlockToLoop(NewBB: VectorLoopStartBlock, LI&: *LI);
759 VectorLoop->addBasicBlockToLoop(NewBB: VectorLoopIncBlock, LI&: *LI);
760
761 ScalarLoop->addBasicBlockToLoop(NewBB: LoopStartBlock, LI&: *LI);
762 ScalarLoop->addBasicBlockToLoop(NewBB: LoopIncBlock, LI&: *LI);
763
764 // Set up some types and constants that we intend to reuse.
765 Type *I64Type = Builder.getInt64Ty();
766
767 // Check the zero-extended iteration count > 0
768 Builder.SetInsertPoint(MinItCheckBlock);
769 Value *ExtStart = Builder.CreateZExt(V: Start, DestTy: I64Type);
770 Value *ExtEnd = Builder.CreateZExt(V: MaxLen, DestTy: I64Type);
771 // This check doesn't really cost us very much.
772
773 Value *LimitCheck = Builder.CreateICmpULE(LHS: Start, RHS: MaxLen);
774 CondBrInst *MinItCheckBr =
775 CondBrInst::Create(Cond: LimitCheck, IfTrue: MemCheckBlock, IfFalse: LoopPreHeaderBlock);
776 MinItCheckBr->setMetadata(
777 KindID: LLVMContext::MD_prof,
778 Node: MDBuilder(MinItCheckBr->getContext()).createBranchWeights(TrueWeight: 99, FalseWeight: 1));
779 Builder.Insert(I: MinItCheckBr);
780
781 DTU.applyUpdates(
782 Updates: {{DominatorTree::Insert, MinItCheckBlock, MemCheckBlock},
783 {DominatorTree::Insert, MinItCheckBlock, LoopPreHeaderBlock}});
784
785 // For each of the arrays, check the start/end addresses are on the same
786 // page.
787 Builder.SetInsertPoint(MemCheckBlock);
788
789 // The early exit in the original loop means that when performing vector
790 // loads we are potentially reading ahead of the early exit. So we could
791 // fault if crossing a page boundary. Therefore, we create runtime memory
792 // checks based on the minimum page size as follows:
793 // 1. Calculate the addresses of the first memory accesses in the loop,
794 // i.e. LhsStart and RhsStart.
795 // 2. Get the last accessed addresses in the loop, i.e. LhsEnd and RhsEnd.
796 // 3. Determine which pages correspond to all the memory accesses, i.e
797 // LhsStartPage, LhsEndPage, RhsStartPage, RhsEndPage.
798 // 4. If LhsStartPage == LhsEndPage and RhsStartPage == RhsEndPage, then
799 // we know we won't cross any page boundaries in the loop so we can
800 // enter the vector loop! Otherwise we fall back on the scalar loop.
801 Value *LhsStartGEP = Builder.CreateGEP(Ty: LoadType, Ptr: PtrA, IdxList: ExtStart);
802 Value *RhsStartGEP = Builder.CreateGEP(Ty: LoadType, Ptr: PtrB, IdxList: ExtStart);
803 Value *RhsStart = Builder.CreatePtrToInt(V: RhsStartGEP, DestTy: I64Type);
804 Value *LhsStart = Builder.CreatePtrToInt(V: LhsStartGEP, DestTy: I64Type);
805 Value *LhsEndGEP = Builder.CreateGEP(Ty: LoadType, Ptr: PtrA, IdxList: ExtEnd);
806 Value *RhsEndGEP = Builder.CreateGEP(Ty: LoadType, Ptr: PtrB, IdxList: ExtEnd);
807 Value *LhsEnd = Builder.CreatePtrToInt(V: LhsEndGEP, DestTy: I64Type);
808 Value *RhsEnd = Builder.CreatePtrToInt(V: RhsEndGEP, DestTy: I64Type);
809
810 const uint64_t MinPageSize = TTI->getMinPageSize().value();
811 const uint64_t AddrShiftAmt = llvm::Log2_64(Value: MinPageSize);
812 Value *LhsStartPage = Builder.CreateLShr(LHS: LhsStart, RHS: AddrShiftAmt);
813 Value *LhsEndPage = Builder.CreateLShr(LHS: LhsEnd, RHS: AddrShiftAmt);
814 Value *RhsStartPage = Builder.CreateLShr(LHS: RhsStart, RHS: AddrShiftAmt);
815 Value *RhsEndPage = Builder.CreateLShr(LHS: RhsEnd, RHS: AddrShiftAmt);
816 Value *LhsPageCmp = Builder.CreateICmpNE(LHS: LhsStartPage, RHS: LhsEndPage);
817 Value *RhsPageCmp = Builder.CreateICmpNE(LHS: RhsStartPage, RHS: RhsEndPage);
818
819 Value *CombinedPageCmp = Builder.CreateOr(LHS: LhsPageCmp, RHS: RhsPageCmp);
820 CondBrInst *CombinedPageCmpCmpBr = CondBrInst::Create(
821 Cond: CombinedPageCmp, IfTrue: LoopPreHeaderBlock, IfFalse: VectorLoopPreheaderBlock);
822 CombinedPageCmpCmpBr->setMetadata(
823 KindID: LLVMContext::MD_prof, Node: MDBuilder(CombinedPageCmpCmpBr->getContext())
824 .createBranchWeights(TrueWeight: 10, FalseWeight: 90));
825 Builder.Insert(I: CombinedPageCmpCmpBr);
826
827 DTU.applyUpdates(
828 Updates: {{DominatorTree::Insert, MemCheckBlock, LoopPreHeaderBlock},
829 {DominatorTree::Insert, MemCheckBlock, VectorLoopPreheaderBlock}});
830
831 // Set up the vector loop preheader, i.e. calculate initial loop predicate,
832 // zero-extend MaxLen to 64-bits, determine the number of vector elements
833 // processed in each iteration, etc.
834 Builder.SetInsertPoint(VectorLoopPreheaderBlock);
835
836 // At this point we know two things must be true:
837 // 1. Start <= End
838 // 2. ExtMaxLen <= MinPageSize due to the page checks.
839 // Therefore, we know that we can use a 64-bit induction variable that
840 // starts from 0 -> ExtMaxLen and it will not overflow.
841 Value *VectorLoopRes = nullptr;
842 switch (VectorizeStyle) {
843 case LoopIdiomVectorizeStyle::Masked:
844 VectorLoopRes =
845 createMaskedFindMismatch(Builder, DTU, GEPA, GEPB, ExtStart, ExtEnd);
846 break;
847 case LoopIdiomVectorizeStyle::Predicated:
848 VectorLoopRes = createPredicatedFindMismatch(Builder, DTU, GEPA, GEPB,
849 ExtStart, ExtEnd);
850 break;
851 }
852
853 Builder.CreateBr(Dest: EndBlock);
854
855 DTU.applyUpdates(
856 Updates: {{DominatorTree::Insert, VectorLoopMismatchBlock, EndBlock}});
857
858 // Generate code for scalar loop.
859 Builder.SetInsertPoint(LoopPreHeaderBlock);
860 Builder.CreateBr(Dest: LoopStartBlock);
861
862 DTU.applyUpdates(
863 Updates: {{DominatorTree::Insert, LoopPreHeaderBlock, LoopStartBlock}});
864
865 Builder.SetInsertPoint(LoopStartBlock);
866 PHINode *IndexPhi = Builder.CreatePHI(Ty: ResType, NumReservedValues: 2, Name: "mismatch_index");
867 IndexPhi->addIncoming(V: Start, BB: LoopPreHeaderBlock);
868
869 // Otherwise compare the values
870 // Load bytes from each array and compare them.
871 Value *GepOffset = Builder.CreateZExt(V: IndexPhi, DestTy: I64Type);
872
873 Value *LhsGep =
874 Builder.CreateGEP(Ty: LoadType, Ptr: PtrA, IdxList: GepOffset, Name: "", NW: GEPA->isInBounds());
875 Value *LhsLoad = Builder.CreateLoad(Ty: LoadType, Ptr: LhsGep);
876
877 Value *RhsGep =
878 Builder.CreateGEP(Ty: LoadType, Ptr: PtrB, IdxList: GepOffset, Name: "", NW: GEPB->isInBounds());
879 Value *RhsLoad = Builder.CreateLoad(Ty: LoadType, Ptr: RhsGep);
880
881 Value *MatchCmp = Builder.CreateICmpEQ(LHS: LhsLoad, RHS: RhsLoad);
882 // If we have a mismatch then exit the loop ...
883 Builder.CreateCondBr(Cond: MatchCmp, True: LoopIncBlock, False: EndBlock);
884
885 DTU.applyUpdates(Updates: {{DominatorTree::Insert, LoopStartBlock, LoopIncBlock},
886 {DominatorTree::Insert, LoopStartBlock, EndBlock}});
887
888 // Have we reached the maximum permitted length for the loop?
889 Builder.SetInsertPoint(LoopIncBlock);
890 Value *PhiInc = Builder.CreateAdd(LHS: IndexPhi, RHS: ConstantInt::get(Ty: ResType, V: 1), Name: "",
891 /*HasNUW=*/Index->hasNoUnsignedWrap(),
892 /*HasNSW=*/Index->hasNoSignedWrap());
893 IndexPhi->addIncoming(V: PhiInc, BB: LoopIncBlock);
894 Value *IVCmp = Builder.CreateICmpEQ(LHS: PhiInc, RHS: MaxLen);
895 Builder.CreateCondBr(Cond: IVCmp, True: EndBlock, False: LoopStartBlock);
896
897 DTU.applyUpdates(Updates: {{DominatorTree::Insert, LoopIncBlock, EndBlock},
898 {DominatorTree::Insert, LoopIncBlock, LoopStartBlock}});
899
900 // In the end block we need to insert a PHI node to deal with three cases:
901 // 1. We didn't find a mismatch in the scalar loop, so we return MaxLen.
902 // 2. We exitted the scalar loop early due to a mismatch and need to return
903 // the index that we found.
904 // 3. We didn't find a mismatch in the vector loop, so we return MaxLen.
905 // 4. We exitted the vector loop early due to a mismatch and need to return
906 // the index that we found.
907 Builder.SetInsertPoint(TheBB: EndBlock, IP: EndBlock->getFirstInsertionPt());
908 PHINode *ResPhi = Builder.CreatePHI(Ty: ResType, NumReservedValues: 4, Name: "mismatch_result");
909 ResPhi->addIncoming(V: MaxLen, BB: LoopIncBlock);
910 ResPhi->addIncoming(V: IndexPhi, BB: LoopStartBlock);
911 ResPhi->addIncoming(V: MaxLen, BB: VectorLoopIncBlock);
912 ResPhi->addIncoming(V: VectorLoopRes, BB: VectorLoopMismatchBlock);
913
914 Value *FinalRes = Builder.CreateTrunc(V: ResPhi, DestTy: ResType);
915
916 if (VerifyLoops) {
917 ScalarLoop->verifyLoop();
918 VectorLoop->verifyLoop();
919 if (!VectorLoop->isRecursivelyLCSSAForm(DT: *DT, LI: *LI))
920 report_fatal_error(reason: "Loops must remain in LCSSA form!");
921 if (!ScalarLoop->isRecursivelyLCSSAForm(DT: *DT, LI: *LI))
922 report_fatal_error(reason: "Loops must remain in LCSSA form!");
923 }
924
925 return FinalRes;
926}
927
928void LoopIdiomVectorize::transformByteCompare(GetElementPtrInst *GEPA,
929 GetElementPtrInst *GEPB,
930 PHINode *IndPhi, Value *MaxLen,
931 Instruction *Index, Value *Start,
932 bool IncIdx, BasicBlock *FoundBB,
933 BasicBlock *EndBB) {
934
935 // Insert the byte compare code at the end of the preheader block
936 BasicBlock *Preheader = CurLoop->getLoopPreheader();
937 BasicBlock *Header = CurLoop->getHeader();
938 UncondBrInst *PHBranch = cast<UncondBrInst>(Val: Preheader->getTerminator());
939 IRBuilder<> Builder(PHBranch);
940 DomTreeUpdater DTU(DT, DomTreeUpdater::UpdateStrategy::Lazy);
941 Builder.SetCurrentDebugLocation(PHBranch->getDebugLoc());
942
943 // Increment the pointer if this was done before the loads in the loop.
944 if (IncIdx)
945 Start = Builder.CreateAdd(LHS: Start, RHS: ConstantInt::get(Ty: Start->getType(), V: 1));
946
947 Value *ByteCmpRes =
948 expandFindMismatch(Builder, DTU, GEPA, GEPB, Index, Start, MaxLen);
949
950 // Replaces uses of index & induction Phi with intrinsic (we already
951 // checked that the the first instruction of Header is the Phi above).
952 assert(IndPhi->hasOneUse() && "Index phi node has more than one use!");
953 Index->replaceAllUsesWith(V: ByteCmpRes);
954
955 // If no mismatch was found, we can jump to the end block. Create a
956 // new basic block for the compare instruction.
957 auto *CmpBB = BasicBlock::Create(Context&: Preheader->getContext(), Name: "byte.compare",
958 Parent: Preheader->getParent());
959 CmpBB->moveBefore(MovePos: EndBB);
960
961 // Replace the branch in the preheader with an always-true conditional branch.
962 // This ensures there is still a reference to the original loop.
963 Builder.CreateCondBr(Cond: Builder.getTrue(), True: CmpBB, False: Header);
964 PHBranch->eraseFromParent();
965
966 BasicBlock *MismatchEnd = cast<Instruction>(Val: ByteCmpRes)->getParent();
967 DTU.applyUpdates(Updates: {{DominatorTree::Insert, MismatchEnd, CmpBB}});
968
969 // Create the branch to either the end or found block depending on the value
970 // returned by the intrinsic.
971 Builder.SetInsertPoint(CmpBB);
972 if (FoundBB != EndBB) {
973 Value *FoundCmp = Builder.CreateICmpEQ(LHS: ByteCmpRes, RHS: MaxLen);
974 Builder.CreateCondBr(Cond: FoundCmp, True: EndBB, False: FoundBB);
975 DTU.applyUpdates(Updates: {{DominatorTree::Insert, CmpBB, FoundBB},
976 {DominatorTree::Insert, CmpBB, EndBB}});
977
978 } else {
979 Builder.CreateBr(Dest: FoundBB);
980 DTU.applyUpdates(Updates: {{DominatorTree::Insert, CmpBB, FoundBB}});
981 }
982
983 // Ensure all Phis in the successors of CmpBB have an incoming value from it.
984 fixSuccessorPhis(L: CurLoop, ScalarRes: ByteCmpRes, VectorRes: ByteCmpRes, SuccBB: EndBB, IncBB: CmpBB);
985 if (EndBB != FoundBB)
986 fixSuccessorPhis(L: CurLoop, ScalarRes: ByteCmpRes, VectorRes: ByteCmpRes, SuccBB: FoundBB, IncBB: CmpBB);
987
988 // The new CmpBB block isn't part of the loop, but will need to be added to
989 // the outer loop if there is one.
990 if (!CurLoop->isOutermost())
991 CurLoop->getParentLoop()->addBasicBlockToLoop(NewBB: CmpBB, LI&: *LI);
992
993 if (VerifyLoops && CurLoop->getParentLoop()) {
994 CurLoop->getParentLoop()->verifyLoop();
995 if (!CurLoop->getParentLoop()->isRecursivelyLCSSAForm(DT: *DT, LI: *LI))
996 report_fatal_error(reason: "Loops must remain in LCSSA form!");
997 }
998}
999
1000bool LoopIdiomVectorize::recognizeFindFirstByte() {
1001 // Currently the transformation only works on scalable vector types, although
1002 // there is no fundamental reason why it cannot be made to work for fixed
1003 // vectors. We also need to know the target's minimum page size in order to
1004 // generate runtime memory checks to ensure the vector version won't fault.
1005 if (!TTI->supportsScalableVectors() || !TTI->getMinPageSize().has_value() ||
1006 DisableFindFirstByte)
1007 return false;
1008
1009 // We exclude loops with trip counts > minimum page size via runtime checks,
1010 // so make sure that the minimum page size is something sensible such that
1011 // induction variables cannot overflow.
1012 if (uint64_t(*TTI->getMinPageSize()) >
1013 (std::numeric_limits<uint64_t>::max() / 2))
1014 return false;
1015
1016 // Define some constants we need throughout.
1017 BasicBlock *Header = CurLoop->getHeader();
1018 LLVMContext &Ctx = Header->getContext();
1019
1020 // We are expecting the four blocks defined below: Header, MatchBB, InnerBB,
1021 // and OuterBB. For now, we will bail our for almost anything else. The Four
1022 // blocks contain one nested loop.
1023 if (CurLoop->getNumBackEdges() != 1 || CurLoop->getNumBlocks() != 4 ||
1024 CurLoop->getSubLoops().size() != 1)
1025 return false;
1026
1027 auto *InnerLoop = CurLoop->getSubLoops().front();
1028 Function &F = *InnerLoop->getHeader()->getParent();
1029
1030 // Bail if vectorization is disabled on inner loop.
1031 LoopVectorizeHints Hints(InnerLoop, /*InterleaveOnlyWhenForced=*/true, ORE);
1032 if (!Hints.allowVectorization(F: &F, L: InnerLoop,
1033 /*VectorizeOnlyWhenForced=*/false)) {
1034 LLVM_DEBUG(dbgs() << DEBUG_TYPE << " is disabled on inner loop "
1035 << InnerLoop->getName()
1036 << " due to vectorization hints\n");
1037 return false;
1038 }
1039
1040 PHINode *IndPhi = dyn_cast<PHINode>(Val: &Header->front());
1041 if (!IndPhi || IndPhi->getNumIncomingValues() != 2)
1042 return false;
1043
1044 // Check instruction counts.
1045 auto LoopBlocks = CurLoop->getBlocks();
1046 if (LoopBlocks[0]->sizeWithoutDebug() > 3 ||
1047 LoopBlocks[1]->sizeWithoutDebug() > 4 ||
1048 LoopBlocks[2]->sizeWithoutDebug() > 3 ||
1049 LoopBlocks[3]->sizeWithoutDebug() > 3)
1050 return false;
1051
1052 // Check that no instruction other than IndPhi has outside uses.
1053 for (BasicBlock *BB : LoopBlocks)
1054 for (Instruction &I : *BB)
1055 if (&I != IndPhi)
1056 for (User *U : I.users())
1057 if (!CurLoop->contains(Inst: cast<Instruction>(Val: U)))
1058 return false;
1059
1060 // Match the branch instruction in the header. We are expecting an
1061 // unconditional branch to the inner loop.
1062 //
1063 // Header:
1064 // %14 = phi ptr [ %24, %OuterBB ], [ %3, %Header.preheader ]
1065 // %15 = load i8, ptr %14, align 1
1066 // br label %MatchBB
1067 BasicBlock *MatchBB;
1068 if (!match(V: Header->getTerminator(), P: m_UnconditionalBr(Succ&: MatchBB)) ||
1069 !InnerLoop->contains(BB: MatchBB))
1070 return false;
1071
1072 // MatchBB should be the entrypoint into the inner loop containing the
1073 // comparison between a search element and a needle.
1074 //
1075 // MatchBB:
1076 // %20 = phi ptr [ %7, %Header ], [ %17, %InnerBB ]
1077 // %21 = load i8, ptr %20, align 1
1078 // %22 = icmp eq i8 %15, %21
1079 // br i1 %22, label %ExitSucc, label %InnerBB
1080 BasicBlock *ExitSucc, *InnerBB;
1081 Value *LoadSearch, *LoadNeedle;
1082 CmpPredicate MatchPred;
1083 if (!match(V: MatchBB->getTerminator(),
1084 P: m_Br(C: m_ICmp(Pred&: MatchPred, L: m_Value(V&: LoadSearch), R: m_Value(V&: LoadNeedle)),
1085 T: m_BasicBlock(V&: ExitSucc), F: m_BasicBlock(V&: InnerBB))) ||
1086 MatchPred != ICmpInst::ICMP_EQ || !InnerLoop->contains(BB: InnerBB))
1087 return false;
1088
1089 // We expect outside uses of `IndPhi' in ExitSucc (and only there).
1090 for (User *U : IndPhi->users())
1091 if (!CurLoop->contains(Inst: cast<Instruction>(Val: U))) {
1092 auto *PN = dyn_cast<PHINode>(Val: U);
1093 if (!PN || PN->getParent() != ExitSucc)
1094 return false;
1095 }
1096
1097 // Match the loads and check they are simple.
1098 Value *Search, *Needle;
1099 if (!match(V: LoadSearch, P: m_Load(Op: m_Value(V&: Search))) ||
1100 !match(V: LoadNeedle, P: m_Load(Op: m_Value(V&: Needle))) ||
1101 !cast<LoadInst>(Val: LoadSearch)->isSimple() ||
1102 !cast<LoadInst>(Val: LoadNeedle)->isSimple())
1103 return false;
1104
1105 // Check we are loading valid characters.
1106 Type *CharTy = LoadSearch->getType();
1107 if (!CharTy->isIntegerTy() || LoadNeedle->getType() != CharTy)
1108 return false;
1109
1110 // Pick the vectorisation factor based on CharTy, work out the cost of the
1111 // match intrinsic and decide if we should use it.
1112 // Note: For the time being we assume 128-bit vectors.
1113 unsigned VF = 128 / CharTy->getIntegerBitWidth();
1114 SmallVector<Type *> Args = {
1115 ScalableVectorType::get(ElementType: CharTy, MinNumElts: VF), FixedVectorType::get(ElementType: CharTy, NumElts: VF),
1116 ScalableVectorType::get(ElementType: Type::getInt1Ty(C&: Ctx), MinNumElts: VF)};
1117 IntrinsicCostAttributes Attrs(Intrinsic::experimental_vector_match, Args[2],
1118 Args);
1119 if (TTI->getIntrinsicInstrCost(ICA: Attrs, CostKind: TTI::TCK_SizeAndLatency) > 4)
1120 return false;
1121
1122 // The loads come from two PHIs, each with two incoming values.
1123 PHINode *PSearch = dyn_cast<PHINode>(Val: Search);
1124 PHINode *PNeedle = dyn_cast<PHINode>(Val: Needle);
1125 if (!PSearch || PSearch->getNumIncomingValues() != 2 || !PNeedle ||
1126 PNeedle->getNumIncomingValues() != 2)
1127 return false;
1128
1129 // One PHI comes from the outer loop (PSearch), the other one from the inner
1130 // loop (PNeedle). PSearch effectively corresponds to IndPhi.
1131 if (InnerLoop->contains(Inst: PSearch))
1132 std::swap(a&: PSearch, b&: PNeedle);
1133 if (PSearch != &Header->front() || PNeedle != &MatchBB->front())
1134 return false;
1135
1136 // The incoming values of both PHI nodes should be a gep of 1.
1137 Value *SearchStart = PSearch->getIncomingValue(i: 0);
1138 Value *SearchIndex = PSearch->getIncomingValue(i: 1);
1139 if (CurLoop->contains(BB: PSearch->getIncomingBlock(i: 0)))
1140 std::swap(a&: SearchStart, b&: SearchIndex);
1141
1142 Value *NeedleStart = PNeedle->getIncomingValue(i: 0);
1143 Value *NeedleIndex = PNeedle->getIncomingValue(i: 1);
1144 if (InnerLoop->contains(BB: PNeedle->getIncomingBlock(i: 0)))
1145 std::swap(a&: NeedleStart, b&: NeedleIndex);
1146
1147 // Match the GEPs.
1148 if (!match(V: SearchIndex, P: m_GEP(Ops: m_Specific(V: PSearch), Ops: m_One())) ||
1149 !match(V: NeedleIndex, P: m_GEP(Ops: m_Specific(V: PNeedle), Ops: m_One())))
1150 return false;
1151
1152 // Check the GEPs result type matches `CharTy'.
1153 GetElementPtrInst *GEPSearch = cast<GetElementPtrInst>(Val: SearchIndex);
1154 GetElementPtrInst *GEPNeedle = cast<GetElementPtrInst>(Val: NeedleIndex);
1155 if (GEPSearch->getResultElementType() != CharTy ||
1156 GEPNeedle->getResultElementType() != CharTy)
1157 return false;
1158
1159 // InnerBB should increment the address of the needle pointer.
1160 //
1161 // InnerBB:
1162 // %17 = getelementptr inbounds i8, ptr %20, i64 1
1163 // %18 = icmp eq ptr %17, %10
1164 // br i1 %18, label %OuterBB, label %MatchBB
1165 BasicBlock *OuterBB;
1166 Value *NeedleEnd;
1167 if (!match(V: InnerBB->getTerminator(),
1168 P: m_Br(C: m_SpecificICmp(MatchPred: ICmpInst::ICMP_EQ, L: m_Specific(V: GEPNeedle),
1169 R: m_Value(V&: NeedleEnd)),
1170 T: m_BasicBlock(V&: OuterBB), F: m_Specific(V: MatchBB))) ||
1171 !CurLoop->contains(BB: OuterBB))
1172 return false;
1173
1174 // OuterBB should increment the address of the search element pointer.
1175 //
1176 // OuterBB:
1177 // %24 = getelementptr inbounds i8, ptr %14, i64 1
1178 // %25 = icmp eq ptr %24, %6
1179 // br i1 %25, label %ExitFail, label %Header
1180 BasicBlock *ExitFail;
1181 Value *SearchEnd;
1182 if (!match(V: OuterBB->getTerminator(),
1183 P: m_Br(C: m_SpecificICmp(MatchPred: ICmpInst::ICMP_EQ, L: m_Specific(V: GEPSearch),
1184 R: m_Value(V&: SearchEnd)),
1185 T: m_BasicBlock(V&: ExitFail), F: m_Specific(V: Header))))
1186 return false;
1187
1188 if (!CurLoop->isLoopInvariant(V: SearchStart) ||
1189 !CurLoop->isLoopInvariant(V: SearchEnd) ||
1190 !CurLoop->isLoopInvariant(V: NeedleStart) ||
1191 !CurLoop->isLoopInvariant(V: NeedleEnd))
1192 return false;
1193
1194 LLVM_DEBUG(dbgs() << "Found idiom in loop: \n" << *CurLoop << "\n\n");
1195
1196 transformFindFirstByte(IndPhi, VF, CharTy, ExitSucc, ExitFail, SearchStart,
1197 SearchEnd, NeedleStart, NeedleEnd);
1198 return true;
1199}
1200
1201Value *LoopIdiomVectorize::expandFindFirstByte(
1202 IRBuilder<> &Builder, DomTreeUpdater &DTU, unsigned VF, Type *CharTy,
1203 Value *IndPhi, BasicBlock *ExitSucc, BasicBlock *ExitFail,
1204 Value *SearchStart, Value *SearchEnd, Value *NeedleStart,
1205 Value *NeedleEnd) {
1206 // Set up some types and constants that we intend to reuse.
1207 auto *I64Ty = Builder.getInt64Ty();
1208 auto *PredVTy = ScalableVectorType::get(ElementType: Builder.getInt1Ty(), MinNumElts: VF);
1209 auto *CharVTy = ScalableVectorType::get(ElementType: CharTy, MinNumElts: VF);
1210 auto *ConstVF = ConstantInt::get(Ty: I64Ty, V: VF);
1211
1212 // Other common arguments.
1213 BasicBlock *Preheader = CurLoop->getLoopPreheader();
1214 LLVMContext &Ctx = Preheader->getContext();
1215 Value *Passthru = ConstantInt::getNullValue(Ty: CharVTy);
1216
1217 // Split block in the original loop preheader.
1218 // SPH is the new preheader to the old scalar loop.
1219 BasicBlock *SPH = SplitBlock(Old: Preheader, SplitPt: Preheader->getTerminator(), DT, LI,
1220 MSSAU: nullptr, BBName: "scalar_preheader");
1221
1222 // Create the blocks that we're going to use.
1223 //
1224 // We will have the following loops:
1225 // (O) Outer loop where we iterate over the elements of the search array.
1226 // (I) Inner loop where we iterate over the elements of the needle array.
1227 //
1228 // Overall, the blocks do the following:
1229 // (0) Check if the arrays can't cross page boundaries. If so go to (1),
1230 // otherwise fall back to the original scalar loop.
1231 // (1) Load the search array. Go to (2).
1232 // (2) (a) Load the needle array.
1233 // (b) Splat the first element to the inactive lanes.
1234 // (c) Accumulate any matches found. If we haven't reached the end of the
1235 // needle array loop back to (2), otherwise go to (3).
1236 // (3) Test if we found any match. If so go to (4), otherwise go to (5).
1237 // (4) Compute the index of the first match and exit.
1238 // (5) Check if we've reached the end of the search array. If not loop back to
1239 // (1), otherwise exit.
1240 // Blocks (0,4) are not part of any loop. Blocks (1,3,5) and (2) belong to the
1241 // outer and inner loops, respectively.
1242 BasicBlock *BB0 = BasicBlock::Create(Context&: Ctx, Name: "mem_check", Parent: SPH->getParent(), InsertBefore: SPH);
1243 BasicBlock *BB1 =
1244 BasicBlock::Create(Context&: Ctx, Name: "find_first_vec_header", Parent: SPH->getParent(), InsertBefore: SPH);
1245 BasicBlock *BB2 =
1246 BasicBlock::Create(Context&: Ctx, Name: "needle_check_vec", Parent: SPH->getParent(), InsertBefore: SPH);
1247 BasicBlock *BB3 =
1248 BasicBlock::Create(Context&: Ctx, Name: "match_check_vec", Parent: SPH->getParent(), InsertBefore: SPH);
1249 BasicBlock *BB4 =
1250 BasicBlock::Create(Context&: Ctx, Name: "calculate_match", Parent: SPH->getParent(), InsertBefore: SPH);
1251 BasicBlock *BB5 =
1252 BasicBlock::Create(Context&: Ctx, Name: "search_check_vec", Parent: SPH->getParent(), InsertBefore: SPH);
1253
1254 // Update LoopInfo with the new loops.
1255 auto OuterLoop = LI->AllocateLoop();
1256 auto InnerLoop = LI->AllocateLoop();
1257
1258 if (auto ParentLoop = CurLoop->getParentLoop()) {
1259 ParentLoop->addBasicBlockToLoop(NewBB: BB0, LI&: *LI);
1260 ParentLoop->addChildLoop(NewChild: OuterLoop);
1261 ParentLoop->addBasicBlockToLoop(NewBB: BB4, LI&: *LI);
1262 } else {
1263 LI->addTopLevelLoop(New: OuterLoop);
1264 }
1265
1266 // Add the inner loop to the outer.
1267 OuterLoop->addChildLoop(NewChild: InnerLoop);
1268
1269 // Add the new basic blocks to the corresponding loops.
1270 OuterLoop->addBasicBlockToLoop(NewBB: BB1, LI&: *LI);
1271 OuterLoop->addBasicBlockToLoop(NewBB: BB3, LI&: *LI);
1272 OuterLoop->addBasicBlockToLoop(NewBB: BB5, LI&: *LI);
1273 InnerLoop->addBasicBlockToLoop(NewBB: BB2, LI&: *LI);
1274
1275 // Update the terminator added by SplitBlock to branch to the first block.
1276 Preheader->getTerminator()->setSuccessor(Idx: 0, BB: BB0);
1277 DTU.applyUpdates(Updates: {{DominatorTree::Delete, Preheader, SPH},
1278 {DominatorTree::Insert, Preheader, BB0}});
1279
1280 // (0) Check if we could be crossing a page boundary; if so, fallback to the
1281 // old scalar loops. Also create a predicate of VF elements to be used in the
1282 // vector loops.
1283 Builder.SetInsertPoint(BB0);
1284 Value *ISearchStart =
1285 Builder.CreatePtrToInt(V: SearchStart, DestTy: I64Ty, Name: "search_start_int");
1286 Value *ISearchEnd =
1287 Builder.CreatePtrToInt(V: SearchEnd, DestTy: I64Ty, Name: "search_end_int");
1288 Value *SearchIdxInit = Constant::getNullValue(Ty: I64Ty);
1289 Value *SearchTripCount =
1290 Builder.CreateZExt(V: Builder.CreatePtrDiff(ElemTy: CharTy, LHS: SearchEnd, RHS: SearchStart,
1291 Name: "search_trip_count"),
1292 DestTy: I64Ty);
1293 Value *INeedleStart =
1294 Builder.CreatePtrToInt(V: NeedleStart, DestTy: I64Ty, Name: "needle_start_int");
1295 Value *INeedleEnd =
1296 Builder.CreatePtrToInt(V: NeedleEnd, DestTy: I64Ty, Name: "needle_end_int");
1297 Value *NeedleIdxInit = Constant::getNullValue(Ty: I64Ty);
1298 Value *NeedleTripCount =
1299 Builder.CreateZExt(V: Builder.CreatePtrDiff(ElemTy: CharTy, LHS: NeedleEnd, RHS: NeedleStart,
1300 Name: "needle_trip_count"),
1301 DestTy: I64Ty);
1302 Value *PredVF =
1303 Builder.CreateIntrinsic(ID: Intrinsic::get_active_lane_mask, Types: {PredVTy, I64Ty},
1304 Args: {ConstantInt::get(Ty: I64Ty, V: 0), ConstVF});
1305
1306 const uint64_t MinPageSize = TTI->getMinPageSize().value();
1307 const uint64_t AddrShiftAmt = llvm::Log2_64(Value: MinPageSize);
1308 Value *SearchStartPage =
1309 Builder.CreateLShr(LHS: ISearchStart, RHS: AddrShiftAmt, Name: "search_start_page");
1310 Value *SearchEndPage =
1311 Builder.CreateLShr(LHS: ISearchEnd, RHS: AddrShiftAmt, Name: "search_end_page");
1312 Value *NeedleStartPage =
1313 Builder.CreateLShr(LHS: INeedleStart, RHS: AddrShiftAmt, Name: "needle_start_page");
1314 Value *NeedleEndPage =
1315 Builder.CreateLShr(LHS: INeedleEnd, RHS: AddrShiftAmt, Name: "needle_end_page");
1316 Value *SearchPageCmp =
1317 Builder.CreateICmpNE(LHS: SearchStartPage, RHS: SearchEndPage, Name: "search_page_cmp");
1318 Value *NeedlePageCmp =
1319 Builder.CreateICmpNE(LHS: NeedleStartPage, RHS: NeedleEndPage, Name: "needle_page_cmp");
1320
1321 Value *CombinedPageCmp =
1322 Builder.CreateOr(LHS: SearchPageCmp, RHS: NeedlePageCmp, Name: "combined_page_cmp");
1323 CondBrInst *CombinedPageBr = Builder.CreateCondBr(Cond: CombinedPageCmp, True: SPH, False: BB1);
1324 CombinedPageBr->setMetadata(KindID: LLVMContext::MD_prof,
1325 Node: MDBuilder(Ctx).createBranchWeights(TrueWeight: 10, FalseWeight: 90));
1326 DTU.applyUpdates(
1327 Updates: {{DominatorTree::Insert, BB0, SPH}, {DominatorTree::Insert, BB0, BB1}});
1328
1329 // (1) Load the search array and branch to the inner loop.
1330 Builder.SetInsertPoint(BB1);
1331 PHINode *SearchIdx = Builder.CreatePHI(Ty: I64Ty, NumReservedValues: 2, Name: "search_idx");
1332 Value *PredSearch = Builder.CreateIntrinsic(
1333 ID: Intrinsic::get_active_lane_mask, Types: {PredVTy, I64Ty},
1334 Args: {SearchIdx, SearchTripCount}, FMFSource: nullptr, Name: "search_pred");
1335 PredSearch = Builder.CreateAnd(LHS: PredVF, RHS: PredSearch, Name: "search_masked");
1336 Value *Search = Builder.CreateGEP(Ty: CharTy, Ptr: SearchStart, IdxList: SearchIdx, Name: "psearch");
1337 Value *LoadSearch = Builder.CreateMaskedLoad(
1338 Ty: CharVTy, Ptr: Search, Alignment: Align(1), Mask: PredSearch, PassThru: Passthru, Name: "search_load_vec");
1339 Value *MatchInit = Constant::getNullValue(Ty: PredVTy);
1340 Builder.CreateBr(Dest: BB2);
1341 DTU.applyUpdates(Updates: {{DominatorTree::Insert, BB1, BB2}});
1342
1343 // (2) Inner loop.
1344 Builder.SetInsertPoint(BB2);
1345 PHINode *NeedleIdx = Builder.CreatePHI(Ty: I64Ty, NumReservedValues: 2, Name: "needle_idx");
1346 PHINode *Match = Builder.CreatePHI(Ty: PredVTy, NumReservedValues: 2, Name: "pmatch");
1347
1348 // (2.a) Load the needle array.
1349 Value *PredNeedle = Builder.CreateIntrinsic(
1350 ID: Intrinsic::get_active_lane_mask, Types: {PredVTy, I64Ty},
1351 Args: {NeedleIdx, NeedleTripCount}, FMFSource: nullptr, Name: "needle_pred");
1352 PredNeedle = Builder.CreateAnd(LHS: PredVF, RHS: PredNeedle, Name: "needle_masked");
1353 Value *Needle = Builder.CreateGEP(Ty: CharTy, Ptr: NeedleStart, IdxList: NeedleIdx, Name: "pneedle");
1354 Value *LoadNeedle = Builder.CreateMaskedLoad(
1355 Ty: CharVTy, Ptr: Needle, Alignment: Align(1), Mask: PredNeedle, PassThru: Passthru, Name: "needle_load_vec");
1356
1357 // (2.b) Splat the first element to the inactive lanes.
1358 Value *Needle0 =
1359 Builder.CreateExtractElement(Vec: LoadNeedle, Idx: uint64_t(0), Name: "needle0");
1360 Value *Needle0Splat = Builder.CreateVectorSplat(EC: ElementCount::getScalable(MinVal: VF),
1361 V: Needle0, Name: "needle0");
1362 LoadNeedle = Builder.CreateSelect(C: PredNeedle, True: LoadNeedle, False: Needle0Splat,
1363 Name: "needle_splat");
1364 LoadNeedle = Builder.CreateExtractVector(
1365 DstType: FixedVectorType::get(ElementType: CharTy, NumElts: VF), SrcVec: LoadNeedle, Idx: uint64_t(0), Name: "needle_vec");
1366
1367 // (2.c) Accumulate matches.
1368 Value *MatchSeg = Builder.CreateIntrinsic(
1369 ID: Intrinsic::experimental_vector_match, Types: {CharVTy, LoadNeedle->getType()},
1370 Args: {LoadSearch, LoadNeedle, PredSearch}, FMFSource: nullptr, Name: "match_segment");
1371 Value *MatchAcc = Builder.CreateOr(LHS: Match, RHS: MatchSeg, Name: "match_accumulator");
1372 Value *NextNeedleIdx =
1373 Builder.CreateAdd(LHS: NeedleIdx, RHS: ConstVF, Name: "needle_idx_next");
1374 Builder.CreateCondBr(Cond: Builder.CreateICmpULT(LHS: NextNeedleIdx, RHS: NeedleTripCount),
1375 True: BB2, False: BB3);
1376 DTU.applyUpdates(
1377 Updates: {{DominatorTree::Insert, BB2, BB2}, {DominatorTree::Insert, BB2, BB3}});
1378
1379 // (3) Check if we found a match.
1380 Builder.SetInsertPoint(BB3);
1381 PHINode *MatchPredAccLCSSA = Builder.CreatePHI(Ty: PredVTy, NumReservedValues: 1, Name: "match_pred");
1382 Value *IfAnyMatch = Builder.CreateOrReduce(Src: MatchPredAccLCSSA);
1383 Builder.CreateCondBr(Cond: IfAnyMatch, True: BB4, False: BB5);
1384 DTU.applyUpdates(
1385 Updates: {{DominatorTree::Insert, BB3, BB4}, {DominatorTree::Insert, BB3, BB5}});
1386
1387 // (4) We found a match. Compute the index of its location and exit.
1388 Builder.SetInsertPoint(BB4);
1389 PHINode *MatchLCSSA =
1390 Builder.CreatePHI(Ty: SearchStart->getType(), NumReservedValues: 1, Name: "match_start");
1391 PHINode *MatchPredLCSSA = Builder.CreatePHI(Ty: PredVTy, NumReservedValues: 1, Name: "match_vec");
1392 Value *MatchCnt = Builder.CreateIntrinsic(
1393 ID: Intrinsic::experimental_cttz_elts, Types: {I64Ty, PredVTy},
1394 Args: {MatchPredLCSSA, /*ZeroIsPoison=*/Builder.getInt1(V: true)}, FMFSource: nullptr,
1395 Name: "match_idx");
1396 Value *MatchVal =
1397 Builder.CreateGEP(Ty: CharTy, Ptr: MatchLCSSA, IdxList: MatchCnt, Name: "match_res");
1398 Builder.CreateBr(Dest: ExitSucc);
1399 DTU.applyUpdates(Updates: {{DominatorTree::Insert, BB4, ExitSucc}});
1400
1401 // (5) Check if we've reached the end of the search array.
1402 Builder.SetInsertPoint(BB5);
1403 Value *NextSearchIdx =
1404 Builder.CreateAdd(LHS: SearchIdx, RHS: ConstVF, Name: "search_idx_next");
1405 Builder.CreateCondBr(Cond: Builder.CreateICmpULT(LHS: NextSearchIdx, RHS: SearchTripCount),
1406 True: BB1, False: ExitFail);
1407 DTU.applyUpdates(Updates: {{DominatorTree::Insert, BB5, BB1},
1408 {DominatorTree::Insert, BB5, ExitFail}});
1409
1410 // Set up the PHI nodes.
1411 SearchIdx->addIncoming(V: SearchIdxInit, BB: BB0);
1412 SearchIdx->addIncoming(V: NextSearchIdx, BB: BB5);
1413 NeedleIdx->addIncoming(V: NeedleIdxInit, BB: BB1);
1414 NeedleIdx->addIncoming(V: NextNeedleIdx, BB: BB2);
1415 Match->addIncoming(V: MatchInit, BB: BB1);
1416 Match->addIncoming(V: MatchAcc, BB: BB2);
1417 // These are needed to retain LCSSA form.
1418 MatchPredAccLCSSA->addIncoming(V: MatchAcc, BB: BB2);
1419 MatchLCSSA->addIncoming(V: Search, BB: BB3);
1420 MatchPredLCSSA->addIncoming(V: MatchPredAccLCSSA, BB: BB3);
1421
1422 // Ensure all Phis in the successors of BB4/BB5 have an incoming value from
1423 // them.
1424 fixSuccessorPhis(L: CurLoop, ScalarRes: IndPhi, VectorRes: MatchVal, SuccBB: ExitSucc, IncBB: BB4);
1425 if (ExitSucc != ExitFail)
1426 fixSuccessorPhis(L: CurLoop, ScalarRes: IndPhi, VectorRes: MatchVal, SuccBB: ExitFail, IncBB: BB5);
1427
1428 if (VerifyLoops) {
1429 OuterLoop->verifyLoop();
1430 InnerLoop->verifyLoop();
1431 if (!OuterLoop->isRecursivelyLCSSAForm(DT: *DT, LI: *LI))
1432 report_fatal_error(reason: "Loops must remain in LCSSA form!");
1433 }
1434
1435 return MatchVal;
1436}
1437
1438void LoopIdiomVectorize::transformFindFirstByte(
1439 PHINode *IndPhi, unsigned VF, Type *CharTy, BasicBlock *ExitSucc,
1440 BasicBlock *ExitFail, Value *SearchStart, Value *SearchEnd,
1441 Value *NeedleStart, Value *NeedleEnd) {
1442 // Insert the find first byte code at the end of the preheader block.
1443 BasicBlock *Preheader = CurLoop->getLoopPreheader();
1444 UncondBrInst *PHBranch = cast<UncondBrInst>(Val: Preheader->getTerminator());
1445 IRBuilder<> Builder(PHBranch);
1446 DomTreeUpdater DTU(DT, DomTreeUpdater::UpdateStrategy::Lazy);
1447 Builder.SetCurrentDebugLocation(PHBranch->getDebugLoc());
1448
1449 expandFindFirstByte(Builder, DTU, VF, CharTy, IndPhi, ExitSucc, ExitFail,
1450 SearchStart, SearchEnd, NeedleStart, NeedleEnd);
1451
1452 if (VerifyLoops && CurLoop->getParentLoop()) {
1453 CurLoop->getParentLoop()->verifyLoop();
1454 if (!CurLoop->getParentLoop()->isRecursivelyLCSSAForm(DT: *DT, LI: *LI))
1455 report_fatal_error(reason: "Loops must remain in LCSSA form!");
1456 }
1457}
1458