1//===----------------------------------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "ReduceOperandsToArgs.h"
10#include "Delta.h"
11#include "Utils.h"
12#include "llvm/ADT/Sequence.h"
13#include "llvm/IR/Constants.h"
14#include "llvm/IR/InstIterator.h"
15#include "llvm/IR/InstrTypes.h"
16#include "llvm/IR/Instructions.h"
17#include "llvm/Transforms/Utils/BasicBlockUtils.h"
18#include "llvm/Transforms/Utils/Cloning.h"
19
20using namespace llvm;
21
22static bool canReplaceFunction(Function *F) {
23 return all_of(Range: F->uses(), P: [](Use &Op) {
24 if (auto *CI = dyn_cast<CallBase>(Val: Op.getUser()))
25 return &CI->getCalledOperandUse() == &Op;
26 return false;
27 });
28}
29
30static bool canReduceUse(Use &Op) {
31 Value *Val = Op.get();
32 Type *Ty = Val->getType();
33
34 // Only replace operands that can be passed-by-value.
35 if (!Ty->isFirstClassType())
36 return false;
37
38 // Don't pass labels/metadata as arguments.
39 if (Ty->isLabelTy() || Ty->isMetadataTy())
40 return false;
41
42 // No need to replace values that are already arguments.
43 if (isa<Argument>(Val))
44 return false;
45
46 // Do not replace literals.
47 if (isa<ConstantData>(Val))
48 return false;
49
50 // Do not convert direct function calls to indirect calls.
51 if (auto *CI = dyn_cast<CallBase>(Val: Op.getUser()))
52 if (&CI->getCalledOperandUse() == &Op)
53 return false;
54
55 return true;
56}
57
58/// Goes over OldF calls and replaces them with a call to NewF.
59static void replaceFunctionCalls(Function *OldF, Function *NewF) {
60 SmallVector<CallBase *> Callers;
61 for (Use &U : OldF->uses()) {
62 auto *CI = cast<CallBase>(Val: U.getUser());
63 assert(&U == &CI->getCalledOperandUse());
64 assert(CI->getCalledFunction() == OldF);
65 Callers.push_back(Elt: CI);
66 }
67
68 // Call arguments for NewF.
69 SmallVector<Value *> Args(NewF->arg_size(), nullptr);
70
71 // Fill up the additional parameters with default values.
72 for (auto ArgIdx : llvm::seq<size_t>(Begin: OldF->arg_size(), End: NewF->arg_size())) {
73 Type *NewArgTy = NewF->getArg(i: ArgIdx)->getType();
74 Args[ArgIdx] = getDefaultValue(T: NewArgTy);
75 }
76
77 for (CallBase *CI : Callers) {
78 // Preserve the original function arguments.
79 for (auto Z : zip_first(t: CI->args(), u&: Args))
80 std::get<1>(t&: Z) = std::get<0>(t&: Z);
81
82 // Also preserve operand bundles.
83 SmallVector<OperandBundleDef> OperandBundles;
84 CI->getOperandBundlesAsDefs(Defs&: OperandBundles);
85
86 // Create the new function call.
87 CallBase *NewCI;
88 if (auto *II = dyn_cast<InvokeInst>(Val: CI)) {
89 NewCI = InvokeInst::Create(Func: NewF, IfNormal: cast<InvokeInst>(Val: II)->getNormalDest(),
90 IfException: cast<InvokeInst>(Val: II)->getUnwindDest(), Args,
91 Bundles: OperandBundles, NameStr: CI->getName());
92 } else {
93 assert(isa<CallInst>(CI));
94 NewCI = CallInst::Create(Func: NewF, Args, Bundles: OperandBundles, NameStr: CI->getName());
95 }
96 NewCI->setCallingConv(NewF->getCallingConv());
97
98 // Do the replacement for this use.
99 if (!CI->use_empty())
100 CI->replaceAllUsesWith(V: NewCI);
101 ReplaceInstWithInst(From: CI, To: NewCI);
102 }
103}
104
105/// Add a new function argument to @p F for each use in @OpsToReplace, and
106/// replace those operand values with the new function argument.
107static void substituteOperandWithArgument(Function *OldF,
108 ArrayRef<Use *> OpsToReplace) {
109 if (OpsToReplace.empty())
110 return;
111
112 SetVector<Value *> UniqueValues;
113 for (Use *Op : OpsToReplace)
114 UniqueValues.insert(X: Op->get());
115
116 // Determine the new function's signature.
117 SmallVector<Type *> NewArgTypes;
118 llvm::append_range(C&: NewArgTypes, R: OldF->getFunctionType()->params());
119 size_t ArgOffset = NewArgTypes.size();
120 for (Value *V : UniqueValues)
121 NewArgTypes.push_back(Elt: V->getType());
122 FunctionType *FTy =
123 FunctionType::get(Result: OldF->getFunctionType()->getReturnType(), Params: NewArgTypes,
124 isVarArg: OldF->getFunctionType()->isVarArg());
125
126 // Create the new function...
127 Function *NewF =
128 Function::Create(Ty: FTy, Linkage: OldF->getLinkage(), AddrSpace: OldF->getAddressSpace(),
129 N: OldF->getName(), M: OldF->getParent());
130
131 // In order to preserve function order, we move NewF behind OldF
132 NewF->removeFromParent();
133 OldF->getParent()->getFunctionList().insertAfter(where: OldF->getIterator(), New: NewF);
134
135 // Preserve the parameters of OldF.
136 ValueToValueMapTy VMap;
137 for (auto Z : zip_first(t: OldF->args(), u: NewF->args())) {
138 Argument &OldArg = std::get<0>(t&: Z);
139 Argument &NewArg = std::get<1>(t&: Z);
140
141 NewArg.setName(OldArg.getName()); // Copy the name over...
142 VMap[&OldArg] = &NewArg; // Add mapping to VMap
143 }
144
145 // Adjust the new parameters.
146 ValueToValueMapTy OldValMap;
147 for (auto Z : zip_first(t&: UniqueValues, u: drop_begin(RangeOrContainer: NewF->args(), N: ArgOffset))) {
148 Value *OldVal = std::get<0>(t&: Z);
149 Argument &NewArg = std::get<1>(t&: Z);
150
151 NewArg.setName(OldVal->getName());
152 OldValMap[OldVal] = &NewArg;
153 }
154
155 SmallVector<ReturnInst *, 8> Returns; // Ignore returns cloned.
156 CloneFunctionInto(NewFunc: NewF, OldFunc: OldF, VMap, Changes: CloneFunctionChangeType::LocalChangesOnly,
157 Returns, NameSuffix: "", /*CodeInfo=*/nullptr);
158
159 // Replace the actual operands.
160 for (Use *Op : OpsToReplace) {
161 Value *NewArg = OldValMap.lookup(Val: Op->get());
162 auto *NewUser = cast<Instruction>(Val: VMap.lookup(Val: Op->getUser()));
163
164 if (PHINode *NewPhi = dyn_cast<PHINode>(Val: NewUser)) {
165 PHINode *OldPhi = cast<PHINode>(Val: Op->getUser());
166 BasicBlock *OldBB = OldPhi->getIncomingBlock(U: *Op);
167 NewPhi->setIncomingValueForBlock(BB: cast<BasicBlock>(Val: VMap.lookup(Val: OldBB)),
168 V: NewArg);
169 } else
170 NewUser->setOperand(i: Op->getOperandNo(), Val: NewArg);
171 }
172
173 // Replace all OldF uses with NewF.
174 replaceFunctionCalls(OldF, NewF);
175
176 // Rename NewF to OldF's name.
177 std::string FName = OldF->getName().str();
178 OldF->replaceAllUsesWith(V: NewF);
179 OldF->eraseFromParent();
180 NewF->setName(FName);
181}
182
183static void reduceOperandsToArgs(Oracle &O, ReducerWorkItem &WorkItem) {
184 Module &Program = WorkItem.getModule();
185
186 SmallVector<Use *> OperandsToReduce;
187 for (Function &F : make_early_inc_range(Range: Program.functions())) {
188 if (!canReplaceFunction(F: &F))
189 continue;
190 OperandsToReduce.clear();
191 for (Instruction &I : instructions(F: &F)) {
192 for (Use &Op : I.operands()) {
193 if (!canReduceUse(Op))
194 continue;
195 if (O.shouldKeep())
196 continue;
197
198 OperandsToReduce.push_back(Elt: &Op);
199 }
200 }
201
202 substituteOperandWithArgument(OldF: &F, OpsToReplace: OperandsToReduce);
203 }
204}
205
206void llvm::reduceOperandsToArgsDeltaPass(TestRunner &Test) {
207 runDeltaPass(Test, ExtractChunksFromModule: reduceOperandsToArgs,
208 Message: "Converting operands to function arguments");
209}
210