1//===-- NVPTXLowerUnreachable.cpp - Lower unreachables to exit =====--===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// PTX does not have a notion of `unreachable`, which results in emitted basic
10// blocks having an edge to the next block:
11//
12// block1:
13// call @does_not_return();
14// // unreachable
15// block2:
16// // ptxas will create a CFG edge from block1 to block2
17//
18// This may result in significant changes to the control flow graph, e.g., when
19// LLVM moves unreachable blocks to the end of the function. That's a problem
20// in the context of divergent control flow, as `ptxas` uses the CFG to
21// determine divergent regions, and some intructions may not be executed
22// divergently.
23//
24// For example, `bar.sync` is not allowed to be executed divergently on Pascal
25// or earlier. If we start with the following:
26//
27// entry:
28// // start of divergent region
29// @%p0 bra cont;
30// @%p1 bra unlikely;
31// ...
32// bra.uni cont;
33// unlikely:
34// ...
35// // unreachable
36// cont:
37// // end of divergent region
38// bar.sync 0;
39// bra.uni exit;
40// exit:
41// ret;
42//
43// it is transformed by the branch-folder and block-placement passes to:
44//
45// entry:
46// // start of divergent region
47// @%p0 bra cont;
48// @%p1 bra unlikely;
49// ...
50// bra.uni cont;
51// cont:
52// bar.sync 0;
53// bra.uni exit;
54// unlikely:
55// ...
56// // unreachable
57// exit:
58// // end of divergent region
59// ret;
60//
61// After moving the `unlikely` block to the end of the function, it has an edge
62// to the `exit` block, which widens the divergent region and makes the
63// `bar.sync` instruction happen divergently.
64//
65// To work around this, we add an `exit` instruction before every `unreachable`,
66// as `ptxas` understands that exit terminates the CFG. We do only do this if
67// `unreachable` is not lowered to `trap`, which has the same effect (although
68// with current versions of `ptxas` only because it is emited as `trap; exit;`).
69//
70//===----------------------------------------------------------------------===//
71
72#include "NVPTX.h"
73#include "llvm/IR/Function.h"
74#include "llvm/IR/InlineAsm.h"
75#include "llvm/IR/Instructions.h"
76#include "llvm/IR/Type.h"
77#include "llvm/Pass.h"
78
79using namespace llvm;
80
81namespace llvm {
82void initializeNVPTXLowerUnreachablePass(PassRegistry &);
83}
84
85namespace {
86class NVPTXLowerUnreachable : public FunctionPass {
87 StringRef getPassName() const override;
88 bool runOnFunction(Function &F) override;
89 bool isLoweredToTrap(const UnreachableInst &I) const;
90
91public:
92 static char ID; // Pass identification, replacement for typeid
93 NVPTXLowerUnreachable(bool TrapUnreachable, bool NoTrapAfterNoreturn)
94 : FunctionPass(ID), TrapUnreachable(TrapUnreachable),
95 NoTrapAfterNoreturn(NoTrapAfterNoreturn) {}
96
97private:
98 bool TrapUnreachable;
99 bool NoTrapAfterNoreturn;
100};
101} // namespace
102
103char NVPTXLowerUnreachable::ID = 1;
104
105INITIALIZE_PASS(NVPTXLowerUnreachable, "nvptx-lower-unreachable",
106 "Lower Unreachable", false, false)
107
108StringRef NVPTXLowerUnreachable::getPassName() const {
109 return "add an exit instruction before every unreachable";
110}
111
112// =============================================================================
113// Returns whether a `trap` intrinsic should be emitted before I.
114//
115// This is a copy of the logic in SelectionDAGBuilder::visitUnreachable().
116// =============================================================================
117bool NVPTXLowerUnreachable::isLoweredToTrap(const UnreachableInst &I) const {
118 if (!TrapUnreachable)
119 return false;
120 if (!NoTrapAfterNoreturn)
121 return true;
122 const CallInst *Call = dyn_cast_or_null<CallInst>(Val: I.getPrevNode());
123 return Call && Call->doesNotReturn();
124}
125
126// =============================================================================
127// Main function for this pass.
128// =============================================================================
129bool NVPTXLowerUnreachable::runOnFunction(Function &F) {
130 if (skipFunction(F))
131 return false;
132 // Early out iff isLoweredToTrap() always returns true.
133 if (TrapUnreachable && !NoTrapAfterNoreturn)
134 return false;
135
136 LLVMContext &C = F.getContext();
137 FunctionType *ExitFTy = FunctionType::get(Result: Type::getVoidTy(C), isVarArg: false);
138 InlineAsm *Exit = InlineAsm::get(Ty: ExitFTy, AsmString: "exit;", Constraints: "", hasSideEffects: true);
139
140 bool Changed = false;
141 for (auto &BB : F)
142 for (auto &I : BB) {
143 if (auto unreachableInst = dyn_cast<UnreachableInst>(Val: &I)) {
144 if (isLoweredToTrap(I: *unreachableInst))
145 continue; // trap is emitted as `trap; exit;`.
146 CallInst::Create(Ty: ExitFTy, F: Exit, NameStr: "", InsertBefore: unreachableInst->getIterator());
147 Changed = true;
148 }
149 }
150 return Changed;
151}
152
153FunctionPass *llvm::createNVPTXLowerUnreachablePass(bool TrapUnreachable,
154 bool NoTrapAfterNoreturn) {
155 return new NVPTXLowerUnreachable(TrapUnreachable, NoTrapAfterNoreturn);
156}
157