1//===-- VPlanPredicator.cpp - VPlan predicator ----------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8///
9/// \file
10/// This file implements predication for VPlans.
11///
12//===----------------------------------------------------------------------===//
13
14#include "VPRecipeBuilder.h"
15#include "VPlan.h"
16#include "VPlanCFG.h"
17#include "VPlanDominatorTree.h"
18#include "VPlanPatternMatch.h"
19#include "VPlanTransforms.h"
20#include "VPlanUtils.h"
21#include "llvm/ADT/PostOrderIterator.h"
22
23using namespace llvm;
24using namespace VPlanPatternMatch;
25
26namespace {
27class VPPredicator {
28 /// Builder to construct recipes to compute masks.
29 VPBuilder Builder;
30
31 /// Dominator tree for the VPlan.
32 VPDominatorTree VPDT;
33
34 /// Post-dominator tree for the VPlan.
35 VPPostDominatorTree VPPDT;
36
37 /// When we if-convert we need to create edge masks. We have to cache values
38 /// so that we don't end up with exponential recursion/IR.
39 using EdgeMaskCacheTy =
40 DenseMap<std::pair<const VPBasicBlock *, const VPBasicBlock *>,
41 VPValue *>;
42 using BlockMaskCacheTy = DenseMap<const VPBasicBlock *, VPValue *>;
43 EdgeMaskCacheTy EdgeMaskCache;
44
45 BlockMaskCacheTy BlockMaskCache;
46
47 /// Create an edge mask for every destination of cases and/or default.
48 void createSwitchEdgeMasks(const VPInstruction *SI);
49
50 /// Computes and return the predicate of the edge between \p Src and \p Dst,
51 /// possibly inserting new recipes at \p Dst (using Builder's insertion point)
52 VPValue *createEdgeMask(const VPBasicBlock *Src, const VPBasicBlock *Dst);
53
54 /// Record \p Mask as the *entry* mask of \p VPBB, which is expected to not
55 /// already have a mask.
56 void setBlockInMask(const VPBasicBlock *VPBB, VPValue *Mask) {
57 // TODO: Include the masks as operands in the predicated VPlan directly to
58 // avoid keeping the map of masks beyond the predication transform.
59 assert(!getBlockInMask(VPBB) && "Mask already set");
60 BlockMaskCache[VPBB] = Mask;
61 }
62
63 /// Record \p Mask as the mask of the edge from \p Src to \p Dst. The edge is
64 /// expected to not have a mask already.
65 VPValue *setEdgeMask(const VPBasicBlock *Src, const VPBasicBlock *Dst,
66 VPValue *Mask) {
67 assert(Src != Dst && "Src and Dst must be different");
68 assert(!getEdgeMask(Src, Dst) && "Mask already set");
69 return EdgeMaskCache[{Src, Dst}] = Mask;
70 }
71
72 /// Returns where to insert new masks in \p VPBB.
73 VPBasicBlock::iterator getMaskInsertPoint(VPBasicBlock *VPBB) {
74 if (VPValue *Mask = getBlockInMask(VPBB))
75 if (VPRecipeBase *MaskR = Mask->getDefiningRecipe())
76 if (MaskR->getParent() == VPBB) // In-mask may be the IDom's.
77 return std::next(x: MaskR->getIterator());
78 return VPBB->getFirstNonPhi();
79 }
80
81public:
82 VPPredicator(VPlan &Plan) : VPDT(Plan), VPPDT(Plan) {}
83
84 /// Returns the *entry* mask for \p VPBB.
85 VPValue *getBlockInMask(const VPBasicBlock *VPBB) const {
86 return BlockMaskCache.lookup(Val: VPBB);
87 }
88
89 /// Returns the precomputed predicate of the edge from \p Src to \p Dst.
90 VPValue *getEdgeMask(const VPBasicBlock *Src, const VPBasicBlock *Dst) const {
91 return EdgeMaskCache.lookup(Val: {Src, Dst});
92 }
93
94 /// Compute the predicate of \p VPBB.
95 void createBlockInMask(VPBasicBlock *VPBB);
96
97 /// Convert phi recipes in \p VPBB to VPBlendRecipes.
98 void convertPhisToBlends(VPBasicBlock *VPBB);
99};
100} // namespace
101
102VPValue *VPPredicator::createEdgeMask(const VPBasicBlock *Src,
103 const VPBasicBlock *Dst) {
104 assert(is_contained(Dst->getPredecessors(), Src) && "Invalid edge");
105
106 // Look for cached value.
107 VPValue *EdgeMask = getEdgeMask(Src, Dst);
108 if (EdgeMask)
109 return EdgeMask;
110
111 VPValue *SrcMask = getBlockInMask(VPBB: Src);
112
113 // If there's a single successor, there's no terminator recipe.
114 if (Src->getNumSuccessors() == 1)
115 return setEdgeMask(Src, Dst, Mask: SrcMask);
116
117 auto *Term = cast<VPInstruction>(Val: Src->getTerminator());
118 if (Term->getOpcode() == Instruction::Switch) {
119 createSwitchEdgeMasks(SI: Term);
120 return getEdgeMask(Src, Dst);
121 }
122
123 assert(Term->getOpcode() == VPInstruction::BranchOnCond &&
124 "Unsupported terminator");
125 if (Src->getSuccessors()[0] == Src->getSuccessors()[1])
126 return setEdgeMask(Src, Dst, Mask: SrcMask);
127
128 EdgeMask = Term->getOperand(N: 0);
129 assert(EdgeMask && "No Edge Mask found for condition");
130
131 if (Src->getSuccessors()[0] != Dst)
132 EdgeMask = Builder.createNot(Operand: EdgeMask, DL: Term->getDebugLoc());
133
134 if (SrcMask) { // Otherwise block in-mask is all-one, no need to AND.
135 // The bitwise 'And' of SrcMask and EdgeMask introduces new UB if SrcMask
136 // is false and EdgeMask is poison. Avoid that by using 'LogicalAnd'
137 // instead which generates 'select i1 SrcMask, i1 EdgeMask, i1 false'.
138 EdgeMask = Builder.createLogicalAnd(LHS: SrcMask, RHS: EdgeMask, DL: Term->getDebugLoc());
139 }
140
141 return setEdgeMask(Src, Dst, Mask: EdgeMask);
142}
143
144void VPPredicator::createBlockInMask(VPBasicBlock *VPBB) {
145 // Start inserting after the block's phis, which be replaced by blends later.
146 Builder.setInsertPoint(TheBB: VPBB, IP: VPBB->getFirstNonPhi());
147
148 // Reuse the mask of the immediate dominator if the VPBB post-dominates the
149 // immediate dominator.
150 auto *IDom = VPDT.getNode(BB: VPBB)->getIDom();
151 assert(IDom && "Block in loop must have immediate dominator");
152 auto *IDomBB = cast<VPBasicBlock>(Val: IDom->getBlock());
153 if (VPPDT.properlyDominates(A: VPBB, B: IDomBB)) {
154 setBlockInMask(VPBB, Mask: getBlockInMask(VPBB: IDomBB));
155 return;
156 }
157 // All-one mask is modelled as no-mask following the convention for masked
158 // load/store/gather/scatter. Initialize BlockMask to no-mask.
159 VPValue *BlockMask = nullptr;
160 // This is the block mask. We OR all unique incoming edges.
161 for (auto *Predecessor : SetVector<VPBlockBase *>(
162 VPBB->getPredecessors().begin(), VPBB->getPredecessors().end())) {
163 VPValue *EdgeMask = createEdgeMask(Src: cast<VPBasicBlock>(Val: Predecessor), Dst: VPBB);
164 if (!EdgeMask) { // Mask of predecessor is all-one so mask of block is
165 // too.
166 setBlockInMask(VPBB, Mask: EdgeMask);
167 return;
168 }
169
170 if (!BlockMask) { // BlockMask has its initial nullptr value.
171 BlockMask = EdgeMask;
172 continue;
173 }
174
175 BlockMask = Builder.createOr(LHS: BlockMask, RHS: EdgeMask, DL: {});
176 }
177
178 setBlockInMask(VPBB, Mask: BlockMask);
179}
180
181void VPPredicator::createSwitchEdgeMasks(const VPInstruction *SI) {
182 const VPBasicBlock *Src = SI->getParent();
183
184 // Create masks where SI is a switch. We create masks for all edges from SI's
185 // parent block at the same time. This is more efficient, as we can create and
186 // collect compares for all cases once.
187 VPValue *Cond = SI->getOperand(N: 0);
188 VPBasicBlock *DefaultDst = cast<VPBasicBlock>(Val: Src->getSuccessors()[0]);
189 MapVector<VPBasicBlock *, SmallVector<VPValue *>> Dst2Compares;
190 for (const auto &[Idx, Succ] : enumerate(First: drop_begin(RangeOrContainer: Src->getSuccessors()))) {
191 VPBasicBlock *Dst = cast<VPBasicBlock>(Val: Succ);
192 assert(!getEdgeMask(Src, Dst) && "Edge masks already created");
193 // Cases whose destination is the same as default are redundant and can
194 // be ignored - they will get there anyhow.
195 if (Dst == DefaultDst)
196 continue;
197 auto &Compares = Dst2Compares[Dst];
198 VPValue *V = SI->getOperand(N: Idx + 1);
199 Compares.push_back(Elt: Builder.createICmp(Pred: CmpInst::ICMP_EQ, A: Cond, B: V));
200 }
201
202 // We need to handle 2 separate cases below for all entries in Dst2Compares,
203 // which excludes destinations matching the default destination.
204 VPValue *SrcMask = getBlockInMask(VPBB: Src);
205 VPValue *DefaultMask = nullptr;
206 for (const auto &[Dst, Conds] : Dst2Compares) {
207 // 1. Dst is not the default destination. Dst is reached if any of the
208 // cases with destination == Dst are taken. Join the conditions for each
209 // case whose destination == Dst using an OR.
210 VPValue *Mask = Conds[0];
211 for (VPValue *V : drop_begin(RangeOrContainer: Conds))
212 Mask = Builder.createOr(LHS: Mask, RHS: V);
213 if (SrcMask)
214 Mask = Builder.createLogicalAnd(LHS: SrcMask, RHS: Mask);
215 setEdgeMask(Src, Dst, Mask);
216
217 // 2. Create the mask for the default destination, which is reached if
218 // none of the cases with destination != default destination are taken.
219 // Join the conditions for each case where the destination is != Dst using
220 // an OR and negate it.
221 DefaultMask = DefaultMask ? Builder.createOr(LHS: DefaultMask, RHS: Mask) : Mask;
222 }
223
224 if (DefaultMask) {
225 DefaultMask = Builder.createNot(Operand: DefaultMask);
226 if (SrcMask)
227 DefaultMask = Builder.createLogicalAnd(LHS: SrcMask, RHS: DefaultMask);
228 } else {
229 // There are no destinations other than the default destination, so this is
230 // an unconditional branch.
231 DefaultMask = SrcMask;
232 }
233 setEdgeMask(Src, Dst: DefaultDst, Mask: DefaultMask);
234}
235
236void VPPredicator::convertPhisToBlends(VPBasicBlock *VPBB) {
237 Builder.setInsertPoint(TheBB: VPBB, IP: getMaskInsertPoint(VPBB));
238
239 SmallVector<VPPhi *> Phis;
240 for (VPRecipeBase &R : VPBB->phis())
241 Phis.push_back(Elt: cast<VPPhi>(Val: &R));
242 for (VPPhi *PhiR : Phis) {
243 // The non-header Phi is converted into a Blend recipe below,
244 // so we don't have to worry about the insertion order and we can just use
245 // the builder. At this point we generate the predication tree. There may
246 // be duplications since this is a simple recursive scan, but future
247 // optimizations will clean it up.
248
249 auto NotPoison = make_filter_range(Range: PhiR->incoming_values(), Pred: [](VPValue *V) {
250 return !match(V, P: m_Poison());
251 });
252 if (all_equal(Range&: NotPoison)) {
253 PhiR->replaceAllUsesWith(New: NotPoison.empty() ? PhiR->getIncomingValue(Idx: 0)
254 : *NotPoison.begin());
255 PhiR->eraseFromParent();
256 continue;
257 }
258
259 SmallVector<VPValue *, 2> OperandsWithMask;
260 for (const auto &[InVPV, InVPBB] : PhiR->incoming_values_and_blocks()) {
261 OperandsWithMask.push_back(Elt: InVPV);
262 OperandsWithMask.push_back(Elt: createEdgeMask(Src: InVPBB, Dst: VPBB));
263 }
264 PHINode *IRPhi = cast_or_null<PHINode>(Val: PhiR->getUnderlyingValue());
265 auto *Blend =
266 new VPBlendRecipe(IRPhi, OperandsWithMask, *PhiR, PhiR->getDebugLoc());
267 Builder.insert(R: Blend);
268 PhiR->replaceAllUsesWith(New: Blend);
269 PhiR->eraseFromParent();
270 }
271}
272
273void VPlanTransforms::introduceMasksAndLinearize(VPlan &Plan) {
274 // Nested loop regions (outer-loop vectorization) are not supported yet.
275 if (Plan.isOuterLoop())
276 return;
277 VPRegionBlock *LoopRegion = Plan.getVectorLoopRegion();
278 // Scan the body of the loop in a topological order to visit each basic block
279 // after having visited its predecessor basic blocks.
280 VPBasicBlock *Header = LoopRegion->getEntryBasicBlock();
281 ReversePostOrderTraversal<VPBlockShallowTraversalWrapper<VPBlockBase *>> RPOT(
282 Header);
283 VPPredicator Predicator(Plan);
284 for (VPBlockBase *VPB : RPOT) {
285 // Non-outer regions with VPBBs only are supported at the moment.
286 auto *VPBB = cast<VPBasicBlock>(Val: VPB);
287 // Introduce the mask for VPBB, which may introduce needed edge masks, and
288 // convert all phi recipes of VPBB to blend recipes unless VPBB is the
289 // header.
290 if (VPBB != Header)
291 Predicator.createBlockInMask(VPBB);
292
293 VPValue *BlockMask = Predicator.getBlockInMask(VPBB);
294 if (!BlockMask)
295 continue;
296
297 // Mask all VPInstructions in the block.
298 for (VPRecipeBase &R : *VPBB) {
299 if (auto *VPI = dyn_cast<VPInstruction>(Val: &R))
300 VPI->addMask(Mask: BlockMask);
301 }
302 }
303
304 for (VPBlockBase *VPBB : reverse(C&: RPOT))
305 if (VPBB != Header)
306 Predicator.convertPhisToBlends(VPBB: cast<VPBasicBlock>(Val: VPBB));
307
308 // Linearize the blocks of the loop into one serial chain.
309 VPBlockBase *PrevVPBB = nullptr;
310 for (VPBasicBlock *VPBB : VPBlockUtils::blocksOnly<VPBasicBlock>(Range&: RPOT)) {
311 auto Successors = to_vector(Range&: VPBB->getSuccessors());
312 if (Successors.size() > 1)
313 VPBB->getTerminator()->eraseFromParent();
314
315 // Flatten the CFG in the loop. To do so, first disconnect VPBB from its
316 // successors. Then connect VPBB to the previously visited VPBB.
317 for (auto *Succ : Successors)
318 VPBlockUtils::disconnectBlocks(From: VPBB, To: Succ);
319 if (PrevVPBB)
320 VPBlockUtils::connectBlocks(From: PrevVPBB, To: VPBB);
321
322 PrevVPBB = VPBB;
323 }
324}
325