| 1 | //===-- VPlanPredicator.cpp - VPlan predicator ----------------------------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | /// |
| 9 | /// \file |
| 10 | /// This file implements predication for VPlans. |
| 11 | /// |
| 12 | //===----------------------------------------------------------------------===// |
| 13 | |
| 14 | #include "VPRecipeBuilder.h" |
| 15 | #include "VPlan.h" |
| 16 | #include "VPlanCFG.h" |
| 17 | #include "VPlanPatternMatch.h" |
| 18 | #include "VPlanTransforms.h" |
| 19 | #include "VPlanUtils.h" |
| 20 | #include "llvm/ADT/PostOrderIterator.h" |
| 21 | |
| 22 | using namespace llvm; |
| 23 | using namespace VPlanPatternMatch; |
| 24 | |
| 25 | namespace { |
| 26 | class VPPredicator { |
| 27 | /// Builder to construct recipes to compute masks. |
| 28 | VPBuilder Builder; |
| 29 | |
| 30 | /// When we if-convert we need to create edge masks. We have to cache values |
| 31 | /// so that we don't end up with exponential recursion/IR. |
| 32 | using EdgeMaskCacheTy = |
| 33 | DenseMap<std::pair<const VPBasicBlock *, const VPBasicBlock *>, |
| 34 | VPValue *>; |
| 35 | using BlockMaskCacheTy = DenseMap<VPBasicBlock *, VPValue *>; |
| 36 | EdgeMaskCacheTy EdgeMaskCache; |
| 37 | |
| 38 | BlockMaskCacheTy BlockMaskCache; |
| 39 | |
| 40 | /// Create an edge mask for every destination of cases and/or default. |
| 41 | void createSwitchEdgeMasks(VPInstruction *SI); |
| 42 | |
| 43 | /// Computes and return the predicate of the edge between \p Src and \p Dst, |
| 44 | /// possibly inserting new recipes at \p Dst (using Builder's insertion point) |
| 45 | VPValue *createEdgeMask(VPBasicBlock *Src, VPBasicBlock *Dst); |
| 46 | |
| 47 | /// Record \p Mask as the *entry* mask of \p VPBB, which is expected to not |
| 48 | /// already have a mask. |
| 49 | void setBlockInMask(VPBasicBlock *VPBB, VPValue *Mask) { |
| 50 | // TODO: Include the masks as operands in the predicated VPlan directly to |
| 51 | // avoid keeping the map of masks beyond the predication transform. |
| 52 | assert(!getBlockInMask(VPBB) && "Mask already set" ); |
| 53 | BlockMaskCache[VPBB] = Mask; |
| 54 | } |
| 55 | |
| 56 | /// Record \p Mask as the mask of the edge from \p Src to \p Dst. The edge is |
| 57 | /// expected to not have a mask already. |
| 58 | VPValue *setEdgeMask(const VPBasicBlock *Src, const VPBasicBlock *Dst, |
| 59 | VPValue *Mask) { |
| 60 | assert(Src != Dst && "Src and Dst must be different" ); |
| 61 | assert(!getEdgeMask(Src, Dst) && "Mask already set" ); |
| 62 | return EdgeMaskCache[{Src, Dst}] = Mask; |
| 63 | } |
| 64 | |
| 65 | public: |
| 66 | /// Returns the *entry* mask for \p VPBB. |
| 67 | VPValue *getBlockInMask(VPBasicBlock *VPBB) const { |
| 68 | return BlockMaskCache.lookup(Val: VPBB); |
| 69 | } |
| 70 | |
| 71 | /// Returns the precomputed predicate of the edge from \p Src to \p Dst. |
| 72 | VPValue *getEdgeMask(const VPBasicBlock *Src, const VPBasicBlock *Dst) const { |
| 73 | return EdgeMaskCache.lookup(Val: {Src, Dst}); |
| 74 | } |
| 75 | |
| 76 | /// Compute and return the mask for the vector loop header block. |
| 77 | void createHeaderMask(VPBasicBlock *, bool FoldTail); |
| 78 | |
| 79 | /// Compute the predicate of \p VPBB, assuming that the header block of the |
| 80 | /// loop is set to True, or to the loop mask when tail folding. |
| 81 | void createBlockInMask(VPBasicBlock *VPBB); |
| 82 | |
| 83 | /// Convert phi recipes in \p VPBB to VPBlendRecipes. |
| 84 | void convertPhisToBlends(VPBasicBlock *VPBB); |
| 85 | }; |
| 86 | } // namespace |
| 87 | |
| 88 | VPValue *VPPredicator::createEdgeMask(VPBasicBlock *Src, VPBasicBlock *Dst) { |
| 89 | assert(is_contained(Dst->getPredecessors(), Src) && "Invalid edge" ); |
| 90 | |
| 91 | // Look for cached value. |
| 92 | VPValue *EdgeMask = getEdgeMask(Src, Dst); |
| 93 | if (EdgeMask) |
| 94 | return EdgeMask; |
| 95 | |
| 96 | VPValue *SrcMask = getBlockInMask(VPBB: Src); |
| 97 | |
| 98 | // If there's a single successor, there's no terminator recipe. |
| 99 | if (Src->getNumSuccessors() == 1) |
| 100 | return setEdgeMask(Src, Dst, Mask: SrcMask); |
| 101 | |
| 102 | auto *Term = cast<VPInstruction>(Val: Src->getTerminator()); |
| 103 | if (Term->getOpcode() == Instruction::Switch) { |
| 104 | createSwitchEdgeMasks(SI: Term); |
| 105 | return getEdgeMask(Src, Dst); |
| 106 | } |
| 107 | |
| 108 | assert(Term->getOpcode() == VPInstruction::BranchOnCond && |
| 109 | "Unsupported terminator" ); |
| 110 | if (Src->getSuccessors()[0] == Src->getSuccessors()[1]) |
| 111 | return setEdgeMask(Src, Dst, Mask: SrcMask); |
| 112 | |
| 113 | EdgeMask = Term->getOperand(N: 0); |
| 114 | assert(EdgeMask && "No Edge Mask found for condition" ); |
| 115 | |
| 116 | if (Src->getSuccessors()[0] != Dst) |
| 117 | EdgeMask = Builder.createNot(Operand: EdgeMask, DL: Term->getDebugLoc()); |
| 118 | |
| 119 | if (SrcMask) { // Otherwise block in-mask is all-one, no need to AND. |
| 120 | // The bitwise 'And' of SrcMask and EdgeMask introduces new UB if SrcMask |
| 121 | // is false and EdgeMask is poison. Avoid that by using 'LogicalAnd' |
| 122 | // instead which generates 'select i1 SrcMask, i1 EdgeMask, i1 false'. |
| 123 | EdgeMask = Builder.createLogicalAnd(LHS: SrcMask, RHS: EdgeMask, DL: Term->getDebugLoc()); |
| 124 | } |
| 125 | |
| 126 | return setEdgeMask(Src, Dst, Mask: EdgeMask); |
| 127 | } |
| 128 | |
| 129 | void VPPredicator::createBlockInMask(VPBasicBlock *VPBB) { |
| 130 | // Start inserting after the block's phis, which be replaced by blends later. |
| 131 | Builder.setInsertPoint(TheBB: VPBB, IP: VPBB->getFirstNonPhi()); |
| 132 | // All-one mask is modelled as no-mask following the convention for masked |
| 133 | // load/store/gather/scatter. Initialize BlockMask to no-mask. |
| 134 | VPValue *BlockMask = nullptr; |
| 135 | // This is the block mask. We OR all unique incoming edges. |
| 136 | for (auto *Predecessor : SetVector<VPBlockBase *>( |
| 137 | VPBB->getPredecessors().begin(), VPBB->getPredecessors().end())) { |
| 138 | VPValue *EdgeMask = createEdgeMask(Src: cast<VPBasicBlock>(Val: Predecessor), Dst: VPBB); |
| 139 | if (!EdgeMask) { // Mask of predecessor is all-one so mask of block is |
| 140 | // too. |
| 141 | setBlockInMask(VPBB, Mask: EdgeMask); |
| 142 | return; |
| 143 | } |
| 144 | |
| 145 | if (!BlockMask) { // BlockMask has its initial nullptr value. |
| 146 | BlockMask = EdgeMask; |
| 147 | continue; |
| 148 | } |
| 149 | |
| 150 | BlockMask = Builder.createOr(LHS: BlockMask, RHS: EdgeMask, DL: {}); |
| 151 | } |
| 152 | |
| 153 | setBlockInMask(VPBB, Mask: BlockMask); |
| 154 | } |
| 155 | |
| 156 | void VPPredicator::(VPBasicBlock *, bool FoldTail) { |
| 157 | if (!FoldTail) { |
| 158 | setBlockInMask(VPBB: HeaderVPBB, Mask: nullptr); |
| 159 | return; |
| 160 | } |
| 161 | |
| 162 | // Introduce the early-exit compare IV <= BTC to form header block mask. |
| 163 | // This is used instead of IV < TC because TC may wrap, unlike BTC. Start by |
| 164 | // constructing the desired canonical IV in the header block as its first |
| 165 | // non-phi instructions. |
| 166 | |
| 167 | auto &Plan = *HeaderVPBB->getPlan(); |
| 168 | auto *IV = |
| 169 | new VPWidenCanonicalIVRecipe(HeaderVPBB->getParent()->getCanonicalIV()); |
| 170 | Builder.setInsertPoint(TheBB: HeaderVPBB, IP: HeaderVPBB->getFirstNonPhi()); |
| 171 | Builder.insert(R: IV); |
| 172 | |
| 173 | VPValue *BTC = Plan.getOrCreateBackedgeTakenCount(); |
| 174 | VPValue *BlockMask = Builder.createICmp(Pred: CmpInst::ICMP_ULE, A: IV, B: BTC); |
| 175 | setBlockInMask(VPBB: HeaderVPBB, Mask: BlockMask); |
| 176 | } |
| 177 | |
| 178 | void VPPredicator::createSwitchEdgeMasks(VPInstruction *SI) { |
| 179 | VPBasicBlock *Src = SI->getParent(); |
| 180 | |
| 181 | // Create masks where SI is a switch. We create masks for all edges from SI's |
| 182 | // parent block at the same time. This is more efficient, as we can create and |
| 183 | // collect compares for all cases once. |
| 184 | VPValue *Cond = SI->getOperand(N: 0); |
| 185 | VPBasicBlock *DefaultDst = cast<VPBasicBlock>(Val: Src->getSuccessors()[0]); |
| 186 | MapVector<VPBasicBlock *, SmallVector<VPValue *>> Dst2Compares; |
| 187 | for (const auto &[Idx, Succ] : enumerate(First: drop_begin(RangeOrContainer&: Src->getSuccessors()))) { |
| 188 | VPBasicBlock *Dst = cast<VPBasicBlock>(Val: Succ); |
| 189 | assert(!getEdgeMask(Src, Dst) && "Edge masks already created" ); |
| 190 | // Cases whose destination is the same as default are redundant and can |
| 191 | // be ignored - they will get there anyhow. |
| 192 | if (Dst == DefaultDst) |
| 193 | continue; |
| 194 | auto &Compares = Dst2Compares[Dst]; |
| 195 | VPValue *V = SI->getOperand(N: Idx + 1); |
| 196 | Compares.push_back(Elt: Builder.createICmp(Pred: CmpInst::ICMP_EQ, A: Cond, B: V)); |
| 197 | } |
| 198 | |
| 199 | // We need to handle 2 separate cases below for all entries in Dst2Compares, |
| 200 | // which excludes destinations matching the default destination. |
| 201 | VPValue *SrcMask = getBlockInMask(VPBB: Src); |
| 202 | VPValue *DefaultMask = nullptr; |
| 203 | for (const auto &[Dst, Conds] : Dst2Compares) { |
| 204 | // 1. Dst is not the default destination. Dst is reached if any of the |
| 205 | // cases with destination == Dst are taken. Join the conditions for each |
| 206 | // case whose destination == Dst using an OR. |
| 207 | VPValue *Mask = Conds[0]; |
| 208 | for (VPValue *V : drop_begin(RangeOrContainer: Conds)) |
| 209 | Mask = Builder.createOr(LHS: Mask, RHS: V); |
| 210 | if (SrcMask) |
| 211 | Mask = Builder.createLogicalAnd(LHS: SrcMask, RHS: Mask); |
| 212 | setEdgeMask(Src, Dst, Mask); |
| 213 | |
| 214 | // 2. Create the mask for the default destination, which is reached if |
| 215 | // none of the cases with destination != default destination are taken. |
| 216 | // Join the conditions for each case where the destination is != Dst using |
| 217 | // an OR and negate it. |
| 218 | DefaultMask = DefaultMask ? Builder.createOr(LHS: DefaultMask, RHS: Mask) : Mask; |
| 219 | } |
| 220 | |
| 221 | if (DefaultMask) { |
| 222 | DefaultMask = Builder.createNot(Operand: DefaultMask); |
| 223 | if (SrcMask) |
| 224 | DefaultMask = Builder.createLogicalAnd(LHS: SrcMask, RHS: DefaultMask); |
| 225 | } else { |
| 226 | // There are no destinations other than the default destination, so this is |
| 227 | // an unconditional branch. |
| 228 | DefaultMask = SrcMask; |
| 229 | } |
| 230 | setEdgeMask(Src, Dst: DefaultDst, Mask: DefaultMask); |
| 231 | } |
| 232 | |
| 233 | void VPPredicator::convertPhisToBlends(VPBasicBlock *VPBB) { |
| 234 | SmallVector<VPPhi *> Phis; |
| 235 | for (VPRecipeBase &R : VPBB->phis()) |
| 236 | Phis.push_back(Elt: cast<VPPhi>(Val: &R)); |
| 237 | for (VPPhi *PhiR : Phis) { |
| 238 | // The non-header Phi is converted into a Blend recipe below, |
| 239 | // so we don't have to worry about the insertion order and we can just use |
| 240 | // the builder. At this point we generate the predication tree. There may |
| 241 | // be duplications since this is a simple recursive scan, but future |
| 242 | // optimizations will clean it up. |
| 243 | |
| 244 | auto NotPoison = make_filter_range(Range: PhiR->incoming_values(), Pred: [](VPValue *V) { |
| 245 | return !match(V, P: m_Poison()); |
| 246 | }); |
| 247 | if (all_equal(Range&: NotPoison)) { |
| 248 | PhiR->replaceAllUsesWith(New: NotPoison.empty() ? PhiR->getIncomingValue(Idx: 0) |
| 249 | : *NotPoison.begin()); |
| 250 | PhiR->eraseFromParent(); |
| 251 | continue; |
| 252 | } |
| 253 | |
| 254 | SmallVector<VPValue *, 2> OperandsWithMask; |
| 255 | for (const auto &[InVPV, InVPBB] : PhiR->incoming_values_and_blocks()) { |
| 256 | OperandsWithMask.push_back(Elt: InVPV); |
| 257 | OperandsWithMask.push_back(Elt: getEdgeMask(Src: InVPBB, Dst: VPBB)); |
| 258 | } |
| 259 | PHINode *IRPhi = cast_or_null<PHINode>(Val: PhiR->getUnderlyingValue()); |
| 260 | auto *Blend = |
| 261 | new VPBlendRecipe(IRPhi, OperandsWithMask, *PhiR, PhiR->getDebugLoc()); |
| 262 | Builder.insert(R: Blend); |
| 263 | PhiR->replaceAllUsesWith(New: Blend); |
| 264 | PhiR->eraseFromParent(); |
| 265 | } |
| 266 | } |
| 267 | |
| 268 | void VPlanTransforms::introduceMasksAndLinearize(VPlan &Plan, bool FoldTail) { |
| 269 | VPRegionBlock *LoopRegion = Plan.getVectorLoopRegion(); |
| 270 | // Scan the body of the loop in a topological order to visit each basic block |
| 271 | // after having visited its predecessor basic blocks. |
| 272 | VPBasicBlock * = LoopRegion->getEntryBasicBlock(); |
| 273 | ReversePostOrderTraversal<VPBlockShallowTraversalWrapper<VPBlockBase *>> RPOT( |
| 274 | Header); |
| 275 | VPPredicator Predicator; |
| 276 | for (VPBlockBase *VPB : RPOT) { |
| 277 | // Non-outer regions with VPBBs only are supported at the moment. |
| 278 | auto *VPBB = cast<VPBasicBlock>(Val: VPB); |
| 279 | // Introduce the mask for VPBB, which may introduce needed edge masks, and |
| 280 | // convert all phi recipes of VPBB to blend recipes unless VPBB is the |
| 281 | // header. |
| 282 | if (VPBB == Header) { |
| 283 | Predicator.createHeaderMask(HeaderVPBB: Header, FoldTail); |
| 284 | } else { |
| 285 | Predicator.createBlockInMask(VPBB); |
| 286 | Predicator.convertPhisToBlends(VPBB); |
| 287 | } |
| 288 | |
| 289 | VPValue *BlockMask = Predicator.getBlockInMask(VPBB); |
| 290 | if (!BlockMask) |
| 291 | continue; |
| 292 | |
| 293 | // Mask all VPInstructions in the block. |
| 294 | for (VPRecipeBase &R : *VPBB) { |
| 295 | if (auto *VPI = dyn_cast<VPInstruction>(Val: &R)) |
| 296 | VPI->addMask(Mask: BlockMask); |
| 297 | } |
| 298 | } |
| 299 | |
| 300 | // Linearize the blocks of the loop into one serial chain. |
| 301 | VPBlockBase *PrevVPBB = nullptr; |
| 302 | for (VPBasicBlock *VPBB : VPBlockUtils::blocksOnly<VPBasicBlock>(Range: RPOT)) { |
| 303 | auto Successors = to_vector(Range&: VPBB->getSuccessors()); |
| 304 | if (Successors.size() > 1) |
| 305 | VPBB->getTerminator()->eraseFromParent(); |
| 306 | |
| 307 | // Flatten the CFG in the loop. To do so, first disconnect VPBB from its |
| 308 | // successors. Then connect VPBB to the previously visited VPBB. |
| 309 | for (auto *Succ : Successors) |
| 310 | VPBlockUtils::disconnectBlocks(From: VPBB, To: Succ); |
| 311 | if (PrevVPBB) |
| 312 | VPBlockUtils::connectBlocks(From: PrevVPBB, To: VPBB); |
| 313 | |
| 314 | PrevVPBB = VPBB; |
| 315 | } |
| 316 | |
| 317 | // If we folded the tail and introduced a header mask, any extract of the |
| 318 | // last element must be updated to extract from the last active lane of the |
| 319 | // header mask instead (i.e., the lane corresponding to the last active |
| 320 | // iteration). |
| 321 | if (FoldTail) { |
| 322 | assert(Plan.getExitBlocks().size() == 1 && |
| 323 | "only a single-exit block is supported currently" ); |
| 324 | assert(Plan.getExitBlocks().front()->getSinglePredecessor() == |
| 325 | Plan.getMiddleBlock() && |
| 326 | "the exit block must have middle block as single predecessor" ); |
| 327 | |
| 328 | VPBuilder B(Plan.getMiddleBlock()->getTerminator()); |
| 329 | for (VPRecipeBase &R : *Plan.getMiddleBlock()) { |
| 330 | VPValue *Op; |
| 331 | if (!match(V: &R, P: m_CombineOr( |
| 332 | L: m_ExitingIVValue(Op0: m_VPValue(), Op1: m_VPValue(V&: Op)), |
| 333 | R: m_ExtractLastLane(Op0: m_ExtractLastPart(Op0: m_VPValue(V&: Op)))))) |
| 334 | continue; |
| 335 | |
| 336 | // Compute the index of the last active lane. |
| 337 | VPValue * = Predicator.getBlockInMask(VPBB: Header); |
| 338 | VPValue *LastActiveLane = |
| 339 | B.createNaryOp(Opcode: VPInstruction::LastActiveLane, Operands: HeaderMask); |
| 340 | auto *Ext = |
| 341 | B.createNaryOp(Opcode: VPInstruction::ExtractLane, Operands: {LastActiveLane, Op}); |
| 342 | R.getVPSingleValue()->replaceAllUsesWith(New: Ext); |
| 343 | } |
| 344 | } |
| 345 | } |
| 346 | |