| 1 | //===-- VPlanPredicator.cpp - VPlan predicator ----------------------------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | /// |
| 9 | /// \file |
| 10 | /// This file implements predication for VPlans. |
| 11 | /// |
| 12 | //===----------------------------------------------------------------------===// |
| 13 | |
| 14 | #include "VPRecipeBuilder.h" |
| 15 | #include "VPlan.h" |
| 16 | #include "VPlanCFG.h" |
| 17 | #include "VPlanPatternMatch.h" |
| 18 | #include "VPlanTransforms.h" |
| 19 | #include "VPlanUtils.h" |
| 20 | #include "llvm/ADT/PostOrderIterator.h" |
| 21 | |
| 22 | using namespace llvm; |
| 23 | using namespace VPlanPatternMatch; |
| 24 | |
| 25 | namespace { |
| 26 | class VPPredicator { |
| 27 | /// Builder to construct recipes to compute masks. |
| 28 | VPBuilder Builder; |
| 29 | |
| 30 | /// When we if-convert we need to create edge masks. We have to cache values |
| 31 | /// so that we don't end up with exponential recursion/IR. |
| 32 | using EdgeMaskCacheTy = |
| 33 | DenseMap<std::pair<const VPBasicBlock *, const VPBasicBlock *>, |
| 34 | VPValue *>; |
| 35 | using BlockMaskCacheTy = DenseMap<VPBasicBlock *, VPValue *>; |
| 36 | EdgeMaskCacheTy EdgeMaskCache; |
| 37 | |
| 38 | BlockMaskCacheTy BlockMaskCache; |
| 39 | |
| 40 | /// Create an edge mask for every destination of cases and/or default. |
| 41 | void createSwitchEdgeMasks(VPInstruction *SI); |
| 42 | |
| 43 | /// Computes and return the predicate of the edge between \p Src and \p Dst, |
| 44 | /// possibly inserting new recipes at \p Dst (using Builder's insertion point) |
| 45 | VPValue *createEdgeMask(VPBasicBlock *Src, VPBasicBlock *Dst); |
| 46 | |
| 47 | /// Record \p Mask as the *entry* mask of \p VPBB, which is expected to not |
| 48 | /// already have a mask. |
| 49 | void setBlockInMask(VPBasicBlock *VPBB, VPValue *Mask) { |
| 50 | // TODO: Include the masks as operands in the predicated VPlan directly to |
| 51 | // avoid keeping the map of masks beyond the predication transform. |
| 52 | assert(!getBlockInMask(VPBB) && "Mask already set" ); |
| 53 | BlockMaskCache[VPBB] = Mask; |
| 54 | } |
| 55 | |
| 56 | /// Record \p Mask as the mask of the edge from \p Src to \p Dst. The edge is |
| 57 | /// expected to not have a mask already. |
| 58 | VPValue *setEdgeMask(const VPBasicBlock *Src, const VPBasicBlock *Dst, |
| 59 | VPValue *Mask) { |
| 60 | assert(Src != Dst && "Src and Dst must be different" ); |
| 61 | assert(!getEdgeMask(Src, Dst) && "Mask already set" ); |
| 62 | return EdgeMaskCache[{Src, Dst}] = Mask; |
| 63 | } |
| 64 | |
| 65 | public: |
| 66 | /// Returns the *entry* mask for \p VPBB. |
| 67 | VPValue *getBlockInMask(VPBasicBlock *VPBB) const { |
| 68 | return BlockMaskCache.lookup(Val: VPBB); |
| 69 | } |
| 70 | |
| 71 | /// Returns the precomputed predicate of the edge from \p Src to \p Dst. |
| 72 | VPValue *getEdgeMask(const VPBasicBlock *Src, const VPBasicBlock *Dst) const { |
| 73 | return EdgeMaskCache.lookup(Val: {Src, Dst}); |
| 74 | } |
| 75 | |
| 76 | /// Compute and return the mask for the vector loop header block. |
| 77 | void createHeaderMask(VPBasicBlock *, bool FoldTail); |
| 78 | |
| 79 | /// Compute and return the predicate of \p VPBB, assuming that the header |
| 80 | /// block of the loop is set to True, or to the loop mask when tail folding. |
| 81 | VPValue *createBlockInMask(VPBasicBlock *VPBB); |
| 82 | |
| 83 | /// Convert phi recipes in \p VPBB to VPBlendRecipes. |
| 84 | void convertPhisToBlends(VPBasicBlock *VPBB); |
| 85 | |
| 86 | const BlockMaskCacheTy getBlockMaskCache() const { return BlockMaskCache; } |
| 87 | }; |
| 88 | } // namespace |
| 89 | |
| 90 | VPValue *VPPredicator::createEdgeMask(VPBasicBlock *Src, VPBasicBlock *Dst) { |
| 91 | assert(is_contained(Dst->getPredecessors(), Src) && "Invalid edge" ); |
| 92 | |
| 93 | // Look for cached value. |
| 94 | VPValue *EdgeMask = getEdgeMask(Src, Dst); |
| 95 | if (EdgeMask) |
| 96 | return EdgeMask; |
| 97 | |
| 98 | VPValue *SrcMask = getBlockInMask(VPBB: Src); |
| 99 | |
| 100 | // If there's a single successor, there's no terminator recipe. |
| 101 | if (Src->getNumSuccessors() == 1) |
| 102 | return setEdgeMask(Src, Dst, Mask: SrcMask); |
| 103 | |
| 104 | auto *Term = cast<VPInstruction>(Val: Src->getTerminator()); |
| 105 | if (Term->getOpcode() == Instruction::Switch) { |
| 106 | createSwitchEdgeMasks(SI: Term); |
| 107 | return getEdgeMask(Src, Dst); |
| 108 | } |
| 109 | |
| 110 | assert(Term->getOpcode() == VPInstruction::BranchOnCond && |
| 111 | "Unsupported terminator" ); |
| 112 | if (Src->getSuccessors()[0] == Src->getSuccessors()[1]) |
| 113 | return setEdgeMask(Src, Dst, Mask: SrcMask); |
| 114 | |
| 115 | EdgeMask = Term->getOperand(N: 0); |
| 116 | assert(EdgeMask && "No Edge Mask found for condition" ); |
| 117 | |
| 118 | if (Src->getSuccessors()[0] != Dst) |
| 119 | EdgeMask = Builder.createNot(Operand: EdgeMask, DL: Term->getDebugLoc()); |
| 120 | |
| 121 | if (SrcMask) { // Otherwise block in-mask is all-one, no need to AND. |
| 122 | // The bitwise 'And' of SrcMask and EdgeMask introduces new UB if SrcMask |
| 123 | // is false and EdgeMask is poison. Avoid that by using 'LogicalAnd' |
| 124 | // instead which generates 'select i1 SrcMask, i1 EdgeMask, i1 false'. |
| 125 | EdgeMask = Builder.createLogicalAnd(LHS: SrcMask, RHS: EdgeMask, DL: Term->getDebugLoc()); |
| 126 | } |
| 127 | |
| 128 | return setEdgeMask(Src, Dst, Mask: EdgeMask); |
| 129 | } |
| 130 | |
| 131 | VPValue *VPPredicator::createBlockInMask(VPBasicBlock *VPBB) { |
| 132 | // Start inserting after the block's phis, which be replaced by blends later. |
| 133 | Builder.setInsertPoint(TheBB: VPBB, IP: VPBB->getFirstNonPhi()); |
| 134 | // All-one mask is modelled as no-mask following the convention for masked |
| 135 | // load/store/gather/scatter. Initialize BlockMask to no-mask. |
| 136 | VPValue *BlockMask = nullptr; |
| 137 | // This is the block mask. We OR all unique incoming edges. |
| 138 | for (auto *Predecessor : SetVector<VPBlockBase *>( |
| 139 | VPBB->getPredecessors().begin(), VPBB->getPredecessors().end())) { |
| 140 | VPValue *EdgeMask = createEdgeMask(Src: cast<VPBasicBlock>(Val: Predecessor), Dst: VPBB); |
| 141 | if (!EdgeMask) { // Mask of predecessor is all-one so mask of block is |
| 142 | // too. |
| 143 | setBlockInMask(VPBB, Mask: EdgeMask); |
| 144 | return EdgeMask; |
| 145 | } |
| 146 | |
| 147 | if (!BlockMask) { // BlockMask has its initial nullptr value. |
| 148 | BlockMask = EdgeMask; |
| 149 | continue; |
| 150 | } |
| 151 | |
| 152 | BlockMask = Builder.createOr(LHS: BlockMask, RHS: EdgeMask, DL: {}); |
| 153 | } |
| 154 | |
| 155 | setBlockInMask(VPBB, Mask: BlockMask); |
| 156 | return BlockMask; |
| 157 | } |
| 158 | |
| 159 | void VPPredicator::(VPBasicBlock *, bool FoldTail) { |
| 160 | if (!FoldTail) { |
| 161 | setBlockInMask(VPBB: HeaderVPBB, Mask: nullptr); |
| 162 | return; |
| 163 | } |
| 164 | |
| 165 | // Introduce the early-exit compare IV <= BTC to form header block mask. |
| 166 | // This is used instead of IV < TC because TC may wrap, unlike BTC. Start by |
| 167 | // constructing the desired canonical IV in the header block as its first |
| 168 | // non-phi instructions. |
| 169 | |
| 170 | auto &Plan = *HeaderVPBB->getPlan(); |
| 171 | auto *IV = |
| 172 | new VPWidenCanonicalIVRecipe(HeaderVPBB->getParent()->getCanonicalIV()); |
| 173 | Builder.setInsertPoint(TheBB: HeaderVPBB, IP: HeaderVPBB->getFirstNonPhi()); |
| 174 | Builder.insert(R: IV); |
| 175 | |
| 176 | VPValue *BTC = Plan.getOrCreateBackedgeTakenCount(); |
| 177 | VPValue *BlockMask = Builder.createICmp(Pred: CmpInst::ICMP_ULE, A: IV, B: BTC); |
| 178 | setBlockInMask(VPBB: HeaderVPBB, Mask: BlockMask); |
| 179 | } |
| 180 | |
| 181 | void VPPredicator::createSwitchEdgeMasks(VPInstruction *SI) { |
| 182 | VPBasicBlock *Src = SI->getParent(); |
| 183 | |
| 184 | // Create masks where SI is a switch. We create masks for all edges from SI's |
| 185 | // parent block at the same time. This is more efficient, as we can create and |
| 186 | // collect compares for all cases once. |
| 187 | VPValue *Cond = SI->getOperand(N: 0); |
| 188 | VPBasicBlock *DefaultDst = cast<VPBasicBlock>(Val: Src->getSuccessors()[0]); |
| 189 | MapVector<VPBasicBlock *, SmallVector<VPValue *>> Dst2Compares; |
| 190 | for (const auto &[Idx, Succ] : enumerate(First: drop_begin(RangeOrContainer&: Src->getSuccessors()))) { |
| 191 | VPBasicBlock *Dst = cast<VPBasicBlock>(Val: Succ); |
| 192 | assert(!getEdgeMask(Src, Dst) && "Edge masks already created" ); |
| 193 | // Cases whose destination is the same as default are redundant and can |
| 194 | // be ignored - they will get there anyhow. |
| 195 | if (Dst == DefaultDst) |
| 196 | continue; |
| 197 | auto &Compares = Dst2Compares[Dst]; |
| 198 | VPValue *V = SI->getOperand(N: Idx + 1); |
| 199 | Compares.push_back(Elt: Builder.createICmp(Pred: CmpInst::ICMP_EQ, A: Cond, B: V)); |
| 200 | } |
| 201 | |
| 202 | // We need to handle 2 separate cases below for all entries in Dst2Compares, |
| 203 | // which excludes destinations matching the default destination. |
| 204 | VPValue *SrcMask = getBlockInMask(VPBB: Src); |
| 205 | VPValue *DefaultMask = nullptr; |
| 206 | for (const auto &[Dst, Conds] : Dst2Compares) { |
| 207 | // 1. Dst is not the default destination. Dst is reached if any of the |
| 208 | // cases with destination == Dst are taken. Join the conditions for each |
| 209 | // case whose destination == Dst using an OR. |
| 210 | VPValue *Mask = Conds[0]; |
| 211 | for (VPValue *V : drop_begin(RangeOrContainer: Conds)) |
| 212 | Mask = Builder.createOr(LHS: Mask, RHS: V); |
| 213 | if (SrcMask) |
| 214 | Mask = Builder.createLogicalAnd(LHS: SrcMask, RHS: Mask); |
| 215 | setEdgeMask(Src, Dst, Mask); |
| 216 | |
| 217 | // 2. Create the mask for the default destination, which is reached if |
| 218 | // none of the cases with destination != default destination are taken. |
| 219 | // Join the conditions for each case where the destination is != Dst using |
| 220 | // an OR and negate it. |
| 221 | DefaultMask = DefaultMask ? Builder.createOr(LHS: DefaultMask, RHS: Mask) : Mask; |
| 222 | } |
| 223 | |
| 224 | if (DefaultMask) { |
| 225 | DefaultMask = Builder.createNot(Operand: DefaultMask); |
| 226 | if (SrcMask) |
| 227 | DefaultMask = Builder.createLogicalAnd(LHS: SrcMask, RHS: DefaultMask); |
| 228 | } |
| 229 | setEdgeMask(Src, Dst: DefaultDst, Mask: DefaultMask); |
| 230 | } |
| 231 | |
| 232 | void VPPredicator::convertPhisToBlends(VPBasicBlock *VPBB) { |
| 233 | SmallVector<VPPhi *> Phis; |
| 234 | for (VPRecipeBase &R : VPBB->phis()) |
| 235 | Phis.push_back(Elt: cast<VPPhi>(Val: &R)); |
| 236 | for (VPPhi *PhiR : Phis) { |
| 237 | // The non-header Phi is converted into a Blend recipe below, |
| 238 | // so we don't have to worry about the insertion order and we can just use |
| 239 | // the builder. At this point we generate the predication tree. There may |
| 240 | // be duplications since this is a simple recursive scan, but future |
| 241 | // optimizations will clean it up. |
| 242 | |
| 243 | if (all_equal(Range: PhiR->incoming_values())) { |
| 244 | PhiR->replaceAllUsesWith(New: PhiR->getIncomingValue(Idx: 0)); |
| 245 | PhiR->eraseFromParent(); |
| 246 | continue; |
| 247 | } |
| 248 | |
| 249 | SmallVector<VPValue *, 2> OperandsWithMask; |
| 250 | for (const auto &[InVPV, InVPBB] : PhiR->incoming_values_and_blocks()) { |
| 251 | OperandsWithMask.push_back(Elt: InVPV); |
| 252 | OperandsWithMask.push_back(Elt: getEdgeMask(Src: InVPBB, Dst: VPBB)); |
| 253 | } |
| 254 | PHINode *IRPhi = cast_or_null<PHINode>(Val: PhiR->getUnderlyingValue()); |
| 255 | auto *Blend = |
| 256 | new VPBlendRecipe(IRPhi, OperandsWithMask, PhiR->getDebugLoc()); |
| 257 | Builder.insert(R: Blend); |
| 258 | PhiR->replaceAllUsesWith(New: Blend); |
| 259 | PhiR->eraseFromParent(); |
| 260 | } |
| 261 | } |
| 262 | |
| 263 | DenseMap<VPBasicBlock *, VPValue *> |
| 264 | VPlanTransforms::introduceMasksAndLinearize(VPlan &Plan, bool FoldTail) { |
| 265 | VPRegionBlock *LoopRegion = Plan.getVectorLoopRegion(); |
| 266 | // Scan the body of the loop in a topological order to visit each basic block |
| 267 | // after having visited its predecessor basic blocks. |
| 268 | VPBasicBlock * = LoopRegion->getEntryBasicBlock(); |
| 269 | ReversePostOrderTraversal<VPBlockShallowTraversalWrapper<VPBlockBase *>> RPOT( |
| 270 | Header); |
| 271 | VPPredicator Predicator; |
| 272 | for (VPBlockBase *VPB : RPOT) { |
| 273 | // Non-outer regions with VPBBs only are supported at the moment. |
| 274 | auto *VPBB = cast<VPBasicBlock>(Val: VPB); |
| 275 | // Introduce the mask for VPBB, which may introduce needed edge masks, and |
| 276 | // convert all phi recipes of VPBB to blend recipes unless VPBB is the |
| 277 | // header. |
| 278 | if (VPBB == Header) { |
| 279 | Predicator.createHeaderMask(HeaderVPBB: Header, FoldTail); |
| 280 | continue; |
| 281 | } |
| 282 | |
| 283 | Predicator.createBlockInMask(VPBB); |
| 284 | Predicator.convertPhisToBlends(VPBB); |
| 285 | } |
| 286 | |
| 287 | // Linearize the blocks of the loop into one serial chain. |
| 288 | VPBlockBase *PrevVPBB = nullptr; |
| 289 | for (VPBasicBlock *VPBB : VPBlockUtils::blocksOnly<VPBasicBlock>(Range: RPOT)) { |
| 290 | auto Successors = to_vector(Range&: VPBB->getSuccessors()); |
| 291 | if (Successors.size() > 1) |
| 292 | VPBB->getTerminator()->eraseFromParent(); |
| 293 | |
| 294 | // Flatten the CFG in the loop. To do so, first disconnect VPBB from its |
| 295 | // successors. Then connect VPBB to the previously visited VPBB. |
| 296 | for (auto *Succ : Successors) |
| 297 | VPBlockUtils::disconnectBlocks(From: VPBB, To: Succ); |
| 298 | if (PrevVPBB) |
| 299 | VPBlockUtils::connectBlocks(From: PrevVPBB, To: VPBB); |
| 300 | |
| 301 | PrevVPBB = VPBB; |
| 302 | } |
| 303 | |
| 304 | // If we folded the tail and introduced a header mask, any extract of the |
| 305 | // last element must be updated to extract from the last active lane of the |
| 306 | // header mask instead (i.e., the lane corresponding to the last active |
| 307 | // iteration). |
| 308 | if (FoldTail) { |
| 309 | assert(Plan.getExitBlocks().size() == 1 && |
| 310 | "only a single-exit block is supported currently" ); |
| 311 | assert(Plan.getExitBlocks().front()->getSinglePredecessor() == |
| 312 | Plan.getMiddleBlock() && |
| 313 | "the exit block must have middle block as single predecessor" ); |
| 314 | |
| 315 | VPBuilder B(Plan.getMiddleBlock()->getTerminator()); |
| 316 | for (VPRecipeBase &R : *Plan.getMiddleBlock()) { |
| 317 | VPValue *Op; |
| 318 | if (!match(V: &R, P: m_ExtractLastLane(Op0: m_ExtractLastPart(Op0: m_VPValue(V&: Op))))) |
| 319 | continue; |
| 320 | |
| 321 | // Compute the index of the last active lane. |
| 322 | VPValue * = Predicator.getBlockInMask(VPBB: Header); |
| 323 | VPValue *LastActiveLane = |
| 324 | B.createNaryOp(Opcode: VPInstruction::LastActiveLane, Operands: HeaderMask); |
| 325 | auto *Ext = |
| 326 | B.createNaryOp(Opcode: VPInstruction::ExtractLane, Operands: {LastActiveLane, Op}); |
| 327 | R.getVPSingleValue()->replaceAllUsesWith(New: Ext); |
| 328 | } |
| 329 | } |
| 330 | return Predicator.getBlockMaskCache(); |
| 331 | } |
| 332 | |