| 1 | //===---- EVLIndVarSimplify.cpp - Optimize vectorized loops w/ EVL IV------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | // |
| 9 | // This pass optimizes a vectorized loop with canonical IV to using EVL-based |
| 10 | // IV if it was tail-folded by predicated EVL. |
| 11 | // |
| 12 | //===----------------------------------------------------------------------===// |
| 13 | |
| 14 | #include "llvm/Transforms/Vectorize/EVLIndVarSimplify.h" |
| 15 | #include "llvm/ADT/Statistic.h" |
| 16 | #include "llvm/Analysis/IVDescriptors.h" |
| 17 | #include "llvm/Analysis/LoopInfo.h" |
| 18 | #include "llvm/Analysis/LoopPass.h" |
| 19 | #include "llvm/Analysis/OptimizationRemarkEmitter.h" |
| 20 | #include "llvm/Analysis/ScalarEvolution.h" |
| 21 | #include "llvm/Analysis/ScalarEvolutionExpressions.h" |
| 22 | #include "llvm/Analysis/ValueTracking.h" |
| 23 | #include "llvm/IR/IRBuilder.h" |
| 24 | #include "llvm/IR/PatternMatch.h" |
| 25 | #include "llvm/Support/CommandLine.h" |
| 26 | #include "llvm/Support/Debug.h" |
| 27 | #include "llvm/Support/MathExtras.h" |
| 28 | #include "llvm/Support/raw_ostream.h" |
| 29 | #include "llvm/Transforms/Scalar/LoopPassManager.h" |
| 30 | #include "llvm/Transforms/Utils/Local.h" |
| 31 | |
| 32 | #define DEBUG_TYPE "evl-iv-simplify" |
| 33 | |
| 34 | using namespace llvm; |
| 35 | |
| 36 | STATISTIC(NumEliminatedCanonicalIV, "Number of canonical IVs we eliminated" ); |
| 37 | |
| 38 | static cl::opt<bool> EnableEVLIndVarSimplify( |
| 39 | "enable-evl-indvar-simplify" , |
| 40 | cl::desc("Enable EVL-based induction variable simplify Pass" ), cl::Hidden, |
| 41 | cl::init(Val: true)); |
| 42 | |
| 43 | namespace { |
| 44 | struct EVLIndVarSimplifyImpl { |
| 45 | ScalarEvolution &SE; |
| 46 | OptimizationRemarkEmitter *ORE = nullptr; |
| 47 | |
| 48 | EVLIndVarSimplifyImpl(LoopStandardAnalysisResults &LAR, |
| 49 | OptimizationRemarkEmitter *ORE) |
| 50 | : SE(LAR.SE), ORE(ORE) {} |
| 51 | |
| 52 | /// Returns true if modify the loop. |
| 53 | bool run(Loop &L); |
| 54 | }; |
| 55 | } // anonymous namespace |
| 56 | |
| 57 | /// Returns the constant part of vectorization factor from the induction |
| 58 | /// variable's step value SCEV expression. |
| 59 | static uint32_t getVFFromIndVar(const SCEV *Step, const Function &F) { |
| 60 | if (!Step) |
| 61 | return 0U; |
| 62 | |
| 63 | // Looking for loops with IV step value in the form of `(<constant VF> x |
| 64 | // vscale)`. |
| 65 | if (const auto *Mul = dyn_cast<SCEVMulExpr>(Val: Step)) { |
| 66 | if (Mul->getNumOperands() == 2) { |
| 67 | const SCEV *LHS = Mul->getOperand(i: 0); |
| 68 | const SCEV *RHS = Mul->getOperand(i: 1); |
| 69 | if (const auto *Const = dyn_cast<SCEVConstant>(Val: LHS); |
| 70 | Const && isa<SCEVVScale>(Val: RHS)) { |
| 71 | uint64_t V = Const->getAPInt().getLimitedValue(); |
| 72 | if (llvm::isUInt<32>(x: V)) |
| 73 | return V; |
| 74 | } |
| 75 | } |
| 76 | } |
| 77 | |
| 78 | // If not, see if the vscale_range of the parent function is a fixed value, |
| 79 | // which makes the step value to be replaced by a constant. |
| 80 | if (F.hasFnAttribute(Kind: Attribute::VScaleRange)) |
| 81 | if (const auto *ConstStep = dyn_cast<SCEVConstant>(Val: Step)) { |
| 82 | APInt V = ConstStep->getAPInt().abs(); |
| 83 | ConstantRange CR = llvm::getVScaleRange(F: &F, BitWidth: 64); |
| 84 | if (const APInt *Fixed = CR.getSingleElement()) { |
| 85 | V = V.zextOrTrunc(width: Fixed->getBitWidth()); |
| 86 | uint64_t VF = V.udiv(RHS: *Fixed).getLimitedValue(); |
| 87 | if (VF && llvm::isUInt<32>(x: VF) && |
| 88 | // Make sure step is divisible by vscale. |
| 89 | V.urem(RHS: *Fixed).isZero()) |
| 90 | return VF; |
| 91 | } |
| 92 | } |
| 93 | |
| 94 | return 0U; |
| 95 | } |
| 96 | |
| 97 | bool EVLIndVarSimplifyImpl::run(Loop &L) { |
| 98 | if (!EnableEVLIndVarSimplify) |
| 99 | return false; |
| 100 | |
| 101 | if (!getBooleanLoopAttribute(TheLoop: &L, Name: "llvm.loop.isvectorized" )) |
| 102 | return false; |
| 103 | const MDOperand *EVLMD = |
| 104 | findStringMetadataForLoop(TheLoop: &L, Name: "llvm.loop.isvectorized.tailfoldingstyle" ) |
| 105 | .value_or(u: nullptr); |
| 106 | if (!EVLMD || !EVLMD->equalsStr(Str: "evl" )) |
| 107 | return false; |
| 108 | |
| 109 | BasicBlock *LatchBlock = L.getLoopLatch(); |
| 110 | ICmpInst *OrigLatchCmp = L.getLatchCmpInst(); |
| 111 | if (!LatchBlock || !OrigLatchCmp) |
| 112 | return false; |
| 113 | |
| 114 | InductionDescriptor IVD; |
| 115 | PHINode *IndVar = L.getInductionVariable(SE); |
| 116 | if (!IndVar || !L.getInductionDescriptor(SE, IndDesc&: IVD)) { |
| 117 | const char *Reason = (IndVar ? "induction descriptor is not available" |
| 118 | : "cannot recognize induction variable" ); |
| 119 | LLVM_DEBUG(dbgs() << "Cannot retrieve IV from loop " << L.getName() |
| 120 | << " because" << Reason << "\n" ); |
| 121 | if (ORE) { |
| 122 | ORE->emit(RemarkBuilder: [&]() { |
| 123 | return OptimizationRemarkMissed(DEBUG_TYPE, "UnrecognizedIndVar" , |
| 124 | L.getStartLoc(), L.getHeader()) |
| 125 | << "Cannot retrieve IV because " << ore::NV("Reason" , Reason); |
| 126 | }); |
| 127 | } |
| 128 | return false; |
| 129 | } |
| 130 | |
| 131 | BasicBlock *InitBlock, *BackEdgeBlock; |
| 132 | if (!L.getIncomingAndBackEdge(Incoming&: InitBlock, Backedge&: BackEdgeBlock)) { |
| 133 | LLVM_DEBUG(dbgs() << "Expect unique incoming and backedge in " |
| 134 | << L.getName() << "\n" ); |
| 135 | if (ORE) { |
| 136 | ORE->emit(RemarkBuilder: [&]() { |
| 137 | return OptimizationRemarkMissed(DEBUG_TYPE, "UnrecognizedLoopStructure" , |
| 138 | L.getStartLoc(), L.getHeader()) |
| 139 | << "Does not have a unique incoming and backedge" ; |
| 140 | }); |
| 141 | } |
| 142 | return false; |
| 143 | } |
| 144 | |
| 145 | // Retrieve the loop bounds. |
| 146 | std::optional<Loop::LoopBounds> Bounds = L.getBounds(SE); |
| 147 | if (!Bounds) { |
| 148 | LLVM_DEBUG(dbgs() << "Could not obtain the bounds for loop " << L.getName() |
| 149 | << "\n" ); |
| 150 | if (ORE) { |
| 151 | ORE->emit(RemarkBuilder: [&]() { |
| 152 | return OptimizationRemarkMissed(DEBUG_TYPE, "UnrecognizedLoopStructure" , |
| 153 | L.getStartLoc(), L.getHeader()) |
| 154 | << "Could not obtain the loop bounds" ; |
| 155 | }); |
| 156 | } |
| 157 | return false; |
| 158 | } |
| 159 | Value *CanonicalIVInit = &Bounds->getInitialIVValue(); |
| 160 | Value *CanonicalIVFinal = &Bounds->getFinalIVValue(); |
| 161 | |
| 162 | const SCEV *StepV = IVD.getStep(); |
| 163 | uint32_t VF = getVFFromIndVar(Step: StepV, F: *L.getHeader()->getParent()); |
| 164 | if (!VF) { |
| 165 | LLVM_DEBUG(dbgs() << "Could not infer VF from IndVar step '" << *StepV |
| 166 | << "'\n" ); |
| 167 | if (ORE) { |
| 168 | ORE->emit(RemarkBuilder: [&]() { |
| 169 | return OptimizationRemarkMissed(DEBUG_TYPE, "UnrecognizedIndVar" , |
| 170 | L.getStartLoc(), L.getHeader()) |
| 171 | << "Could not infer VF from IndVar step " |
| 172 | << ore::NV("Step" , StepV); |
| 173 | }); |
| 174 | } |
| 175 | return false; |
| 176 | } |
| 177 | LLVM_DEBUG(dbgs() << "Using VF=" << VF << " for loop " << L.getName() |
| 178 | << "\n" ); |
| 179 | |
| 180 | // Try to find the EVL-based induction variable. |
| 181 | using namespace PatternMatch; |
| 182 | BasicBlock *BB = IndVar->getParent(); |
| 183 | |
| 184 | Value *EVLIndVar = nullptr; |
| 185 | Value *RemTC = nullptr; |
| 186 | Value *TC = nullptr; |
| 187 | auto IntrinsicMatch = m_Intrinsic<Intrinsic::experimental_get_vector_length>( |
| 188 | Op0: m_Value(V&: RemTC), Op1: m_SpecificInt(V: VF), |
| 189 | /*Scalable=*/Op2: m_SpecificInt(V: 1)); |
| 190 | for (PHINode &PN : BB->phis()) { |
| 191 | if (&PN == IndVar) |
| 192 | continue; |
| 193 | |
| 194 | // Check 1: it has to contain both incoming (init) & backedge blocks |
| 195 | // from IndVar. |
| 196 | if (PN.getBasicBlockIndex(BB: InitBlock) < 0 || |
| 197 | PN.getBasicBlockIndex(BB: BackEdgeBlock) < 0) |
| 198 | continue; |
| 199 | // Check 2: EVL index is always increasing, thus its inital value has to be |
| 200 | // equal to either the initial IV value (when the canonical IV is also |
| 201 | // increasing) or the last IV value (when canonical IV is decreasing). |
| 202 | Value *Init = PN.getIncomingValueForBlock(BB: InitBlock); |
| 203 | using Direction = Loop::LoopBounds::Direction; |
| 204 | switch (Bounds->getDirection()) { |
| 205 | case Direction::Increasing: |
| 206 | if (Init != CanonicalIVInit) |
| 207 | continue; |
| 208 | break; |
| 209 | case Direction::Decreasing: |
| 210 | if (Init != CanonicalIVFinal) |
| 211 | continue; |
| 212 | break; |
| 213 | case Direction::Unknown: |
| 214 | // To be more permissive and see if either the initial or final IV value |
| 215 | // matches PN's init value. |
| 216 | if (Init != CanonicalIVInit && Init != CanonicalIVFinal) |
| 217 | continue; |
| 218 | break; |
| 219 | } |
| 220 | Value *RecValue = PN.getIncomingValueForBlock(BB: BackEdgeBlock); |
| 221 | assert(RecValue && "expect recurrent IndVar value" ); |
| 222 | |
| 223 | LLVM_DEBUG(dbgs() << "Found candidate PN of EVL-based IndVar: " << PN |
| 224 | << "\n" ); |
| 225 | |
| 226 | // Check 3: Pattern match to find the EVL-based index and total trip count |
| 227 | // (TC). |
| 228 | if (match(V: RecValue, |
| 229 | P: m_c_Add(L: m_ZExtOrSelf(Op: IntrinsicMatch), R: m_Specific(V: &PN))) && |
| 230 | match(V: RemTC, P: m_Sub(L: m_Value(V&: TC), R: m_Specific(V: &PN)))) { |
| 231 | EVLIndVar = RecValue; |
| 232 | break; |
| 233 | } |
| 234 | } |
| 235 | |
| 236 | if (!EVLIndVar || !TC) |
| 237 | return false; |
| 238 | |
| 239 | LLVM_DEBUG(dbgs() << "Using " << *EVLIndVar << " for EVL-based IndVar\n" ); |
| 240 | if (ORE) { |
| 241 | ORE->emit(RemarkBuilder: [&]() { |
| 242 | DebugLoc DL; |
| 243 | BasicBlock *Region = nullptr; |
| 244 | if (auto *I = dyn_cast<Instruction>(Val: EVLIndVar)) { |
| 245 | DL = I->getDebugLoc(); |
| 246 | Region = I->getParent(); |
| 247 | } else { |
| 248 | DL = L.getStartLoc(); |
| 249 | Region = L.getHeader(); |
| 250 | } |
| 251 | return OptimizationRemark(DEBUG_TYPE, "UseEVLIndVar" , DL, Region) |
| 252 | << "Using " << ore::NV("EVLIndVar" , EVLIndVar) |
| 253 | << " for EVL-based IndVar" ; |
| 254 | }); |
| 255 | } |
| 256 | |
| 257 | // Create an EVL-based comparison and replace the branch to use it as |
| 258 | // predicate. |
| 259 | |
| 260 | // Loop::getLatchCmpInst check at the beginning of this function has ensured |
| 261 | // that latch block ends in a conditional branch. |
| 262 | auto *LatchBranch = cast<BranchInst>(Val: LatchBlock->getTerminator()); |
| 263 | assert(LatchBranch->isConditional() && |
| 264 | "expect the loop latch to be ended with a conditional branch" ); |
| 265 | ICmpInst::Predicate Pred; |
| 266 | if (LatchBranch->getSuccessor(i: 0) == L.getHeader()) |
| 267 | Pred = ICmpInst::ICMP_NE; |
| 268 | else |
| 269 | Pred = ICmpInst::ICMP_EQ; |
| 270 | |
| 271 | IRBuilder<> Builder(OrigLatchCmp); |
| 272 | auto *NewLatchCmp = Builder.CreateICmp(P: Pred, LHS: EVLIndVar, RHS: TC); |
| 273 | OrigLatchCmp->replaceAllUsesWith(V: NewLatchCmp); |
| 274 | |
| 275 | // llvm::RecursivelyDeleteDeadPHINode only deletes cycles whose values are |
| 276 | // not used outside the cycles. However, in this case the now-RAUW-ed |
| 277 | // OrigLatchCmp will be considered a use outside the cycle while in reality |
| 278 | // it's practically dead. Thus we need to remove it before calling |
| 279 | // RecursivelyDeleteDeadPHINode. |
| 280 | (void)RecursivelyDeleteTriviallyDeadInstructions(V: OrigLatchCmp); |
| 281 | if (llvm::RecursivelyDeleteDeadPHINode(PN: IndVar)) |
| 282 | LLVM_DEBUG(dbgs() << "Removed original IndVar\n" ); |
| 283 | |
| 284 | ++NumEliminatedCanonicalIV; |
| 285 | |
| 286 | return true; |
| 287 | } |
| 288 | |
| 289 | PreservedAnalyses EVLIndVarSimplifyPass::run(Loop &L, LoopAnalysisManager &LAM, |
| 290 | LoopStandardAnalysisResults &AR, |
| 291 | LPMUpdater &U) { |
| 292 | Function &F = *L.getHeader()->getParent(); |
| 293 | auto &FAMProxy = LAM.getResult<FunctionAnalysisManagerLoopProxy>(IR&: L, ExtraArgs&: AR); |
| 294 | OptimizationRemarkEmitter *ORE = |
| 295 | FAMProxy.getCachedResult<OptimizationRemarkEmitterAnalysis>(IR&: F); |
| 296 | |
| 297 | if (EVLIndVarSimplifyImpl(AR, ORE).run(L)) |
| 298 | return PreservedAnalyses::allInSet<CFGAnalyses>(); |
| 299 | return PreservedAnalyses::all(); |
| 300 | } |
| 301 | |