| 1 | //===-- VPlanPredicator.cpp - VPlan predicator ----------------------------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | /// |
| 9 | /// \file |
| 10 | /// This file implements predication for VPlans. |
| 11 | /// |
| 12 | //===----------------------------------------------------------------------===// |
| 13 | |
| 14 | #include "VPRecipeBuilder.h" |
| 15 | #include "VPlan.h" |
| 16 | #include "VPlanCFG.h" |
| 17 | #include "VPlanTransforms.h" |
| 18 | #include "VPlanUtils.h" |
| 19 | #include "llvm/ADT/PostOrderIterator.h" |
| 20 | |
| 21 | using namespace llvm; |
| 22 | |
| 23 | namespace { |
| 24 | class VPPredicator { |
| 25 | /// Builder to construct recipes to compute masks. |
| 26 | VPBuilder Builder; |
| 27 | |
| 28 | /// When we if-convert we need to create edge masks. We have to cache values |
| 29 | /// so that we don't end up with exponential recursion/IR. |
| 30 | using EdgeMaskCacheTy = |
| 31 | DenseMap<std::pair<const VPBasicBlock *, const VPBasicBlock *>, |
| 32 | VPValue *>; |
| 33 | using BlockMaskCacheTy = DenseMap<VPBasicBlock *, VPValue *>; |
| 34 | EdgeMaskCacheTy EdgeMaskCache; |
| 35 | |
| 36 | BlockMaskCacheTy BlockMaskCache; |
| 37 | |
| 38 | /// Create an edge mask for every destination of cases and/or default. |
| 39 | void createSwitchEdgeMasks(VPInstruction *SI); |
| 40 | |
| 41 | /// Computes and return the predicate of the edge between \p Src and \p Dst, |
| 42 | /// possibly inserting new recipes at \p Dst (using Builder's insertion point) |
| 43 | VPValue *createEdgeMask(VPBasicBlock *Src, VPBasicBlock *Dst); |
| 44 | |
| 45 | /// Returns the *entry* mask for \p VPBB. |
| 46 | VPValue *getBlockInMask(VPBasicBlock *VPBB) const { |
| 47 | return BlockMaskCache.lookup(Val: VPBB); |
| 48 | } |
| 49 | |
| 50 | /// Record \p Mask as the *entry* mask of \p VPBB, which is expected to not |
| 51 | /// already have a mask. |
| 52 | void setBlockInMask(VPBasicBlock *VPBB, VPValue *Mask) { |
| 53 | // TODO: Include the masks as operands in the predicated VPlan directly to |
| 54 | // avoid keeping the map of masks beyond the predication transform. |
| 55 | assert(!getBlockInMask(VPBB) && "Mask already set" ); |
| 56 | BlockMaskCache[VPBB] = Mask; |
| 57 | } |
| 58 | |
| 59 | /// Record \p Mask as the mask of the edge from \p Src to \p Dst. The edge is |
| 60 | /// expected to not have a mask already. |
| 61 | VPValue *setEdgeMask(const VPBasicBlock *Src, const VPBasicBlock *Dst, |
| 62 | VPValue *Mask) { |
| 63 | assert(Src != Dst && "Src and Dst must be different" ); |
| 64 | assert(!getEdgeMask(Src, Dst) && "Mask already set" ); |
| 65 | return EdgeMaskCache[{Src, Dst}] = Mask; |
| 66 | } |
| 67 | |
| 68 | public: |
| 69 | /// Returns the precomputed predicate of the edge from \p Src to \p Dst. |
| 70 | VPValue *getEdgeMask(const VPBasicBlock *Src, const VPBasicBlock *Dst) const { |
| 71 | return EdgeMaskCache.lookup(Val: {Src, Dst}); |
| 72 | } |
| 73 | |
| 74 | /// Compute and return the mask for the vector loop header block. |
| 75 | void createHeaderMask(VPBasicBlock *, bool FoldTail); |
| 76 | |
| 77 | /// Compute and return the predicate of \p VPBB, assuming that the header |
| 78 | /// block of the loop is set to True, or to the loop mask when tail folding. |
| 79 | VPValue *createBlockInMask(VPBasicBlock *VPBB); |
| 80 | |
| 81 | /// Convert phi recipes in \p VPBB to VPBlendRecipes. |
| 82 | void convertPhisToBlends(VPBasicBlock *VPBB); |
| 83 | |
| 84 | const BlockMaskCacheTy getBlockMaskCache() const { return BlockMaskCache; } |
| 85 | }; |
| 86 | } // namespace |
| 87 | |
| 88 | VPValue *VPPredicator::createEdgeMask(VPBasicBlock *Src, VPBasicBlock *Dst) { |
| 89 | assert(is_contained(Dst->getPredecessors(), Src) && "Invalid edge" ); |
| 90 | |
| 91 | // Look for cached value. |
| 92 | VPValue *EdgeMask = getEdgeMask(Src, Dst); |
| 93 | if (EdgeMask) |
| 94 | return EdgeMask; |
| 95 | |
| 96 | VPValue *SrcMask = getBlockInMask(VPBB: Src); |
| 97 | |
| 98 | // If there's a single successor, there's no terminator recipe. |
| 99 | if (Src->getNumSuccessors() == 1) |
| 100 | return setEdgeMask(Src, Dst, Mask: SrcMask); |
| 101 | |
| 102 | auto *Term = cast<VPInstruction>(Val: Src->getTerminator()); |
| 103 | if (Term->getOpcode() == Instruction::Switch) { |
| 104 | createSwitchEdgeMasks(SI: Term); |
| 105 | return getEdgeMask(Src, Dst); |
| 106 | } |
| 107 | |
| 108 | assert(Term->getOpcode() == VPInstruction::BranchOnCond && |
| 109 | "Unsupported terminator" ); |
| 110 | if (Src->getSuccessors()[0] == Src->getSuccessors()[1]) |
| 111 | return setEdgeMask(Src, Dst, Mask: SrcMask); |
| 112 | |
| 113 | EdgeMask = Term->getOperand(N: 0); |
| 114 | assert(EdgeMask && "No Edge Mask found for condition" ); |
| 115 | |
| 116 | if (Src->getSuccessors()[0] != Dst) |
| 117 | EdgeMask = Builder.createNot(Operand: EdgeMask, DL: Term->getDebugLoc()); |
| 118 | |
| 119 | if (SrcMask) { // Otherwise block in-mask is all-one, no need to AND. |
| 120 | // The bitwise 'And' of SrcMask and EdgeMask introduces new UB if SrcMask |
| 121 | // is false and EdgeMask is poison. Avoid that by using 'LogicalAnd' |
| 122 | // instead which generates 'select i1 SrcMask, i1 EdgeMask, i1 false'. |
| 123 | EdgeMask = Builder.createLogicalAnd(LHS: SrcMask, RHS: EdgeMask, DL: Term->getDebugLoc()); |
| 124 | } |
| 125 | |
| 126 | return setEdgeMask(Src, Dst, Mask: EdgeMask); |
| 127 | } |
| 128 | |
| 129 | VPValue *VPPredicator::createBlockInMask(VPBasicBlock *VPBB) { |
| 130 | // Start inserting after the block's phis, which be replaced by blends later. |
| 131 | Builder.setInsertPoint(TheBB: VPBB, IP: VPBB->getFirstNonPhi()); |
| 132 | // All-one mask is modelled as no-mask following the convention for masked |
| 133 | // load/store/gather/scatter. Initialize BlockMask to no-mask. |
| 134 | VPValue *BlockMask = nullptr; |
| 135 | // This is the block mask. We OR all unique incoming edges. |
| 136 | for (auto *Predecessor : SetVector<VPBlockBase *>( |
| 137 | VPBB->getPredecessors().begin(), VPBB->getPredecessors().end())) { |
| 138 | VPValue *EdgeMask = createEdgeMask(Src: cast<VPBasicBlock>(Val: Predecessor), Dst: VPBB); |
| 139 | if (!EdgeMask) { // Mask of predecessor is all-one so mask of block is |
| 140 | // too. |
| 141 | setBlockInMask(VPBB, Mask: EdgeMask); |
| 142 | return EdgeMask; |
| 143 | } |
| 144 | |
| 145 | if (!BlockMask) { // BlockMask has its initial nullptr value. |
| 146 | BlockMask = EdgeMask; |
| 147 | continue; |
| 148 | } |
| 149 | |
| 150 | BlockMask = Builder.createOr(LHS: BlockMask, RHS: EdgeMask, DL: {}); |
| 151 | } |
| 152 | |
| 153 | setBlockInMask(VPBB, Mask: BlockMask); |
| 154 | return BlockMask; |
| 155 | } |
| 156 | |
| 157 | void VPPredicator::(VPBasicBlock *, bool FoldTail) { |
| 158 | if (!FoldTail) { |
| 159 | setBlockInMask(VPBB: HeaderVPBB, Mask: nullptr); |
| 160 | return; |
| 161 | } |
| 162 | |
| 163 | // Introduce the early-exit compare IV <= BTC to form header block mask. |
| 164 | // This is used instead of IV < TC because TC may wrap, unlike BTC. Start by |
| 165 | // constructing the desired canonical IV in the header block as its first |
| 166 | // non-phi instructions. |
| 167 | |
| 168 | auto &Plan = *HeaderVPBB->getPlan(); |
| 169 | auto *IV = new VPWidenCanonicalIVRecipe(Plan.getCanonicalIV()); |
| 170 | Builder.setInsertPoint(TheBB: HeaderVPBB, IP: HeaderVPBB->getFirstNonPhi()); |
| 171 | Builder.insert(R: IV); |
| 172 | |
| 173 | VPValue *BTC = Plan.getOrCreateBackedgeTakenCount(); |
| 174 | VPValue *BlockMask = Builder.createICmp(Pred: CmpInst::ICMP_ULE, A: IV, B: BTC); |
| 175 | setBlockInMask(VPBB: HeaderVPBB, Mask: BlockMask); |
| 176 | } |
| 177 | |
| 178 | void VPPredicator::createSwitchEdgeMasks(VPInstruction *SI) { |
| 179 | VPBasicBlock *Src = SI->getParent(); |
| 180 | |
| 181 | // Create masks where SI is a switch. We create masks for all edges from SI's |
| 182 | // parent block at the same time. This is more efficient, as we can create and |
| 183 | // collect compares for all cases once. |
| 184 | VPValue *Cond = SI->getOperand(N: 0); |
| 185 | VPBasicBlock *DefaultDst = cast<VPBasicBlock>(Val: Src->getSuccessors()[0]); |
| 186 | MapVector<VPBasicBlock *, SmallVector<VPValue *>> Dst2Compares; |
| 187 | for (const auto &[Idx, Succ] : |
| 188 | enumerate(First: ArrayRef(Src->getSuccessors()).drop_front())) { |
| 189 | VPBasicBlock *Dst = cast<VPBasicBlock>(Val: Succ); |
| 190 | assert(!getEdgeMask(Src, Dst) && "Edge masks already created" ); |
| 191 | // Cases whose destination is the same as default are redundant and can |
| 192 | // be ignored - they will get there anyhow. |
| 193 | if (Dst == DefaultDst) |
| 194 | continue; |
| 195 | auto &Compares = Dst2Compares[Dst]; |
| 196 | VPValue *V = SI->getOperand(N: Idx + 1); |
| 197 | Compares.push_back(Elt: Builder.createICmp(Pred: CmpInst::ICMP_EQ, A: Cond, B: V)); |
| 198 | } |
| 199 | |
| 200 | // We need to handle 2 separate cases below for all entries in Dst2Compares, |
| 201 | // which excludes destinations matching the default destination. |
| 202 | VPValue *SrcMask = getBlockInMask(VPBB: Src); |
| 203 | VPValue *DefaultMask = nullptr; |
| 204 | for (const auto &[Dst, Conds] : Dst2Compares) { |
| 205 | // 1. Dst is not the default destination. Dst is reached if any of the |
| 206 | // cases with destination == Dst are taken. Join the conditions for each |
| 207 | // case whose destination == Dst using an OR. |
| 208 | VPValue *Mask = Conds[0]; |
| 209 | for (VPValue *V : ArrayRef<VPValue *>(Conds).drop_front()) |
| 210 | Mask = Builder.createOr(LHS: Mask, RHS: V); |
| 211 | if (SrcMask) |
| 212 | Mask = Builder.createLogicalAnd(LHS: SrcMask, RHS: Mask); |
| 213 | setEdgeMask(Src, Dst, Mask); |
| 214 | |
| 215 | // 2. Create the mask for the default destination, which is reached if |
| 216 | // none of the cases with destination != default destination are taken. |
| 217 | // Join the conditions for each case where the destination is != Dst using |
| 218 | // an OR and negate it. |
| 219 | DefaultMask = DefaultMask ? Builder.createOr(LHS: DefaultMask, RHS: Mask) : Mask; |
| 220 | } |
| 221 | |
| 222 | if (DefaultMask) { |
| 223 | DefaultMask = Builder.createNot(Operand: DefaultMask); |
| 224 | if (SrcMask) |
| 225 | DefaultMask = Builder.createLogicalAnd(LHS: SrcMask, RHS: DefaultMask); |
| 226 | } |
| 227 | setEdgeMask(Src, Dst: DefaultDst, Mask: DefaultMask); |
| 228 | } |
| 229 | |
| 230 | void VPPredicator::convertPhisToBlends(VPBasicBlock *VPBB) { |
| 231 | SmallVector<VPWidenPHIRecipe *> Phis; |
| 232 | for (VPRecipeBase &R : VPBB->phis()) |
| 233 | Phis.push_back(Elt: cast<VPWidenPHIRecipe>(Val: &R)); |
| 234 | for (VPWidenPHIRecipe *PhiR : Phis) { |
| 235 | // The non-header Phi is converted into a Blend recipe below, |
| 236 | // so we don't have to worry about the insertion order and we can just use |
| 237 | // the builder. At this point we generate the predication tree. There may |
| 238 | // be duplications since this is a simple recursive scan, but future |
| 239 | // optimizations will clean it up. |
| 240 | |
| 241 | SmallVector<VPValue *, 2> OperandsWithMask; |
| 242 | unsigned NumIncoming = PhiR->getNumIncoming(); |
| 243 | for (unsigned In = 0; In < NumIncoming; In++) { |
| 244 | const VPBasicBlock *Pred = PhiR->getIncomingBlock(Idx: In); |
| 245 | OperandsWithMask.push_back(Elt: PhiR->getIncomingValue(Idx: In)); |
| 246 | VPValue *EdgeMask = getEdgeMask(Src: Pred, Dst: VPBB); |
| 247 | if (!EdgeMask) { |
| 248 | assert(In == 0 && "Both null and non-null edge masks found" ); |
| 249 | assert(all_equal(PhiR->operands()) && |
| 250 | "Distinct incoming values with one having a full mask" ); |
| 251 | break; |
| 252 | } |
| 253 | OperandsWithMask.push_back(Elt: EdgeMask); |
| 254 | } |
| 255 | PHINode *IRPhi = cast<PHINode>(Val: PhiR->getUnderlyingValue()); |
| 256 | auto *Blend = new VPBlendRecipe(IRPhi, OperandsWithMask); |
| 257 | Builder.insert(R: Blend); |
| 258 | PhiR->replaceAllUsesWith(New: Blend); |
| 259 | PhiR->eraseFromParent(); |
| 260 | } |
| 261 | } |
| 262 | |
| 263 | DenseMap<VPBasicBlock *, VPValue *> |
| 264 | VPlanTransforms::introduceMasksAndLinearize(VPlan &Plan, bool FoldTail) { |
| 265 | VPRegionBlock *LoopRegion = Plan.getVectorLoopRegion(); |
| 266 | // Scan the body of the loop in a topological order to visit each basic block |
| 267 | // after having visited its predecessor basic blocks. |
| 268 | VPBasicBlock * = LoopRegion->getEntryBasicBlock(); |
| 269 | ReversePostOrderTraversal<VPBlockShallowTraversalWrapper<VPBlockBase *>> RPOT( |
| 270 | Header); |
| 271 | VPPredicator Predicator; |
| 272 | for (VPBlockBase *VPB : RPOT) { |
| 273 | // Non-outer regions with VPBBs only are supported at the moment. |
| 274 | auto *VPBB = cast<VPBasicBlock>(Val: VPB); |
| 275 | // Introduce the mask for VPBB, which may introduce needed edge masks, and |
| 276 | // convert all phi recipes of VPBB to blend recipes unless VPBB is the |
| 277 | // header. |
| 278 | if (VPBB == Header) { |
| 279 | Predicator.createHeaderMask(HeaderVPBB: Header, FoldTail); |
| 280 | continue; |
| 281 | } |
| 282 | |
| 283 | Predicator.createBlockInMask(VPBB); |
| 284 | Predicator.convertPhisToBlends(VPBB); |
| 285 | } |
| 286 | |
| 287 | // Linearize the blocks of the loop into one serial chain. |
| 288 | VPBlockBase *PrevVPBB = nullptr; |
| 289 | for (VPBasicBlock *VPBB : VPBlockUtils::blocksOnly<VPBasicBlock>(Range: RPOT)) { |
| 290 | auto Successors = to_vector(Range&: VPBB->getSuccessors()); |
| 291 | if (Successors.size() > 1) |
| 292 | VPBB->getTerminator()->eraseFromParent(); |
| 293 | |
| 294 | // Flatten the CFG in the loop. To do so, first disconnect VPBB from its |
| 295 | // successors. Then connect VPBB to the previously visited VPBB. |
| 296 | for (auto *Succ : Successors) |
| 297 | VPBlockUtils::disconnectBlocks(From: VPBB, To: Succ); |
| 298 | if (PrevVPBB) |
| 299 | VPBlockUtils::connectBlocks(From: PrevVPBB, To: VPBB); |
| 300 | |
| 301 | PrevVPBB = VPBB; |
| 302 | } |
| 303 | return Predicator.getBlockMaskCache(); |
| 304 | } |
| 305 | |