1//===-------- SplitModuleByCategory.cpp - split a module by categories ----===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8// See comments in the header.
9//===----------------------------------------------------------------------===//
10
11#include "llvm/Transforms/Utils/SplitModuleByCategory.h"
12#include "llvm/ADT/SetVector.h"
13#include "llvm/ADT/SmallPtrSet.h"
14#include "llvm/ADT/StringExtras.h"
15#include "llvm/IR/Function.h"
16#include "llvm/IR/InstIterator.h"
17#include "llvm/IR/Instructions.h"
18#include "llvm/IR/Module.h"
19#include "llvm/Support/Debug.h"
20#include "llvm/Transforms/Utils/Cloning.h"
21
22#include <map>
23#include <utility>
24
25using namespace llvm;
26
27#define DEBUG_TYPE "split-module-by-category"
28
29namespace {
30
31// A vector that contains a group of function with the same category.
32using EntryPointSet = SetVector<const Function *>;
33
34/// Represents a group of functions with one category.
35struct EntryPointGroup {
36 int ID;
37 EntryPointSet Functions;
38
39 EntryPointGroup() = default;
40
41 EntryPointGroup(int ID, EntryPointSet &&Functions = EntryPointSet())
42 : ID(ID), Functions(std::move(Functions)) {}
43
44 void clear() { Functions.clear(); }
45
46#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
47 LLVM_DUMP_METHOD void dump() const {
48 constexpr size_t INDENT = 4;
49 dbgs().indent(INDENT) << "ENTRY POINTS"
50 << " " << ID << " {\n";
51 for (const Function *F : Functions)
52 dbgs().indent(INDENT) << " " << F->getName() << "\n";
53
54 dbgs().indent(INDENT) << "}\n";
55 }
56#endif
57};
58
59/// Annotates an llvm::Module with information necessary to perform and track
60/// the result of code (llvm::Module instances) splitting:
61/// - entry points group from the module.
62class ModuleDesc {
63 std::unique_ptr<Module> M;
64 EntryPointGroup EntryPoints;
65
66public:
67 ModuleDesc(std::unique_ptr<Module> M,
68 EntryPointGroup &&EntryPoints = EntryPointGroup())
69 : M(std::move(M)), EntryPoints(std::move(EntryPoints)) {
70 assert(this->M && "Module should be non-null");
71 }
72
73 Module &getModule() { return *M; }
74 const Module &getModule() const { return *M; }
75
76 std::unique_ptr<Module> releaseModule() {
77 EntryPoints.clear();
78 return std::move(M);
79 }
80
81#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
82 LLVM_DUMP_METHOD void dump() const {
83 dbgs() << "ModuleDesc[" << M->getName() << "] {\n";
84 EntryPoints.dump();
85 dbgs() << "}\n";
86 }
87#endif
88};
89
90// Represents "dependency" or "use" graph of global objects (functions and
91// global variables) in a module. It is used during code split to
92// understand which global variables and functions (other than entry points)
93// should be included into a split module.
94//
95// Nodes of the graph represent LLVM's GlobalObjects, edges "A" -> "B" represent
96// the fact that if "A" is included into a module, then "B" should be included
97// as well.
98//
99// Examples of dependencies which are represented in this graph:
100// - Function FA calls function FB
101// - Function FA uses global variable GA
102// - Global variable GA references (initialized with) function FB
103// - Function FA stores address of a function FB somewhere
104//
105// The following cases are treated as dependencies between global objects:
106// 1. Global object A is used by a global object B in any way (store,
107// bitcast, phi node, call, etc.): "A" -> "B" edge will be added to the
108// graph;
109// 2. function A performs an indirect call of a function with signature S and
110// there is a function B with signature S. "A" -> "B" edge will be added to
111// the graph;
112class DependencyGraph {
113public:
114 using GlobalSet = SmallPtrSet<const GlobalValue *, 16>;
115
116 DependencyGraph(const Module &M) {
117 // Group functions by their signature to handle case (2) described above
118 DenseMap<const FunctionType *, DependencyGraph::GlobalSet>
119 FuncTypeToFuncsMap;
120 for (const Function &F : M.functions()) {
121 // Kernels can't be called (either directly or indirectly).
122 if (F.hasKernelCallingConv())
123 continue;
124
125 FuncTypeToFuncsMap[F.getFunctionType()].insert(Ptr: &F);
126 }
127
128 for (const Function &F : M.functions()) {
129 // case (1), see comment above the class definition
130 for (const Value *U : F.users())
131 addUserToGraphRecursively(Root: cast<const User>(Val: U), V: &F);
132
133 // case (2), see comment above the class definition
134 for (const Instruction &I : instructions(F)) {
135 const CallBase *CB = dyn_cast<CallBase>(Val: &I);
136 if (!CB || !CB->isIndirectCall()) // Direct calls were handled above
137 continue;
138
139 const FunctionType *Signature = CB->getFunctionType();
140 GlobalSet &PotentialCallees = FuncTypeToFuncsMap[Signature];
141 Graph[&F].insert(I: PotentialCallees.begin(), E: PotentialCallees.end());
142 }
143 }
144
145 // And every global variable (but their handling is a bit simpler)
146 for (const GlobalVariable &GV : M.globals())
147 for (const Value *U : GV.users())
148 addUserToGraphRecursively(Root: cast<const User>(Val: U), V: &GV);
149 }
150
151 iterator_range<GlobalSet::const_iterator>
152 dependencies(const GlobalValue *Val) const {
153 auto It = Graph.find(Val);
154 return (It == Graph.end())
155 ? make_range(x: EmptySet.begin(), y: EmptySet.end())
156 : make_range(x: It->second.begin(), y: It->second.end());
157 }
158
159private:
160 void addUserToGraphRecursively(const User *Root, const GlobalValue *V) {
161 SmallVector<const User *, 8> WorkList;
162 WorkList.push_back(Elt: Root);
163
164 while (!WorkList.empty()) {
165 const User *U = WorkList.pop_back_val();
166 if (const auto *I = dyn_cast<const Instruction>(Val: U)) {
167 const Function *UFunc = I->getFunction();
168 Graph[UFunc].insert(Ptr: V);
169 } else if (isa<const Constant>(Val: U)) {
170 if (const auto *GV = dyn_cast<const GlobalVariable>(Val: U))
171 Graph[GV].insert(Ptr: V);
172 // This could be a global variable or some constant expression (like
173 // bitcast or gep). We trace users of this constant further to reach
174 // global objects they are used by and add them to the graph.
175 for (const User *UU : U->users())
176 WorkList.push_back(Elt: UU);
177 } else {
178 llvm_unreachable("Unhandled type of function user");
179 }
180 }
181 }
182
183 DenseMap<const GlobalValue *, GlobalSet> Graph;
184 SmallPtrSet<const GlobalValue *, 1> EmptySet;
185};
186
187void collectFunctionsAndGlobalVariablesToExtract(
188 SetVector<const GlobalValue *> &GVs, const Module &M,
189 const EntryPointGroup &ModuleEntryPoints, const DependencyGraph &DG) {
190 // We start with module entry points
191 for (const Function *F : ModuleEntryPoints.Functions)
192 GVs.insert(X: F);
193
194 // Non-discardable global variables are also include into the initial set
195 for (const GlobalVariable &GV : M.globals())
196 if (!GV.isDiscardableIfUnused())
197 GVs.insert(X: &GV);
198
199 // GVs has SetVector type. This type inserts a value only if it is not yet
200 // present there. So, recursion is not expected here.
201 size_t Idx = 0;
202 while (Idx < GVs.size()) {
203 const GlobalValue *Obj = GVs[Idx++];
204
205 for (const GlobalValue *Dep : DG.dependencies(Val: Obj)) {
206 if (const auto *Func = dyn_cast<const Function>(Val: Dep)) {
207 if (!Func->isDeclaration())
208 GVs.insert(X: Func);
209 } else {
210 GVs.insert(X: Dep); // Global variables are added unconditionally
211 }
212 }
213 }
214}
215
216ModuleDesc extractSubModule(const Module &M,
217 const SetVector<const GlobalValue *> &GVs,
218 EntryPointGroup &&ModuleEntryPoints) {
219 ValueToValueMapTy VMap;
220 // Clone definitions only for needed globals. Others will be added as
221 // declarations and removed later.
222 std::unique_ptr<Module> SubM = CloneModule(
223 M, VMap, ShouldCloneDefinition: [&](const GlobalValue *GV) { return GVs.contains(key: GV); });
224 // Replace entry points with cloned ones.
225 EntryPointSet NewEPs;
226 const EntryPointSet &EPs = ModuleEntryPoints.Functions;
227 llvm::for_each(
228 Range: EPs, F: [&](const Function *F) { NewEPs.insert(X: cast<Function>(Val&: VMap[F])); });
229 ModuleEntryPoints.Functions = std::move(NewEPs);
230 return ModuleDesc{std::move(SubM), std::move(ModuleEntryPoints)};
231}
232
233// The function produces a copy of input LLVM IR module M with only those
234// functions and globals that can be called from entry points that are specified
235// in ModuleEntryPoints vector, in addition to the entry point functions.
236ModuleDesc extractCallGraph(const Module &M,
237 EntryPointGroup &&ModuleEntryPoints,
238 const DependencyGraph &DG) {
239 SetVector<const GlobalValue *> GVs;
240 collectFunctionsAndGlobalVariablesToExtract(GVs, M, ModuleEntryPoints, DG);
241
242 ModuleDesc SplitM = extractSubModule(M, GVs, ModuleEntryPoints: std::move(ModuleEntryPoints));
243 LLVM_DEBUG(SplitM.dump());
244 return SplitM;
245}
246
247using EntryPointGroupVec = SmallVector<EntryPointGroup>;
248
249/// Module Splitter.
250/// It gets a module and a collection of entry points groups.
251/// Each group specifies subset entry points from input module that should be
252/// included in a split module.
253class ModuleSplitter {
254private:
255 std::unique_ptr<Module> M;
256 EntryPointGroupVec Groups;
257 DependencyGraph DG;
258
259private:
260 EntryPointGroup drawEntryPointGroup() {
261 assert(Groups.size() > 0 && "Reached end of entry point groups list.");
262 EntryPointGroup Group = std::move(Groups.back());
263 Groups.pop_back();
264 return Group;
265 }
266
267public:
268 ModuleSplitter(std::unique_ptr<Module> Module, EntryPointGroupVec &&GroupVec)
269 : M(std::move(Module)), Groups(std::move(GroupVec)), DG(*M) {
270 assert(!Groups.empty() && "Entry points groups collection is empty!");
271 }
272
273 /// Gets next subsequence of entry points in an input module and provides
274 /// split submodule containing these entry points and their dependencies.
275 ModuleDesc getNextSplit() {
276 return extractCallGraph(M: *M, ModuleEntryPoints: drawEntryPointGroup(), DG);
277 }
278
279 /// Check that there are still submodules to split.
280 bool hasMoreSplits() const { return Groups.size() > 0; }
281};
282
283EntryPointGroupVec selectEntryPointGroups(
284 const Module &M, function_ref<std::optional<int>(const Function &F)> EPC) {
285 // std::map is used here to ensure stable ordering of entry point groups,
286 // which is based on their contents, this greatly helps LIT tests
287 // Note: EPC is allowed to return big identifiers. Therefore, we use
288 // std::map + SmallVector approach here.
289 std::map<int, EntryPointSet> EntryPointsMap;
290
291 for (const auto &F : M.functions())
292 if (std::optional<int> Category = EPC(F); Category)
293 EntryPointsMap[*Category].insert(X: &F);
294
295 EntryPointGroupVec Groups;
296 Groups.reserve(N: EntryPointsMap.size());
297 for (auto &[Key, EntryPoints] : EntryPointsMap)
298 Groups.emplace_back(Args: Key, Args: std::move(EntryPoints));
299
300 return Groups;
301}
302
303} // namespace
304
305void llvm::splitModuleTransitiveFromEntryPoints(
306 std::unique_ptr<Module> M,
307 function_ref<std::optional<int>(const Function &F)> EntryPointCategorizer,
308 function_ref<void(std::unique_ptr<Module> Part)> Callback) {
309 EntryPointGroupVec Groups = selectEntryPointGroups(M: *M, EPC: EntryPointCategorizer);
310 ModuleSplitter Splitter(std::move(M), std::move(Groups));
311 while (Splitter.hasMoreSplits()) {
312 ModuleDesc MD = Splitter.getNextSplit();
313 Callback(MD.releaseModule());
314 }
315}
316