1//===-- VPlanPredicator.cpp - VPlan predicator ----------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8///
9/// \file
10/// This file implements predication for VPlans.
11///
12//===----------------------------------------------------------------------===//
13
14#include "VPRecipeBuilder.h"
15#include "VPlan.h"
16#include "VPlanCFG.h"
17#include "VPlanTransforms.h"
18#include "VPlanUtils.h"
19#include "llvm/ADT/PostOrderIterator.h"
20
21using namespace llvm;
22
23namespace {
24class VPPredicator {
25 /// Builder to construct recipes to compute masks.
26 VPBuilder Builder;
27
28 /// When we if-convert we need to create edge masks. We have to cache values
29 /// so that we don't end up with exponential recursion/IR.
30 using EdgeMaskCacheTy =
31 DenseMap<std::pair<const VPBasicBlock *, const VPBasicBlock *>,
32 VPValue *>;
33 using BlockMaskCacheTy = DenseMap<VPBasicBlock *, VPValue *>;
34 EdgeMaskCacheTy EdgeMaskCache;
35
36 BlockMaskCacheTy BlockMaskCache;
37
38 /// Create an edge mask for every destination of cases and/or default.
39 void createSwitchEdgeMasks(VPInstruction *SI);
40
41 /// Computes and return the predicate of the edge between \p Src and \p Dst,
42 /// possibly inserting new recipes at \p Dst (using Builder's insertion point)
43 VPValue *createEdgeMask(VPBasicBlock *Src, VPBasicBlock *Dst);
44
45 /// Returns the *entry* mask for \p VPBB.
46 VPValue *getBlockInMask(VPBasicBlock *VPBB) const {
47 return BlockMaskCache.lookup(Val: VPBB);
48 }
49
50 /// Record \p Mask as the *entry* mask of \p VPBB, which is expected to not
51 /// already have a mask.
52 void setBlockInMask(VPBasicBlock *VPBB, VPValue *Mask) {
53 // TODO: Include the masks as operands in the predicated VPlan directly to
54 // avoid keeping the map of masks beyond the predication transform.
55 assert(!getBlockInMask(VPBB) && "Mask already set");
56 BlockMaskCache[VPBB] = Mask;
57 }
58
59 /// Record \p Mask as the mask of the edge from \p Src to \p Dst. The edge is
60 /// expected to not have a mask already.
61 VPValue *setEdgeMask(const VPBasicBlock *Src, const VPBasicBlock *Dst,
62 VPValue *Mask) {
63 assert(Src != Dst && "Src and Dst must be different");
64 assert(!getEdgeMask(Src, Dst) && "Mask already set");
65 return EdgeMaskCache[{Src, Dst}] = Mask;
66 }
67
68public:
69 /// Returns the precomputed predicate of the edge from \p Src to \p Dst.
70 VPValue *getEdgeMask(const VPBasicBlock *Src, const VPBasicBlock *Dst) const {
71 return EdgeMaskCache.lookup(Val: {Src, Dst});
72 }
73
74 /// Compute and return the mask for the vector loop header block.
75 void createHeaderMask(VPBasicBlock *HeaderVPBB, bool FoldTail);
76
77 /// Compute and return the predicate of \p VPBB, assuming that the header
78 /// block of the loop is set to True, or to the loop mask when tail folding.
79 VPValue *createBlockInMask(VPBasicBlock *VPBB);
80
81 /// Convert phi recipes in \p VPBB to VPBlendRecipes.
82 void convertPhisToBlends(VPBasicBlock *VPBB);
83
84 const BlockMaskCacheTy getBlockMaskCache() const { return BlockMaskCache; }
85};
86} // namespace
87
88VPValue *VPPredicator::createEdgeMask(VPBasicBlock *Src, VPBasicBlock *Dst) {
89 assert(is_contained(Dst->getPredecessors(), Src) && "Invalid edge");
90
91 // Look for cached value.
92 VPValue *EdgeMask = getEdgeMask(Src, Dst);
93 if (EdgeMask)
94 return EdgeMask;
95
96 VPValue *SrcMask = getBlockInMask(VPBB: Src);
97
98 // If there's a single successor, there's no terminator recipe.
99 if (Src->getNumSuccessors() == 1)
100 return setEdgeMask(Src, Dst, Mask: SrcMask);
101
102 auto *Term = cast<VPInstruction>(Val: Src->getTerminator());
103 if (Term->getOpcode() == Instruction::Switch) {
104 createSwitchEdgeMasks(SI: Term);
105 return getEdgeMask(Src, Dst);
106 }
107
108 assert(Term->getOpcode() == VPInstruction::BranchOnCond &&
109 "Unsupported terminator");
110 if (Src->getSuccessors()[0] == Src->getSuccessors()[1])
111 return setEdgeMask(Src, Dst, Mask: SrcMask);
112
113 EdgeMask = Term->getOperand(N: 0);
114 assert(EdgeMask && "No Edge Mask found for condition");
115
116 if (Src->getSuccessors()[0] != Dst)
117 EdgeMask = Builder.createNot(Operand: EdgeMask, DL: Term->getDebugLoc());
118
119 if (SrcMask) { // Otherwise block in-mask is all-one, no need to AND.
120 // The bitwise 'And' of SrcMask and EdgeMask introduces new UB if SrcMask
121 // is false and EdgeMask is poison. Avoid that by using 'LogicalAnd'
122 // instead which generates 'select i1 SrcMask, i1 EdgeMask, i1 false'.
123 EdgeMask = Builder.createLogicalAnd(LHS: SrcMask, RHS: EdgeMask, DL: Term->getDebugLoc());
124 }
125
126 return setEdgeMask(Src, Dst, Mask: EdgeMask);
127}
128
129VPValue *VPPredicator::createBlockInMask(VPBasicBlock *VPBB) {
130 // Start inserting after the block's phis, which be replaced by blends later.
131 Builder.setInsertPoint(TheBB: VPBB, IP: VPBB->getFirstNonPhi());
132 // All-one mask is modelled as no-mask following the convention for masked
133 // load/store/gather/scatter. Initialize BlockMask to no-mask.
134 VPValue *BlockMask = nullptr;
135 // This is the block mask. We OR all unique incoming edges.
136 for (auto *Predecessor : SetVector<VPBlockBase *>(
137 VPBB->getPredecessors().begin(), VPBB->getPredecessors().end())) {
138 VPValue *EdgeMask = createEdgeMask(Src: cast<VPBasicBlock>(Val: Predecessor), Dst: VPBB);
139 if (!EdgeMask) { // Mask of predecessor is all-one so mask of block is
140 // too.
141 setBlockInMask(VPBB, Mask: EdgeMask);
142 return EdgeMask;
143 }
144
145 if (!BlockMask) { // BlockMask has its initial nullptr value.
146 BlockMask = EdgeMask;
147 continue;
148 }
149
150 BlockMask = Builder.createOr(LHS: BlockMask, RHS: EdgeMask, DL: {});
151 }
152
153 setBlockInMask(VPBB, Mask: BlockMask);
154 return BlockMask;
155}
156
157void VPPredicator::createHeaderMask(VPBasicBlock *HeaderVPBB, bool FoldTail) {
158 if (!FoldTail) {
159 setBlockInMask(VPBB: HeaderVPBB, Mask: nullptr);
160 return;
161 }
162
163 // Introduce the early-exit compare IV <= BTC to form header block mask.
164 // This is used instead of IV < TC because TC may wrap, unlike BTC. Start by
165 // constructing the desired canonical IV in the header block as its first
166 // non-phi instructions.
167
168 auto &Plan = *HeaderVPBB->getPlan();
169 auto *IV = new VPWidenCanonicalIVRecipe(Plan.getCanonicalIV());
170 Builder.setInsertPoint(TheBB: HeaderVPBB, IP: HeaderVPBB->getFirstNonPhi());
171 Builder.insert(R: IV);
172
173 VPValue *BTC = Plan.getOrCreateBackedgeTakenCount();
174 VPValue *BlockMask = Builder.createICmp(Pred: CmpInst::ICMP_ULE, A: IV, B: BTC);
175 setBlockInMask(VPBB: HeaderVPBB, Mask: BlockMask);
176}
177
178void VPPredicator::createSwitchEdgeMasks(VPInstruction *SI) {
179 VPBasicBlock *Src = SI->getParent();
180
181 // Create masks where SI is a switch. We create masks for all edges from SI's
182 // parent block at the same time. This is more efficient, as we can create and
183 // collect compares for all cases once.
184 VPValue *Cond = SI->getOperand(N: 0);
185 VPBasicBlock *DefaultDst = cast<VPBasicBlock>(Val: Src->getSuccessors()[0]);
186 MapVector<VPBasicBlock *, SmallVector<VPValue *>> Dst2Compares;
187 for (const auto &[Idx, Succ] :
188 enumerate(First: ArrayRef(Src->getSuccessors()).drop_front())) {
189 VPBasicBlock *Dst = cast<VPBasicBlock>(Val: Succ);
190 assert(!getEdgeMask(Src, Dst) && "Edge masks already created");
191 // Cases whose destination is the same as default are redundant and can
192 // be ignored - they will get there anyhow.
193 if (Dst == DefaultDst)
194 continue;
195 auto &Compares = Dst2Compares[Dst];
196 VPValue *V = SI->getOperand(N: Idx + 1);
197 Compares.push_back(Elt: Builder.createICmp(Pred: CmpInst::ICMP_EQ, A: Cond, B: V));
198 }
199
200 // We need to handle 2 separate cases below for all entries in Dst2Compares,
201 // which excludes destinations matching the default destination.
202 VPValue *SrcMask = getBlockInMask(VPBB: Src);
203 VPValue *DefaultMask = nullptr;
204 for (const auto &[Dst, Conds] : Dst2Compares) {
205 // 1. Dst is not the default destination. Dst is reached if any of the
206 // cases with destination == Dst are taken. Join the conditions for each
207 // case whose destination == Dst using an OR.
208 VPValue *Mask = Conds[0];
209 for (VPValue *V : ArrayRef<VPValue *>(Conds).drop_front())
210 Mask = Builder.createOr(LHS: Mask, RHS: V);
211 if (SrcMask)
212 Mask = Builder.createLogicalAnd(LHS: SrcMask, RHS: Mask);
213 setEdgeMask(Src, Dst, Mask);
214
215 // 2. Create the mask for the default destination, which is reached if
216 // none of the cases with destination != default destination are taken.
217 // Join the conditions for each case where the destination is != Dst using
218 // an OR and negate it.
219 DefaultMask = DefaultMask ? Builder.createOr(LHS: DefaultMask, RHS: Mask) : Mask;
220 }
221
222 if (DefaultMask) {
223 DefaultMask = Builder.createNot(Operand: DefaultMask);
224 if (SrcMask)
225 DefaultMask = Builder.createLogicalAnd(LHS: SrcMask, RHS: DefaultMask);
226 }
227 setEdgeMask(Src, Dst: DefaultDst, Mask: DefaultMask);
228}
229
230void VPPredicator::convertPhisToBlends(VPBasicBlock *VPBB) {
231 SmallVector<VPWidenPHIRecipe *> Phis;
232 for (VPRecipeBase &R : VPBB->phis())
233 Phis.push_back(Elt: cast<VPWidenPHIRecipe>(Val: &R));
234 for (VPWidenPHIRecipe *PhiR : Phis) {
235 // The non-header Phi is converted into a Blend recipe below,
236 // so we don't have to worry about the insertion order and we can just use
237 // the builder. At this point we generate the predication tree. There may
238 // be duplications since this is a simple recursive scan, but future
239 // optimizations will clean it up.
240
241 SmallVector<VPValue *, 2> OperandsWithMask;
242 unsigned NumIncoming = PhiR->getNumIncoming();
243 for (unsigned In = 0; In < NumIncoming; In++) {
244 const VPBasicBlock *Pred = PhiR->getIncomingBlock(Idx: In);
245 OperandsWithMask.push_back(Elt: PhiR->getIncomingValue(Idx: In));
246 VPValue *EdgeMask = getEdgeMask(Src: Pred, Dst: VPBB);
247 if (!EdgeMask) {
248 assert(In == 0 && "Both null and non-null edge masks found");
249 assert(all_equal(PhiR->operands()) &&
250 "Distinct incoming values with one having a full mask");
251 break;
252 }
253 OperandsWithMask.push_back(Elt: EdgeMask);
254 }
255 PHINode *IRPhi = cast<PHINode>(Val: PhiR->getUnderlyingValue());
256 auto *Blend = new VPBlendRecipe(IRPhi, OperandsWithMask);
257 Builder.insert(R: Blend);
258 PhiR->replaceAllUsesWith(New: Blend);
259 PhiR->eraseFromParent();
260 }
261}
262
263DenseMap<VPBasicBlock *, VPValue *>
264VPlanTransforms::introduceMasksAndLinearize(VPlan &Plan, bool FoldTail) {
265 VPRegionBlock *LoopRegion = Plan.getVectorLoopRegion();
266 // Scan the body of the loop in a topological order to visit each basic block
267 // after having visited its predecessor basic blocks.
268 VPBasicBlock *Header = LoopRegion->getEntryBasicBlock();
269 ReversePostOrderTraversal<VPBlockShallowTraversalWrapper<VPBlockBase *>> RPOT(
270 Header);
271 VPPredicator Predicator;
272 for (VPBlockBase *VPB : RPOT) {
273 // Non-outer regions with VPBBs only are supported at the moment.
274 auto *VPBB = cast<VPBasicBlock>(Val: VPB);
275 // Introduce the mask for VPBB, which may introduce needed edge masks, and
276 // convert all phi recipes of VPBB to blend recipes unless VPBB is the
277 // header.
278 if (VPBB == Header) {
279 Predicator.createHeaderMask(HeaderVPBB: Header, FoldTail);
280 continue;
281 }
282
283 Predicator.createBlockInMask(VPBB);
284 Predicator.convertPhisToBlends(VPBB);
285 }
286
287 // Linearize the blocks of the loop into one serial chain.
288 VPBlockBase *PrevVPBB = nullptr;
289 for (VPBasicBlock *VPBB : VPBlockUtils::blocksOnly<VPBasicBlock>(Range: RPOT)) {
290 auto Successors = to_vector(Range&: VPBB->getSuccessors());
291 if (Successors.size() > 1)
292 VPBB->getTerminator()->eraseFromParent();
293
294 // Flatten the CFG in the loop. To do so, first disconnect VPBB from its
295 // successors. Then connect VPBB to the previously visited VPBB.
296 for (auto *Succ : Successors)
297 VPBlockUtils::disconnectBlocks(From: VPBB, To: Succ);
298 if (PrevVPBB)
299 VPBlockUtils::connectBlocks(From: PrevVPBB, To: VPBB);
300
301 PrevVPBB = VPBB;
302 }
303 return Predicator.getBlockMaskCache();
304}
305