1//===- ReduceArguments.cpp - Specialized Delta Pass -----------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements a function which calls the Generic Delta pass in order
10// to reduce uninteresting Arguments from declared and defined functions.
11//
12//===----------------------------------------------------------------------===//
13
14#include "ReduceArguments.h"
15#include "Utils.h"
16#include "llvm/ADT/SmallVector.h"
17#include "llvm/IR/FMF.h"
18#include "llvm/IR/Instructions.h"
19#include "llvm/IR/Intrinsics.h"
20#include "llvm/IR/Operator.h"
21#include "llvm/Transforms/Utils/BasicBlockUtils.h"
22#include "llvm/Transforms/Utils/Cloning.h"
23#include <set>
24#include <vector>
25
26using namespace llvm;
27
28static bool callingConvRequiresArgument(const Function &F,
29 const Argument &Arg) {
30 switch (F.getCallingConv()) {
31 case CallingConv::X86_INTR:
32 // If there are any arguments, the first one must by byval.
33 return Arg.getArgNo() == 0 && F.arg_size() != 1;
34 default:
35 return false;
36 }
37
38 llvm_unreachable("covered calling conv switch");
39}
40
41/// Goes over OldF calls and replaces them with a call to NewF
42static void replaceFunctionCalls(Function &OldF, Function &NewF,
43 const std::set<int> &ArgIndexesToKeep) {
44 LLVMContext &Ctx = OldF.getContext();
45
46 const auto &Users = OldF.users();
47 for (auto I = Users.begin(), E = Users.end(); I != E; )
48 if (auto *CI = dyn_cast<CallInst>(Val: *I++)) {
49 // Skip uses in call instructions where OldF isn't the called function
50 // (e.g. if OldF is an argument of the call).
51 if (CI->getCalledFunction() != &OldF)
52 continue;
53 SmallVector<Value *, 8> Args;
54 SmallVector<AttrBuilder, 8> ArgAttrs;
55
56 for (auto ArgI = CI->arg_begin(), E = CI->arg_end(); ArgI != E; ++ArgI) {
57 unsigned ArgIdx = ArgI - CI->arg_begin();
58 if (ArgIndexesToKeep.count(x: ArgIdx)) {
59 Args.push_back(Elt: *ArgI);
60 ArgAttrs.emplace_back(Args&: Ctx, Args: CI->getParamAttributes(ArgNo: ArgIdx));
61 }
62 }
63
64 SmallVector<OperandBundleDef, 2> OpBundles;
65 CI->getOperandBundlesAsDefs(Defs&: OpBundles);
66
67 CallInst *NewCI = CallInst::Create(Func: &NewF, Args, Bundles: OpBundles);
68 NewCI->setCallingConv(CI->getCallingConv());
69
70 AttrBuilder CallSiteAttrs(Ctx, CI->getAttributes().getFnAttrs());
71 NewCI->setAttributes(
72 AttributeList::get(C&: Ctx, Index: AttributeList::FunctionIndex, B: CallSiteAttrs));
73 NewCI->addRetAttrs(B: AttrBuilder(Ctx, CI->getRetAttributes()));
74
75 unsigned AttrIdx = 0;
76 for (auto ArgI = NewCI->arg_begin(), E = NewCI->arg_end(); ArgI != E;
77 ++ArgI, ++AttrIdx)
78 NewCI->addParamAttrs(ArgNo: AttrIdx, B: ArgAttrs[AttrIdx]);
79
80 if (auto *FPOp = dyn_cast<FPMathOperator>(Val: NewCI))
81 cast<Instruction>(Val: FPOp)->setFastMathFlags(CI->getFastMathFlags());
82
83 NewCI->copyMetadata(SrcInst: *CI);
84
85 if (!CI->use_empty())
86 CI->replaceAllUsesWith(V: NewCI);
87 ReplaceInstWithInst(From: CI, To: NewCI);
88 }
89}
90
91/// Returns whether or not this function should be considered a candidate for
92/// argument removal. Currently, functions with no arguments and intrinsics are
93/// not considered. Intrinsics aren't considered because their signatures are
94/// fixed.
95static bool shouldRemoveArguments(const Function &F) {
96 return !F.arg_empty() && !F.isIntrinsic();
97}
98
99static bool allFuncUsersRewritable(const Function &F) {
100 for (const Use &U : F.uses()) {
101 const CallBase *CB = dyn_cast<CallBase>(Val: U.getUser());
102 if (!CB || !CB->isCallee(U: &U))
103 continue;
104
105 // TODO: Handle all CallBase cases.
106 if (!isa<CallInst>(Val: CB))
107 return false;
108 }
109
110 return true;
111}
112
113/// Removes out-of-chunk arguments from functions, and modifies their calls
114/// accordingly. It also removes allocations of out-of-chunk arguments.
115void llvm::reduceArgumentsDeltaPass(Oracle &O, ReducerWorkItem &WorkItem) {
116 Module &Program = WorkItem.getModule();
117 std::vector<Argument *> InitArgsToKeep;
118 std::vector<Function *> Funcs;
119
120 // Get inside-chunk arguments, as well as their parent function
121 for (auto &F : Program) {
122 if (!shouldRemoveArguments(F))
123 continue;
124 if (!allFuncUsersRewritable(F))
125 continue;
126 Funcs.push_back(x: &F);
127 for (auto &A : F.args()) {
128 if (callingConvRequiresArgument(F, Arg: A) || O.shouldKeep())
129 InitArgsToKeep.push_back(x: &A);
130 }
131 }
132
133 // We create a vector first, then convert it to a set, so that we don't have
134 // to pay the cost of rebalancing the set frequently if the order we insert
135 // the elements doesn't match the order they should appear inside the set.
136 std::set<Argument *> ArgsToKeep(InitArgsToKeep.begin(), InitArgsToKeep.end());
137
138 for (auto *F : Funcs) {
139 ValueToValueMapTy VMap;
140 std::vector<WeakVH> InstToDelete;
141 for (auto &A : F->args())
142 if (!ArgsToKeep.count(x: &A)) {
143 // By adding undesired arguments to the VMap, CloneFunction will remove
144 // them from the resulting Function
145 VMap[&A] = getDefaultValue(T: A.getType());
146 for (auto *U : A.users())
147 if (auto *I = dyn_cast<Instruction>(Val: *&U))
148 InstToDelete.push_back(x: I);
149 }
150 // Delete any (unique) instruction that uses the argument
151 for (Value *V : InstToDelete) {
152 if (!V)
153 continue;
154 auto *I = cast<Instruction>(Val: V);
155 I->replaceAllUsesWith(V: getDefaultValue(T: I->getType()));
156 if (!I->isTerminator())
157 I->eraseFromParent();
158 }
159
160 // No arguments to reduce
161 if (VMap.empty())
162 continue;
163
164 std::set<int> ArgIndexesToKeep;
165 for (const auto &[Index, Arg] : enumerate(First: F->args()))
166 if (ArgsToKeep.count(x: &Arg))
167 ArgIndexesToKeep.insert(x: Index);
168
169 auto *ClonedFunc = CloneFunction(F, VMap);
170 // In order to preserve function order, we move Clone after old Function
171 ClonedFunc->takeName(V: F);
172 ClonedFunc->removeFromParent();
173 Program.getFunctionList().insertAfter(where: F->getIterator(), New: ClonedFunc);
174
175 replaceFunctionCalls(OldF&: *F, NewF&: *ClonedFunc, ArgIndexesToKeep);
176 F->replaceAllUsesWith(V: ClonedFunc);
177 F->eraseFromParent();
178 }
179}
180