1//===- CoroAnnotationElide.cpp - Elide attributed safe coroutine calls ----===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// \file
10// This pass transforms all Call or Invoke instructions that are annotated
11// "coro_elide_safe" to call the `.noalloc` variant of coroutine instead.
12// The frame of the callee coroutine is allocated inside the caller. A pointer
13// to the allocated frame will be passed into the `.noalloc` ramp function.
14//
15//===----------------------------------------------------------------------===//
16
17#include "llvm/Transforms/Coroutines/CoroAnnotationElide.h"
18
19#include "llvm/Analysis/CGSCCPassManager.h"
20#include "llvm/Analysis/LazyCallGraph.h"
21#include "llvm/Analysis/OptimizationRemarkEmitter.h"
22#include "llvm/IR/Analysis.h"
23#include "llvm/IR/IRBuilder.h"
24#include "llvm/IR/Instruction.h"
25#include "llvm/IR/Module.h"
26#include "llvm/IR/PassManager.h"
27#include "llvm/Support/BranchProbability.h"
28#include "llvm/Support/CommandLine.h"
29#include "llvm/Support/FileSystem.h"
30#include "llvm/Transforms/Utils/CallGraphUpdater.h"
31#include "llvm/Transforms/Utils/Cloning.h"
32
33#include <cassert>
34
35using namespace llvm;
36
37#define DEBUG_TYPE "coro-annotation-elide"
38
39static cl::opt<float> CoroElideBranchRatio(
40 "coro-elide-branch-ratio", cl::init(Val: 0.55), cl::Hidden,
41 cl::desc("Minimum BranchProbability to consider a elide a coroutine."));
42extern cl::opt<unsigned> MinBlockCounterExecution;
43
44static Instruction *getFirstNonAllocaInTheEntryBlock(Function *F) {
45 for (Instruction &I : F->getEntryBlock())
46 if (!isa<AllocaInst>(Val: &I))
47 return &I;
48 llvm_unreachable("no terminator in the entry block");
49}
50
51// Create an alloca in the caller, using FrameSize and FrameAlign as the callee
52// coroutine's activation frame.
53static Value *allocateFrameInCaller(Function *Caller, uint64_t FrameSize,
54 Align FrameAlign) {
55 LLVMContext &C = Caller->getContext();
56 BasicBlock::iterator InsertPt =
57 getFirstNonAllocaInTheEntryBlock(F: Caller)->getIterator();
58 const DataLayout &DL = Caller->getDataLayout();
59 auto FrameTy = ArrayType::get(ElementType: Type::getInt8Ty(C), NumElements: FrameSize);
60 auto *Frame = new AllocaInst(FrameTy, DL.getAllocaAddrSpace(), "", InsertPt);
61 Frame->setAlignment(FrameAlign);
62 return Frame;
63}
64
65// Given a call or invoke instruction to the elide safe coroutine, this function
66// does the following:
67// - Allocate a frame for the callee coroutine in the caller using alloca.
68// - Replace the old CB with a new Call or Invoke to `NewCallee`, with the
69// pointer to the frame as an additional argument to NewCallee.
70static void processCall(CallBase *CB, Function *Caller, Function *NewCallee,
71 uint64_t FrameSize, Align FrameAlign) {
72 // TODO: generate the lifetime intrinsics for the new frame. This will require
73 // introduction of two pesudo lifetime intrinsics in the frontend around the
74 // `co_await` expression and convert them to real lifetime intrinsics here.
75 auto *FramePtr = allocateFrameInCaller(Caller, FrameSize, FrameAlign);
76 auto NewCBInsertPt = CB->getIterator();
77 llvm::CallBase *NewCB = nullptr;
78 SmallVector<Value *, 4> NewArgs;
79 NewArgs.append(in_start: CB->arg_begin(), in_end: CB->arg_end());
80 NewArgs.push_back(Elt: FramePtr);
81
82 if (auto *CI = dyn_cast<CallInst>(Val: CB)) {
83 auto *NewCI = CallInst::Create(Ty: NewCallee->getFunctionType(), Func: NewCallee,
84 Args: NewArgs, NameStr: "", InsertBefore: NewCBInsertPt);
85 NewCI->setTailCallKind(CI->getTailCallKind());
86 NewCB = NewCI;
87 } else if (auto *II = dyn_cast<InvokeInst>(Val: CB)) {
88 NewCB = InvokeInst::Create(Ty: NewCallee->getFunctionType(), Func: NewCallee,
89 IfNormal: II->getNormalDest(), IfException: II->getUnwindDest(),
90 Args: NewArgs, Bundles: {}, NameStr: "", InsertBefore: NewCBInsertPt);
91 } else {
92 llvm_unreachable("CallBase should either be Call or Invoke!");
93 }
94
95 NewCB->setCalledFunction(FTy: NewCallee->getFunctionType(), Fn: NewCallee);
96 NewCB->setCallingConv(CB->getCallingConv());
97 NewCB->setAttributes(CB->getAttributes());
98 NewCB->setDebugLoc(CB->getDebugLoc());
99 std::copy(first: CB->bundle_op_info_begin(), last: CB->bundle_op_info_end(),
100 result: NewCB->bundle_op_info_begin());
101
102 NewCB->removeFnAttr(Kind: llvm::Attribute::CoroElideSafe);
103 CB->replaceAllUsesWith(V: NewCB);
104
105 InlineFunctionInfo IFI;
106 InlineResult IR = InlineFunction(CB&: *NewCB, IFI);
107 if (IR.isSuccess()) {
108 CB->eraseFromParent();
109 } else {
110 NewCB->replaceAllUsesWith(V: CB);
111 NewCB->eraseFromParent();
112 }
113}
114
115PreservedAnalyses CoroAnnotationElidePass::run(LazyCallGraph::SCC &C,
116 CGSCCAnalysisManager &AM,
117 LazyCallGraph &CG,
118 CGSCCUpdateResult &UR) {
119 bool Changed = false;
120 CallGraphUpdater CGUpdater;
121 CGUpdater.initialize(LCG&: CG, SCC&: C, AM, UR);
122
123 auto &FAM =
124 AM.getResult<FunctionAnalysisManagerCGSCCProxy>(IR&: C, ExtraArgs&: CG).getManager();
125
126 for (LazyCallGraph::Node &N : C) {
127 Function *Callee = &N.getFunction();
128 Function *NewCallee = Callee->getParent()->getFunction(
129 Name: (Callee->getName() + ".noalloc").str());
130 if (!NewCallee)
131 continue;
132
133 SmallVector<CallBase *, 4> Users;
134 for (auto *U : Callee->users()) {
135 if (auto *CB = dyn_cast<CallBase>(Val: U)) {
136 if (CB->getCalledFunction() == Callee)
137 Users.push_back(Elt: CB);
138 }
139 }
140 auto FramePtrArgPosition = NewCallee->arg_size() - 1;
141 auto FrameSize =
142 NewCallee->getParamDereferenceableBytes(ArgNo: FramePtrArgPosition);
143 auto FrameAlign =
144 NewCallee->getParamAlign(ArgNo: FramePtrArgPosition).valueOrOne();
145
146 auto &ORE = FAM.getResult<OptimizationRemarkEmitterAnalysis>(IR&: *Callee);
147
148 for (auto *CB : Users) {
149 auto *Caller = CB->getFunction();
150 if (!Caller)
151 continue;
152
153 bool IsCallerPresplitCoroutine = Caller->isPresplitCoroutine();
154 bool HasAttr = CB->hasFnAttr(Kind: llvm::Attribute::CoroElideSafe);
155 if (IsCallerPresplitCoroutine && HasAttr) {
156 auto &BFI = FAM.getResult<BlockFrequencyAnalysis>(IR&: *Caller);
157
158 auto BlockFreq = BFI.getBlockFreq(BB: CB->getParent()).getFrequency();
159 auto EntryFreq = BFI.getEntryFreq().getFrequency();
160 uint64_t MinFreq =
161 static_cast<uint64_t>(EntryFreq * CoroElideBranchRatio);
162
163 if (BlockFreq < MinFreq) {
164 ORE.emit(RemarkBuilder: [&]() {
165 return OptimizationRemarkMissed(
166 DEBUG_TYPE, "CoroAnnotationElideUnlikely", Caller)
167 << "'" << ore::NV("callee", Callee->getName())
168 << "' not elided in '"
169 << ore::NV("caller", Caller->getName())
170 << "' because of low frequency: "
171 << ore::NV("block_freq", BlockFreq)
172 << " (threshold: " << ore::NV("min_freq", MinFreq) << ")";
173 });
174 continue;
175 }
176
177 auto *CallerN = CG.lookup(F: *Caller);
178 auto *CallerC = CallerN ? CG.lookupSCC(N&: *CallerN) : nullptr;
179 // If CallerC is nullptr, it means LazyCallGraph hasn't visited Caller
180 // yet. Skip the call graph update.
181 auto ShouldUpdateCallGraph = !!CallerC;
182 processCall(CB, Caller, NewCallee, FrameSize, FrameAlign);
183
184 ORE.emit(RemarkBuilder: [&]() {
185 return OptimizationRemark(DEBUG_TYPE, "CoroAnnotationElide", Caller)
186 << "'" << ore::NV("callee", Callee->getName())
187 << "' elided in '" << ore::NV("caller", Caller->getName())
188 << "' (block_freq: " << ore::NV("block_freq", BlockFreq)
189 << ")";
190 });
191
192 FAM.invalidate(IR&: *Caller, PA: PreservedAnalyses::none());
193 Changed = true;
194 if (ShouldUpdateCallGraph)
195 updateCGAndAnalysisManagerForCGSCCPass(G&: CG, C&: *CallerC, N&: *CallerN, AM, UR,
196 FAM);
197
198 } else {
199 ORE.emit(RemarkBuilder: [&]() {
200 return OptimizationRemarkMissed(DEBUG_TYPE, "CoroAnnotationElide",
201 Caller)
202 << "'" << ore::NV("callee", Callee->getName())
203 << "' not elided in '" << ore::NV("caller", Caller->getName())
204 << "' (caller_presplit="
205 << ore::NV("caller_presplit", IsCallerPresplitCoroutine)
206 << ", elide_safe_attr=" << ore::NV("elide_safe_attr", HasAttr)
207 << ")";
208 });
209 }
210 }
211 }
212
213 return Changed ? PreservedAnalyses::none() : PreservedAnalyses::all();
214}
215