1//===-------- LoopIdiomVectorize.cpp - Loop idiom vectorization -----------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This pass implements a pass that recognizes certain loop idioms and
10// transforms them into more optimized versions of the same loop. In cases
11// where this happens, it can be a significant performance win.
12//
13// We currently only recognize one loop that finds the first mismatched byte
14// in an array and returns the index, i.e. something like:
15//
16// while (++i != n) {
17// if (a[i] != b[i])
18// break;
19// }
20//
21// In this example we can actually vectorize the loop despite the early exit,
22// although the loop vectorizer does not support it. It requires some extra
23// checks to deal with the possibility of faulting loads when crossing page
24// boundaries. However, even with these checks it is still profitable to do the
25// transformation.
26//
27//===----------------------------------------------------------------------===//
28//
29// NOTE: This Pass matches a really specific loop pattern because it's only
30// supposed to be a temporary solution until our LoopVectorizer is powerful
31// enought to vectorize it automatically.
32//
33// TODO List:
34//
35// * Add support for the inverse case where we scan for a matching element.
36// * Permit 64-bit induction variable types.
37// * Recognize loops that increment the IV *after* comparing bytes.
38// * Allow 32-bit sign-extends of the IV used by the GEP.
39//
40//===----------------------------------------------------------------------===//
41
42#include "llvm/Transforms/Vectorize/LoopIdiomVectorize.h"
43#include "llvm/Analysis/DomTreeUpdater.h"
44#include "llvm/Analysis/LoopPass.h"
45#include "llvm/Analysis/TargetTransformInfo.h"
46#include "llvm/IR/Dominators.h"
47#include "llvm/IR/IRBuilder.h"
48#include "llvm/IR/Intrinsics.h"
49#include "llvm/IR/MDBuilder.h"
50#include "llvm/IR/PatternMatch.h"
51#include "llvm/Transforms/Utils/BasicBlockUtils.h"
52
53using namespace llvm;
54using namespace PatternMatch;
55
56#define DEBUG_TYPE "loop-idiom-vectorize"
57
58static cl::opt<bool> DisableAll("disable-loop-idiom-vectorize-all", cl::Hidden,
59 cl::init(Val: false),
60 cl::desc("Disable Loop Idiom Vectorize Pass."));
61
62static cl::opt<LoopIdiomVectorizeStyle>
63 LITVecStyle("loop-idiom-vectorize-style", cl::Hidden,
64 cl::desc("The vectorization style for loop idiom transform."),
65 cl::values(clEnumValN(LoopIdiomVectorizeStyle::Masked, "masked",
66 "Use masked vector intrinsics"),
67 clEnumValN(LoopIdiomVectorizeStyle::Predicated,
68 "predicated", "Use VP intrinsics")),
69 cl::init(Val: LoopIdiomVectorizeStyle::Masked));
70
71static cl::opt<bool>
72 DisableByteCmp("disable-loop-idiom-vectorize-bytecmp", cl::Hidden,
73 cl::init(Val: false),
74 cl::desc("Proceed with Loop Idiom Vectorize Pass, but do "
75 "not convert byte-compare loop(s)."));
76
77static cl::opt<unsigned>
78 ByteCmpVF("loop-idiom-vectorize-bytecmp-vf", cl::Hidden,
79 cl::desc("The vectorization factor for byte-compare patterns."),
80 cl::init(Val: 16));
81
82static cl::opt<bool>
83 VerifyLoops("loop-idiom-vectorize-verify", cl::Hidden, cl::init(Val: false),
84 cl::desc("Verify loops generated Loop Idiom Vectorize Pass."));
85
86namespace {
87class LoopIdiomVectorize {
88 LoopIdiomVectorizeStyle VectorizeStyle;
89 unsigned ByteCompareVF;
90 Loop *CurLoop = nullptr;
91 DominatorTree *DT;
92 LoopInfo *LI;
93 const TargetTransformInfo *TTI;
94 const DataLayout *DL;
95
96 // Blocks that will be used for inserting vectorized code.
97 BasicBlock *EndBlock = nullptr;
98 BasicBlock *VectorLoopPreheaderBlock = nullptr;
99 BasicBlock *VectorLoopStartBlock = nullptr;
100 BasicBlock *VectorLoopMismatchBlock = nullptr;
101 BasicBlock *VectorLoopIncBlock = nullptr;
102
103public:
104 LoopIdiomVectorize(LoopIdiomVectorizeStyle S, unsigned VF, DominatorTree *DT,
105 LoopInfo *LI, const TargetTransformInfo *TTI,
106 const DataLayout *DL)
107 : VectorizeStyle(S), ByteCompareVF(VF), DT(DT), LI(LI), TTI(TTI), DL(DL) {
108 }
109
110 bool run(Loop *L);
111
112private:
113 /// \name Countable Loop Idiom Handling
114 /// @{
115
116 bool runOnCountableLoop();
117 bool runOnLoopBlock(BasicBlock *BB, const SCEV *BECount,
118 SmallVectorImpl<BasicBlock *> &ExitBlocks);
119
120 bool recognizeByteCompare();
121
122 Value *expandFindMismatch(IRBuilder<> &Builder, DomTreeUpdater &DTU,
123 GetElementPtrInst *GEPA, GetElementPtrInst *GEPB,
124 Instruction *Index, Value *Start, Value *MaxLen);
125
126 Value *createMaskedFindMismatch(IRBuilder<> &Builder, DomTreeUpdater &DTU,
127 GetElementPtrInst *GEPA,
128 GetElementPtrInst *GEPB, Value *ExtStart,
129 Value *ExtEnd);
130 Value *createPredicatedFindMismatch(IRBuilder<> &Builder, DomTreeUpdater &DTU,
131 GetElementPtrInst *GEPA,
132 GetElementPtrInst *GEPB, Value *ExtStart,
133 Value *ExtEnd);
134
135 void transformByteCompare(GetElementPtrInst *GEPA, GetElementPtrInst *GEPB,
136 PHINode *IndPhi, Value *MaxLen, Instruction *Index,
137 Value *Start, bool IncIdx, BasicBlock *FoundBB,
138 BasicBlock *EndBB);
139 /// @}
140};
141} // anonymous namespace
142
143PreservedAnalyses LoopIdiomVectorizePass::run(Loop &L, LoopAnalysisManager &AM,
144 LoopStandardAnalysisResults &AR,
145 LPMUpdater &) {
146 if (DisableAll)
147 return PreservedAnalyses::all();
148
149 const auto *DL = &L.getHeader()->getDataLayout();
150
151 LoopIdiomVectorizeStyle VecStyle = VectorizeStyle;
152 if (LITVecStyle.getNumOccurrences())
153 VecStyle = LITVecStyle;
154
155 unsigned BCVF = ByteCompareVF;
156 if (ByteCmpVF.getNumOccurrences())
157 BCVF = ByteCmpVF;
158
159 LoopIdiomVectorize LIV(VecStyle, BCVF, &AR.DT, &AR.LI, &AR.TTI, DL);
160 if (!LIV.run(L: &L))
161 return PreservedAnalyses::all();
162
163 return PreservedAnalyses::none();
164}
165
166//===----------------------------------------------------------------------===//
167//
168// Implementation of LoopIdiomVectorize
169//
170//===----------------------------------------------------------------------===//
171
172bool LoopIdiomVectorize::run(Loop *L) {
173 CurLoop = L;
174
175 Function &F = *L->getHeader()->getParent();
176 if (DisableAll || F.hasOptSize())
177 return false;
178
179 if (F.hasFnAttribute(Kind: Attribute::NoImplicitFloat)) {
180 LLVM_DEBUG(dbgs() << DEBUG_TYPE << " is disabled on " << F.getName()
181 << " due to its NoImplicitFloat attribute");
182 return false;
183 }
184
185 // If the loop could not be converted to canonical form, it must have an
186 // indirectbr in it, just give up.
187 if (!L->getLoopPreheader())
188 return false;
189
190 LLVM_DEBUG(dbgs() << DEBUG_TYPE " Scanning: F[" << F.getName() << "] Loop %"
191 << CurLoop->getHeader()->getName() << "\n");
192
193 return recognizeByteCompare();
194}
195
196bool LoopIdiomVectorize::recognizeByteCompare() {
197 // Currently the transformation only works on scalable vector types, although
198 // there is no fundamental reason why it cannot be made to work for fixed
199 // width too.
200
201 // We also need to know the minimum page size for the target in order to
202 // generate runtime memory checks to ensure the vector version won't fault.
203 if (!TTI->supportsScalableVectors() || !TTI->getMinPageSize().has_value() ||
204 DisableByteCmp)
205 return false;
206
207 BasicBlock *Header = CurLoop->getHeader();
208
209 // In LoopIdiomVectorize::run we have already checked that the loop
210 // has a preheader so we can assume it's in a canonical form.
211 if (CurLoop->getNumBackEdges() != 1 || CurLoop->getNumBlocks() != 2)
212 return false;
213
214 PHINode *PN = dyn_cast<PHINode>(Val: &Header->front());
215 if (!PN || PN->getNumIncomingValues() != 2)
216 return false;
217
218 auto LoopBlocks = CurLoop->getBlocks();
219 // The first block in the loop should contain only 4 instructions, e.g.
220 //
221 // while.cond:
222 // %res.phi = phi i32 [ %start, %ph ], [ %inc, %while.body ]
223 // %inc = add i32 %res.phi, 1
224 // %cmp.not = icmp eq i32 %inc, %n
225 // br i1 %cmp.not, label %while.end, label %while.body
226 //
227 if (LoopBlocks[0]->sizeWithoutDebug() > 4)
228 return false;
229
230 // The second block should contain 7 instructions, e.g.
231 //
232 // while.body:
233 // %idx = zext i32 %inc to i64
234 // %idx.a = getelementptr inbounds i8, ptr %a, i64 %idx
235 // %load.a = load i8, ptr %idx.a
236 // %idx.b = getelementptr inbounds i8, ptr %b, i64 %idx
237 // %load.b = load i8, ptr %idx.b
238 // %cmp.not.ld = icmp eq i8 %load.a, %load.b
239 // br i1 %cmp.not.ld, label %while.cond, label %while.end
240 //
241 if (LoopBlocks[1]->sizeWithoutDebug() > 7)
242 return false;
243
244 // The incoming value to the PHI node from the loop should be an add of 1.
245 Value *StartIdx = nullptr;
246 Instruction *Index = nullptr;
247 if (!CurLoop->contains(BB: PN->getIncomingBlock(i: 0))) {
248 StartIdx = PN->getIncomingValue(i: 0);
249 Index = dyn_cast<Instruction>(Val: PN->getIncomingValue(i: 1));
250 } else {
251 StartIdx = PN->getIncomingValue(i: 1);
252 Index = dyn_cast<Instruction>(Val: PN->getIncomingValue(i: 0));
253 }
254
255 // Limit to 32-bit types for now
256 if (!Index || !Index->getType()->isIntegerTy(Bitwidth: 32) ||
257 !match(V: Index, P: m_c_Add(L: m_Specific(V: PN), R: m_One())))
258 return false;
259
260 // If we match the pattern, PN and Index will be replaced with the result of
261 // the cttz.elts intrinsic. If any other instructions are used outside of
262 // the loop, we cannot replace it.
263 for (BasicBlock *BB : LoopBlocks)
264 for (Instruction &I : *BB)
265 if (&I != PN && &I != Index)
266 for (User *U : I.users())
267 if (!CurLoop->contains(Inst: cast<Instruction>(Val: U)))
268 return false;
269
270 // Match the branch instruction for the header
271 ICmpInst::Predicate Pred;
272 Value *MaxLen;
273 BasicBlock *EndBB, *WhileBB;
274 if (!match(V: Header->getTerminator(),
275 P: m_Br(C: m_ICmp(Pred, L: m_Specific(V: Index), R: m_Value(V&: MaxLen)),
276 T: m_BasicBlock(V&: EndBB), F: m_BasicBlock(V&: WhileBB))) ||
277 Pred != ICmpInst::Predicate::ICMP_EQ || !CurLoop->contains(BB: WhileBB))
278 return false;
279
280 // WhileBB should contain the pattern of load & compare instructions. Match
281 // the pattern and find the GEP instructions used by the loads.
282 ICmpInst::Predicate WhilePred;
283 BasicBlock *FoundBB;
284 BasicBlock *TrueBB;
285 Value *LoadA, *LoadB;
286 if (!match(V: WhileBB->getTerminator(),
287 P: m_Br(C: m_ICmp(Pred&: WhilePred, L: m_Value(V&: LoadA), R: m_Value(V&: LoadB)),
288 T: m_BasicBlock(V&: TrueBB), F: m_BasicBlock(V&: FoundBB))) ||
289 WhilePred != ICmpInst::Predicate::ICMP_EQ || !CurLoop->contains(BB: TrueBB))
290 return false;
291
292 Value *A, *B;
293 if (!match(V: LoadA, P: m_Load(Op: m_Value(V&: A))) || !match(V: LoadB, P: m_Load(Op: m_Value(V&: B))))
294 return false;
295
296 LoadInst *LoadAI = cast<LoadInst>(Val: LoadA);
297 LoadInst *LoadBI = cast<LoadInst>(Val: LoadB);
298 if (!LoadAI->isSimple() || !LoadBI->isSimple())
299 return false;
300
301 GetElementPtrInst *GEPA = dyn_cast<GetElementPtrInst>(Val: A);
302 GetElementPtrInst *GEPB = dyn_cast<GetElementPtrInst>(Val: B);
303
304 if (!GEPA || !GEPB)
305 return false;
306
307 Value *PtrA = GEPA->getPointerOperand();
308 Value *PtrB = GEPB->getPointerOperand();
309
310 // Check we are loading i8 values from two loop invariant pointers
311 if (!CurLoop->isLoopInvariant(V: PtrA) || !CurLoop->isLoopInvariant(V: PtrB) ||
312 !GEPA->getResultElementType()->isIntegerTy(Bitwidth: 8) ||
313 !GEPB->getResultElementType()->isIntegerTy(Bitwidth: 8) ||
314 !LoadAI->getType()->isIntegerTy(Bitwidth: 8) ||
315 !LoadBI->getType()->isIntegerTy(Bitwidth: 8) || PtrA == PtrB)
316 return false;
317
318 // Check that the index to the GEPs is the index we found earlier
319 if (GEPA->getNumIndices() > 1 || GEPB->getNumIndices() > 1)
320 return false;
321
322 Value *IdxA = GEPA->getOperand(i_nocapture: GEPA->getNumIndices());
323 Value *IdxB = GEPB->getOperand(i_nocapture: GEPB->getNumIndices());
324 if (IdxA != IdxB || !match(V: IdxA, P: m_ZExt(Op: m_Specific(V: Index))))
325 return false;
326
327 // We only ever expect the pre-incremented index value to be used inside the
328 // loop.
329 if (!PN->hasOneUse())
330 return false;
331
332 // Ensure that when the Found and End blocks are identical the PHIs have the
333 // supported format. We don't currently allow cases like this:
334 // while.cond:
335 // ...
336 // br i1 %cmp.not, label %while.end, label %while.body
337 //
338 // while.body:
339 // ...
340 // br i1 %cmp.not2, label %while.cond, label %while.end
341 //
342 // while.end:
343 // %final_ptr = phi ptr [ %c, %while.body ], [ %d, %while.cond ]
344 //
345 // Where the incoming values for %final_ptr are unique and from each of the
346 // loop blocks, but not actually defined in the loop. This requires extra
347 // work setting up the byte.compare block, i.e. by introducing a select to
348 // choose the correct value.
349 // TODO: We could add support for this in future.
350 if (FoundBB == EndBB) {
351 for (PHINode &EndPN : EndBB->phis()) {
352 Value *WhileCondVal = EndPN.getIncomingValueForBlock(BB: Header);
353 Value *WhileBodyVal = EndPN.getIncomingValueForBlock(BB: WhileBB);
354
355 // The value of the index when leaving the while.cond block is always the
356 // same as the end value (MaxLen) so we permit either. The value when
357 // leaving the while.body block should only be the index. Otherwise for
358 // any other values we only allow ones that are same for both blocks.
359 if (WhileCondVal != WhileBodyVal &&
360 ((WhileCondVal != Index && WhileCondVal != MaxLen) ||
361 (WhileBodyVal != Index)))
362 return false;
363 }
364 }
365
366 LLVM_DEBUG(dbgs() << "FOUND IDIOM IN LOOP: \n"
367 << *(EndBB->getParent()) << "\n\n");
368
369 // The index is incremented before the GEP/Load pair so we need to
370 // add 1 to the start value.
371 transformByteCompare(GEPA, GEPB, IndPhi: PN, MaxLen, Index, Start: StartIdx, /*IncIdx=*/true,
372 FoundBB, EndBB);
373 return true;
374}
375
376Value *LoopIdiomVectorize::createMaskedFindMismatch(
377 IRBuilder<> &Builder, DomTreeUpdater &DTU, GetElementPtrInst *GEPA,
378 GetElementPtrInst *GEPB, Value *ExtStart, Value *ExtEnd) {
379 Type *I64Type = Builder.getInt64Ty();
380 Type *ResType = Builder.getInt32Ty();
381 Type *LoadType = Builder.getInt8Ty();
382 Value *PtrA = GEPA->getPointerOperand();
383 Value *PtrB = GEPB->getPointerOperand();
384
385 ScalableVectorType *PredVTy =
386 ScalableVectorType::get(ElementType: Builder.getInt1Ty(), MinNumElts: ByteCompareVF);
387
388 Value *InitialPred = Builder.CreateIntrinsic(
389 ID: Intrinsic::get_active_lane_mask, Types: {PredVTy, I64Type}, Args: {ExtStart, ExtEnd});
390
391 Value *VecLen = Builder.CreateIntrinsic(ID: Intrinsic::vscale, Types: {I64Type}, Args: {});
392 VecLen =
393 Builder.CreateMul(LHS: VecLen, RHS: ConstantInt::get(Ty: I64Type, V: ByteCompareVF), Name: "",
394 /*HasNUW=*/true, /*HasNSW=*/true);
395
396 Value *PFalse = Builder.CreateVectorSplat(EC: PredVTy->getElementCount(),
397 V: Builder.getInt1(V: false));
398
399 BranchInst *JumpToVectorLoop = BranchInst::Create(IfTrue: VectorLoopStartBlock);
400 Builder.Insert(I: JumpToVectorLoop);
401
402 DTU.applyUpdates(Updates: {{DominatorTree::Insert, VectorLoopPreheaderBlock,
403 VectorLoopStartBlock}});
404
405 // Set up the first vector loop block by creating the PHIs, doing the vector
406 // loads and comparing the vectors.
407 Builder.SetInsertPoint(VectorLoopStartBlock);
408 PHINode *LoopPred = Builder.CreatePHI(Ty: PredVTy, NumReservedValues: 2, Name: "mismatch_vec_loop_pred");
409 LoopPred->addIncoming(V: InitialPred, BB: VectorLoopPreheaderBlock);
410 PHINode *VectorIndexPhi = Builder.CreatePHI(Ty: I64Type, NumReservedValues: 2, Name: "mismatch_vec_index");
411 VectorIndexPhi->addIncoming(V: ExtStart, BB: VectorLoopPreheaderBlock);
412 Type *VectorLoadType =
413 ScalableVectorType::get(ElementType: Builder.getInt8Ty(), MinNumElts: ByteCompareVF);
414 Value *Passthru = ConstantInt::getNullValue(Ty: VectorLoadType);
415
416 Value *VectorLhsGep =
417 Builder.CreateGEP(Ty: LoadType, Ptr: PtrA, IdxList: VectorIndexPhi, Name: "", NW: GEPA->isInBounds());
418 Value *VectorLhsLoad = Builder.CreateMaskedLoad(Ty: VectorLoadType, Ptr: VectorLhsGep,
419 Alignment: Align(1), Mask: LoopPred, PassThru: Passthru);
420
421 Value *VectorRhsGep =
422 Builder.CreateGEP(Ty: LoadType, Ptr: PtrB, IdxList: VectorIndexPhi, Name: "", NW: GEPB->isInBounds());
423 Value *VectorRhsLoad = Builder.CreateMaskedLoad(Ty: VectorLoadType, Ptr: VectorRhsGep,
424 Alignment: Align(1), Mask: LoopPred, PassThru: Passthru);
425
426 Value *VectorMatchCmp = Builder.CreateICmpNE(LHS: VectorLhsLoad, RHS: VectorRhsLoad);
427 VectorMatchCmp = Builder.CreateSelect(C: LoopPred, True: VectorMatchCmp, False: PFalse);
428 Value *VectorMatchHasActiveLanes = Builder.CreateOrReduce(Src: VectorMatchCmp);
429 BranchInst *VectorEarlyExit = BranchInst::Create(
430 IfTrue: VectorLoopMismatchBlock, IfFalse: VectorLoopIncBlock, Cond: VectorMatchHasActiveLanes);
431 Builder.Insert(I: VectorEarlyExit);
432
433 DTU.applyUpdates(
434 Updates: {{DominatorTree::Insert, VectorLoopStartBlock, VectorLoopMismatchBlock},
435 {DominatorTree::Insert, VectorLoopStartBlock, VectorLoopIncBlock}});
436
437 // Increment the index counter and calculate the predicate for the next
438 // iteration of the loop. We branch back to the start of the loop if there
439 // is at least one active lane.
440 Builder.SetInsertPoint(VectorLoopIncBlock);
441 Value *NewVectorIndexPhi =
442 Builder.CreateAdd(LHS: VectorIndexPhi, RHS: VecLen, Name: "",
443 /*HasNUW=*/true, /*HasNSW=*/true);
444 VectorIndexPhi->addIncoming(V: NewVectorIndexPhi, BB: VectorLoopIncBlock);
445 Value *NewPred =
446 Builder.CreateIntrinsic(ID: Intrinsic::get_active_lane_mask,
447 Types: {PredVTy, I64Type}, Args: {NewVectorIndexPhi, ExtEnd});
448 LoopPred->addIncoming(V: NewPred, BB: VectorLoopIncBlock);
449
450 Value *PredHasActiveLanes =
451 Builder.CreateExtractElement(Vec: NewPred, Idx: uint64_t(0));
452 BranchInst *VectorLoopBranchBack =
453 BranchInst::Create(IfTrue: VectorLoopStartBlock, IfFalse: EndBlock, Cond: PredHasActiveLanes);
454 Builder.Insert(I: VectorLoopBranchBack);
455
456 DTU.applyUpdates(
457 Updates: {{DominatorTree::Insert, VectorLoopIncBlock, VectorLoopStartBlock},
458 {DominatorTree::Insert, VectorLoopIncBlock, EndBlock}});
459
460 // If we found a mismatch then we need to calculate which lane in the vector
461 // had a mismatch and add that on to the current loop index.
462 Builder.SetInsertPoint(VectorLoopMismatchBlock);
463 PHINode *FoundPred = Builder.CreatePHI(Ty: PredVTy, NumReservedValues: 1, Name: "mismatch_vec_found_pred");
464 FoundPred->addIncoming(V: VectorMatchCmp, BB: VectorLoopStartBlock);
465 PHINode *LastLoopPred =
466 Builder.CreatePHI(Ty: PredVTy, NumReservedValues: 1, Name: "mismatch_vec_last_loop_pred");
467 LastLoopPred->addIncoming(V: LoopPred, BB: VectorLoopStartBlock);
468 PHINode *VectorFoundIndex =
469 Builder.CreatePHI(Ty: I64Type, NumReservedValues: 1, Name: "mismatch_vec_found_index");
470 VectorFoundIndex->addIncoming(V: VectorIndexPhi, BB: VectorLoopStartBlock);
471
472 Value *PredMatchCmp = Builder.CreateAnd(LHS: LastLoopPred, RHS: FoundPred);
473 Value *Ctz = Builder.CreateIntrinsic(
474 ID: Intrinsic::experimental_cttz_elts, Types: {ResType, PredMatchCmp->getType()},
475 Args: {PredMatchCmp, /*ZeroIsPoison=*/Builder.getInt1(V: true)});
476 Ctz = Builder.CreateZExt(V: Ctz, DestTy: I64Type);
477 Value *VectorLoopRes64 = Builder.CreateAdd(LHS: VectorFoundIndex, RHS: Ctz, Name: "",
478 /*HasNUW=*/true, /*HasNSW=*/true);
479 return Builder.CreateTrunc(V: VectorLoopRes64, DestTy: ResType);
480}
481
482Value *LoopIdiomVectorize::createPredicatedFindMismatch(
483 IRBuilder<> &Builder, DomTreeUpdater &DTU, GetElementPtrInst *GEPA,
484 GetElementPtrInst *GEPB, Value *ExtStart, Value *ExtEnd) {
485 Type *I64Type = Builder.getInt64Ty();
486 Type *I32Type = Builder.getInt32Ty();
487 Type *ResType = I32Type;
488 Type *LoadType = Builder.getInt8Ty();
489 Value *PtrA = GEPA->getPointerOperand();
490 Value *PtrB = GEPB->getPointerOperand();
491
492 auto *JumpToVectorLoop = BranchInst::Create(IfTrue: VectorLoopStartBlock);
493 Builder.Insert(I: JumpToVectorLoop);
494
495 DTU.applyUpdates(Updates: {{DominatorTree::Insert, VectorLoopPreheaderBlock,
496 VectorLoopStartBlock}});
497
498 // Set up the first Vector loop block by creating the PHIs, doing the vector
499 // loads and comparing the vectors.
500 Builder.SetInsertPoint(VectorLoopStartBlock);
501 auto *VectorIndexPhi = Builder.CreatePHI(Ty: I64Type, NumReservedValues: 2, Name: "mismatch_vector_index");
502 VectorIndexPhi->addIncoming(V: ExtStart, BB: VectorLoopPreheaderBlock);
503
504 // Calculate AVL by subtracting the vector loop index from the trip count
505 Value *AVL = Builder.CreateSub(LHS: ExtEnd, RHS: VectorIndexPhi, Name: "avl", /*HasNUW=*/true,
506 /*HasNSW=*/true);
507
508 auto *VectorLoadType = ScalableVectorType::get(ElementType: LoadType, MinNumElts: ByteCompareVF);
509 auto *VF = ConstantInt::get(Ty: I32Type, V: ByteCompareVF);
510
511 Value *VL = Builder.CreateIntrinsic(ID: Intrinsic::experimental_get_vector_length,
512 Types: {I64Type}, Args: {AVL, VF, Builder.getTrue()});
513 Value *GepOffset = VectorIndexPhi;
514
515 Value *VectorLhsGep =
516 Builder.CreateGEP(Ty: LoadType, Ptr: PtrA, IdxList: GepOffset, Name: "", NW: GEPA->isInBounds());
517 VectorType *TrueMaskTy =
518 VectorType::get(ElementType: Builder.getInt1Ty(), EC: VectorLoadType->getElementCount());
519 Value *AllTrueMask = Constant::getAllOnesValue(Ty: TrueMaskTy);
520 Value *VectorLhsLoad = Builder.CreateIntrinsic(
521 ID: Intrinsic::vp_load, Types: {VectorLoadType, VectorLhsGep->getType()},
522 Args: {VectorLhsGep, AllTrueMask, VL}, FMFSource: nullptr, Name: "lhs.load");
523
524 Value *VectorRhsGep =
525 Builder.CreateGEP(Ty: LoadType, Ptr: PtrB, IdxList: GepOffset, Name: "", NW: GEPB->isInBounds());
526 Value *VectorRhsLoad = Builder.CreateIntrinsic(
527 ID: Intrinsic::vp_load, Types: {VectorLoadType, VectorLhsGep->getType()},
528 Args: {VectorRhsGep, AllTrueMask, VL}, FMFSource: nullptr, Name: "rhs.load");
529
530 StringRef PredicateStr = CmpInst::getPredicateName(P: CmpInst::ICMP_NE);
531 auto *PredicateMDS = MDString::get(Context&: VectorLhsLoad->getContext(), Str: PredicateStr);
532 Value *Pred = MetadataAsValue::get(Context&: VectorLhsLoad->getContext(), MD: PredicateMDS);
533 Value *VectorMatchCmp = Builder.CreateIntrinsic(
534 ID: Intrinsic::vp_icmp, Types: {VectorLhsLoad->getType()},
535 Args: {VectorLhsLoad, VectorRhsLoad, Pred, AllTrueMask, VL}, FMFSource: nullptr,
536 Name: "mismatch.cmp");
537 Value *CTZ = Builder.CreateIntrinsic(
538 ID: Intrinsic::vp_cttz_elts, Types: {ResType, VectorMatchCmp->getType()},
539 Args: {VectorMatchCmp, /*ZeroIsPoison=*/Builder.getInt1(V: false), AllTrueMask,
540 VL});
541 Value *MismatchFound = Builder.CreateICmpNE(LHS: CTZ, RHS: VL);
542 auto *VectorEarlyExit = BranchInst::Create(IfTrue: VectorLoopMismatchBlock,
543 IfFalse: VectorLoopIncBlock, Cond: MismatchFound);
544 Builder.Insert(I: VectorEarlyExit);
545
546 DTU.applyUpdates(
547 Updates: {{DominatorTree::Insert, VectorLoopStartBlock, VectorLoopMismatchBlock},
548 {DominatorTree::Insert, VectorLoopStartBlock, VectorLoopIncBlock}});
549
550 // Increment the index counter and calculate the predicate for the next
551 // iteration of the loop. We branch back to the start of the loop if there
552 // is at least one active lane.
553 Builder.SetInsertPoint(VectorLoopIncBlock);
554 Value *VL64 = Builder.CreateZExt(V: VL, DestTy: I64Type);
555 Value *NewVectorIndexPhi =
556 Builder.CreateAdd(LHS: VectorIndexPhi, RHS: VL64, Name: "",
557 /*HasNUW=*/true, /*HasNSW=*/true);
558 VectorIndexPhi->addIncoming(V: NewVectorIndexPhi, BB: VectorLoopIncBlock);
559 Value *ExitCond = Builder.CreateICmpNE(LHS: NewVectorIndexPhi, RHS: ExtEnd);
560 auto *VectorLoopBranchBack =
561 BranchInst::Create(IfTrue: VectorLoopStartBlock, IfFalse: EndBlock, Cond: ExitCond);
562 Builder.Insert(I: VectorLoopBranchBack);
563
564 DTU.applyUpdates(
565 Updates: {{DominatorTree::Insert, VectorLoopIncBlock, VectorLoopStartBlock},
566 {DominatorTree::Insert, VectorLoopIncBlock, EndBlock}});
567
568 // If we found a mismatch then we need to calculate which lane in the vector
569 // had a mismatch and add that on to the current loop index.
570 Builder.SetInsertPoint(VectorLoopMismatchBlock);
571
572 // Add LCSSA phis for CTZ and VectorIndexPhi.
573 auto *CTZLCSSAPhi = Builder.CreatePHI(Ty: CTZ->getType(), NumReservedValues: 1, Name: "ctz");
574 CTZLCSSAPhi->addIncoming(V: CTZ, BB: VectorLoopStartBlock);
575 auto *VectorIndexLCSSAPhi =
576 Builder.CreatePHI(Ty: VectorIndexPhi->getType(), NumReservedValues: 1, Name: "mismatch_vector_index");
577 VectorIndexLCSSAPhi->addIncoming(V: VectorIndexPhi, BB: VectorLoopStartBlock);
578
579 Value *CTZI64 = Builder.CreateZExt(V: CTZLCSSAPhi, DestTy: I64Type);
580 Value *VectorLoopRes64 = Builder.CreateAdd(LHS: VectorIndexLCSSAPhi, RHS: CTZI64, Name: "",
581 /*HasNUW=*/true, /*HasNSW=*/true);
582 return Builder.CreateTrunc(V: VectorLoopRes64, DestTy: ResType);
583}
584
585Value *LoopIdiomVectorize::expandFindMismatch(
586 IRBuilder<> &Builder, DomTreeUpdater &DTU, GetElementPtrInst *GEPA,
587 GetElementPtrInst *GEPB, Instruction *Index, Value *Start, Value *MaxLen) {
588 Value *PtrA = GEPA->getPointerOperand();
589 Value *PtrB = GEPB->getPointerOperand();
590
591 // Get the arguments and types for the intrinsic.
592 BasicBlock *Preheader = CurLoop->getLoopPreheader();
593 BranchInst *PHBranch = cast<BranchInst>(Val: Preheader->getTerminator());
594 LLVMContext &Ctx = PHBranch->getContext();
595 Type *LoadType = Type::getInt8Ty(C&: Ctx);
596 Type *ResType = Builder.getInt32Ty();
597
598 // Split block in the original loop preheader.
599 EndBlock = SplitBlock(Old: Preheader, SplitPt: PHBranch, DT, LI, MSSAU: nullptr, BBName: "mismatch_end");
600
601 // Create the blocks that we're going to need:
602 // 1. A block for checking the zero-extended length exceeds 0
603 // 2. A block to check that the start and end addresses of a given array
604 // lie on the same page.
605 // 3. The vector loop preheader.
606 // 4. The first vector loop block.
607 // 5. The vector loop increment block.
608 // 6. A block we can jump to from the vector loop when a mismatch is found.
609 // 7. The first block of the scalar loop itself, containing PHIs , loads
610 // and cmp.
611 // 8. A scalar loop increment block to increment the PHIs and go back
612 // around the loop.
613
614 BasicBlock *MinItCheckBlock = BasicBlock::Create(
615 Context&: Ctx, Name: "mismatch_min_it_check", Parent: EndBlock->getParent(), InsertBefore: EndBlock);
616
617 // Update the terminator added by SplitBlock to branch to the first block
618 Preheader->getTerminator()->setSuccessor(Idx: 0, BB: MinItCheckBlock);
619
620 BasicBlock *MemCheckBlock = BasicBlock::Create(
621 Context&: Ctx, Name: "mismatch_mem_check", Parent: EndBlock->getParent(), InsertBefore: EndBlock);
622
623 VectorLoopPreheaderBlock = BasicBlock::Create(
624 Context&: Ctx, Name: "mismatch_vec_loop_preheader", Parent: EndBlock->getParent(), InsertBefore: EndBlock);
625
626 VectorLoopStartBlock = BasicBlock::Create(Context&: Ctx, Name: "mismatch_vec_loop",
627 Parent: EndBlock->getParent(), InsertBefore: EndBlock);
628
629 VectorLoopIncBlock = BasicBlock::Create(Context&: Ctx, Name: "mismatch_vec_loop_inc",
630 Parent: EndBlock->getParent(), InsertBefore: EndBlock);
631
632 VectorLoopMismatchBlock = BasicBlock::Create(Context&: Ctx, Name: "mismatch_vec_loop_found",
633 Parent: EndBlock->getParent(), InsertBefore: EndBlock);
634
635 BasicBlock *LoopPreHeaderBlock = BasicBlock::Create(
636 Context&: Ctx, Name: "mismatch_loop_pre", Parent: EndBlock->getParent(), InsertBefore: EndBlock);
637
638 BasicBlock *LoopStartBlock =
639 BasicBlock::Create(Context&: Ctx, Name: "mismatch_loop", Parent: EndBlock->getParent(), InsertBefore: EndBlock);
640
641 BasicBlock *LoopIncBlock = BasicBlock::Create(
642 Context&: Ctx, Name: "mismatch_loop_inc", Parent: EndBlock->getParent(), InsertBefore: EndBlock);
643
644 DTU.applyUpdates(Updates: {{DominatorTree::Insert, Preheader, MinItCheckBlock},
645 {DominatorTree::Delete, Preheader, EndBlock}});
646
647 // Update LoopInfo with the new vector & scalar loops.
648 auto VectorLoop = LI->AllocateLoop();
649 auto ScalarLoop = LI->AllocateLoop();
650
651 if (CurLoop->getParentLoop()) {
652 CurLoop->getParentLoop()->addBasicBlockToLoop(NewBB: MinItCheckBlock, LI&: *LI);
653 CurLoop->getParentLoop()->addBasicBlockToLoop(NewBB: MemCheckBlock, LI&: *LI);
654 CurLoop->getParentLoop()->addBasicBlockToLoop(NewBB: VectorLoopPreheaderBlock,
655 LI&: *LI);
656 CurLoop->getParentLoop()->addChildLoop(NewChild: VectorLoop);
657 CurLoop->getParentLoop()->addBasicBlockToLoop(NewBB: VectorLoopMismatchBlock, LI&: *LI);
658 CurLoop->getParentLoop()->addBasicBlockToLoop(NewBB: LoopPreHeaderBlock, LI&: *LI);
659 CurLoop->getParentLoop()->addChildLoop(NewChild: ScalarLoop);
660 } else {
661 LI->addTopLevelLoop(New: VectorLoop);
662 LI->addTopLevelLoop(New: ScalarLoop);
663 }
664
665 // Add the new basic blocks to their associated loops.
666 VectorLoop->addBasicBlockToLoop(NewBB: VectorLoopStartBlock, LI&: *LI);
667 VectorLoop->addBasicBlockToLoop(NewBB: VectorLoopIncBlock, LI&: *LI);
668
669 ScalarLoop->addBasicBlockToLoop(NewBB: LoopStartBlock, LI&: *LI);
670 ScalarLoop->addBasicBlockToLoop(NewBB: LoopIncBlock, LI&: *LI);
671
672 // Set up some types and constants that we intend to reuse.
673 Type *I64Type = Builder.getInt64Ty();
674
675 // Check the zero-extended iteration count > 0
676 Builder.SetInsertPoint(MinItCheckBlock);
677 Value *ExtStart = Builder.CreateZExt(V: Start, DestTy: I64Type);
678 Value *ExtEnd = Builder.CreateZExt(V: MaxLen, DestTy: I64Type);
679 // This check doesn't really cost us very much.
680
681 Value *LimitCheck = Builder.CreateICmpULE(LHS: Start, RHS: MaxLen);
682 BranchInst *MinItCheckBr =
683 BranchInst::Create(IfTrue: MemCheckBlock, IfFalse: LoopPreHeaderBlock, Cond: LimitCheck);
684 MinItCheckBr->setMetadata(
685 KindID: LLVMContext::MD_prof,
686 Node: MDBuilder(MinItCheckBr->getContext()).createBranchWeights(TrueWeight: 99, FalseWeight: 1));
687 Builder.Insert(I: MinItCheckBr);
688
689 DTU.applyUpdates(
690 Updates: {{DominatorTree::Insert, MinItCheckBlock, MemCheckBlock},
691 {DominatorTree::Insert, MinItCheckBlock, LoopPreHeaderBlock}});
692
693 // For each of the arrays, check the start/end addresses are on the same
694 // page.
695 Builder.SetInsertPoint(MemCheckBlock);
696
697 // The early exit in the original loop means that when performing vector
698 // loads we are potentially reading ahead of the early exit. So we could
699 // fault if crossing a page boundary. Therefore, we create runtime memory
700 // checks based on the minimum page size as follows:
701 // 1. Calculate the addresses of the first memory accesses in the loop,
702 // i.e. LhsStart and RhsStart.
703 // 2. Get the last accessed addresses in the loop, i.e. LhsEnd and RhsEnd.
704 // 3. Determine which pages correspond to all the memory accesses, i.e
705 // LhsStartPage, LhsEndPage, RhsStartPage, RhsEndPage.
706 // 4. If LhsStartPage == LhsEndPage and RhsStartPage == RhsEndPage, then
707 // we know we won't cross any page boundaries in the loop so we can
708 // enter the vector loop! Otherwise we fall back on the scalar loop.
709 Value *LhsStartGEP = Builder.CreateGEP(Ty: LoadType, Ptr: PtrA, IdxList: ExtStart);
710 Value *RhsStartGEP = Builder.CreateGEP(Ty: LoadType, Ptr: PtrB, IdxList: ExtStart);
711 Value *RhsStart = Builder.CreatePtrToInt(V: RhsStartGEP, DestTy: I64Type);
712 Value *LhsStart = Builder.CreatePtrToInt(V: LhsStartGEP, DestTy: I64Type);
713 Value *LhsEndGEP = Builder.CreateGEP(Ty: LoadType, Ptr: PtrA, IdxList: ExtEnd);
714 Value *RhsEndGEP = Builder.CreateGEP(Ty: LoadType, Ptr: PtrB, IdxList: ExtEnd);
715 Value *LhsEnd = Builder.CreatePtrToInt(V: LhsEndGEP, DestTy: I64Type);
716 Value *RhsEnd = Builder.CreatePtrToInt(V: RhsEndGEP, DestTy: I64Type);
717
718 const uint64_t MinPageSize = TTI->getMinPageSize().value();
719 const uint64_t AddrShiftAmt = llvm::Log2_64(Value: MinPageSize);
720 Value *LhsStartPage = Builder.CreateLShr(LHS: LhsStart, RHS: AddrShiftAmt);
721 Value *LhsEndPage = Builder.CreateLShr(LHS: LhsEnd, RHS: AddrShiftAmt);
722 Value *RhsStartPage = Builder.CreateLShr(LHS: RhsStart, RHS: AddrShiftAmt);
723 Value *RhsEndPage = Builder.CreateLShr(LHS: RhsEnd, RHS: AddrShiftAmt);
724 Value *LhsPageCmp = Builder.CreateICmpNE(LHS: LhsStartPage, RHS: LhsEndPage);
725 Value *RhsPageCmp = Builder.CreateICmpNE(LHS: RhsStartPage, RHS: RhsEndPage);
726
727 Value *CombinedPageCmp = Builder.CreateOr(LHS: LhsPageCmp, RHS: RhsPageCmp);
728 BranchInst *CombinedPageCmpCmpBr = BranchInst::Create(
729 IfTrue: LoopPreHeaderBlock, IfFalse: VectorLoopPreheaderBlock, Cond: CombinedPageCmp);
730 CombinedPageCmpCmpBr->setMetadata(
731 KindID: LLVMContext::MD_prof, Node: MDBuilder(CombinedPageCmpCmpBr->getContext())
732 .createBranchWeights(TrueWeight: 10, FalseWeight: 90));
733 Builder.Insert(I: CombinedPageCmpCmpBr);
734
735 DTU.applyUpdates(
736 Updates: {{DominatorTree::Insert, MemCheckBlock, LoopPreHeaderBlock},
737 {DominatorTree::Insert, MemCheckBlock, VectorLoopPreheaderBlock}});
738
739 // Set up the vector loop preheader, i.e. calculate initial loop predicate,
740 // zero-extend MaxLen to 64-bits, determine the number of vector elements
741 // processed in each iteration, etc.
742 Builder.SetInsertPoint(VectorLoopPreheaderBlock);
743
744 // At this point we know two things must be true:
745 // 1. Start <= End
746 // 2. ExtMaxLen <= MinPageSize due to the page checks.
747 // Therefore, we know that we can use a 64-bit induction variable that
748 // starts from 0 -> ExtMaxLen and it will not overflow.
749 Value *VectorLoopRes = nullptr;
750 switch (VectorizeStyle) {
751 case LoopIdiomVectorizeStyle::Masked:
752 VectorLoopRes =
753 createMaskedFindMismatch(Builder, DTU, GEPA, GEPB, ExtStart, ExtEnd);
754 break;
755 case LoopIdiomVectorizeStyle::Predicated:
756 VectorLoopRes = createPredicatedFindMismatch(Builder, DTU, GEPA, GEPB,
757 ExtStart, ExtEnd);
758 break;
759 }
760
761 Builder.Insert(I: BranchInst::Create(IfTrue: EndBlock));
762
763 DTU.applyUpdates(
764 Updates: {{DominatorTree::Insert, VectorLoopMismatchBlock, EndBlock}});
765
766 // Generate code for scalar loop.
767 Builder.SetInsertPoint(LoopPreHeaderBlock);
768 Builder.Insert(I: BranchInst::Create(IfTrue: LoopStartBlock));
769
770 DTU.applyUpdates(
771 Updates: {{DominatorTree::Insert, LoopPreHeaderBlock, LoopStartBlock}});
772
773 Builder.SetInsertPoint(LoopStartBlock);
774 PHINode *IndexPhi = Builder.CreatePHI(Ty: ResType, NumReservedValues: 2, Name: "mismatch_index");
775 IndexPhi->addIncoming(V: Start, BB: LoopPreHeaderBlock);
776
777 // Otherwise compare the values
778 // Load bytes from each array and compare them.
779 Value *GepOffset = Builder.CreateZExt(V: IndexPhi, DestTy: I64Type);
780
781 Value *LhsGep =
782 Builder.CreateGEP(Ty: LoadType, Ptr: PtrA, IdxList: GepOffset, Name: "", NW: GEPA->isInBounds());
783 Value *LhsLoad = Builder.CreateLoad(Ty: LoadType, Ptr: LhsGep);
784
785 Value *RhsGep =
786 Builder.CreateGEP(Ty: LoadType, Ptr: PtrB, IdxList: GepOffset, Name: "", NW: GEPB->isInBounds());
787 Value *RhsLoad = Builder.CreateLoad(Ty: LoadType, Ptr: RhsGep);
788
789 Value *MatchCmp = Builder.CreateICmpEQ(LHS: LhsLoad, RHS: RhsLoad);
790 // If we have a mismatch then exit the loop ...
791 BranchInst *MatchCmpBr = BranchInst::Create(IfTrue: LoopIncBlock, IfFalse: EndBlock, Cond: MatchCmp);
792 Builder.Insert(I: MatchCmpBr);
793
794 DTU.applyUpdates(Updates: {{DominatorTree::Insert, LoopStartBlock, LoopIncBlock},
795 {DominatorTree::Insert, LoopStartBlock, EndBlock}});
796
797 // Have we reached the maximum permitted length for the loop?
798 Builder.SetInsertPoint(LoopIncBlock);
799 Value *PhiInc = Builder.CreateAdd(LHS: IndexPhi, RHS: ConstantInt::get(Ty: ResType, V: 1), Name: "",
800 /*HasNUW=*/Index->hasNoUnsignedWrap(),
801 /*HasNSW=*/Index->hasNoSignedWrap());
802 IndexPhi->addIncoming(V: PhiInc, BB: LoopIncBlock);
803 Value *IVCmp = Builder.CreateICmpEQ(LHS: PhiInc, RHS: MaxLen);
804 BranchInst *IVCmpBr = BranchInst::Create(IfTrue: EndBlock, IfFalse: LoopStartBlock, Cond: IVCmp);
805 Builder.Insert(I: IVCmpBr);
806
807 DTU.applyUpdates(Updates: {{DominatorTree::Insert, LoopIncBlock, EndBlock},
808 {DominatorTree::Insert, LoopIncBlock, LoopStartBlock}});
809
810 // In the end block we need to insert a PHI node to deal with three cases:
811 // 1. We didn't find a mismatch in the scalar loop, so we return MaxLen.
812 // 2. We exitted the scalar loop early due to a mismatch and need to return
813 // the index that we found.
814 // 3. We didn't find a mismatch in the vector loop, so we return MaxLen.
815 // 4. We exitted the vector loop early due to a mismatch and need to return
816 // the index that we found.
817 Builder.SetInsertPoint(TheBB: EndBlock, IP: EndBlock->getFirstInsertionPt());
818 PHINode *ResPhi = Builder.CreatePHI(Ty: ResType, NumReservedValues: 4, Name: "mismatch_result");
819 ResPhi->addIncoming(V: MaxLen, BB: LoopIncBlock);
820 ResPhi->addIncoming(V: IndexPhi, BB: LoopStartBlock);
821 ResPhi->addIncoming(V: MaxLen, BB: VectorLoopIncBlock);
822 ResPhi->addIncoming(V: VectorLoopRes, BB: VectorLoopMismatchBlock);
823
824 Value *FinalRes = Builder.CreateTrunc(V: ResPhi, DestTy: ResType);
825
826 if (VerifyLoops) {
827 ScalarLoop->verifyLoop();
828 VectorLoop->verifyLoop();
829 if (!VectorLoop->isRecursivelyLCSSAForm(DT: *DT, LI: *LI))
830 report_fatal_error(reason: "Loops must remain in LCSSA form!");
831 if (!ScalarLoop->isRecursivelyLCSSAForm(DT: *DT, LI: *LI))
832 report_fatal_error(reason: "Loops must remain in LCSSA form!");
833 }
834
835 return FinalRes;
836}
837
838void LoopIdiomVectorize::transformByteCompare(GetElementPtrInst *GEPA,
839 GetElementPtrInst *GEPB,
840 PHINode *IndPhi, Value *MaxLen,
841 Instruction *Index, Value *Start,
842 bool IncIdx, BasicBlock *FoundBB,
843 BasicBlock *EndBB) {
844
845 // Insert the byte compare code at the end of the preheader block
846 BasicBlock *Preheader = CurLoop->getLoopPreheader();
847 BasicBlock *Header = CurLoop->getHeader();
848 BranchInst *PHBranch = cast<BranchInst>(Val: Preheader->getTerminator());
849 IRBuilder<> Builder(PHBranch);
850 DomTreeUpdater DTU(DT, DomTreeUpdater::UpdateStrategy::Lazy);
851 Builder.SetCurrentDebugLocation(PHBranch->getDebugLoc());
852
853 // Increment the pointer if this was done before the loads in the loop.
854 if (IncIdx)
855 Start = Builder.CreateAdd(LHS: Start, RHS: ConstantInt::get(Ty: Start->getType(), V: 1));
856
857 Value *ByteCmpRes =
858 expandFindMismatch(Builder, DTU, GEPA, GEPB, Index, Start, MaxLen);
859
860 // Replaces uses of index & induction Phi with intrinsic (we already
861 // checked that the the first instruction of Header is the Phi above).
862 assert(IndPhi->hasOneUse() && "Index phi node has more than one use!");
863 Index->replaceAllUsesWith(V: ByteCmpRes);
864
865 assert(PHBranch->isUnconditional() &&
866 "Expected preheader to terminate with an unconditional branch.");
867
868 // If no mismatch was found, we can jump to the end block. Create a
869 // new basic block for the compare instruction.
870 auto *CmpBB = BasicBlock::Create(Context&: Preheader->getContext(), Name: "byte.compare",
871 Parent: Preheader->getParent());
872 CmpBB->moveBefore(MovePos: EndBB);
873
874 // Replace the branch in the preheader with an always-true conditional branch.
875 // This ensures there is still a reference to the original loop.
876 Builder.CreateCondBr(Cond: Builder.getTrue(), True: CmpBB, False: Header);
877 PHBranch->eraseFromParent();
878
879 BasicBlock *MismatchEnd = cast<Instruction>(Val: ByteCmpRes)->getParent();
880 DTU.applyUpdates(Updates: {{DominatorTree::Insert, MismatchEnd, CmpBB}});
881
882 // Create the branch to either the end or found block depending on the value
883 // returned by the intrinsic.
884 Builder.SetInsertPoint(CmpBB);
885 if (FoundBB != EndBB) {
886 Value *FoundCmp = Builder.CreateICmpEQ(LHS: ByteCmpRes, RHS: MaxLen);
887 Builder.CreateCondBr(Cond: FoundCmp, True: EndBB, False: FoundBB);
888 DTU.applyUpdates(Updates: {{DominatorTree::Insert, CmpBB, FoundBB},
889 {DominatorTree::Insert, CmpBB, EndBB}});
890
891 } else {
892 Builder.CreateBr(Dest: FoundBB);
893 DTU.applyUpdates(Updates: {{DominatorTree::Insert, CmpBB, FoundBB}});
894 }
895
896 auto fixSuccessorPhis = [&](BasicBlock *SuccBB) {
897 for (PHINode &PN : SuccBB->phis()) {
898 // At this point we've already replaced all uses of the result from the
899 // loop with ByteCmp. Look through the incoming values to find ByteCmp,
900 // meaning this is a Phi collecting the results of the byte compare.
901 bool ResPhi = false;
902 for (Value *Op : PN.incoming_values())
903 if (Op == ByteCmpRes) {
904 ResPhi = true;
905 break;
906 }
907
908 // Any PHI that depended upon the result of the byte compare needs a new
909 // incoming value from CmpBB. This is because the original loop will get
910 // deleted.
911 if (ResPhi)
912 PN.addIncoming(V: ByteCmpRes, BB: CmpBB);
913 else {
914 // There should be no other outside uses of other values in the
915 // original loop. Any incoming values should either:
916 // 1. Be for blocks outside the loop, which aren't interesting. Or ..
917 // 2. These are from blocks in the loop with values defined outside
918 // the loop. We should a similar incoming value from CmpBB.
919 for (BasicBlock *BB : PN.blocks())
920 if (CurLoop->contains(BB)) {
921 PN.addIncoming(V: PN.getIncomingValueForBlock(BB), BB: CmpBB);
922 break;
923 }
924 }
925 }
926 };
927
928 // Ensure all Phis in the successors of CmpBB have an incoming value from it.
929 fixSuccessorPhis(EndBB);
930 if (EndBB != FoundBB)
931 fixSuccessorPhis(FoundBB);
932
933 // The new CmpBB block isn't part of the loop, but will need to be added to
934 // the outer loop if there is one.
935 if (!CurLoop->isOutermost())
936 CurLoop->getParentLoop()->addBasicBlockToLoop(NewBB: CmpBB, LI&: *LI);
937
938 if (VerifyLoops && CurLoop->getParentLoop()) {
939 CurLoop->getParentLoop()->verifyLoop();
940 if (!CurLoop->getParentLoop()->isRecursivelyLCSSAForm(DT: *DT, LI: *LI))
941 report_fatal_error(reason: "Loops must remain in LCSSA form!");
942 }
943}
944