1//===-------- SplitModuleByCategory.cpp - split a module by categories ----===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8// See comments in the header.
9//===----------------------------------------------------------------------===//
10
11#include "llvm/Transforms/Utils/SplitModuleByCategory.h"
12#include "llvm/ADT/SetVector.h"
13#include "llvm/ADT/SmallPtrSet.h"
14#include "llvm/ADT/StringExtras.h"
15#include "llvm/IR/Function.h"
16#include "llvm/IR/InstIterator.h"
17#include "llvm/IR/Instructions.h"
18#include "llvm/IR/Module.h"
19#include "llvm/Support/Debug.h"
20#include "llvm/Transforms/Utils/Cloning.h"
21
22#include <map>
23#include <utility>
24
25using namespace llvm;
26
27#define DEBUG_TYPE "split-module-by-category"
28
29namespace {
30
31// A vector that contains a group of function with the same category.
32using EntryPointSet = SetVector<const Function *>;
33
34/// Represents a group of functions with one category.
35struct EntryPointGroup {
36 int ID;
37 EntryPointSet Functions;
38
39 EntryPointGroup() = default;
40
41 EntryPointGroup(int ID, EntryPointSet &&Functions = EntryPointSet())
42 : ID(ID), Functions(std::move(Functions)) {}
43
44 void clear() { Functions.clear(); }
45
46#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
47 LLVM_DUMP_METHOD void dump() const {
48 constexpr size_t INDENT = 4;
49 dbgs().indent(INDENT) << "ENTRY POINTS"
50 << " " << ID << " {\n";
51 for (const Function *F : Functions)
52 dbgs().indent(INDENT) << " " << F->getName() << "\n";
53
54 dbgs().indent(INDENT) << "}\n";
55 }
56#endif
57};
58
59/// Annotates an llvm::Module with information necessary to perform and track
60/// the result of code (llvm::Module instances) splitting:
61/// - entry points group from the module.
62class ModuleDesc {
63 std::unique_ptr<Module> M;
64 EntryPointGroup EntryPoints;
65
66public:
67 ModuleDesc(std::unique_ptr<Module> M,
68 EntryPointGroup &&EntryPoints = EntryPointGroup())
69 : M(std::move(M)), EntryPoints(std::move(EntryPoints)) {
70 assert(this->M && "Module should be non-null");
71 }
72
73 Module &getModule() { return *M; }
74 const Module &getModule() const { return *M; }
75
76 std::unique_ptr<Module> releaseModule() {
77 EntryPoints.clear();
78 return std::move(M);
79 }
80
81#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
82 LLVM_DUMP_METHOD void dump() const {
83 dbgs() << "ModuleDesc[" << M->getName() << "] {\n";
84 EntryPoints.dump();
85 dbgs() << "}\n";
86 }
87#endif
88};
89
90bool isKernel(const Function &F) {
91 return F.getCallingConv() == CallingConv::SPIR_KERNEL ||
92 F.getCallingConv() == CallingConv::AMDGPU_KERNEL ||
93 F.getCallingConv() == CallingConv::PTX_Kernel;
94}
95
96// Represents "dependency" or "use" graph of global objects (functions and
97// global variables) in a module. It is used during code split to
98// understand which global variables and functions (other than entry points)
99// should be included into a split module.
100//
101// Nodes of the graph represent LLVM's GlobalObjects, edges "A" -> "B" represent
102// the fact that if "A" is included into a module, then "B" should be included
103// as well.
104//
105// Examples of dependencies which are represented in this graph:
106// - Function FA calls function FB
107// - Function FA uses global variable GA
108// - Global variable GA references (initialized with) function FB
109// - Function FA stores address of a function FB somewhere
110//
111// The following cases are treated as dependencies between global objects:
112// 1. Global object A is used by a global object B in any way (store,
113// bitcast, phi node, call, etc.): "A" -> "B" edge will be added to the
114// graph;
115// 2. function A performs an indirect call of a function with signature S and
116// there is a function B with signature S. "A" -> "B" edge will be added to
117// the graph;
118class DependencyGraph {
119public:
120 using GlobalSet = SmallPtrSet<const GlobalValue *, 16>;
121
122 DependencyGraph(const Module &M) {
123 // Group functions by their signature to handle case (2) described above
124 DenseMap<const FunctionType *, DependencyGraph::GlobalSet>
125 FuncTypeToFuncsMap;
126 for (const Function &F : M.functions()) {
127 // Kernels can't be called (either directly or indirectly).
128 if (isKernel(F))
129 continue;
130
131 FuncTypeToFuncsMap[F.getFunctionType()].insert(Ptr: &F);
132 }
133
134 for (const Function &F : M.functions()) {
135 // case (1), see comment above the class definition
136 for (const Value *U : F.users())
137 addUserToGraphRecursively(Root: cast<const User>(Val: U), V: &F);
138
139 // case (2), see comment above the class definition
140 for (const Instruction &I : instructions(F)) {
141 const CallBase *CB = dyn_cast<CallBase>(Val: &I);
142 if (!CB || !CB->isIndirectCall()) // Direct calls were handled above
143 continue;
144
145 const FunctionType *Signature = CB->getFunctionType();
146 GlobalSet &PotentialCallees = FuncTypeToFuncsMap[Signature];
147 Graph[&F].insert(I: PotentialCallees.begin(), E: PotentialCallees.end());
148 }
149 }
150
151 // And every global variable (but their handling is a bit simpler)
152 for (const GlobalVariable &GV : M.globals())
153 for (const Value *U : GV.users())
154 addUserToGraphRecursively(Root: cast<const User>(Val: U), V: &GV);
155 }
156
157 iterator_range<GlobalSet::const_iterator>
158 dependencies(const GlobalValue *Val) const {
159 auto It = Graph.find(Val);
160 return (It == Graph.end())
161 ? make_range(x: EmptySet.begin(), y: EmptySet.end())
162 : make_range(x: It->second.begin(), y: It->second.end());
163 }
164
165private:
166 void addUserToGraphRecursively(const User *Root, const GlobalValue *V) {
167 SmallVector<const User *, 8> WorkList;
168 WorkList.push_back(Elt: Root);
169
170 while (!WorkList.empty()) {
171 const User *U = WorkList.pop_back_val();
172 if (const auto *I = dyn_cast<const Instruction>(Val: U)) {
173 const Function *UFunc = I->getFunction();
174 Graph[UFunc].insert(Ptr: V);
175 } else if (isa<const Constant>(Val: U)) {
176 if (const auto *GV = dyn_cast<const GlobalVariable>(Val: U))
177 Graph[GV].insert(Ptr: V);
178 // This could be a global variable or some constant expression (like
179 // bitcast or gep). We trace users of this constant further to reach
180 // global objects they are used by and add them to the graph.
181 for (const User *UU : U->users())
182 WorkList.push_back(Elt: UU);
183 } else {
184 llvm_unreachable("Unhandled type of function user");
185 }
186 }
187 }
188
189 DenseMap<const GlobalValue *, GlobalSet> Graph;
190 SmallPtrSet<const GlobalValue *, 1> EmptySet;
191};
192
193void collectFunctionsAndGlobalVariablesToExtract(
194 SetVector<const GlobalValue *> &GVs, const Module &M,
195 const EntryPointGroup &ModuleEntryPoints, const DependencyGraph &DG) {
196 // We start with module entry points
197 for (const Function *F : ModuleEntryPoints.Functions)
198 GVs.insert(X: F);
199
200 // Non-discardable global variables are also include into the initial set
201 for (const GlobalVariable &GV : M.globals())
202 if (!GV.isDiscardableIfUnused())
203 GVs.insert(X: &GV);
204
205 // GVs has SetVector type. This type inserts a value only if it is not yet
206 // present there. So, recursion is not expected here.
207 size_t Idx = 0;
208 while (Idx < GVs.size()) {
209 const GlobalValue *Obj = GVs[Idx++];
210
211 for (const GlobalValue *Dep : DG.dependencies(Val: Obj)) {
212 if (const auto *Func = dyn_cast<const Function>(Val: Dep)) {
213 if (!Func->isDeclaration())
214 GVs.insert(X: Func);
215 } else {
216 GVs.insert(X: Dep); // Global variables are added unconditionally
217 }
218 }
219 }
220}
221
222ModuleDesc extractSubModule(const Module &M,
223 const SetVector<const GlobalValue *> &GVs,
224 EntryPointGroup &&ModuleEntryPoints) {
225 ValueToValueMapTy VMap;
226 // Clone definitions only for needed globals. Others will be added as
227 // declarations and removed later.
228 std::unique_ptr<Module> SubM = CloneModule(
229 M, VMap, ShouldCloneDefinition: [&](const GlobalValue *GV) { return GVs.contains(key: GV); });
230 // Replace entry points with cloned ones.
231 EntryPointSet NewEPs;
232 const EntryPointSet &EPs = ModuleEntryPoints.Functions;
233 llvm::for_each(
234 Range: EPs, F: [&](const Function *F) { NewEPs.insert(X: cast<Function>(Val&: VMap[F])); });
235 ModuleEntryPoints.Functions = std::move(NewEPs);
236 return ModuleDesc{std::move(SubM), std::move(ModuleEntryPoints)};
237}
238
239// The function produces a copy of input LLVM IR module M with only those
240// functions and globals that can be called from entry points that are specified
241// in ModuleEntryPoints vector, in addition to the entry point functions.
242ModuleDesc extractCallGraph(const Module &M,
243 EntryPointGroup &&ModuleEntryPoints,
244 const DependencyGraph &DG) {
245 SetVector<const GlobalValue *> GVs;
246 collectFunctionsAndGlobalVariablesToExtract(GVs, M, ModuleEntryPoints, DG);
247
248 ModuleDesc SplitM = extractSubModule(M, GVs, ModuleEntryPoints: std::move(ModuleEntryPoints));
249 LLVM_DEBUG(SplitM.dump());
250 return SplitM;
251}
252
253using EntryPointGroupVec = SmallVector<EntryPointGroup>;
254
255/// Module Splitter.
256/// It gets a module and a collection of entry points groups.
257/// Each group specifies subset entry points from input module that should be
258/// included in a split module.
259class ModuleSplitter {
260private:
261 std::unique_ptr<Module> M;
262 EntryPointGroupVec Groups;
263 DependencyGraph DG;
264
265private:
266 EntryPointGroup drawEntryPointGroup() {
267 assert(Groups.size() > 0 && "Reached end of entry point groups list.");
268 EntryPointGroup Group = std::move(Groups.back());
269 Groups.pop_back();
270 return Group;
271 }
272
273public:
274 ModuleSplitter(std::unique_ptr<Module> Module, EntryPointGroupVec &&GroupVec)
275 : M(std::move(Module)), Groups(std::move(GroupVec)), DG(*M) {
276 assert(!Groups.empty() && "Entry points groups collection is empty!");
277 }
278
279 /// Gets next subsequence of entry points in an input module and provides
280 /// split submodule containing these entry points and their dependencies.
281 ModuleDesc getNextSplit() {
282 return extractCallGraph(M: *M, ModuleEntryPoints: drawEntryPointGroup(), DG);
283 }
284
285 /// Check that there are still submodules to split.
286 bool hasMoreSplits() const { return Groups.size() > 0; }
287};
288
289EntryPointGroupVec selectEntryPointGroups(
290 const Module &M, function_ref<std::optional<int>(const Function &F)> EPC) {
291 // std::map is used here to ensure stable ordering of entry point groups,
292 // which is based on their contents, this greatly helps LIT tests
293 // Note: EPC is allowed to return big identifiers. Therefore, we use
294 // std::map + SmallVector approach here.
295 std::map<int, EntryPointSet> EntryPointsMap;
296
297 for (const auto &F : M.functions())
298 if (std::optional<int> Category = EPC(F); Category)
299 EntryPointsMap[*Category].insert(X: &F);
300
301 EntryPointGroupVec Groups;
302 Groups.reserve(N: EntryPointsMap.size());
303 for (auto &[Key, EntryPoints] : EntryPointsMap)
304 Groups.emplace_back(Args: Key, Args: std::move(EntryPoints));
305
306 return Groups;
307}
308
309} // namespace
310
311void llvm::splitModuleTransitiveFromEntryPoints(
312 std::unique_ptr<Module> M,
313 function_ref<std::optional<int>(const Function &F)> EntryPointCategorizer,
314 function_ref<void(std::unique_ptr<Module> Part)> Callback) {
315 EntryPointGroupVec Groups = selectEntryPointGroups(M: *M, EPC: EntryPointCategorizer);
316 ModuleSplitter Splitter(std::move(M), std::move(Groups));
317 while (Splitter.hasMoreSplits()) {
318 ModuleDesc MD = Splitter.getNextSplit();
319 Callback(MD.releaseModule());
320 }
321}
322