1 | //===-------- LoopIdiomVectorize.cpp - Loop idiom vectorization -----------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This pass implements a pass that recognizes certain loop idioms and |
10 | // transforms them into more optimized versions of the same loop. In cases |
11 | // where this happens, it can be a significant performance win. |
12 | // |
13 | // We currently only recognize one loop that finds the first mismatched byte |
14 | // in an array and returns the index, i.e. something like: |
15 | // |
16 | // while (++i != n) { |
17 | // if (a[i] != b[i]) |
18 | // break; |
19 | // } |
20 | // |
21 | // In this example we can actually vectorize the loop despite the early exit, |
22 | // although the loop vectorizer does not support it. It requires some extra |
23 | // checks to deal with the possibility of faulting loads when crossing page |
24 | // boundaries. However, even with these checks it is still profitable to do the |
25 | // transformation. |
26 | // |
27 | //===----------------------------------------------------------------------===// |
28 | // |
29 | // NOTE: This Pass matches a really specific loop pattern because it's only |
30 | // supposed to be a temporary solution until our LoopVectorizer is powerful |
31 | // enought to vectorize it automatically. |
32 | // |
33 | // TODO List: |
34 | // |
35 | // * Add support for the inverse case where we scan for a matching element. |
36 | // * Permit 64-bit induction variable types. |
37 | // * Recognize loops that increment the IV *after* comparing bytes. |
38 | // * Allow 32-bit sign-extends of the IV used by the GEP. |
39 | // |
40 | //===----------------------------------------------------------------------===// |
41 | |
42 | #include "llvm/Transforms/Vectorize/LoopIdiomVectorize.h" |
43 | #include "llvm/Analysis/DomTreeUpdater.h" |
44 | #include "llvm/Analysis/LoopPass.h" |
45 | #include "llvm/Analysis/TargetTransformInfo.h" |
46 | #include "llvm/IR/Dominators.h" |
47 | #include "llvm/IR/IRBuilder.h" |
48 | #include "llvm/IR/Intrinsics.h" |
49 | #include "llvm/IR/MDBuilder.h" |
50 | #include "llvm/IR/PatternMatch.h" |
51 | #include "llvm/Transforms/Utils/BasicBlockUtils.h" |
52 | |
53 | using namespace llvm; |
54 | using namespace PatternMatch; |
55 | |
56 | #define DEBUG_TYPE "loop-idiom-vectorize" |
57 | |
58 | static cl::opt<bool> DisableAll("disable-loop-idiom-vectorize-all" , cl::Hidden, |
59 | cl::init(Val: false), |
60 | cl::desc("Disable Loop Idiom Vectorize Pass." )); |
61 | |
62 | static cl::opt<LoopIdiomVectorizeStyle> |
63 | LITVecStyle("loop-idiom-vectorize-style" , cl::Hidden, |
64 | cl::desc("The vectorization style for loop idiom transform." ), |
65 | cl::values(clEnumValN(LoopIdiomVectorizeStyle::Masked, "masked" , |
66 | "Use masked vector intrinsics" ), |
67 | clEnumValN(LoopIdiomVectorizeStyle::Predicated, |
68 | "predicated" , "Use VP intrinsics" )), |
69 | cl::init(Val: LoopIdiomVectorizeStyle::Masked)); |
70 | |
71 | static cl::opt<bool> |
72 | DisableByteCmp("disable-loop-idiom-vectorize-bytecmp" , cl::Hidden, |
73 | cl::init(Val: false), |
74 | cl::desc("Proceed with Loop Idiom Vectorize Pass, but do " |
75 | "not convert byte-compare loop(s)." )); |
76 | |
77 | static cl::opt<unsigned> |
78 | ByteCmpVF("loop-idiom-vectorize-bytecmp-vf" , cl::Hidden, |
79 | cl::desc("The vectorization factor for byte-compare patterns." ), |
80 | cl::init(Val: 16)); |
81 | |
82 | static cl::opt<bool> |
83 | VerifyLoops("loop-idiom-vectorize-verify" , cl::Hidden, cl::init(Val: false), |
84 | cl::desc("Verify loops generated Loop Idiom Vectorize Pass." )); |
85 | |
86 | namespace { |
87 | class LoopIdiomVectorize { |
88 | LoopIdiomVectorizeStyle VectorizeStyle; |
89 | unsigned ByteCompareVF; |
90 | Loop *CurLoop = nullptr; |
91 | DominatorTree *DT; |
92 | LoopInfo *LI; |
93 | const TargetTransformInfo *TTI; |
94 | const DataLayout *DL; |
95 | |
96 | // Blocks that will be used for inserting vectorized code. |
97 | BasicBlock *EndBlock = nullptr; |
98 | BasicBlock * = nullptr; |
99 | BasicBlock *VectorLoopStartBlock = nullptr; |
100 | BasicBlock *VectorLoopMismatchBlock = nullptr; |
101 | BasicBlock *VectorLoopIncBlock = nullptr; |
102 | |
103 | public: |
104 | LoopIdiomVectorize(LoopIdiomVectorizeStyle S, unsigned VF, DominatorTree *DT, |
105 | LoopInfo *LI, const TargetTransformInfo *TTI, |
106 | const DataLayout *DL) |
107 | : VectorizeStyle(S), ByteCompareVF(VF), DT(DT), LI(LI), TTI(TTI), DL(DL) { |
108 | } |
109 | |
110 | bool run(Loop *L); |
111 | |
112 | private: |
113 | /// \name Countable Loop Idiom Handling |
114 | /// @{ |
115 | |
116 | bool runOnCountableLoop(); |
117 | bool runOnLoopBlock(BasicBlock *BB, const SCEV *BECount, |
118 | SmallVectorImpl<BasicBlock *> &ExitBlocks); |
119 | |
120 | bool recognizeByteCompare(); |
121 | |
122 | Value *expandFindMismatch(IRBuilder<> &Builder, DomTreeUpdater &DTU, |
123 | GetElementPtrInst *GEPA, GetElementPtrInst *GEPB, |
124 | Instruction *Index, Value *Start, Value *MaxLen); |
125 | |
126 | Value *createMaskedFindMismatch(IRBuilder<> &Builder, DomTreeUpdater &DTU, |
127 | GetElementPtrInst *GEPA, |
128 | GetElementPtrInst *GEPB, Value *ExtStart, |
129 | Value *ExtEnd); |
130 | Value *createPredicatedFindMismatch(IRBuilder<> &Builder, DomTreeUpdater &DTU, |
131 | GetElementPtrInst *GEPA, |
132 | GetElementPtrInst *GEPB, Value *ExtStart, |
133 | Value *ExtEnd); |
134 | |
135 | void transformByteCompare(GetElementPtrInst *GEPA, GetElementPtrInst *GEPB, |
136 | PHINode *IndPhi, Value *MaxLen, Instruction *Index, |
137 | Value *Start, bool IncIdx, BasicBlock *FoundBB, |
138 | BasicBlock *EndBB); |
139 | /// @} |
140 | }; |
141 | } // anonymous namespace |
142 | |
143 | PreservedAnalyses LoopIdiomVectorizePass::run(Loop &L, LoopAnalysisManager &AM, |
144 | LoopStandardAnalysisResults &AR, |
145 | LPMUpdater &) { |
146 | if (DisableAll) |
147 | return PreservedAnalyses::all(); |
148 | |
149 | const auto *DL = &L.getHeader()->getDataLayout(); |
150 | |
151 | LoopIdiomVectorizeStyle VecStyle = VectorizeStyle; |
152 | if (LITVecStyle.getNumOccurrences()) |
153 | VecStyle = LITVecStyle; |
154 | |
155 | unsigned BCVF = ByteCompareVF; |
156 | if (ByteCmpVF.getNumOccurrences()) |
157 | BCVF = ByteCmpVF; |
158 | |
159 | LoopIdiomVectorize LIV(VecStyle, BCVF, &AR.DT, &AR.LI, &AR.TTI, DL); |
160 | if (!LIV.run(L: &L)) |
161 | return PreservedAnalyses::all(); |
162 | |
163 | return PreservedAnalyses::none(); |
164 | } |
165 | |
166 | //===----------------------------------------------------------------------===// |
167 | // |
168 | // Implementation of LoopIdiomVectorize |
169 | // |
170 | //===----------------------------------------------------------------------===// |
171 | |
172 | bool LoopIdiomVectorize::run(Loop *L) { |
173 | CurLoop = L; |
174 | |
175 | Function &F = *L->getHeader()->getParent(); |
176 | if (DisableAll || F.hasOptSize()) |
177 | return false; |
178 | |
179 | if (F.hasFnAttribute(Kind: Attribute::NoImplicitFloat)) { |
180 | LLVM_DEBUG(dbgs() << DEBUG_TYPE << " is disabled on " << F.getName() |
181 | << " due to its NoImplicitFloat attribute" ); |
182 | return false; |
183 | } |
184 | |
185 | // If the loop could not be converted to canonical form, it must have an |
186 | // indirectbr in it, just give up. |
187 | if (!L->getLoopPreheader()) |
188 | return false; |
189 | |
190 | LLVM_DEBUG(dbgs() << DEBUG_TYPE " Scanning: F[" << F.getName() << "] Loop %" |
191 | << CurLoop->getHeader()->getName() << "\n" ); |
192 | |
193 | return recognizeByteCompare(); |
194 | } |
195 | |
196 | bool LoopIdiomVectorize::recognizeByteCompare() { |
197 | // Currently the transformation only works on scalable vector types, although |
198 | // there is no fundamental reason why it cannot be made to work for fixed |
199 | // width too. |
200 | |
201 | // We also need to know the minimum page size for the target in order to |
202 | // generate runtime memory checks to ensure the vector version won't fault. |
203 | if (!TTI->supportsScalableVectors() || !TTI->getMinPageSize().has_value() || |
204 | DisableByteCmp) |
205 | return false; |
206 | |
207 | BasicBlock * = CurLoop->getHeader(); |
208 | |
209 | // In LoopIdiomVectorize::run we have already checked that the loop |
210 | // has a preheader so we can assume it's in a canonical form. |
211 | if (CurLoop->getNumBackEdges() != 1 || CurLoop->getNumBlocks() != 2) |
212 | return false; |
213 | |
214 | PHINode *PN = dyn_cast<PHINode>(Val: &Header->front()); |
215 | if (!PN || PN->getNumIncomingValues() != 2) |
216 | return false; |
217 | |
218 | auto LoopBlocks = CurLoop->getBlocks(); |
219 | // The first block in the loop should contain only 4 instructions, e.g. |
220 | // |
221 | // while.cond: |
222 | // %res.phi = phi i32 [ %start, %ph ], [ %inc, %while.body ] |
223 | // %inc = add i32 %res.phi, 1 |
224 | // %cmp.not = icmp eq i32 %inc, %n |
225 | // br i1 %cmp.not, label %while.end, label %while.body |
226 | // |
227 | if (LoopBlocks[0]->sizeWithoutDebug() > 4) |
228 | return false; |
229 | |
230 | // The second block should contain 7 instructions, e.g. |
231 | // |
232 | // while.body: |
233 | // %idx = zext i32 %inc to i64 |
234 | // %idx.a = getelementptr inbounds i8, ptr %a, i64 %idx |
235 | // %load.a = load i8, ptr %idx.a |
236 | // %idx.b = getelementptr inbounds i8, ptr %b, i64 %idx |
237 | // %load.b = load i8, ptr %idx.b |
238 | // %cmp.not.ld = icmp eq i8 %load.a, %load.b |
239 | // br i1 %cmp.not.ld, label %while.cond, label %while.end |
240 | // |
241 | if (LoopBlocks[1]->sizeWithoutDebug() > 7) |
242 | return false; |
243 | |
244 | // The incoming value to the PHI node from the loop should be an add of 1. |
245 | Value *StartIdx = nullptr; |
246 | Instruction *Index = nullptr; |
247 | if (!CurLoop->contains(BB: PN->getIncomingBlock(i: 0))) { |
248 | StartIdx = PN->getIncomingValue(i: 0); |
249 | Index = dyn_cast<Instruction>(Val: PN->getIncomingValue(i: 1)); |
250 | } else { |
251 | StartIdx = PN->getIncomingValue(i: 1); |
252 | Index = dyn_cast<Instruction>(Val: PN->getIncomingValue(i: 0)); |
253 | } |
254 | |
255 | // Limit to 32-bit types for now |
256 | if (!Index || !Index->getType()->isIntegerTy(Bitwidth: 32) || |
257 | !match(V: Index, P: m_c_Add(L: m_Specific(V: PN), R: m_One()))) |
258 | return false; |
259 | |
260 | // If we match the pattern, PN and Index will be replaced with the result of |
261 | // the cttz.elts intrinsic. If any other instructions are used outside of |
262 | // the loop, we cannot replace it. |
263 | for (BasicBlock *BB : LoopBlocks) |
264 | for (Instruction &I : *BB) |
265 | if (&I != PN && &I != Index) |
266 | for (User *U : I.users()) |
267 | if (!CurLoop->contains(Inst: cast<Instruction>(Val: U))) |
268 | return false; |
269 | |
270 | // Match the branch instruction for the header |
271 | ICmpInst::Predicate Pred; |
272 | Value *MaxLen; |
273 | BasicBlock *EndBB, *WhileBB; |
274 | if (!match(V: Header->getTerminator(), |
275 | P: m_Br(C: m_ICmp(Pred, L: m_Specific(V: Index), R: m_Value(V&: MaxLen)), |
276 | T: m_BasicBlock(V&: EndBB), F: m_BasicBlock(V&: WhileBB))) || |
277 | Pred != ICmpInst::Predicate::ICMP_EQ || !CurLoop->contains(BB: WhileBB)) |
278 | return false; |
279 | |
280 | // WhileBB should contain the pattern of load & compare instructions. Match |
281 | // the pattern and find the GEP instructions used by the loads. |
282 | ICmpInst::Predicate WhilePred; |
283 | BasicBlock *FoundBB; |
284 | BasicBlock *TrueBB; |
285 | Value *LoadA, *LoadB; |
286 | if (!match(V: WhileBB->getTerminator(), |
287 | P: m_Br(C: m_ICmp(Pred&: WhilePred, L: m_Value(V&: LoadA), R: m_Value(V&: LoadB)), |
288 | T: m_BasicBlock(V&: TrueBB), F: m_BasicBlock(V&: FoundBB))) || |
289 | WhilePred != ICmpInst::Predicate::ICMP_EQ || !CurLoop->contains(BB: TrueBB)) |
290 | return false; |
291 | |
292 | Value *A, *B; |
293 | if (!match(V: LoadA, P: m_Load(Op: m_Value(V&: A))) || !match(V: LoadB, P: m_Load(Op: m_Value(V&: B)))) |
294 | return false; |
295 | |
296 | LoadInst *LoadAI = cast<LoadInst>(Val: LoadA); |
297 | LoadInst *LoadBI = cast<LoadInst>(Val: LoadB); |
298 | if (!LoadAI->isSimple() || !LoadBI->isSimple()) |
299 | return false; |
300 | |
301 | GetElementPtrInst *GEPA = dyn_cast<GetElementPtrInst>(Val: A); |
302 | GetElementPtrInst *GEPB = dyn_cast<GetElementPtrInst>(Val: B); |
303 | |
304 | if (!GEPA || !GEPB) |
305 | return false; |
306 | |
307 | Value *PtrA = GEPA->getPointerOperand(); |
308 | Value *PtrB = GEPB->getPointerOperand(); |
309 | |
310 | // Check we are loading i8 values from two loop invariant pointers |
311 | if (!CurLoop->isLoopInvariant(V: PtrA) || !CurLoop->isLoopInvariant(V: PtrB) || |
312 | !GEPA->getResultElementType()->isIntegerTy(Bitwidth: 8) || |
313 | !GEPB->getResultElementType()->isIntegerTy(Bitwidth: 8) || |
314 | !LoadAI->getType()->isIntegerTy(Bitwidth: 8) || |
315 | !LoadBI->getType()->isIntegerTy(Bitwidth: 8) || PtrA == PtrB) |
316 | return false; |
317 | |
318 | // Check that the index to the GEPs is the index we found earlier |
319 | if (GEPA->getNumIndices() > 1 || GEPB->getNumIndices() > 1) |
320 | return false; |
321 | |
322 | Value *IdxA = GEPA->getOperand(i_nocapture: GEPA->getNumIndices()); |
323 | Value *IdxB = GEPB->getOperand(i_nocapture: GEPB->getNumIndices()); |
324 | if (IdxA != IdxB || !match(V: IdxA, P: m_ZExt(Op: m_Specific(V: Index)))) |
325 | return false; |
326 | |
327 | // We only ever expect the pre-incremented index value to be used inside the |
328 | // loop. |
329 | if (!PN->hasOneUse()) |
330 | return false; |
331 | |
332 | // Ensure that when the Found and End blocks are identical the PHIs have the |
333 | // supported format. We don't currently allow cases like this: |
334 | // while.cond: |
335 | // ... |
336 | // br i1 %cmp.not, label %while.end, label %while.body |
337 | // |
338 | // while.body: |
339 | // ... |
340 | // br i1 %cmp.not2, label %while.cond, label %while.end |
341 | // |
342 | // while.end: |
343 | // %final_ptr = phi ptr [ %c, %while.body ], [ %d, %while.cond ] |
344 | // |
345 | // Where the incoming values for %final_ptr are unique and from each of the |
346 | // loop blocks, but not actually defined in the loop. This requires extra |
347 | // work setting up the byte.compare block, i.e. by introducing a select to |
348 | // choose the correct value. |
349 | // TODO: We could add support for this in future. |
350 | if (FoundBB == EndBB) { |
351 | for (PHINode &EndPN : EndBB->phis()) { |
352 | Value *WhileCondVal = EndPN.getIncomingValueForBlock(BB: Header); |
353 | Value *WhileBodyVal = EndPN.getIncomingValueForBlock(BB: WhileBB); |
354 | |
355 | // The value of the index when leaving the while.cond block is always the |
356 | // same as the end value (MaxLen) so we permit either. The value when |
357 | // leaving the while.body block should only be the index. Otherwise for |
358 | // any other values we only allow ones that are same for both blocks. |
359 | if (WhileCondVal != WhileBodyVal && |
360 | ((WhileCondVal != Index && WhileCondVal != MaxLen) || |
361 | (WhileBodyVal != Index))) |
362 | return false; |
363 | } |
364 | } |
365 | |
366 | LLVM_DEBUG(dbgs() << "FOUND IDIOM IN LOOP: \n" |
367 | << *(EndBB->getParent()) << "\n\n" ); |
368 | |
369 | // The index is incremented before the GEP/Load pair so we need to |
370 | // add 1 to the start value. |
371 | transformByteCompare(GEPA, GEPB, IndPhi: PN, MaxLen, Index, Start: StartIdx, /*IncIdx=*/true, |
372 | FoundBB, EndBB); |
373 | return true; |
374 | } |
375 | |
376 | Value *LoopIdiomVectorize::createMaskedFindMismatch( |
377 | IRBuilder<> &Builder, DomTreeUpdater &DTU, GetElementPtrInst *GEPA, |
378 | GetElementPtrInst *GEPB, Value *ExtStart, Value *ExtEnd) { |
379 | Type *I64Type = Builder.getInt64Ty(); |
380 | Type *ResType = Builder.getInt32Ty(); |
381 | Type *LoadType = Builder.getInt8Ty(); |
382 | Value *PtrA = GEPA->getPointerOperand(); |
383 | Value *PtrB = GEPB->getPointerOperand(); |
384 | |
385 | ScalableVectorType *PredVTy = |
386 | ScalableVectorType::get(ElementType: Builder.getInt1Ty(), MinNumElts: ByteCompareVF); |
387 | |
388 | Value *InitialPred = Builder.CreateIntrinsic( |
389 | ID: Intrinsic::get_active_lane_mask, Types: {PredVTy, I64Type}, Args: {ExtStart, ExtEnd}); |
390 | |
391 | Value *VecLen = Builder.CreateIntrinsic(ID: Intrinsic::vscale, Types: {I64Type}, Args: {}); |
392 | VecLen = |
393 | Builder.CreateMul(LHS: VecLen, RHS: ConstantInt::get(Ty: I64Type, V: ByteCompareVF), Name: "" , |
394 | /*HasNUW=*/true, /*HasNSW=*/true); |
395 | |
396 | Value *PFalse = Builder.CreateVectorSplat(EC: PredVTy->getElementCount(), |
397 | V: Builder.getInt1(V: false)); |
398 | |
399 | BranchInst *JumpToVectorLoop = BranchInst::Create(IfTrue: VectorLoopStartBlock); |
400 | Builder.Insert(I: JumpToVectorLoop); |
401 | |
402 | DTU.applyUpdates(Updates: {{DominatorTree::Insert, VectorLoopPreheaderBlock, |
403 | VectorLoopStartBlock}}); |
404 | |
405 | // Set up the first vector loop block by creating the PHIs, doing the vector |
406 | // loads and comparing the vectors. |
407 | Builder.SetInsertPoint(VectorLoopStartBlock); |
408 | PHINode *LoopPred = Builder.CreatePHI(Ty: PredVTy, NumReservedValues: 2, Name: "mismatch_vec_loop_pred" ); |
409 | LoopPred->addIncoming(V: InitialPred, BB: VectorLoopPreheaderBlock); |
410 | PHINode *VectorIndexPhi = Builder.CreatePHI(Ty: I64Type, NumReservedValues: 2, Name: "mismatch_vec_index" ); |
411 | VectorIndexPhi->addIncoming(V: ExtStart, BB: VectorLoopPreheaderBlock); |
412 | Type *VectorLoadType = |
413 | ScalableVectorType::get(ElementType: Builder.getInt8Ty(), MinNumElts: ByteCompareVF); |
414 | Value *Passthru = ConstantInt::getNullValue(Ty: VectorLoadType); |
415 | |
416 | Value *VectorLhsGep = |
417 | Builder.CreateGEP(Ty: LoadType, Ptr: PtrA, IdxList: VectorIndexPhi, Name: "" , NW: GEPA->isInBounds()); |
418 | Value *VectorLhsLoad = Builder.CreateMaskedLoad(Ty: VectorLoadType, Ptr: VectorLhsGep, |
419 | Alignment: Align(1), Mask: LoopPred, PassThru: Passthru); |
420 | |
421 | Value *VectorRhsGep = |
422 | Builder.CreateGEP(Ty: LoadType, Ptr: PtrB, IdxList: VectorIndexPhi, Name: "" , NW: GEPB->isInBounds()); |
423 | Value *VectorRhsLoad = Builder.CreateMaskedLoad(Ty: VectorLoadType, Ptr: VectorRhsGep, |
424 | Alignment: Align(1), Mask: LoopPred, PassThru: Passthru); |
425 | |
426 | Value *VectorMatchCmp = Builder.CreateICmpNE(LHS: VectorLhsLoad, RHS: VectorRhsLoad); |
427 | VectorMatchCmp = Builder.CreateSelect(C: LoopPred, True: VectorMatchCmp, False: PFalse); |
428 | Value *VectorMatchHasActiveLanes = Builder.CreateOrReduce(Src: VectorMatchCmp); |
429 | BranchInst *VectorEarlyExit = BranchInst::Create( |
430 | IfTrue: VectorLoopMismatchBlock, IfFalse: VectorLoopIncBlock, Cond: VectorMatchHasActiveLanes); |
431 | Builder.Insert(I: VectorEarlyExit); |
432 | |
433 | DTU.applyUpdates( |
434 | Updates: {{DominatorTree::Insert, VectorLoopStartBlock, VectorLoopMismatchBlock}, |
435 | {DominatorTree::Insert, VectorLoopStartBlock, VectorLoopIncBlock}}); |
436 | |
437 | // Increment the index counter and calculate the predicate for the next |
438 | // iteration of the loop. We branch back to the start of the loop if there |
439 | // is at least one active lane. |
440 | Builder.SetInsertPoint(VectorLoopIncBlock); |
441 | Value *NewVectorIndexPhi = |
442 | Builder.CreateAdd(LHS: VectorIndexPhi, RHS: VecLen, Name: "" , |
443 | /*HasNUW=*/true, /*HasNSW=*/true); |
444 | VectorIndexPhi->addIncoming(V: NewVectorIndexPhi, BB: VectorLoopIncBlock); |
445 | Value *NewPred = |
446 | Builder.CreateIntrinsic(ID: Intrinsic::get_active_lane_mask, |
447 | Types: {PredVTy, I64Type}, Args: {NewVectorIndexPhi, ExtEnd}); |
448 | LoopPred->addIncoming(V: NewPred, BB: VectorLoopIncBlock); |
449 | |
450 | Value *PredHasActiveLanes = |
451 | Builder.CreateExtractElement(Vec: NewPred, Idx: uint64_t(0)); |
452 | BranchInst *VectorLoopBranchBack = |
453 | BranchInst::Create(IfTrue: VectorLoopStartBlock, IfFalse: EndBlock, Cond: PredHasActiveLanes); |
454 | Builder.Insert(I: VectorLoopBranchBack); |
455 | |
456 | DTU.applyUpdates( |
457 | Updates: {{DominatorTree::Insert, VectorLoopIncBlock, VectorLoopStartBlock}, |
458 | {DominatorTree::Insert, VectorLoopIncBlock, EndBlock}}); |
459 | |
460 | // If we found a mismatch then we need to calculate which lane in the vector |
461 | // had a mismatch and add that on to the current loop index. |
462 | Builder.SetInsertPoint(VectorLoopMismatchBlock); |
463 | PHINode *FoundPred = Builder.CreatePHI(Ty: PredVTy, NumReservedValues: 1, Name: "mismatch_vec_found_pred" ); |
464 | FoundPred->addIncoming(V: VectorMatchCmp, BB: VectorLoopStartBlock); |
465 | PHINode *LastLoopPred = |
466 | Builder.CreatePHI(Ty: PredVTy, NumReservedValues: 1, Name: "mismatch_vec_last_loop_pred" ); |
467 | LastLoopPred->addIncoming(V: LoopPred, BB: VectorLoopStartBlock); |
468 | PHINode *VectorFoundIndex = |
469 | Builder.CreatePHI(Ty: I64Type, NumReservedValues: 1, Name: "mismatch_vec_found_index" ); |
470 | VectorFoundIndex->addIncoming(V: VectorIndexPhi, BB: VectorLoopStartBlock); |
471 | |
472 | Value *PredMatchCmp = Builder.CreateAnd(LHS: LastLoopPred, RHS: FoundPred); |
473 | Value *Ctz = Builder.CreateIntrinsic( |
474 | ID: Intrinsic::experimental_cttz_elts, Types: {ResType, PredMatchCmp->getType()}, |
475 | Args: {PredMatchCmp, /*ZeroIsPoison=*/Builder.getInt1(V: true)}); |
476 | Ctz = Builder.CreateZExt(V: Ctz, DestTy: I64Type); |
477 | Value *VectorLoopRes64 = Builder.CreateAdd(LHS: VectorFoundIndex, RHS: Ctz, Name: "" , |
478 | /*HasNUW=*/true, /*HasNSW=*/true); |
479 | return Builder.CreateTrunc(V: VectorLoopRes64, DestTy: ResType); |
480 | } |
481 | |
482 | Value *LoopIdiomVectorize::createPredicatedFindMismatch( |
483 | IRBuilder<> &Builder, DomTreeUpdater &DTU, GetElementPtrInst *GEPA, |
484 | GetElementPtrInst *GEPB, Value *ExtStart, Value *ExtEnd) { |
485 | Type *I64Type = Builder.getInt64Ty(); |
486 | Type *I32Type = Builder.getInt32Ty(); |
487 | Type *ResType = I32Type; |
488 | Type *LoadType = Builder.getInt8Ty(); |
489 | Value *PtrA = GEPA->getPointerOperand(); |
490 | Value *PtrB = GEPB->getPointerOperand(); |
491 | |
492 | auto *JumpToVectorLoop = BranchInst::Create(IfTrue: VectorLoopStartBlock); |
493 | Builder.Insert(I: JumpToVectorLoop); |
494 | |
495 | DTU.applyUpdates(Updates: {{DominatorTree::Insert, VectorLoopPreheaderBlock, |
496 | VectorLoopStartBlock}}); |
497 | |
498 | // Set up the first Vector loop block by creating the PHIs, doing the vector |
499 | // loads and comparing the vectors. |
500 | Builder.SetInsertPoint(VectorLoopStartBlock); |
501 | auto *VectorIndexPhi = Builder.CreatePHI(Ty: I64Type, NumReservedValues: 2, Name: "mismatch_vector_index" ); |
502 | VectorIndexPhi->addIncoming(V: ExtStart, BB: VectorLoopPreheaderBlock); |
503 | |
504 | // Calculate AVL by subtracting the vector loop index from the trip count |
505 | Value *AVL = Builder.CreateSub(LHS: ExtEnd, RHS: VectorIndexPhi, Name: "avl" , /*HasNUW=*/true, |
506 | /*HasNSW=*/true); |
507 | |
508 | auto *VectorLoadType = ScalableVectorType::get(ElementType: LoadType, MinNumElts: ByteCompareVF); |
509 | auto *VF = ConstantInt::get(Ty: I32Type, V: ByteCompareVF); |
510 | |
511 | Value *VL = Builder.CreateIntrinsic(ID: Intrinsic::experimental_get_vector_length, |
512 | Types: {I64Type}, Args: {AVL, VF, Builder.getTrue()}); |
513 | Value *GepOffset = VectorIndexPhi; |
514 | |
515 | Value *VectorLhsGep = |
516 | Builder.CreateGEP(Ty: LoadType, Ptr: PtrA, IdxList: GepOffset, Name: "" , NW: GEPA->isInBounds()); |
517 | VectorType *TrueMaskTy = |
518 | VectorType::get(ElementType: Builder.getInt1Ty(), EC: VectorLoadType->getElementCount()); |
519 | Value *AllTrueMask = Constant::getAllOnesValue(Ty: TrueMaskTy); |
520 | Value *VectorLhsLoad = Builder.CreateIntrinsic( |
521 | ID: Intrinsic::vp_load, Types: {VectorLoadType, VectorLhsGep->getType()}, |
522 | Args: {VectorLhsGep, AllTrueMask, VL}, FMFSource: nullptr, Name: "lhs.load" ); |
523 | |
524 | Value *VectorRhsGep = |
525 | Builder.CreateGEP(Ty: LoadType, Ptr: PtrB, IdxList: GepOffset, Name: "" , NW: GEPB->isInBounds()); |
526 | Value *VectorRhsLoad = Builder.CreateIntrinsic( |
527 | ID: Intrinsic::vp_load, Types: {VectorLoadType, VectorLhsGep->getType()}, |
528 | Args: {VectorRhsGep, AllTrueMask, VL}, FMFSource: nullptr, Name: "rhs.load" ); |
529 | |
530 | StringRef PredicateStr = CmpInst::getPredicateName(P: CmpInst::ICMP_NE); |
531 | auto *PredicateMDS = MDString::get(Context&: VectorLhsLoad->getContext(), Str: PredicateStr); |
532 | Value *Pred = MetadataAsValue::get(Context&: VectorLhsLoad->getContext(), MD: PredicateMDS); |
533 | Value *VectorMatchCmp = Builder.CreateIntrinsic( |
534 | ID: Intrinsic::vp_icmp, Types: {VectorLhsLoad->getType()}, |
535 | Args: {VectorLhsLoad, VectorRhsLoad, Pred, AllTrueMask, VL}, FMFSource: nullptr, |
536 | Name: "mismatch.cmp" ); |
537 | Value *CTZ = Builder.CreateIntrinsic( |
538 | ID: Intrinsic::vp_cttz_elts, Types: {ResType, VectorMatchCmp->getType()}, |
539 | Args: {VectorMatchCmp, /*ZeroIsPoison=*/Builder.getInt1(V: false), AllTrueMask, |
540 | VL}); |
541 | Value *MismatchFound = Builder.CreateICmpNE(LHS: CTZ, RHS: VL); |
542 | auto *VectorEarlyExit = BranchInst::Create(IfTrue: VectorLoopMismatchBlock, |
543 | IfFalse: VectorLoopIncBlock, Cond: MismatchFound); |
544 | Builder.Insert(I: VectorEarlyExit); |
545 | |
546 | DTU.applyUpdates( |
547 | Updates: {{DominatorTree::Insert, VectorLoopStartBlock, VectorLoopMismatchBlock}, |
548 | {DominatorTree::Insert, VectorLoopStartBlock, VectorLoopIncBlock}}); |
549 | |
550 | // Increment the index counter and calculate the predicate for the next |
551 | // iteration of the loop. We branch back to the start of the loop if there |
552 | // is at least one active lane. |
553 | Builder.SetInsertPoint(VectorLoopIncBlock); |
554 | Value *VL64 = Builder.CreateZExt(V: VL, DestTy: I64Type); |
555 | Value *NewVectorIndexPhi = |
556 | Builder.CreateAdd(LHS: VectorIndexPhi, RHS: VL64, Name: "" , |
557 | /*HasNUW=*/true, /*HasNSW=*/true); |
558 | VectorIndexPhi->addIncoming(V: NewVectorIndexPhi, BB: VectorLoopIncBlock); |
559 | Value *ExitCond = Builder.CreateICmpNE(LHS: NewVectorIndexPhi, RHS: ExtEnd); |
560 | auto *VectorLoopBranchBack = |
561 | BranchInst::Create(IfTrue: VectorLoopStartBlock, IfFalse: EndBlock, Cond: ExitCond); |
562 | Builder.Insert(I: VectorLoopBranchBack); |
563 | |
564 | DTU.applyUpdates( |
565 | Updates: {{DominatorTree::Insert, VectorLoopIncBlock, VectorLoopStartBlock}, |
566 | {DominatorTree::Insert, VectorLoopIncBlock, EndBlock}}); |
567 | |
568 | // If we found a mismatch then we need to calculate which lane in the vector |
569 | // had a mismatch and add that on to the current loop index. |
570 | Builder.SetInsertPoint(VectorLoopMismatchBlock); |
571 | |
572 | // Add LCSSA phis for CTZ and VectorIndexPhi. |
573 | auto *CTZLCSSAPhi = Builder.CreatePHI(Ty: CTZ->getType(), NumReservedValues: 1, Name: "ctz" ); |
574 | CTZLCSSAPhi->addIncoming(V: CTZ, BB: VectorLoopStartBlock); |
575 | auto *VectorIndexLCSSAPhi = |
576 | Builder.CreatePHI(Ty: VectorIndexPhi->getType(), NumReservedValues: 1, Name: "mismatch_vector_index" ); |
577 | VectorIndexLCSSAPhi->addIncoming(V: VectorIndexPhi, BB: VectorLoopStartBlock); |
578 | |
579 | Value *CTZI64 = Builder.CreateZExt(V: CTZLCSSAPhi, DestTy: I64Type); |
580 | Value *VectorLoopRes64 = Builder.CreateAdd(LHS: VectorIndexLCSSAPhi, RHS: CTZI64, Name: "" , |
581 | /*HasNUW=*/true, /*HasNSW=*/true); |
582 | return Builder.CreateTrunc(V: VectorLoopRes64, DestTy: ResType); |
583 | } |
584 | |
585 | Value *LoopIdiomVectorize::expandFindMismatch( |
586 | IRBuilder<> &Builder, DomTreeUpdater &DTU, GetElementPtrInst *GEPA, |
587 | GetElementPtrInst *GEPB, Instruction *Index, Value *Start, Value *MaxLen) { |
588 | Value *PtrA = GEPA->getPointerOperand(); |
589 | Value *PtrB = GEPB->getPointerOperand(); |
590 | |
591 | // Get the arguments and types for the intrinsic. |
592 | BasicBlock * = CurLoop->getLoopPreheader(); |
593 | BranchInst *PHBranch = cast<BranchInst>(Val: Preheader->getTerminator()); |
594 | LLVMContext &Ctx = PHBranch->getContext(); |
595 | Type *LoadType = Type::getInt8Ty(C&: Ctx); |
596 | Type *ResType = Builder.getInt32Ty(); |
597 | |
598 | // Split block in the original loop preheader. |
599 | EndBlock = SplitBlock(Old: Preheader, SplitPt: PHBranch, DT, LI, MSSAU: nullptr, BBName: "mismatch_end" ); |
600 | |
601 | // Create the blocks that we're going to need: |
602 | // 1. A block for checking the zero-extended length exceeds 0 |
603 | // 2. A block to check that the start and end addresses of a given array |
604 | // lie on the same page. |
605 | // 3. The vector loop preheader. |
606 | // 4. The first vector loop block. |
607 | // 5. The vector loop increment block. |
608 | // 6. A block we can jump to from the vector loop when a mismatch is found. |
609 | // 7. The first block of the scalar loop itself, containing PHIs , loads |
610 | // and cmp. |
611 | // 8. A scalar loop increment block to increment the PHIs and go back |
612 | // around the loop. |
613 | |
614 | BasicBlock *MinItCheckBlock = BasicBlock::Create( |
615 | Context&: Ctx, Name: "mismatch_min_it_check" , Parent: EndBlock->getParent(), InsertBefore: EndBlock); |
616 | |
617 | // Update the terminator added by SplitBlock to branch to the first block |
618 | Preheader->getTerminator()->setSuccessor(Idx: 0, BB: MinItCheckBlock); |
619 | |
620 | BasicBlock *MemCheckBlock = BasicBlock::Create( |
621 | Context&: Ctx, Name: "mismatch_mem_check" , Parent: EndBlock->getParent(), InsertBefore: EndBlock); |
622 | |
623 | VectorLoopPreheaderBlock = BasicBlock::Create( |
624 | Context&: Ctx, Name: "mismatch_vec_loop_preheader" , Parent: EndBlock->getParent(), InsertBefore: EndBlock); |
625 | |
626 | VectorLoopStartBlock = BasicBlock::Create(Context&: Ctx, Name: "mismatch_vec_loop" , |
627 | Parent: EndBlock->getParent(), InsertBefore: EndBlock); |
628 | |
629 | VectorLoopIncBlock = BasicBlock::Create(Context&: Ctx, Name: "mismatch_vec_loop_inc" , |
630 | Parent: EndBlock->getParent(), InsertBefore: EndBlock); |
631 | |
632 | VectorLoopMismatchBlock = BasicBlock::Create(Context&: Ctx, Name: "mismatch_vec_loop_found" , |
633 | Parent: EndBlock->getParent(), InsertBefore: EndBlock); |
634 | |
635 | BasicBlock * = BasicBlock::Create( |
636 | Context&: Ctx, Name: "mismatch_loop_pre" , Parent: EndBlock->getParent(), InsertBefore: EndBlock); |
637 | |
638 | BasicBlock *LoopStartBlock = |
639 | BasicBlock::Create(Context&: Ctx, Name: "mismatch_loop" , Parent: EndBlock->getParent(), InsertBefore: EndBlock); |
640 | |
641 | BasicBlock *LoopIncBlock = BasicBlock::Create( |
642 | Context&: Ctx, Name: "mismatch_loop_inc" , Parent: EndBlock->getParent(), InsertBefore: EndBlock); |
643 | |
644 | DTU.applyUpdates(Updates: {{DominatorTree::Insert, Preheader, MinItCheckBlock}, |
645 | {DominatorTree::Delete, Preheader, EndBlock}}); |
646 | |
647 | // Update LoopInfo with the new vector & scalar loops. |
648 | auto VectorLoop = LI->AllocateLoop(); |
649 | auto ScalarLoop = LI->AllocateLoop(); |
650 | |
651 | if (CurLoop->getParentLoop()) { |
652 | CurLoop->getParentLoop()->addBasicBlockToLoop(NewBB: MinItCheckBlock, LI&: *LI); |
653 | CurLoop->getParentLoop()->addBasicBlockToLoop(NewBB: MemCheckBlock, LI&: *LI); |
654 | CurLoop->getParentLoop()->addBasicBlockToLoop(NewBB: VectorLoopPreheaderBlock, |
655 | LI&: *LI); |
656 | CurLoop->getParentLoop()->addChildLoop(NewChild: VectorLoop); |
657 | CurLoop->getParentLoop()->addBasicBlockToLoop(NewBB: VectorLoopMismatchBlock, LI&: *LI); |
658 | CurLoop->getParentLoop()->addBasicBlockToLoop(NewBB: LoopPreHeaderBlock, LI&: *LI); |
659 | CurLoop->getParentLoop()->addChildLoop(NewChild: ScalarLoop); |
660 | } else { |
661 | LI->addTopLevelLoop(New: VectorLoop); |
662 | LI->addTopLevelLoop(New: ScalarLoop); |
663 | } |
664 | |
665 | // Add the new basic blocks to their associated loops. |
666 | VectorLoop->addBasicBlockToLoop(NewBB: VectorLoopStartBlock, LI&: *LI); |
667 | VectorLoop->addBasicBlockToLoop(NewBB: VectorLoopIncBlock, LI&: *LI); |
668 | |
669 | ScalarLoop->addBasicBlockToLoop(NewBB: LoopStartBlock, LI&: *LI); |
670 | ScalarLoop->addBasicBlockToLoop(NewBB: LoopIncBlock, LI&: *LI); |
671 | |
672 | // Set up some types and constants that we intend to reuse. |
673 | Type *I64Type = Builder.getInt64Ty(); |
674 | |
675 | // Check the zero-extended iteration count > 0 |
676 | Builder.SetInsertPoint(MinItCheckBlock); |
677 | Value *ExtStart = Builder.CreateZExt(V: Start, DestTy: I64Type); |
678 | Value *ExtEnd = Builder.CreateZExt(V: MaxLen, DestTy: I64Type); |
679 | // This check doesn't really cost us very much. |
680 | |
681 | Value *LimitCheck = Builder.CreateICmpULE(LHS: Start, RHS: MaxLen); |
682 | BranchInst *MinItCheckBr = |
683 | BranchInst::Create(IfTrue: MemCheckBlock, IfFalse: LoopPreHeaderBlock, Cond: LimitCheck); |
684 | MinItCheckBr->setMetadata( |
685 | KindID: LLVMContext::MD_prof, |
686 | Node: MDBuilder(MinItCheckBr->getContext()).createBranchWeights(TrueWeight: 99, FalseWeight: 1)); |
687 | Builder.Insert(I: MinItCheckBr); |
688 | |
689 | DTU.applyUpdates( |
690 | Updates: {{DominatorTree::Insert, MinItCheckBlock, MemCheckBlock}, |
691 | {DominatorTree::Insert, MinItCheckBlock, LoopPreHeaderBlock}}); |
692 | |
693 | // For each of the arrays, check the start/end addresses are on the same |
694 | // page. |
695 | Builder.SetInsertPoint(MemCheckBlock); |
696 | |
697 | // The early exit in the original loop means that when performing vector |
698 | // loads we are potentially reading ahead of the early exit. So we could |
699 | // fault if crossing a page boundary. Therefore, we create runtime memory |
700 | // checks based on the minimum page size as follows: |
701 | // 1. Calculate the addresses of the first memory accesses in the loop, |
702 | // i.e. LhsStart and RhsStart. |
703 | // 2. Get the last accessed addresses in the loop, i.e. LhsEnd and RhsEnd. |
704 | // 3. Determine which pages correspond to all the memory accesses, i.e |
705 | // LhsStartPage, LhsEndPage, RhsStartPage, RhsEndPage. |
706 | // 4. If LhsStartPage == LhsEndPage and RhsStartPage == RhsEndPage, then |
707 | // we know we won't cross any page boundaries in the loop so we can |
708 | // enter the vector loop! Otherwise we fall back on the scalar loop. |
709 | Value *LhsStartGEP = Builder.CreateGEP(Ty: LoadType, Ptr: PtrA, IdxList: ExtStart); |
710 | Value *RhsStartGEP = Builder.CreateGEP(Ty: LoadType, Ptr: PtrB, IdxList: ExtStart); |
711 | Value *RhsStart = Builder.CreatePtrToInt(V: RhsStartGEP, DestTy: I64Type); |
712 | Value *LhsStart = Builder.CreatePtrToInt(V: LhsStartGEP, DestTy: I64Type); |
713 | Value *LhsEndGEP = Builder.CreateGEP(Ty: LoadType, Ptr: PtrA, IdxList: ExtEnd); |
714 | Value *RhsEndGEP = Builder.CreateGEP(Ty: LoadType, Ptr: PtrB, IdxList: ExtEnd); |
715 | Value *LhsEnd = Builder.CreatePtrToInt(V: LhsEndGEP, DestTy: I64Type); |
716 | Value *RhsEnd = Builder.CreatePtrToInt(V: RhsEndGEP, DestTy: I64Type); |
717 | |
718 | const uint64_t MinPageSize = TTI->getMinPageSize().value(); |
719 | const uint64_t AddrShiftAmt = llvm::Log2_64(Value: MinPageSize); |
720 | Value *LhsStartPage = Builder.CreateLShr(LHS: LhsStart, RHS: AddrShiftAmt); |
721 | Value *LhsEndPage = Builder.CreateLShr(LHS: LhsEnd, RHS: AddrShiftAmt); |
722 | Value *RhsStartPage = Builder.CreateLShr(LHS: RhsStart, RHS: AddrShiftAmt); |
723 | Value *RhsEndPage = Builder.CreateLShr(LHS: RhsEnd, RHS: AddrShiftAmt); |
724 | Value *LhsPageCmp = Builder.CreateICmpNE(LHS: LhsStartPage, RHS: LhsEndPage); |
725 | Value *RhsPageCmp = Builder.CreateICmpNE(LHS: RhsStartPage, RHS: RhsEndPage); |
726 | |
727 | Value *CombinedPageCmp = Builder.CreateOr(LHS: LhsPageCmp, RHS: RhsPageCmp); |
728 | BranchInst *CombinedPageCmpCmpBr = BranchInst::Create( |
729 | IfTrue: LoopPreHeaderBlock, IfFalse: VectorLoopPreheaderBlock, Cond: CombinedPageCmp); |
730 | CombinedPageCmpCmpBr->setMetadata( |
731 | KindID: LLVMContext::MD_prof, Node: MDBuilder(CombinedPageCmpCmpBr->getContext()) |
732 | .createBranchWeights(TrueWeight: 10, FalseWeight: 90)); |
733 | Builder.Insert(I: CombinedPageCmpCmpBr); |
734 | |
735 | DTU.applyUpdates( |
736 | Updates: {{DominatorTree::Insert, MemCheckBlock, LoopPreHeaderBlock}, |
737 | {DominatorTree::Insert, MemCheckBlock, VectorLoopPreheaderBlock}}); |
738 | |
739 | // Set up the vector loop preheader, i.e. calculate initial loop predicate, |
740 | // zero-extend MaxLen to 64-bits, determine the number of vector elements |
741 | // processed in each iteration, etc. |
742 | Builder.SetInsertPoint(VectorLoopPreheaderBlock); |
743 | |
744 | // At this point we know two things must be true: |
745 | // 1. Start <= End |
746 | // 2. ExtMaxLen <= MinPageSize due to the page checks. |
747 | // Therefore, we know that we can use a 64-bit induction variable that |
748 | // starts from 0 -> ExtMaxLen and it will not overflow. |
749 | Value *VectorLoopRes = nullptr; |
750 | switch (VectorizeStyle) { |
751 | case LoopIdiomVectorizeStyle::Masked: |
752 | VectorLoopRes = |
753 | createMaskedFindMismatch(Builder, DTU, GEPA, GEPB, ExtStart, ExtEnd); |
754 | break; |
755 | case LoopIdiomVectorizeStyle::Predicated: |
756 | VectorLoopRes = createPredicatedFindMismatch(Builder, DTU, GEPA, GEPB, |
757 | ExtStart, ExtEnd); |
758 | break; |
759 | } |
760 | |
761 | Builder.Insert(I: BranchInst::Create(IfTrue: EndBlock)); |
762 | |
763 | DTU.applyUpdates( |
764 | Updates: {{DominatorTree::Insert, VectorLoopMismatchBlock, EndBlock}}); |
765 | |
766 | // Generate code for scalar loop. |
767 | Builder.SetInsertPoint(LoopPreHeaderBlock); |
768 | Builder.Insert(I: BranchInst::Create(IfTrue: LoopStartBlock)); |
769 | |
770 | DTU.applyUpdates( |
771 | Updates: {{DominatorTree::Insert, LoopPreHeaderBlock, LoopStartBlock}}); |
772 | |
773 | Builder.SetInsertPoint(LoopStartBlock); |
774 | PHINode *IndexPhi = Builder.CreatePHI(Ty: ResType, NumReservedValues: 2, Name: "mismatch_index" ); |
775 | IndexPhi->addIncoming(V: Start, BB: LoopPreHeaderBlock); |
776 | |
777 | // Otherwise compare the values |
778 | // Load bytes from each array and compare them. |
779 | Value *GepOffset = Builder.CreateZExt(V: IndexPhi, DestTy: I64Type); |
780 | |
781 | Value *LhsGep = |
782 | Builder.CreateGEP(Ty: LoadType, Ptr: PtrA, IdxList: GepOffset, Name: "" , NW: GEPA->isInBounds()); |
783 | Value *LhsLoad = Builder.CreateLoad(Ty: LoadType, Ptr: LhsGep); |
784 | |
785 | Value *RhsGep = |
786 | Builder.CreateGEP(Ty: LoadType, Ptr: PtrB, IdxList: GepOffset, Name: "" , NW: GEPB->isInBounds()); |
787 | Value *RhsLoad = Builder.CreateLoad(Ty: LoadType, Ptr: RhsGep); |
788 | |
789 | Value *MatchCmp = Builder.CreateICmpEQ(LHS: LhsLoad, RHS: RhsLoad); |
790 | // If we have a mismatch then exit the loop ... |
791 | BranchInst *MatchCmpBr = BranchInst::Create(IfTrue: LoopIncBlock, IfFalse: EndBlock, Cond: MatchCmp); |
792 | Builder.Insert(I: MatchCmpBr); |
793 | |
794 | DTU.applyUpdates(Updates: {{DominatorTree::Insert, LoopStartBlock, LoopIncBlock}, |
795 | {DominatorTree::Insert, LoopStartBlock, EndBlock}}); |
796 | |
797 | // Have we reached the maximum permitted length for the loop? |
798 | Builder.SetInsertPoint(LoopIncBlock); |
799 | Value *PhiInc = Builder.CreateAdd(LHS: IndexPhi, RHS: ConstantInt::get(Ty: ResType, V: 1), Name: "" , |
800 | /*HasNUW=*/Index->hasNoUnsignedWrap(), |
801 | /*HasNSW=*/Index->hasNoSignedWrap()); |
802 | IndexPhi->addIncoming(V: PhiInc, BB: LoopIncBlock); |
803 | Value *IVCmp = Builder.CreateICmpEQ(LHS: PhiInc, RHS: MaxLen); |
804 | BranchInst *IVCmpBr = BranchInst::Create(IfTrue: EndBlock, IfFalse: LoopStartBlock, Cond: IVCmp); |
805 | Builder.Insert(I: IVCmpBr); |
806 | |
807 | DTU.applyUpdates(Updates: {{DominatorTree::Insert, LoopIncBlock, EndBlock}, |
808 | {DominatorTree::Insert, LoopIncBlock, LoopStartBlock}}); |
809 | |
810 | // In the end block we need to insert a PHI node to deal with three cases: |
811 | // 1. We didn't find a mismatch in the scalar loop, so we return MaxLen. |
812 | // 2. We exitted the scalar loop early due to a mismatch and need to return |
813 | // the index that we found. |
814 | // 3. We didn't find a mismatch in the vector loop, so we return MaxLen. |
815 | // 4. We exitted the vector loop early due to a mismatch and need to return |
816 | // the index that we found. |
817 | Builder.SetInsertPoint(TheBB: EndBlock, IP: EndBlock->getFirstInsertionPt()); |
818 | PHINode *ResPhi = Builder.CreatePHI(Ty: ResType, NumReservedValues: 4, Name: "mismatch_result" ); |
819 | ResPhi->addIncoming(V: MaxLen, BB: LoopIncBlock); |
820 | ResPhi->addIncoming(V: IndexPhi, BB: LoopStartBlock); |
821 | ResPhi->addIncoming(V: MaxLen, BB: VectorLoopIncBlock); |
822 | ResPhi->addIncoming(V: VectorLoopRes, BB: VectorLoopMismatchBlock); |
823 | |
824 | Value *FinalRes = Builder.CreateTrunc(V: ResPhi, DestTy: ResType); |
825 | |
826 | if (VerifyLoops) { |
827 | ScalarLoop->verifyLoop(); |
828 | VectorLoop->verifyLoop(); |
829 | if (!VectorLoop->isRecursivelyLCSSAForm(DT: *DT, LI: *LI)) |
830 | report_fatal_error(reason: "Loops must remain in LCSSA form!" ); |
831 | if (!ScalarLoop->isRecursivelyLCSSAForm(DT: *DT, LI: *LI)) |
832 | report_fatal_error(reason: "Loops must remain in LCSSA form!" ); |
833 | } |
834 | |
835 | return FinalRes; |
836 | } |
837 | |
838 | void LoopIdiomVectorize::transformByteCompare(GetElementPtrInst *GEPA, |
839 | GetElementPtrInst *GEPB, |
840 | PHINode *IndPhi, Value *MaxLen, |
841 | Instruction *Index, Value *Start, |
842 | bool IncIdx, BasicBlock *FoundBB, |
843 | BasicBlock *EndBB) { |
844 | |
845 | // Insert the byte compare code at the end of the preheader block |
846 | BasicBlock * = CurLoop->getLoopPreheader(); |
847 | BasicBlock * = CurLoop->getHeader(); |
848 | BranchInst *PHBranch = cast<BranchInst>(Val: Preheader->getTerminator()); |
849 | IRBuilder<> Builder(PHBranch); |
850 | DomTreeUpdater DTU(DT, DomTreeUpdater::UpdateStrategy::Lazy); |
851 | Builder.SetCurrentDebugLocation(PHBranch->getDebugLoc()); |
852 | |
853 | // Increment the pointer if this was done before the loads in the loop. |
854 | if (IncIdx) |
855 | Start = Builder.CreateAdd(LHS: Start, RHS: ConstantInt::get(Ty: Start->getType(), V: 1)); |
856 | |
857 | Value *ByteCmpRes = |
858 | expandFindMismatch(Builder, DTU, GEPA, GEPB, Index, Start, MaxLen); |
859 | |
860 | // Replaces uses of index & induction Phi with intrinsic (we already |
861 | // checked that the the first instruction of Header is the Phi above). |
862 | assert(IndPhi->hasOneUse() && "Index phi node has more than one use!" ); |
863 | Index->replaceAllUsesWith(V: ByteCmpRes); |
864 | |
865 | assert(PHBranch->isUnconditional() && |
866 | "Expected preheader to terminate with an unconditional branch." ); |
867 | |
868 | // If no mismatch was found, we can jump to the end block. Create a |
869 | // new basic block for the compare instruction. |
870 | auto *CmpBB = BasicBlock::Create(Context&: Preheader->getContext(), Name: "byte.compare" , |
871 | Parent: Preheader->getParent()); |
872 | CmpBB->moveBefore(MovePos: EndBB); |
873 | |
874 | // Replace the branch in the preheader with an always-true conditional branch. |
875 | // This ensures there is still a reference to the original loop. |
876 | Builder.CreateCondBr(Cond: Builder.getTrue(), True: CmpBB, False: Header); |
877 | PHBranch->eraseFromParent(); |
878 | |
879 | BasicBlock *MismatchEnd = cast<Instruction>(Val: ByteCmpRes)->getParent(); |
880 | DTU.applyUpdates(Updates: {{DominatorTree::Insert, MismatchEnd, CmpBB}}); |
881 | |
882 | // Create the branch to either the end or found block depending on the value |
883 | // returned by the intrinsic. |
884 | Builder.SetInsertPoint(CmpBB); |
885 | if (FoundBB != EndBB) { |
886 | Value *FoundCmp = Builder.CreateICmpEQ(LHS: ByteCmpRes, RHS: MaxLen); |
887 | Builder.CreateCondBr(Cond: FoundCmp, True: EndBB, False: FoundBB); |
888 | DTU.applyUpdates(Updates: {{DominatorTree::Insert, CmpBB, FoundBB}, |
889 | {DominatorTree::Insert, CmpBB, EndBB}}); |
890 | |
891 | } else { |
892 | Builder.CreateBr(Dest: FoundBB); |
893 | DTU.applyUpdates(Updates: {{DominatorTree::Insert, CmpBB, FoundBB}}); |
894 | } |
895 | |
896 | auto fixSuccessorPhis = [&](BasicBlock *SuccBB) { |
897 | for (PHINode &PN : SuccBB->phis()) { |
898 | // At this point we've already replaced all uses of the result from the |
899 | // loop with ByteCmp. Look through the incoming values to find ByteCmp, |
900 | // meaning this is a Phi collecting the results of the byte compare. |
901 | bool ResPhi = false; |
902 | for (Value *Op : PN.incoming_values()) |
903 | if (Op == ByteCmpRes) { |
904 | ResPhi = true; |
905 | break; |
906 | } |
907 | |
908 | // Any PHI that depended upon the result of the byte compare needs a new |
909 | // incoming value from CmpBB. This is because the original loop will get |
910 | // deleted. |
911 | if (ResPhi) |
912 | PN.addIncoming(V: ByteCmpRes, BB: CmpBB); |
913 | else { |
914 | // There should be no other outside uses of other values in the |
915 | // original loop. Any incoming values should either: |
916 | // 1. Be for blocks outside the loop, which aren't interesting. Or .. |
917 | // 2. These are from blocks in the loop with values defined outside |
918 | // the loop. We should a similar incoming value from CmpBB. |
919 | for (BasicBlock *BB : PN.blocks()) |
920 | if (CurLoop->contains(BB)) { |
921 | PN.addIncoming(V: PN.getIncomingValueForBlock(BB), BB: CmpBB); |
922 | break; |
923 | } |
924 | } |
925 | } |
926 | }; |
927 | |
928 | // Ensure all Phis in the successors of CmpBB have an incoming value from it. |
929 | fixSuccessorPhis(EndBB); |
930 | if (EndBB != FoundBB) |
931 | fixSuccessorPhis(FoundBB); |
932 | |
933 | // The new CmpBB block isn't part of the loop, but will need to be added to |
934 | // the outer loop if there is one. |
935 | if (!CurLoop->isOutermost()) |
936 | CurLoop->getParentLoop()->addBasicBlockToLoop(NewBB: CmpBB, LI&: *LI); |
937 | |
938 | if (VerifyLoops && CurLoop->getParentLoop()) { |
939 | CurLoop->getParentLoop()->verifyLoop(); |
940 | if (!CurLoop->getParentLoop()->isRecursivelyLCSSAForm(DT: *DT, LI: *LI)) |
941 | report_fatal_error(reason: "Loops must remain in LCSSA form!" ); |
942 | } |
943 | } |
944 | |