| 1 | //===- MVETailPredication.cpp - MVE Tail Predication ------------*- C++ -*-===// | 
|---|
| 2 | // | 
|---|
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | 
|---|
| 4 | // See https://llvm.org/LICENSE.txt for license information. | 
|---|
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | 
|---|
| 6 | // | 
|---|
| 7 | //===----------------------------------------------------------------------===// | 
|---|
| 8 | // | 
|---|
| 9 | /// \file | 
|---|
| 10 | /// Armv8.1m introduced MVE, M-Profile Vector Extension, and low-overhead | 
|---|
| 11 | /// branches to help accelerate DSP applications. These two extensions, | 
|---|
| 12 | /// combined with a new form of predication called tail-predication, can be used | 
|---|
| 13 | /// to provide implicit vector predication within a low-overhead loop. | 
|---|
| 14 | /// This is implicit because the predicate of active/inactive lanes is | 
|---|
| 15 | /// calculated by hardware, and thus does not need to be explicitly passed | 
|---|
| 16 | /// to vector instructions. The instructions responsible for this are the | 
|---|
| 17 | /// DLSTP and WLSTP instructions, which setup a tail-predicated loop and the | 
|---|
| 18 | /// the total number of data elements processed by the loop. The loop-end | 
|---|
| 19 | /// LETP instruction is responsible for decrementing and setting the remaining | 
|---|
| 20 | /// elements to be processed and generating the mask of active lanes. | 
|---|
| 21 | /// | 
|---|
| 22 | /// The HardwareLoops pass inserts intrinsics identifying loops that the | 
|---|
| 23 | /// backend will attempt to convert into a low-overhead loop. The vectorizer is | 
|---|
| 24 | /// responsible for generating a vectorized loop in which the lanes are | 
|---|
| 25 | /// predicated upon an get.active.lane.mask intrinsic. This pass looks at these | 
|---|
| 26 | /// get.active.lane.mask intrinsic and attempts to convert them to VCTP | 
|---|
| 27 | /// instructions. This will be picked up by the ARM Low-overhead loop pass later | 
|---|
| 28 | /// in the backend, which performs the final transformation to a DLSTP or WLSTP | 
|---|
| 29 | /// tail-predicated loop. | 
|---|
| 30 | // | 
|---|
| 31 | //===----------------------------------------------------------------------===// | 
|---|
| 32 |  | 
|---|
| 33 | #include "ARM.h" | 
|---|
| 34 | #include "ARMSubtarget.h" | 
|---|
| 35 | #include "ARMTargetTransformInfo.h" | 
|---|
| 36 | #include "llvm/Analysis/LoopInfo.h" | 
|---|
| 37 | #include "llvm/Analysis/LoopPass.h" | 
|---|
| 38 | #include "llvm/Analysis/ScalarEvolution.h" | 
|---|
| 39 | #include "llvm/Analysis/ScalarEvolutionExpressions.h" | 
|---|
| 40 | #include "llvm/Analysis/TargetLibraryInfo.h" | 
|---|
| 41 | #include "llvm/Analysis/TargetTransformInfo.h" | 
|---|
| 42 | #include "llvm/Analysis/ValueTracking.h" | 
|---|
| 43 | #include "llvm/CodeGen/TargetPassConfig.h" | 
|---|
| 44 | #include "llvm/IR/IRBuilder.h" | 
|---|
| 45 | #include "llvm/IR/Instructions.h" | 
|---|
| 46 | #include "llvm/IR/IntrinsicsARM.h" | 
|---|
| 47 | #include "llvm/Support/Debug.h" | 
|---|
| 48 | #include "llvm/Transforms/Utils/BasicBlockUtils.h" | 
|---|
| 49 | #include "llvm/Transforms/Utils/Local.h" | 
|---|
| 50 | #include "llvm/Transforms/Utils/LoopUtils.h" | 
|---|
| 51 | #include "llvm/Transforms/Utils/ScalarEvolutionExpander.h" | 
|---|
| 52 |  | 
|---|
| 53 | using namespace llvm; | 
|---|
| 54 |  | 
|---|
| 55 | #define DEBUG_TYPE "mve-tail-predication" | 
|---|
| 56 | #define DESC "Transform predicated vector loops to use MVE tail predication" | 
|---|
| 57 |  | 
|---|
| 58 | cl::opt<TailPredication::Mode> EnableTailPredication( | 
|---|
| 59 | "tail-predication", cl::desc( "MVE tail-predication pass options"), | 
|---|
| 60 | cl::init(Val: TailPredication::Enabled), | 
|---|
| 61 | cl::values(clEnumValN(TailPredication::Disabled, "disabled", | 
|---|
| 62 | "Don't tail-predicate loops"), | 
|---|
| 63 | clEnumValN(TailPredication::EnabledNoReductions, | 
|---|
| 64 | "enabled-no-reductions", | 
|---|
| 65 | "Enable tail-predication, but not for reduction loops"), | 
|---|
| 66 | clEnumValN(TailPredication::Enabled, | 
|---|
| 67 | "enabled", | 
|---|
| 68 | "Enable tail-predication, including reduction loops"), | 
|---|
| 69 | clEnumValN(TailPredication::ForceEnabledNoReductions, | 
|---|
| 70 | "force-enabled-no-reductions", | 
|---|
| 71 | "Enable tail-predication, but not for reduction loops, " | 
|---|
| 72 | "and force this which might be unsafe"), | 
|---|
| 73 | clEnumValN(TailPredication::ForceEnabled, | 
|---|
| 74 | "force-enabled", | 
|---|
| 75 | "Enable tail-predication, including reduction loops, " | 
|---|
| 76 | "and force this which might be unsafe"))); | 
|---|
| 77 |  | 
|---|
| 78 |  | 
|---|
| 79 | namespace { | 
|---|
| 80 |  | 
|---|
| 81 | class MVETailPredication : public LoopPass { | 
|---|
| 82 | SmallVector<IntrinsicInst*, 4> MaskedInsts; | 
|---|
| 83 | Loop *L = nullptr; | 
|---|
| 84 | ScalarEvolution *SE = nullptr; | 
|---|
| 85 | TargetTransformInfo *TTI = nullptr; | 
|---|
| 86 | const ARMSubtarget *ST = nullptr; | 
|---|
| 87 |  | 
|---|
| 88 | public: | 
|---|
| 89 | static char ID; | 
|---|
| 90 |  | 
|---|
| 91 | MVETailPredication() : LoopPass(ID) { } | 
|---|
| 92 |  | 
|---|
| 93 | void getAnalysisUsage(AnalysisUsage &AU) const override { | 
|---|
| 94 | AU.addRequired<ScalarEvolutionWrapperPass>(); | 
|---|
| 95 | AU.addRequired<LoopInfoWrapperPass>(); | 
|---|
| 96 | AU.addRequired<TargetPassConfig>(); | 
|---|
| 97 | AU.addRequired<TargetTransformInfoWrapperPass>(); | 
|---|
| 98 | AU.addPreserved<LoopInfoWrapperPass>(); | 
|---|
| 99 | AU.setPreservesCFG(); | 
|---|
| 100 | } | 
|---|
| 101 |  | 
|---|
| 102 | bool runOnLoop(Loop *L, LPPassManager&) override; | 
|---|
| 103 |  | 
|---|
| 104 | private: | 
|---|
| 105 | /// Perform the relevant checks on the loop and convert active lane masks if | 
|---|
| 106 | /// possible. | 
|---|
| 107 | bool TryConvertActiveLaneMask(Value *TripCount); | 
|---|
| 108 |  | 
|---|
| 109 | /// Perform several checks on the arguments of @llvm.get.active.lane.mask | 
|---|
| 110 | /// intrinsic. E.g., check that the loop induction variable and the element | 
|---|
| 111 | /// count are of the form we expect, and also perform overflow checks for | 
|---|
| 112 | /// the new expressions that are created. | 
|---|
| 113 | const SCEV *IsSafeActiveMask(IntrinsicInst *ActiveLaneMask, Value *TripCount); | 
|---|
| 114 |  | 
|---|
| 115 | /// Insert the intrinsic to represent the effect of tail predication. | 
|---|
| 116 | void InsertVCTPIntrinsic(IntrinsicInst *ActiveLaneMask, Value *Start); | 
|---|
| 117 | }; | 
|---|
| 118 |  | 
|---|
| 119 | } // end namespace | 
|---|
| 120 |  | 
|---|
| 121 | bool MVETailPredication::runOnLoop(Loop *L, LPPassManager&) { | 
|---|
| 122 | if (skipLoop(L) || !EnableTailPredication) | 
|---|
| 123 | return false; | 
|---|
| 124 |  | 
|---|
| 125 | MaskedInsts.clear(); | 
|---|
| 126 | Function &F = *L->getHeader()->getParent(); | 
|---|
| 127 | auto &TPC = getAnalysis<TargetPassConfig>(); | 
|---|
| 128 | auto &TM = TPC.getTM<TargetMachine>(); | 
|---|
| 129 | ST = &TM.getSubtarget<ARMSubtarget>(F); | 
|---|
| 130 | TTI = &getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F); | 
|---|
| 131 | SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE(); | 
|---|
| 132 | this->L = L; | 
|---|
| 133 |  | 
|---|
| 134 | // The MVE and LOB extensions are combined to enable tail-predication, but | 
|---|
| 135 | // there's nothing preventing us from generating VCTP instructions for v8.1m. | 
|---|
| 136 | if (!ST->hasMVEIntegerOps() || !ST->hasV8_1MMainlineOps()) { | 
|---|
| 137 | LLVM_DEBUG(dbgs() << "ARM TP: Not a v8.1m.main+mve target.\n"); | 
|---|
| 138 | return false; | 
|---|
| 139 | } | 
|---|
| 140 |  | 
|---|
| 141 | BasicBlock * = L->getLoopPreheader(); | 
|---|
| 142 | if (!Preheader) | 
|---|
| 143 | return false; | 
|---|
| 144 |  | 
|---|
| 145 | auto FindLoopIterations = [](BasicBlock *BB) -> IntrinsicInst* { | 
|---|
| 146 | for (auto &I : *BB) { | 
|---|
| 147 | auto *Call = dyn_cast<IntrinsicInst>(Val: &I); | 
|---|
| 148 | if (!Call) | 
|---|
| 149 | continue; | 
|---|
| 150 |  | 
|---|
| 151 | Intrinsic::ID ID = Call->getIntrinsicID(); | 
|---|
| 152 | if (ID == Intrinsic::start_loop_iterations || | 
|---|
| 153 | ID == Intrinsic::test_start_loop_iterations) | 
|---|
| 154 | return cast<IntrinsicInst>(Val: &I); | 
|---|
| 155 | } | 
|---|
| 156 | return nullptr; | 
|---|
| 157 | }; | 
|---|
| 158 |  | 
|---|
| 159 | // Look for the hardware loop intrinsic that sets the iteration count. | 
|---|
| 160 | IntrinsicInst *Setup = FindLoopIterations(Preheader); | 
|---|
| 161 |  | 
|---|
| 162 | // The test.set iteration could live in the pre-preheader. | 
|---|
| 163 | if (!Setup) { | 
|---|
| 164 | if (!Preheader->getSinglePredecessor()) | 
|---|
| 165 | return false; | 
|---|
| 166 | Setup = FindLoopIterations(Preheader->getSinglePredecessor()); | 
|---|
| 167 | if (!Setup) | 
|---|
| 168 | return false; | 
|---|
| 169 | } | 
|---|
| 170 |  | 
|---|
| 171 | LLVM_DEBUG(dbgs() << "ARM TP: Running on Loop: "<< *L << *Setup << "\n"); | 
|---|
| 172 |  | 
|---|
| 173 | bool Changed = TryConvertActiveLaneMask(TripCount: Setup->getArgOperand(i: 0)); | 
|---|
| 174 |  | 
|---|
| 175 | return Changed; | 
|---|
| 176 | } | 
|---|
| 177 |  | 
|---|
| 178 | // The active lane intrinsic has this form: | 
|---|
| 179 | // | 
|---|
| 180 | //    @llvm.get.active.lane.mask(IV, TC) | 
|---|
| 181 | // | 
|---|
| 182 | // Here we perform checks that this intrinsic behaves as expected, | 
|---|
| 183 | // which means: | 
|---|
| 184 | // | 
|---|
| 185 | // 1) Check that the TripCount (TC) belongs to this loop (originally). | 
|---|
| 186 | // 2) The element count (TC) needs to be sufficiently large that the decrement | 
|---|
| 187 | //    of element counter doesn't overflow, which means that we need to prove: | 
|---|
| 188 | //        ceil(ElementCount / VectorWidth) >= TripCount | 
|---|
| 189 | //    by rounding up ElementCount up: | 
|---|
| 190 | //        ((ElementCount + (VectorWidth - 1)) / VectorWidth | 
|---|
| 191 | //    and evaluate if expression isKnownNonNegative: | 
|---|
| 192 | //        (((ElementCount + (VectorWidth - 1)) / VectorWidth) - TripCount | 
|---|
| 193 | // 3) The IV must be an induction phi with an increment equal to the | 
|---|
| 194 | //    vector width. | 
|---|
| 195 | const SCEV *MVETailPredication::IsSafeActiveMask(IntrinsicInst *ActiveLaneMask, | 
|---|
| 196 | Value *TripCount) { | 
|---|
| 197 | bool ForceTailPredication = | 
|---|
| 198 | EnableTailPredication == TailPredication::ForceEnabledNoReductions || | 
|---|
| 199 | EnableTailPredication == TailPredication::ForceEnabled; | 
|---|
| 200 |  | 
|---|
| 201 | Value *ElemCount = ActiveLaneMask->getOperand(i_nocapture: 1); | 
|---|
| 202 | bool Changed = false; | 
|---|
| 203 | if (!L->makeLoopInvariant(V: ElemCount, Changed)) | 
|---|
| 204 | return nullptr; | 
|---|
| 205 |  | 
|---|
| 206 | const SCEV *EC = SE->getSCEV(V: ElemCount); | 
|---|
| 207 | const SCEV *TC = SE->getSCEV(V: TripCount); | 
|---|
| 208 | int VectorWidth = | 
|---|
| 209 | cast<FixedVectorType>(Val: ActiveLaneMask->getType())->getNumElements(); | 
|---|
| 210 | if (VectorWidth != 2 && VectorWidth != 4 && VectorWidth != 8 && | 
|---|
| 211 | VectorWidth != 16) | 
|---|
| 212 | return nullptr; | 
|---|
| 213 | ConstantInt *ConstElemCount = nullptr; | 
|---|
| 214 |  | 
|---|
| 215 | // 1) Smoke tests that the original scalar loop TripCount (TC) belongs to | 
|---|
| 216 | // this loop.  The scalar tripcount corresponds the number of elements | 
|---|
| 217 | // processed by the loop, so we will refer to that from this point on. | 
|---|
| 218 | if (!SE->isLoopInvariant(S: EC, L)) { | 
|---|
| 219 | LLVM_DEBUG(dbgs() << "ARM TP: element count must be loop invariant.\n"); | 
|---|
| 220 | return nullptr; | 
|---|
| 221 | } | 
|---|
| 222 |  | 
|---|
| 223 | // 2) Find out if IV is an induction phi. Note that we can't use Loop | 
|---|
| 224 | // helpers here to get the induction variable, because the hardware loop is | 
|---|
| 225 | // no longer in loopsimplify form, and also the hwloop intrinsic uses a | 
|---|
| 226 | // different counter. Using SCEV, we check that the induction is of the | 
|---|
| 227 | // form i = i + 4, where the increment must be equal to the VectorWidth. | 
|---|
| 228 | auto *IV = ActiveLaneMask->getOperand(i_nocapture: 0); | 
|---|
| 229 | const SCEV *IVExpr = SE->getSCEV(V: IV); | 
|---|
| 230 | auto *AddExpr = dyn_cast<SCEVAddRecExpr>(Val: IVExpr); | 
|---|
| 231 |  | 
|---|
| 232 | if (!AddExpr) { | 
|---|
| 233 | LLVM_DEBUG(dbgs() << "ARM TP: induction not an add expr: "; IVExpr->dump()); | 
|---|
| 234 | return nullptr; | 
|---|
| 235 | } | 
|---|
| 236 | // Check that this AddRec is associated with this loop. | 
|---|
| 237 | if (AddExpr->getLoop() != L) { | 
|---|
| 238 | LLVM_DEBUG(dbgs() << "ARM TP: phi not part of this loop\n"); | 
|---|
| 239 | return nullptr; | 
|---|
| 240 | } | 
|---|
| 241 | auto *Step = dyn_cast<SCEVConstant>(Val: AddExpr->getOperand(i: 1)); | 
|---|
| 242 | if (!Step) { | 
|---|
| 243 | LLVM_DEBUG(dbgs() << "ARM TP: induction step is not a constant: "; | 
|---|
| 244 | AddExpr->getOperand(1)->dump()); | 
|---|
| 245 | return nullptr; | 
|---|
| 246 | } | 
|---|
| 247 | auto StepValue = Step->getValue()->getSExtValue(); | 
|---|
| 248 | if (VectorWidth != StepValue) { | 
|---|
| 249 | LLVM_DEBUG(dbgs() << "ARM TP: Step value "<< StepValue | 
|---|
| 250 | << " doesn't match vector width "<< VectorWidth << "\n"); | 
|---|
| 251 | return nullptr; | 
|---|
| 252 | } | 
|---|
| 253 |  | 
|---|
| 254 | if ((ConstElemCount = dyn_cast<ConstantInt>(Val: ElemCount))) { | 
|---|
| 255 | ConstantInt *TC = dyn_cast<ConstantInt>(Val: TripCount); | 
|---|
| 256 | if (!TC) { | 
|---|
| 257 | LLVM_DEBUG(dbgs() << "ARM TP: Constant tripcount expected in " | 
|---|
| 258 | "set.loop.iterations\n"); | 
|---|
| 259 | return nullptr; | 
|---|
| 260 | } | 
|---|
| 261 |  | 
|---|
| 262 | // Calculate 2 tripcount values and check that they are consistent with | 
|---|
| 263 | // each other. The TripCount for a predicated vector loop body is | 
|---|
| 264 | // ceil(ElementCount/Width), or floor((ElementCount+Width-1)/Width) as we | 
|---|
| 265 | // work it out here. | 
|---|
| 266 | uint64_t TC1 = TC->getZExtValue(); | 
|---|
| 267 | uint64_t TC2 = | 
|---|
| 268 | (ConstElemCount->getZExtValue() + VectorWidth - 1) / VectorWidth; | 
|---|
| 269 |  | 
|---|
| 270 | // If the tripcount values are inconsistent, we can't insert the VCTP and | 
|---|
| 271 | // trigger tail-predication; keep the intrinsic as a get.active.lane.mask | 
|---|
| 272 | // and legalize this. | 
|---|
| 273 | if (TC1 != TC2) { | 
|---|
| 274 | LLVM_DEBUG(dbgs() << "ARM TP: inconsistent constant tripcount values: " | 
|---|
| 275 | << TC1 << " from set.loop.iterations, and " | 
|---|
| 276 | << TC2 << " from get.active.lane.mask\n"); | 
|---|
| 277 | return nullptr; | 
|---|
| 278 | } | 
|---|
| 279 | } else if (!ForceTailPredication) { | 
|---|
| 280 | // 3) We need to prove that the sub expression that we create in the | 
|---|
| 281 | // tail-predicated loop body, which calculates the remaining elements to be | 
|---|
| 282 | // processed, is non-negative, i.e. it doesn't overflow: | 
|---|
| 283 | // | 
|---|
| 284 | //   ((ElementCount + VectorWidth - 1) / VectorWidth) - TripCount >= 0 | 
|---|
| 285 | // | 
|---|
| 286 | // This is true if: | 
|---|
| 287 | // | 
|---|
| 288 | //    TripCount == (ElementCount + VectorWidth - 1) / VectorWidth | 
|---|
| 289 | // | 
|---|
| 290 | // which what we will be using here. | 
|---|
| 291 | // | 
|---|
| 292 | const SCEV *VW = | 
|---|
| 293 | SE->getSCEV(V: ConstantInt::get(Ty: TripCount->getType(), V: VectorWidth)); | 
|---|
| 294 | // ElementCount + (VW-1): | 
|---|
| 295 | const SCEV *Start = AddExpr->getStart(); | 
|---|
| 296 | const SCEV *ECPlusVWMinus1 = SE->getAddExpr( | 
|---|
| 297 | LHS: EC, | 
|---|
| 298 | RHS: SE->getSCEV(V: ConstantInt::get(Ty: TripCount->getType(), V: VectorWidth - 1))); | 
|---|
| 299 |  | 
|---|
| 300 | // Ceil = ElementCount + (VW-1) / VW | 
|---|
| 301 | const SCEV *Ceil = SE->getUDivExpr(LHS: ECPlusVWMinus1, RHS: VW); | 
|---|
| 302 |  | 
|---|
| 303 | // Prevent unused variable warnings with TC | 
|---|
| 304 | (void)TC; | 
|---|
| 305 | LLVM_DEBUG({ | 
|---|
| 306 | dbgs() << "ARM TP: Analysing overflow behaviour for:\n"; | 
|---|
| 307 | dbgs() << "ARM TP: - TripCount = "<< *TC << "\n"; | 
|---|
| 308 | dbgs() << "ARM TP: - ElemCount = "<< *EC << "\n"; | 
|---|
| 309 | dbgs() << "ARM TP: - Start = "<< *Start << "\n"; | 
|---|
| 310 | dbgs() << "ARM TP: - BETC = "<< *SE->getBackedgeTakenCount(L) << "\n"; | 
|---|
| 311 | dbgs() << "ARM TP: - VecWidth =  "<< VectorWidth << "\n"; | 
|---|
| 312 | dbgs() << "ARM TP: - (ElemCount+VW-1) / VW = "<< *Ceil << "\n"; | 
|---|
| 313 | }); | 
|---|
| 314 |  | 
|---|
| 315 | // As an example, almost all the tripcount expressions (produced by the | 
|---|
| 316 | // vectoriser) look like this: | 
|---|
| 317 | // | 
|---|
| 318 | //   TC = ((-4 + (4 * ((3 + %N) /u 4))<nuw> - start) /u 4) | 
|---|
| 319 | // | 
|---|
| 320 | // and "ElementCount + (VW-1) / VW": | 
|---|
| 321 | // | 
|---|
| 322 | //   Ceil = ((3 + %N) /u 4) | 
|---|
| 323 | // | 
|---|
| 324 | // Check for equality of TC and Ceil by calculating SCEV expression | 
|---|
| 325 | // TC - Ceil and test it for zero. | 
|---|
| 326 | // | 
|---|
| 327 | const SCEV *Div = SE->getUDivExpr( | 
|---|
| 328 | LHS: SE->getAddExpr(Op0: SE->getMulExpr(LHS: Ceil, RHS: VW), Op1: SE->getNegativeSCEV(V: VW), | 
|---|
| 329 | Op2: SE->getNegativeSCEV(V: Start)), | 
|---|
| 330 | RHS: VW); | 
|---|
| 331 | const SCEV *Sub = SE->getMinusSCEV(LHS: SE->getBackedgeTakenCount(L), RHS: Div); | 
|---|
| 332 | LLVM_DEBUG(dbgs() << "ARM TP: - Sub       = "; Sub->dump()); | 
|---|
| 333 |  | 
|---|
| 334 | // Use context sensitive facts about the path to the loop to refine.  This | 
|---|
| 335 | // comes up as the backedge taken count can incorporate context sensitive | 
|---|
| 336 | // reasoning, and our RHS just above doesn't. | 
|---|
| 337 | Sub = SE->applyLoopGuards(Expr: Sub, L); | 
|---|
| 338 | LLVM_DEBUG(dbgs() << "ARM TP: - (Guarded) = "; Sub->dump()); | 
|---|
| 339 |  | 
|---|
| 340 | if (!Sub->isZero()) { | 
|---|
| 341 | LLVM_DEBUG(dbgs() << "ARM TP: possible overflow in sub expression.\n"); | 
|---|
| 342 | return nullptr; | 
|---|
| 343 | } | 
|---|
| 344 | } | 
|---|
| 345 |  | 
|---|
| 346 | // Check that the start value is a multiple of the VectorWidth. | 
|---|
| 347 | // TODO: This could do with a method to check if the scev is a multiple of | 
|---|
| 348 | // VectorWidth. For the moment we just check for constants, muls and unknowns | 
|---|
| 349 | // (which use MaskedValueIsZero and seems to be the most common). | 
|---|
| 350 | if (auto *BaseC = dyn_cast<SCEVConstant>(Val: AddExpr->getStart())) { | 
|---|
| 351 | if (BaseC->getAPInt().urem(RHS: VectorWidth) == 0) | 
|---|
| 352 | return SE->getMinusSCEV(LHS: EC, RHS: BaseC); | 
|---|
| 353 | } else if (auto *BaseV = dyn_cast<SCEVUnknown>(Val: AddExpr->getStart())) { | 
|---|
| 354 | Type *Ty = BaseV->getType(); | 
|---|
| 355 | APInt Mask = APInt::getLowBitsSet(numBits: Ty->getPrimitiveSizeInBits(), | 
|---|
| 356 | loBitsSet: Log2_64(Value: VectorWidth)); | 
|---|
| 357 | if (MaskedValueIsZero(V: BaseV->getValue(), Mask, | 
|---|
| 358 | SQ: L->getHeader()->getDataLayout())) | 
|---|
| 359 | return SE->getMinusSCEV(LHS: EC, RHS: BaseV); | 
|---|
| 360 | } else if (auto *BaseMul = dyn_cast<SCEVMulExpr>(Val: AddExpr->getStart())) { | 
|---|
| 361 | if (auto *BaseC = dyn_cast<SCEVConstant>(Val: BaseMul->getOperand(i: 0))) | 
|---|
| 362 | if (BaseC->getAPInt().urem(RHS: VectorWidth) == 0) | 
|---|
| 363 | return SE->getMinusSCEV(LHS: EC, RHS: BaseC); | 
|---|
| 364 | if (auto *BaseC = dyn_cast<SCEVConstant>(Val: BaseMul->getOperand(i: 1))) | 
|---|
| 365 | if (BaseC->getAPInt().urem(RHS: VectorWidth) == 0) | 
|---|
| 366 | return SE->getMinusSCEV(LHS: EC, RHS: BaseC); | 
|---|
| 367 | } | 
|---|
| 368 |  | 
|---|
| 369 | LLVM_DEBUG( | 
|---|
| 370 | dbgs() << "ARM TP: induction base is not know to be a multiple of VF: " | 
|---|
| 371 | << *AddExpr->getOperand(0) << "\n"); | 
|---|
| 372 | return nullptr; | 
|---|
| 373 | } | 
|---|
| 374 |  | 
|---|
| 375 | void MVETailPredication::InsertVCTPIntrinsic(IntrinsicInst *ActiveLaneMask, | 
|---|
| 376 | Value *Start) { | 
|---|
| 377 | IRBuilder<> Builder(L->getLoopPreheader()->getTerminator()); | 
|---|
| 378 | Module *M = L->getHeader()->getModule(); | 
|---|
| 379 | Type *Ty = IntegerType::get(C&: M->getContext(), NumBits: 32); | 
|---|
| 380 | unsigned VectorWidth = | 
|---|
| 381 | cast<FixedVectorType>(Val: ActiveLaneMask->getType())->getNumElements(); | 
|---|
| 382 |  | 
|---|
| 383 | // Insert a phi to count the number of elements processed by the loop. | 
|---|
| 384 | Builder.SetInsertPoint(TheBB: L->getHeader(), IP: L->getHeader()->getFirstNonPHIIt()); | 
|---|
| 385 | PHINode *Processed = Builder.CreatePHI(Ty, NumReservedValues: 2); | 
|---|
| 386 | Processed->addIncoming(V: Start, BB: L->getLoopPreheader()); | 
|---|
| 387 |  | 
|---|
| 388 | // Replace @llvm.get.active.mask() with the ARM specific VCTP intrinic, and | 
|---|
| 389 | // thus represent the effect of tail predication. | 
|---|
| 390 | Builder.SetInsertPoint(ActiveLaneMask); | 
|---|
| 391 | ConstantInt *Factor = ConstantInt::get(Ty: cast<IntegerType>(Val: Ty), V: VectorWidth); | 
|---|
| 392 |  | 
|---|
| 393 | Intrinsic::ID VCTPID; | 
|---|
| 394 | switch (VectorWidth) { | 
|---|
| 395 | default: | 
|---|
| 396 | llvm_unreachable( "unexpected number of lanes"); | 
|---|
| 397 | case 2:  VCTPID = Intrinsic::arm_mve_vctp64; break; | 
|---|
| 398 | case 4:  VCTPID = Intrinsic::arm_mve_vctp32; break; | 
|---|
| 399 | case 8:  VCTPID = Intrinsic::arm_mve_vctp16; break; | 
|---|
| 400 | case 16: VCTPID = Intrinsic::arm_mve_vctp8; break; | 
|---|
| 401 | } | 
|---|
| 402 | Value *VCTPCall = Builder.CreateIntrinsic(ID: VCTPID, Args: Processed); | 
|---|
| 403 | ActiveLaneMask->replaceAllUsesWith(V: VCTPCall); | 
|---|
| 404 |  | 
|---|
| 405 | // Add the incoming value to the new phi. | 
|---|
| 406 | // TODO: This add likely already exists in the loop. | 
|---|
| 407 | Value *Remaining = Builder.CreateSub(LHS: Processed, RHS: Factor); | 
|---|
| 408 | Processed->addIncoming(V: Remaining, BB: L->getLoopLatch()); | 
|---|
| 409 | LLVM_DEBUG(dbgs() << "ARM TP: Insert processed elements phi: " | 
|---|
| 410 | << *Processed << "\n" | 
|---|
| 411 | << "ARM TP: Inserted VCTP: "<< *VCTPCall << "\n"); | 
|---|
| 412 | } | 
|---|
| 413 |  | 
|---|
| 414 | bool MVETailPredication::TryConvertActiveLaneMask(Value *TripCount) { | 
|---|
| 415 | SmallVector<IntrinsicInst *, 4> ActiveLaneMasks; | 
|---|
| 416 | for (auto *BB : L->getBlocks()) | 
|---|
| 417 | for (auto &I : *BB) | 
|---|
| 418 | if (auto *Int = dyn_cast<IntrinsicInst>(Val: &I)) | 
|---|
| 419 | if (Int->getIntrinsicID() == Intrinsic::get_active_lane_mask) | 
|---|
| 420 | ActiveLaneMasks.push_back(Elt: Int); | 
|---|
| 421 |  | 
|---|
| 422 | if (ActiveLaneMasks.empty()) | 
|---|
| 423 | return false; | 
|---|
| 424 |  | 
|---|
| 425 | LLVM_DEBUG(dbgs() << "ARM TP: Found predicated vector loop.\n"); | 
|---|
| 426 |  | 
|---|
| 427 | for (auto *ActiveLaneMask : ActiveLaneMasks) { | 
|---|
| 428 | LLVM_DEBUG(dbgs() << "ARM TP: Found active lane mask: " | 
|---|
| 429 | << *ActiveLaneMask << "\n"); | 
|---|
| 430 |  | 
|---|
| 431 | const SCEV *StartSCEV = IsSafeActiveMask(ActiveLaneMask, TripCount); | 
|---|
| 432 | if (!StartSCEV) { | 
|---|
| 433 | LLVM_DEBUG(dbgs() << "ARM TP: Not safe to insert VCTP.\n"); | 
|---|
| 434 | return false; | 
|---|
| 435 | } | 
|---|
| 436 | LLVM_DEBUG(dbgs() << "ARM TP: Safe to insert VCTP. Start is "<< *StartSCEV | 
|---|
| 437 | << "\n"); | 
|---|
| 438 | SCEVExpander Expander(*SE, L->getHeader()->getDataLayout(), | 
|---|
| 439 | "start"); | 
|---|
| 440 | Instruction *Ins = L->getLoopPreheader()->getTerminator(); | 
|---|
| 441 | Value *Start = Expander.expandCodeFor(SH: StartSCEV, Ty: StartSCEV->getType(), I: Ins); | 
|---|
| 442 | LLVM_DEBUG(dbgs() << "ARM TP: Created start value "<< *Start << "\n"); | 
|---|
| 443 | InsertVCTPIntrinsic(ActiveLaneMask, Start); | 
|---|
| 444 | } | 
|---|
| 445 |  | 
|---|
| 446 | // Remove dead instructions and now dead phis. | 
|---|
| 447 | for (auto *II : ActiveLaneMasks) | 
|---|
| 448 | RecursivelyDeleteTriviallyDeadInstructions(V: II); | 
|---|
| 449 | for (auto *I : L->blocks()) | 
|---|
| 450 | DeleteDeadPHIs(BB: I); | 
|---|
| 451 | return true; | 
|---|
| 452 | } | 
|---|
| 453 |  | 
|---|
| 454 | Pass *llvm::createMVETailPredicationPass() { | 
|---|
| 455 | return new MVETailPredication(); | 
|---|
| 456 | } | 
|---|
| 457 |  | 
|---|
| 458 | char MVETailPredication::ID = 0; | 
|---|
| 459 |  | 
|---|
| 460 | INITIALIZE_PASS_BEGIN(MVETailPredication, DEBUG_TYPE, DESC, false, false) | 
|---|
| 461 | INITIALIZE_PASS_END(MVETailPredication, DEBUG_TYPE, DESC, false, false) | 
|---|
| 462 |  | 
|---|