1//===----------------------------------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Try to reduce a function by inserting new return instructions. Try to insert
10// an early return for each instruction value at that point. This requires
11// mutating the return type, or finding instructions with a compatible type.
12//
13//===----------------------------------------------------------------------===//
14
15#define DEBUG_TYPE "llvm-reduce"
16
17#include "ReduceValuesToReturn.h"
18
19#include "Delta.h"
20#include "Utils.h"
21#include "llvm/IR/AttributeMask.h"
22#include "llvm/IR/Attributes.h"
23#include "llvm/IR/CFG.h"
24#include "llvm/IR/Instructions.h"
25#include "llvm/Support/Debug.h"
26#include "llvm/Transforms/Utils/BasicBlockUtils.h"
27
28using namespace llvm;
29
30/// Return true if it is legal to emit a copy of the function with a non-void
31/// return type.
32static bool canUseNonVoidReturnType(const Function &F) {
33 // Functions with sret arguments must return void.
34 return !F.hasStructRetAttr() &&
35 CallingConv::supportsNonVoidReturnType(CC: F.getCallingConv());
36}
37
38/// Return true if it's legal to replace a function return type to use \p Ty.
39static bool isReallyValidReturnType(Type *Ty) {
40 return FunctionType::isValidReturnType(RetTy: Ty) && !Ty->isTokenTy() &&
41 Ty->isFirstClassType();
42}
43
44/// Insert a ret inst after \p NewRetValue, which returns the value it produces.
45static void rewriteFuncWithReturnType(Function &OldF, Value *NewRetValue) {
46 Type *NewRetTy = NewRetValue->getType();
47 FunctionType *OldFuncTy = OldF.getFunctionType();
48
49 FunctionType *NewFuncTy =
50 FunctionType::get(Result: NewRetTy, Params: OldFuncTy->params(), isVarArg: OldFuncTy->isVarArg());
51
52 LLVMContext &Ctx = OldF.getContext();
53 BasicBlock &EntryBB = OldF.getEntryBlock();
54 Instruction *NewRetI = dyn_cast<Instruction>(Val: NewRetValue);
55 BasicBlock *NewRetBlock = NewRetI ? NewRetI->getParent() : &EntryBB;
56
57 BasicBlock::iterator NewValIt =
58 NewRetI ? std::next(x: NewRetI->getIterator()) : EntryBB.begin();
59
60 Type *OldRetTy = OldFuncTy->getReturnType();
61
62 // Hack up any return values in other blocks, we can't leave them as returning OldRetTy.
63 if (OldRetTy != NewRetTy) {
64 for (BasicBlock &OtherRetBB : OldF) {
65 if (&OtherRetBB != NewRetBlock) {
66 auto *OrigRI = dyn_cast<ReturnInst>(Val: OtherRetBB.getTerminator());
67 if (!OrigRI)
68 continue;
69
70 OrigRI->eraseFromParent();
71 ReturnInst::Create(C&: Ctx, retVal: getDefaultValue(T: NewRetTy), InsertBefore: &OtherRetBB);
72 }
73 }
74 }
75
76 // If we're returning an instruction, split the basic block so we can let
77 // simpleSimplifyCFG cleanup the successors.
78 BasicBlock *TailBB = NewRetBlock->splitBasicBlock(I: NewValIt);
79
80 // Replace the unconditional branch splitBasicBlock created
81 NewRetBlock->getTerminator()->eraseFromParent();
82 ReturnInst::Create(C&: Ctx, retVal: NewRetValue, InsertBefore: NewRetBlock);
83
84 // Now prune any CFG edges we have to deal with.
85 simpleSimplifyCFG(F&: OldF, BBs: {TailBB}, /*FoldBlockIntoPredecessor=*/false);
86
87 // Drop the incompatible attributes before we copy over to the new function.
88 if (OldRetTy != NewRetTy) {
89 AttributeList AL = OldF.getAttributes();
90 AttributeMask IncompatibleAttrs =
91 AttributeFuncs::typeIncompatible(Ty: NewRetTy, AS: AL.getRetAttrs());
92 OldF.removeRetAttrs(Attrs: IncompatibleAttrs);
93 }
94
95 // Now we need to remove any returned attributes from parameters.
96 for (Argument &A : OldF.args())
97 OldF.removeParamAttr(ArgNo: A.getArgNo(), Kind: Attribute::Returned);
98
99 Function *NewF =
100 Function::Create(Ty: NewFuncTy, Linkage: OldF.getLinkage(), AddrSpace: OldF.getAddressSpace(), N: "",
101 M: OldF.getParent());
102
103 NewF->removeFromParent();
104 OldF.getParent()->getFunctionList().insertAfter(where: OldF.getIterator(), New: NewF);
105 NewF->takeName(V: &OldF);
106 NewF->copyAttributesFrom(Src: &OldF);
107
108 // Adjust the callsite uses to the new return type. We pre-filtered cases
109 // where the original call type was incorrectly non-void.
110 for (User *U : make_early_inc_range(Range: OldF.users())) {
111 if (auto *CB = dyn_cast<CallBase>(Val: U);
112 CB && CB->getCalledOperand() == &OldF) {
113 if (CB->getType()->isVoidTy()) {
114 FunctionType *CallType = CB->getFunctionType();
115
116 // The callsite may not match the new function type, in an undefined
117 // behavior way. Only mutate the local return type.
118 FunctionType *NewCallType = FunctionType::get(
119 Result: NewRetTy, Params: CallType->params(), isVarArg: CallType->isVarArg());
120
121 CB->mutateType(Ty: NewRetTy);
122 CB->setCalledFunction(FTy: NewCallType, Fn: NewF);
123 } else {
124 assert(CB->getType() == NewRetTy &&
125 "only handle exact return type match with non-void returns");
126 }
127 }
128 }
129
130 NewF->splice(ToIt: NewF->begin(), FromF: &OldF);
131 OldF.replaceAllUsesWith(V: NewF);
132
133 // Preserve the parameters of OldF.
134 for (auto Z : zip_first(t: OldF.args(), u: NewF->args())) {
135 Argument &OldArg = std::get<0>(t&: Z);
136 Argument &NewArg = std::get<1>(t&: Z);
137
138 OldArg.replaceAllUsesWith(V: &NewArg);
139 NewArg.takeName(V: &OldArg);
140 }
141
142 OldF.eraseFromParent();
143}
144
145// Check if all the callsites of the void function are void, or happen to
146// incorrectly use the new return type.
147//
148// TODO: We could make better effort to handle call type mismatches.
149static bool canReplaceFuncUsers(const Function &F, Type *NewRetTy) {
150 for (const Use &U : F.uses()) {
151 const CallBase *CB = dyn_cast<CallBase>(Val: U.getUser());
152 if (!CB)
153 continue;
154
155 // Normal pointer uses are trivially replacable.
156 if (!CB->isCallee(U: &U))
157 continue;
158
159 // We can trivially replace the correct void call sites.
160 if (CB->getType()->isVoidTy())
161 continue;
162
163 // We can trivially replace the call if the return type happened to match
164 // the new return type.
165 if (CB->getType() == NewRetTy)
166 continue;
167
168 // TODO: If all callsites have no uses, we could mutate the type of all the
169 // callsites. This will complicate the visit and rewrite ordering though.
170 LLVM_DEBUG(dbgs() << "Cannot replace used callsite with wrong type: " << *CB
171 << '\n');
172 return false;
173 }
174
175 return true;
176}
177
178/// Return true if it's worthwhile replacing the non-void return value of \p BB
179/// with \p Replacement
180static bool shouldReplaceNonVoidReturnValue(const BasicBlock &BB,
181 const Value *Replacement) {
182 if (const auto *RI = dyn_cast<ReturnInst>(Val: BB.getTerminator()))
183 return RI->getReturnValue() != Replacement;
184 return true;
185}
186
187static bool shouldForwardValueToReturn(const BasicBlock &BB, const Value *V,
188 Type *RetTy) {
189 if (!isReallyValidReturnType(Ty: V->getType()))
190 return false;
191
192 return (RetTy->isVoidTy() || shouldReplaceNonVoidReturnValue(BB, Replacement: V)) &&
193 canReplaceFuncUsers(F: *BB.getParent(), NewRetTy: V->getType());
194}
195
196static bool tryForwardingInstructionsToReturn(
197 Function &F, Oracle &O,
198 std::vector<std::pair<Function *, Value *>> &FuncsToReplace) {
199
200 // TODO: Should we try to expand returns to aggregate for function that
201 // already have a return value?
202 Type *RetTy = F.getReturnType();
203
204 for (BasicBlock &BB : F) {
205 // Skip the terminator, we can't insert a second terminator to return its
206 // value.
207 for (Instruction &I : make_range(x: BB.begin(), y: std::prev(x: BB.end()))) {
208 if (shouldForwardValueToReturn(BB, V: &I, RetTy) && !O.shouldKeep()) {
209 FuncsToReplace.emplace_back(args: &F, args: &I);
210 return true;
211 }
212 }
213 }
214
215 return false;
216}
217
218static bool tryForwardingArgumentsToReturn(
219 Function &F, Oracle &O,
220 std::vector<std::pair<Function *, Value *>> &FuncsToReplace) {
221
222 Type *RetTy = F.getReturnType();
223 BasicBlock &EntryBB = F.getEntryBlock();
224
225 for (Argument &A : F.args()) {
226 if (shouldForwardValueToReturn(BB: EntryBB, V: &A, RetTy) && !O.shouldKeep()) {
227 FuncsToReplace.emplace_back(args: &F, args: &A);
228 return true;
229 }
230 }
231
232 return false;
233}
234
235void llvm::reduceArgumentsToReturnDeltaPass(Oracle &O,
236 ReducerWorkItem &WorkItem) {
237 Module &Program = WorkItem.getModule();
238
239 // We're going to chaotically hack on the other users of the function in other
240 // functions, so we need to collect a worklist of returns to replace.
241 std::vector<std::pair<Function *, Value *>> FuncsToReplace;
242
243 for (Function &F : Program.functions()) {
244 if (!F.isDeclaration() && canUseNonVoidReturnType(F))
245 tryForwardingArgumentsToReturn(F, O, FuncsToReplace);
246 }
247
248 for (auto [F, NewRetVal] : FuncsToReplace)
249 rewriteFuncWithReturnType(OldF&: *F, NewRetValue: NewRetVal);
250}
251
252void llvm::reduceInstructionsToReturnDeltaPass(Oracle &O,
253 ReducerWorkItem &WorkItem) {
254 Module &Program = WorkItem.getModule();
255
256 // We're going to chaotically hack on the other users of the function in other
257 // functions, so we need to collect a worklist of returns to replace.
258 std::vector<std::pair<Function *, Value *>> FuncsToReplace;
259
260 for (Function &F : Program.functions()) {
261 if (!F.isDeclaration() && canUseNonVoidReturnType(F))
262 tryForwardingInstructionsToReturn(F, O, FuncsToReplace);
263 }
264
265 for (auto [F, NewRetVal] : FuncsToReplace)
266 rewriteFuncWithReturnType(OldF&: *F, NewRetValue: NewRetVal);
267}
268