| 1 | //===- ReduceArguments.cpp - Specialized Delta Pass -----------------------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | // |
| 9 | // This file implements a function which calls the Generic Delta pass in order |
| 10 | // to reduce uninteresting Arguments from declared and defined functions. |
| 11 | // |
| 12 | //===----------------------------------------------------------------------===// |
| 13 | |
| 14 | #include "ReduceArguments.h" |
| 15 | #include "Utils.h" |
| 16 | #include "llvm/ADT/SmallVector.h" |
| 17 | #include "llvm/IR/FMF.h" |
| 18 | #include "llvm/IR/Instructions.h" |
| 19 | #include "llvm/IR/Intrinsics.h" |
| 20 | #include "llvm/IR/Operator.h" |
| 21 | #include "llvm/Transforms/Utils/BasicBlockUtils.h" |
| 22 | #include "llvm/Transforms/Utils/Cloning.h" |
| 23 | #include <set> |
| 24 | #include <vector> |
| 25 | |
| 26 | using namespace llvm; |
| 27 | |
| 28 | static bool callingConvRequiresArgument(const Function &F, |
| 29 | const Argument &Arg) { |
| 30 | switch (F.getCallingConv()) { |
| 31 | case CallingConv::X86_INTR: |
| 32 | // If there are any arguments, the first one must by byval. |
| 33 | return Arg.getArgNo() == 0 && F.arg_size() != 1; |
| 34 | default: |
| 35 | return false; |
| 36 | } |
| 37 | |
| 38 | llvm_unreachable("covered calling conv switch" ); |
| 39 | } |
| 40 | |
| 41 | /// Goes over OldF calls and replaces them with a call to NewF |
| 42 | static void replaceFunctionCalls(Function &OldF, Function &NewF, |
| 43 | const std::set<int> &ArgIndexesToKeep) { |
| 44 | LLVMContext &Ctx = OldF.getContext(); |
| 45 | |
| 46 | const auto &Users = OldF.users(); |
| 47 | for (auto I = Users.begin(), E = Users.end(); I != E; ) |
| 48 | if (auto *CI = dyn_cast<CallInst>(Val: *I++)) { |
| 49 | // Skip uses in call instructions where OldF isn't the called function |
| 50 | // (e.g. if OldF is an argument of the call). |
| 51 | if (CI->getCalledFunction() != &OldF) |
| 52 | continue; |
| 53 | SmallVector<Value *, 8> Args; |
| 54 | SmallVector<AttrBuilder, 8> ArgAttrs; |
| 55 | |
| 56 | for (auto ArgI = CI->arg_begin(), E = CI->arg_end(); ArgI != E; ++ArgI) { |
| 57 | unsigned ArgIdx = ArgI - CI->arg_begin(); |
| 58 | if (ArgIndexesToKeep.count(x: ArgIdx)) { |
| 59 | Args.push_back(Elt: *ArgI); |
| 60 | ArgAttrs.emplace_back(Args&: Ctx, Args: CI->getParamAttributes(ArgNo: ArgIdx)); |
| 61 | } |
| 62 | } |
| 63 | |
| 64 | SmallVector<OperandBundleDef, 2> OpBundles; |
| 65 | CI->getOperandBundlesAsDefs(Defs&: OpBundles); |
| 66 | |
| 67 | CallInst *NewCI = CallInst::Create(Func: &NewF, Args, Bundles: OpBundles); |
| 68 | NewCI->setCallingConv(CI->getCallingConv()); |
| 69 | |
| 70 | AttrBuilder CallSiteAttrs(Ctx, CI->getAttributes().getFnAttrs()); |
| 71 | NewCI->setAttributes( |
| 72 | AttributeList::get(C&: Ctx, Index: AttributeList::FunctionIndex, B: CallSiteAttrs)); |
| 73 | NewCI->addRetAttrs(B: AttrBuilder(Ctx, CI->getRetAttributes())); |
| 74 | |
| 75 | unsigned AttrIdx = 0; |
| 76 | for (auto ArgI = NewCI->arg_begin(), E = NewCI->arg_end(); ArgI != E; |
| 77 | ++ArgI, ++AttrIdx) |
| 78 | NewCI->addParamAttrs(ArgNo: AttrIdx, B: ArgAttrs[AttrIdx]); |
| 79 | |
| 80 | if (auto *FPOp = dyn_cast<FPMathOperator>(Val: NewCI)) |
| 81 | cast<Instruction>(Val: FPOp)->setFastMathFlags(CI->getFastMathFlags()); |
| 82 | |
| 83 | NewCI->copyMetadata(SrcInst: *CI); |
| 84 | |
| 85 | if (!CI->use_empty()) |
| 86 | CI->replaceAllUsesWith(V: NewCI); |
| 87 | ReplaceInstWithInst(From: CI, To: NewCI); |
| 88 | } |
| 89 | } |
| 90 | |
| 91 | /// Returns whether or not this function should be considered a candidate for |
| 92 | /// argument removal. Currently, functions with no arguments and intrinsics are |
| 93 | /// not considered. Intrinsics aren't considered because their signatures are |
| 94 | /// fixed. |
| 95 | static bool shouldRemoveArguments(const Function &F) { |
| 96 | return !F.arg_empty() && !F.isIntrinsic(); |
| 97 | } |
| 98 | |
| 99 | static bool allFuncUsersRewritable(const Function &F) { |
| 100 | for (const Use &U : F.uses()) { |
| 101 | const CallBase *CB = dyn_cast<CallBase>(Val: U.getUser()); |
| 102 | if (!CB || !CB->isCallee(U: &U)) |
| 103 | continue; |
| 104 | |
| 105 | // TODO: Handle all CallBase cases. |
| 106 | if (!isa<CallInst>(Val: CB)) |
| 107 | return false; |
| 108 | } |
| 109 | |
| 110 | return true; |
| 111 | } |
| 112 | |
| 113 | /// Removes out-of-chunk arguments from functions, and modifies their calls |
| 114 | /// accordingly. It also removes allocations of out-of-chunk arguments. |
| 115 | void llvm::reduceArgumentsDeltaPass(Oracle &O, ReducerWorkItem &WorkItem) { |
| 116 | Module &Program = WorkItem.getModule(); |
| 117 | std::vector<Argument *> InitArgsToKeep; |
| 118 | std::vector<Function *> Funcs; |
| 119 | |
| 120 | // Get inside-chunk arguments, as well as their parent function |
| 121 | for (auto &F : Program) { |
| 122 | if (!shouldRemoveArguments(F)) |
| 123 | continue; |
| 124 | if (!allFuncUsersRewritable(F)) |
| 125 | continue; |
| 126 | Funcs.push_back(x: &F); |
| 127 | for (auto &A : F.args()) { |
| 128 | if (callingConvRequiresArgument(F, Arg: A) || O.shouldKeep()) |
| 129 | InitArgsToKeep.push_back(x: &A); |
| 130 | } |
| 131 | } |
| 132 | |
| 133 | // We create a vector first, then convert it to a set, so that we don't have |
| 134 | // to pay the cost of rebalancing the set frequently if the order we insert |
| 135 | // the elements doesn't match the order they should appear inside the set. |
| 136 | std::set<Argument *> ArgsToKeep(InitArgsToKeep.begin(), InitArgsToKeep.end()); |
| 137 | |
| 138 | for (auto *F : Funcs) { |
| 139 | ValueToValueMapTy VMap; |
| 140 | std::vector<WeakVH> InstToDelete; |
| 141 | for (auto &A : F->args()) |
| 142 | if (!ArgsToKeep.count(x: &A)) { |
| 143 | // By adding undesired arguments to the VMap, CloneFunction will remove |
| 144 | // them from the resulting Function |
| 145 | VMap[&A] = getDefaultValue(T: A.getType()); |
| 146 | for (auto *U : A.users()) |
| 147 | if (auto *I = dyn_cast<Instruction>(Val: *&U)) |
| 148 | InstToDelete.push_back(x: I); |
| 149 | } |
| 150 | // Delete any (unique) instruction that uses the argument |
| 151 | for (Value *V : InstToDelete) { |
| 152 | if (!V) |
| 153 | continue; |
| 154 | auto *I = cast<Instruction>(Val: V); |
| 155 | I->replaceAllUsesWith(V: getDefaultValue(T: I->getType())); |
| 156 | if (!I->isTerminator()) |
| 157 | I->eraseFromParent(); |
| 158 | } |
| 159 | |
| 160 | // No arguments to reduce |
| 161 | if (VMap.empty()) |
| 162 | continue; |
| 163 | |
| 164 | std::set<int> ArgIndexesToKeep; |
| 165 | for (const auto &[Index, Arg] : enumerate(First: F->args())) |
| 166 | if (ArgsToKeep.count(x: &Arg)) |
| 167 | ArgIndexesToKeep.insert(x: Index); |
| 168 | |
| 169 | auto *ClonedFunc = CloneFunction(F, VMap); |
| 170 | // In order to preserve function order, we move Clone after old Function |
| 171 | ClonedFunc->takeName(V: F); |
| 172 | ClonedFunc->removeFromParent(); |
| 173 | Program.getFunctionList().insertAfter(where: F->getIterator(), New: ClonedFunc); |
| 174 | |
| 175 | replaceFunctionCalls(OldF&: *F, NewF&: *ClonedFunc, ArgIndexesToKeep); |
| 176 | F->replaceAllUsesWith(V: ClonedFunc); |
| 177 | F->eraseFromParent(); |
| 178 | } |
| 179 | } |
| 180 | |