1//===-------- LoopIdiomVectorize.cpp - Loop idiom vectorization -----------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This pass implements a pass that recognizes certain loop idioms and
10// transforms them into more optimized versions of the same loop. In cases
11// where this happens, it can be a significant performance win.
12//
13// We currently support two loops:
14//
15// 1. A loop that finds the first mismatched byte in an array and returns the
16// index, i.e. something like:
17//
18// while (++i != n) {
19// if (a[i] != b[i])
20// break;
21// }
22//
23// In this example we can actually vectorize the loop despite the early exit,
24// although the loop vectorizer does not support it. It requires some extra
25// checks to deal with the possibility of faulting loads when crossing page
26// boundaries. However, even with these checks it is still profitable to do the
27// transformation.
28//
29// TODO List:
30//
31// * Add support for the inverse case where we scan for a matching element.
32// * Permit 64-bit induction variable types.
33// * Recognize loops that increment the IV *after* comparing bytes.
34// * Allow 32-bit sign-extends of the IV used by the GEP.
35//
36// 2. A loop that finds the first matching character in an array among a set of
37// possible matches, e.g.:
38//
39// for (; first != last; ++first)
40// for (s_it = s_first; s_it != s_last; ++s_it)
41// if (*first == *s_it)
42// return first;
43// return last;
44//
45// This corresponds to std::find_first_of (for arrays of bytes) from the C++
46// standard library. This function can be implemented efficiently for targets
47// that support @llvm.experimental.vector.match. For example, on AArch64 targets
48// that implement SVE2, this lower to a MATCH instruction, which enables us to
49// perform up to 16x16=256 comparisons in one go. This can lead to very
50// significant speedups.
51//
52// TODO:
53//
54// * Add support for `find_first_not_of' loops (i.e. with not-equal comparison).
55// * Make VF a configurable parameter (right now we assume 128-bit vectors).
56// * Potentially adjust the cost model to let the transformation kick-in even if
57// @llvm.experimental.vector.match doesn't have direct support in hardware.
58//
59//===----------------------------------------------------------------------===//
60//
61// NOTE: This Pass matches really specific loop patterns because it's only
62// supposed to be a temporary solution until our LoopVectorizer is powerful
63// enough to vectorize them automatically.
64//
65//===----------------------------------------------------------------------===//
66
67#include "llvm/Transforms/Vectorize/LoopIdiomVectorize.h"
68#include "llvm/Analysis/DomTreeUpdater.h"
69#include "llvm/Analysis/LoopPass.h"
70#include "llvm/Analysis/TargetTransformInfo.h"
71#include "llvm/IR/Dominators.h"
72#include "llvm/IR/IRBuilder.h"
73#include "llvm/IR/Intrinsics.h"
74#include "llvm/IR/MDBuilder.h"
75#include "llvm/IR/PatternMatch.h"
76#include "llvm/Transforms/Utils/BasicBlockUtils.h"
77
78using namespace llvm;
79using namespace PatternMatch;
80
81#define DEBUG_TYPE "loop-idiom-vectorize"
82
83static cl::opt<bool> DisableAll("disable-loop-idiom-vectorize-all", cl::Hidden,
84 cl::init(Val: false),
85 cl::desc("Disable Loop Idiom Vectorize Pass."));
86
87static cl::opt<LoopIdiomVectorizeStyle>
88 LITVecStyle("loop-idiom-vectorize-style", cl::Hidden,
89 cl::desc("The vectorization style for loop idiom transform."),
90 cl::values(clEnumValN(LoopIdiomVectorizeStyle::Masked, "masked",
91 "Use masked vector intrinsics"),
92 clEnumValN(LoopIdiomVectorizeStyle::Predicated,
93 "predicated", "Use VP intrinsics")),
94 cl::init(Val: LoopIdiomVectorizeStyle::Masked));
95
96static cl::opt<bool>
97 DisableByteCmp("disable-loop-idiom-vectorize-bytecmp", cl::Hidden,
98 cl::init(Val: false),
99 cl::desc("Proceed with Loop Idiom Vectorize Pass, but do "
100 "not convert byte-compare loop(s)."));
101
102static cl::opt<unsigned>
103 ByteCmpVF("loop-idiom-vectorize-bytecmp-vf", cl::Hidden,
104 cl::desc("The vectorization factor for byte-compare patterns."),
105 cl::init(Val: 16));
106
107static cl::opt<bool>
108 DisableFindFirstByte("disable-loop-idiom-vectorize-find-first-byte",
109 cl::Hidden, cl::init(Val: false),
110 cl::desc("Do not convert find-first-byte loop(s)."));
111
112static cl::opt<bool>
113 VerifyLoops("loop-idiom-vectorize-verify", cl::Hidden, cl::init(Val: false),
114 cl::desc("Verify loops generated Loop Idiom Vectorize Pass."));
115
116namespace {
117class LoopIdiomVectorize {
118 LoopIdiomVectorizeStyle VectorizeStyle;
119 unsigned ByteCompareVF;
120 Loop *CurLoop = nullptr;
121 DominatorTree *DT;
122 LoopInfo *LI;
123 const TargetTransformInfo *TTI;
124 const DataLayout *DL;
125
126 // Blocks that will be used for inserting vectorized code.
127 BasicBlock *EndBlock = nullptr;
128 BasicBlock *VectorLoopPreheaderBlock = nullptr;
129 BasicBlock *VectorLoopStartBlock = nullptr;
130 BasicBlock *VectorLoopMismatchBlock = nullptr;
131 BasicBlock *VectorLoopIncBlock = nullptr;
132
133public:
134 LoopIdiomVectorize(LoopIdiomVectorizeStyle S, unsigned VF, DominatorTree *DT,
135 LoopInfo *LI, const TargetTransformInfo *TTI,
136 const DataLayout *DL)
137 : VectorizeStyle(S), ByteCompareVF(VF), DT(DT), LI(LI), TTI(TTI), DL(DL) {
138 }
139
140 bool run(Loop *L);
141
142private:
143 /// \name Countable Loop Idiom Handling
144 /// @{
145
146 bool runOnCountableLoop();
147 bool runOnLoopBlock(BasicBlock *BB, const SCEV *BECount,
148 SmallVectorImpl<BasicBlock *> &ExitBlocks);
149
150 bool recognizeByteCompare();
151
152 Value *expandFindMismatch(IRBuilder<> &Builder, DomTreeUpdater &DTU,
153 GetElementPtrInst *GEPA, GetElementPtrInst *GEPB,
154 Instruction *Index, Value *Start, Value *MaxLen);
155
156 Value *createMaskedFindMismatch(IRBuilder<> &Builder, DomTreeUpdater &DTU,
157 GetElementPtrInst *GEPA,
158 GetElementPtrInst *GEPB, Value *ExtStart,
159 Value *ExtEnd);
160 Value *createPredicatedFindMismatch(IRBuilder<> &Builder, DomTreeUpdater &DTU,
161 GetElementPtrInst *GEPA,
162 GetElementPtrInst *GEPB, Value *ExtStart,
163 Value *ExtEnd);
164
165 void transformByteCompare(GetElementPtrInst *GEPA, GetElementPtrInst *GEPB,
166 PHINode *IndPhi, Value *MaxLen, Instruction *Index,
167 Value *Start, bool IncIdx, BasicBlock *FoundBB,
168 BasicBlock *EndBB);
169
170 bool recognizeFindFirstByte();
171
172 Value *expandFindFirstByte(IRBuilder<> &Builder, DomTreeUpdater &DTU,
173 unsigned VF, Type *CharTy, Value *IndPhi,
174 BasicBlock *ExitSucc, BasicBlock *ExitFail,
175 Value *SearchStart, Value *SearchEnd,
176 Value *NeedleStart, Value *NeedleEnd);
177
178 void transformFindFirstByte(PHINode *IndPhi, unsigned VF, Type *CharTy,
179 BasicBlock *ExitSucc, BasicBlock *ExitFail,
180 Value *SearchStart, Value *SearchEnd,
181 Value *NeedleStart, Value *NeedleEnd);
182 /// @}
183};
184} // anonymous namespace
185
186PreservedAnalyses LoopIdiomVectorizePass::run(Loop &L, LoopAnalysisManager &AM,
187 LoopStandardAnalysisResults &AR,
188 LPMUpdater &) {
189 if (DisableAll)
190 return PreservedAnalyses::all();
191
192 const auto *DL = &L.getHeader()->getDataLayout();
193
194 LoopIdiomVectorizeStyle VecStyle = VectorizeStyle;
195 if (LITVecStyle.getNumOccurrences())
196 VecStyle = LITVecStyle;
197
198 unsigned BCVF = ByteCompareVF;
199 if (ByteCmpVF.getNumOccurrences())
200 BCVF = ByteCmpVF;
201
202 LoopIdiomVectorize LIV(VecStyle, BCVF, &AR.DT, &AR.LI, &AR.TTI, DL);
203 if (!LIV.run(L: &L))
204 return PreservedAnalyses::all();
205
206 return PreservedAnalyses::none();
207}
208
209//===----------------------------------------------------------------------===//
210//
211// Implementation of LoopIdiomVectorize
212//
213//===----------------------------------------------------------------------===//
214
215bool LoopIdiomVectorize::run(Loop *L) {
216 CurLoop = L;
217
218 Function &F = *L->getHeader()->getParent();
219 if (DisableAll || F.hasOptSize())
220 return false;
221
222 if (F.hasFnAttribute(Kind: Attribute::NoImplicitFloat)) {
223 LLVM_DEBUG(dbgs() << DEBUG_TYPE << " is disabled on " << F.getName()
224 << " due to its NoImplicitFloat attribute");
225 return false;
226 }
227
228 // If the loop could not be converted to canonical form, it must have an
229 // indirectbr in it, just give up.
230 if (!L->getLoopPreheader())
231 return false;
232
233 LLVM_DEBUG(dbgs() << DEBUG_TYPE " Scanning: F[" << F.getName() << "] Loop %"
234 << CurLoop->getHeader()->getName() << "\n");
235
236 if (recognizeByteCompare())
237 return true;
238
239 if (recognizeFindFirstByte())
240 return true;
241
242 return false;
243}
244
245static void fixSuccessorPhis(Loop *L, Value *ScalarRes, Value *VectorRes,
246 BasicBlock *SuccBB, BasicBlock *IncBB) {
247 for (PHINode &PN : SuccBB->phis()) {
248 // Look through the incoming values to find ScalarRes, meaning this is a
249 // PHI collecting the results of the transformation.
250 bool ResPhi = false;
251 for (Value *Op : PN.incoming_values())
252 if (Op == ScalarRes) {
253 ResPhi = true;
254 break;
255 }
256
257 // Any PHI that depended upon the result of the transformation needs a new
258 // incoming value from IncBB.
259 if (ResPhi)
260 PN.addIncoming(V: VectorRes, BB: IncBB);
261 else {
262 // There should be no other outside uses of other values in the
263 // original loop. Any incoming values should either:
264 // 1. Be for blocks outside the loop, which aren't interesting. Or ..
265 // 2. These are from blocks in the loop with values defined outside
266 // the loop. We should a similar incoming value from CmpBB.
267 for (BasicBlock *BB : PN.blocks())
268 if (L->contains(BB)) {
269 PN.addIncoming(V: PN.getIncomingValueForBlock(BB), BB: IncBB);
270 break;
271 }
272 }
273 }
274}
275
276bool LoopIdiomVectorize::recognizeByteCompare() {
277 // Currently the transformation only works on scalable vector types, although
278 // there is no fundamental reason why it cannot be made to work for fixed
279 // width too.
280
281 // We also need to know the minimum page size for the target in order to
282 // generate runtime memory checks to ensure the vector version won't fault.
283 if (!TTI->supportsScalableVectors() || !TTI->getMinPageSize().has_value() ||
284 DisableByteCmp)
285 return false;
286
287 BasicBlock *Header = CurLoop->getHeader();
288
289 // In LoopIdiomVectorize::run we have already checked that the loop
290 // has a preheader so we can assume it's in a canonical form.
291 if (CurLoop->getNumBackEdges() != 1 || CurLoop->getNumBlocks() != 2)
292 return false;
293
294 PHINode *PN = dyn_cast<PHINode>(Val: &Header->front());
295 if (!PN || PN->getNumIncomingValues() != 2)
296 return false;
297
298 auto LoopBlocks = CurLoop->getBlocks();
299 // The first block in the loop should contain only 4 instructions, e.g.
300 //
301 // while.cond:
302 // %res.phi = phi i32 [ %start, %ph ], [ %inc, %while.body ]
303 // %inc = add i32 %res.phi, 1
304 // %cmp.not = icmp eq i32 %inc, %n
305 // br i1 %cmp.not, label %while.end, label %while.body
306 //
307 if (LoopBlocks[0]->sizeWithoutDebug() > 4)
308 return false;
309
310 // The second block should contain 7 instructions, e.g.
311 //
312 // while.body:
313 // %idx = zext i32 %inc to i64
314 // %idx.a = getelementptr inbounds i8, ptr %a, i64 %idx
315 // %load.a = load i8, ptr %idx.a
316 // %idx.b = getelementptr inbounds i8, ptr %b, i64 %idx
317 // %load.b = load i8, ptr %idx.b
318 // %cmp.not.ld = icmp eq i8 %load.a, %load.b
319 // br i1 %cmp.not.ld, label %while.cond, label %while.end
320 //
321 if (LoopBlocks[1]->sizeWithoutDebug() > 7)
322 return false;
323
324 // The incoming value to the PHI node from the loop should be an add of 1.
325 Value *StartIdx = nullptr;
326 Instruction *Index = nullptr;
327 if (!CurLoop->contains(BB: PN->getIncomingBlock(i: 0))) {
328 StartIdx = PN->getIncomingValue(i: 0);
329 Index = dyn_cast<Instruction>(Val: PN->getIncomingValue(i: 1));
330 } else {
331 StartIdx = PN->getIncomingValue(i: 1);
332 Index = dyn_cast<Instruction>(Val: PN->getIncomingValue(i: 0));
333 }
334
335 // Limit to 32-bit types for now
336 if (!Index || !Index->getType()->isIntegerTy(Bitwidth: 32) ||
337 !match(V: Index, P: m_c_Add(L: m_Specific(V: PN), R: m_One())))
338 return false;
339
340 // If we match the pattern, PN and Index will be replaced with the result of
341 // the cttz.elts intrinsic. If any other instructions are used outside of
342 // the loop, we cannot replace it.
343 for (BasicBlock *BB : LoopBlocks)
344 for (Instruction &I : *BB)
345 if (&I != PN && &I != Index)
346 for (User *U : I.users())
347 if (!CurLoop->contains(Inst: cast<Instruction>(Val: U)))
348 return false;
349
350 // Match the branch instruction for the header
351 Value *MaxLen;
352 BasicBlock *EndBB, *WhileBB;
353 if (!match(V: Header->getTerminator(),
354 P: m_Br(C: m_SpecificICmp(MatchPred: ICmpInst::ICMP_EQ, L: m_Specific(V: Index),
355 R: m_Value(V&: MaxLen)),
356 T: m_BasicBlock(V&: EndBB), F: m_BasicBlock(V&: WhileBB))) ||
357 !CurLoop->contains(BB: WhileBB))
358 return false;
359
360 // WhileBB should contain the pattern of load & compare instructions. Match
361 // the pattern and find the GEP instructions used by the loads.
362 BasicBlock *FoundBB;
363 BasicBlock *TrueBB;
364 Value *LoadA, *LoadB;
365 if (!match(V: WhileBB->getTerminator(),
366 P: m_Br(C: m_SpecificICmp(MatchPred: ICmpInst::ICMP_EQ, L: m_Value(V&: LoadA),
367 R: m_Value(V&: LoadB)),
368 T: m_BasicBlock(V&: TrueBB), F: m_BasicBlock(V&: FoundBB))) ||
369 !CurLoop->contains(BB: TrueBB))
370 return false;
371
372 Value *A, *B;
373 if (!match(V: LoadA, P: m_Load(Op: m_Value(V&: A))) || !match(V: LoadB, P: m_Load(Op: m_Value(V&: B))))
374 return false;
375
376 LoadInst *LoadAI = cast<LoadInst>(Val: LoadA);
377 LoadInst *LoadBI = cast<LoadInst>(Val: LoadB);
378 if (!LoadAI->isSimple() || !LoadBI->isSimple())
379 return false;
380
381 GetElementPtrInst *GEPA = dyn_cast<GetElementPtrInst>(Val: A);
382 GetElementPtrInst *GEPB = dyn_cast<GetElementPtrInst>(Val: B);
383
384 if (!GEPA || !GEPB)
385 return false;
386
387 Value *PtrA = GEPA->getPointerOperand();
388 Value *PtrB = GEPB->getPointerOperand();
389
390 // Check we are loading i8 values from two loop invariant pointers
391 if (!CurLoop->isLoopInvariant(V: PtrA) || !CurLoop->isLoopInvariant(V: PtrB) ||
392 !GEPA->getResultElementType()->isIntegerTy(Bitwidth: 8) ||
393 !GEPB->getResultElementType()->isIntegerTy(Bitwidth: 8) ||
394 !LoadAI->getType()->isIntegerTy(Bitwidth: 8) ||
395 !LoadBI->getType()->isIntegerTy(Bitwidth: 8) || PtrA == PtrB)
396 return false;
397
398 // Check that the index to the GEPs is the index we found earlier
399 if (GEPA->getNumIndices() > 1 || GEPB->getNumIndices() > 1)
400 return false;
401
402 Value *IdxA = GEPA->getOperand(i_nocapture: GEPA->getNumIndices());
403 Value *IdxB = GEPB->getOperand(i_nocapture: GEPB->getNumIndices());
404 if (IdxA != IdxB || !match(V: IdxA, P: m_ZExt(Op: m_Specific(V: Index))))
405 return false;
406
407 // We only ever expect the pre-incremented index value to be used inside the
408 // loop.
409 if (!PN->hasOneUse())
410 return false;
411
412 // Ensure that when the Found and End blocks are identical the PHIs have the
413 // supported format. We don't currently allow cases like this:
414 // while.cond:
415 // ...
416 // br i1 %cmp.not, label %while.end, label %while.body
417 //
418 // while.body:
419 // ...
420 // br i1 %cmp.not2, label %while.cond, label %while.end
421 //
422 // while.end:
423 // %final_ptr = phi ptr [ %c, %while.body ], [ %d, %while.cond ]
424 //
425 // Where the incoming values for %final_ptr are unique and from each of the
426 // loop blocks, but not actually defined in the loop. This requires extra
427 // work setting up the byte.compare block, i.e. by introducing a select to
428 // choose the correct value.
429 // TODO: We could add support for this in future.
430 if (FoundBB == EndBB) {
431 for (PHINode &EndPN : EndBB->phis()) {
432 Value *WhileCondVal = EndPN.getIncomingValueForBlock(BB: Header);
433 Value *WhileBodyVal = EndPN.getIncomingValueForBlock(BB: WhileBB);
434
435 // The value of the index when leaving the while.cond block is always the
436 // same as the end value (MaxLen) so we permit either. The value when
437 // leaving the while.body block should only be the index. Otherwise for
438 // any other values we only allow ones that are same for both blocks.
439 if (WhileCondVal != WhileBodyVal &&
440 ((WhileCondVal != Index && WhileCondVal != MaxLen) ||
441 (WhileBodyVal != Index)))
442 return false;
443 }
444 }
445
446 LLVM_DEBUG(dbgs() << "FOUND IDIOM IN LOOP: \n"
447 << *(EndBB->getParent()) << "\n\n");
448
449 // The index is incremented before the GEP/Load pair so we need to
450 // add 1 to the start value.
451 transformByteCompare(GEPA, GEPB, IndPhi: PN, MaxLen, Index, Start: StartIdx, /*IncIdx=*/true,
452 FoundBB, EndBB);
453 return true;
454}
455
456Value *LoopIdiomVectorize::createMaskedFindMismatch(
457 IRBuilder<> &Builder, DomTreeUpdater &DTU, GetElementPtrInst *GEPA,
458 GetElementPtrInst *GEPB, Value *ExtStart, Value *ExtEnd) {
459 Type *I64Type = Builder.getInt64Ty();
460 Type *ResType = Builder.getInt32Ty();
461 Type *LoadType = Builder.getInt8Ty();
462 Value *PtrA = GEPA->getPointerOperand();
463 Value *PtrB = GEPB->getPointerOperand();
464
465 ScalableVectorType *PredVTy =
466 ScalableVectorType::get(ElementType: Builder.getInt1Ty(), MinNumElts: ByteCompareVF);
467
468 Value *InitialPred = Builder.CreateIntrinsic(
469 ID: Intrinsic::get_active_lane_mask, Types: {PredVTy, I64Type}, Args: {ExtStart, ExtEnd});
470
471 Value *VecLen = Builder.CreateVScale(Ty: I64Type);
472 VecLen =
473 Builder.CreateMul(LHS: VecLen, RHS: ConstantInt::get(Ty: I64Type, V: ByteCompareVF), Name: "",
474 /*HasNUW=*/true, /*HasNSW=*/true);
475
476 Value *PFalse = Builder.CreateVectorSplat(EC: PredVTy->getElementCount(),
477 V: Builder.getInt1(V: false));
478
479 BranchInst *JumpToVectorLoop = BranchInst::Create(IfTrue: VectorLoopStartBlock);
480 Builder.Insert(I: JumpToVectorLoop);
481
482 DTU.applyUpdates(Updates: {{DominatorTree::Insert, VectorLoopPreheaderBlock,
483 VectorLoopStartBlock}});
484
485 // Set up the first vector loop block by creating the PHIs, doing the vector
486 // loads and comparing the vectors.
487 Builder.SetInsertPoint(VectorLoopStartBlock);
488 PHINode *LoopPred = Builder.CreatePHI(Ty: PredVTy, NumReservedValues: 2, Name: "mismatch_vec_loop_pred");
489 LoopPred->addIncoming(V: InitialPred, BB: VectorLoopPreheaderBlock);
490 PHINode *VectorIndexPhi = Builder.CreatePHI(Ty: I64Type, NumReservedValues: 2, Name: "mismatch_vec_index");
491 VectorIndexPhi->addIncoming(V: ExtStart, BB: VectorLoopPreheaderBlock);
492 Type *VectorLoadType =
493 ScalableVectorType::get(ElementType: Builder.getInt8Ty(), MinNumElts: ByteCompareVF);
494 Value *Passthru = ConstantInt::getNullValue(Ty: VectorLoadType);
495
496 Value *VectorLhsGep =
497 Builder.CreateGEP(Ty: LoadType, Ptr: PtrA, IdxList: VectorIndexPhi, Name: "", NW: GEPA->isInBounds());
498 Value *VectorLhsLoad = Builder.CreateMaskedLoad(Ty: VectorLoadType, Ptr: VectorLhsGep,
499 Alignment: Align(1), Mask: LoopPred, PassThru: Passthru);
500
501 Value *VectorRhsGep =
502 Builder.CreateGEP(Ty: LoadType, Ptr: PtrB, IdxList: VectorIndexPhi, Name: "", NW: GEPB->isInBounds());
503 Value *VectorRhsLoad = Builder.CreateMaskedLoad(Ty: VectorLoadType, Ptr: VectorRhsGep,
504 Alignment: Align(1), Mask: LoopPred, PassThru: Passthru);
505
506 Value *VectorMatchCmp = Builder.CreateICmpNE(LHS: VectorLhsLoad, RHS: VectorRhsLoad);
507 VectorMatchCmp = Builder.CreateSelect(C: LoopPred, True: VectorMatchCmp, False: PFalse);
508 Value *VectorMatchHasActiveLanes = Builder.CreateOrReduce(Src: VectorMatchCmp);
509 BranchInst *VectorEarlyExit = BranchInst::Create(
510 IfTrue: VectorLoopMismatchBlock, IfFalse: VectorLoopIncBlock, Cond: VectorMatchHasActiveLanes);
511 Builder.Insert(I: VectorEarlyExit);
512
513 DTU.applyUpdates(
514 Updates: {{DominatorTree::Insert, VectorLoopStartBlock, VectorLoopMismatchBlock},
515 {DominatorTree::Insert, VectorLoopStartBlock, VectorLoopIncBlock}});
516
517 // Increment the index counter and calculate the predicate for the next
518 // iteration of the loop. We branch back to the start of the loop if there
519 // is at least one active lane.
520 Builder.SetInsertPoint(VectorLoopIncBlock);
521 Value *NewVectorIndexPhi =
522 Builder.CreateAdd(LHS: VectorIndexPhi, RHS: VecLen, Name: "",
523 /*HasNUW=*/true, /*HasNSW=*/true);
524 VectorIndexPhi->addIncoming(V: NewVectorIndexPhi, BB: VectorLoopIncBlock);
525 Value *NewPred =
526 Builder.CreateIntrinsic(ID: Intrinsic::get_active_lane_mask,
527 Types: {PredVTy, I64Type}, Args: {NewVectorIndexPhi, ExtEnd});
528 LoopPred->addIncoming(V: NewPred, BB: VectorLoopIncBlock);
529
530 Value *PredHasActiveLanes =
531 Builder.CreateExtractElement(Vec: NewPred, Idx: uint64_t(0));
532 BranchInst *VectorLoopBranchBack =
533 BranchInst::Create(IfTrue: VectorLoopStartBlock, IfFalse: EndBlock, Cond: PredHasActiveLanes);
534 Builder.Insert(I: VectorLoopBranchBack);
535
536 DTU.applyUpdates(
537 Updates: {{DominatorTree::Insert, VectorLoopIncBlock, VectorLoopStartBlock},
538 {DominatorTree::Insert, VectorLoopIncBlock, EndBlock}});
539
540 // If we found a mismatch then we need to calculate which lane in the vector
541 // had a mismatch and add that on to the current loop index.
542 Builder.SetInsertPoint(VectorLoopMismatchBlock);
543 PHINode *FoundPred = Builder.CreatePHI(Ty: PredVTy, NumReservedValues: 1, Name: "mismatch_vec_found_pred");
544 FoundPred->addIncoming(V: VectorMatchCmp, BB: VectorLoopStartBlock);
545 PHINode *LastLoopPred =
546 Builder.CreatePHI(Ty: PredVTy, NumReservedValues: 1, Name: "mismatch_vec_last_loop_pred");
547 LastLoopPred->addIncoming(V: LoopPred, BB: VectorLoopStartBlock);
548 PHINode *VectorFoundIndex =
549 Builder.CreatePHI(Ty: I64Type, NumReservedValues: 1, Name: "mismatch_vec_found_index");
550 VectorFoundIndex->addIncoming(V: VectorIndexPhi, BB: VectorLoopStartBlock);
551
552 Value *PredMatchCmp = Builder.CreateAnd(LHS: LastLoopPred, RHS: FoundPred);
553 Value *Ctz = Builder.CreateCountTrailingZeroElems(ResTy: ResType, Mask: PredMatchCmp);
554 Ctz = Builder.CreateZExt(V: Ctz, DestTy: I64Type);
555 Value *VectorLoopRes64 = Builder.CreateAdd(LHS: VectorFoundIndex, RHS: Ctz, Name: "",
556 /*HasNUW=*/true, /*HasNSW=*/true);
557 return Builder.CreateTrunc(V: VectorLoopRes64, DestTy: ResType);
558}
559
560Value *LoopIdiomVectorize::createPredicatedFindMismatch(
561 IRBuilder<> &Builder, DomTreeUpdater &DTU, GetElementPtrInst *GEPA,
562 GetElementPtrInst *GEPB, Value *ExtStart, Value *ExtEnd) {
563 Type *I64Type = Builder.getInt64Ty();
564 Type *I32Type = Builder.getInt32Ty();
565 Type *ResType = I32Type;
566 Type *LoadType = Builder.getInt8Ty();
567 Value *PtrA = GEPA->getPointerOperand();
568 Value *PtrB = GEPB->getPointerOperand();
569
570 auto *JumpToVectorLoop = BranchInst::Create(IfTrue: VectorLoopStartBlock);
571 Builder.Insert(I: JumpToVectorLoop);
572
573 DTU.applyUpdates(Updates: {{DominatorTree::Insert, VectorLoopPreheaderBlock,
574 VectorLoopStartBlock}});
575
576 // Set up the first Vector loop block by creating the PHIs, doing the vector
577 // loads and comparing the vectors.
578 Builder.SetInsertPoint(VectorLoopStartBlock);
579 auto *VectorIndexPhi = Builder.CreatePHI(Ty: I64Type, NumReservedValues: 2, Name: "mismatch_vector_index");
580 VectorIndexPhi->addIncoming(V: ExtStart, BB: VectorLoopPreheaderBlock);
581
582 // Calculate AVL by subtracting the vector loop index from the trip count
583 Value *AVL = Builder.CreateSub(LHS: ExtEnd, RHS: VectorIndexPhi, Name: "avl", /*HasNUW=*/true,
584 /*HasNSW=*/true);
585
586 auto *VectorLoadType = ScalableVectorType::get(ElementType: LoadType, MinNumElts: ByteCompareVF);
587 auto *VF = ConstantInt::get(Ty: I32Type, V: ByteCompareVF);
588
589 Value *VL = Builder.CreateIntrinsic(ID: Intrinsic::experimental_get_vector_length,
590 Types: {I64Type}, Args: {AVL, VF, Builder.getTrue()});
591 Value *GepOffset = VectorIndexPhi;
592
593 Value *VectorLhsGep =
594 Builder.CreateGEP(Ty: LoadType, Ptr: PtrA, IdxList: GepOffset, Name: "", NW: GEPA->isInBounds());
595 VectorType *TrueMaskTy =
596 VectorType::get(ElementType: Builder.getInt1Ty(), EC: VectorLoadType->getElementCount());
597 Value *AllTrueMask = Constant::getAllOnesValue(Ty: TrueMaskTy);
598 Value *VectorLhsLoad = Builder.CreateIntrinsic(
599 ID: Intrinsic::vp_load, Types: {VectorLoadType, VectorLhsGep->getType()},
600 Args: {VectorLhsGep, AllTrueMask, VL}, FMFSource: nullptr, Name: "lhs.load");
601
602 Value *VectorRhsGep =
603 Builder.CreateGEP(Ty: LoadType, Ptr: PtrB, IdxList: GepOffset, Name: "", NW: GEPB->isInBounds());
604 Value *VectorRhsLoad = Builder.CreateIntrinsic(
605 ID: Intrinsic::vp_load, Types: {VectorLoadType, VectorLhsGep->getType()},
606 Args: {VectorRhsGep, AllTrueMask, VL}, FMFSource: nullptr, Name: "rhs.load");
607
608 Value *VectorMatchCmp =
609 Builder.CreateICmpNE(LHS: VectorLhsLoad, RHS: VectorRhsLoad, Name: "mismatch.cmp");
610 Value *CTZ = Builder.CreateIntrinsic(
611 ID: Intrinsic::vp_cttz_elts, Types: {ResType, VectorMatchCmp->getType()},
612 Args: {VectorMatchCmp, /*ZeroIsPoison=*/Builder.getInt1(V: false), AllTrueMask,
613 VL});
614 Value *MismatchFound = Builder.CreateICmpNE(LHS: CTZ, RHS: VL);
615 auto *VectorEarlyExit = BranchInst::Create(IfTrue: VectorLoopMismatchBlock,
616 IfFalse: VectorLoopIncBlock, Cond: MismatchFound);
617 Builder.Insert(I: VectorEarlyExit);
618
619 DTU.applyUpdates(
620 Updates: {{DominatorTree::Insert, VectorLoopStartBlock, VectorLoopMismatchBlock},
621 {DominatorTree::Insert, VectorLoopStartBlock, VectorLoopIncBlock}});
622
623 // Increment the index counter and calculate the predicate for the next
624 // iteration of the loop. We branch back to the start of the loop if there
625 // is at least one active lane.
626 Builder.SetInsertPoint(VectorLoopIncBlock);
627 Value *VL64 = Builder.CreateZExt(V: VL, DestTy: I64Type);
628 Value *NewVectorIndexPhi =
629 Builder.CreateAdd(LHS: VectorIndexPhi, RHS: VL64, Name: "",
630 /*HasNUW=*/true, /*HasNSW=*/true);
631 VectorIndexPhi->addIncoming(V: NewVectorIndexPhi, BB: VectorLoopIncBlock);
632 Value *ExitCond = Builder.CreateICmpNE(LHS: NewVectorIndexPhi, RHS: ExtEnd);
633 auto *VectorLoopBranchBack =
634 BranchInst::Create(IfTrue: VectorLoopStartBlock, IfFalse: EndBlock, Cond: ExitCond);
635 Builder.Insert(I: VectorLoopBranchBack);
636
637 DTU.applyUpdates(
638 Updates: {{DominatorTree::Insert, VectorLoopIncBlock, VectorLoopStartBlock},
639 {DominatorTree::Insert, VectorLoopIncBlock, EndBlock}});
640
641 // If we found a mismatch then we need to calculate which lane in the vector
642 // had a mismatch and add that on to the current loop index.
643 Builder.SetInsertPoint(VectorLoopMismatchBlock);
644
645 // Add LCSSA phis for CTZ and VectorIndexPhi.
646 auto *CTZLCSSAPhi = Builder.CreatePHI(Ty: CTZ->getType(), NumReservedValues: 1, Name: "ctz");
647 CTZLCSSAPhi->addIncoming(V: CTZ, BB: VectorLoopStartBlock);
648 auto *VectorIndexLCSSAPhi =
649 Builder.CreatePHI(Ty: VectorIndexPhi->getType(), NumReservedValues: 1, Name: "mismatch_vector_index");
650 VectorIndexLCSSAPhi->addIncoming(V: VectorIndexPhi, BB: VectorLoopStartBlock);
651
652 Value *CTZI64 = Builder.CreateZExt(V: CTZLCSSAPhi, DestTy: I64Type);
653 Value *VectorLoopRes64 = Builder.CreateAdd(LHS: VectorIndexLCSSAPhi, RHS: CTZI64, Name: "",
654 /*HasNUW=*/true, /*HasNSW=*/true);
655 return Builder.CreateTrunc(V: VectorLoopRes64, DestTy: ResType);
656}
657
658Value *LoopIdiomVectorize::expandFindMismatch(
659 IRBuilder<> &Builder, DomTreeUpdater &DTU, GetElementPtrInst *GEPA,
660 GetElementPtrInst *GEPB, Instruction *Index, Value *Start, Value *MaxLen) {
661 Value *PtrA = GEPA->getPointerOperand();
662 Value *PtrB = GEPB->getPointerOperand();
663
664 // Get the arguments and types for the intrinsic.
665 BasicBlock *Preheader = CurLoop->getLoopPreheader();
666 BranchInst *PHBranch = cast<BranchInst>(Val: Preheader->getTerminator());
667 LLVMContext &Ctx = PHBranch->getContext();
668 Type *LoadType = Type::getInt8Ty(C&: Ctx);
669 Type *ResType = Builder.getInt32Ty();
670
671 // Split block in the original loop preheader.
672 EndBlock = SplitBlock(Old: Preheader, SplitPt: PHBranch, DT, LI, MSSAU: nullptr, BBName: "mismatch_end");
673
674 // Create the blocks that we're going to need:
675 // 1. A block for checking the zero-extended length exceeds 0
676 // 2. A block to check that the start and end addresses of a given array
677 // lie on the same page.
678 // 3. The vector loop preheader.
679 // 4. The first vector loop block.
680 // 5. The vector loop increment block.
681 // 6. A block we can jump to from the vector loop when a mismatch is found.
682 // 7. The first block of the scalar loop itself, containing PHIs , loads
683 // and cmp.
684 // 8. A scalar loop increment block to increment the PHIs and go back
685 // around the loop.
686
687 BasicBlock *MinItCheckBlock = BasicBlock::Create(
688 Context&: Ctx, Name: "mismatch_min_it_check", Parent: EndBlock->getParent(), InsertBefore: EndBlock);
689
690 // Update the terminator added by SplitBlock to branch to the first block
691 Preheader->getTerminator()->setSuccessor(Idx: 0, BB: MinItCheckBlock);
692
693 BasicBlock *MemCheckBlock = BasicBlock::Create(
694 Context&: Ctx, Name: "mismatch_mem_check", Parent: EndBlock->getParent(), InsertBefore: EndBlock);
695
696 VectorLoopPreheaderBlock = BasicBlock::Create(
697 Context&: Ctx, Name: "mismatch_vec_loop_preheader", Parent: EndBlock->getParent(), InsertBefore: EndBlock);
698
699 VectorLoopStartBlock = BasicBlock::Create(Context&: Ctx, Name: "mismatch_vec_loop",
700 Parent: EndBlock->getParent(), InsertBefore: EndBlock);
701
702 VectorLoopIncBlock = BasicBlock::Create(Context&: Ctx, Name: "mismatch_vec_loop_inc",
703 Parent: EndBlock->getParent(), InsertBefore: EndBlock);
704
705 VectorLoopMismatchBlock = BasicBlock::Create(Context&: Ctx, Name: "mismatch_vec_loop_found",
706 Parent: EndBlock->getParent(), InsertBefore: EndBlock);
707
708 BasicBlock *LoopPreHeaderBlock = BasicBlock::Create(
709 Context&: Ctx, Name: "mismatch_loop_pre", Parent: EndBlock->getParent(), InsertBefore: EndBlock);
710
711 BasicBlock *LoopStartBlock =
712 BasicBlock::Create(Context&: Ctx, Name: "mismatch_loop", Parent: EndBlock->getParent(), InsertBefore: EndBlock);
713
714 BasicBlock *LoopIncBlock = BasicBlock::Create(
715 Context&: Ctx, Name: "mismatch_loop_inc", Parent: EndBlock->getParent(), InsertBefore: EndBlock);
716
717 DTU.applyUpdates(Updates: {{DominatorTree::Insert, Preheader, MinItCheckBlock},
718 {DominatorTree::Delete, Preheader, EndBlock}});
719
720 // Update LoopInfo with the new vector & scalar loops.
721 auto VectorLoop = LI->AllocateLoop();
722 auto ScalarLoop = LI->AllocateLoop();
723
724 if (CurLoop->getParentLoop()) {
725 CurLoop->getParentLoop()->addBasicBlockToLoop(NewBB: MinItCheckBlock, LI&: *LI);
726 CurLoop->getParentLoop()->addBasicBlockToLoop(NewBB: MemCheckBlock, LI&: *LI);
727 CurLoop->getParentLoop()->addBasicBlockToLoop(NewBB: VectorLoopPreheaderBlock,
728 LI&: *LI);
729 CurLoop->getParentLoop()->addChildLoop(NewChild: VectorLoop);
730 CurLoop->getParentLoop()->addBasicBlockToLoop(NewBB: VectorLoopMismatchBlock, LI&: *LI);
731 CurLoop->getParentLoop()->addBasicBlockToLoop(NewBB: LoopPreHeaderBlock, LI&: *LI);
732 CurLoop->getParentLoop()->addChildLoop(NewChild: ScalarLoop);
733 } else {
734 LI->addTopLevelLoop(New: VectorLoop);
735 LI->addTopLevelLoop(New: ScalarLoop);
736 }
737
738 // Add the new basic blocks to their associated loops.
739 VectorLoop->addBasicBlockToLoop(NewBB: VectorLoopStartBlock, LI&: *LI);
740 VectorLoop->addBasicBlockToLoop(NewBB: VectorLoopIncBlock, LI&: *LI);
741
742 ScalarLoop->addBasicBlockToLoop(NewBB: LoopStartBlock, LI&: *LI);
743 ScalarLoop->addBasicBlockToLoop(NewBB: LoopIncBlock, LI&: *LI);
744
745 // Set up some types and constants that we intend to reuse.
746 Type *I64Type = Builder.getInt64Ty();
747
748 // Check the zero-extended iteration count > 0
749 Builder.SetInsertPoint(MinItCheckBlock);
750 Value *ExtStart = Builder.CreateZExt(V: Start, DestTy: I64Type);
751 Value *ExtEnd = Builder.CreateZExt(V: MaxLen, DestTy: I64Type);
752 // This check doesn't really cost us very much.
753
754 Value *LimitCheck = Builder.CreateICmpULE(LHS: Start, RHS: MaxLen);
755 BranchInst *MinItCheckBr =
756 BranchInst::Create(IfTrue: MemCheckBlock, IfFalse: LoopPreHeaderBlock, Cond: LimitCheck);
757 MinItCheckBr->setMetadata(
758 KindID: LLVMContext::MD_prof,
759 Node: MDBuilder(MinItCheckBr->getContext()).createBranchWeights(TrueWeight: 99, FalseWeight: 1));
760 Builder.Insert(I: MinItCheckBr);
761
762 DTU.applyUpdates(
763 Updates: {{DominatorTree::Insert, MinItCheckBlock, MemCheckBlock},
764 {DominatorTree::Insert, MinItCheckBlock, LoopPreHeaderBlock}});
765
766 // For each of the arrays, check the start/end addresses are on the same
767 // page.
768 Builder.SetInsertPoint(MemCheckBlock);
769
770 // The early exit in the original loop means that when performing vector
771 // loads we are potentially reading ahead of the early exit. So we could
772 // fault if crossing a page boundary. Therefore, we create runtime memory
773 // checks based on the minimum page size as follows:
774 // 1. Calculate the addresses of the first memory accesses in the loop,
775 // i.e. LhsStart and RhsStart.
776 // 2. Get the last accessed addresses in the loop, i.e. LhsEnd and RhsEnd.
777 // 3. Determine which pages correspond to all the memory accesses, i.e
778 // LhsStartPage, LhsEndPage, RhsStartPage, RhsEndPage.
779 // 4. If LhsStartPage == LhsEndPage and RhsStartPage == RhsEndPage, then
780 // we know we won't cross any page boundaries in the loop so we can
781 // enter the vector loop! Otherwise we fall back on the scalar loop.
782 Value *LhsStartGEP = Builder.CreateGEP(Ty: LoadType, Ptr: PtrA, IdxList: ExtStart);
783 Value *RhsStartGEP = Builder.CreateGEP(Ty: LoadType, Ptr: PtrB, IdxList: ExtStart);
784 Value *RhsStart = Builder.CreatePtrToInt(V: RhsStartGEP, DestTy: I64Type);
785 Value *LhsStart = Builder.CreatePtrToInt(V: LhsStartGEP, DestTy: I64Type);
786 Value *LhsEndGEP = Builder.CreateGEP(Ty: LoadType, Ptr: PtrA, IdxList: ExtEnd);
787 Value *RhsEndGEP = Builder.CreateGEP(Ty: LoadType, Ptr: PtrB, IdxList: ExtEnd);
788 Value *LhsEnd = Builder.CreatePtrToInt(V: LhsEndGEP, DestTy: I64Type);
789 Value *RhsEnd = Builder.CreatePtrToInt(V: RhsEndGEP, DestTy: I64Type);
790
791 const uint64_t MinPageSize = TTI->getMinPageSize().value();
792 const uint64_t AddrShiftAmt = llvm::Log2_64(Value: MinPageSize);
793 Value *LhsStartPage = Builder.CreateLShr(LHS: LhsStart, RHS: AddrShiftAmt);
794 Value *LhsEndPage = Builder.CreateLShr(LHS: LhsEnd, RHS: AddrShiftAmt);
795 Value *RhsStartPage = Builder.CreateLShr(LHS: RhsStart, RHS: AddrShiftAmt);
796 Value *RhsEndPage = Builder.CreateLShr(LHS: RhsEnd, RHS: AddrShiftAmt);
797 Value *LhsPageCmp = Builder.CreateICmpNE(LHS: LhsStartPage, RHS: LhsEndPage);
798 Value *RhsPageCmp = Builder.CreateICmpNE(LHS: RhsStartPage, RHS: RhsEndPage);
799
800 Value *CombinedPageCmp = Builder.CreateOr(LHS: LhsPageCmp, RHS: RhsPageCmp);
801 BranchInst *CombinedPageCmpCmpBr = BranchInst::Create(
802 IfTrue: LoopPreHeaderBlock, IfFalse: VectorLoopPreheaderBlock, Cond: CombinedPageCmp);
803 CombinedPageCmpCmpBr->setMetadata(
804 KindID: LLVMContext::MD_prof, Node: MDBuilder(CombinedPageCmpCmpBr->getContext())
805 .createBranchWeights(TrueWeight: 10, FalseWeight: 90));
806 Builder.Insert(I: CombinedPageCmpCmpBr);
807
808 DTU.applyUpdates(
809 Updates: {{DominatorTree::Insert, MemCheckBlock, LoopPreHeaderBlock},
810 {DominatorTree::Insert, MemCheckBlock, VectorLoopPreheaderBlock}});
811
812 // Set up the vector loop preheader, i.e. calculate initial loop predicate,
813 // zero-extend MaxLen to 64-bits, determine the number of vector elements
814 // processed in each iteration, etc.
815 Builder.SetInsertPoint(VectorLoopPreheaderBlock);
816
817 // At this point we know two things must be true:
818 // 1. Start <= End
819 // 2. ExtMaxLen <= MinPageSize due to the page checks.
820 // Therefore, we know that we can use a 64-bit induction variable that
821 // starts from 0 -> ExtMaxLen and it will not overflow.
822 Value *VectorLoopRes = nullptr;
823 switch (VectorizeStyle) {
824 case LoopIdiomVectorizeStyle::Masked:
825 VectorLoopRes =
826 createMaskedFindMismatch(Builder, DTU, GEPA, GEPB, ExtStart, ExtEnd);
827 break;
828 case LoopIdiomVectorizeStyle::Predicated:
829 VectorLoopRes = createPredicatedFindMismatch(Builder, DTU, GEPA, GEPB,
830 ExtStart, ExtEnd);
831 break;
832 }
833
834 Builder.Insert(I: BranchInst::Create(IfTrue: EndBlock));
835
836 DTU.applyUpdates(
837 Updates: {{DominatorTree::Insert, VectorLoopMismatchBlock, EndBlock}});
838
839 // Generate code for scalar loop.
840 Builder.SetInsertPoint(LoopPreHeaderBlock);
841 Builder.Insert(I: BranchInst::Create(IfTrue: LoopStartBlock));
842
843 DTU.applyUpdates(
844 Updates: {{DominatorTree::Insert, LoopPreHeaderBlock, LoopStartBlock}});
845
846 Builder.SetInsertPoint(LoopStartBlock);
847 PHINode *IndexPhi = Builder.CreatePHI(Ty: ResType, NumReservedValues: 2, Name: "mismatch_index");
848 IndexPhi->addIncoming(V: Start, BB: LoopPreHeaderBlock);
849
850 // Otherwise compare the values
851 // Load bytes from each array and compare them.
852 Value *GepOffset = Builder.CreateZExt(V: IndexPhi, DestTy: I64Type);
853
854 Value *LhsGep =
855 Builder.CreateGEP(Ty: LoadType, Ptr: PtrA, IdxList: GepOffset, Name: "", NW: GEPA->isInBounds());
856 Value *LhsLoad = Builder.CreateLoad(Ty: LoadType, Ptr: LhsGep);
857
858 Value *RhsGep =
859 Builder.CreateGEP(Ty: LoadType, Ptr: PtrB, IdxList: GepOffset, Name: "", NW: GEPB->isInBounds());
860 Value *RhsLoad = Builder.CreateLoad(Ty: LoadType, Ptr: RhsGep);
861
862 Value *MatchCmp = Builder.CreateICmpEQ(LHS: LhsLoad, RHS: RhsLoad);
863 // If we have a mismatch then exit the loop ...
864 BranchInst *MatchCmpBr = BranchInst::Create(IfTrue: LoopIncBlock, IfFalse: EndBlock, Cond: MatchCmp);
865 Builder.Insert(I: MatchCmpBr);
866
867 DTU.applyUpdates(Updates: {{DominatorTree::Insert, LoopStartBlock, LoopIncBlock},
868 {DominatorTree::Insert, LoopStartBlock, EndBlock}});
869
870 // Have we reached the maximum permitted length for the loop?
871 Builder.SetInsertPoint(LoopIncBlock);
872 Value *PhiInc = Builder.CreateAdd(LHS: IndexPhi, RHS: ConstantInt::get(Ty: ResType, V: 1), Name: "",
873 /*HasNUW=*/Index->hasNoUnsignedWrap(),
874 /*HasNSW=*/Index->hasNoSignedWrap());
875 IndexPhi->addIncoming(V: PhiInc, BB: LoopIncBlock);
876 Value *IVCmp = Builder.CreateICmpEQ(LHS: PhiInc, RHS: MaxLen);
877 BranchInst *IVCmpBr = BranchInst::Create(IfTrue: EndBlock, IfFalse: LoopStartBlock, Cond: IVCmp);
878 Builder.Insert(I: IVCmpBr);
879
880 DTU.applyUpdates(Updates: {{DominatorTree::Insert, LoopIncBlock, EndBlock},
881 {DominatorTree::Insert, LoopIncBlock, LoopStartBlock}});
882
883 // In the end block we need to insert a PHI node to deal with three cases:
884 // 1. We didn't find a mismatch in the scalar loop, so we return MaxLen.
885 // 2. We exitted the scalar loop early due to a mismatch and need to return
886 // the index that we found.
887 // 3. We didn't find a mismatch in the vector loop, so we return MaxLen.
888 // 4. We exitted the vector loop early due to a mismatch and need to return
889 // the index that we found.
890 Builder.SetInsertPoint(TheBB: EndBlock, IP: EndBlock->getFirstInsertionPt());
891 PHINode *ResPhi = Builder.CreatePHI(Ty: ResType, NumReservedValues: 4, Name: "mismatch_result");
892 ResPhi->addIncoming(V: MaxLen, BB: LoopIncBlock);
893 ResPhi->addIncoming(V: IndexPhi, BB: LoopStartBlock);
894 ResPhi->addIncoming(V: MaxLen, BB: VectorLoopIncBlock);
895 ResPhi->addIncoming(V: VectorLoopRes, BB: VectorLoopMismatchBlock);
896
897 Value *FinalRes = Builder.CreateTrunc(V: ResPhi, DestTy: ResType);
898
899 if (VerifyLoops) {
900 ScalarLoop->verifyLoop();
901 VectorLoop->verifyLoop();
902 if (!VectorLoop->isRecursivelyLCSSAForm(DT: *DT, LI: *LI))
903 report_fatal_error(reason: "Loops must remain in LCSSA form!");
904 if (!ScalarLoop->isRecursivelyLCSSAForm(DT: *DT, LI: *LI))
905 report_fatal_error(reason: "Loops must remain in LCSSA form!");
906 }
907
908 return FinalRes;
909}
910
911void LoopIdiomVectorize::transformByteCompare(GetElementPtrInst *GEPA,
912 GetElementPtrInst *GEPB,
913 PHINode *IndPhi, Value *MaxLen,
914 Instruction *Index, Value *Start,
915 bool IncIdx, BasicBlock *FoundBB,
916 BasicBlock *EndBB) {
917
918 // Insert the byte compare code at the end of the preheader block
919 BasicBlock *Preheader = CurLoop->getLoopPreheader();
920 BasicBlock *Header = CurLoop->getHeader();
921 BranchInst *PHBranch = cast<BranchInst>(Val: Preheader->getTerminator());
922 IRBuilder<> Builder(PHBranch);
923 DomTreeUpdater DTU(DT, DomTreeUpdater::UpdateStrategy::Lazy);
924 Builder.SetCurrentDebugLocation(PHBranch->getDebugLoc());
925
926 // Increment the pointer if this was done before the loads in the loop.
927 if (IncIdx)
928 Start = Builder.CreateAdd(LHS: Start, RHS: ConstantInt::get(Ty: Start->getType(), V: 1));
929
930 Value *ByteCmpRes =
931 expandFindMismatch(Builder, DTU, GEPA, GEPB, Index, Start, MaxLen);
932
933 // Replaces uses of index & induction Phi with intrinsic (we already
934 // checked that the the first instruction of Header is the Phi above).
935 assert(IndPhi->hasOneUse() && "Index phi node has more than one use!");
936 Index->replaceAllUsesWith(V: ByteCmpRes);
937
938 assert(PHBranch->isUnconditional() &&
939 "Expected preheader to terminate with an unconditional branch.");
940
941 // If no mismatch was found, we can jump to the end block. Create a
942 // new basic block for the compare instruction.
943 auto *CmpBB = BasicBlock::Create(Context&: Preheader->getContext(), Name: "byte.compare",
944 Parent: Preheader->getParent());
945 CmpBB->moveBefore(MovePos: EndBB);
946
947 // Replace the branch in the preheader with an always-true conditional branch.
948 // This ensures there is still a reference to the original loop.
949 Builder.CreateCondBr(Cond: Builder.getTrue(), True: CmpBB, False: Header);
950 PHBranch->eraseFromParent();
951
952 BasicBlock *MismatchEnd = cast<Instruction>(Val: ByteCmpRes)->getParent();
953 DTU.applyUpdates(Updates: {{DominatorTree::Insert, MismatchEnd, CmpBB}});
954
955 // Create the branch to either the end or found block depending on the value
956 // returned by the intrinsic.
957 Builder.SetInsertPoint(CmpBB);
958 if (FoundBB != EndBB) {
959 Value *FoundCmp = Builder.CreateICmpEQ(LHS: ByteCmpRes, RHS: MaxLen);
960 Builder.CreateCondBr(Cond: FoundCmp, True: EndBB, False: FoundBB);
961 DTU.applyUpdates(Updates: {{DominatorTree::Insert, CmpBB, FoundBB},
962 {DominatorTree::Insert, CmpBB, EndBB}});
963
964 } else {
965 Builder.CreateBr(Dest: FoundBB);
966 DTU.applyUpdates(Updates: {{DominatorTree::Insert, CmpBB, FoundBB}});
967 }
968
969 // Ensure all Phis in the successors of CmpBB have an incoming value from it.
970 fixSuccessorPhis(L: CurLoop, ScalarRes: ByteCmpRes, VectorRes: ByteCmpRes, SuccBB: EndBB, IncBB: CmpBB);
971 if (EndBB != FoundBB)
972 fixSuccessorPhis(L: CurLoop, ScalarRes: ByteCmpRes, VectorRes: ByteCmpRes, SuccBB: FoundBB, IncBB: CmpBB);
973
974 // The new CmpBB block isn't part of the loop, but will need to be added to
975 // the outer loop if there is one.
976 if (!CurLoop->isOutermost())
977 CurLoop->getParentLoop()->addBasicBlockToLoop(NewBB: CmpBB, LI&: *LI);
978
979 if (VerifyLoops && CurLoop->getParentLoop()) {
980 CurLoop->getParentLoop()->verifyLoop();
981 if (!CurLoop->getParentLoop()->isRecursivelyLCSSAForm(DT: *DT, LI: *LI))
982 report_fatal_error(reason: "Loops must remain in LCSSA form!");
983 }
984}
985
986bool LoopIdiomVectorize::recognizeFindFirstByte() {
987 // Currently the transformation only works on scalable vector types, although
988 // there is no fundamental reason why it cannot be made to work for fixed
989 // vectors. We also need to know the target's minimum page size in order to
990 // generate runtime memory checks to ensure the vector version won't fault.
991 if (!TTI->supportsScalableVectors() || !TTI->getMinPageSize().has_value() ||
992 DisableFindFirstByte)
993 return false;
994
995 // Define some constants we need throughout.
996 BasicBlock *Header = CurLoop->getHeader();
997 LLVMContext &Ctx = Header->getContext();
998
999 // We are expecting the four blocks defined below: Header, MatchBB, InnerBB,
1000 // and OuterBB. For now, we will bail our for almost anything else. The Four
1001 // blocks contain one nested loop.
1002 if (CurLoop->getNumBackEdges() != 1 || CurLoop->getNumBlocks() != 4 ||
1003 CurLoop->getSubLoops().size() != 1)
1004 return false;
1005
1006 auto *InnerLoop = CurLoop->getSubLoops().front();
1007 PHINode *IndPhi = dyn_cast<PHINode>(Val: &Header->front());
1008 if (!IndPhi || IndPhi->getNumIncomingValues() != 2)
1009 return false;
1010
1011 // Check instruction counts.
1012 auto LoopBlocks = CurLoop->getBlocks();
1013 if (LoopBlocks[0]->sizeWithoutDebug() > 3 ||
1014 LoopBlocks[1]->sizeWithoutDebug() > 4 ||
1015 LoopBlocks[2]->sizeWithoutDebug() > 3 ||
1016 LoopBlocks[3]->sizeWithoutDebug() > 3)
1017 return false;
1018
1019 // Check that no instruction other than IndPhi has outside uses.
1020 for (BasicBlock *BB : LoopBlocks)
1021 for (Instruction &I : *BB)
1022 if (&I != IndPhi)
1023 for (User *U : I.users())
1024 if (!CurLoop->contains(Inst: cast<Instruction>(Val: U)))
1025 return false;
1026
1027 // Match the branch instruction in the header. We are expecting an
1028 // unconditional branch to the inner loop.
1029 //
1030 // Header:
1031 // %14 = phi ptr [ %24, %OuterBB ], [ %3, %Header.preheader ]
1032 // %15 = load i8, ptr %14, align 1
1033 // br label %MatchBB
1034 BasicBlock *MatchBB;
1035 if (!match(V: Header->getTerminator(), P: m_UnconditionalBr(Succ&: MatchBB)) ||
1036 !InnerLoop->contains(BB: MatchBB))
1037 return false;
1038
1039 // MatchBB should be the entrypoint into the inner loop containing the
1040 // comparison between a search element and a needle.
1041 //
1042 // MatchBB:
1043 // %20 = phi ptr [ %7, %Header ], [ %17, %InnerBB ]
1044 // %21 = load i8, ptr %20, align 1
1045 // %22 = icmp eq i8 %15, %21
1046 // br i1 %22, label %ExitSucc, label %InnerBB
1047 BasicBlock *ExitSucc, *InnerBB;
1048 Value *LoadSearch, *LoadNeedle;
1049 CmpPredicate MatchPred;
1050 if (!match(V: MatchBB->getTerminator(),
1051 P: m_Br(C: m_ICmp(Pred&: MatchPred, L: m_Value(V&: LoadSearch), R: m_Value(V&: LoadNeedle)),
1052 T: m_BasicBlock(V&: ExitSucc), F: m_BasicBlock(V&: InnerBB))) ||
1053 MatchPred != ICmpInst::ICMP_EQ || !InnerLoop->contains(BB: InnerBB))
1054 return false;
1055
1056 // We expect outside uses of `IndPhi' in ExitSucc (and only there).
1057 for (User *U : IndPhi->users())
1058 if (!CurLoop->contains(Inst: cast<Instruction>(Val: U))) {
1059 auto *PN = dyn_cast<PHINode>(Val: U);
1060 if (!PN || PN->getParent() != ExitSucc)
1061 return false;
1062 }
1063
1064 // Match the loads and check they are simple.
1065 Value *Search, *Needle;
1066 if (!match(V: LoadSearch, P: m_Load(Op: m_Value(V&: Search))) ||
1067 !match(V: LoadNeedle, P: m_Load(Op: m_Value(V&: Needle))) ||
1068 !cast<LoadInst>(Val: LoadSearch)->isSimple() ||
1069 !cast<LoadInst>(Val: LoadNeedle)->isSimple())
1070 return false;
1071
1072 // Check we are loading valid characters.
1073 Type *CharTy = LoadSearch->getType();
1074 if (!CharTy->isIntegerTy() || LoadNeedle->getType() != CharTy)
1075 return false;
1076
1077 // Pick the vectorisation factor based on CharTy, work out the cost of the
1078 // match intrinsic and decide if we should use it.
1079 // Note: For the time being we assume 128-bit vectors.
1080 unsigned VF = 128 / CharTy->getIntegerBitWidth();
1081 SmallVector<Type *> Args = {
1082 ScalableVectorType::get(ElementType: CharTy, MinNumElts: VF), FixedVectorType::get(ElementType: CharTy, NumElts: VF),
1083 ScalableVectorType::get(ElementType: Type::getInt1Ty(C&: Ctx), MinNumElts: VF)};
1084 IntrinsicCostAttributes Attrs(Intrinsic::experimental_vector_match, Args[2],
1085 Args);
1086 if (TTI->getIntrinsicInstrCost(ICA: Attrs, CostKind: TTI::TCK_SizeAndLatency) > 4)
1087 return false;
1088
1089 // The loads come from two PHIs, each with two incoming values.
1090 PHINode *PSearch = dyn_cast<PHINode>(Val: Search);
1091 PHINode *PNeedle = dyn_cast<PHINode>(Val: Needle);
1092 if (!PSearch || PSearch->getNumIncomingValues() != 2 || !PNeedle ||
1093 PNeedle->getNumIncomingValues() != 2)
1094 return false;
1095
1096 // One PHI comes from the outer loop (PSearch), the other one from the inner
1097 // loop (PNeedle). PSearch effectively corresponds to IndPhi.
1098 if (InnerLoop->contains(Inst: PSearch))
1099 std::swap(a&: PSearch, b&: PNeedle);
1100 if (PSearch != &Header->front() || PNeedle != &MatchBB->front())
1101 return false;
1102
1103 // The incoming values of both PHI nodes should be a gep of 1.
1104 Value *SearchStart = PSearch->getIncomingValue(i: 0);
1105 Value *SearchIndex = PSearch->getIncomingValue(i: 1);
1106 if (CurLoop->contains(BB: PSearch->getIncomingBlock(i: 0)))
1107 std::swap(a&: SearchStart, b&: SearchIndex);
1108
1109 Value *NeedleStart = PNeedle->getIncomingValue(i: 0);
1110 Value *NeedleIndex = PNeedle->getIncomingValue(i: 1);
1111 if (InnerLoop->contains(BB: PNeedle->getIncomingBlock(i: 0)))
1112 std::swap(a&: NeedleStart, b&: NeedleIndex);
1113
1114 // Match the GEPs.
1115 if (!match(V: SearchIndex, P: m_GEP(Ops: m_Specific(V: PSearch), Ops: m_One())) ||
1116 !match(V: NeedleIndex, P: m_GEP(Ops: m_Specific(V: PNeedle), Ops: m_One())))
1117 return false;
1118
1119 // Check the GEPs result type matches `CharTy'.
1120 GetElementPtrInst *GEPSearch = cast<GetElementPtrInst>(Val: SearchIndex);
1121 GetElementPtrInst *GEPNeedle = cast<GetElementPtrInst>(Val: NeedleIndex);
1122 if (GEPSearch->getResultElementType() != CharTy ||
1123 GEPNeedle->getResultElementType() != CharTy)
1124 return false;
1125
1126 // InnerBB should increment the address of the needle pointer.
1127 //
1128 // InnerBB:
1129 // %17 = getelementptr inbounds i8, ptr %20, i64 1
1130 // %18 = icmp eq ptr %17, %10
1131 // br i1 %18, label %OuterBB, label %MatchBB
1132 BasicBlock *OuterBB;
1133 Value *NeedleEnd;
1134 if (!match(V: InnerBB->getTerminator(),
1135 P: m_Br(C: m_SpecificICmp(MatchPred: ICmpInst::ICMP_EQ, L: m_Specific(V: GEPNeedle),
1136 R: m_Value(V&: NeedleEnd)),
1137 T: m_BasicBlock(V&: OuterBB), F: m_Specific(V: MatchBB))) ||
1138 !CurLoop->contains(BB: OuterBB))
1139 return false;
1140
1141 // OuterBB should increment the address of the search element pointer.
1142 //
1143 // OuterBB:
1144 // %24 = getelementptr inbounds i8, ptr %14, i64 1
1145 // %25 = icmp eq ptr %24, %6
1146 // br i1 %25, label %ExitFail, label %Header
1147 BasicBlock *ExitFail;
1148 Value *SearchEnd;
1149 if (!match(V: OuterBB->getTerminator(),
1150 P: m_Br(C: m_SpecificICmp(MatchPred: ICmpInst::ICMP_EQ, L: m_Specific(V: GEPSearch),
1151 R: m_Value(V&: SearchEnd)),
1152 T: m_BasicBlock(V&: ExitFail), F: m_Specific(V: Header))))
1153 return false;
1154
1155 if (!CurLoop->isLoopInvariant(V: SearchStart) ||
1156 !CurLoop->isLoopInvariant(V: SearchEnd) ||
1157 !CurLoop->isLoopInvariant(V: NeedleStart) ||
1158 !CurLoop->isLoopInvariant(V: NeedleEnd))
1159 return false;
1160
1161 LLVM_DEBUG(dbgs() << "Found idiom in loop: \n" << *CurLoop << "\n\n");
1162
1163 transformFindFirstByte(IndPhi, VF, CharTy, ExitSucc, ExitFail, SearchStart,
1164 SearchEnd, NeedleStart, NeedleEnd);
1165 return true;
1166}
1167
1168Value *LoopIdiomVectorize::expandFindFirstByte(
1169 IRBuilder<> &Builder, DomTreeUpdater &DTU, unsigned VF, Type *CharTy,
1170 Value *IndPhi, BasicBlock *ExitSucc, BasicBlock *ExitFail,
1171 Value *SearchStart, Value *SearchEnd, Value *NeedleStart,
1172 Value *NeedleEnd) {
1173 // Set up some types and constants that we intend to reuse.
1174 auto *PtrTy = Builder.getPtrTy();
1175 auto *I64Ty = Builder.getInt64Ty();
1176 auto *PredVTy = ScalableVectorType::get(ElementType: Builder.getInt1Ty(), MinNumElts: VF);
1177 auto *CharVTy = ScalableVectorType::get(ElementType: CharTy, MinNumElts: VF);
1178 auto *ConstVF = ConstantInt::get(Ty: I64Ty, V: VF);
1179
1180 // Other common arguments.
1181 BasicBlock *Preheader = CurLoop->getLoopPreheader();
1182 LLVMContext &Ctx = Preheader->getContext();
1183 Value *Passthru = ConstantInt::getNullValue(Ty: CharVTy);
1184
1185 // Split block in the original loop preheader.
1186 // SPH is the new preheader to the old scalar loop.
1187 BasicBlock *SPH = SplitBlock(Old: Preheader, SplitPt: Preheader->getTerminator(), DT, LI,
1188 MSSAU: nullptr, BBName: "scalar_preheader");
1189
1190 // Create the blocks that we're going to use.
1191 //
1192 // We will have the following loops:
1193 // (O) Outer loop where we iterate over the elements of the search array.
1194 // (I) Inner loop where we iterate over the elements of the needle array.
1195 //
1196 // Overall, the blocks do the following:
1197 // (0) Check if the arrays can't cross page boundaries. If so go to (1),
1198 // otherwise fall back to the original scalar loop.
1199 // (1) Load the search array. Go to (2).
1200 // (2) (a) Load the needle array.
1201 // (b) Splat the first element to the inactive lanes.
1202 // (c) Check if any elements match. If so go to (3), otherwise go to (4).
1203 // (3) Compute the index of the first match and exit.
1204 // (4) Check if we've reached the end of the needle array. If not loop back to
1205 // (2), otherwise go to (5).
1206 // (5) Check if we've reached the end of the search array. If not loop back to
1207 // (1), otherwise exit.
1208 // Blocks (0,3) are not part of any loop. Blocks (1,5) and (2,4) belong to
1209 // the outer and inner loops, respectively.
1210 BasicBlock *BB0 = BasicBlock::Create(Context&: Ctx, Name: "mem_check", Parent: SPH->getParent(), InsertBefore: SPH);
1211 BasicBlock *BB1 =
1212 BasicBlock::Create(Context&: Ctx, Name: "find_first_vec_header", Parent: SPH->getParent(), InsertBefore: SPH);
1213 BasicBlock *BB2 =
1214 BasicBlock::Create(Context&: Ctx, Name: "match_check_vec", Parent: SPH->getParent(), InsertBefore: SPH);
1215 BasicBlock *BB3 =
1216 BasicBlock::Create(Context&: Ctx, Name: "calculate_match", Parent: SPH->getParent(), InsertBefore: SPH);
1217 BasicBlock *BB4 =
1218 BasicBlock::Create(Context&: Ctx, Name: "needle_check_vec", Parent: SPH->getParent(), InsertBefore: SPH);
1219 BasicBlock *BB5 =
1220 BasicBlock::Create(Context&: Ctx, Name: "search_check_vec", Parent: SPH->getParent(), InsertBefore: SPH);
1221
1222 // Update LoopInfo with the new loops.
1223 auto OuterLoop = LI->AllocateLoop();
1224 auto InnerLoop = LI->AllocateLoop();
1225
1226 if (auto ParentLoop = CurLoop->getParentLoop()) {
1227 ParentLoop->addBasicBlockToLoop(NewBB: BB0, LI&: *LI);
1228 ParentLoop->addChildLoop(NewChild: OuterLoop);
1229 ParentLoop->addBasicBlockToLoop(NewBB: BB3, LI&: *LI);
1230 } else {
1231 LI->addTopLevelLoop(New: OuterLoop);
1232 }
1233
1234 // Add the inner loop to the outer.
1235 OuterLoop->addChildLoop(NewChild: InnerLoop);
1236
1237 // Add the new basic blocks to the corresponding loops.
1238 OuterLoop->addBasicBlockToLoop(NewBB: BB1, LI&: *LI);
1239 OuterLoop->addBasicBlockToLoop(NewBB: BB5, LI&: *LI);
1240 InnerLoop->addBasicBlockToLoop(NewBB: BB2, LI&: *LI);
1241 InnerLoop->addBasicBlockToLoop(NewBB: BB4, LI&: *LI);
1242
1243 // Update the terminator added by SplitBlock to branch to the first block.
1244 Preheader->getTerminator()->setSuccessor(Idx: 0, BB: BB0);
1245 DTU.applyUpdates(Updates: {{DominatorTree::Delete, Preheader, SPH},
1246 {DominatorTree::Insert, Preheader, BB0}});
1247
1248 // (0) Check if we could be crossing a page boundary; if so, fallback to the
1249 // old scalar loops. Also create a predicate of VF elements to be used in the
1250 // vector loops.
1251 Builder.SetInsertPoint(BB0);
1252 Value *ISearchStart =
1253 Builder.CreatePtrToInt(V: SearchStart, DestTy: I64Ty, Name: "search_start_int");
1254 Value *ISearchEnd =
1255 Builder.CreatePtrToInt(V: SearchEnd, DestTy: I64Ty, Name: "search_end_int");
1256 Value *INeedleStart =
1257 Builder.CreatePtrToInt(V: NeedleStart, DestTy: I64Ty, Name: "needle_start_int");
1258 Value *INeedleEnd =
1259 Builder.CreatePtrToInt(V: NeedleEnd, DestTy: I64Ty, Name: "needle_end_int");
1260 Value *PredVF =
1261 Builder.CreateIntrinsic(ID: Intrinsic::get_active_lane_mask, Types: {PredVTy, I64Ty},
1262 Args: {ConstantInt::get(Ty: I64Ty, V: 0), ConstVF});
1263
1264 const uint64_t MinPageSize = TTI->getMinPageSize().value();
1265 const uint64_t AddrShiftAmt = llvm::Log2_64(Value: MinPageSize);
1266 Value *SearchStartPage =
1267 Builder.CreateLShr(LHS: ISearchStart, RHS: AddrShiftAmt, Name: "search_start_page");
1268 Value *SearchEndPage =
1269 Builder.CreateLShr(LHS: ISearchEnd, RHS: AddrShiftAmt, Name: "search_end_page");
1270 Value *NeedleStartPage =
1271 Builder.CreateLShr(LHS: INeedleStart, RHS: AddrShiftAmt, Name: "needle_start_page");
1272 Value *NeedleEndPage =
1273 Builder.CreateLShr(LHS: INeedleEnd, RHS: AddrShiftAmt, Name: "needle_end_page");
1274 Value *SearchPageCmp =
1275 Builder.CreateICmpNE(LHS: SearchStartPage, RHS: SearchEndPage, Name: "search_page_cmp");
1276 Value *NeedlePageCmp =
1277 Builder.CreateICmpNE(LHS: NeedleStartPage, RHS: NeedleEndPage, Name: "needle_page_cmp");
1278
1279 Value *CombinedPageCmp =
1280 Builder.CreateOr(LHS: SearchPageCmp, RHS: NeedlePageCmp, Name: "combined_page_cmp");
1281 BranchInst *CombinedPageBr = Builder.CreateCondBr(Cond: CombinedPageCmp, True: SPH, False: BB1);
1282 CombinedPageBr->setMetadata(KindID: LLVMContext::MD_prof,
1283 Node: MDBuilder(Ctx).createBranchWeights(TrueWeight: 10, FalseWeight: 90));
1284 DTU.applyUpdates(
1285 Updates: {{DominatorTree::Insert, BB0, SPH}, {DominatorTree::Insert, BB0, BB1}});
1286
1287 // (1) Load the search array and branch to the inner loop.
1288 Builder.SetInsertPoint(BB1);
1289 PHINode *Search = Builder.CreatePHI(Ty: PtrTy, NumReservedValues: 2, Name: "psearch");
1290 Value *PredSearch = Builder.CreateIntrinsic(
1291 ID: Intrinsic::get_active_lane_mask, Types: {PredVTy, I64Ty},
1292 Args: {Builder.CreatePtrToInt(V: Search, DestTy: I64Ty), ISearchEnd}, FMFSource: nullptr,
1293 Name: "search_pred");
1294 PredSearch = Builder.CreateAnd(LHS: PredVF, RHS: PredSearch, Name: "search_masked");
1295 Value *LoadSearch = Builder.CreateMaskedLoad(
1296 Ty: CharVTy, Ptr: Search, Alignment: Align(1), Mask: PredSearch, PassThru: Passthru, Name: "search_load_vec");
1297 Builder.CreateBr(Dest: BB2);
1298 DTU.applyUpdates(Updates: {{DominatorTree::Insert, BB1, BB2}});
1299
1300 // (2) Inner loop.
1301 Builder.SetInsertPoint(BB2);
1302 PHINode *Needle = Builder.CreatePHI(Ty: PtrTy, NumReservedValues: 2, Name: "pneedle");
1303
1304 // (2.a) Load the needle array.
1305 Value *PredNeedle = Builder.CreateIntrinsic(
1306 ID: Intrinsic::get_active_lane_mask, Types: {PredVTy, I64Ty},
1307 Args: {Builder.CreatePtrToInt(V: Needle, DestTy: I64Ty), INeedleEnd}, FMFSource: nullptr,
1308 Name: "needle_pred");
1309 PredNeedle = Builder.CreateAnd(LHS: PredVF, RHS: PredNeedle, Name: "needle_masked");
1310 Value *LoadNeedle = Builder.CreateMaskedLoad(
1311 Ty: CharVTy, Ptr: Needle, Alignment: Align(1), Mask: PredNeedle, PassThru: Passthru, Name: "needle_load_vec");
1312
1313 // (2.b) Splat the first element to the inactive lanes.
1314 Value *Needle0 =
1315 Builder.CreateExtractElement(Vec: LoadNeedle, Idx: uint64_t(0), Name: "needle0");
1316 Value *Needle0Splat = Builder.CreateVectorSplat(EC: ElementCount::getScalable(MinVal: VF),
1317 V: Needle0, Name: "needle0");
1318 LoadNeedle = Builder.CreateSelect(C: PredNeedle, True: LoadNeedle, False: Needle0Splat,
1319 Name: "needle_splat");
1320 LoadNeedle = Builder.CreateExtractVector(
1321 DstType: FixedVectorType::get(ElementType: CharTy, NumElts: VF), SrcVec: LoadNeedle, Idx: uint64_t(0), Name: "needle_vec");
1322
1323 // (2.c) Test if there's a match.
1324 Value *MatchPred = Builder.CreateIntrinsic(
1325 ID: Intrinsic::experimental_vector_match, Types: {CharVTy, LoadNeedle->getType()},
1326 Args: {LoadSearch, LoadNeedle, PredSearch}, FMFSource: nullptr, Name: "match_pred");
1327 Value *IfAnyMatch = Builder.CreateOrReduce(Src: MatchPred);
1328 Builder.CreateCondBr(Cond: IfAnyMatch, True: BB3, False: BB4);
1329 DTU.applyUpdates(
1330 Updates: {{DominatorTree::Insert, BB2, BB3}, {DominatorTree::Insert, BB2, BB4}});
1331
1332 // (3) We found a match. Compute the index of its location and exit.
1333 Builder.SetInsertPoint(BB3);
1334 PHINode *MatchLCSSA = Builder.CreatePHI(Ty: PtrTy, NumReservedValues: 1, Name: "match_start");
1335 PHINode *MatchPredLCSSA =
1336 Builder.CreatePHI(Ty: MatchPred->getType(), NumReservedValues: 1, Name: "match_vec");
1337 Value *MatchCnt = Builder.CreateIntrinsic(
1338 ID: Intrinsic::experimental_cttz_elts, Types: {I64Ty, MatchPred->getType()},
1339 Args: {MatchPredLCSSA, /*ZeroIsPoison=*/Builder.getInt1(V: true)}, FMFSource: nullptr,
1340 Name: "match_idx");
1341 Value *MatchVal =
1342 Builder.CreateGEP(Ty: CharTy, Ptr: MatchLCSSA, IdxList: MatchCnt, Name: "match_res");
1343 Builder.CreateBr(Dest: ExitSucc);
1344 DTU.applyUpdates(Updates: {{DominatorTree::Insert, BB3, ExitSucc}});
1345
1346 // (4) Check if we've reached the end of the needle array.
1347 Builder.SetInsertPoint(BB4);
1348 Value *NextNeedle =
1349 Builder.CreateGEP(Ty: CharTy, Ptr: Needle, IdxList: ConstVF, Name: "needle_next_vec");
1350 Builder.CreateCondBr(Cond: Builder.CreateICmpULT(LHS: NextNeedle, RHS: NeedleEnd), True: BB2, False: BB5);
1351 DTU.applyUpdates(
1352 Updates: {{DominatorTree::Insert, BB4, BB2}, {DominatorTree::Insert, BB4, BB5}});
1353
1354 // (5) Check if we've reached the end of the search array.
1355 Builder.SetInsertPoint(BB5);
1356 Value *NextSearch =
1357 Builder.CreateGEP(Ty: CharTy, Ptr: Search, IdxList: ConstVF, Name: "search_next_vec");
1358 Builder.CreateCondBr(Cond: Builder.CreateICmpULT(LHS: NextSearch, RHS: SearchEnd), True: BB1,
1359 False: ExitFail);
1360 DTU.applyUpdates(Updates: {{DominatorTree::Insert, BB5, BB1},
1361 {DominatorTree::Insert, BB5, ExitFail}});
1362
1363 // Set up the PHI nodes.
1364 Search->addIncoming(V: SearchStart, BB: BB0);
1365 Search->addIncoming(V: NextSearch, BB: BB5);
1366 Needle->addIncoming(V: NeedleStart, BB: BB1);
1367 Needle->addIncoming(V: NextNeedle, BB: BB4);
1368 // These are needed to retain LCSSA form.
1369 MatchLCSSA->addIncoming(V: Search, BB: BB2);
1370 MatchPredLCSSA->addIncoming(V: MatchPred, BB: BB2);
1371
1372 // Ensure all Phis in the successors of BB3/BB5 have an incoming value from
1373 // them.
1374 fixSuccessorPhis(L: CurLoop, ScalarRes: IndPhi, VectorRes: MatchVal, SuccBB: ExitSucc, IncBB: BB3);
1375 if (ExitSucc != ExitFail)
1376 fixSuccessorPhis(L: CurLoop, ScalarRes: IndPhi, VectorRes: MatchVal, SuccBB: ExitFail, IncBB: BB5);
1377
1378 if (VerifyLoops) {
1379 OuterLoop->verifyLoop();
1380 InnerLoop->verifyLoop();
1381 if (!OuterLoop->isRecursivelyLCSSAForm(DT: *DT, LI: *LI))
1382 report_fatal_error(reason: "Loops must remain in LCSSA form!");
1383 }
1384
1385 return MatchVal;
1386}
1387
1388void LoopIdiomVectorize::transformFindFirstByte(
1389 PHINode *IndPhi, unsigned VF, Type *CharTy, BasicBlock *ExitSucc,
1390 BasicBlock *ExitFail, Value *SearchStart, Value *SearchEnd,
1391 Value *NeedleStart, Value *NeedleEnd) {
1392 // Insert the find first byte code at the end of the preheader block.
1393 BasicBlock *Preheader = CurLoop->getLoopPreheader();
1394 BranchInst *PHBranch = cast<BranchInst>(Val: Preheader->getTerminator());
1395 IRBuilder<> Builder(PHBranch);
1396 DomTreeUpdater DTU(DT, DomTreeUpdater::UpdateStrategy::Lazy);
1397 Builder.SetCurrentDebugLocation(PHBranch->getDebugLoc());
1398
1399 expandFindFirstByte(Builder, DTU, VF, CharTy, IndPhi, ExitSucc, ExitFail,
1400 SearchStart, SearchEnd, NeedleStart, NeedleEnd);
1401
1402 assert(PHBranch->isUnconditional() &&
1403 "Expected preheader to terminate with an unconditional branch.");
1404
1405 if (VerifyLoops && CurLoop->getParentLoop()) {
1406 CurLoop->getParentLoop()->verifyLoop();
1407 if (!CurLoop->getParentLoop()->isRecursivelyLCSSAForm(DT: *DT, LI: *LI))
1408 report_fatal_error(reason: "Loops must remain in LCSSA form!");
1409 }
1410}
1411