| 1 | //===-------- LoopIdiomVectorize.cpp - Loop idiom vectorization -----------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | // |
| 9 | // This pass implements a pass that recognizes certain loop idioms and |
| 10 | // transforms them into more optimized versions of the same loop. In cases |
| 11 | // where this happens, it can be a significant performance win. |
| 12 | // |
| 13 | // We currently support two loops: |
| 14 | // |
| 15 | // 1. A loop that finds the first mismatched byte in an array and returns the |
| 16 | // index, i.e. something like: |
| 17 | // |
| 18 | // while (++i != n) { |
| 19 | // if (a[i] != b[i]) |
| 20 | // break; |
| 21 | // } |
| 22 | // |
| 23 | // In this example we can actually vectorize the loop despite the early exit, |
| 24 | // although the loop vectorizer does not support it. It requires some extra |
| 25 | // checks to deal with the possibility of faulting loads when crossing page |
| 26 | // boundaries. However, even with these checks it is still profitable to do the |
| 27 | // transformation. |
| 28 | // |
| 29 | // TODO List: |
| 30 | // |
| 31 | // * Add support for the inverse case where we scan for a matching element. |
| 32 | // * Permit 64-bit induction variable types. |
| 33 | // * Recognize loops that increment the IV *after* comparing bytes. |
| 34 | // * Allow 32-bit sign-extends of the IV used by the GEP. |
| 35 | // |
| 36 | // 2. A loop that finds the first matching character in an array among a set of |
| 37 | // possible matches, e.g.: |
| 38 | // |
| 39 | // for (; first != last; ++first) |
| 40 | // for (s_it = s_first; s_it != s_last; ++s_it) |
| 41 | // if (*first == *s_it) |
| 42 | // return first; |
| 43 | // return last; |
| 44 | // |
| 45 | // This corresponds to std::find_first_of (for arrays of bytes) from the C++ |
| 46 | // standard library. This function can be implemented efficiently for targets |
| 47 | // that support @llvm.experimental.vector.match. For example, on AArch64 targets |
| 48 | // that implement SVE2, this lower to a MATCH instruction, which enables us to |
| 49 | // perform up to 16x16=256 comparisons in one go. This can lead to very |
| 50 | // significant speedups. |
| 51 | // |
| 52 | // TODO: |
| 53 | // |
| 54 | // * Add support for `find_first_not_of' loops (i.e. with not-equal comparison). |
| 55 | // * Make VF a configurable parameter (right now we assume 128-bit vectors). |
| 56 | // * Potentially adjust the cost model to let the transformation kick-in even if |
| 57 | // @llvm.experimental.vector.match doesn't have direct support in hardware. |
| 58 | // |
| 59 | //===----------------------------------------------------------------------===// |
| 60 | // |
| 61 | // NOTE: This Pass matches really specific loop patterns because it's only |
| 62 | // supposed to be a temporary solution until our LoopVectorizer is powerful |
| 63 | // enough to vectorize them automatically. |
| 64 | // |
| 65 | //===----------------------------------------------------------------------===// |
| 66 | |
| 67 | #include "llvm/Transforms/Vectorize/LoopIdiomVectorize.h" |
| 68 | #include "llvm/Analysis/DomTreeUpdater.h" |
| 69 | #include "llvm/Analysis/LoopPass.h" |
| 70 | #include "llvm/Analysis/TargetTransformInfo.h" |
| 71 | #include "llvm/IR/Dominators.h" |
| 72 | #include "llvm/IR/IRBuilder.h" |
| 73 | #include "llvm/IR/Intrinsics.h" |
| 74 | #include "llvm/IR/MDBuilder.h" |
| 75 | #include "llvm/IR/PatternMatch.h" |
| 76 | #include "llvm/Transforms/Utils/BasicBlockUtils.h" |
| 77 | |
| 78 | using namespace llvm; |
| 79 | using namespace PatternMatch; |
| 80 | |
| 81 | #define DEBUG_TYPE "loop-idiom-vectorize" |
| 82 | |
| 83 | static cl::opt<bool> DisableAll("disable-loop-idiom-vectorize-all" , cl::Hidden, |
| 84 | cl::init(Val: false), |
| 85 | cl::desc("Disable Loop Idiom Vectorize Pass." )); |
| 86 | |
| 87 | static cl::opt<LoopIdiomVectorizeStyle> |
| 88 | LITVecStyle("loop-idiom-vectorize-style" , cl::Hidden, |
| 89 | cl::desc("The vectorization style for loop idiom transform." ), |
| 90 | cl::values(clEnumValN(LoopIdiomVectorizeStyle::Masked, "masked" , |
| 91 | "Use masked vector intrinsics" ), |
| 92 | clEnumValN(LoopIdiomVectorizeStyle::Predicated, |
| 93 | "predicated" , "Use VP intrinsics" )), |
| 94 | cl::init(Val: LoopIdiomVectorizeStyle::Masked)); |
| 95 | |
| 96 | static cl::opt<bool> |
| 97 | DisableByteCmp("disable-loop-idiom-vectorize-bytecmp" , cl::Hidden, |
| 98 | cl::init(Val: false), |
| 99 | cl::desc("Proceed with Loop Idiom Vectorize Pass, but do " |
| 100 | "not convert byte-compare loop(s)." )); |
| 101 | |
| 102 | static cl::opt<unsigned> |
| 103 | ByteCmpVF("loop-idiom-vectorize-bytecmp-vf" , cl::Hidden, |
| 104 | cl::desc("The vectorization factor for byte-compare patterns." ), |
| 105 | cl::init(Val: 16)); |
| 106 | |
| 107 | static cl::opt<bool> |
| 108 | DisableFindFirstByte("disable-loop-idiom-vectorize-find-first-byte" , |
| 109 | cl::Hidden, cl::init(Val: false), |
| 110 | cl::desc("Do not convert find-first-byte loop(s)." )); |
| 111 | |
| 112 | static cl::opt<bool> |
| 113 | VerifyLoops("loop-idiom-vectorize-verify" , cl::Hidden, cl::init(Val: false), |
| 114 | cl::desc("Verify loops generated Loop Idiom Vectorize Pass." )); |
| 115 | |
| 116 | namespace { |
| 117 | class LoopIdiomVectorize { |
| 118 | LoopIdiomVectorizeStyle VectorizeStyle; |
| 119 | unsigned ByteCompareVF; |
| 120 | Loop *CurLoop = nullptr; |
| 121 | DominatorTree *DT; |
| 122 | LoopInfo *LI; |
| 123 | const TargetTransformInfo *TTI; |
| 124 | const DataLayout *DL; |
| 125 | |
| 126 | // Blocks that will be used for inserting vectorized code. |
| 127 | BasicBlock *EndBlock = nullptr; |
| 128 | BasicBlock * = nullptr; |
| 129 | BasicBlock *VectorLoopStartBlock = nullptr; |
| 130 | BasicBlock *VectorLoopMismatchBlock = nullptr; |
| 131 | BasicBlock *VectorLoopIncBlock = nullptr; |
| 132 | |
| 133 | public: |
| 134 | LoopIdiomVectorize(LoopIdiomVectorizeStyle S, unsigned VF, DominatorTree *DT, |
| 135 | LoopInfo *LI, const TargetTransformInfo *TTI, |
| 136 | const DataLayout *DL) |
| 137 | : VectorizeStyle(S), ByteCompareVF(VF), DT(DT), LI(LI), TTI(TTI), DL(DL) { |
| 138 | } |
| 139 | |
| 140 | bool run(Loop *L); |
| 141 | |
| 142 | private: |
| 143 | /// \name Countable Loop Idiom Handling |
| 144 | /// @{ |
| 145 | |
| 146 | bool runOnCountableLoop(); |
| 147 | bool runOnLoopBlock(BasicBlock *BB, const SCEV *BECount, |
| 148 | SmallVectorImpl<BasicBlock *> &ExitBlocks); |
| 149 | |
| 150 | bool recognizeByteCompare(); |
| 151 | |
| 152 | Value *expandFindMismatch(IRBuilder<> &Builder, DomTreeUpdater &DTU, |
| 153 | GetElementPtrInst *GEPA, GetElementPtrInst *GEPB, |
| 154 | Instruction *Index, Value *Start, Value *MaxLen); |
| 155 | |
| 156 | Value *createMaskedFindMismatch(IRBuilder<> &Builder, DomTreeUpdater &DTU, |
| 157 | GetElementPtrInst *GEPA, |
| 158 | GetElementPtrInst *GEPB, Value *ExtStart, |
| 159 | Value *ExtEnd); |
| 160 | Value *createPredicatedFindMismatch(IRBuilder<> &Builder, DomTreeUpdater &DTU, |
| 161 | GetElementPtrInst *GEPA, |
| 162 | GetElementPtrInst *GEPB, Value *ExtStart, |
| 163 | Value *ExtEnd); |
| 164 | |
| 165 | void transformByteCompare(GetElementPtrInst *GEPA, GetElementPtrInst *GEPB, |
| 166 | PHINode *IndPhi, Value *MaxLen, Instruction *Index, |
| 167 | Value *Start, bool IncIdx, BasicBlock *FoundBB, |
| 168 | BasicBlock *EndBB); |
| 169 | |
| 170 | bool recognizeFindFirstByte(); |
| 171 | |
| 172 | Value *expandFindFirstByte(IRBuilder<> &Builder, DomTreeUpdater &DTU, |
| 173 | unsigned VF, Type *CharTy, BasicBlock *ExitSucc, |
| 174 | BasicBlock *ExitFail, Value *SearchStart, |
| 175 | Value *SearchEnd, Value *NeedleStart, |
| 176 | Value *NeedleEnd); |
| 177 | |
| 178 | void transformFindFirstByte(PHINode *IndPhi, unsigned VF, Type *CharTy, |
| 179 | BasicBlock *ExitSucc, BasicBlock *ExitFail, |
| 180 | Value *SearchStart, Value *SearchEnd, |
| 181 | Value *NeedleStart, Value *NeedleEnd); |
| 182 | /// @} |
| 183 | }; |
| 184 | } // anonymous namespace |
| 185 | |
| 186 | PreservedAnalyses LoopIdiomVectorizePass::run(Loop &L, LoopAnalysisManager &AM, |
| 187 | LoopStandardAnalysisResults &AR, |
| 188 | LPMUpdater &) { |
| 189 | if (DisableAll) |
| 190 | return PreservedAnalyses::all(); |
| 191 | |
| 192 | const auto *DL = &L.getHeader()->getDataLayout(); |
| 193 | |
| 194 | LoopIdiomVectorizeStyle VecStyle = VectorizeStyle; |
| 195 | if (LITVecStyle.getNumOccurrences()) |
| 196 | VecStyle = LITVecStyle; |
| 197 | |
| 198 | unsigned BCVF = ByteCompareVF; |
| 199 | if (ByteCmpVF.getNumOccurrences()) |
| 200 | BCVF = ByteCmpVF; |
| 201 | |
| 202 | LoopIdiomVectorize LIV(VecStyle, BCVF, &AR.DT, &AR.LI, &AR.TTI, DL); |
| 203 | if (!LIV.run(L: &L)) |
| 204 | return PreservedAnalyses::all(); |
| 205 | |
| 206 | return PreservedAnalyses::none(); |
| 207 | } |
| 208 | |
| 209 | //===----------------------------------------------------------------------===// |
| 210 | // |
| 211 | // Implementation of LoopIdiomVectorize |
| 212 | // |
| 213 | //===----------------------------------------------------------------------===// |
| 214 | |
| 215 | bool LoopIdiomVectorize::run(Loop *L) { |
| 216 | CurLoop = L; |
| 217 | |
| 218 | Function &F = *L->getHeader()->getParent(); |
| 219 | if (DisableAll || F.hasOptSize()) |
| 220 | return false; |
| 221 | |
| 222 | if (F.hasFnAttribute(Kind: Attribute::NoImplicitFloat)) { |
| 223 | LLVM_DEBUG(dbgs() << DEBUG_TYPE << " is disabled on " << F.getName() |
| 224 | << " due to its NoImplicitFloat attribute" ); |
| 225 | return false; |
| 226 | } |
| 227 | |
| 228 | // If the loop could not be converted to canonical form, it must have an |
| 229 | // indirectbr in it, just give up. |
| 230 | if (!L->getLoopPreheader()) |
| 231 | return false; |
| 232 | |
| 233 | LLVM_DEBUG(dbgs() << DEBUG_TYPE " Scanning: F[" << F.getName() << "] Loop %" |
| 234 | << CurLoop->getHeader()->getName() << "\n" ); |
| 235 | |
| 236 | if (recognizeByteCompare()) |
| 237 | return true; |
| 238 | |
| 239 | if (recognizeFindFirstByte()) |
| 240 | return true; |
| 241 | |
| 242 | return false; |
| 243 | } |
| 244 | |
| 245 | bool LoopIdiomVectorize::recognizeByteCompare() { |
| 246 | // Currently the transformation only works on scalable vector types, although |
| 247 | // there is no fundamental reason why it cannot be made to work for fixed |
| 248 | // width too. |
| 249 | |
| 250 | // We also need to know the minimum page size for the target in order to |
| 251 | // generate runtime memory checks to ensure the vector version won't fault. |
| 252 | if (!TTI->supportsScalableVectors() || !TTI->getMinPageSize().has_value() || |
| 253 | DisableByteCmp) |
| 254 | return false; |
| 255 | |
| 256 | BasicBlock * = CurLoop->getHeader(); |
| 257 | |
| 258 | // In LoopIdiomVectorize::run we have already checked that the loop |
| 259 | // has a preheader so we can assume it's in a canonical form. |
| 260 | if (CurLoop->getNumBackEdges() != 1 || CurLoop->getNumBlocks() != 2) |
| 261 | return false; |
| 262 | |
| 263 | PHINode *PN = dyn_cast<PHINode>(Val: &Header->front()); |
| 264 | if (!PN || PN->getNumIncomingValues() != 2) |
| 265 | return false; |
| 266 | |
| 267 | auto LoopBlocks = CurLoop->getBlocks(); |
| 268 | // The first block in the loop should contain only 4 instructions, e.g. |
| 269 | // |
| 270 | // while.cond: |
| 271 | // %res.phi = phi i32 [ %start, %ph ], [ %inc, %while.body ] |
| 272 | // %inc = add i32 %res.phi, 1 |
| 273 | // %cmp.not = icmp eq i32 %inc, %n |
| 274 | // br i1 %cmp.not, label %while.end, label %while.body |
| 275 | // |
| 276 | if (LoopBlocks[0]->sizeWithoutDebug() > 4) |
| 277 | return false; |
| 278 | |
| 279 | // The second block should contain 7 instructions, e.g. |
| 280 | // |
| 281 | // while.body: |
| 282 | // %idx = zext i32 %inc to i64 |
| 283 | // %idx.a = getelementptr inbounds i8, ptr %a, i64 %idx |
| 284 | // %load.a = load i8, ptr %idx.a |
| 285 | // %idx.b = getelementptr inbounds i8, ptr %b, i64 %idx |
| 286 | // %load.b = load i8, ptr %idx.b |
| 287 | // %cmp.not.ld = icmp eq i8 %load.a, %load.b |
| 288 | // br i1 %cmp.not.ld, label %while.cond, label %while.end |
| 289 | // |
| 290 | if (LoopBlocks[1]->sizeWithoutDebug() > 7) |
| 291 | return false; |
| 292 | |
| 293 | // The incoming value to the PHI node from the loop should be an add of 1. |
| 294 | Value *StartIdx = nullptr; |
| 295 | Instruction *Index = nullptr; |
| 296 | if (!CurLoop->contains(BB: PN->getIncomingBlock(i: 0))) { |
| 297 | StartIdx = PN->getIncomingValue(i: 0); |
| 298 | Index = dyn_cast<Instruction>(Val: PN->getIncomingValue(i: 1)); |
| 299 | } else { |
| 300 | StartIdx = PN->getIncomingValue(i: 1); |
| 301 | Index = dyn_cast<Instruction>(Val: PN->getIncomingValue(i: 0)); |
| 302 | } |
| 303 | |
| 304 | // Limit to 32-bit types for now |
| 305 | if (!Index || !Index->getType()->isIntegerTy(Bitwidth: 32) || |
| 306 | !match(V: Index, P: m_c_Add(L: m_Specific(V: PN), R: m_One()))) |
| 307 | return false; |
| 308 | |
| 309 | // If we match the pattern, PN and Index will be replaced with the result of |
| 310 | // the cttz.elts intrinsic. If any other instructions are used outside of |
| 311 | // the loop, we cannot replace it. |
| 312 | for (BasicBlock *BB : LoopBlocks) |
| 313 | for (Instruction &I : *BB) |
| 314 | if (&I != PN && &I != Index) |
| 315 | for (User *U : I.users()) |
| 316 | if (!CurLoop->contains(Inst: cast<Instruction>(Val: U))) |
| 317 | return false; |
| 318 | |
| 319 | // Match the branch instruction for the header |
| 320 | Value *MaxLen; |
| 321 | BasicBlock *EndBB, *WhileBB; |
| 322 | if (!match(V: Header->getTerminator(), |
| 323 | P: m_Br(C: m_SpecificICmp(MatchPred: ICmpInst::ICMP_EQ, L: m_Specific(V: Index), |
| 324 | R: m_Value(V&: MaxLen)), |
| 325 | T: m_BasicBlock(V&: EndBB), F: m_BasicBlock(V&: WhileBB))) || |
| 326 | !CurLoop->contains(BB: WhileBB)) |
| 327 | return false; |
| 328 | |
| 329 | // WhileBB should contain the pattern of load & compare instructions. Match |
| 330 | // the pattern and find the GEP instructions used by the loads. |
| 331 | BasicBlock *FoundBB; |
| 332 | BasicBlock *TrueBB; |
| 333 | Value *LoadA, *LoadB; |
| 334 | if (!match(V: WhileBB->getTerminator(), |
| 335 | P: m_Br(C: m_SpecificICmp(MatchPred: ICmpInst::ICMP_EQ, L: m_Value(V&: LoadA), |
| 336 | R: m_Value(V&: LoadB)), |
| 337 | T: m_BasicBlock(V&: TrueBB), F: m_BasicBlock(V&: FoundBB))) || |
| 338 | !CurLoop->contains(BB: TrueBB)) |
| 339 | return false; |
| 340 | |
| 341 | Value *A, *B; |
| 342 | if (!match(V: LoadA, P: m_Load(Op: m_Value(V&: A))) || !match(V: LoadB, P: m_Load(Op: m_Value(V&: B)))) |
| 343 | return false; |
| 344 | |
| 345 | LoadInst *LoadAI = cast<LoadInst>(Val: LoadA); |
| 346 | LoadInst *LoadBI = cast<LoadInst>(Val: LoadB); |
| 347 | if (!LoadAI->isSimple() || !LoadBI->isSimple()) |
| 348 | return false; |
| 349 | |
| 350 | GetElementPtrInst *GEPA = dyn_cast<GetElementPtrInst>(Val: A); |
| 351 | GetElementPtrInst *GEPB = dyn_cast<GetElementPtrInst>(Val: B); |
| 352 | |
| 353 | if (!GEPA || !GEPB) |
| 354 | return false; |
| 355 | |
| 356 | Value *PtrA = GEPA->getPointerOperand(); |
| 357 | Value *PtrB = GEPB->getPointerOperand(); |
| 358 | |
| 359 | // Check we are loading i8 values from two loop invariant pointers |
| 360 | if (!CurLoop->isLoopInvariant(V: PtrA) || !CurLoop->isLoopInvariant(V: PtrB) || |
| 361 | !GEPA->getResultElementType()->isIntegerTy(Bitwidth: 8) || |
| 362 | !GEPB->getResultElementType()->isIntegerTy(Bitwidth: 8) || |
| 363 | !LoadAI->getType()->isIntegerTy(Bitwidth: 8) || |
| 364 | !LoadBI->getType()->isIntegerTy(Bitwidth: 8) || PtrA == PtrB) |
| 365 | return false; |
| 366 | |
| 367 | // Check that the index to the GEPs is the index we found earlier |
| 368 | if (GEPA->getNumIndices() > 1 || GEPB->getNumIndices() > 1) |
| 369 | return false; |
| 370 | |
| 371 | Value *IdxA = GEPA->getOperand(i_nocapture: GEPA->getNumIndices()); |
| 372 | Value *IdxB = GEPB->getOperand(i_nocapture: GEPB->getNumIndices()); |
| 373 | if (IdxA != IdxB || !match(V: IdxA, P: m_ZExt(Op: m_Specific(V: Index)))) |
| 374 | return false; |
| 375 | |
| 376 | // We only ever expect the pre-incremented index value to be used inside the |
| 377 | // loop. |
| 378 | if (!PN->hasOneUse()) |
| 379 | return false; |
| 380 | |
| 381 | // Ensure that when the Found and End blocks are identical the PHIs have the |
| 382 | // supported format. We don't currently allow cases like this: |
| 383 | // while.cond: |
| 384 | // ... |
| 385 | // br i1 %cmp.not, label %while.end, label %while.body |
| 386 | // |
| 387 | // while.body: |
| 388 | // ... |
| 389 | // br i1 %cmp.not2, label %while.cond, label %while.end |
| 390 | // |
| 391 | // while.end: |
| 392 | // %final_ptr = phi ptr [ %c, %while.body ], [ %d, %while.cond ] |
| 393 | // |
| 394 | // Where the incoming values for %final_ptr are unique and from each of the |
| 395 | // loop blocks, but not actually defined in the loop. This requires extra |
| 396 | // work setting up the byte.compare block, i.e. by introducing a select to |
| 397 | // choose the correct value. |
| 398 | // TODO: We could add support for this in future. |
| 399 | if (FoundBB == EndBB) { |
| 400 | for (PHINode &EndPN : EndBB->phis()) { |
| 401 | Value *WhileCondVal = EndPN.getIncomingValueForBlock(BB: Header); |
| 402 | Value *WhileBodyVal = EndPN.getIncomingValueForBlock(BB: WhileBB); |
| 403 | |
| 404 | // The value of the index when leaving the while.cond block is always the |
| 405 | // same as the end value (MaxLen) so we permit either. The value when |
| 406 | // leaving the while.body block should only be the index. Otherwise for |
| 407 | // any other values we only allow ones that are same for both blocks. |
| 408 | if (WhileCondVal != WhileBodyVal && |
| 409 | ((WhileCondVal != Index && WhileCondVal != MaxLen) || |
| 410 | (WhileBodyVal != Index))) |
| 411 | return false; |
| 412 | } |
| 413 | } |
| 414 | |
| 415 | LLVM_DEBUG(dbgs() << "FOUND IDIOM IN LOOP: \n" |
| 416 | << *(EndBB->getParent()) << "\n\n" ); |
| 417 | |
| 418 | // The index is incremented before the GEP/Load pair so we need to |
| 419 | // add 1 to the start value. |
| 420 | transformByteCompare(GEPA, GEPB, IndPhi: PN, MaxLen, Index, Start: StartIdx, /*IncIdx=*/true, |
| 421 | FoundBB, EndBB); |
| 422 | return true; |
| 423 | } |
| 424 | |
| 425 | Value *LoopIdiomVectorize::createMaskedFindMismatch( |
| 426 | IRBuilder<> &Builder, DomTreeUpdater &DTU, GetElementPtrInst *GEPA, |
| 427 | GetElementPtrInst *GEPB, Value *ExtStart, Value *ExtEnd) { |
| 428 | Type *I64Type = Builder.getInt64Ty(); |
| 429 | Type *ResType = Builder.getInt32Ty(); |
| 430 | Type *LoadType = Builder.getInt8Ty(); |
| 431 | Value *PtrA = GEPA->getPointerOperand(); |
| 432 | Value *PtrB = GEPB->getPointerOperand(); |
| 433 | |
| 434 | ScalableVectorType *PredVTy = |
| 435 | ScalableVectorType::get(ElementType: Builder.getInt1Ty(), MinNumElts: ByteCompareVF); |
| 436 | |
| 437 | Value *InitialPred = Builder.CreateIntrinsic( |
| 438 | ID: Intrinsic::get_active_lane_mask, Types: {PredVTy, I64Type}, Args: {ExtStart, ExtEnd}); |
| 439 | |
| 440 | Value *VecLen = Builder.CreateVScale(Ty: I64Type); |
| 441 | VecLen = |
| 442 | Builder.CreateMul(LHS: VecLen, RHS: ConstantInt::get(Ty: I64Type, V: ByteCompareVF), Name: "" , |
| 443 | /*HasNUW=*/true, /*HasNSW=*/true); |
| 444 | |
| 445 | Value *PFalse = Builder.CreateVectorSplat(EC: PredVTy->getElementCount(), |
| 446 | V: Builder.getInt1(V: false)); |
| 447 | |
| 448 | BranchInst *JumpToVectorLoop = BranchInst::Create(IfTrue: VectorLoopStartBlock); |
| 449 | Builder.Insert(I: JumpToVectorLoop); |
| 450 | |
| 451 | DTU.applyUpdates(Updates: {{DominatorTree::Insert, VectorLoopPreheaderBlock, |
| 452 | VectorLoopStartBlock}}); |
| 453 | |
| 454 | // Set up the first vector loop block by creating the PHIs, doing the vector |
| 455 | // loads and comparing the vectors. |
| 456 | Builder.SetInsertPoint(VectorLoopStartBlock); |
| 457 | PHINode *LoopPred = Builder.CreatePHI(Ty: PredVTy, NumReservedValues: 2, Name: "mismatch_vec_loop_pred" ); |
| 458 | LoopPred->addIncoming(V: InitialPred, BB: VectorLoopPreheaderBlock); |
| 459 | PHINode *VectorIndexPhi = Builder.CreatePHI(Ty: I64Type, NumReservedValues: 2, Name: "mismatch_vec_index" ); |
| 460 | VectorIndexPhi->addIncoming(V: ExtStart, BB: VectorLoopPreheaderBlock); |
| 461 | Type *VectorLoadType = |
| 462 | ScalableVectorType::get(ElementType: Builder.getInt8Ty(), MinNumElts: ByteCompareVF); |
| 463 | Value *Passthru = ConstantInt::getNullValue(Ty: VectorLoadType); |
| 464 | |
| 465 | Value *VectorLhsGep = |
| 466 | Builder.CreateGEP(Ty: LoadType, Ptr: PtrA, IdxList: VectorIndexPhi, Name: "" , NW: GEPA->isInBounds()); |
| 467 | Value *VectorLhsLoad = Builder.CreateMaskedLoad(Ty: VectorLoadType, Ptr: VectorLhsGep, |
| 468 | Alignment: Align(1), Mask: LoopPred, PassThru: Passthru); |
| 469 | |
| 470 | Value *VectorRhsGep = |
| 471 | Builder.CreateGEP(Ty: LoadType, Ptr: PtrB, IdxList: VectorIndexPhi, Name: "" , NW: GEPB->isInBounds()); |
| 472 | Value *VectorRhsLoad = Builder.CreateMaskedLoad(Ty: VectorLoadType, Ptr: VectorRhsGep, |
| 473 | Alignment: Align(1), Mask: LoopPred, PassThru: Passthru); |
| 474 | |
| 475 | Value *VectorMatchCmp = Builder.CreateICmpNE(LHS: VectorLhsLoad, RHS: VectorRhsLoad); |
| 476 | VectorMatchCmp = Builder.CreateSelect(C: LoopPred, True: VectorMatchCmp, False: PFalse); |
| 477 | Value *VectorMatchHasActiveLanes = Builder.CreateOrReduce(Src: VectorMatchCmp); |
| 478 | BranchInst *VectorEarlyExit = BranchInst::Create( |
| 479 | IfTrue: VectorLoopMismatchBlock, IfFalse: VectorLoopIncBlock, Cond: VectorMatchHasActiveLanes); |
| 480 | Builder.Insert(I: VectorEarlyExit); |
| 481 | |
| 482 | DTU.applyUpdates( |
| 483 | Updates: {{DominatorTree::Insert, VectorLoopStartBlock, VectorLoopMismatchBlock}, |
| 484 | {DominatorTree::Insert, VectorLoopStartBlock, VectorLoopIncBlock}}); |
| 485 | |
| 486 | // Increment the index counter and calculate the predicate for the next |
| 487 | // iteration of the loop. We branch back to the start of the loop if there |
| 488 | // is at least one active lane. |
| 489 | Builder.SetInsertPoint(VectorLoopIncBlock); |
| 490 | Value *NewVectorIndexPhi = |
| 491 | Builder.CreateAdd(LHS: VectorIndexPhi, RHS: VecLen, Name: "" , |
| 492 | /*HasNUW=*/true, /*HasNSW=*/true); |
| 493 | VectorIndexPhi->addIncoming(V: NewVectorIndexPhi, BB: VectorLoopIncBlock); |
| 494 | Value *NewPred = |
| 495 | Builder.CreateIntrinsic(ID: Intrinsic::get_active_lane_mask, |
| 496 | Types: {PredVTy, I64Type}, Args: {NewVectorIndexPhi, ExtEnd}); |
| 497 | LoopPred->addIncoming(V: NewPred, BB: VectorLoopIncBlock); |
| 498 | |
| 499 | Value *PredHasActiveLanes = |
| 500 | Builder.CreateExtractElement(Vec: NewPred, Idx: uint64_t(0)); |
| 501 | BranchInst *VectorLoopBranchBack = |
| 502 | BranchInst::Create(IfTrue: VectorLoopStartBlock, IfFalse: EndBlock, Cond: PredHasActiveLanes); |
| 503 | Builder.Insert(I: VectorLoopBranchBack); |
| 504 | |
| 505 | DTU.applyUpdates( |
| 506 | Updates: {{DominatorTree::Insert, VectorLoopIncBlock, VectorLoopStartBlock}, |
| 507 | {DominatorTree::Insert, VectorLoopIncBlock, EndBlock}}); |
| 508 | |
| 509 | // If we found a mismatch then we need to calculate which lane in the vector |
| 510 | // had a mismatch and add that on to the current loop index. |
| 511 | Builder.SetInsertPoint(VectorLoopMismatchBlock); |
| 512 | PHINode *FoundPred = Builder.CreatePHI(Ty: PredVTy, NumReservedValues: 1, Name: "mismatch_vec_found_pred" ); |
| 513 | FoundPred->addIncoming(V: VectorMatchCmp, BB: VectorLoopStartBlock); |
| 514 | PHINode *LastLoopPred = |
| 515 | Builder.CreatePHI(Ty: PredVTy, NumReservedValues: 1, Name: "mismatch_vec_last_loop_pred" ); |
| 516 | LastLoopPred->addIncoming(V: LoopPred, BB: VectorLoopStartBlock); |
| 517 | PHINode *VectorFoundIndex = |
| 518 | Builder.CreatePHI(Ty: I64Type, NumReservedValues: 1, Name: "mismatch_vec_found_index" ); |
| 519 | VectorFoundIndex->addIncoming(V: VectorIndexPhi, BB: VectorLoopStartBlock); |
| 520 | |
| 521 | Value *PredMatchCmp = Builder.CreateAnd(LHS: LastLoopPred, RHS: FoundPred); |
| 522 | Value *Ctz = Builder.CreateCountTrailingZeroElems(ResTy: ResType, Mask: PredMatchCmp); |
| 523 | Ctz = Builder.CreateZExt(V: Ctz, DestTy: I64Type); |
| 524 | Value *VectorLoopRes64 = Builder.CreateAdd(LHS: VectorFoundIndex, RHS: Ctz, Name: "" , |
| 525 | /*HasNUW=*/true, /*HasNSW=*/true); |
| 526 | return Builder.CreateTrunc(V: VectorLoopRes64, DestTy: ResType); |
| 527 | } |
| 528 | |
| 529 | Value *LoopIdiomVectorize::createPredicatedFindMismatch( |
| 530 | IRBuilder<> &Builder, DomTreeUpdater &DTU, GetElementPtrInst *GEPA, |
| 531 | GetElementPtrInst *GEPB, Value *ExtStart, Value *ExtEnd) { |
| 532 | Type *I64Type = Builder.getInt64Ty(); |
| 533 | Type *I32Type = Builder.getInt32Ty(); |
| 534 | Type *ResType = I32Type; |
| 535 | Type *LoadType = Builder.getInt8Ty(); |
| 536 | Value *PtrA = GEPA->getPointerOperand(); |
| 537 | Value *PtrB = GEPB->getPointerOperand(); |
| 538 | |
| 539 | auto *JumpToVectorLoop = BranchInst::Create(IfTrue: VectorLoopStartBlock); |
| 540 | Builder.Insert(I: JumpToVectorLoop); |
| 541 | |
| 542 | DTU.applyUpdates(Updates: {{DominatorTree::Insert, VectorLoopPreheaderBlock, |
| 543 | VectorLoopStartBlock}}); |
| 544 | |
| 545 | // Set up the first Vector loop block by creating the PHIs, doing the vector |
| 546 | // loads and comparing the vectors. |
| 547 | Builder.SetInsertPoint(VectorLoopStartBlock); |
| 548 | auto *VectorIndexPhi = Builder.CreatePHI(Ty: I64Type, NumReservedValues: 2, Name: "mismatch_vector_index" ); |
| 549 | VectorIndexPhi->addIncoming(V: ExtStart, BB: VectorLoopPreheaderBlock); |
| 550 | |
| 551 | // Calculate AVL by subtracting the vector loop index from the trip count |
| 552 | Value *AVL = Builder.CreateSub(LHS: ExtEnd, RHS: VectorIndexPhi, Name: "avl" , /*HasNUW=*/true, |
| 553 | /*HasNSW=*/true); |
| 554 | |
| 555 | auto *VectorLoadType = ScalableVectorType::get(ElementType: LoadType, MinNumElts: ByteCompareVF); |
| 556 | auto *VF = ConstantInt::get(Ty: I32Type, V: ByteCompareVF); |
| 557 | |
| 558 | Value *VL = Builder.CreateIntrinsic(ID: Intrinsic::experimental_get_vector_length, |
| 559 | Types: {I64Type}, Args: {AVL, VF, Builder.getTrue()}); |
| 560 | Value *GepOffset = VectorIndexPhi; |
| 561 | |
| 562 | Value *VectorLhsGep = |
| 563 | Builder.CreateGEP(Ty: LoadType, Ptr: PtrA, IdxList: GepOffset, Name: "" , NW: GEPA->isInBounds()); |
| 564 | VectorType *TrueMaskTy = |
| 565 | VectorType::get(ElementType: Builder.getInt1Ty(), EC: VectorLoadType->getElementCount()); |
| 566 | Value *AllTrueMask = Constant::getAllOnesValue(Ty: TrueMaskTy); |
| 567 | Value *VectorLhsLoad = Builder.CreateIntrinsic( |
| 568 | ID: Intrinsic::vp_load, Types: {VectorLoadType, VectorLhsGep->getType()}, |
| 569 | Args: {VectorLhsGep, AllTrueMask, VL}, FMFSource: nullptr, Name: "lhs.load" ); |
| 570 | |
| 571 | Value *VectorRhsGep = |
| 572 | Builder.CreateGEP(Ty: LoadType, Ptr: PtrB, IdxList: GepOffset, Name: "" , NW: GEPB->isInBounds()); |
| 573 | Value *VectorRhsLoad = Builder.CreateIntrinsic( |
| 574 | ID: Intrinsic::vp_load, Types: {VectorLoadType, VectorLhsGep->getType()}, |
| 575 | Args: {VectorRhsGep, AllTrueMask, VL}, FMFSource: nullptr, Name: "rhs.load" ); |
| 576 | |
| 577 | StringRef PredicateStr = CmpInst::getPredicateName(P: CmpInst::ICMP_NE); |
| 578 | auto *PredicateMDS = MDString::get(Context&: VectorLhsLoad->getContext(), Str: PredicateStr); |
| 579 | Value *Pred = MetadataAsValue::get(Context&: VectorLhsLoad->getContext(), MD: PredicateMDS); |
| 580 | Value *VectorMatchCmp = Builder.CreateIntrinsic( |
| 581 | ID: Intrinsic::vp_icmp, Types: {VectorLhsLoad->getType()}, |
| 582 | Args: {VectorLhsLoad, VectorRhsLoad, Pred, AllTrueMask, VL}, FMFSource: nullptr, |
| 583 | Name: "mismatch.cmp" ); |
| 584 | Value *CTZ = Builder.CreateIntrinsic( |
| 585 | ID: Intrinsic::vp_cttz_elts, Types: {ResType, VectorMatchCmp->getType()}, |
| 586 | Args: {VectorMatchCmp, /*ZeroIsPoison=*/Builder.getInt1(V: false), AllTrueMask, |
| 587 | VL}); |
| 588 | Value *MismatchFound = Builder.CreateICmpNE(LHS: CTZ, RHS: VL); |
| 589 | auto *VectorEarlyExit = BranchInst::Create(IfTrue: VectorLoopMismatchBlock, |
| 590 | IfFalse: VectorLoopIncBlock, Cond: MismatchFound); |
| 591 | Builder.Insert(I: VectorEarlyExit); |
| 592 | |
| 593 | DTU.applyUpdates( |
| 594 | Updates: {{DominatorTree::Insert, VectorLoopStartBlock, VectorLoopMismatchBlock}, |
| 595 | {DominatorTree::Insert, VectorLoopStartBlock, VectorLoopIncBlock}}); |
| 596 | |
| 597 | // Increment the index counter and calculate the predicate for the next |
| 598 | // iteration of the loop. We branch back to the start of the loop if there |
| 599 | // is at least one active lane. |
| 600 | Builder.SetInsertPoint(VectorLoopIncBlock); |
| 601 | Value *VL64 = Builder.CreateZExt(V: VL, DestTy: I64Type); |
| 602 | Value *NewVectorIndexPhi = |
| 603 | Builder.CreateAdd(LHS: VectorIndexPhi, RHS: VL64, Name: "" , |
| 604 | /*HasNUW=*/true, /*HasNSW=*/true); |
| 605 | VectorIndexPhi->addIncoming(V: NewVectorIndexPhi, BB: VectorLoopIncBlock); |
| 606 | Value *ExitCond = Builder.CreateICmpNE(LHS: NewVectorIndexPhi, RHS: ExtEnd); |
| 607 | auto *VectorLoopBranchBack = |
| 608 | BranchInst::Create(IfTrue: VectorLoopStartBlock, IfFalse: EndBlock, Cond: ExitCond); |
| 609 | Builder.Insert(I: VectorLoopBranchBack); |
| 610 | |
| 611 | DTU.applyUpdates( |
| 612 | Updates: {{DominatorTree::Insert, VectorLoopIncBlock, VectorLoopStartBlock}, |
| 613 | {DominatorTree::Insert, VectorLoopIncBlock, EndBlock}}); |
| 614 | |
| 615 | // If we found a mismatch then we need to calculate which lane in the vector |
| 616 | // had a mismatch and add that on to the current loop index. |
| 617 | Builder.SetInsertPoint(VectorLoopMismatchBlock); |
| 618 | |
| 619 | // Add LCSSA phis for CTZ and VectorIndexPhi. |
| 620 | auto *CTZLCSSAPhi = Builder.CreatePHI(Ty: CTZ->getType(), NumReservedValues: 1, Name: "ctz" ); |
| 621 | CTZLCSSAPhi->addIncoming(V: CTZ, BB: VectorLoopStartBlock); |
| 622 | auto *VectorIndexLCSSAPhi = |
| 623 | Builder.CreatePHI(Ty: VectorIndexPhi->getType(), NumReservedValues: 1, Name: "mismatch_vector_index" ); |
| 624 | VectorIndexLCSSAPhi->addIncoming(V: VectorIndexPhi, BB: VectorLoopStartBlock); |
| 625 | |
| 626 | Value *CTZI64 = Builder.CreateZExt(V: CTZLCSSAPhi, DestTy: I64Type); |
| 627 | Value *VectorLoopRes64 = Builder.CreateAdd(LHS: VectorIndexLCSSAPhi, RHS: CTZI64, Name: "" , |
| 628 | /*HasNUW=*/true, /*HasNSW=*/true); |
| 629 | return Builder.CreateTrunc(V: VectorLoopRes64, DestTy: ResType); |
| 630 | } |
| 631 | |
| 632 | Value *LoopIdiomVectorize::expandFindMismatch( |
| 633 | IRBuilder<> &Builder, DomTreeUpdater &DTU, GetElementPtrInst *GEPA, |
| 634 | GetElementPtrInst *GEPB, Instruction *Index, Value *Start, Value *MaxLen) { |
| 635 | Value *PtrA = GEPA->getPointerOperand(); |
| 636 | Value *PtrB = GEPB->getPointerOperand(); |
| 637 | |
| 638 | // Get the arguments and types for the intrinsic. |
| 639 | BasicBlock * = CurLoop->getLoopPreheader(); |
| 640 | BranchInst *PHBranch = cast<BranchInst>(Val: Preheader->getTerminator()); |
| 641 | LLVMContext &Ctx = PHBranch->getContext(); |
| 642 | Type *LoadType = Type::getInt8Ty(C&: Ctx); |
| 643 | Type *ResType = Builder.getInt32Ty(); |
| 644 | |
| 645 | // Split block in the original loop preheader. |
| 646 | EndBlock = SplitBlock(Old: Preheader, SplitPt: PHBranch, DT, LI, MSSAU: nullptr, BBName: "mismatch_end" ); |
| 647 | |
| 648 | // Create the blocks that we're going to need: |
| 649 | // 1. A block for checking the zero-extended length exceeds 0 |
| 650 | // 2. A block to check that the start and end addresses of a given array |
| 651 | // lie on the same page. |
| 652 | // 3. The vector loop preheader. |
| 653 | // 4. The first vector loop block. |
| 654 | // 5. The vector loop increment block. |
| 655 | // 6. A block we can jump to from the vector loop when a mismatch is found. |
| 656 | // 7. The first block of the scalar loop itself, containing PHIs , loads |
| 657 | // and cmp. |
| 658 | // 8. A scalar loop increment block to increment the PHIs and go back |
| 659 | // around the loop. |
| 660 | |
| 661 | BasicBlock *MinItCheckBlock = BasicBlock::Create( |
| 662 | Context&: Ctx, Name: "mismatch_min_it_check" , Parent: EndBlock->getParent(), InsertBefore: EndBlock); |
| 663 | |
| 664 | // Update the terminator added by SplitBlock to branch to the first block |
| 665 | Preheader->getTerminator()->setSuccessor(Idx: 0, BB: MinItCheckBlock); |
| 666 | |
| 667 | BasicBlock *MemCheckBlock = BasicBlock::Create( |
| 668 | Context&: Ctx, Name: "mismatch_mem_check" , Parent: EndBlock->getParent(), InsertBefore: EndBlock); |
| 669 | |
| 670 | VectorLoopPreheaderBlock = BasicBlock::Create( |
| 671 | Context&: Ctx, Name: "mismatch_vec_loop_preheader" , Parent: EndBlock->getParent(), InsertBefore: EndBlock); |
| 672 | |
| 673 | VectorLoopStartBlock = BasicBlock::Create(Context&: Ctx, Name: "mismatch_vec_loop" , |
| 674 | Parent: EndBlock->getParent(), InsertBefore: EndBlock); |
| 675 | |
| 676 | VectorLoopIncBlock = BasicBlock::Create(Context&: Ctx, Name: "mismatch_vec_loop_inc" , |
| 677 | Parent: EndBlock->getParent(), InsertBefore: EndBlock); |
| 678 | |
| 679 | VectorLoopMismatchBlock = BasicBlock::Create(Context&: Ctx, Name: "mismatch_vec_loop_found" , |
| 680 | Parent: EndBlock->getParent(), InsertBefore: EndBlock); |
| 681 | |
| 682 | BasicBlock * = BasicBlock::Create( |
| 683 | Context&: Ctx, Name: "mismatch_loop_pre" , Parent: EndBlock->getParent(), InsertBefore: EndBlock); |
| 684 | |
| 685 | BasicBlock *LoopStartBlock = |
| 686 | BasicBlock::Create(Context&: Ctx, Name: "mismatch_loop" , Parent: EndBlock->getParent(), InsertBefore: EndBlock); |
| 687 | |
| 688 | BasicBlock *LoopIncBlock = BasicBlock::Create( |
| 689 | Context&: Ctx, Name: "mismatch_loop_inc" , Parent: EndBlock->getParent(), InsertBefore: EndBlock); |
| 690 | |
| 691 | DTU.applyUpdates(Updates: {{DominatorTree::Insert, Preheader, MinItCheckBlock}, |
| 692 | {DominatorTree::Delete, Preheader, EndBlock}}); |
| 693 | |
| 694 | // Update LoopInfo with the new vector & scalar loops. |
| 695 | auto VectorLoop = LI->AllocateLoop(); |
| 696 | auto ScalarLoop = LI->AllocateLoop(); |
| 697 | |
| 698 | if (CurLoop->getParentLoop()) { |
| 699 | CurLoop->getParentLoop()->addBasicBlockToLoop(NewBB: MinItCheckBlock, LI&: *LI); |
| 700 | CurLoop->getParentLoop()->addBasicBlockToLoop(NewBB: MemCheckBlock, LI&: *LI); |
| 701 | CurLoop->getParentLoop()->addBasicBlockToLoop(NewBB: VectorLoopPreheaderBlock, |
| 702 | LI&: *LI); |
| 703 | CurLoop->getParentLoop()->addChildLoop(NewChild: VectorLoop); |
| 704 | CurLoop->getParentLoop()->addBasicBlockToLoop(NewBB: VectorLoopMismatchBlock, LI&: *LI); |
| 705 | CurLoop->getParentLoop()->addBasicBlockToLoop(NewBB: LoopPreHeaderBlock, LI&: *LI); |
| 706 | CurLoop->getParentLoop()->addChildLoop(NewChild: ScalarLoop); |
| 707 | } else { |
| 708 | LI->addTopLevelLoop(New: VectorLoop); |
| 709 | LI->addTopLevelLoop(New: ScalarLoop); |
| 710 | } |
| 711 | |
| 712 | // Add the new basic blocks to their associated loops. |
| 713 | VectorLoop->addBasicBlockToLoop(NewBB: VectorLoopStartBlock, LI&: *LI); |
| 714 | VectorLoop->addBasicBlockToLoop(NewBB: VectorLoopIncBlock, LI&: *LI); |
| 715 | |
| 716 | ScalarLoop->addBasicBlockToLoop(NewBB: LoopStartBlock, LI&: *LI); |
| 717 | ScalarLoop->addBasicBlockToLoop(NewBB: LoopIncBlock, LI&: *LI); |
| 718 | |
| 719 | // Set up some types and constants that we intend to reuse. |
| 720 | Type *I64Type = Builder.getInt64Ty(); |
| 721 | |
| 722 | // Check the zero-extended iteration count > 0 |
| 723 | Builder.SetInsertPoint(MinItCheckBlock); |
| 724 | Value *ExtStart = Builder.CreateZExt(V: Start, DestTy: I64Type); |
| 725 | Value *ExtEnd = Builder.CreateZExt(V: MaxLen, DestTy: I64Type); |
| 726 | // This check doesn't really cost us very much. |
| 727 | |
| 728 | Value *LimitCheck = Builder.CreateICmpULE(LHS: Start, RHS: MaxLen); |
| 729 | BranchInst *MinItCheckBr = |
| 730 | BranchInst::Create(IfTrue: MemCheckBlock, IfFalse: LoopPreHeaderBlock, Cond: LimitCheck); |
| 731 | MinItCheckBr->setMetadata( |
| 732 | KindID: LLVMContext::MD_prof, |
| 733 | Node: MDBuilder(MinItCheckBr->getContext()).createBranchWeights(TrueWeight: 99, FalseWeight: 1)); |
| 734 | Builder.Insert(I: MinItCheckBr); |
| 735 | |
| 736 | DTU.applyUpdates( |
| 737 | Updates: {{DominatorTree::Insert, MinItCheckBlock, MemCheckBlock}, |
| 738 | {DominatorTree::Insert, MinItCheckBlock, LoopPreHeaderBlock}}); |
| 739 | |
| 740 | // For each of the arrays, check the start/end addresses are on the same |
| 741 | // page. |
| 742 | Builder.SetInsertPoint(MemCheckBlock); |
| 743 | |
| 744 | // The early exit in the original loop means that when performing vector |
| 745 | // loads we are potentially reading ahead of the early exit. So we could |
| 746 | // fault if crossing a page boundary. Therefore, we create runtime memory |
| 747 | // checks based on the minimum page size as follows: |
| 748 | // 1. Calculate the addresses of the first memory accesses in the loop, |
| 749 | // i.e. LhsStart and RhsStart. |
| 750 | // 2. Get the last accessed addresses in the loop, i.e. LhsEnd and RhsEnd. |
| 751 | // 3. Determine which pages correspond to all the memory accesses, i.e |
| 752 | // LhsStartPage, LhsEndPage, RhsStartPage, RhsEndPage. |
| 753 | // 4. If LhsStartPage == LhsEndPage and RhsStartPage == RhsEndPage, then |
| 754 | // we know we won't cross any page boundaries in the loop so we can |
| 755 | // enter the vector loop! Otherwise we fall back on the scalar loop. |
| 756 | Value *LhsStartGEP = Builder.CreateGEP(Ty: LoadType, Ptr: PtrA, IdxList: ExtStart); |
| 757 | Value *RhsStartGEP = Builder.CreateGEP(Ty: LoadType, Ptr: PtrB, IdxList: ExtStart); |
| 758 | Value *RhsStart = Builder.CreatePtrToInt(V: RhsStartGEP, DestTy: I64Type); |
| 759 | Value *LhsStart = Builder.CreatePtrToInt(V: LhsStartGEP, DestTy: I64Type); |
| 760 | Value *LhsEndGEP = Builder.CreateGEP(Ty: LoadType, Ptr: PtrA, IdxList: ExtEnd); |
| 761 | Value *RhsEndGEP = Builder.CreateGEP(Ty: LoadType, Ptr: PtrB, IdxList: ExtEnd); |
| 762 | Value *LhsEnd = Builder.CreatePtrToInt(V: LhsEndGEP, DestTy: I64Type); |
| 763 | Value *RhsEnd = Builder.CreatePtrToInt(V: RhsEndGEP, DestTy: I64Type); |
| 764 | |
| 765 | const uint64_t MinPageSize = TTI->getMinPageSize().value(); |
| 766 | const uint64_t AddrShiftAmt = llvm::Log2_64(Value: MinPageSize); |
| 767 | Value *LhsStartPage = Builder.CreateLShr(LHS: LhsStart, RHS: AddrShiftAmt); |
| 768 | Value *LhsEndPage = Builder.CreateLShr(LHS: LhsEnd, RHS: AddrShiftAmt); |
| 769 | Value *RhsStartPage = Builder.CreateLShr(LHS: RhsStart, RHS: AddrShiftAmt); |
| 770 | Value *RhsEndPage = Builder.CreateLShr(LHS: RhsEnd, RHS: AddrShiftAmt); |
| 771 | Value *LhsPageCmp = Builder.CreateICmpNE(LHS: LhsStartPage, RHS: LhsEndPage); |
| 772 | Value *RhsPageCmp = Builder.CreateICmpNE(LHS: RhsStartPage, RHS: RhsEndPage); |
| 773 | |
| 774 | Value *CombinedPageCmp = Builder.CreateOr(LHS: LhsPageCmp, RHS: RhsPageCmp); |
| 775 | BranchInst *CombinedPageCmpCmpBr = BranchInst::Create( |
| 776 | IfTrue: LoopPreHeaderBlock, IfFalse: VectorLoopPreheaderBlock, Cond: CombinedPageCmp); |
| 777 | CombinedPageCmpCmpBr->setMetadata( |
| 778 | KindID: LLVMContext::MD_prof, Node: MDBuilder(CombinedPageCmpCmpBr->getContext()) |
| 779 | .createBranchWeights(TrueWeight: 10, FalseWeight: 90)); |
| 780 | Builder.Insert(I: CombinedPageCmpCmpBr); |
| 781 | |
| 782 | DTU.applyUpdates( |
| 783 | Updates: {{DominatorTree::Insert, MemCheckBlock, LoopPreHeaderBlock}, |
| 784 | {DominatorTree::Insert, MemCheckBlock, VectorLoopPreheaderBlock}}); |
| 785 | |
| 786 | // Set up the vector loop preheader, i.e. calculate initial loop predicate, |
| 787 | // zero-extend MaxLen to 64-bits, determine the number of vector elements |
| 788 | // processed in each iteration, etc. |
| 789 | Builder.SetInsertPoint(VectorLoopPreheaderBlock); |
| 790 | |
| 791 | // At this point we know two things must be true: |
| 792 | // 1. Start <= End |
| 793 | // 2. ExtMaxLen <= MinPageSize due to the page checks. |
| 794 | // Therefore, we know that we can use a 64-bit induction variable that |
| 795 | // starts from 0 -> ExtMaxLen and it will not overflow. |
| 796 | Value *VectorLoopRes = nullptr; |
| 797 | switch (VectorizeStyle) { |
| 798 | case LoopIdiomVectorizeStyle::Masked: |
| 799 | VectorLoopRes = |
| 800 | createMaskedFindMismatch(Builder, DTU, GEPA, GEPB, ExtStart, ExtEnd); |
| 801 | break; |
| 802 | case LoopIdiomVectorizeStyle::Predicated: |
| 803 | VectorLoopRes = createPredicatedFindMismatch(Builder, DTU, GEPA, GEPB, |
| 804 | ExtStart, ExtEnd); |
| 805 | break; |
| 806 | } |
| 807 | |
| 808 | Builder.Insert(I: BranchInst::Create(IfTrue: EndBlock)); |
| 809 | |
| 810 | DTU.applyUpdates( |
| 811 | Updates: {{DominatorTree::Insert, VectorLoopMismatchBlock, EndBlock}}); |
| 812 | |
| 813 | // Generate code for scalar loop. |
| 814 | Builder.SetInsertPoint(LoopPreHeaderBlock); |
| 815 | Builder.Insert(I: BranchInst::Create(IfTrue: LoopStartBlock)); |
| 816 | |
| 817 | DTU.applyUpdates( |
| 818 | Updates: {{DominatorTree::Insert, LoopPreHeaderBlock, LoopStartBlock}}); |
| 819 | |
| 820 | Builder.SetInsertPoint(LoopStartBlock); |
| 821 | PHINode *IndexPhi = Builder.CreatePHI(Ty: ResType, NumReservedValues: 2, Name: "mismatch_index" ); |
| 822 | IndexPhi->addIncoming(V: Start, BB: LoopPreHeaderBlock); |
| 823 | |
| 824 | // Otherwise compare the values |
| 825 | // Load bytes from each array and compare them. |
| 826 | Value *GepOffset = Builder.CreateZExt(V: IndexPhi, DestTy: I64Type); |
| 827 | |
| 828 | Value *LhsGep = |
| 829 | Builder.CreateGEP(Ty: LoadType, Ptr: PtrA, IdxList: GepOffset, Name: "" , NW: GEPA->isInBounds()); |
| 830 | Value *LhsLoad = Builder.CreateLoad(Ty: LoadType, Ptr: LhsGep); |
| 831 | |
| 832 | Value *RhsGep = |
| 833 | Builder.CreateGEP(Ty: LoadType, Ptr: PtrB, IdxList: GepOffset, Name: "" , NW: GEPB->isInBounds()); |
| 834 | Value *RhsLoad = Builder.CreateLoad(Ty: LoadType, Ptr: RhsGep); |
| 835 | |
| 836 | Value *MatchCmp = Builder.CreateICmpEQ(LHS: LhsLoad, RHS: RhsLoad); |
| 837 | // If we have a mismatch then exit the loop ... |
| 838 | BranchInst *MatchCmpBr = BranchInst::Create(IfTrue: LoopIncBlock, IfFalse: EndBlock, Cond: MatchCmp); |
| 839 | Builder.Insert(I: MatchCmpBr); |
| 840 | |
| 841 | DTU.applyUpdates(Updates: {{DominatorTree::Insert, LoopStartBlock, LoopIncBlock}, |
| 842 | {DominatorTree::Insert, LoopStartBlock, EndBlock}}); |
| 843 | |
| 844 | // Have we reached the maximum permitted length for the loop? |
| 845 | Builder.SetInsertPoint(LoopIncBlock); |
| 846 | Value *PhiInc = Builder.CreateAdd(LHS: IndexPhi, RHS: ConstantInt::get(Ty: ResType, V: 1), Name: "" , |
| 847 | /*HasNUW=*/Index->hasNoUnsignedWrap(), |
| 848 | /*HasNSW=*/Index->hasNoSignedWrap()); |
| 849 | IndexPhi->addIncoming(V: PhiInc, BB: LoopIncBlock); |
| 850 | Value *IVCmp = Builder.CreateICmpEQ(LHS: PhiInc, RHS: MaxLen); |
| 851 | BranchInst *IVCmpBr = BranchInst::Create(IfTrue: EndBlock, IfFalse: LoopStartBlock, Cond: IVCmp); |
| 852 | Builder.Insert(I: IVCmpBr); |
| 853 | |
| 854 | DTU.applyUpdates(Updates: {{DominatorTree::Insert, LoopIncBlock, EndBlock}, |
| 855 | {DominatorTree::Insert, LoopIncBlock, LoopStartBlock}}); |
| 856 | |
| 857 | // In the end block we need to insert a PHI node to deal with three cases: |
| 858 | // 1. We didn't find a mismatch in the scalar loop, so we return MaxLen. |
| 859 | // 2. We exitted the scalar loop early due to a mismatch and need to return |
| 860 | // the index that we found. |
| 861 | // 3. We didn't find a mismatch in the vector loop, so we return MaxLen. |
| 862 | // 4. We exitted the vector loop early due to a mismatch and need to return |
| 863 | // the index that we found. |
| 864 | Builder.SetInsertPoint(TheBB: EndBlock, IP: EndBlock->getFirstInsertionPt()); |
| 865 | PHINode *ResPhi = Builder.CreatePHI(Ty: ResType, NumReservedValues: 4, Name: "mismatch_result" ); |
| 866 | ResPhi->addIncoming(V: MaxLen, BB: LoopIncBlock); |
| 867 | ResPhi->addIncoming(V: IndexPhi, BB: LoopStartBlock); |
| 868 | ResPhi->addIncoming(V: MaxLen, BB: VectorLoopIncBlock); |
| 869 | ResPhi->addIncoming(V: VectorLoopRes, BB: VectorLoopMismatchBlock); |
| 870 | |
| 871 | Value *FinalRes = Builder.CreateTrunc(V: ResPhi, DestTy: ResType); |
| 872 | |
| 873 | if (VerifyLoops) { |
| 874 | ScalarLoop->verifyLoop(); |
| 875 | VectorLoop->verifyLoop(); |
| 876 | if (!VectorLoop->isRecursivelyLCSSAForm(DT: *DT, LI: *LI)) |
| 877 | report_fatal_error(reason: "Loops must remain in LCSSA form!" ); |
| 878 | if (!ScalarLoop->isRecursivelyLCSSAForm(DT: *DT, LI: *LI)) |
| 879 | report_fatal_error(reason: "Loops must remain in LCSSA form!" ); |
| 880 | } |
| 881 | |
| 882 | return FinalRes; |
| 883 | } |
| 884 | |
| 885 | void LoopIdiomVectorize::transformByteCompare(GetElementPtrInst *GEPA, |
| 886 | GetElementPtrInst *GEPB, |
| 887 | PHINode *IndPhi, Value *MaxLen, |
| 888 | Instruction *Index, Value *Start, |
| 889 | bool IncIdx, BasicBlock *FoundBB, |
| 890 | BasicBlock *EndBB) { |
| 891 | |
| 892 | // Insert the byte compare code at the end of the preheader block |
| 893 | BasicBlock * = CurLoop->getLoopPreheader(); |
| 894 | BasicBlock * = CurLoop->getHeader(); |
| 895 | BranchInst *PHBranch = cast<BranchInst>(Val: Preheader->getTerminator()); |
| 896 | IRBuilder<> Builder(PHBranch); |
| 897 | DomTreeUpdater DTU(DT, DomTreeUpdater::UpdateStrategy::Lazy); |
| 898 | Builder.SetCurrentDebugLocation(PHBranch->getDebugLoc()); |
| 899 | |
| 900 | // Increment the pointer if this was done before the loads in the loop. |
| 901 | if (IncIdx) |
| 902 | Start = Builder.CreateAdd(LHS: Start, RHS: ConstantInt::get(Ty: Start->getType(), V: 1)); |
| 903 | |
| 904 | Value *ByteCmpRes = |
| 905 | expandFindMismatch(Builder, DTU, GEPA, GEPB, Index, Start, MaxLen); |
| 906 | |
| 907 | // Replaces uses of index & induction Phi with intrinsic (we already |
| 908 | // checked that the the first instruction of Header is the Phi above). |
| 909 | assert(IndPhi->hasOneUse() && "Index phi node has more than one use!" ); |
| 910 | Index->replaceAllUsesWith(V: ByteCmpRes); |
| 911 | |
| 912 | assert(PHBranch->isUnconditional() && |
| 913 | "Expected preheader to terminate with an unconditional branch." ); |
| 914 | |
| 915 | // If no mismatch was found, we can jump to the end block. Create a |
| 916 | // new basic block for the compare instruction. |
| 917 | auto *CmpBB = BasicBlock::Create(Context&: Preheader->getContext(), Name: "byte.compare" , |
| 918 | Parent: Preheader->getParent()); |
| 919 | CmpBB->moveBefore(MovePos: EndBB); |
| 920 | |
| 921 | // Replace the branch in the preheader with an always-true conditional branch. |
| 922 | // This ensures there is still a reference to the original loop. |
| 923 | Builder.CreateCondBr(Cond: Builder.getTrue(), True: CmpBB, False: Header); |
| 924 | PHBranch->eraseFromParent(); |
| 925 | |
| 926 | BasicBlock *MismatchEnd = cast<Instruction>(Val: ByteCmpRes)->getParent(); |
| 927 | DTU.applyUpdates(Updates: {{DominatorTree::Insert, MismatchEnd, CmpBB}}); |
| 928 | |
| 929 | // Create the branch to either the end or found block depending on the value |
| 930 | // returned by the intrinsic. |
| 931 | Builder.SetInsertPoint(CmpBB); |
| 932 | if (FoundBB != EndBB) { |
| 933 | Value *FoundCmp = Builder.CreateICmpEQ(LHS: ByteCmpRes, RHS: MaxLen); |
| 934 | Builder.CreateCondBr(Cond: FoundCmp, True: EndBB, False: FoundBB); |
| 935 | DTU.applyUpdates(Updates: {{DominatorTree::Insert, CmpBB, FoundBB}, |
| 936 | {DominatorTree::Insert, CmpBB, EndBB}}); |
| 937 | |
| 938 | } else { |
| 939 | Builder.CreateBr(Dest: FoundBB); |
| 940 | DTU.applyUpdates(Updates: {{DominatorTree::Insert, CmpBB, FoundBB}}); |
| 941 | } |
| 942 | |
| 943 | auto fixSuccessorPhis = [&](BasicBlock *SuccBB) { |
| 944 | for (PHINode &PN : SuccBB->phis()) { |
| 945 | // At this point we've already replaced all uses of the result from the |
| 946 | // loop with ByteCmp. Look through the incoming values to find ByteCmp, |
| 947 | // meaning this is a Phi collecting the results of the byte compare. |
| 948 | bool ResPhi = false; |
| 949 | for (Value *Op : PN.incoming_values()) |
| 950 | if (Op == ByteCmpRes) { |
| 951 | ResPhi = true; |
| 952 | break; |
| 953 | } |
| 954 | |
| 955 | // Any PHI that depended upon the result of the byte compare needs a new |
| 956 | // incoming value from CmpBB. This is because the original loop will get |
| 957 | // deleted. |
| 958 | if (ResPhi) |
| 959 | PN.addIncoming(V: ByteCmpRes, BB: CmpBB); |
| 960 | else { |
| 961 | // There should be no other outside uses of other values in the |
| 962 | // original loop. Any incoming values should either: |
| 963 | // 1. Be for blocks outside the loop, which aren't interesting. Or .. |
| 964 | // 2. These are from blocks in the loop with values defined outside |
| 965 | // the loop. We should a similar incoming value from CmpBB. |
| 966 | for (BasicBlock *BB : PN.blocks()) |
| 967 | if (CurLoop->contains(BB)) { |
| 968 | PN.addIncoming(V: PN.getIncomingValueForBlock(BB), BB: CmpBB); |
| 969 | break; |
| 970 | } |
| 971 | } |
| 972 | } |
| 973 | }; |
| 974 | |
| 975 | // Ensure all Phis in the successors of CmpBB have an incoming value from it. |
| 976 | fixSuccessorPhis(EndBB); |
| 977 | if (EndBB != FoundBB) |
| 978 | fixSuccessorPhis(FoundBB); |
| 979 | |
| 980 | // The new CmpBB block isn't part of the loop, but will need to be added to |
| 981 | // the outer loop if there is one. |
| 982 | if (!CurLoop->isOutermost()) |
| 983 | CurLoop->getParentLoop()->addBasicBlockToLoop(NewBB: CmpBB, LI&: *LI); |
| 984 | |
| 985 | if (VerifyLoops && CurLoop->getParentLoop()) { |
| 986 | CurLoop->getParentLoop()->verifyLoop(); |
| 987 | if (!CurLoop->getParentLoop()->isRecursivelyLCSSAForm(DT: *DT, LI: *LI)) |
| 988 | report_fatal_error(reason: "Loops must remain in LCSSA form!" ); |
| 989 | } |
| 990 | } |
| 991 | |
| 992 | bool LoopIdiomVectorize::recognizeFindFirstByte() { |
| 993 | // Currently the transformation only works on scalable vector types, although |
| 994 | // there is no fundamental reason why it cannot be made to work for fixed |
| 995 | // vectors. We also need to know the target's minimum page size in order to |
| 996 | // generate runtime memory checks to ensure the vector version won't fault. |
| 997 | if (!TTI->supportsScalableVectors() || !TTI->getMinPageSize().has_value() || |
| 998 | DisableFindFirstByte) |
| 999 | return false; |
| 1000 | |
| 1001 | // Define some constants we need throughout. |
| 1002 | BasicBlock * = CurLoop->getHeader(); |
| 1003 | LLVMContext &Ctx = Header->getContext(); |
| 1004 | |
| 1005 | // We are expecting the four blocks defined below: Header, MatchBB, InnerBB, |
| 1006 | // and OuterBB. For now, we will bail our for almost anything else. The Four |
| 1007 | // blocks contain one nested loop. |
| 1008 | if (CurLoop->getNumBackEdges() != 1 || CurLoop->getNumBlocks() != 4 || |
| 1009 | CurLoop->getSubLoops().size() != 1) |
| 1010 | return false; |
| 1011 | |
| 1012 | auto *InnerLoop = CurLoop->getSubLoops().front(); |
| 1013 | PHINode *IndPhi = dyn_cast<PHINode>(Val: &Header->front()); |
| 1014 | if (!IndPhi || IndPhi->getNumIncomingValues() != 2) |
| 1015 | return false; |
| 1016 | |
| 1017 | // Check instruction counts. |
| 1018 | auto LoopBlocks = CurLoop->getBlocks(); |
| 1019 | if (LoopBlocks[0]->sizeWithoutDebug() > 3 || |
| 1020 | LoopBlocks[1]->sizeWithoutDebug() > 4 || |
| 1021 | LoopBlocks[2]->sizeWithoutDebug() > 3 || |
| 1022 | LoopBlocks[3]->sizeWithoutDebug() > 3) |
| 1023 | return false; |
| 1024 | |
| 1025 | // Check that no instruction other than IndPhi has outside uses. |
| 1026 | for (BasicBlock *BB : LoopBlocks) |
| 1027 | for (Instruction &I : *BB) |
| 1028 | if (&I != IndPhi) |
| 1029 | for (User *U : I.users()) |
| 1030 | if (!CurLoop->contains(Inst: cast<Instruction>(Val: U))) |
| 1031 | return false; |
| 1032 | |
| 1033 | // Match the branch instruction in the header. We are expecting an |
| 1034 | // unconditional branch to the inner loop. |
| 1035 | // |
| 1036 | // Header: |
| 1037 | // %14 = phi ptr [ %24, %OuterBB ], [ %3, %Header.preheader ] |
| 1038 | // %15 = load i8, ptr %14, align 1 |
| 1039 | // br label %MatchBB |
| 1040 | BasicBlock *MatchBB; |
| 1041 | if (!match(V: Header->getTerminator(), P: m_UnconditionalBr(Succ&: MatchBB)) || |
| 1042 | !InnerLoop->contains(BB: MatchBB)) |
| 1043 | return false; |
| 1044 | |
| 1045 | // MatchBB should be the entrypoint into the inner loop containing the |
| 1046 | // comparison between a search element and a needle. |
| 1047 | // |
| 1048 | // MatchBB: |
| 1049 | // %20 = phi ptr [ %7, %Header ], [ %17, %InnerBB ] |
| 1050 | // %21 = load i8, ptr %20, align 1 |
| 1051 | // %22 = icmp eq i8 %15, %21 |
| 1052 | // br i1 %22, label %ExitSucc, label %InnerBB |
| 1053 | BasicBlock *ExitSucc, *InnerBB; |
| 1054 | Value *LoadSearch, *LoadNeedle; |
| 1055 | CmpPredicate MatchPred; |
| 1056 | if (!match(V: MatchBB->getTerminator(), |
| 1057 | P: m_Br(C: m_ICmp(Pred&: MatchPred, L: m_Value(V&: LoadSearch), R: m_Value(V&: LoadNeedle)), |
| 1058 | T: m_BasicBlock(V&: ExitSucc), F: m_BasicBlock(V&: InnerBB))) || |
| 1059 | MatchPred != ICmpInst::ICMP_EQ || !InnerLoop->contains(BB: InnerBB)) |
| 1060 | return false; |
| 1061 | |
| 1062 | // We expect outside uses of `IndPhi' in ExitSucc (and only there). |
| 1063 | for (User *U : IndPhi->users()) |
| 1064 | if (!CurLoop->contains(Inst: cast<Instruction>(Val: U))) { |
| 1065 | auto *PN = dyn_cast<PHINode>(Val: U); |
| 1066 | if (!PN || PN->getParent() != ExitSucc) |
| 1067 | return false; |
| 1068 | } |
| 1069 | |
| 1070 | // Match the loads and check they are simple. |
| 1071 | Value *Search, *Needle; |
| 1072 | if (!match(V: LoadSearch, P: m_Load(Op: m_Value(V&: Search))) || |
| 1073 | !match(V: LoadNeedle, P: m_Load(Op: m_Value(V&: Needle))) || |
| 1074 | !cast<LoadInst>(Val: LoadSearch)->isSimple() || |
| 1075 | !cast<LoadInst>(Val: LoadNeedle)->isSimple()) |
| 1076 | return false; |
| 1077 | |
| 1078 | // Check we are loading valid characters. |
| 1079 | Type *CharTy = LoadSearch->getType(); |
| 1080 | if (!CharTy->isIntegerTy() || LoadNeedle->getType() != CharTy) |
| 1081 | return false; |
| 1082 | |
| 1083 | // Pick the vectorisation factor based on CharTy, work out the cost of the |
| 1084 | // match intrinsic and decide if we should use it. |
| 1085 | // Note: For the time being we assume 128-bit vectors. |
| 1086 | unsigned VF = 128 / CharTy->getIntegerBitWidth(); |
| 1087 | SmallVector<Type *> Args = { |
| 1088 | ScalableVectorType::get(ElementType: CharTy, MinNumElts: VF), FixedVectorType::get(ElementType: CharTy, NumElts: VF), |
| 1089 | ScalableVectorType::get(ElementType: Type::getInt1Ty(C&: Ctx), MinNumElts: VF)}; |
| 1090 | IntrinsicCostAttributes Attrs(Intrinsic::experimental_vector_match, Args[2], |
| 1091 | Args); |
| 1092 | if (TTI->getIntrinsicInstrCost(ICA: Attrs, CostKind: TTI::TCK_SizeAndLatency) > 4) |
| 1093 | return false; |
| 1094 | |
| 1095 | // The loads come from two PHIs, each with two incoming values. |
| 1096 | PHINode *PSearch = dyn_cast<PHINode>(Val: Search); |
| 1097 | PHINode *PNeedle = dyn_cast<PHINode>(Val: Needle); |
| 1098 | if (!PSearch || PSearch->getNumIncomingValues() != 2 || !PNeedle || |
| 1099 | PNeedle->getNumIncomingValues() != 2) |
| 1100 | return false; |
| 1101 | |
| 1102 | // One PHI comes from the outer loop (PSearch), the other one from the inner |
| 1103 | // loop (PNeedle). PSearch effectively corresponds to IndPhi. |
| 1104 | if (InnerLoop->contains(Inst: PSearch)) |
| 1105 | std::swap(a&: PSearch, b&: PNeedle); |
| 1106 | if (PSearch != &Header->front() || PNeedle != &MatchBB->front()) |
| 1107 | return false; |
| 1108 | |
| 1109 | // The incoming values of both PHI nodes should be a gep of 1. |
| 1110 | Value *SearchStart = PSearch->getIncomingValue(i: 0); |
| 1111 | Value *SearchIndex = PSearch->getIncomingValue(i: 1); |
| 1112 | if (CurLoop->contains(BB: PSearch->getIncomingBlock(i: 0))) |
| 1113 | std::swap(a&: SearchStart, b&: SearchIndex); |
| 1114 | |
| 1115 | Value *NeedleStart = PNeedle->getIncomingValue(i: 0); |
| 1116 | Value *NeedleIndex = PNeedle->getIncomingValue(i: 1); |
| 1117 | if (InnerLoop->contains(BB: PNeedle->getIncomingBlock(i: 0))) |
| 1118 | std::swap(a&: NeedleStart, b&: NeedleIndex); |
| 1119 | |
| 1120 | // Match the GEPs. |
| 1121 | if (!match(V: SearchIndex, P: m_GEP(Ops: m_Specific(V: PSearch), Ops: m_One())) || |
| 1122 | !match(V: NeedleIndex, P: m_GEP(Ops: m_Specific(V: PNeedle), Ops: m_One()))) |
| 1123 | return false; |
| 1124 | |
| 1125 | // Check the GEPs result type matches `CharTy'. |
| 1126 | GetElementPtrInst *GEPSearch = cast<GetElementPtrInst>(Val: SearchIndex); |
| 1127 | GetElementPtrInst *GEPNeedle = cast<GetElementPtrInst>(Val: NeedleIndex); |
| 1128 | if (GEPSearch->getResultElementType() != CharTy || |
| 1129 | GEPNeedle->getResultElementType() != CharTy) |
| 1130 | return false; |
| 1131 | |
| 1132 | // InnerBB should increment the address of the needle pointer. |
| 1133 | // |
| 1134 | // InnerBB: |
| 1135 | // %17 = getelementptr inbounds i8, ptr %20, i64 1 |
| 1136 | // %18 = icmp eq ptr %17, %10 |
| 1137 | // br i1 %18, label %OuterBB, label %MatchBB |
| 1138 | BasicBlock *OuterBB; |
| 1139 | Value *NeedleEnd; |
| 1140 | if (!match(V: InnerBB->getTerminator(), |
| 1141 | P: m_Br(C: m_SpecificICmp(MatchPred: ICmpInst::ICMP_EQ, L: m_Specific(V: GEPNeedle), |
| 1142 | R: m_Value(V&: NeedleEnd)), |
| 1143 | T: m_BasicBlock(V&: OuterBB), F: m_Specific(V: MatchBB))) || |
| 1144 | !CurLoop->contains(BB: OuterBB)) |
| 1145 | return false; |
| 1146 | |
| 1147 | // OuterBB should increment the address of the search element pointer. |
| 1148 | // |
| 1149 | // OuterBB: |
| 1150 | // %24 = getelementptr inbounds i8, ptr %14, i64 1 |
| 1151 | // %25 = icmp eq ptr %24, %6 |
| 1152 | // br i1 %25, label %ExitFail, label %Header |
| 1153 | BasicBlock *ExitFail; |
| 1154 | Value *SearchEnd; |
| 1155 | if (!match(V: OuterBB->getTerminator(), |
| 1156 | P: m_Br(C: m_SpecificICmp(MatchPred: ICmpInst::ICMP_EQ, L: m_Specific(V: GEPSearch), |
| 1157 | R: m_Value(V&: SearchEnd)), |
| 1158 | T: m_BasicBlock(V&: ExitFail), F: m_Specific(V: Header)))) |
| 1159 | return false; |
| 1160 | |
| 1161 | if (!CurLoop->isLoopInvariant(V: SearchStart) || |
| 1162 | !CurLoop->isLoopInvariant(V: SearchEnd) || |
| 1163 | !CurLoop->isLoopInvariant(V: NeedleStart) || |
| 1164 | !CurLoop->isLoopInvariant(V: NeedleEnd)) |
| 1165 | return false; |
| 1166 | |
| 1167 | LLVM_DEBUG(dbgs() << "Found idiom in loop: \n" << *CurLoop << "\n\n" ); |
| 1168 | |
| 1169 | transformFindFirstByte(IndPhi, VF, CharTy, ExitSucc, ExitFail, SearchStart, |
| 1170 | SearchEnd, NeedleStart, NeedleEnd); |
| 1171 | return true; |
| 1172 | } |
| 1173 | |
| 1174 | Value *LoopIdiomVectorize::expandFindFirstByte( |
| 1175 | IRBuilder<> &Builder, DomTreeUpdater &DTU, unsigned VF, Type *CharTy, |
| 1176 | BasicBlock *ExitSucc, BasicBlock *ExitFail, Value *SearchStart, |
| 1177 | Value *SearchEnd, Value *NeedleStart, Value *NeedleEnd) { |
| 1178 | // Set up some types and constants that we intend to reuse. |
| 1179 | auto *PtrTy = Builder.getPtrTy(); |
| 1180 | auto *I64Ty = Builder.getInt64Ty(); |
| 1181 | auto *PredVTy = ScalableVectorType::get(ElementType: Builder.getInt1Ty(), MinNumElts: VF); |
| 1182 | auto *CharVTy = ScalableVectorType::get(ElementType: CharTy, MinNumElts: VF); |
| 1183 | auto *ConstVF = ConstantInt::get(Ty: I64Ty, V: VF); |
| 1184 | |
| 1185 | // Other common arguments. |
| 1186 | BasicBlock * = CurLoop->getLoopPreheader(); |
| 1187 | LLVMContext &Ctx = Preheader->getContext(); |
| 1188 | Value *Passthru = ConstantInt::getNullValue(Ty: CharVTy); |
| 1189 | |
| 1190 | // Split block in the original loop preheader. |
| 1191 | // SPH is the new preheader to the old scalar loop. |
| 1192 | BasicBlock *SPH = SplitBlock(Old: Preheader, SplitPt: Preheader->getTerminator(), DT, LI, |
| 1193 | MSSAU: nullptr, BBName: "scalar_preheader" ); |
| 1194 | |
| 1195 | // Create the blocks that we're going to use. |
| 1196 | // |
| 1197 | // We will have the following loops: |
| 1198 | // (O) Outer loop where we iterate over the elements of the search array. |
| 1199 | // (I) Inner loop where we iterate over the elements of the needle array. |
| 1200 | // |
| 1201 | // Overall, the blocks do the following: |
| 1202 | // (0) Check if the arrays can't cross page boundaries. If so go to (1), |
| 1203 | // otherwise fall back to the original scalar loop. |
| 1204 | // (1) Load the search array. Go to (2). |
| 1205 | // (2) (a) Load the needle array. |
| 1206 | // (b) Splat the first element to the inactive lanes. |
| 1207 | // (c) Check if any elements match. If so go to (3), otherwise go to (4). |
| 1208 | // (3) Compute the index of the first match and exit. |
| 1209 | // (4) Check if we've reached the end of the needle array. If not loop back to |
| 1210 | // (2), otherwise go to (5). |
| 1211 | // (5) Check if we've reached the end of the search array. If not loop back to |
| 1212 | // (1), otherwise exit. |
| 1213 | // Blocks (0,3) are not part of any loop. Blocks (1,5) and (2,4) belong to |
| 1214 | // the outer and inner loops, respectively. |
| 1215 | BasicBlock *BB0 = BasicBlock::Create(Context&: Ctx, Name: "mem_check" , Parent: SPH->getParent(), InsertBefore: SPH); |
| 1216 | BasicBlock *BB1 = |
| 1217 | BasicBlock::Create(Context&: Ctx, Name: "find_first_vec_header" , Parent: SPH->getParent(), InsertBefore: SPH); |
| 1218 | BasicBlock *BB2 = |
| 1219 | BasicBlock::Create(Context&: Ctx, Name: "match_check_vec" , Parent: SPH->getParent(), InsertBefore: SPH); |
| 1220 | BasicBlock *BB3 = |
| 1221 | BasicBlock::Create(Context&: Ctx, Name: "calculate_match" , Parent: SPH->getParent(), InsertBefore: SPH); |
| 1222 | BasicBlock *BB4 = |
| 1223 | BasicBlock::Create(Context&: Ctx, Name: "needle_check_vec" , Parent: SPH->getParent(), InsertBefore: SPH); |
| 1224 | BasicBlock *BB5 = |
| 1225 | BasicBlock::Create(Context&: Ctx, Name: "search_check_vec" , Parent: SPH->getParent(), InsertBefore: SPH); |
| 1226 | |
| 1227 | // Update LoopInfo with the new loops. |
| 1228 | auto OuterLoop = LI->AllocateLoop(); |
| 1229 | auto InnerLoop = LI->AllocateLoop(); |
| 1230 | |
| 1231 | if (auto ParentLoop = CurLoop->getParentLoop()) { |
| 1232 | ParentLoop->addBasicBlockToLoop(NewBB: BB0, LI&: *LI); |
| 1233 | ParentLoop->addChildLoop(NewChild: OuterLoop); |
| 1234 | ParentLoop->addBasicBlockToLoop(NewBB: BB3, LI&: *LI); |
| 1235 | } else { |
| 1236 | LI->addTopLevelLoop(New: OuterLoop); |
| 1237 | } |
| 1238 | |
| 1239 | // Add the inner loop to the outer. |
| 1240 | OuterLoop->addChildLoop(NewChild: InnerLoop); |
| 1241 | |
| 1242 | // Add the new basic blocks to the corresponding loops. |
| 1243 | OuterLoop->addBasicBlockToLoop(NewBB: BB1, LI&: *LI); |
| 1244 | OuterLoop->addBasicBlockToLoop(NewBB: BB5, LI&: *LI); |
| 1245 | InnerLoop->addBasicBlockToLoop(NewBB: BB2, LI&: *LI); |
| 1246 | InnerLoop->addBasicBlockToLoop(NewBB: BB4, LI&: *LI); |
| 1247 | |
| 1248 | // Update the terminator added by SplitBlock to branch to the first block. |
| 1249 | Preheader->getTerminator()->setSuccessor(Idx: 0, BB: BB0); |
| 1250 | DTU.applyUpdates(Updates: {{DominatorTree::Delete, Preheader, SPH}, |
| 1251 | {DominatorTree::Insert, Preheader, BB0}}); |
| 1252 | |
| 1253 | // (0) Check if we could be crossing a page boundary; if so, fallback to the |
| 1254 | // old scalar loops. Also create a predicate of VF elements to be used in the |
| 1255 | // vector loops. |
| 1256 | Builder.SetInsertPoint(BB0); |
| 1257 | Value *ISearchStart = |
| 1258 | Builder.CreatePtrToInt(V: SearchStart, DestTy: I64Ty, Name: "search_start_int" ); |
| 1259 | Value *ISearchEnd = |
| 1260 | Builder.CreatePtrToInt(V: SearchEnd, DestTy: I64Ty, Name: "search_end_int" ); |
| 1261 | Value *INeedleStart = |
| 1262 | Builder.CreatePtrToInt(V: NeedleStart, DestTy: I64Ty, Name: "needle_start_int" ); |
| 1263 | Value *INeedleEnd = |
| 1264 | Builder.CreatePtrToInt(V: NeedleEnd, DestTy: I64Ty, Name: "needle_end_int" ); |
| 1265 | Value *PredVF = |
| 1266 | Builder.CreateIntrinsic(ID: Intrinsic::get_active_lane_mask, Types: {PredVTy, I64Ty}, |
| 1267 | Args: {ConstantInt::get(Ty: I64Ty, V: 0), ConstVF}); |
| 1268 | |
| 1269 | const uint64_t MinPageSize = TTI->getMinPageSize().value(); |
| 1270 | const uint64_t AddrShiftAmt = llvm::Log2_64(Value: MinPageSize); |
| 1271 | Value *SearchStartPage = |
| 1272 | Builder.CreateLShr(LHS: ISearchStart, RHS: AddrShiftAmt, Name: "search_start_page" ); |
| 1273 | Value *SearchEndPage = |
| 1274 | Builder.CreateLShr(LHS: ISearchEnd, RHS: AddrShiftAmt, Name: "search_end_page" ); |
| 1275 | Value *NeedleStartPage = |
| 1276 | Builder.CreateLShr(LHS: INeedleStart, RHS: AddrShiftAmt, Name: "needle_start_page" ); |
| 1277 | Value *NeedleEndPage = |
| 1278 | Builder.CreateLShr(LHS: INeedleEnd, RHS: AddrShiftAmt, Name: "needle_end_page" ); |
| 1279 | Value *SearchPageCmp = |
| 1280 | Builder.CreateICmpNE(LHS: SearchStartPage, RHS: SearchEndPage, Name: "search_page_cmp" ); |
| 1281 | Value *NeedlePageCmp = |
| 1282 | Builder.CreateICmpNE(LHS: NeedleStartPage, RHS: NeedleEndPage, Name: "needle_page_cmp" ); |
| 1283 | |
| 1284 | Value *CombinedPageCmp = |
| 1285 | Builder.CreateOr(LHS: SearchPageCmp, RHS: NeedlePageCmp, Name: "combined_page_cmp" ); |
| 1286 | BranchInst *CombinedPageBr = Builder.CreateCondBr(Cond: CombinedPageCmp, True: SPH, False: BB1); |
| 1287 | CombinedPageBr->setMetadata(KindID: LLVMContext::MD_prof, |
| 1288 | Node: MDBuilder(Ctx).createBranchWeights(TrueWeight: 10, FalseWeight: 90)); |
| 1289 | DTU.applyUpdates( |
| 1290 | Updates: {{DominatorTree::Insert, BB0, SPH}, {DominatorTree::Insert, BB0, BB1}}); |
| 1291 | |
| 1292 | // (1) Load the search array and branch to the inner loop. |
| 1293 | Builder.SetInsertPoint(BB1); |
| 1294 | PHINode *Search = Builder.CreatePHI(Ty: PtrTy, NumReservedValues: 2, Name: "psearch" ); |
| 1295 | Value *PredSearch = Builder.CreateIntrinsic( |
| 1296 | ID: Intrinsic::get_active_lane_mask, Types: {PredVTy, I64Ty}, |
| 1297 | Args: {Builder.CreatePtrToInt(V: Search, DestTy: I64Ty), ISearchEnd}, FMFSource: nullptr, |
| 1298 | Name: "search_pred" ); |
| 1299 | PredSearch = Builder.CreateAnd(LHS: PredVF, RHS: PredSearch, Name: "search_masked" ); |
| 1300 | Value *LoadSearch = Builder.CreateMaskedLoad( |
| 1301 | Ty: CharVTy, Ptr: Search, Alignment: Align(1), Mask: PredSearch, PassThru: Passthru, Name: "search_load_vec" ); |
| 1302 | Builder.CreateBr(Dest: BB2); |
| 1303 | DTU.applyUpdates(Updates: {{DominatorTree::Insert, BB1, BB2}}); |
| 1304 | |
| 1305 | // (2) Inner loop. |
| 1306 | Builder.SetInsertPoint(BB2); |
| 1307 | PHINode *Needle = Builder.CreatePHI(Ty: PtrTy, NumReservedValues: 2, Name: "pneedle" ); |
| 1308 | |
| 1309 | // (2.a) Load the needle array. |
| 1310 | Value *PredNeedle = Builder.CreateIntrinsic( |
| 1311 | ID: Intrinsic::get_active_lane_mask, Types: {PredVTy, I64Ty}, |
| 1312 | Args: {Builder.CreatePtrToInt(V: Needle, DestTy: I64Ty), INeedleEnd}, FMFSource: nullptr, |
| 1313 | Name: "needle_pred" ); |
| 1314 | PredNeedle = Builder.CreateAnd(LHS: PredVF, RHS: PredNeedle, Name: "needle_masked" ); |
| 1315 | Value *LoadNeedle = Builder.CreateMaskedLoad( |
| 1316 | Ty: CharVTy, Ptr: Needle, Alignment: Align(1), Mask: PredNeedle, PassThru: Passthru, Name: "needle_load_vec" ); |
| 1317 | |
| 1318 | // (2.b) Splat the first element to the inactive lanes. |
| 1319 | Value *Needle0 = |
| 1320 | Builder.CreateExtractElement(Vec: LoadNeedle, Idx: uint64_t(0), Name: "needle0" ); |
| 1321 | Value *Needle0Splat = Builder.CreateVectorSplat(EC: ElementCount::getScalable(MinVal: VF), |
| 1322 | V: Needle0, Name: "needle0" ); |
| 1323 | LoadNeedle = Builder.CreateSelect(C: PredNeedle, True: LoadNeedle, False: Needle0Splat, |
| 1324 | Name: "needle_splat" ); |
| 1325 | LoadNeedle = Builder.CreateExtractVector( |
| 1326 | DstType: FixedVectorType::get(ElementType: CharTy, NumElts: VF), SrcVec: LoadNeedle, Idx: uint64_t(0), Name: "needle_vec" ); |
| 1327 | |
| 1328 | // (2.c) Test if there's a match. |
| 1329 | Value *MatchPred = Builder.CreateIntrinsic( |
| 1330 | ID: Intrinsic::experimental_vector_match, Types: {CharVTy, LoadNeedle->getType()}, |
| 1331 | Args: {LoadSearch, LoadNeedle, PredSearch}, FMFSource: nullptr, Name: "match_pred" ); |
| 1332 | Value *IfAnyMatch = Builder.CreateOrReduce(Src: MatchPred); |
| 1333 | Builder.CreateCondBr(Cond: IfAnyMatch, True: BB3, False: BB4); |
| 1334 | DTU.applyUpdates( |
| 1335 | Updates: {{DominatorTree::Insert, BB2, BB3}, {DominatorTree::Insert, BB2, BB4}}); |
| 1336 | |
| 1337 | // (3) We found a match. Compute the index of its location and exit. |
| 1338 | Builder.SetInsertPoint(BB3); |
| 1339 | PHINode *MatchLCSSA = Builder.CreatePHI(Ty: PtrTy, NumReservedValues: 1, Name: "match_start" ); |
| 1340 | PHINode *MatchPredLCSSA = |
| 1341 | Builder.CreatePHI(Ty: MatchPred->getType(), NumReservedValues: 1, Name: "match_vec" ); |
| 1342 | Value *MatchCnt = Builder.CreateIntrinsic( |
| 1343 | ID: Intrinsic::experimental_cttz_elts, Types: {I64Ty, MatchPred->getType()}, |
| 1344 | Args: {MatchPredLCSSA, /*ZeroIsPoison=*/Builder.getInt1(V: true)}, FMFSource: nullptr, |
| 1345 | Name: "match_idx" ); |
| 1346 | Value *MatchVal = |
| 1347 | Builder.CreateGEP(Ty: CharTy, Ptr: MatchLCSSA, IdxList: MatchCnt, Name: "match_res" ); |
| 1348 | Builder.CreateBr(Dest: ExitSucc); |
| 1349 | DTU.applyUpdates(Updates: {{DominatorTree::Insert, BB3, ExitSucc}}); |
| 1350 | |
| 1351 | // (4) Check if we've reached the end of the needle array. |
| 1352 | Builder.SetInsertPoint(BB4); |
| 1353 | Value *NextNeedle = |
| 1354 | Builder.CreateGEP(Ty: CharTy, Ptr: Needle, IdxList: ConstVF, Name: "needle_next_vec" ); |
| 1355 | Builder.CreateCondBr(Cond: Builder.CreateICmpULT(LHS: NextNeedle, RHS: NeedleEnd), True: BB2, False: BB5); |
| 1356 | DTU.applyUpdates( |
| 1357 | Updates: {{DominatorTree::Insert, BB4, BB2}, {DominatorTree::Insert, BB4, BB5}}); |
| 1358 | |
| 1359 | // (5) Check if we've reached the end of the search array. |
| 1360 | Builder.SetInsertPoint(BB5); |
| 1361 | Value *NextSearch = |
| 1362 | Builder.CreateGEP(Ty: CharTy, Ptr: Search, IdxList: ConstVF, Name: "search_next_vec" ); |
| 1363 | Builder.CreateCondBr(Cond: Builder.CreateICmpULT(LHS: NextSearch, RHS: SearchEnd), True: BB1, |
| 1364 | False: ExitFail); |
| 1365 | DTU.applyUpdates(Updates: {{DominatorTree::Insert, BB5, BB1}, |
| 1366 | {DominatorTree::Insert, BB5, ExitFail}}); |
| 1367 | |
| 1368 | // Set up the PHI nodes. |
| 1369 | Search->addIncoming(V: SearchStart, BB: BB0); |
| 1370 | Search->addIncoming(V: NextSearch, BB: BB5); |
| 1371 | Needle->addIncoming(V: NeedleStart, BB: BB1); |
| 1372 | Needle->addIncoming(V: NextNeedle, BB: BB4); |
| 1373 | // These are needed to retain LCSSA form. |
| 1374 | MatchLCSSA->addIncoming(V: Search, BB: BB2); |
| 1375 | MatchPredLCSSA->addIncoming(V: MatchPred, BB: BB2); |
| 1376 | |
| 1377 | if (VerifyLoops) { |
| 1378 | OuterLoop->verifyLoop(); |
| 1379 | InnerLoop->verifyLoop(); |
| 1380 | if (!OuterLoop->isRecursivelyLCSSAForm(DT: *DT, LI: *LI)) |
| 1381 | report_fatal_error(reason: "Loops must remain in LCSSA form!" ); |
| 1382 | } |
| 1383 | |
| 1384 | return MatchVal; |
| 1385 | } |
| 1386 | |
| 1387 | void LoopIdiomVectorize::transformFindFirstByte( |
| 1388 | PHINode *IndPhi, unsigned VF, Type *CharTy, BasicBlock *ExitSucc, |
| 1389 | BasicBlock *ExitFail, Value *SearchStart, Value *SearchEnd, |
| 1390 | Value *NeedleStart, Value *NeedleEnd) { |
| 1391 | // Insert the find first byte code at the end of the preheader block. |
| 1392 | BasicBlock * = CurLoop->getLoopPreheader(); |
| 1393 | BranchInst *PHBranch = cast<BranchInst>(Val: Preheader->getTerminator()); |
| 1394 | IRBuilder<> Builder(PHBranch); |
| 1395 | DomTreeUpdater DTU(DT, DomTreeUpdater::UpdateStrategy::Lazy); |
| 1396 | Builder.SetCurrentDebugLocation(PHBranch->getDebugLoc()); |
| 1397 | |
| 1398 | Value *MatchVal = |
| 1399 | expandFindFirstByte(Builder, DTU, VF, CharTy, ExitSucc, ExitFail, |
| 1400 | SearchStart, SearchEnd, NeedleStart, NeedleEnd); |
| 1401 | |
| 1402 | assert(PHBranch->isUnconditional() && |
| 1403 | "Expected preheader to terminate with an unconditional branch." ); |
| 1404 | |
| 1405 | // Add new incoming values with the result of the transformation to PHINodes |
| 1406 | // of ExitSucc that use IndPhi. |
| 1407 | for (auto *U : llvm::make_early_inc_range(Range: IndPhi->users())) { |
| 1408 | auto *PN = dyn_cast<PHINode>(Val: U); |
| 1409 | if (PN && PN->getParent() == ExitSucc) |
| 1410 | PN->addIncoming(V: MatchVal, BB: cast<Instruction>(Val: MatchVal)->getParent()); |
| 1411 | } |
| 1412 | |
| 1413 | if (VerifyLoops && CurLoop->getParentLoop()) { |
| 1414 | CurLoop->getParentLoop()->verifyLoop(); |
| 1415 | if (!CurLoop->getParentLoop()->isRecursivelyLCSSAForm(DT: *DT, LI: *LI)) |
| 1416 | report_fatal_error(reason: "Loops must remain in LCSSA form!" ); |
| 1417 | } |
| 1418 | } |
| 1419 | |