1//===-- SPIRVStructurizer.cpp ----------------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9//===----------------------------------------------------------------------===//
10
11#include "Analysis/SPIRVConvergenceRegionAnalysis.h"
12#include "SPIRV.h"
13#include "SPIRVStructurizerWrapper.h"
14#include "SPIRVSubtarget.h"
15#include "SPIRVUtils.h"
16#include "llvm/ADT/DenseMap.h"
17#include "llvm/ADT/SmallPtrSet.h"
18#include "llvm/Analysis/LoopInfo.h"
19#include "llvm/IR/CFG.h"
20#include "llvm/IR/Dominators.h"
21#include "llvm/IR/IRBuilder.h"
22#include "llvm/IR/IntrinsicInst.h"
23#include "llvm/IR/Intrinsics.h"
24#include "llvm/IR/IntrinsicsSPIRV.h"
25#include "llvm/IR/LegacyPassManager.h"
26#include "llvm/InitializePasses.h"
27#include "llvm/Transforms/Utils.h"
28#include "llvm/Transforms/Utils/Cloning.h"
29#include "llvm/Transforms/Utils/LoopSimplify.h"
30#include "llvm/Transforms/Utils/LowerMemIntrinsics.h"
31#include <stack>
32
33using namespace llvm;
34using namespace SPIRV;
35
36using BlockSet = SmallPtrSet<BasicBlock *, 0>;
37using Edge = std::pair<BasicBlock *, BasicBlock *>;
38
39// Helper function to do a partial order visit from the block |Start|, calling
40// |Op| on each visited node.
41static void partialOrderVisit(BasicBlock &Start,
42 std::function<bool(BasicBlock *)> Op) {
43 PartialOrderingVisitor V(*Start.getParent());
44 V.partialOrderVisit(Start, Op: std::move(Op));
45}
46
47// Returns the exact convergence region in the tree defined by `Node` for which
48// `BB` is the header, nullptr otherwise.
49static const ConvergenceRegion *
50getRegionForHeader(const ConvergenceRegion *Node, BasicBlock *BB) {
51 if (Node->Entry == BB)
52 return Node;
53
54 for (auto *Child : Node->Children) {
55 const auto *CR = getRegionForHeader(Node: Child, BB);
56 if (CR != nullptr)
57 return CR;
58 }
59 return nullptr;
60}
61
62// Returns the single BasicBlock exiting the convergence region `CR`,
63// nullptr if no such exit exists.
64static BasicBlock *getExitFor(const ConvergenceRegion *CR) {
65 SmallPtrSet<BasicBlock *, 0> ExitTargets;
66 for (BasicBlock *Exit : CR->Exits) {
67 for (BasicBlock *Successor : successors(BB: Exit)) {
68 if (CR->Blocks.count(Ptr: Successor) == 0)
69 ExitTargets.insert(Ptr: Successor);
70 }
71 }
72
73 assert(ExitTargets.size() <= 1);
74 if (ExitTargets.size() == 0)
75 return nullptr;
76
77 return *ExitTargets.begin();
78}
79
80// Returns the merge block designated by I if I is a merge instruction, nullptr
81// otherwise.
82static BasicBlock *getDesignatedMergeBlock(Instruction *I) {
83 IntrinsicInst *II = dyn_cast_or_null<IntrinsicInst>(Val: I);
84 if (II == nullptr)
85 return nullptr;
86
87 if (II->getIntrinsicID() != Intrinsic::spv_loop_merge &&
88 II->getIntrinsicID() != Intrinsic::spv_selection_merge)
89 return nullptr;
90
91 BlockAddress *BA = cast<BlockAddress>(Val: II->getOperand(i_nocapture: 0));
92 return BA->getBasicBlock();
93}
94
95// Returns the continue block designated by I if I is an OpLoopMerge, nullptr
96// otherwise.
97static BasicBlock *getDesignatedContinueBlock(Instruction *I) {
98 IntrinsicInst *II = dyn_cast_or_null<IntrinsicInst>(Val: I);
99 if (II == nullptr)
100 return nullptr;
101
102 if (II->getIntrinsicID() != Intrinsic::spv_loop_merge)
103 return nullptr;
104
105 BlockAddress *BA = cast<BlockAddress>(Val: II->getOperand(i_nocapture: 1));
106 return BA->getBasicBlock();
107}
108
109// Returns true if Header has one merge instruction which designated Merge as
110// merge block.
111static bool isDefinedAsSelectionMergeBy(BasicBlock &Header, BasicBlock &Merge) {
112 for (auto &I : Header) {
113 BasicBlock *MB = getDesignatedMergeBlock(I: &I);
114 if (MB == &Merge)
115 return true;
116 }
117 return false;
118}
119
120// Returns true if the BB has one OpLoopMerge instruction.
121static bool hasLoopMergeInstruction(BasicBlock &BB) {
122 for (auto &I : BB)
123 if (getDesignatedContinueBlock(I: &I))
124 return true;
125 return false;
126}
127
128// Returns true is I is an OpSelectionMerge or OpLoopMerge instruction, false
129// otherwise.
130static bool isMergeInstruction(Instruction *I) {
131 return getDesignatedMergeBlock(I) != nullptr;
132}
133
134// Returns all blocks in F having at least one OpLoopMerge or OpSelectionMerge
135// instruction.
136static SmallPtrSet<BasicBlock *, 2> getHeaderBlocks(Function &F) {
137 SmallPtrSet<BasicBlock *, 2> Output;
138 for (BasicBlock &BB : F) {
139 for (Instruction &I : BB) {
140 if (getDesignatedMergeBlock(I: &I) != nullptr)
141 Output.insert(Ptr: &BB);
142 }
143 }
144 return Output;
145}
146
147// Returns all basic blocks in |F| referenced by at least 1
148// OpSelectionMerge/OpLoopMerge instruction.
149static SmallPtrSet<BasicBlock *, 2> getMergeBlocks(Function &F) {
150 SmallPtrSet<BasicBlock *, 2> Output;
151 for (BasicBlock &BB : F) {
152 for (Instruction &I : BB) {
153 BasicBlock *MB = getDesignatedMergeBlock(I: &I);
154 if (MB != nullptr)
155 Output.insert(Ptr: MB);
156 }
157 }
158 return Output;
159}
160
161// Return all the merge instructions contained in BB.
162// Note: the SPIR-V spec doesn't allow a single BB to contain more than 1 merge
163// instruction, but this can happen while we structurize the CFG.
164static std::vector<Instruction *> getMergeInstructions(BasicBlock &BB) {
165 std::vector<Instruction *> Output;
166 for (Instruction &I : BB)
167 if (isMergeInstruction(I: &I))
168 Output.push_back(x: &I);
169 return Output;
170}
171
172// Returns all basic blocks in |F| referenced as continue target by at least 1
173// OpLoopMerge instruction.
174static SmallPtrSet<BasicBlock *, 2> getContinueBlocks(Function &F) {
175 SmallPtrSet<BasicBlock *, 2> Output;
176 for (BasicBlock &BB : F) {
177 for (Instruction &I : BB) {
178 BasicBlock *MB = getDesignatedContinueBlock(I: &I);
179 if (MB != nullptr)
180 Output.insert(Ptr: MB);
181 }
182 }
183 return Output;
184}
185
186// Do a preorder traversal of the CFG starting from the BB |Start|.
187// point. Calls |op| on each basic block encountered during the traversal.
188static void visit(BasicBlock &Start, std::function<bool(BasicBlock *)> op) {
189 std::stack<BasicBlock *> ToVisit;
190 SmallPtrSet<BasicBlock *, 8> Seen;
191
192 ToVisit.push(x: &Start);
193 Seen.insert(Ptr: ToVisit.top());
194 while (ToVisit.size() != 0) {
195 BasicBlock *BB = ToVisit.top();
196 ToVisit.pop();
197
198 if (!op(BB))
199 continue;
200
201 for (auto Succ : successors(BB)) {
202 if (Seen.contains(Ptr: Succ))
203 continue;
204 ToVisit.push(x: Succ);
205 Seen.insert(Ptr: Succ);
206 }
207 }
208}
209
210// Replaces the conditional and unconditional branch targets of |BB| by
211// |NewTarget| if the target was |OldTarget|. This function also makes sure the
212// associated merge instruction gets updated accordingly.
213static void replaceIfBranchTargets(BasicBlock *BB, BasicBlock *OldTarget,
214 BasicBlock *NewTarget) {
215 auto *BI = cast<CondBrInst>(Val: BB->getTerminator());
216
217 // 1. Replace all matching successors.
218 for (size_t i = 0; i < BI->getNumSuccessors(); i++) {
219 if (BI->getSuccessor(i) == OldTarget)
220 BI->setSuccessor(idx: i, NewSucc: NewTarget);
221 }
222
223 // Branch had 2 successors, maybe now both are the same?
224 if (BI->getSuccessor(i: 0) != BI->getSuccessor(i: 1))
225 return;
226
227 // Note: we may end up here because the original IR had such branches.
228 // This means Target is not necessarily equal to NewTarget.
229 IRBuilder<> Builder(BB);
230 Builder.SetInsertPoint(BI);
231 Builder.CreateBr(Dest: BI->getSuccessor(i: 0));
232 BI->eraseFromParent();
233
234 // The branch was the only instruction, nothing else to do.
235 if (BB->size() == 1)
236 return;
237
238 // Otherwise, we need to check: was there an OpSelectionMerge before this
239 // branch? If we removed the OpBranchConditional, we must also remove the
240 // OpSelectionMerge. This is not valid for OpLoopMerge:
241 IntrinsicInst *II =
242 dyn_cast<IntrinsicInst>(Val: BB->getTerminator()->getPrevNode());
243 if (!II || II->getIntrinsicID() != Intrinsic::spv_selection_merge)
244 return;
245
246 Constant *C = cast<Constant>(Val: II->getOperand(i_nocapture: 0));
247 II->eraseFromParent();
248 if (!C->isConstantUsed())
249 C->destroyConstant();
250}
251
252// Replaces the target of branch instruction in |BB| with |NewTarget| if it
253// was |OldTarget|. This function also fixes the associated merge instruction.
254// Note: this function does not simplify branching instructions, it only updates
255// targets. See also: simplifyBranches.
256static void replaceBranchTargets(BasicBlock *BB, BasicBlock *OldTarget,
257 BasicBlock *NewTarget) {
258 auto *T = BB->getTerminator();
259 if (isa<ReturnInst>(Val: T))
260 return;
261 if (auto *BI = dyn_cast<UncondBrInst>(Val: T)) {
262 if (BI->getSuccessor() == OldTarget)
263 BI->setSuccessor(NewTarget);
264 return;
265 }
266
267 if (isa<CondBrInst>(Val: T))
268 return replaceIfBranchTargets(BB, OldTarget, NewTarget);
269
270 if (auto *SI = dyn_cast<SwitchInst>(Val: T)) {
271 for (size_t i = 0; i < SI->getNumSuccessors(); i++) {
272 if (SI->getSuccessor(idx: i) == OldTarget)
273 SI->setSuccessor(idx: i, NewSucc: NewTarget);
274 }
275 return;
276 }
277
278 assert(false && "Unhandled terminator type.");
279}
280
281namespace {
282// Given a reducible CFG, produces a structurized CFG in the SPIR-V sense,
283// adding merge instructions when required.
284class SPIRVStructurizer : public FunctionPass {
285 struct DivergentConstruct;
286 // Represents a list of condition/loops/switch constructs.
287 // See SPIR-V 2.11.2. Structured Control-flow Constructs for the list of
288 // constructs.
289 using ConstructList = std::vector<std::unique_ptr<DivergentConstruct>>;
290
291 // Represents a divergent construct in the SPIR-V sense.
292 // Such constructs are represented by a header (entry), a merge block (exit),
293 // and possibly a continue block (back-edge). A construct can contain other
294 // constructs, but their boundaries do not cross.
295 struct DivergentConstruct {
296 BasicBlock *Header = nullptr;
297 BasicBlock *Merge = nullptr;
298 BasicBlock *Continue = nullptr;
299
300 DivergentConstruct *Parent = nullptr;
301 ConstructList Children;
302 };
303
304 // An helper class to clean the construct boundaries.
305 // It is used to gather the list of blocks that should belong to each
306 // divergent construct, and possibly modify CFG edges when exits would cross
307 // the boundary of multiple constructs.
308 struct Splitter {
309 Function &F;
310 LoopInfo &LI;
311 DomTreeBuilder::BBDomTree DT;
312 DomTreeBuilder::BBPostDomTree PDT;
313
314 Splitter(Function &F, LoopInfo &LI) : F(F), LI(LI) { invalidate(); }
315
316 void invalidate() {
317 PDT.recalculate(Func&: F);
318 DT.recalculate(Func&: F);
319 }
320
321 // Returns the list of blocks that belong to a SPIR-V loop construct,
322 // including the continue construct.
323 std::vector<BasicBlock *> getLoopConstructBlocks(BasicBlock *Header,
324 BasicBlock *Merge) {
325 assert(DT.dominates(Header, Merge));
326 std::vector<BasicBlock *> Output;
327 partialOrderVisit(Start&: *Header, Op: [&](BasicBlock *BB) {
328 if (BB == Merge)
329 return false;
330 if (DT.dominates(A: Merge, B: BB) || !DT.dominates(A: Header, B: BB))
331 return false;
332 Output.push_back(x: BB);
333 return true;
334 });
335 return Output;
336 }
337
338 // Returns the list of blocks that belong to a SPIR-V selection construct.
339 std::vector<BasicBlock *>
340 getSelectionConstructBlocks(DivergentConstruct *Node) {
341 assert(DT.dominates(Node->Header, Node->Merge));
342 BlockSet OutsideBlocks;
343 OutsideBlocks.insert(Ptr: Node->Merge);
344
345 for (DivergentConstruct *It = Node->Parent; It != nullptr;
346 It = It->Parent) {
347 OutsideBlocks.insert(Ptr: It->Merge);
348 if (It->Continue)
349 OutsideBlocks.insert(Ptr: It->Continue);
350 }
351
352 std::vector<BasicBlock *> Output;
353 partialOrderVisit(Start&: *Node->Header, Op: [&](BasicBlock *BB) {
354 if (OutsideBlocks.count(Ptr: BB) != 0)
355 return false;
356 if (DT.dominates(A: Node->Merge, B: BB) || !DT.dominates(A: Node->Header, B: BB))
357 return false;
358 Output.push_back(x: BB);
359 return true;
360 });
361 return Output;
362 }
363
364 // Returns the list of blocks that belong to a SPIR-V switch construct.
365 std::vector<BasicBlock *> getSwitchConstructBlocks(BasicBlock *Header,
366 BasicBlock *Merge) {
367 assert(DT.dominates(Header, Merge));
368
369 std::vector<BasicBlock *> Output;
370 partialOrderVisit(Start&: *Header, Op: [&](BasicBlock *BB) {
371 // the blocks structurally dominated by a switch header,
372 if (!DT.dominates(A: Header, B: BB))
373 return false;
374 // excluding blocks structurally dominated by the switch header’s merge
375 // block.
376 if (DT.dominates(A: Merge, B: BB) || BB == Merge)
377 return false;
378 Output.push_back(x: BB);
379 return true;
380 });
381 return Output;
382 }
383
384 // Returns the list of blocks that belong to a SPIR-V case construct.
385 std::vector<BasicBlock *> getCaseConstructBlocks(BasicBlock *Target,
386 BasicBlock *Merge) {
387 assert(DT.dominates(Target, Merge));
388
389 std::vector<BasicBlock *> Output;
390 partialOrderVisit(Start&: *Target, Op: [&](BasicBlock *BB) {
391 // the blocks structurally dominated by an OpSwitch Target or Default
392 // block
393 if (!DT.dominates(A: Target, B: BB))
394 return false;
395 // excluding the blocks structurally dominated by the OpSwitch
396 // construct’s corresponding merge block.
397 if (DT.dominates(A: Merge, B: BB) || BB == Merge)
398 return false;
399 Output.push_back(x: BB);
400 return true;
401 });
402 return Output;
403 }
404
405 // Splits the given edges by recreating proxy nodes so that the destination
406 // has unique incoming edges from this region.
407 //
408 // clang-format off
409 //
410 // In SPIR-V, constructs must have a single exit/merge.
411 // Given nodes A and B in the construct, a node C outside, and the following edges.
412 // A -> C
413 // B -> C
414 //
415 // In such cases, we must create a new exit node D, that belong to the construct to make is viable:
416 // A -> D -> C
417 // B -> D -> C
418 //
419 // This is fine (assuming C has no PHI nodes), but requires handling the merge instruction here.
420 // By adding a proxy node, we create a regular divergent shape which can easily be regularized later on.
421 // A -> D -> D1 -> C
422 // B -> D -> D2 -> C
423 //
424 // A, B, D belongs to the construct. D is the exit. D1 and D2 are empty.
425 //
426 // clang-format on
427 std::vector<Edge>
428 createAliasBlocksForComplexEdges(std::vector<Edge> Edges) {
429 SmallPtrSet<BasicBlock *, 0> Seen;
430 std::vector<Edge> Output;
431 Output.reserve(n: Edges.size());
432
433 for (auto &[Src, Dst] : Edges) {
434 auto [Iterator, Inserted] = Seen.insert(Ptr: Src);
435 if (!Inserted) {
436 // Src already a source node. Cannot have 2 edges from A to B.
437 // Creating alias source block.
438 BasicBlock *NewSrc = BasicBlock::Create(
439 Context&: F.getContext(), Name: Src->getName() + ".new.src", Parent: &F);
440 replaceBranchTargets(BB: Src, OldTarget: Dst, NewTarget: NewSrc);
441 IRBuilder<> Builder(NewSrc);
442 Builder.CreateBr(Dest: Dst);
443 Src = NewSrc;
444 }
445
446 Output.emplace_back(args&: Src, args&: Dst);
447 }
448
449 return Output;
450 }
451
452 // Given a construct defined by |Header|, and a list of exiting edges
453 // |Edges|, creates a new single exit node, fixing up those edges.
454 BasicBlock *createSingleExitNode(BasicBlock *Header,
455 std::vector<Edge> &Edges) {
456
457 std::vector<Edge> FixedEdges = createAliasBlocksForComplexEdges(Edges);
458
459 std::vector<BasicBlock *> Dsts;
460 DenseMap<BasicBlock *, ConstantInt *> DstToIndex;
461 auto NewExit = BasicBlock::Create(Context&: F.getContext(),
462 Name: Header->getName() + ".new.exit", Parent: &F);
463 IRBuilder<> ExitBuilder(NewExit);
464 for (auto &[Src, Dst] : FixedEdges) {
465 if (DstToIndex.count(Val: Dst) != 0)
466 continue;
467 DstToIndex.try_emplace(Key: Dst, Args: ExitBuilder.getInt32(C: DstToIndex.size()));
468 Dsts.push_back(x: Dst);
469 }
470
471 if (Dsts.size() == 1) {
472 for (auto &[Src, Dst] : FixedEdges) {
473 replaceBranchTargets(BB: Src, OldTarget: Dst, NewTarget: NewExit);
474 }
475 ExitBuilder.CreateBr(Dest: Dsts[0]);
476 return NewExit;
477 }
478
479 AllocaInst *Variable = createVariable(F, Type: ExitBuilder.getInt32Ty());
480 for (auto &[Src, Dst] : FixedEdges) {
481 IRBuilder<> B2(Src);
482 B2.SetInsertPoint(Src->getFirstInsertionPt());
483 B2.CreateStore(Val: DstToIndex[Dst], Ptr: Variable);
484 replaceBranchTargets(BB: Src, OldTarget: Dst, NewTarget: NewExit);
485 }
486
487 Value *Load = ExitBuilder.CreateLoad(Ty: ExitBuilder.getInt32Ty(), Ptr: Variable);
488
489 // If we can avoid an OpSwitch, generate an OpBranch. Reason is some
490 // OpBranch are allowed to exist without a new OpSelectionMerge if one of
491 // the branch is the parent's merge node, while OpSwitches are not.
492 if (Dsts.size() == 2) {
493 Value *Condition =
494 ExitBuilder.CreateCmp(Pred: CmpInst::ICMP_EQ, LHS: DstToIndex[Dsts[0]], RHS: Load);
495 ExitBuilder.CreateCondBr(Cond: Condition, True: Dsts[0], False: Dsts[1]);
496 return NewExit;
497 }
498
499 SwitchInst *Sw = ExitBuilder.CreateSwitch(V: Load, Dest: Dsts[0], NumCases: Dsts.size() - 1);
500 for (BasicBlock *BB : drop_begin(RangeOrContainer&: Dsts))
501 Sw->addCase(OnVal: DstToIndex[BB], Dest: BB);
502 return NewExit;
503 }
504 };
505
506 // Creates a new basic block in F with a single OpUnreachable instruction.
507 BasicBlock *CreateUnreachable(Function &F) {
508 BasicBlock *BB = BasicBlock::Create(Context&: F.getContext(), Name: "unreachable", Parent: &F);
509 IRBuilder<> Builder(BB);
510 Builder.CreateUnreachable();
511 return BB;
512 }
513
514 // Add OpLoopMerge instruction on cycles.
515 bool addMergeForLoops(Function &F) {
516 LoopInfo &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
517 auto *TopLevelRegion =
518 getAnalysis<SPIRVConvergenceRegionAnalysisWrapperPass>()
519 .getRegionInfo()
520 .getTopLevelRegion();
521
522 bool Modified = false;
523 for (auto &BB : F) {
524 // Not a loop header. Ignoring for now.
525 if (!LI.isLoopHeader(BB: &BB))
526 continue;
527 auto *L = LI.getLoopFor(BB: &BB);
528
529 // This loop header is not the entrance of a convergence region. Ignoring
530 // this block.
531 auto *CR = getRegionForHeader(Node: TopLevelRegion, BB: &BB);
532 if (CR == nullptr)
533 continue;
534
535 IRBuilder<> Builder(&BB);
536
537 auto *Merge = getExitFor(CR);
538 // We are indeed in a loop, but there are no exits (infinite loop).
539 // This could be caused by a bad shader, but also could be an artifact
540 // from an earlier optimization. It is not always clear if structurally
541 // reachable means runtime reachable, so we cannot error-out. What we must
542 // do however is to make is legal on the SPIR-V point of view, hence
543 // adding an unreachable merge block.
544 if (Merge == nullptr) {
545 UncondBrInst *Br = cast<UncondBrInst>(Val: BB.getTerminator());
546 Merge = CreateUnreachable(F);
547 Builder.SetInsertPoint(Br);
548 Builder.CreateCondBr(Cond: Builder.getFalse(), True: Merge, False: Br->getSuccessor(i: 0));
549 Br->eraseFromParent();
550 }
551
552 auto *Continue = L->getLoopLatch();
553
554 Builder.SetInsertPoint(BB.getTerminator());
555 auto MergeAddress = BlockAddress::get(F: Merge->getParent(), BB: Merge);
556 auto ContinueAddress = BlockAddress::get(F: Continue->getParent(), BB: Continue);
557 SmallVector<Value *, 2> Args = {MergeAddress, ContinueAddress};
558 SmallVector<unsigned, 1> LoopControlImms =
559 getSpirvLoopControlOperandsFromLoopMetadata(L);
560 for (unsigned Imm : LoopControlImms)
561 Args.emplace_back(Args: ConstantInt::get(Ty: Builder.getInt32Ty(), V: Imm));
562 Builder.CreateIntrinsic(ID: Intrinsic::spv_loop_merge, Args: {Args});
563 Modified = true;
564 }
565
566 return Modified;
567 }
568
569 // Adds an OpSelectionMerge to the immediate dominator or each node with an
570 // in-degree of 2 or more which is not already the merge target of an
571 // OpLoopMerge/OpSelectionMerge.
572 bool addMergeForNodesWithMultiplePredecessors(Function &F) {
573 DomTreeBuilder::BBDomTree DT;
574 DT.recalculate(Func&: F);
575
576 bool Modified = false;
577 for (auto &BB : F) {
578 if (pred_size(BB: &BB) <= 1)
579 continue;
580
581 if (hasLoopMergeInstruction(BB) && pred_size(BB: &BB) <= 2)
582 continue;
583
584 assert(DT.getNode(&BB)->getIDom());
585 BasicBlock *Header = DT.getNode(BB: &BB)->getIDom()->getBlock();
586
587 if (isDefinedAsSelectionMergeBy(Header&: *Header, Merge&: BB))
588 continue;
589
590 IRBuilder<> Builder(Header);
591 Builder.SetInsertPoint(Header->getTerminator());
592
593 auto MergeAddress = BlockAddress::get(F: BB.getParent(), BB: &BB);
594 createOpSelectMerge(Builder: &Builder, MergeAddress);
595
596 Modified = true;
597 }
598
599 return Modified;
600 }
601
602 // When a block has multiple OpSelectionMerge/OpLoopMerge instructions, sorts
603 // them to put the "largest" first. A merge instruction is defined as larger
604 // than another when its target merge block post-dominates the other target's
605 // merge block. (This ordering should match the nesting ordering of the source
606 // HLSL).
607 bool sortSelectionMerge(Function &F, BasicBlock &Block) {
608 std::vector<Instruction *> MergeInstructions;
609 for (Instruction &I : Block)
610 if (isMergeInstruction(I: &I))
611 MergeInstructions.push_back(x: &I);
612
613 if (MergeInstructions.size() <= 1)
614 return false;
615
616 Instruction *InsertionPoint = *MergeInstructions.begin();
617
618 PartialOrderingVisitor Visitor(F);
619 std::sort(first: MergeInstructions.begin(), last: MergeInstructions.end(),
620 comp: [&Visitor](Instruction *Left, Instruction *Right) {
621 if (Left == Right)
622 return false;
623 BasicBlock *RightMerge = getDesignatedMergeBlock(I: Right);
624 BasicBlock *LeftMerge = getDesignatedMergeBlock(I: Left);
625 return !Visitor.compare(LHS: RightMerge, RHS: LeftMerge);
626 });
627
628 for (Instruction *I : MergeInstructions) {
629 I->moveBefore(InsertPos: InsertionPoint->getIterator());
630 InsertionPoint = I;
631 }
632
633 return true;
634 }
635
636 // Sorts selection merge headers in |F|.
637 // A is sorted before B if the merge block designated by B is an ancestor of
638 // the one designated by A.
639 bool sortSelectionMergeHeaders(Function &F) {
640 bool Modified = false;
641 for (BasicBlock &BB : F) {
642 Modified |= sortSelectionMerge(F, Block&: BB);
643 }
644 return Modified;
645 }
646
647 // Split basic blocks containing multiple OpLoopMerge/OpSelectionMerge
648 // instructions so each basic block contains only a single merge instruction.
649 bool splitBlocksWithMultipleHeaders(Function &F) {
650 std::stack<BasicBlock *> Work;
651 for (auto &BB : F) {
652 std::vector<Instruction *> MergeInstructions = getMergeInstructions(BB);
653 if (MergeInstructions.size() <= 1)
654 continue;
655 Work.push(x: &BB);
656 }
657
658 const bool Modified = Work.size() > 0;
659 while (Work.size() > 0) {
660 BasicBlock *Header = Work.top();
661 Work.pop();
662
663 std::vector<Instruction *> MergeInstructions =
664 getMergeInstructions(BB&: *Header);
665 for (unsigned i = 1; i < MergeInstructions.size(); i++) {
666 BasicBlock *NewBlock =
667 Header->splitBasicBlock(I: MergeInstructions[i], BBName: "new.header");
668
669 if (getDesignatedContinueBlock(I: MergeInstructions[0]) == nullptr) {
670 BasicBlock *Unreachable = CreateUnreachable(F);
671
672 Instruction *Term = Header->getTerminator();
673 IRBuilder<> Builder(Header);
674 Builder.SetInsertPoint(Term);
675 Builder.CreateCondBr(Cond: Builder.getTrue(), True: NewBlock, False: Unreachable);
676 Term->eraseFromParent();
677 }
678
679 Header = NewBlock;
680 }
681 }
682
683 return Modified;
684 }
685
686 // Adds an OpSelectionMerge to each block with an out-degree >= 2 which
687 // doesn't already have an OpSelectionMerge.
688 bool addMergeForDivergentBlocks(Function &F) {
689 DomTreeBuilder::BBPostDomTree PDT;
690 PDT.recalculate(Func&: F);
691 bool Modified = false;
692
693 auto MergeBlocks = getMergeBlocks(F);
694 auto ContinueBlocks = getContinueBlocks(F);
695
696 for (auto &BB : F) {
697 if (getMergeInstructions(BB).size() != 0)
698 continue;
699
700 std::vector<BasicBlock *> Candidates;
701 for (BasicBlock *Successor : successors(BB: &BB)) {
702 if (MergeBlocks.contains(Ptr: Successor))
703 continue;
704 if (ContinueBlocks.contains(Ptr: Successor))
705 continue;
706 Candidates.push_back(x: Successor);
707 }
708
709 if (Candidates.size() <= 1)
710 continue;
711
712 Modified = true;
713 BasicBlock *Merge = Candidates[0];
714
715 auto MergeAddress = BlockAddress::get(F: Merge->getParent(), BB: Merge);
716 IRBuilder<> Builder(&BB);
717 Builder.SetInsertPoint(BB.getTerminator());
718 createOpSelectMerge(Builder: &Builder, MergeAddress);
719 }
720
721 return Modified;
722 }
723
724 // Gather all the exit nodes for the construct header by |Header| and
725 // containing the blocks |Construct|.
726 std::vector<Edge> getExitsFrom(const BlockSet &Construct,
727 BasicBlock &Header) {
728 std::vector<Edge> Output;
729 visit(Start&: Header, op: [&](BasicBlock *Item) {
730 if (Construct.count(Ptr: Item) == 0)
731 return false;
732
733 for (BasicBlock *Successor : successors(BB: Item)) {
734 if (Construct.count(Ptr: Successor) == 0)
735 Output.emplace_back(args&: Item, args&: Successor);
736 }
737 return true;
738 });
739
740 return Output;
741 }
742
743 // Build a divergent construct tree searching from |BB|.
744 // If |Parent| is not null, this tree is attached to the parent's tree.
745 void constructDivergentConstruct(BlockSet &Visited, Splitter &S,
746 BasicBlock *BB, DivergentConstruct *Parent) {
747 if (Visited.count(Ptr: BB) != 0)
748 return;
749 Visited.insert(Ptr: BB);
750
751 auto MIS = getMergeInstructions(BB&: *BB);
752 if (MIS.size() == 0) {
753 for (BasicBlock *Successor : successors(BB))
754 constructDivergentConstruct(Visited, S, BB: Successor, Parent);
755 return;
756 }
757
758 assert(MIS.size() == 1);
759 Instruction *MI = MIS[0];
760
761 BasicBlock *Merge = getDesignatedMergeBlock(I: MI);
762 BasicBlock *Continue = getDesignatedContinueBlock(I: MI);
763
764 auto Output = std::make_unique<DivergentConstruct>();
765 Output->Header = BB;
766 Output->Merge = Merge;
767 Output->Continue = Continue;
768 Output->Parent = Parent;
769
770 constructDivergentConstruct(Visited, S, BB: Merge, Parent);
771 if (Continue)
772 constructDivergentConstruct(Visited, S, BB: Continue, Parent: Output.get());
773
774 for (BasicBlock *Successor : successors(BB))
775 constructDivergentConstruct(Visited, S, BB: Successor, Parent: Output.get());
776
777 if (Parent)
778 Parent->Children.emplace_back(args: std::move(Output));
779 }
780
781 // Returns the blocks belonging to the divergent construct |Node|.
782 BlockSet getConstructBlocks(Splitter &S, DivergentConstruct *Node) {
783 assert(Node->Header && Node->Merge);
784
785 if (Node->Continue) {
786 auto LoopBlocks = S.getLoopConstructBlocks(Header: Node->Header, Merge: Node->Merge);
787 return BlockSet(LoopBlocks.begin(), LoopBlocks.end());
788 }
789
790 auto SelectionBlocks = S.getSelectionConstructBlocks(Node);
791 return BlockSet(SelectionBlocks.begin(), SelectionBlocks.end());
792 }
793
794 // Fixup the construct |Node| to respect a set of rules defined by the SPIR-V
795 // spec.
796 bool fixupConstruct(Splitter &S, DivergentConstruct *Node) {
797 bool Modified = false;
798 for (auto &Child : Node->Children)
799 Modified |= fixupConstruct(S, Node: Child.get());
800
801 // This construct is the root construct. Does not represent any real
802 // construct, just a way to access the first level of the forest.
803 if (Node->Parent == nullptr)
804 return Modified;
805
806 // This node's parent is the root. Meaning this is a top-level construct.
807 // There can be multiple exists, but all are guaranteed to exit at most 1
808 // construct since we are at first level.
809 if (Node->Parent->Header == nullptr)
810 return Modified;
811
812 // Health check for the structure.
813 assert(Node->Header && Node->Merge);
814 assert(Node->Parent->Header && Node->Parent->Merge);
815
816 BlockSet ConstructBlocks = getConstructBlocks(S, Node);
817 auto Edges = getExitsFrom(Construct: ConstructBlocks, Header&: *Node->Header);
818
819 // No edges exiting the construct.
820 if (Edges.size() < 1)
821 return Modified;
822
823 bool HasBadEdge = Node->Merge == Node->Parent->Merge ||
824 Node->Merge == Node->Parent->Continue;
825 // BasicBlock *Target = Edges[0].second;
826 for (auto &[Src, Dst] : Edges) {
827 // - Breaking from a selection construct: S is a selection construct, S is
828 // the innermost structured
829 // control-flow construct containing A, and B is the merge block for S
830 // - Breaking from the innermost loop: S is the innermost loop construct
831 // containing A,
832 // and B is the merge block for S
833 if (Node->Merge == Dst)
834 continue;
835
836 // Entering the innermost loop’s continue construct: S is the innermost
837 // loop construct containing A, and B is the continue target for S
838 if (Node->Continue == Dst)
839 continue;
840
841 // TODO: what about cases branching to another case in the switch? Seems
842 // to work, but need to double check.
843 HasBadEdge = true;
844 }
845
846 if (!HasBadEdge)
847 return Modified;
848
849 // Create a single exit node gathering all exit edges.
850 BasicBlock *NewExit = S.createSingleExitNode(Header: Node->Header, Edges);
851
852 // Fixup this construct's merge node to point to the new exit.
853 // Note: this algorithm fixes inner-most divergence construct first. So
854 // recursive structures sharing a single merge node are fixed from the
855 // inside toward the outside.
856 auto MergeInstructions = getMergeInstructions(BB&: *Node->Header);
857 assert(MergeInstructions.size() == 1);
858 Instruction *I = MergeInstructions[0];
859 BlockAddress *BA = cast<BlockAddress>(Val: I->getOperand(i: 0));
860 if (BA->getBasicBlock() == Node->Merge) {
861 auto MergeAddress = BlockAddress::get(F: NewExit->getParent(), BB: NewExit);
862 I->setOperand(i: 0, Val: MergeAddress);
863 }
864
865 // Clean up of the possible dangling BockAddr operands to prevent MIR
866 // comments about "address of removed block taken".
867 if (!BA->isConstantUsed())
868 BA->destroyConstant();
869
870 Node->Merge = NewExit;
871 // Regenerate the dom trees.
872 S.invalidate();
873 return true;
874 }
875
876 bool splitCriticalEdges(Function &F) {
877 LoopInfo &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
878 Splitter S(F, LI);
879
880 DivergentConstruct Root;
881 BlockSet Visited;
882 constructDivergentConstruct(Visited, S, BB: &*F.begin(), Parent: &Root);
883 return fixupConstruct(S, Node: &Root);
884 }
885
886 // Simplify branches when possible:
887 // - if the 2 sides of a conditional branch are the same, transforms it to an
888 // unconditional branch.
889 // - if a switch has only 2 distinct successors, converts it to a conditional
890 // branch.
891 bool simplifyBranches(Function &F) {
892 bool Modified = false;
893
894 for (BasicBlock &BB : F) {
895 SwitchInst *SI = dyn_cast<SwitchInst>(Val: BB.getTerminator());
896 if (!SI)
897 continue;
898 if (SI->getNumCases() > 1)
899 continue;
900
901 Modified = true;
902 IRBuilder<> Builder(&BB);
903 Builder.SetInsertPoint(SI);
904
905 if (SI->getNumCases() == 0) {
906 Builder.CreateBr(Dest: SI->getDefaultDest());
907 } else {
908 Value *Condition =
909 Builder.CreateCmp(Pred: CmpInst::ICMP_EQ, LHS: SI->getCondition(),
910 RHS: SI->case_begin()->getCaseValue());
911 Builder.CreateCondBr(Cond: Condition, True: SI->case_begin()->getCaseSuccessor(),
912 False: SI->getDefaultDest());
913 }
914 SI->eraseFromParent();
915 }
916
917 return Modified;
918 }
919
920 // Makes sure every case target in |F| is unique. If 2 cases branch to the
921 // same basic block, one of the targets is updated so it jumps to a new basic
922 // block ending with a single unconditional branch to the original target.
923 bool splitSwitchCases(Function &F) {
924 bool Modified = false;
925
926 for (BasicBlock &BB : F) {
927 SwitchInst *SI = dyn_cast<SwitchInst>(Val: BB.getTerminator());
928 if (!SI)
929 continue;
930
931 BlockSet Seen;
932 Seen.insert(Ptr: SI->getDefaultDest());
933
934 auto It = SI->case_begin();
935 while (It != SI->case_end()) {
936 BasicBlock *Target = It->getCaseSuccessor();
937 if (Seen.count(Ptr: Target) == 0) {
938 Seen.insert(Ptr: Target);
939 ++It;
940 continue;
941 }
942
943 Modified = true;
944 BasicBlock *NewTarget =
945 BasicBlock::Create(Context&: F.getContext(), Name: "new.sw.case", Parent: &F);
946 IRBuilder<> Builder(NewTarget);
947 Builder.CreateBr(Dest: Target);
948 SI->addCase(OnVal: It->getCaseValue(), Dest: NewTarget);
949 It = SI->removeCase(I: It);
950 }
951 }
952
953 return Modified;
954 }
955
956 // Removes blocks not contributing to any structured CFG. This assumes there
957 // is no PHI nodes.
958 bool removeUselessBlocks(Function &F) {
959 std::vector<BasicBlock *> ToRemove;
960
961 auto MergeBlocks = getMergeBlocks(F);
962 auto ContinueBlocks = getContinueBlocks(F);
963
964 for (BasicBlock &BB : F) {
965 if (BB.size() != 1)
966 continue;
967
968 if (isa<ReturnInst>(Val: BB.getTerminator()))
969 continue;
970
971 if (MergeBlocks.count(Ptr: &BB) != 0 || ContinueBlocks.count(Ptr: &BB) != 0)
972 continue;
973
974 if (BB.getUniqueSuccessor() == nullptr)
975 continue;
976
977 BasicBlock *Successor = BB.getUniqueSuccessor();
978 std::vector<BasicBlock *> Predecessors(predecessors(BB: &BB).begin(),
979 predecessors(BB: &BB).end());
980 for (BasicBlock *Predecessor : Predecessors)
981 replaceBranchTargets(BB: Predecessor, OldTarget: &BB, NewTarget: Successor);
982 ToRemove.push_back(x: &BB);
983 }
984
985 for (BasicBlock *BB : ToRemove)
986 BB->eraseFromParent();
987
988 return ToRemove.size() != 0;
989 }
990
991 bool addHeaderToRemainingDivergentDAG(Function &F) {
992 bool Modified = false;
993
994 auto MergeBlocks = getMergeBlocks(F);
995 auto ContinueBlocks = getContinueBlocks(F);
996 auto HeaderBlocks = getHeaderBlocks(F);
997
998 DomTreeBuilder::BBDomTree DT;
999 DomTreeBuilder::BBPostDomTree PDT;
1000 PDT.recalculate(Func&: F);
1001 DT.recalculate(Func&: F);
1002
1003 for (BasicBlock &BB : F) {
1004 if (HeaderBlocks.count(Ptr: &BB) != 0)
1005 continue;
1006 if (succ_size(BB: &BB) < 2)
1007 continue;
1008
1009 size_t CandidateEdges = 0;
1010 for (BasicBlock *Successor : successors(BB: &BB)) {
1011 if (MergeBlocks.count(Ptr: Successor) != 0 ||
1012 ContinueBlocks.count(Ptr: Successor) != 0)
1013 continue;
1014 if (HeaderBlocks.count(Ptr: Successor) != 0)
1015 continue;
1016 CandidateEdges += 1;
1017 }
1018
1019 if (CandidateEdges <= 1)
1020 continue;
1021
1022 BasicBlock *Header = &BB;
1023 BasicBlock *Merge = PDT.getNode(BB: &BB)->getIDom()->getBlock();
1024
1025 bool HasBadBlock = false;
1026 visit(Start&: *Header, op: [&](const BasicBlock *Node) {
1027 if (DT.dominates(A: Header, B: Node))
1028 return false;
1029 if (PDT.dominates(A: Merge, B: Node))
1030 return false;
1031 if (Node == Header || Node == Merge)
1032 return true;
1033
1034 HasBadBlock |= MergeBlocks.count(Ptr: Node) != 0 ||
1035 ContinueBlocks.count(Ptr: Node) != 0 ||
1036 HeaderBlocks.count(Ptr: Node) != 0;
1037 return !HasBadBlock;
1038 });
1039
1040 if (HasBadBlock)
1041 continue;
1042
1043 Modified = true;
1044
1045 if (Merge == nullptr) {
1046 Merge = *successors(BB: Header).begin();
1047 IRBuilder<> Builder(Header);
1048 Builder.SetInsertPoint(Header->getTerminator());
1049
1050 auto MergeAddress = BlockAddress::get(F: Merge->getParent(), BB: Merge);
1051 createOpSelectMerge(Builder: &Builder, MergeAddress);
1052 continue;
1053 }
1054
1055 Instruction *SplitInstruction = Merge->getTerminator();
1056 if (isMergeInstruction(I: SplitInstruction->getPrevNode()))
1057 SplitInstruction = SplitInstruction->getPrevNode();
1058 BasicBlock *NewMerge =
1059 Merge->splitBasicBlockBefore(I: SplitInstruction, BBName: "new.merge");
1060
1061 IRBuilder<> Builder(Header);
1062 Builder.SetInsertPoint(Header->getTerminator());
1063
1064 auto MergeAddress = BlockAddress::get(F: NewMerge->getParent(), BB: NewMerge);
1065 createOpSelectMerge(Builder: &Builder, MergeAddress);
1066 }
1067
1068 return Modified;
1069 }
1070
1071public:
1072 static char ID;
1073
1074 SPIRVStructurizer() : FunctionPass(ID) {}
1075
1076 bool runOnFunction(Function &F) override {
1077 bool Modified = false;
1078
1079 // In LLVM, Switches are allowed to have several cases branching to the same
1080 // basic block. This is allowed in SPIR-V, but can make structurizing SPIR-V
1081 // harder, so first remove edge cases.
1082 Modified |= splitSwitchCases(F);
1083
1084 // LLVM allows conditional branches to have both side jumping to the same
1085 // block. It also allows switched to have a single default, or just one
1086 // case. Cleaning this up now.
1087 Modified |= simplifyBranches(F);
1088
1089 // At this state, we should have a reducible CFG with cycles.
1090 // STEP 1: Adding OpLoopMerge instructions to loop headers.
1091 Modified |= addMergeForLoops(F);
1092
1093 // STEP 2: adding OpSelectionMerge to each node with an in-degree >= 2.
1094 Modified |= addMergeForNodesWithMultiplePredecessors(F);
1095
1096 // STEP 3:
1097 // Sort selection merge, the largest construct goes first.
1098 // This simplifies the next step.
1099 Modified |= sortSelectionMergeHeaders(F);
1100
1101 // STEP 4: As this stage, we can have a single basic block with multiple
1102 // OpLoopMerge/OpSelectionMerge instructions. Splitting this block so each
1103 // BB has a single merge instruction.
1104 Modified |= splitBlocksWithMultipleHeaders(F);
1105
1106 // STEP 5: In the previous steps, we added merge blocks the loops and
1107 // natural merge blocks (in-degree >= 2). What remains are conditions with
1108 // an exiting branch (return, unreachable). In such case, we must start from
1109 // the header, and add headers to divergent construct with no headers.
1110 Modified |= addMergeForDivergentBlocks(F);
1111
1112 // STEP 6: At this stage, we have several divergent construct defines by a
1113 // header and a merge block. But their boundaries have no constraints: a
1114 // construct exit could be outside of the parents' construct exit. Such
1115 // edges are called critical edges. What we need is to split those edges
1116 // into several parts. Each part exiting the parent's construct by its merge
1117 // block.
1118 Modified |= splitCriticalEdges(F);
1119
1120 // STEP 7: The previous steps possibly created a lot of "proxy" blocks.
1121 // Blocks with a single unconditional branch, used to create a valid
1122 // divergent construct tree. Some nodes are still requires (e.g: nodes
1123 // allowing a valid exit through the parent's merge block). But some are
1124 // left-overs of past transformations, and could cause actual validation
1125 // issues. E.g: the SPIR-V spec allows a construct to break to the parents
1126 // loop construct without an OpSelectionMerge, but this requires a straight
1127 // jump. If a proxy block lies between the conditional branch and the
1128 // parent's merge, the CFG is not valid.
1129 Modified |= removeUselessBlocks(F);
1130
1131 // STEP 8: Final fix-up steps: our tree boundaries are correct, but some
1132 // blocks are branching with no header. Those are often simple conditional
1133 // branches with 1 or 2 returning edges. Adding a header for those.
1134 Modified |= addHeaderToRemainingDivergentDAG(F);
1135
1136 // STEP 9: sort basic blocks to match both the LLVM & SPIR-V requirements.
1137 Modified |= sortBlocks(F);
1138
1139 return Modified;
1140 }
1141
1142 void getAnalysisUsage(AnalysisUsage &AU) const override {
1143 AU.addRequired<DominatorTreeWrapperPass>();
1144 AU.addRequired<LoopInfoWrapperPass>();
1145 AU.addRequired<SPIRVConvergenceRegionAnalysisWrapperPass>();
1146
1147 AU.addPreserved<SPIRVConvergenceRegionAnalysisWrapperPass>();
1148 FunctionPass::getAnalysisUsage(AU);
1149 }
1150
1151 void createOpSelectMerge(IRBuilder<> *Builder, BlockAddress *MergeAddress) {
1152 Instruction *BBTerminatorInst = Builder->GetInsertBlock()->getTerminator();
1153
1154 MDNode *MDNode = BBTerminatorInst->getMetadata(Kind: "hlsl.controlflow.hint");
1155
1156 ConstantInt *BranchHint = ConstantInt::get(Ty: Builder->getInt32Ty(), V: 0);
1157
1158 if (MDNode) {
1159 assert(MDNode->getNumOperands() == 2 &&
1160 "invalid metadata hlsl.controlflow.hint");
1161 BranchHint = mdconst::extract<ConstantInt>(MD: MDNode->getOperand(I: 1));
1162 }
1163
1164 SmallVector<Value *, 2> Args = {MergeAddress, BranchHint};
1165
1166 Builder->CreateIntrinsic(ID: Intrinsic::spv_selection_merge,
1167 OverloadTypes: {MergeAddress->getType()}, Args);
1168 }
1169};
1170} // anonymous namespace
1171
1172char SPIRVStructurizer::ID = 0;
1173
1174INITIALIZE_PASS_BEGIN(SPIRVStructurizer, "spirv-structurizer",
1175 "structurize SPIRV", false, false)
1176INITIALIZE_PASS_DEPENDENCY(LoopSimplify)
1177INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
1178INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass)
1179INITIALIZE_PASS_DEPENDENCY(SPIRVConvergenceRegionAnalysisWrapperPass)
1180
1181INITIALIZE_PASS_END(SPIRVStructurizer, "spirv-structurizer",
1182 "structurize SPIRV", false, false)
1183
1184FunctionPass *llvm::createSPIRVStructurizerPass() {
1185 return new SPIRVStructurizer();
1186}
1187
1188PreservedAnalyses SPIRVStructurizerWrapper::run(Function &F,
1189 FunctionAnalysisManager &AF) {
1190
1191 auto FPM = legacy::FunctionPassManager(F.getParent());
1192 FPM.add(P: createSPIRVStructurizerPass());
1193
1194 if (!FPM.run(F))
1195 return PreservedAnalyses::all();
1196 PreservedAnalyses PA;
1197 PA.preserveSet<CFGAnalyses>();
1198 return PA;
1199}
1200