1 | //===-------- LoopIdiomVectorize.cpp - Loop idiom vectorization -----------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This pass implements a pass that recognizes certain loop idioms and |
10 | // transforms them into more optimized versions of the same loop. In cases |
11 | // where this happens, it can be a significant performance win. |
12 | // |
13 | // We currently support two loops: |
14 | // |
15 | // 1. A loop that finds the first mismatched byte in an array and returns the |
16 | // index, i.e. something like: |
17 | // |
18 | // while (++i != n) { |
19 | // if (a[i] != b[i]) |
20 | // break; |
21 | // } |
22 | // |
23 | // In this example we can actually vectorize the loop despite the early exit, |
24 | // although the loop vectorizer does not support it. It requires some extra |
25 | // checks to deal with the possibility of faulting loads when crossing page |
26 | // boundaries. However, even with these checks it is still profitable to do the |
27 | // transformation. |
28 | // |
29 | // TODO List: |
30 | // |
31 | // * Add support for the inverse case where we scan for a matching element. |
32 | // * Permit 64-bit induction variable types. |
33 | // * Recognize loops that increment the IV *after* comparing bytes. |
34 | // * Allow 32-bit sign-extends of the IV used by the GEP. |
35 | // |
36 | // 2. A loop that finds the first matching character in an array among a set of |
37 | // possible matches, e.g.: |
38 | // |
39 | // for (; first != last; ++first) |
40 | // for (s_it = s_first; s_it != s_last; ++s_it) |
41 | // if (*first == *s_it) |
42 | // return first; |
43 | // return last; |
44 | // |
45 | // This corresponds to std::find_first_of (for arrays of bytes) from the C++ |
46 | // standard library. This function can be implemented efficiently for targets |
47 | // that support @llvm.experimental.vector.match. For example, on AArch64 targets |
48 | // that implement SVE2, this lower to a MATCH instruction, which enables us to |
49 | // perform up to 16x16=256 comparisons in one go. This can lead to very |
50 | // significant speedups. |
51 | // |
52 | // TODO: |
53 | // |
54 | // * Add support for `find_first_not_of' loops (i.e. with not-equal comparison). |
55 | // * Make VF a configurable parameter (right now we assume 128-bit vectors). |
56 | // * Potentially adjust the cost model to let the transformation kick-in even if |
57 | // @llvm.experimental.vector.match doesn't have direct support in hardware. |
58 | // |
59 | //===----------------------------------------------------------------------===// |
60 | // |
61 | // NOTE: This Pass matches really specific loop patterns because it's only |
62 | // supposed to be a temporary solution until our LoopVectorizer is powerful |
63 | // enough to vectorize them automatically. |
64 | // |
65 | //===----------------------------------------------------------------------===// |
66 | |
67 | #include "llvm/Transforms/Vectorize/LoopIdiomVectorize.h" |
68 | #include "llvm/Analysis/DomTreeUpdater.h" |
69 | #include "llvm/Analysis/LoopPass.h" |
70 | #include "llvm/Analysis/TargetTransformInfo.h" |
71 | #include "llvm/IR/Dominators.h" |
72 | #include "llvm/IR/IRBuilder.h" |
73 | #include "llvm/IR/Intrinsics.h" |
74 | #include "llvm/IR/MDBuilder.h" |
75 | #include "llvm/IR/PatternMatch.h" |
76 | #include "llvm/Transforms/Utils/BasicBlockUtils.h" |
77 | |
78 | using namespace llvm; |
79 | using namespace PatternMatch; |
80 | |
81 | #define DEBUG_TYPE "loop-idiom-vectorize" |
82 | |
83 | static cl::opt<bool> DisableAll("disable-loop-idiom-vectorize-all" , cl::Hidden, |
84 | cl::init(Val: false), |
85 | cl::desc("Disable Loop Idiom Vectorize Pass." )); |
86 | |
87 | static cl::opt<LoopIdiomVectorizeStyle> |
88 | LITVecStyle("loop-idiom-vectorize-style" , cl::Hidden, |
89 | cl::desc("The vectorization style for loop idiom transform." ), |
90 | cl::values(clEnumValN(LoopIdiomVectorizeStyle::Masked, "masked" , |
91 | "Use masked vector intrinsics" ), |
92 | clEnumValN(LoopIdiomVectorizeStyle::Predicated, |
93 | "predicated" , "Use VP intrinsics" )), |
94 | cl::init(Val: LoopIdiomVectorizeStyle::Masked)); |
95 | |
96 | static cl::opt<bool> |
97 | DisableByteCmp("disable-loop-idiom-vectorize-bytecmp" , cl::Hidden, |
98 | cl::init(Val: false), |
99 | cl::desc("Proceed with Loop Idiom Vectorize Pass, but do " |
100 | "not convert byte-compare loop(s)." )); |
101 | |
102 | static cl::opt<unsigned> |
103 | ByteCmpVF("loop-idiom-vectorize-bytecmp-vf" , cl::Hidden, |
104 | cl::desc("The vectorization factor for byte-compare patterns." ), |
105 | cl::init(Val: 16)); |
106 | |
107 | static cl::opt<bool> |
108 | DisableFindFirstByte("disable-loop-idiom-vectorize-find-first-byte" , |
109 | cl::Hidden, cl::init(Val: false), |
110 | cl::desc("Do not convert find-first-byte loop(s)." )); |
111 | |
112 | static cl::opt<bool> |
113 | VerifyLoops("loop-idiom-vectorize-verify" , cl::Hidden, cl::init(Val: false), |
114 | cl::desc("Verify loops generated Loop Idiom Vectorize Pass." )); |
115 | |
116 | namespace { |
117 | class LoopIdiomVectorize { |
118 | LoopIdiomVectorizeStyle VectorizeStyle; |
119 | unsigned ByteCompareVF; |
120 | Loop *CurLoop = nullptr; |
121 | DominatorTree *DT; |
122 | LoopInfo *LI; |
123 | const TargetTransformInfo *TTI; |
124 | const DataLayout *DL; |
125 | |
126 | // Blocks that will be used for inserting vectorized code. |
127 | BasicBlock *EndBlock = nullptr; |
128 | BasicBlock * = nullptr; |
129 | BasicBlock *VectorLoopStartBlock = nullptr; |
130 | BasicBlock *VectorLoopMismatchBlock = nullptr; |
131 | BasicBlock *VectorLoopIncBlock = nullptr; |
132 | |
133 | public: |
134 | LoopIdiomVectorize(LoopIdiomVectorizeStyle S, unsigned VF, DominatorTree *DT, |
135 | LoopInfo *LI, const TargetTransformInfo *TTI, |
136 | const DataLayout *DL) |
137 | : VectorizeStyle(S), ByteCompareVF(VF), DT(DT), LI(LI), TTI(TTI), DL(DL) { |
138 | } |
139 | |
140 | bool run(Loop *L); |
141 | |
142 | private: |
143 | /// \name Countable Loop Idiom Handling |
144 | /// @{ |
145 | |
146 | bool runOnCountableLoop(); |
147 | bool runOnLoopBlock(BasicBlock *BB, const SCEV *BECount, |
148 | SmallVectorImpl<BasicBlock *> &ExitBlocks); |
149 | |
150 | bool recognizeByteCompare(); |
151 | |
152 | Value *expandFindMismatch(IRBuilder<> &Builder, DomTreeUpdater &DTU, |
153 | GetElementPtrInst *GEPA, GetElementPtrInst *GEPB, |
154 | Instruction *Index, Value *Start, Value *MaxLen); |
155 | |
156 | Value *createMaskedFindMismatch(IRBuilder<> &Builder, DomTreeUpdater &DTU, |
157 | GetElementPtrInst *GEPA, |
158 | GetElementPtrInst *GEPB, Value *ExtStart, |
159 | Value *ExtEnd); |
160 | Value *createPredicatedFindMismatch(IRBuilder<> &Builder, DomTreeUpdater &DTU, |
161 | GetElementPtrInst *GEPA, |
162 | GetElementPtrInst *GEPB, Value *ExtStart, |
163 | Value *ExtEnd); |
164 | |
165 | void transformByteCompare(GetElementPtrInst *GEPA, GetElementPtrInst *GEPB, |
166 | PHINode *IndPhi, Value *MaxLen, Instruction *Index, |
167 | Value *Start, bool IncIdx, BasicBlock *FoundBB, |
168 | BasicBlock *EndBB); |
169 | |
170 | bool recognizeFindFirstByte(); |
171 | |
172 | Value *expandFindFirstByte(IRBuilder<> &Builder, DomTreeUpdater &DTU, |
173 | unsigned VF, Type *CharTy, BasicBlock *ExitSucc, |
174 | BasicBlock *ExitFail, Value *SearchStart, |
175 | Value *SearchEnd, Value *NeedleStart, |
176 | Value *NeedleEnd); |
177 | |
178 | void transformFindFirstByte(PHINode *IndPhi, unsigned VF, Type *CharTy, |
179 | BasicBlock *ExitSucc, BasicBlock *ExitFail, |
180 | Value *SearchStart, Value *SearchEnd, |
181 | Value *NeedleStart, Value *NeedleEnd); |
182 | /// @} |
183 | }; |
184 | } // anonymous namespace |
185 | |
186 | PreservedAnalyses LoopIdiomVectorizePass::run(Loop &L, LoopAnalysisManager &AM, |
187 | LoopStandardAnalysisResults &AR, |
188 | LPMUpdater &) { |
189 | if (DisableAll) |
190 | return PreservedAnalyses::all(); |
191 | |
192 | const auto *DL = &L.getHeader()->getDataLayout(); |
193 | |
194 | LoopIdiomVectorizeStyle VecStyle = VectorizeStyle; |
195 | if (LITVecStyle.getNumOccurrences()) |
196 | VecStyle = LITVecStyle; |
197 | |
198 | unsigned BCVF = ByteCompareVF; |
199 | if (ByteCmpVF.getNumOccurrences()) |
200 | BCVF = ByteCmpVF; |
201 | |
202 | LoopIdiomVectorize LIV(VecStyle, BCVF, &AR.DT, &AR.LI, &AR.TTI, DL); |
203 | if (!LIV.run(L: &L)) |
204 | return PreservedAnalyses::all(); |
205 | |
206 | return PreservedAnalyses::none(); |
207 | } |
208 | |
209 | //===----------------------------------------------------------------------===// |
210 | // |
211 | // Implementation of LoopIdiomVectorize |
212 | // |
213 | //===----------------------------------------------------------------------===// |
214 | |
215 | bool LoopIdiomVectorize::run(Loop *L) { |
216 | CurLoop = L; |
217 | |
218 | Function &F = *L->getHeader()->getParent(); |
219 | if (DisableAll || F.hasOptSize()) |
220 | return false; |
221 | |
222 | if (F.hasFnAttribute(Kind: Attribute::NoImplicitFloat)) { |
223 | LLVM_DEBUG(dbgs() << DEBUG_TYPE << " is disabled on " << F.getName() |
224 | << " due to its NoImplicitFloat attribute" ); |
225 | return false; |
226 | } |
227 | |
228 | // If the loop could not be converted to canonical form, it must have an |
229 | // indirectbr in it, just give up. |
230 | if (!L->getLoopPreheader()) |
231 | return false; |
232 | |
233 | LLVM_DEBUG(dbgs() << DEBUG_TYPE " Scanning: F[" << F.getName() << "] Loop %" |
234 | << CurLoop->getHeader()->getName() << "\n" ); |
235 | |
236 | if (recognizeByteCompare()) |
237 | return true; |
238 | |
239 | if (recognizeFindFirstByte()) |
240 | return true; |
241 | |
242 | return false; |
243 | } |
244 | |
245 | bool LoopIdiomVectorize::recognizeByteCompare() { |
246 | // Currently the transformation only works on scalable vector types, although |
247 | // there is no fundamental reason why it cannot be made to work for fixed |
248 | // width too. |
249 | |
250 | // We also need to know the minimum page size for the target in order to |
251 | // generate runtime memory checks to ensure the vector version won't fault. |
252 | if (!TTI->supportsScalableVectors() || !TTI->getMinPageSize().has_value() || |
253 | DisableByteCmp) |
254 | return false; |
255 | |
256 | BasicBlock * = CurLoop->getHeader(); |
257 | |
258 | // In LoopIdiomVectorize::run we have already checked that the loop |
259 | // has a preheader so we can assume it's in a canonical form. |
260 | if (CurLoop->getNumBackEdges() != 1 || CurLoop->getNumBlocks() != 2) |
261 | return false; |
262 | |
263 | PHINode *PN = dyn_cast<PHINode>(Val: &Header->front()); |
264 | if (!PN || PN->getNumIncomingValues() != 2) |
265 | return false; |
266 | |
267 | auto LoopBlocks = CurLoop->getBlocks(); |
268 | // The first block in the loop should contain only 4 instructions, e.g. |
269 | // |
270 | // while.cond: |
271 | // %res.phi = phi i32 [ %start, %ph ], [ %inc, %while.body ] |
272 | // %inc = add i32 %res.phi, 1 |
273 | // %cmp.not = icmp eq i32 %inc, %n |
274 | // br i1 %cmp.not, label %while.end, label %while.body |
275 | // |
276 | if (LoopBlocks[0]->sizeWithoutDebug() > 4) |
277 | return false; |
278 | |
279 | // The second block should contain 7 instructions, e.g. |
280 | // |
281 | // while.body: |
282 | // %idx = zext i32 %inc to i64 |
283 | // %idx.a = getelementptr inbounds i8, ptr %a, i64 %idx |
284 | // %load.a = load i8, ptr %idx.a |
285 | // %idx.b = getelementptr inbounds i8, ptr %b, i64 %idx |
286 | // %load.b = load i8, ptr %idx.b |
287 | // %cmp.not.ld = icmp eq i8 %load.a, %load.b |
288 | // br i1 %cmp.not.ld, label %while.cond, label %while.end |
289 | // |
290 | if (LoopBlocks[1]->sizeWithoutDebug() > 7) |
291 | return false; |
292 | |
293 | // The incoming value to the PHI node from the loop should be an add of 1. |
294 | Value *StartIdx = nullptr; |
295 | Instruction *Index = nullptr; |
296 | if (!CurLoop->contains(BB: PN->getIncomingBlock(i: 0))) { |
297 | StartIdx = PN->getIncomingValue(i: 0); |
298 | Index = dyn_cast<Instruction>(Val: PN->getIncomingValue(i: 1)); |
299 | } else { |
300 | StartIdx = PN->getIncomingValue(i: 1); |
301 | Index = dyn_cast<Instruction>(Val: PN->getIncomingValue(i: 0)); |
302 | } |
303 | |
304 | // Limit to 32-bit types for now |
305 | if (!Index || !Index->getType()->isIntegerTy(Bitwidth: 32) || |
306 | !match(V: Index, P: m_c_Add(L: m_Specific(V: PN), R: m_One()))) |
307 | return false; |
308 | |
309 | // If we match the pattern, PN and Index will be replaced with the result of |
310 | // the cttz.elts intrinsic. If any other instructions are used outside of |
311 | // the loop, we cannot replace it. |
312 | for (BasicBlock *BB : LoopBlocks) |
313 | for (Instruction &I : *BB) |
314 | if (&I != PN && &I != Index) |
315 | for (User *U : I.users()) |
316 | if (!CurLoop->contains(Inst: cast<Instruction>(Val: U))) |
317 | return false; |
318 | |
319 | // Match the branch instruction for the header |
320 | Value *MaxLen; |
321 | BasicBlock *EndBB, *WhileBB; |
322 | if (!match(V: Header->getTerminator(), |
323 | P: m_Br(C: m_SpecificICmp(MatchPred: ICmpInst::ICMP_EQ, L: m_Specific(V: Index), |
324 | R: m_Value(V&: MaxLen)), |
325 | T: m_BasicBlock(V&: EndBB), F: m_BasicBlock(V&: WhileBB))) || |
326 | !CurLoop->contains(BB: WhileBB)) |
327 | return false; |
328 | |
329 | // WhileBB should contain the pattern of load & compare instructions. Match |
330 | // the pattern and find the GEP instructions used by the loads. |
331 | BasicBlock *FoundBB; |
332 | BasicBlock *TrueBB; |
333 | Value *LoadA, *LoadB; |
334 | if (!match(V: WhileBB->getTerminator(), |
335 | P: m_Br(C: m_SpecificICmp(MatchPred: ICmpInst::ICMP_EQ, L: m_Value(V&: LoadA), |
336 | R: m_Value(V&: LoadB)), |
337 | T: m_BasicBlock(V&: TrueBB), F: m_BasicBlock(V&: FoundBB))) || |
338 | !CurLoop->contains(BB: TrueBB)) |
339 | return false; |
340 | |
341 | Value *A, *B; |
342 | if (!match(V: LoadA, P: m_Load(Op: m_Value(V&: A))) || !match(V: LoadB, P: m_Load(Op: m_Value(V&: B)))) |
343 | return false; |
344 | |
345 | LoadInst *LoadAI = cast<LoadInst>(Val: LoadA); |
346 | LoadInst *LoadBI = cast<LoadInst>(Val: LoadB); |
347 | if (!LoadAI->isSimple() || !LoadBI->isSimple()) |
348 | return false; |
349 | |
350 | GetElementPtrInst *GEPA = dyn_cast<GetElementPtrInst>(Val: A); |
351 | GetElementPtrInst *GEPB = dyn_cast<GetElementPtrInst>(Val: B); |
352 | |
353 | if (!GEPA || !GEPB) |
354 | return false; |
355 | |
356 | Value *PtrA = GEPA->getPointerOperand(); |
357 | Value *PtrB = GEPB->getPointerOperand(); |
358 | |
359 | // Check we are loading i8 values from two loop invariant pointers |
360 | if (!CurLoop->isLoopInvariant(V: PtrA) || !CurLoop->isLoopInvariant(V: PtrB) || |
361 | !GEPA->getResultElementType()->isIntegerTy(Bitwidth: 8) || |
362 | !GEPB->getResultElementType()->isIntegerTy(Bitwidth: 8) || |
363 | !LoadAI->getType()->isIntegerTy(Bitwidth: 8) || |
364 | !LoadBI->getType()->isIntegerTy(Bitwidth: 8) || PtrA == PtrB) |
365 | return false; |
366 | |
367 | // Check that the index to the GEPs is the index we found earlier |
368 | if (GEPA->getNumIndices() > 1 || GEPB->getNumIndices() > 1) |
369 | return false; |
370 | |
371 | Value *IdxA = GEPA->getOperand(i_nocapture: GEPA->getNumIndices()); |
372 | Value *IdxB = GEPB->getOperand(i_nocapture: GEPB->getNumIndices()); |
373 | if (IdxA != IdxB || !match(V: IdxA, P: m_ZExt(Op: m_Specific(V: Index)))) |
374 | return false; |
375 | |
376 | // We only ever expect the pre-incremented index value to be used inside the |
377 | // loop. |
378 | if (!PN->hasOneUse()) |
379 | return false; |
380 | |
381 | // Ensure that when the Found and End blocks are identical the PHIs have the |
382 | // supported format. We don't currently allow cases like this: |
383 | // while.cond: |
384 | // ... |
385 | // br i1 %cmp.not, label %while.end, label %while.body |
386 | // |
387 | // while.body: |
388 | // ... |
389 | // br i1 %cmp.not2, label %while.cond, label %while.end |
390 | // |
391 | // while.end: |
392 | // %final_ptr = phi ptr [ %c, %while.body ], [ %d, %while.cond ] |
393 | // |
394 | // Where the incoming values for %final_ptr are unique and from each of the |
395 | // loop blocks, but not actually defined in the loop. This requires extra |
396 | // work setting up the byte.compare block, i.e. by introducing a select to |
397 | // choose the correct value. |
398 | // TODO: We could add support for this in future. |
399 | if (FoundBB == EndBB) { |
400 | for (PHINode &EndPN : EndBB->phis()) { |
401 | Value *WhileCondVal = EndPN.getIncomingValueForBlock(BB: Header); |
402 | Value *WhileBodyVal = EndPN.getIncomingValueForBlock(BB: WhileBB); |
403 | |
404 | // The value of the index when leaving the while.cond block is always the |
405 | // same as the end value (MaxLen) so we permit either. The value when |
406 | // leaving the while.body block should only be the index. Otherwise for |
407 | // any other values we only allow ones that are same for both blocks. |
408 | if (WhileCondVal != WhileBodyVal && |
409 | ((WhileCondVal != Index && WhileCondVal != MaxLen) || |
410 | (WhileBodyVal != Index))) |
411 | return false; |
412 | } |
413 | } |
414 | |
415 | LLVM_DEBUG(dbgs() << "FOUND IDIOM IN LOOP: \n" |
416 | << *(EndBB->getParent()) << "\n\n" ); |
417 | |
418 | // The index is incremented before the GEP/Load pair so we need to |
419 | // add 1 to the start value. |
420 | transformByteCompare(GEPA, GEPB, IndPhi: PN, MaxLen, Index, Start: StartIdx, /*IncIdx=*/true, |
421 | FoundBB, EndBB); |
422 | return true; |
423 | } |
424 | |
425 | Value *LoopIdiomVectorize::createMaskedFindMismatch( |
426 | IRBuilder<> &Builder, DomTreeUpdater &DTU, GetElementPtrInst *GEPA, |
427 | GetElementPtrInst *GEPB, Value *ExtStart, Value *ExtEnd) { |
428 | Type *I64Type = Builder.getInt64Ty(); |
429 | Type *ResType = Builder.getInt32Ty(); |
430 | Type *LoadType = Builder.getInt8Ty(); |
431 | Value *PtrA = GEPA->getPointerOperand(); |
432 | Value *PtrB = GEPB->getPointerOperand(); |
433 | |
434 | ScalableVectorType *PredVTy = |
435 | ScalableVectorType::get(ElementType: Builder.getInt1Ty(), MinNumElts: ByteCompareVF); |
436 | |
437 | Value *InitialPred = Builder.CreateIntrinsic( |
438 | ID: Intrinsic::get_active_lane_mask, Types: {PredVTy, I64Type}, Args: {ExtStart, ExtEnd}); |
439 | |
440 | Value *VecLen = Builder.CreateVScale(Ty: I64Type); |
441 | VecLen = |
442 | Builder.CreateMul(LHS: VecLen, RHS: ConstantInt::get(Ty: I64Type, V: ByteCompareVF), Name: "" , |
443 | /*HasNUW=*/true, /*HasNSW=*/true); |
444 | |
445 | Value *PFalse = Builder.CreateVectorSplat(EC: PredVTy->getElementCount(), |
446 | V: Builder.getInt1(V: false)); |
447 | |
448 | BranchInst *JumpToVectorLoop = BranchInst::Create(IfTrue: VectorLoopStartBlock); |
449 | Builder.Insert(I: JumpToVectorLoop); |
450 | |
451 | DTU.applyUpdates(Updates: {{DominatorTree::Insert, VectorLoopPreheaderBlock, |
452 | VectorLoopStartBlock}}); |
453 | |
454 | // Set up the first vector loop block by creating the PHIs, doing the vector |
455 | // loads and comparing the vectors. |
456 | Builder.SetInsertPoint(VectorLoopStartBlock); |
457 | PHINode *LoopPred = Builder.CreatePHI(Ty: PredVTy, NumReservedValues: 2, Name: "mismatch_vec_loop_pred" ); |
458 | LoopPred->addIncoming(V: InitialPred, BB: VectorLoopPreheaderBlock); |
459 | PHINode *VectorIndexPhi = Builder.CreatePHI(Ty: I64Type, NumReservedValues: 2, Name: "mismatch_vec_index" ); |
460 | VectorIndexPhi->addIncoming(V: ExtStart, BB: VectorLoopPreheaderBlock); |
461 | Type *VectorLoadType = |
462 | ScalableVectorType::get(ElementType: Builder.getInt8Ty(), MinNumElts: ByteCompareVF); |
463 | Value *Passthru = ConstantInt::getNullValue(Ty: VectorLoadType); |
464 | |
465 | Value *VectorLhsGep = |
466 | Builder.CreateGEP(Ty: LoadType, Ptr: PtrA, IdxList: VectorIndexPhi, Name: "" , NW: GEPA->isInBounds()); |
467 | Value *VectorLhsLoad = Builder.CreateMaskedLoad(Ty: VectorLoadType, Ptr: VectorLhsGep, |
468 | Alignment: Align(1), Mask: LoopPred, PassThru: Passthru); |
469 | |
470 | Value *VectorRhsGep = |
471 | Builder.CreateGEP(Ty: LoadType, Ptr: PtrB, IdxList: VectorIndexPhi, Name: "" , NW: GEPB->isInBounds()); |
472 | Value *VectorRhsLoad = Builder.CreateMaskedLoad(Ty: VectorLoadType, Ptr: VectorRhsGep, |
473 | Alignment: Align(1), Mask: LoopPred, PassThru: Passthru); |
474 | |
475 | Value *VectorMatchCmp = Builder.CreateICmpNE(LHS: VectorLhsLoad, RHS: VectorRhsLoad); |
476 | VectorMatchCmp = Builder.CreateSelect(C: LoopPred, True: VectorMatchCmp, False: PFalse); |
477 | Value *VectorMatchHasActiveLanes = Builder.CreateOrReduce(Src: VectorMatchCmp); |
478 | BranchInst *VectorEarlyExit = BranchInst::Create( |
479 | IfTrue: VectorLoopMismatchBlock, IfFalse: VectorLoopIncBlock, Cond: VectorMatchHasActiveLanes); |
480 | Builder.Insert(I: VectorEarlyExit); |
481 | |
482 | DTU.applyUpdates( |
483 | Updates: {{DominatorTree::Insert, VectorLoopStartBlock, VectorLoopMismatchBlock}, |
484 | {DominatorTree::Insert, VectorLoopStartBlock, VectorLoopIncBlock}}); |
485 | |
486 | // Increment the index counter and calculate the predicate for the next |
487 | // iteration of the loop. We branch back to the start of the loop if there |
488 | // is at least one active lane. |
489 | Builder.SetInsertPoint(VectorLoopIncBlock); |
490 | Value *NewVectorIndexPhi = |
491 | Builder.CreateAdd(LHS: VectorIndexPhi, RHS: VecLen, Name: "" , |
492 | /*HasNUW=*/true, /*HasNSW=*/true); |
493 | VectorIndexPhi->addIncoming(V: NewVectorIndexPhi, BB: VectorLoopIncBlock); |
494 | Value *NewPred = |
495 | Builder.CreateIntrinsic(ID: Intrinsic::get_active_lane_mask, |
496 | Types: {PredVTy, I64Type}, Args: {NewVectorIndexPhi, ExtEnd}); |
497 | LoopPred->addIncoming(V: NewPred, BB: VectorLoopIncBlock); |
498 | |
499 | Value *PredHasActiveLanes = |
500 | Builder.CreateExtractElement(Vec: NewPred, Idx: uint64_t(0)); |
501 | BranchInst *VectorLoopBranchBack = |
502 | BranchInst::Create(IfTrue: VectorLoopStartBlock, IfFalse: EndBlock, Cond: PredHasActiveLanes); |
503 | Builder.Insert(I: VectorLoopBranchBack); |
504 | |
505 | DTU.applyUpdates( |
506 | Updates: {{DominatorTree::Insert, VectorLoopIncBlock, VectorLoopStartBlock}, |
507 | {DominatorTree::Insert, VectorLoopIncBlock, EndBlock}}); |
508 | |
509 | // If we found a mismatch then we need to calculate which lane in the vector |
510 | // had a mismatch and add that on to the current loop index. |
511 | Builder.SetInsertPoint(VectorLoopMismatchBlock); |
512 | PHINode *FoundPred = Builder.CreatePHI(Ty: PredVTy, NumReservedValues: 1, Name: "mismatch_vec_found_pred" ); |
513 | FoundPred->addIncoming(V: VectorMatchCmp, BB: VectorLoopStartBlock); |
514 | PHINode *LastLoopPred = |
515 | Builder.CreatePHI(Ty: PredVTy, NumReservedValues: 1, Name: "mismatch_vec_last_loop_pred" ); |
516 | LastLoopPred->addIncoming(V: LoopPred, BB: VectorLoopStartBlock); |
517 | PHINode *VectorFoundIndex = |
518 | Builder.CreatePHI(Ty: I64Type, NumReservedValues: 1, Name: "mismatch_vec_found_index" ); |
519 | VectorFoundIndex->addIncoming(V: VectorIndexPhi, BB: VectorLoopStartBlock); |
520 | |
521 | Value *PredMatchCmp = Builder.CreateAnd(LHS: LastLoopPred, RHS: FoundPred); |
522 | Value *Ctz = Builder.CreateCountTrailingZeroElems(ResTy: ResType, Mask: PredMatchCmp); |
523 | Ctz = Builder.CreateZExt(V: Ctz, DestTy: I64Type); |
524 | Value *VectorLoopRes64 = Builder.CreateAdd(LHS: VectorFoundIndex, RHS: Ctz, Name: "" , |
525 | /*HasNUW=*/true, /*HasNSW=*/true); |
526 | return Builder.CreateTrunc(V: VectorLoopRes64, DestTy: ResType); |
527 | } |
528 | |
529 | Value *LoopIdiomVectorize::createPredicatedFindMismatch( |
530 | IRBuilder<> &Builder, DomTreeUpdater &DTU, GetElementPtrInst *GEPA, |
531 | GetElementPtrInst *GEPB, Value *ExtStart, Value *ExtEnd) { |
532 | Type *I64Type = Builder.getInt64Ty(); |
533 | Type *I32Type = Builder.getInt32Ty(); |
534 | Type *ResType = I32Type; |
535 | Type *LoadType = Builder.getInt8Ty(); |
536 | Value *PtrA = GEPA->getPointerOperand(); |
537 | Value *PtrB = GEPB->getPointerOperand(); |
538 | |
539 | auto *JumpToVectorLoop = BranchInst::Create(IfTrue: VectorLoopStartBlock); |
540 | Builder.Insert(I: JumpToVectorLoop); |
541 | |
542 | DTU.applyUpdates(Updates: {{DominatorTree::Insert, VectorLoopPreheaderBlock, |
543 | VectorLoopStartBlock}}); |
544 | |
545 | // Set up the first Vector loop block by creating the PHIs, doing the vector |
546 | // loads and comparing the vectors. |
547 | Builder.SetInsertPoint(VectorLoopStartBlock); |
548 | auto *VectorIndexPhi = Builder.CreatePHI(Ty: I64Type, NumReservedValues: 2, Name: "mismatch_vector_index" ); |
549 | VectorIndexPhi->addIncoming(V: ExtStart, BB: VectorLoopPreheaderBlock); |
550 | |
551 | // Calculate AVL by subtracting the vector loop index from the trip count |
552 | Value *AVL = Builder.CreateSub(LHS: ExtEnd, RHS: VectorIndexPhi, Name: "avl" , /*HasNUW=*/true, |
553 | /*HasNSW=*/true); |
554 | |
555 | auto *VectorLoadType = ScalableVectorType::get(ElementType: LoadType, MinNumElts: ByteCompareVF); |
556 | auto *VF = ConstantInt::get(Ty: I32Type, V: ByteCompareVF); |
557 | |
558 | Value *VL = Builder.CreateIntrinsic(ID: Intrinsic::experimental_get_vector_length, |
559 | Types: {I64Type}, Args: {AVL, VF, Builder.getTrue()}); |
560 | Value *GepOffset = VectorIndexPhi; |
561 | |
562 | Value *VectorLhsGep = |
563 | Builder.CreateGEP(Ty: LoadType, Ptr: PtrA, IdxList: GepOffset, Name: "" , NW: GEPA->isInBounds()); |
564 | VectorType *TrueMaskTy = |
565 | VectorType::get(ElementType: Builder.getInt1Ty(), EC: VectorLoadType->getElementCount()); |
566 | Value *AllTrueMask = Constant::getAllOnesValue(Ty: TrueMaskTy); |
567 | Value *VectorLhsLoad = Builder.CreateIntrinsic( |
568 | ID: Intrinsic::vp_load, Types: {VectorLoadType, VectorLhsGep->getType()}, |
569 | Args: {VectorLhsGep, AllTrueMask, VL}, FMFSource: nullptr, Name: "lhs.load" ); |
570 | |
571 | Value *VectorRhsGep = |
572 | Builder.CreateGEP(Ty: LoadType, Ptr: PtrB, IdxList: GepOffset, Name: "" , NW: GEPB->isInBounds()); |
573 | Value *VectorRhsLoad = Builder.CreateIntrinsic( |
574 | ID: Intrinsic::vp_load, Types: {VectorLoadType, VectorLhsGep->getType()}, |
575 | Args: {VectorRhsGep, AllTrueMask, VL}, FMFSource: nullptr, Name: "rhs.load" ); |
576 | |
577 | StringRef PredicateStr = CmpInst::getPredicateName(P: CmpInst::ICMP_NE); |
578 | auto *PredicateMDS = MDString::get(Context&: VectorLhsLoad->getContext(), Str: PredicateStr); |
579 | Value *Pred = MetadataAsValue::get(Context&: VectorLhsLoad->getContext(), MD: PredicateMDS); |
580 | Value *VectorMatchCmp = Builder.CreateIntrinsic( |
581 | ID: Intrinsic::vp_icmp, Types: {VectorLhsLoad->getType()}, |
582 | Args: {VectorLhsLoad, VectorRhsLoad, Pred, AllTrueMask, VL}, FMFSource: nullptr, |
583 | Name: "mismatch.cmp" ); |
584 | Value *CTZ = Builder.CreateIntrinsic( |
585 | ID: Intrinsic::vp_cttz_elts, Types: {ResType, VectorMatchCmp->getType()}, |
586 | Args: {VectorMatchCmp, /*ZeroIsPoison=*/Builder.getInt1(V: false), AllTrueMask, |
587 | VL}); |
588 | Value *MismatchFound = Builder.CreateICmpNE(LHS: CTZ, RHS: VL); |
589 | auto *VectorEarlyExit = BranchInst::Create(IfTrue: VectorLoopMismatchBlock, |
590 | IfFalse: VectorLoopIncBlock, Cond: MismatchFound); |
591 | Builder.Insert(I: VectorEarlyExit); |
592 | |
593 | DTU.applyUpdates( |
594 | Updates: {{DominatorTree::Insert, VectorLoopStartBlock, VectorLoopMismatchBlock}, |
595 | {DominatorTree::Insert, VectorLoopStartBlock, VectorLoopIncBlock}}); |
596 | |
597 | // Increment the index counter and calculate the predicate for the next |
598 | // iteration of the loop. We branch back to the start of the loop if there |
599 | // is at least one active lane. |
600 | Builder.SetInsertPoint(VectorLoopIncBlock); |
601 | Value *VL64 = Builder.CreateZExt(V: VL, DestTy: I64Type); |
602 | Value *NewVectorIndexPhi = |
603 | Builder.CreateAdd(LHS: VectorIndexPhi, RHS: VL64, Name: "" , |
604 | /*HasNUW=*/true, /*HasNSW=*/true); |
605 | VectorIndexPhi->addIncoming(V: NewVectorIndexPhi, BB: VectorLoopIncBlock); |
606 | Value *ExitCond = Builder.CreateICmpNE(LHS: NewVectorIndexPhi, RHS: ExtEnd); |
607 | auto *VectorLoopBranchBack = |
608 | BranchInst::Create(IfTrue: VectorLoopStartBlock, IfFalse: EndBlock, Cond: ExitCond); |
609 | Builder.Insert(I: VectorLoopBranchBack); |
610 | |
611 | DTU.applyUpdates( |
612 | Updates: {{DominatorTree::Insert, VectorLoopIncBlock, VectorLoopStartBlock}, |
613 | {DominatorTree::Insert, VectorLoopIncBlock, EndBlock}}); |
614 | |
615 | // If we found a mismatch then we need to calculate which lane in the vector |
616 | // had a mismatch and add that on to the current loop index. |
617 | Builder.SetInsertPoint(VectorLoopMismatchBlock); |
618 | |
619 | // Add LCSSA phis for CTZ and VectorIndexPhi. |
620 | auto *CTZLCSSAPhi = Builder.CreatePHI(Ty: CTZ->getType(), NumReservedValues: 1, Name: "ctz" ); |
621 | CTZLCSSAPhi->addIncoming(V: CTZ, BB: VectorLoopStartBlock); |
622 | auto *VectorIndexLCSSAPhi = |
623 | Builder.CreatePHI(Ty: VectorIndexPhi->getType(), NumReservedValues: 1, Name: "mismatch_vector_index" ); |
624 | VectorIndexLCSSAPhi->addIncoming(V: VectorIndexPhi, BB: VectorLoopStartBlock); |
625 | |
626 | Value *CTZI64 = Builder.CreateZExt(V: CTZLCSSAPhi, DestTy: I64Type); |
627 | Value *VectorLoopRes64 = Builder.CreateAdd(LHS: VectorIndexLCSSAPhi, RHS: CTZI64, Name: "" , |
628 | /*HasNUW=*/true, /*HasNSW=*/true); |
629 | return Builder.CreateTrunc(V: VectorLoopRes64, DestTy: ResType); |
630 | } |
631 | |
632 | Value *LoopIdiomVectorize::expandFindMismatch( |
633 | IRBuilder<> &Builder, DomTreeUpdater &DTU, GetElementPtrInst *GEPA, |
634 | GetElementPtrInst *GEPB, Instruction *Index, Value *Start, Value *MaxLen) { |
635 | Value *PtrA = GEPA->getPointerOperand(); |
636 | Value *PtrB = GEPB->getPointerOperand(); |
637 | |
638 | // Get the arguments and types for the intrinsic. |
639 | BasicBlock * = CurLoop->getLoopPreheader(); |
640 | BranchInst *PHBranch = cast<BranchInst>(Val: Preheader->getTerminator()); |
641 | LLVMContext &Ctx = PHBranch->getContext(); |
642 | Type *LoadType = Type::getInt8Ty(C&: Ctx); |
643 | Type *ResType = Builder.getInt32Ty(); |
644 | |
645 | // Split block in the original loop preheader. |
646 | EndBlock = SplitBlock(Old: Preheader, SplitPt: PHBranch, DT, LI, MSSAU: nullptr, BBName: "mismatch_end" ); |
647 | |
648 | // Create the blocks that we're going to need: |
649 | // 1. A block for checking the zero-extended length exceeds 0 |
650 | // 2. A block to check that the start and end addresses of a given array |
651 | // lie on the same page. |
652 | // 3. The vector loop preheader. |
653 | // 4. The first vector loop block. |
654 | // 5. The vector loop increment block. |
655 | // 6. A block we can jump to from the vector loop when a mismatch is found. |
656 | // 7. The first block of the scalar loop itself, containing PHIs , loads |
657 | // and cmp. |
658 | // 8. A scalar loop increment block to increment the PHIs and go back |
659 | // around the loop. |
660 | |
661 | BasicBlock *MinItCheckBlock = BasicBlock::Create( |
662 | Context&: Ctx, Name: "mismatch_min_it_check" , Parent: EndBlock->getParent(), InsertBefore: EndBlock); |
663 | |
664 | // Update the terminator added by SplitBlock to branch to the first block |
665 | Preheader->getTerminator()->setSuccessor(Idx: 0, BB: MinItCheckBlock); |
666 | |
667 | BasicBlock *MemCheckBlock = BasicBlock::Create( |
668 | Context&: Ctx, Name: "mismatch_mem_check" , Parent: EndBlock->getParent(), InsertBefore: EndBlock); |
669 | |
670 | VectorLoopPreheaderBlock = BasicBlock::Create( |
671 | Context&: Ctx, Name: "mismatch_vec_loop_preheader" , Parent: EndBlock->getParent(), InsertBefore: EndBlock); |
672 | |
673 | VectorLoopStartBlock = BasicBlock::Create(Context&: Ctx, Name: "mismatch_vec_loop" , |
674 | Parent: EndBlock->getParent(), InsertBefore: EndBlock); |
675 | |
676 | VectorLoopIncBlock = BasicBlock::Create(Context&: Ctx, Name: "mismatch_vec_loop_inc" , |
677 | Parent: EndBlock->getParent(), InsertBefore: EndBlock); |
678 | |
679 | VectorLoopMismatchBlock = BasicBlock::Create(Context&: Ctx, Name: "mismatch_vec_loop_found" , |
680 | Parent: EndBlock->getParent(), InsertBefore: EndBlock); |
681 | |
682 | BasicBlock * = BasicBlock::Create( |
683 | Context&: Ctx, Name: "mismatch_loop_pre" , Parent: EndBlock->getParent(), InsertBefore: EndBlock); |
684 | |
685 | BasicBlock *LoopStartBlock = |
686 | BasicBlock::Create(Context&: Ctx, Name: "mismatch_loop" , Parent: EndBlock->getParent(), InsertBefore: EndBlock); |
687 | |
688 | BasicBlock *LoopIncBlock = BasicBlock::Create( |
689 | Context&: Ctx, Name: "mismatch_loop_inc" , Parent: EndBlock->getParent(), InsertBefore: EndBlock); |
690 | |
691 | DTU.applyUpdates(Updates: {{DominatorTree::Insert, Preheader, MinItCheckBlock}, |
692 | {DominatorTree::Delete, Preheader, EndBlock}}); |
693 | |
694 | // Update LoopInfo with the new vector & scalar loops. |
695 | auto VectorLoop = LI->AllocateLoop(); |
696 | auto ScalarLoop = LI->AllocateLoop(); |
697 | |
698 | if (CurLoop->getParentLoop()) { |
699 | CurLoop->getParentLoop()->addBasicBlockToLoop(NewBB: MinItCheckBlock, LI&: *LI); |
700 | CurLoop->getParentLoop()->addBasicBlockToLoop(NewBB: MemCheckBlock, LI&: *LI); |
701 | CurLoop->getParentLoop()->addBasicBlockToLoop(NewBB: VectorLoopPreheaderBlock, |
702 | LI&: *LI); |
703 | CurLoop->getParentLoop()->addChildLoop(NewChild: VectorLoop); |
704 | CurLoop->getParentLoop()->addBasicBlockToLoop(NewBB: VectorLoopMismatchBlock, LI&: *LI); |
705 | CurLoop->getParentLoop()->addBasicBlockToLoop(NewBB: LoopPreHeaderBlock, LI&: *LI); |
706 | CurLoop->getParentLoop()->addChildLoop(NewChild: ScalarLoop); |
707 | } else { |
708 | LI->addTopLevelLoop(New: VectorLoop); |
709 | LI->addTopLevelLoop(New: ScalarLoop); |
710 | } |
711 | |
712 | // Add the new basic blocks to their associated loops. |
713 | VectorLoop->addBasicBlockToLoop(NewBB: VectorLoopStartBlock, LI&: *LI); |
714 | VectorLoop->addBasicBlockToLoop(NewBB: VectorLoopIncBlock, LI&: *LI); |
715 | |
716 | ScalarLoop->addBasicBlockToLoop(NewBB: LoopStartBlock, LI&: *LI); |
717 | ScalarLoop->addBasicBlockToLoop(NewBB: LoopIncBlock, LI&: *LI); |
718 | |
719 | // Set up some types and constants that we intend to reuse. |
720 | Type *I64Type = Builder.getInt64Ty(); |
721 | |
722 | // Check the zero-extended iteration count > 0 |
723 | Builder.SetInsertPoint(MinItCheckBlock); |
724 | Value *ExtStart = Builder.CreateZExt(V: Start, DestTy: I64Type); |
725 | Value *ExtEnd = Builder.CreateZExt(V: MaxLen, DestTy: I64Type); |
726 | // This check doesn't really cost us very much. |
727 | |
728 | Value *LimitCheck = Builder.CreateICmpULE(LHS: Start, RHS: MaxLen); |
729 | BranchInst *MinItCheckBr = |
730 | BranchInst::Create(IfTrue: MemCheckBlock, IfFalse: LoopPreHeaderBlock, Cond: LimitCheck); |
731 | MinItCheckBr->setMetadata( |
732 | KindID: LLVMContext::MD_prof, |
733 | Node: MDBuilder(MinItCheckBr->getContext()).createBranchWeights(TrueWeight: 99, FalseWeight: 1)); |
734 | Builder.Insert(I: MinItCheckBr); |
735 | |
736 | DTU.applyUpdates( |
737 | Updates: {{DominatorTree::Insert, MinItCheckBlock, MemCheckBlock}, |
738 | {DominatorTree::Insert, MinItCheckBlock, LoopPreHeaderBlock}}); |
739 | |
740 | // For each of the arrays, check the start/end addresses are on the same |
741 | // page. |
742 | Builder.SetInsertPoint(MemCheckBlock); |
743 | |
744 | // The early exit in the original loop means that when performing vector |
745 | // loads we are potentially reading ahead of the early exit. So we could |
746 | // fault if crossing a page boundary. Therefore, we create runtime memory |
747 | // checks based on the minimum page size as follows: |
748 | // 1. Calculate the addresses of the first memory accesses in the loop, |
749 | // i.e. LhsStart and RhsStart. |
750 | // 2. Get the last accessed addresses in the loop, i.e. LhsEnd and RhsEnd. |
751 | // 3. Determine which pages correspond to all the memory accesses, i.e |
752 | // LhsStartPage, LhsEndPage, RhsStartPage, RhsEndPage. |
753 | // 4. If LhsStartPage == LhsEndPage and RhsStartPage == RhsEndPage, then |
754 | // we know we won't cross any page boundaries in the loop so we can |
755 | // enter the vector loop! Otherwise we fall back on the scalar loop. |
756 | Value *LhsStartGEP = Builder.CreateGEP(Ty: LoadType, Ptr: PtrA, IdxList: ExtStart); |
757 | Value *RhsStartGEP = Builder.CreateGEP(Ty: LoadType, Ptr: PtrB, IdxList: ExtStart); |
758 | Value *RhsStart = Builder.CreatePtrToInt(V: RhsStartGEP, DestTy: I64Type); |
759 | Value *LhsStart = Builder.CreatePtrToInt(V: LhsStartGEP, DestTy: I64Type); |
760 | Value *LhsEndGEP = Builder.CreateGEP(Ty: LoadType, Ptr: PtrA, IdxList: ExtEnd); |
761 | Value *RhsEndGEP = Builder.CreateGEP(Ty: LoadType, Ptr: PtrB, IdxList: ExtEnd); |
762 | Value *LhsEnd = Builder.CreatePtrToInt(V: LhsEndGEP, DestTy: I64Type); |
763 | Value *RhsEnd = Builder.CreatePtrToInt(V: RhsEndGEP, DestTy: I64Type); |
764 | |
765 | const uint64_t MinPageSize = TTI->getMinPageSize().value(); |
766 | const uint64_t AddrShiftAmt = llvm::Log2_64(Value: MinPageSize); |
767 | Value *LhsStartPage = Builder.CreateLShr(LHS: LhsStart, RHS: AddrShiftAmt); |
768 | Value *LhsEndPage = Builder.CreateLShr(LHS: LhsEnd, RHS: AddrShiftAmt); |
769 | Value *RhsStartPage = Builder.CreateLShr(LHS: RhsStart, RHS: AddrShiftAmt); |
770 | Value *RhsEndPage = Builder.CreateLShr(LHS: RhsEnd, RHS: AddrShiftAmt); |
771 | Value *LhsPageCmp = Builder.CreateICmpNE(LHS: LhsStartPage, RHS: LhsEndPage); |
772 | Value *RhsPageCmp = Builder.CreateICmpNE(LHS: RhsStartPage, RHS: RhsEndPage); |
773 | |
774 | Value *CombinedPageCmp = Builder.CreateOr(LHS: LhsPageCmp, RHS: RhsPageCmp); |
775 | BranchInst *CombinedPageCmpCmpBr = BranchInst::Create( |
776 | IfTrue: LoopPreHeaderBlock, IfFalse: VectorLoopPreheaderBlock, Cond: CombinedPageCmp); |
777 | CombinedPageCmpCmpBr->setMetadata( |
778 | KindID: LLVMContext::MD_prof, Node: MDBuilder(CombinedPageCmpCmpBr->getContext()) |
779 | .createBranchWeights(TrueWeight: 10, FalseWeight: 90)); |
780 | Builder.Insert(I: CombinedPageCmpCmpBr); |
781 | |
782 | DTU.applyUpdates( |
783 | Updates: {{DominatorTree::Insert, MemCheckBlock, LoopPreHeaderBlock}, |
784 | {DominatorTree::Insert, MemCheckBlock, VectorLoopPreheaderBlock}}); |
785 | |
786 | // Set up the vector loop preheader, i.e. calculate initial loop predicate, |
787 | // zero-extend MaxLen to 64-bits, determine the number of vector elements |
788 | // processed in each iteration, etc. |
789 | Builder.SetInsertPoint(VectorLoopPreheaderBlock); |
790 | |
791 | // At this point we know two things must be true: |
792 | // 1. Start <= End |
793 | // 2. ExtMaxLen <= MinPageSize due to the page checks. |
794 | // Therefore, we know that we can use a 64-bit induction variable that |
795 | // starts from 0 -> ExtMaxLen and it will not overflow. |
796 | Value *VectorLoopRes = nullptr; |
797 | switch (VectorizeStyle) { |
798 | case LoopIdiomVectorizeStyle::Masked: |
799 | VectorLoopRes = |
800 | createMaskedFindMismatch(Builder, DTU, GEPA, GEPB, ExtStart, ExtEnd); |
801 | break; |
802 | case LoopIdiomVectorizeStyle::Predicated: |
803 | VectorLoopRes = createPredicatedFindMismatch(Builder, DTU, GEPA, GEPB, |
804 | ExtStart, ExtEnd); |
805 | break; |
806 | } |
807 | |
808 | Builder.Insert(I: BranchInst::Create(IfTrue: EndBlock)); |
809 | |
810 | DTU.applyUpdates( |
811 | Updates: {{DominatorTree::Insert, VectorLoopMismatchBlock, EndBlock}}); |
812 | |
813 | // Generate code for scalar loop. |
814 | Builder.SetInsertPoint(LoopPreHeaderBlock); |
815 | Builder.Insert(I: BranchInst::Create(IfTrue: LoopStartBlock)); |
816 | |
817 | DTU.applyUpdates( |
818 | Updates: {{DominatorTree::Insert, LoopPreHeaderBlock, LoopStartBlock}}); |
819 | |
820 | Builder.SetInsertPoint(LoopStartBlock); |
821 | PHINode *IndexPhi = Builder.CreatePHI(Ty: ResType, NumReservedValues: 2, Name: "mismatch_index" ); |
822 | IndexPhi->addIncoming(V: Start, BB: LoopPreHeaderBlock); |
823 | |
824 | // Otherwise compare the values |
825 | // Load bytes from each array and compare them. |
826 | Value *GepOffset = Builder.CreateZExt(V: IndexPhi, DestTy: I64Type); |
827 | |
828 | Value *LhsGep = |
829 | Builder.CreateGEP(Ty: LoadType, Ptr: PtrA, IdxList: GepOffset, Name: "" , NW: GEPA->isInBounds()); |
830 | Value *LhsLoad = Builder.CreateLoad(Ty: LoadType, Ptr: LhsGep); |
831 | |
832 | Value *RhsGep = |
833 | Builder.CreateGEP(Ty: LoadType, Ptr: PtrB, IdxList: GepOffset, Name: "" , NW: GEPB->isInBounds()); |
834 | Value *RhsLoad = Builder.CreateLoad(Ty: LoadType, Ptr: RhsGep); |
835 | |
836 | Value *MatchCmp = Builder.CreateICmpEQ(LHS: LhsLoad, RHS: RhsLoad); |
837 | // If we have a mismatch then exit the loop ... |
838 | BranchInst *MatchCmpBr = BranchInst::Create(IfTrue: LoopIncBlock, IfFalse: EndBlock, Cond: MatchCmp); |
839 | Builder.Insert(I: MatchCmpBr); |
840 | |
841 | DTU.applyUpdates(Updates: {{DominatorTree::Insert, LoopStartBlock, LoopIncBlock}, |
842 | {DominatorTree::Insert, LoopStartBlock, EndBlock}}); |
843 | |
844 | // Have we reached the maximum permitted length for the loop? |
845 | Builder.SetInsertPoint(LoopIncBlock); |
846 | Value *PhiInc = Builder.CreateAdd(LHS: IndexPhi, RHS: ConstantInt::get(Ty: ResType, V: 1), Name: "" , |
847 | /*HasNUW=*/Index->hasNoUnsignedWrap(), |
848 | /*HasNSW=*/Index->hasNoSignedWrap()); |
849 | IndexPhi->addIncoming(V: PhiInc, BB: LoopIncBlock); |
850 | Value *IVCmp = Builder.CreateICmpEQ(LHS: PhiInc, RHS: MaxLen); |
851 | BranchInst *IVCmpBr = BranchInst::Create(IfTrue: EndBlock, IfFalse: LoopStartBlock, Cond: IVCmp); |
852 | Builder.Insert(I: IVCmpBr); |
853 | |
854 | DTU.applyUpdates(Updates: {{DominatorTree::Insert, LoopIncBlock, EndBlock}, |
855 | {DominatorTree::Insert, LoopIncBlock, LoopStartBlock}}); |
856 | |
857 | // In the end block we need to insert a PHI node to deal with three cases: |
858 | // 1. We didn't find a mismatch in the scalar loop, so we return MaxLen. |
859 | // 2. We exitted the scalar loop early due to a mismatch and need to return |
860 | // the index that we found. |
861 | // 3. We didn't find a mismatch in the vector loop, so we return MaxLen. |
862 | // 4. We exitted the vector loop early due to a mismatch and need to return |
863 | // the index that we found. |
864 | Builder.SetInsertPoint(TheBB: EndBlock, IP: EndBlock->getFirstInsertionPt()); |
865 | PHINode *ResPhi = Builder.CreatePHI(Ty: ResType, NumReservedValues: 4, Name: "mismatch_result" ); |
866 | ResPhi->addIncoming(V: MaxLen, BB: LoopIncBlock); |
867 | ResPhi->addIncoming(V: IndexPhi, BB: LoopStartBlock); |
868 | ResPhi->addIncoming(V: MaxLen, BB: VectorLoopIncBlock); |
869 | ResPhi->addIncoming(V: VectorLoopRes, BB: VectorLoopMismatchBlock); |
870 | |
871 | Value *FinalRes = Builder.CreateTrunc(V: ResPhi, DestTy: ResType); |
872 | |
873 | if (VerifyLoops) { |
874 | ScalarLoop->verifyLoop(); |
875 | VectorLoop->verifyLoop(); |
876 | if (!VectorLoop->isRecursivelyLCSSAForm(DT: *DT, LI: *LI)) |
877 | report_fatal_error(reason: "Loops must remain in LCSSA form!" ); |
878 | if (!ScalarLoop->isRecursivelyLCSSAForm(DT: *DT, LI: *LI)) |
879 | report_fatal_error(reason: "Loops must remain in LCSSA form!" ); |
880 | } |
881 | |
882 | return FinalRes; |
883 | } |
884 | |
885 | void LoopIdiomVectorize::transformByteCompare(GetElementPtrInst *GEPA, |
886 | GetElementPtrInst *GEPB, |
887 | PHINode *IndPhi, Value *MaxLen, |
888 | Instruction *Index, Value *Start, |
889 | bool IncIdx, BasicBlock *FoundBB, |
890 | BasicBlock *EndBB) { |
891 | |
892 | // Insert the byte compare code at the end of the preheader block |
893 | BasicBlock * = CurLoop->getLoopPreheader(); |
894 | BasicBlock * = CurLoop->getHeader(); |
895 | BranchInst *PHBranch = cast<BranchInst>(Val: Preheader->getTerminator()); |
896 | IRBuilder<> Builder(PHBranch); |
897 | DomTreeUpdater DTU(DT, DomTreeUpdater::UpdateStrategy::Lazy); |
898 | Builder.SetCurrentDebugLocation(PHBranch->getDebugLoc()); |
899 | |
900 | // Increment the pointer if this was done before the loads in the loop. |
901 | if (IncIdx) |
902 | Start = Builder.CreateAdd(LHS: Start, RHS: ConstantInt::get(Ty: Start->getType(), V: 1)); |
903 | |
904 | Value *ByteCmpRes = |
905 | expandFindMismatch(Builder, DTU, GEPA, GEPB, Index, Start, MaxLen); |
906 | |
907 | // Replaces uses of index & induction Phi with intrinsic (we already |
908 | // checked that the the first instruction of Header is the Phi above). |
909 | assert(IndPhi->hasOneUse() && "Index phi node has more than one use!" ); |
910 | Index->replaceAllUsesWith(V: ByteCmpRes); |
911 | |
912 | assert(PHBranch->isUnconditional() && |
913 | "Expected preheader to terminate with an unconditional branch." ); |
914 | |
915 | // If no mismatch was found, we can jump to the end block. Create a |
916 | // new basic block for the compare instruction. |
917 | auto *CmpBB = BasicBlock::Create(Context&: Preheader->getContext(), Name: "byte.compare" , |
918 | Parent: Preheader->getParent()); |
919 | CmpBB->moveBefore(MovePos: EndBB); |
920 | |
921 | // Replace the branch in the preheader with an always-true conditional branch. |
922 | // This ensures there is still a reference to the original loop. |
923 | Builder.CreateCondBr(Cond: Builder.getTrue(), True: CmpBB, False: Header); |
924 | PHBranch->eraseFromParent(); |
925 | |
926 | BasicBlock *MismatchEnd = cast<Instruction>(Val: ByteCmpRes)->getParent(); |
927 | DTU.applyUpdates(Updates: {{DominatorTree::Insert, MismatchEnd, CmpBB}}); |
928 | |
929 | // Create the branch to either the end or found block depending on the value |
930 | // returned by the intrinsic. |
931 | Builder.SetInsertPoint(CmpBB); |
932 | if (FoundBB != EndBB) { |
933 | Value *FoundCmp = Builder.CreateICmpEQ(LHS: ByteCmpRes, RHS: MaxLen); |
934 | Builder.CreateCondBr(Cond: FoundCmp, True: EndBB, False: FoundBB); |
935 | DTU.applyUpdates(Updates: {{DominatorTree::Insert, CmpBB, FoundBB}, |
936 | {DominatorTree::Insert, CmpBB, EndBB}}); |
937 | |
938 | } else { |
939 | Builder.CreateBr(Dest: FoundBB); |
940 | DTU.applyUpdates(Updates: {{DominatorTree::Insert, CmpBB, FoundBB}}); |
941 | } |
942 | |
943 | auto fixSuccessorPhis = [&](BasicBlock *SuccBB) { |
944 | for (PHINode &PN : SuccBB->phis()) { |
945 | // At this point we've already replaced all uses of the result from the |
946 | // loop with ByteCmp. Look through the incoming values to find ByteCmp, |
947 | // meaning this is a Phi collecting the results of the byte compare. |
948 | bool ResPhi = false; |
949 | for (Value *Op : PN.incoming_values()) |
950 | if (Op == ByteCmpRes) { |
951 | ResPhi = true; |
952 | break; |
953 | } |
954 | |
955 | // Any PHI that depended upon the result of the byte compare needs a new |
956 | // incoming value from CmpBB. This is because the original loop will get |
957 | // deleted. |
958 | if (ResPhi) |
959 | PN.addIncoming(V: ByteCmpRes, BB: CmpBB); |
960 | else { |
961 | // There should be no other outside uses of other values in the |
962 | // original loop. Any incoming values should either: |
963 | // 1. Be for blocks outside the loop, which aren't interesting. Or .. |
964 | // 2. These are from blocks in the loop with values defined outside |
965 | // the loop. We should a similar incoming value from CmpBB. |
966 | for (BasicBlock *BB : PN.blocks()) |
967 | if (CurLoop->contains(BB)) { |
968 | PN.addIncoming(V: PN.getIncomingValueForBlock(BB), BB: CmpBB); |
969 | break; |
970 | } |
971 | } |
972 | } |
973 | }; |
974 | |
975 | // Ensure all Phis in the successors of CmpBB have an incoming value from it. |
976 | fixSuccessorPhis(EndBB); |
977 | if (EndBB != FoundBB) |
978 | fixSuccessorPhis(FoundBB); |
979 | |
980 | // The new CmpBB block isn't part of the loop, but will need to be added to |
981 | // the outer loop if there is one. |
982 | if (!CurLoop->isOutermost()) |
983 | CurLoop->getParentLoop()->addBasicBlockToLoop(NewBB: CmpBB, LI&: *LI); |
984 | |
985 | if (VerifyLoops && CurLoop->getParentLoop()) { |
986 | CurLoop->getParentLoop()->verifyLoop(); |
987 | if (!CurLoop->getParentLoop()->isRecursivelyLCSSAForm(DT: *DT, LI: *LI)) |
988 | report_fatal_error(reason: "Loops must remain in LCSSA form!" ); |
989 | } |
990 | } |
991 | |
992 | bool LoopIdiomVectorize::recognizeFindFirstByte() { |
993 | // Currently the transformation only works on scalable vector types, although |
994 | // there is no fundamental reason why it cannot be made to work for fixed |
995 | // vectors. We also need to know the target's minimum page size in order to |
996 | // generate runtime memory checks to ensure the vector version won't fault. |
997 | if (!TTI->supportsScalableVectors() || !TTI->getMinPageSize().has_value() || |
998 | DisableFindFirstByte) |
999 | return false; |
1000 | |
1001 | // Define some constants we need throughout. |
1002 | BasicBlock * = CurLoop->getHeader(); |
1003 | LLVMContext &Ctx = Header->getContext(); |
1004 | |
1005 | // We are expecting the four blocks defined below: Header, MatchBB, InnerBB, |
1006 | // and OuterBB. For now, we will bail our for almost anything else. The Four |
1007 | // blocks contain one nested loop. |
1008 | if (CurLoop->getNumBackEdges() != 1 || CurLoop->getNumBlocks() != 4 || |
1009 | CurLoop->getSubLoops().size() != 1) |
1010 | return false; |
1011 | |
1012 | auto *InnerLoop = CurLoop->getSubLoops().front(); |
1013 | PHINode *IndPhi = dyn_cast<PHINode>(Val: &Header->front()); |
1014 | if (!IndPhi || IndPhi->getNumIncomingValues() != 2) |
1015 | return false; |
1016 | |
1017 | // Check instruction counts. |
1018 | auto LoopBlocks = CurLoop->getBlocks(); |
1019 | if (LoopBlocks[0]->sizeWithoutDebug() > 3 || |
1020 | LoopBlocks[1]->sizeWithoutDebug() > 4 || |
1021 | LoopBlocks[2]->sizeWithoutDebug() > 3 || |
1022 | LoopBlocks[3]->sizeWithoutDebug() > 3) |
1023 | return false; |
1024 | |
1025 | // Check that no instruction other than IndPhi has outside uses. |
1026 | for (BasicBlock *BB : LoopBlocks) |
1027 | for (Instruction &I : *BB) |
1028 | if (&I != IndPhi) |
1029 | for (User *U : I.users()) |
1030 | if (!CurLoop->contains(Inst: cast<Instruction>(Val: U))) |
1031 | return false; |
1032 | |
1033 | // Match the branch instruction in the header. We are expecting an |
1034 | // unconditional branch to the inner loop. |
1035 | // |
1036 | // Header: |
1037 | // %14 = phi ptr [ %24, %OuterBB ], [ %3, %Header.preheader ] |
1038 | // %15 = load i8, ptr %14, align 1 |
1039 | // br label %MatchBB |
1040 | BasicBlock *MatchBB; |
1041 | if (!match(V: Header->getTerminator(), P: m_UnconditionalBr(Succ&: MatchBB)) || |
1042 | !InnerLoop->contains(BB: MatchBB)) |
1043 | return false; |
1044 | |
1045 | // MatchBB should be the entrypoint into the inner loop containing the |
1046 | // comparison between a search element and a needle. |
1047 | // |
1048 | // MatchBB: |
1049 | // %20 = phi ptr [ %7, %Header ], [ %17, %InnerBB ] |
1050 | // %21 = load i8, ptr %20, align 1 |
1051 | // %22 = icmp eq i8 %15, %21 |
1052 | // br i1 %22, label %ExitSucc, label %InnerBB |
1053 | BasicBlock *ExitSucc, *InnerBB; |
1054 | Value *LoadSearch, *LoadNeedle; |
1055 | CmpPredicate MatchPred; |
1056 | if (!match(V: MatchBB->getTerminator(), |
1057 | P: m_Br(C: m_ICmp(Pred&: MatchPred, L: m_Value(V&: LoadSearch), R: m_Value(V&: LoadNeedle)), |
1058 | T: m_BasicBlock(V&: ExitSucc), F: m_BasicBlock(V&: InnerBB))) || |
1059 | MatchPred != ICmpInst::ICMP_EQ || !InnerLoop->contains(BB: InnerBB)) |
1060 | return false; |
1061 | |
1062 | // We expect outside uses of `IndPhi' in ExitSucc (and only there). |
1063 | for (User *U : IndPhi->users()) |
1064 | if (!CurLoop->contains(Inst: cast<Instruction>(Val: U))) { |
1065 | auto *PN = dyn_cast<PHINode>(Val: U); |
1066 | if (!PN || PN->getParent() != ExitSucc) |
1067 | return false; |
1068 | } |
1069 | |
1070 | // Match the loads and check they are simple. |
1071 | Value *Search, *Needle; |
1072 | if (!match(V: LoadSearch, P: m_Load(Op: m_Value(V&: Search))) || |
1073 | !match(V: LoadNeedle, P: m_Load(Op: m_Value(V&: Needle))) || |
1074 | !cast<LoadInst>(Val: LoadSearch)->isSimple() || |
1075 | !cast<LoadInst>(Val: LoadNeedle)->isSimple()) |
1076 | return false; |
1077 | |
1078 | // Check we are loading valid characters. |
1079 | Type *CharTy = LoadSearch->getType(); |
1080 | if (!CharTy->isIntegerTy() || LoadNeedle->getType() != CharTy) |
1081 | return false; |
1082 | |
1083 | // Pick the vectorisation factor based on CharTy, work out the cost of the |
1084 | // match intrinsic and decide if we should use it. |
1085 | // Note: For the time being we assume 128-bit vectors. |
1086 | unsigned VF = 128 / CharTy->getIntegerBitWidth(); |
1087 | SmallVector<Type *> Args = { |
1088 | ScalableVectorType::get(ElementType: CharTy, MinNumElts: VF), FixedVectorType::get(ElementType: CharTy, NumElts: VF), |
1089 | ScalableVectorType::get(ElementType: Type::getInt1Ty(C&: Ctx), MinNumElts: VF)}; |
1090 | IntrinsicCostAttributes Attrs(Intrinsic::experimental_vector_match, Args[2], |
1091 | Args); |
1092 | if (TTI->getIntrinsicInstrCost(ICA: Attrs, CostKind: TTI::TCK_SizeAndLatency) > 4) |
1093 | return false; |
1094 | |
1095 | // The loads come from two PHIs, each with two incoming values. |
1096 | PHINode *PSearch = dyn_cast<PHINode>(Val: Search); |
1097 | PHINode *PNeedle = dyn_cast<PHINode>(Val: Needle); |
1098 | if (!PSearch || PSearch->getNumIncomingValues() != 2 || !PNeedle || |
1099 | PNeedle->getNumIncomingValues() != 2) |
1100 | return false; |
1101 | |
1102 | // One PHI comes from the outer loop (PSearch), the other one from the inner |
1103 | // loop (PNeedle). PSearch effectively corresponds to IndPhi. |
1104 | if (InnerLoop->contains(Inst: PSearch)) |
1105 | std::swap(a&: PSearch, b&: PNeedle); |
1106 | if (PSearch != &Header->front() || PNeedle != &MatchBB->front()) |
1107 | return false; |
1108 | |
1109 | // The incoming values of both PHI nodes should be a gep of 1. |
1110 | Value *SearchStart = PSearch->getIncomingValue(i: 0); |
1111 | Value *SearchIndex = PSearch->getIncomingValue(i: 1); |
1112 | if (CurLoop->contains(BB: PSearch->getIncomingBlock(i: 0))) |
1113 | std::swap(a&: SearchStart, b&: SearchIndex); |
1114 | |
1115 | Value *NeedleStart = PNeedle->getIncomingValue(i: 0); |
1116 | Value *NeedleIndex = PNeedle->getIncomingValue(i: 1); |
1117 | if (InnerLoop->contains(BB: PNeedle->getIncomingBlock(i: 0))) |
1118 | std::swap(a&: NeedleStart, b&: NeedleIndex); |
1119 | |
1120 | // Match the GEPs. |
1121 | if (!match(V: SearchIndex, P: m_GEP(Ops: m_Specific(V: PSearch), Ops: m_One())) || |
1122 | !match(V: NeedleIndex, P: m_GEP(Ops: m_Specific(V: PNeedle), Ops: m_One()))) |
1123 | return false; |
1124 | |
1125 | // Check the GEPs result type matches `CharTy'. |
1126 | GetElementPtrInst *GEPSearch = cast<GetElementPtrInst>(Val: SearchIndex); |
1127 | GetElementPtrInst *GEPNeedle = cast<GetElementPtrInst>(Val: NeedleIndex); |
1128 | if (GEPSearch->getResultElementType() != CharTy || |
1129 | GEPNeedle->getResultElementType() != CharTy) |
1130 | return false; |
1131 | |
1132 | // InnerBB should increment the address of the needle pointer. |
1133 | // |
1134 | // InnerBB: |
1135 | // %17 = getelementptr inbounds i8, ptr %20, i64 1 |
1136 | // %18 = icmp eq ptr %17, %10 |
1137 | // br i1 %18, label %OuterBB, label %MatchBB |
1138 | BasicBlock *OuterBB; |
1139 | Value *NeedleEnd; |
1140 | if (!match(V: InnerBB->getTerminator(), |
1141 | P: m_Br(C: m_SpecificICmp(MatchPred: ICmpInst::ICMP_EQ, L: m_Specific(V: GEPNeedle), |
1142 | R: m_Value(V&: NeedleEnd)), |
1143 | T: m_BasicBlock(V&: OuterBB), F: m_Specific(V: MatchBB))) || |
1144 | !CurLoop->contains(BB: OuterBB)) |
1145 | return false; |
1146 | |
1147 | // OuterBB should increment the address of the search element pointer. |
1148 | // |
1149 | // OuterBB: |
1150 | // %24 = getelementptr inbounds i8, ptr %14, i64 1 |
1151 | // %25 = icmp eq ptr %24, %6 |
1152 | // br i1 %25, label %ExitFail, label %Header |
1153 | BasicBlock *ExitFail; |
1154 | Value *SearchEnd; |
1155 | if (!match(V: OuterBB->getTerminator(), |
1156 | P: m_Br(C: m_SpecificICmp(MatchPred: ICmpInst::ICMP_EQ, L: m_Specific(V: GEPSearch), |
1157 | R: m_Value(V&: SearchEnd)), |
1158 | T: m_BasicBlock(V&: ExitFail), F: m_Specific(V: Header)))) |
1159 | return false; |
1160 | |
1161 | if (!CurLoop->isLoopInvariant(V: SearchStart) || |
1162 | !CurLoop->isLoopInvariant(V: SearchEnd) || |
1163 | !CurLoop->isLoopInvariant(V: NeedleStart) || |
1164 | !CurLoop->isLoopInvariant(V: NeedleEnd)) |
1165 | return false; |
1166 | |
1167 | LLVM_DEBUG(dbgs() << "Found idiom in loop: \n" << *CurLoop << "\n\n" ); |
1168 | |
1169 | transformFindFirstByte(IndPhi, VF, CharTy, ExitSucc, ExitFail, SearchStart, |
1170 | SearchEnd, NeedleStart, NeedleEnd); |
1171 | return true; |
1172 | } |
1173 | |
1174 | Value *LoopIdiomVectorize::expandFindFirstByte( |
1175 | IRBuilder<> &Builder, DomTreeUpdater &DTU, unsigned VF, Type *CharTy, |
1176 | BasicBlock *ExitSucc, BasicBlock *ExitFail, Value *SearchStart, |
1177 | Value *SearchEnd, Value *NeedleStart, Value *NeedleEnd) { |
1178 | // Set up some types and constants that we intend to reuse. |
1179 | auto *PtrTy = Builder.getPtrTy(); |
1180 | auto *I64Ty = Builder.getInt64Ty(); |
1181 | auto *PredVTy = ScalableVectorType::get(ElementType: Builder.getInt1Ty(), MinNumElts: VF); |
1182 | auto *CharVTy = ScalableVectorType::get(ElementType: CharTy, MinNumElts: VF); |
1183 | auto *ConstVF = ConstantInt::get(Ty: I64Ty, V: VF); |
1184 | |
1185 | // Other common arguments. |
1186 | BasicBlock * = CurLoop->getLoopPreheader(); |
1187 | LLVMContext &Ctx = Preheader->getContext(); |
1188 | Value *Passthru = ConstantInt::getNullValue(Ty: CharVTy); |
1189 | |
1190 | // Split block in the original loop preheader. |
1191 | // SPH is the new preheader to the old scalar loop. |
1192 | BasicBlock *SPH = SplitBlock(Old: Preheader, SplitPt: Preheader->getTerminator(), DT, LI, |
1193 | MSSAU: nullptr, BBName: "scalar_preheader" ); |
1194 | |
1195 | // Create the blocks that we're going to use. |
1196 | // |
1197 | // We will have the following loops: |
1198 | // (O) Outer loop where we iterate over the elements of the search array. |
1199 | // (I) Inner loop where we iterate over the elements of the needle array. |
1200 | // |
1201 | // Overall, the blocks do the following: |
1202 | // (0) Check if the arrays can't cross page boundaries. If so go to (1), |
1203 | // otherwise fall back to the original scalar loop. |
1204 | // (1) Load the search array. Go to (2). |
1205 | // (2) (a) Load the needle array. |
1206 | // (b) Splat the first element to the inactive lanes. |
1207 | // (c) Check if any elements match. If so go to (3), otherwise go to (4). |
1208 | // (3) Compute the index of the first match and exit. |
1209 | // (4) Check if we've reached the end of the needle array. If not loop back to |
1210 | // (2), otherwise go to (5). |
1211 | // (5) Check if we've reached the end of the search array. If not loop back to |
1212 | // (1), otherwise exit. |
1213 | // Blocks (0,3) are not part of any loop. Blocks (1,5) and (2,4) belong to |
1214 | // the outer and inner loops, respectively. |
1215 | BasicBlock *BB0 = BasicBlock::Create(Context&: Ctx, Name: "mem_check" , Parent: SPH->getParent(), InsertBefore: SPH); |
1216 | BasicBlock *BB1 = |
1217 | BasicBlock::Create(Context&: Ctx, Name: "find_first_vec_header" , Parent: SPH->getParent(), InsertBefore: SPH); |
1218 | BasicBlock *BB2 = |
1219 | BasicBlock::Create(Context&: Ctx, Name: "match_check_vec" , Parent: SPH->getParent(), InsertBefore: SPH); |
1220 | BasicBlock *BB3 = |
1221 | BasicBlock::Create(Context&: Ctx, Name: "calculate_match" , Parent: SPH->getParent(), InsertBefore: SPH); |
1222 | BasicBlock *BB4 = |
1223 | BasicBlock::Create(Context&: Ctx, Name: "needle_check_vec" , Parent: SPH->getParent(), InsertBefore: SPH); |
1224 | BasicBlock *BB5 = |
1225 | BasicBlock::Create(Context&: Ctx, Name: "search_check_vec" , Parent: SPH->getParent(), InsertBefore: SPH); |
1226 | |
1227 | // Update LoopInfo with the new loops. |
1228 | auto OuterLoop = LI->AllocateLoop(); |
1229 | auto InnerLoop = LI->AllocateLoop(); |
1230 | |
1231 | if (auto ParentLoop = CurLoop->getParentLoop()) { |
1232 | ParentLoop->addBasicBlockToLoop(NewBB: BB0, LI&: *LI); |
1233 | ParentLoop->addChildLoop(NewChild: OuterLoop); |
1234 | ParentLoop->addBasicBlockToLoop(NewBB: BB3, LI&: *LI); |
1235 | } else { |
1236 | LI->addTopLevelLoop(New: OuterLoop); |
1237 | } |
1238 | |
1239 | // Add the inner loop to the outer. |
1240 | OuterLoop->addChildLoop(NewChild: InnerLoop); |
1241 | |
1242 | // Add the new basic blocks to the corresponding loops. |
1243 | OuterLoop->addBasicBlockToLoop(NewBB: BB1, LI&: *LI); |
1244 | OuterLoop->addBasicBlockToLoop(NewBB: BB5, LI&: *LI); |
1245 | InnerLoop->addBasicBlockToLoop(NewBB: BB2, LI&: *LI); |
1246 | InnerLoop->addBasicBlockToLoop(NewBB: BB4, LI&: *LI); |
1247 | |
1248 | // Update the terminator added by SplitBlock to branch to the first block. |
1249 | Preheader->getTerminator()->setSuccessor(Idx: 0, BB: BB0); |
1250 | DTU.applyUpdates(Updates: {{DominatorTree::Delete, Preheader, SPH}, |
1251 | {DominatorTree::Insert, Preheader, BB0}}); |
1252 | |
1253 | // (0) Check if we could be crossing a page boundary; if so, fallback to the |
1254 | // old scalar loops. Also create a predicate of VF elements to be used in the |
1255 | // vector loops. |
1256 | Builder.SetInsertPoint(BB0); |
1257 | Value *ISearchStart = |
1258 | Builder.CreatePtrToInt(V: SearchStart, DestTy: I64Ty, Name: "search_start_int" ); |
1259 | Value *ISearchEnd = |
1260 | Builder.CreatePtrToInt(V: SearchEnd, DestTy: I64Ty, Name: "search_end_int" ); |
1261 | Value *INeedleStart = |
1262 | Builder.CreatePtrToInt(V: NeedleStart, DestTy: I64Ty, Name: "needle_start_int" ); |
1263 | Value *INeedleEnd = |
1264 | Builder.CreatePtrToInt(V: NeedleEnd, DestTy: I64Ty, Name: "needle_end_int" ); |
1265 | Value *PredVF = |
1266 | Builder.CreateIntrinsic(ID: Intrinsic::get_active_lane_mask, Types: {PredVTy, I64Ty}, |
1267 | Args: {ConstantInt::get(Ty: I64Ty, V: 0), ConstVF}); |
1268 | |
1269 | const uint64_t MinPageSize = TTI->getMinPageSize().value(); |
1270 | const uint64_t AddrShiftAmt = llvm::Log2_64(Value: MinPageSize); |
1271 | Value *SearchStartPage = |
1272 | Builder.CreateLShr(LHS: ISearchStart, RHS: AddrShiftAmt, Name: "search_start_page" ); |
1273 | Value *SearchEndPage = |
1274 | Builder.CreateLShr(LHS: ISearchEnd, RHS: AddrShiftAmt, Name: "search_end_page" ); |
1275 | Value *NeedleStartPage = |
1276 | Builder.CreateLShr(LHS: INeedleStart, RHS: AddrShiftAmt, Name: "needle_start_page" ); |
1277 | Value *NeedleEndPage = |
1278 | Builder.CreateLShr(LHS: INeedleEnd, RHS: AddrShiftAmt, Name: "needle_end_page" ); |
1279 | Value *SearchPageCmp = |
1280 | Builder.CreateICmpNE(LHS: SearchStartPage, RHS: SearchEndPage, Name: "search_page_cmp" ); |
1281 | Value *NeedlePageCmp = |
1282 | Builder.CreateICmpNE(LHS: NeedleStartPage, RHS: NeedleEndPage, Name: "needle_page_cmp" ); |
1283 | |
1284 | Value *CombinedPageCmp = |
1285 | Builder.CreateOr(LHS: SearchPageCmp, RHS: NeedlePageCmp, Name: "combined_page_cmp" ); |
1286 | BranchInst *CombinedPageBr = Builder.CreateCondBr(Cond: CombinedPageCmp, True: SPH, False: BB1); |
1287 | CombinedPageBr->setMetadata(KindID: LLVMContext::MD_prof, |
1288 | Node: MDBuilder(Ctx).createBranchWeights(TrueWeight: 10, FalseWeight: 90)); |
1289 | DTU.applyUpdates( |
1290 | Updates: {{DominatorTree::Insert, BB0, SPH}, {DominatorTree::Insert, BB0, BB1}}); |
1291 | |
1292 | // (1) Load the search array and branch to the inner loop. |
1293 | Builder.SetInsertPoint(BB1); |
1294 | PHINode *Search = Builder.CreatePHI(Ty: PtrTy, NumReservedValues: 2, Name: "psearch" ); |
1295 | Value *PredSearch = Builder.CreateIntrinsic( |
1296 | ID: Intrinsic::get_active_lane_mask, Types: {PredVTy, I64Ty}, |
1297 | Args: {Builder.CreatePtrToInt(V: Search, DestTy: I64Ty), ISearchEnd}, FMFSource: nullptr, |
1298 | Name: "search_pred" ); |
1299 | PredSearch = Builder.CreateAnd(LHS: PredVF, RHS: PredSearch, Name: "search_masked" ); |
1300 | Value *LoadSearch = Builder.CreateMaskedLoad( |
1301 | Ty: CharVTy, Ptr: Search, Alignment: Align(1), Mask: PredSearch, PassThru: Passthru, Name: "search_load_vec" ); |
1302 | Builder.CreateBr(Dest: BB2); |
1303 | DTU.applyUpdates(Updates: {{DominatorTree::Insert, BB1, BB2}}); |
1304 | |
1305 | // (2) Inner loop. |
1306 | Builder.SetInsertPoint(BB2); |
1307 | PHINode *Needle = Builder.CreatePHI(Ty: PtrTy, NumReservedValues: 2, Name: "pneedle" ); |
1308 | |
1309 | // (2.a) Load the needle array. |
1310 | Value *PredNeedle = Builder.CreateIntrinsic( |
1311 | ID: Intrinsic::get_active_lane_mask, Types: {PredVTy, I64Ty}, |
1312 | Args: {Builder.CreatePtrToInt(V: Needle, DestTy: I64Ty), INeedleEnd}, FMFSource: nullptr, |
1313 | Name: "needle_pred" ); |
1314 | PredNeedle = Builder.CreateAnd(LHS: PredVF, RHS: PredNeedle, Name: "needle_masked" ); |
1315 | Value *LoadNeedle = Builder.CreateMaskedLoad( |
1316 | Ty: CharVTy, Ptr: Needle, Alignment: Align(1), Mask: PredNeedle, PassThru: Passthru, Name: "needle_load_vec" ); |
1317 | |
1318 | // (2.b) Splat the first element to the inactive lanes. |
1319 | Value *Needle0 = |
1320 | Builder.CreateExtractElement(Vec: LoadNeedle, Idx: uint64_t(0), Name: "needle0" ); |
1321 | Value *Needle0Splat = Builder.CreateVectorSplat(EC: ElementCount::getScalable(MinVal: VF), |
1322 | V: Needle0, Name: "needle0" ); |
1323 | LoadNeedle = Builder.CreateSelect(C: PredNeedle, True: LoadNeedle, False: Needle0Splat, |
1324 | Name: "needle_splat" ); |
1325 | LoadNeedle = Builder.CreateExtractVector( |
1326 | DstType: FixedVectorType::get(ElementType: CharTy, NumElts: VF), SrcVec: LoadNeedle, Idx: uint64_t(0), Name: "needle_vec" ); |
1327 | |
1328 | // (2.c) Test if there's a match. |
1329 | Value *MatchPred = Builder.CreateIntrinsic( |
1330 | ID: Intrinsic::experimental_vector_match, Types: {CharVTy, LoadNeedle->getType()}, |
1331 | Args: {LoadSearch, LoadNeedle, PredSearch}, FMFSource: nullptr, Name: "match_pred" ); |
1332 | Value *IfAnyMatch = Builder.CreateOrReduce(Src: MatchPred); |
1333 | Builder.CreateCondBr(Cond: IfAnyMatch, True: BB3, False: BB4); |
1334 | DTU.applyUpdates( |
1335 | Updates: {{DominatorTree::Insert, BB2, BB3}, {DominatorTree::Insert, BB2, BB4}}); |
1336 | |
1337 | // (3) We found a match. Compute the index of its location and exit. |
1338 | Builder.SetInsertPoint(BB3); |
1339 | PHINode *MatchLCSSA = Builder.CreatePHI(Ty: PtrTy, NumReservedValues: 1, Name: "match_start" ); |
1340 | PHINode *MatchPredLCSSA = |
1341 | Builder.CreatePHI(Ty: MatchPred->getType(), NumReservedValues: 1, Name: "match_vec" ); |
1342 | Value *MatchCnt = Builder.CreateIntrinsic( |
1343 | ID: Intrinsic::experimental_cttz_elts, Types: {I64Ty, MatchPred->getType()}, |
1344 | Args: {MatchPredLCSSA, /*ZeroIsPoison=*/Builder.getInt1(V: true)}, FMFSource: nullptr, |
1345 | Name: "match_idx" ); |
1346 | Value *MatchVal = |
1347 | Builder.CreateGEP(Ty: CharTy, Ptr: MatchLCSSA, IdxList: MatchCnt, Name: "match_res" ); |
1348 | Builder.CreateBr(Dest: ExitSucc); |
1349 | DTU.applyUpdates(Updates: {{DominatorTree::Insert, BB3, ExitSucc}}); |
1350 | |
1351 | // (4) Check if we've reached the end of the needle array. |
1352 | Builder.SetInsertPoint(BB4); |
1353 | Value *NextNeedle = |
1354 | Builder.CreateGEP(Ty: CharTy, Ptr: Needle, IdxList: ConstVF, Name: "needle_next_vec" ); |
1355 | Builder.CreateCondBr(Cond: Builder.CreateICmpULT(LHS: NextNeedle, RHS: NeedleEnd), True: BB2, False: BB5); |
1356 | DTU.applyUpdates( |
1357 | Updates: {{DominatorTree::Insert, BB4, BB2}, {DominatorTree::Insert, BB4, BB5}}); |
1358 | |
1359 | // (5) Check if we've reached the end of the search array. |
1360 | Builder.SetInsertPoint(BB5); |
1361 | Value *NextSearch = |
1362 | Builder.CreateGEP(Ty: CharTy, Ptr: Search, IdxList: ConstVF, Name: "search_next_vec" ); |
1363 | Builder.CreateCondBr(Cond: Builder.CreateICmpULT(LHS: NextSearch, RHS: SearchEnd), True: BB1, |
1364 | False: ExitFail); |
1365 | DTU.applyUpdates(Updates: {{DominatorTree::Insert, BB5, BB1}, |
1366 | {DominatorTree::Insert, BB5, ExitFail}}); |
1367 | |
1368 | // Set up the PHI nodes. |
1369 | Search->addIncoming(V: SearchStart, BB: BB0); |
1370 | Search->addIncoming(V: NextSearch, BB: BB5); |
1371 | Needle->addIncoming(V: NeedleStart, BB: BB1); |
1372 | Needle->addIncoming(V: NextNeedle, BB: BB4); |
1373 | // These are needed to retain LCSSA form. |
1374 | MatchLCSSA->addIncoming(V: Search, BB: BB2); |
1375 | MatchPredLCSSA->addIncoming(V: MatchPred, BB: BB2); |
1376 | |
1377 | if (VerifyLoops) { |
1378 | OuterLoop->verifyLoop(); |
1379 | InnerLoop->verifyLoop(); |
1380 | if (!OuterLoop->isRecursivelyLCSSAForm(DT: *DT, LI: *LI)) |
1381 | report_fatal_error(reason: "Loops must remain in LCSSA form!" ); |
1382 | } |
1383 | |
1384 | return MatchVal; |
1385 | } |
1386 | |
1387 | void LoopIdiomVectorize::transformFindFirstByte( |
1388 | PHINode *IndPhi, unsigned VF, Type *CharTy, BasicBlock *ExitSucc, |
1389 | BasicBlock *ExitFail, Value *SearchStart, Value *SearchEnd, |
1390 | Value *NeedleStart, Value *NeedleEnd) { |
1391 | // Insert the find first byte code at the end of the preheader block. |
1392 | BasicBlock * = CurLoop->getLoopPreheader(); |
1393 | BranchInst *PHBranch = cast<BranchInst>(Val: Preheader->getTerminator()); |
1394 | IRBuilder<> Builder(PHBranch); |
1395 | DomTreeUpdater DTU(DT, DomTreeUpdater::UpdateStrategy::Lazy); |
1396 | Builder.SetCurrentDebugLocation(PHBranch->getDebugLoc()); |
1397 | |
1398 | Value *MatchVal = |
1399 | expandFindFirstByte(Builder, DTU, VF, CharTy, ExitSucc, ExitFail, |
1400 | SearchStart, SearchEnd, NeedleStart, NeedleEnd); |
1401 | |
1402 | assert(PHBranch->isUnconditional() && |
1403 | "Expected preheader to terminate with an unconditional branch." ); |
1404 | |
1405 | // Add new incoming values with the result of the transformation to PHINodes |
1406 | // of ExitSucc that use IndPhi. |
1407 | for (auto *U : llvm::make_early_inc_range(Range: IndPhi->users())) { |
1408 | auto *PN = dyn_cast<PHINode>(Val: U); |
1409 | if (PN && PN->getParent() == ExitSucc) |
1410 | PN->addIncoming(V: MatchVal, BB: cast<Instruction>(Val: MatchVal)->getParent()); |
1411 | } |
1412 | |
1413 | if (VerifyLoops && CurLoop->getParentLoop()) { |
1414 | CurLoop->getParentLoop()->verifyLoop(); |
1415 | if (!CurLoop->getParentLoop()->isRecursivelyLCSSAForm(DT: *DT, LI: *LI)) |
1416 | report_fatal_error(reason: "Loops must remain in LCSSA form!" ); |
1417 | } |
1418 | } |
1419 | |