1//===-- NVPTXLowerUnreachable.cpp - Lower unreachables to exit =====--===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// PTX does not have a notion of `unreachable`, which results in emitted basic
10// blocks having an edge to the next block:
11//
12// block1:
13// call @does_not_return();
14// // unreachable
15// block2:
16// // ptxas will create a CFG edge from block1 to block2
17//
18// This may result in significant changes to the control flow graph, e.g., when
19// LLVM moves unreachable blocks to the end of the function. That's a problem
20// in the context of divergent control flow, as `ptxas` uses the CFG to
21// determine divergent regions, and some intructions may not be executed
22// divergently.
23//
24// For example, `bar.sync` is not allowed to be executed divergently on Pascal
25// or earlier. If we start with the following:
26//
27// entry:
28// // start of divergent region
29// @%p0 bra cont;
30// @%p1 bra unlikely;
31// ...
32// bra.uni cont;
33// unlikely:
34// ...
35// // unreachable
36// cont:
37// // end of divergent region
38// bar.sync 0;
39// bra.uni exit;
40// exit:
41// ret;
42//
43// it is transformed by the branch-folder and block-placement passes to:
44//
45// entry:
46// // start of divergent region
47// @%p0 bra cont;
48// @%p1 bra unlikely;
49// ...
50// bra.uni cont;
51// cont:
52// bar.sync 0;
53// bra.uni exit;
54// unlikely:
55// ...
56// // unreachable
57// exit:
58// // end of divergent region
59// ret;
60//
61// After moving the `unlikely` block to the end of the function, it has an edge
62// to the `exit` block, which widens the divergent region and makes the
63// `bar.sync` instruction happen divergently.
64//
65// To work around this, we add an `exit` instruction before every `unreachable`,
66// as `ptxas` understands that exit terminates the CFG. We do only do this if
67// `unreachable` is not lowered to `trap`, which has the same effect (although
68// with current versions of `ptxas` only because it is emited as `trap; exit;`).
69//
70//===----------------------------------------------------------------------===//
71
72#include "NVPTX.h"
73#include "llvm/IR/Function.h"
74#include "llvm/IR/InlineAsm.h"
75#include "llvm/IR/Instructions.h"
76#include "llvm/IR/Type.h"
77#include "llvm/Pass.h"
78
79using namespace llvm;
80
81namespace {
82class NVPTXLowerUnreachable : public FunctionPass {
83 StringRef getPassName() const override;
84 bool runOnFunction(Function &F) override;
85 bool isLoweredToTrap(const UnreachableInst &I) const;
86
87public:
88 static char ID; // Pass identification, replacement for typeid
89 NVPTXLowerUnreachable(bool TrapUnreachable, bool NoTrapAfterNoreturn)
90 : FunctionPass(ID), TrapUnreachable(TrapUnreachable),
91 NoTrapAfterNoreturn(NoTrapAfterNoreturn) {}
92
93private:
94 bool TrapUnreachable;
95 bool NoTrapAfterNoreturn;
96};
97} // namespace
98
99char NVPTXLowerUnreachable::ID = 1;
100
101INITIALIZE_PASS(NVPTXLowerUnreachable, "nvptx-lower-unreachable",
102 "Lower Unreachable", false, false)
103
104StringRef NVPTXLowerUnreachable::getPassName() const {
105 return "add an exit instruction before every unreachable";
106}
107
108// =============================================================================
109// Returns whether a `trap` intrinsic would be emitted before I.
110//
111// This is a copy of the logic in SelectionDAGBuilder::visitUnreachable().
112// =============================================================================
113bool NVPTXLowerUnreachable::isLoweredToTrap(const UnreachableInst &I) const {
114 if (const auto *Call = dyn_cast_or_null<CallInst>(Val: I.getPrevNode())) {
115 // We've already emitted a non-continuable trap.
116 if (Call->isNonContinuableTrap())
117 return true;
118
119 // No traps are emitted for calls that do not return
120 // when this option is enabled.
121 if (NoTrapAfterNoreturn && Call->doesNotReturn())
122 return false;
123 }
124
125 // In all other cases, we will generate a trap if TrapUnreachable is set.
126 return TrapUnreachable;
127}
128
129// =============================================================================
130// Main function for this pass.
131// =============================================================================
132bool NVPTXLowerUnreachable::runOnFunction(Function &F) {
133 if (skipFunction(F))
134 return false;
135 // Early out iff isLoweredToTrap() always returns true.
136 if (TrapUnreachable && !NoTrapAfterNoreturn)
137 return false;
138
139 LLVMContext &C = F.getContext();
140 FunctionType *ExitFTy = FunctionType::get(Result: Type::getVoidTy(C), isVarArg: false);
141 InlineAsm *Exit = InlineAsm::get(Ty: ExitFTy, AsmString: "exit;", Constraints: "", hasSideEffects: true);
142
143 bool Changed = false;
144 for (auto &BB : F)
145 for (auto &I : BB) {
146 if (auto unreachableInst = dyn_cast<UnreachableInst>(Val: &I)) {
147 if (isLoweredToTrap(I: *unreachableInst))
148 continue; // trap is emitted as `trap; exit;`.
149 CallInst::Create(Ty: ExitFTy, F: Exit, NameStr: "", InsertBefore: unreachableInst->getIterator());
150 Changed = true;
151 }
152 }
153 return Changed;
154}
155
156FunctionPass *llvm::createNVPTXLowerUnreachablePass(bool TrapUnreachable,
157 bool NoTrapAfterNoreturn) {
158 return new NVPTXLowerUnreachable(TrapUnreachable, NoTrapAfterNoreturn);
159}
160