1//===----------------------------------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "ReduceOperandsToArgs.h"
10#include "Utils.h"
11#include "llvm/ADT/Sequence.h"
12#include "llvm/IR/Constants.h"
13#include "llvm/IR/InstIterator.h"
14#include "llvm/IR/InstrTypes.h"
15#include "llvm/IR/Instructions.h"
16#include "llvm/IR/IntrinsicInst.h"
17#include "llvm/IR/Operator.h"
18#include "llvm/Transforms/Utils/BasicBlockUtils.h"
19#include "llvm/Transforms/Utils/Cloning.h"
20
21using namespace llvm;
22
23static bool canReplaceFunction(const Function &F) {
24 // TODO: Add controls to avoid ABI breaks (e.g. don't break main)
25 return true;
26}
27
28static bool canReduceUse(Use &Op) {
29 Value *Val = Op.get();
30 Type *Ty = Val->getType();
31
32 // Only replace operands that can be passed-by-value.
33 if (!Ty->isFirstClassType())
34 return false;
35
36 // Don't pass labels/metadata as arguments.
37 if (Ty->isLabelTy() || Ty->isMetadataTy() || Ty->isTokenTy())
38 return false;
39
40 // No need to replace values that are already arguments.
41 if (isa<Argument>(Val))
42 return false;
43
44 // Do not replace literals.
45 if (isa<ConstantData>(Val))
46 return false;
47
48 // Do not convert direct function calls to indirect calls.
49 if (auto *CI = dyn_cast<CallBase>(Val: Op.getUser()))
50 if (&CI->getCalledOperandUse() == &Op)
51 return false;
52
53 // lifetime.start/lifetime.end require alloca argument.
54 if (isa<LifetimeIntrinsic>(Val: Op.getUser()))
55 return false;
56
57 return true;
58}
59
60/// Goes over OldF calls and replaces them with a call to NewF.
61static void replaceFunctionCalls(Function *OldF, Function *NewF) {
62 SmallVector<CallBase *> Callers;
63 for (Use &U : OldF->uses()) {
64 auto *CI = dyn_cast<CallBase>(Val: U.getUser());
65 if (!CI || !CI->isCallee(U: &U)) // RAUW can handle these fine.
66 continue;
67
68 Function *CalledF = CI->getCalledFunction();
69 if (CalledF == OldF) {
70 Callers.push_back(Elt: CI);
71 } else {
72 // The call may have undefined behavior by calling a function with a
73 // mismatched signature. In this case, do not bother adjusting the
74 // callsites to pad with any new arguments.
75
76 // TODO: Better QoI to try to add new arguments to the end, and ignore
77 // existing mismatches.
78 assert(!CalledF && CI->getCalledOperand()->stripPointerCasts() == OldF &&
79 "only expected call and function signature mismatch");
80 }
81 }
82
83 // Call arguments for NewF.
84 SmallVector<Value *> Args(NewF->arg_size(), nullptr);
85
86 // Fill up the additional parameters with default values.
87 for (auto ArgIdx : llvm::seq<size_t>(Begin: OldF->arg_size(), End: NewF->arg_size())) {
88 Type *NewArgTy = NewF->getArg(i: ArgIdx)->getType();
89 Args[ArgIdx] = getDefaultValue(T: NewArgTy);
90 }
91
92 for (CallBase *CI : Callers) {
93 // Preserve the original function arguments.
94 for (auto Z : zip_first(t: CI->args(), u&: Args))
95 std::get<1>(t&: Z) = std::get<0>(t&: Z);
96
97 // Also preserve operand bundles.
98 SmallVector<OperandBundleDef> OperandBundles;
99 CI->getOperandBundlesAsDefs(Defs&: OperandBundles);
100
101 // Create the new function call.
102 CallBase *NewCI;
103 if (auto *II = dyn_cast<InvokeInst>(Val: CI)) {
104 NewCI = InvokeInst::Create(Func: NewF, IfNormal: II->getNormalDest(), IfException: II->getUnwindDest(),
105 Args, Bundles: OperandBundles, NameStr: CI->getName());
106 } else {
107 assert(isa<CallInst>(CI));
108 NewCI = CallInst::Create(Func: NewF, Args, Bundles: OperandBundles, NameStr: CI->getName());
109 }
110 NewCI->setCallingConv(NewF->getCallingConv());
111 NewCI->setAttributes(CI->getAttributes());
112
113 if (isa<FPMathOperator>(Val: NewCI))
114 NewCI->setFastMathFlags(CI->getFastMathFlags());
115
116 NewCI->copyMetadata(SrcInst: *CI);
117
118 // Do the replacement for this use.
119 if (!CI->use_empty())
120 CI->replaceAllUsesWith(V: NewCI);
121 ReplaceInstWithInst(From: CI, To: NewCI);
122 }
123}
124
125/// Add a new function argument to @p F for each use in @OpsToReplace, and
126/// replace those operand values with the new function argument.
127static void substituteOperandWithArgument(Function *OldF,
128 ArrayRef<Use *> OpsToReplace) {
129 if (OpsToReplace.empty())
130 return;
131
132 SetVector<Value *> UniqueValues;
133 for (Use *Op : OpsToReplace)
134 UniqueValues.insert(X: Op->get());
135
136 // Determine the new function's signature.
137 SmallVector<Type *> NewArgTypes(OldF->getFunctionType()->params());
138 size_t ArgOffset = NewArgTypes.size();
139 for (Value *V : UniqueValues)
140 NewArgTypes.push_back(Elt: V->getType());
141 FunctionType *FTy =
142 FunctionType::get(Result: OldF->getFunctionType()->getReturnType(), Params: NewArgTypes,
143 isVarArg: OldF->getFunctionType()->isVarArg());
144
145 // Create the new function...
146 Function *NewF = Function::Create(
147 Ty: FTy, Linkage: OldF->getLinkage(), AddrSpace: OldF->getAddressSpace(), N: "", M: OldF->getParent());
148
149 // In order to preserve function order, we move NewF behind OldF
150 NewF->removeFromParent();
151 OldF->getParent()->getFunctionList().insertAfter(where: OldF->getIterator(), New: NewF);
152
153 // Preserve the parameters of OldF.
154 ValueToValueMapTy VMap;
155 for (auto Z : zip_first(t: OldF->args(), u: NewF->args())) {
156 Argument &OldArg = std::get<0>(t&: Z);
157 Argument &NewArg = std::get<1>(t&: Z);
158
159 NewArg.takeName(V: &OldArg); // Copy the name over...
160 VMap[&OldArg] = &NewArg; // Add mapping to VMap
161 }
162
163 LLVMContext &Ctx = OldF->getContext();
164
165 // Adjust the new parameters.
166 ValueToValueMapTy OldValMap;
167 for (auto Z : zip_first(t&: UniqueValues, u: drop_begin(RangeOrContainer: NewF->args(), N: ArgOffset))) {
168 Value *OldVal = std::get<0>(t&: Z);
169 Argument &NewArg = std::get<1>(t&: Z);
170
171 NewArg.setName(OldVal->getName());
172 OldValMap[OldVal] = &NewArg;
173 }
174
175 SmallVector<ReturnInst *, 8> Returns; // Ignore returns cloned.
176 CloneFunctionInto(NewFunc: NewF, OldFunc: OldF, VMap, Changes: CloneFunctionChangeType::LocalChangesOnly,
177 Returns, NameSuffix: "", /*CodeInfo=*/nullptr);
178
179 // Replace the actual operands.
180 for (Use *Op : OpsToReplace) {
181 Argument *NewArg = cast<Argument>(Val: OldValMap.lookup(Val: Op->get()));
182 auto *NewUser = cast<Instruction>(Val: VMap.lookup(Val: Op->getUser()));
183
184 // Try to preserve any information contained metadata annotations as the
185 // equivalent parameter attributes if possible.
186 if (auto *MDSrcInst = dyn_cast<Instruction>(Val: Op)) {
187 AttrBuilder AB(Ctx);
188 NewArg->addAttrs(B&: AB.addFromEquivalentMetadata(I: *MDSrcInst));
189 }
190
191 if (PHINode *NewPhi = dyn_cast<PHINode>(Val: NewUser)) {
192 PHINode *OldPhi = cast<PHINode>(Val: Op->getUser());
193 BasicBlock *OldBB = OldPhi->getIncomingBlock(U: *Op);
194 NewPhi->setIncomingValueForBlock(BB: cast<BasicBlock>(Val: VMap.lookup(Val: OldBB)),
195 V: NewArg);
196 } else
197 NewUser->setOperand(i: Op->getOperandNo(), Val: NewArg);
198 }
199
200 // Replace all OldF uses with NewF.
201 replaceFunctionCalls(OldF, NewF);
202
203 NewF->takeName(V: OldF);
204 OldF->replaceAllUsesWith(V: NewF);
205 OldF->eraseFromParent();
206}
207
208void llvm::reduceOperandsToArgsDeltaPass(Oracle &O, ReducerWorkItem &WorkItem) {
209 Module &Program = WorkItem.getModule();
210
211 SmallVector<Use *> OperandsToReduce;
212 for (Function &F : make_early_inc_range(Range: Program.functions())) {
213 if (!canReplaceFunction(F))
214 continue;
215 OperandsToReduce.clear();
216 for (Instruction &I : instructions(F: &F)) {
217 for (Use &Op : I.operands()) {
218 if (!canReduceUse(Op))
219 continue;
220 if (O.shouldKeep())
221 continue;
222
223 OperandsToReduce.push_back(Elt: &Op);
224 }
225 }
226
227 substituteOperandWithArgument(OldF: &F, OpsToReplace: OperandsToReduce);
228 }
229}
230