1//===----------------------------------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "ReduceOperandsToArgs.h"
10#include "Utils.h"
11#include "llvm/ADT/Sequence.h"
12#include "llvm/IR/Constants.h"
13#include "llvm/IR/InstIterator.h"
14#include "llvm/IR/InstrTypes.h"
15#include "llvm/IR/Instructions.h"
16#include "llvm/IR/Operator.h"
17#include "llvm/Transforms/Utils/BasicBlockUtils.h"
18#include "llvm/Transforms/Utils/Cloning.h"
19
20using namespace llvm;
21
22static bool canReplaceFunction(const Function &F) {
23 // TODO: Add controls to avoid ABI breaks (e.g. don't break main)
24 return true;
25}
26
27static bool canReduceUse(Use &Op) {
28 Value *Val = Op.get();
29 Type *Ty = Val->getType();
30
31 // Only replace operands that can be passed-by-value.
32 if (!Ty->isFirstClassType())
33 return false;
34
35 // Don't pass labels/metadata as arguments.
36 if (Ty->isLabelTy() || Ty->isMetadataTy() || Ty->isTokenTy())
37 return false;
38
39 // No need to replace values that are already arguments.
40 if (isa<Argument>(Val))
41 return false;
42
43 // Do not replace literals.
44 if (isa<ConstantData>(Val))
45 return false;
46
47 // Do not convert direct function calls to indirect calls.
48 if (auto *CI = dyn_cast<CallBase>(Val: Op.getUser()))
49 if (&CI->getCalledOperandUse() == &Op)
50 return false;
51
52 return true;
53}
54
55/// Goes over OldF calls and replaces them with a call to NewF.
56static void replaceFunctionCalls(Function *OldF, Function *NewF) {
57 SmallVector<CallBase *> Callers;
58 for (Use &U : OldF->uses()) {
59 auto *CI = dyn_cast<CallBase>(Val: U.getUser());
60 if (!CI || !CI->isCallee(U: &U)) // RAUW can handle these fine.
61 continue;
62
63 Function *CalledF = CI->getCalledFunction();
64 if (CalledF == OldF) {
65 Callers.push_back(Elt: CI);
66 } else {
67 // The call may have undefined behavior by calling a function with a
68 // mismatched signature. In this case, do not bother adjusting the
69 // callsites to pad with any new arguments.
70
71 // TODO: Better QoI to try to add new arguments to the end, and ignore
72 // existing mismatches.
73 assert(!CalledF && CI->getCalledOperand()->stripPointerCasts() == OldF &&
74 "only expected call and function signature mismatch");
75 }
76 }
77
78 // Call arguments for NewF.
79 SmallVector<Value *> Args(NewF->arg_size(), nullptr);
80
81 // Fill up the additional parameters with default values.
82 for (auto ArgIdx : llvm::seq<size_t>(Begin: OldF->arg_size(), End: NewF->arg_size())) {
83 Type *NewArgTy = NewF->getArg(i: ArgIdx)->getType();
84 Args[ArgIdx] = getDefaultValue(T: NewArgTy);
85 }
86
87 for (CallBase *CI : Callers) {
88 // Preserve the original function arguments.
89 for (auto Z : zip_first(t: CI->args(), u&: Args))
90 std::get<1>(t&: Z) = std::get<0>(t&: Z);
91
92 // Also preserve operand bundles.
93 SmallVector<OperandBundleDef> OperandBundles;
94 CI->getOperandBundlesAsDefs(Defs&: OperandBundles);
95
96 // Create the new function call.
97 CallBase *NewCI;
98 if (auto *II = dyn_cast<InvokeInst>(Val: CI)) {
99 NewCI = InvokeInst::Create(Func: NewF, IfNormal: II->getNormalDest(), IfException: II->getUnwindDest(),
100 Args, Bundles: OperandBundles, NameStr: CI->getName());
101 } else {
102 assert(isa<CallInst>(CI));
103 NewCI = CallInst::Create(Func: NewF, Args, Bundles: OperandBundles, NameStr: CI->getName());
104 }
105 NewCI->setCallingConv(NewF->getCallingConv());
106 NewCI->setAttributes(CI->getAttributes());
107
108 if (isa<FPMathOperator>(Val: NewCI))
109 NewCI->setFastMathFlags(CI->getFastMathFlags());
110
111 NewCI->copyMetadata(SrcInst: *CI);
112
113 // Do the replacement for this use.
114 if (!CI->use_empty())
115 CI->replaceAllUsesWith(V: NewCI);
116 ReplaceInstWithInst(From: CI, To: NewCI);
117 }
118}
119
120/// Add a new function argument to @p F for each use in @OpsToReplace, and
121/// replace those operand values with the new function argument.
122static void substituteOperandWithArgument(Function *OldF,
123 ArrayRef<Use *> OpsToReplace) {
124 if (OpsToReplace.empty())
125 return;
126
127 SetVector<Value *> UniqueValues;
128 for (Use *Op : OpsToReplace)
129 UniqueValues.insert(X: Op->get());
130
131 // Determine the new function's signature.
132 SmallVector<Type *> NewArgTypes(OldF->getFunctionType()->params());
133 size_t ArgOffset = NewArgTypes.size();
134 for (Value *V : UniqueValues)
135 NewArgTypes.push_back(Elt: V->getType());
136 FunctionType *FTy =
137 FunctionType::get(Result: OldF->getFunctionType()->getReturnType(), Params: NewArgTypes,
138 isVarArg: OldF->getFunctionType()->isVarArg());
139
140 // Create the new function...
141 Function *NewF = Function::Create(
142 Ty: FTy, Linkage: OldF->getLinkage(), AddrSpace: OldF->getAddressSpace(), N: "", M: OldF->getParent());
143
144 // In order to preserve function order, we move NewF behind OldF
145 NewF->removeFromParent();
146 OldF->getParent()->getFunctionList().insertAfter(where: OldF->getIterator(), New: NewF);
147
148 // Preserve the parameters of OldF.
149 ValueToValueMapTy VMap;
150 for (auto Z : zip_first(t: OldF->args(), u: NewF->args())) {
151 Argument &OldArg = std::get<0>(t&: Z);
152 Argument &NewArg = std::get<1>(t&: Z);
153
154 NewArg.takeName(V: &OldArg); // Copy the name over...
155 VMap[&OldArg] = &NewArg; // Add mapping to VMap
156 }
157
158 LLVMContext &Ctx = OldF->getContext();
159
160 // Adjust the new parameters.
161 ValueToValueMapTy OldValMap;
162 for (auto Z : zip_first(t&: UniqueValues, u: drop_begin(RangeOrContainer: NewF->args(), N: ArgOffset))) {
163 Value *OldVal = std::get<0>(t&: Z);
164 Argument &NewArg = std::get<1>(t&: Z);
165
166 NewArg.setName(OldVal->getName());
167 OldValMap[OldVal] = &NewArg;
168 }
169
170 SmallVector<ReturnInst *, 8> Returns; // Ignore returns cloned.
171 CloneFunctionInto(NewFunc: NewF, OldFunc: OldF, VMap, Changes: CloneFunctionChangeType::LocalChangesOnly,
172 Returns, NameSuffix: "", /*CodeInfo=*/nullptr);
173
174 // Replace the actual operands.
175 for (Use *Op : OpsToReplace) {
176 Argument *NewArg = cast<Argument>(Val: OldValMap.lookup(Val: Op->get()));
177 auto *NewUser = cast<Instruction>(Val: VMap.lookup(Val: Op->getUser()));
178
179 // Try to preserve any information contained metadata annotations as the
180 // equivalent parameter attributes if possible.
181 if (auto *MDSrcInst = dyn_cast<Instruction>(Val: Op)) {
182 AttrBuilder AB(Ctx);
183 NewArg->addAttrs(B&: AB.addFromEquivalentMetadata(I: *MDSrcInst));
184 }
185
186 if (PHINode *NewPhi = dyn_cast<PHINode>(Val: NewUser)) {
187 PHINode *OldPhi = cast<PHINode>(Val: Op->getUser());
188 BasicBlock *OldBB = OldPhi->getIncomingBlock(U: *Op);
189 NewPhi->setIncomingValueForBlock(BB: cast<BasicBlock>(Val: VMap.lookup(Val: OldBB)),
190 V: NewArg);
191 } else
192 NewUser->setOperand(i: Op->getOperandNo(), Val: NewArg);
193 }
194
195 // Replace all OldF uses with NewF.
196 replaceFunctionCalls(OldF, NewF);
197
198 NewF->takeName(V: OldF);
199 OldF->replaceAllUsesWith(V: NewF);
200 OldF->eraseFromParent();
201}
202
203void llvm::reduceOperandsToArgsDeltaPass(Oracle &O, ReducerWorkItem &WorkItem) {
204 Module &Program = WorkItem.getModule();
205
206 SmallVector<Use *> OperandsToReduce;
207 for (Function &F : make_early_inc_range(Range: Program.functions())) {
208 if (!canReplaceFunction(F))
209 continue;
210 OperandsToReduce.clear();
211 for (Instruction &I : instructions(F: &F)) {
212 for (Use &Op : I.operands()) {
213 if (!canReduceUse(Op))
214 continue;
215 if (O.shouldKeep())
216 continue;
217
218 OperandsToReduce.push_back(Elt: &Op);
219 }
220 }
221
222 substituteOperandWithArgument(OldF: &F, OpsToReplace: OperandsToReduce);
223 }
224}
225