1//===- CoroAnnotationElide.cpp - Elide attributed safe coroutine calls ----===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// \file
10// This pass transforms all Call or Invoke instructions that are annotated
11// "coro_elide_safe" to call the `.noalloc` variant of coroutine instead.
12// The frame of the callee coroutine is allocated inside the caller. A pointer
13// to the allocated frame will be passed into the `.noalloc` ramp function.
14//
15//===----------------------------------------------------------------------===//
16
17#include "llvm/Transforms/Coroutines/CoroAnnotationElide.h"
18
19#include "llvm/Analysis/CGSCCPassManager.h"
20#include "llvm/Analysis/LazyCallGraph.h"
21#include "llvm/Analysis/OptimizationRemarkEmitter.h"
22#include "llvm/IR/Analysis.h"
23#include "llvm/IR/IRBuilder.h"
24#include "llvm/IR/Instruction.h"
25#include "llvm/IR/Module.h"
26#include "llvm/IR/PassManager.h"
27#include "llvm/Transforms/Utils/CallGraphUpdater.h"
28#include "llvm/Transforms/Utils/Cloning.h"
29
30#include <cassert>
31
32using namespace llvm;
33
34#define DEBUG_TYPE "coro-annotation-elide"
35
36static Instruction *getFirstNonAllocaInTheEntryBlock(Function *F) {
37 for (Instruction &I : F->getEntryBlock())
38 if (!isa<AllocaInst>(Val: &I))
39 return &I;
40 llvm_unreachable("no terminator in the entry block");
41}
42
43// Create an alloca in the caller, using FrameSize and FrameAlign as the callee
44// coroutine's activation frame.
45static Value *allocateFrameInCaller(Function *Caller, uint64_t FrameSize,
46 Align FrameAlign) {
47 LLVMContext &C = Caller->getContext();
48 BasicBlock::iterator InsertPt =
49 getFirstNonAllocaInTheEntryBlock(F: Caller)->getIterator();
50 const DataLayout &DL = Caller->getDataLayout();
51 auto FrameTy = ArrayType::get(ElementType: Type::getInt8Ty(C), NumElements: FrameSize);
52 auto *Frame = new AllocaInst(FrameTy, DL.getAllocaAddrSpace(), "", InsertPt);
53 Frame->setAlignment(FrameAlign);
54 return Frame;
55}
56
57// Given a call or invoke instruction to the elide safe coroutine, this function
58// does the following:
59// - Allocate a frame for the callee coroutine in the caller using alloca.
60// - Replace the old CB with a new Call or Invoke to `NewCallee`, with the
61// pointer to the frame as an additional argument to NewCallee.
62static void processCall(CallBase *CB, Function *Caller, Function *NewCallee,
63 uint64_t FrameSize, Align FrameAlign) {
64 // TODO: generate the lifetime intrinsics for the new frame. This will require
65 // introduction of two pesudo lifetime intrinsics in the frontend around the
66 // `co_await` expression and convert them to real lifetime intrinsics here.
67 auto *FramePtr = allocateFrameInCaller(Caller, FrameSize, FrameAlign);
68 auto NewCBInsertPt = CB->getIterator();
69 llvm::CallBase *NewCB = nullptr;
70 SmallVector<Value *, 4> NewArgs;
71 NewArgs.append(in_start: CB->arg_begin(), in_end: CB->arg_end());
72 NewArgs.push_back(Elt: FramePtr);
73
74 if (auto *CI = dyn_cast<CallInst>(Val: CB)) {
75 auto *NewCI = CallInst::Create(Ty: NewCallee->getFunctionType(), Func: NewCallee,
76 Args: NewArgs, NameStr: "", InsertBefore: NewCBInsertPt);
77 NewCI->setTailCallKind(CI->getTailCallKind());
78 NewCB = NewCI;
79 } else if (auto *II = dyn_cast<InvokeInst>(Val: CB)) {
80 NewCB = InvokeInst::Create(Ty: NewCallee->getFunctionType(), Func: NewCallee,
81 IfNormal: II->getNormalDest(), IfException: II->getUnwindDest(),
82 Args: NewArgs, Bundles: {}, NameStr: "", InsertBefore: NewCBInsertPt);
83 } else {
84 llvm_unreachable("CallBase should either be Call or Invoke!");
85 }
86
87 NewCB->setCalledFunction(FTy: NewCallee->getFunctionType(), Fn: NewCallee);
88 NewCB->setCallingConv(CB->getCallingConv());
89 NewCB->setAttributes(CB->getAttributes());
90 NewCB->setDebugLoc(CB->getDebugLoc());
91 std::copy(first: CB->bundle_op_info_begin(), last: CB->bundle_op_info_end(),
92 result: NewCB->bundle_op_info_begin());
93
94 NewCB->removeFnAttr(Kind: llvm::Attribute::CoroElideSafe);
95 CB->replaceAllUsesWith(V: NewCB);
96
97 InlineFunctionInfo IFI;
98 InlineResult IR = InlineFunction(CB&: *NewCB, IFI);
99 if (IR.isSuccess()) {
100 CB->eraseFromParent();
101 } else {
102 NewCB->replaceAllUsesWith(V: CB);
103 NewCB->eraseFromParent();
104 }
105}
106
107PreservedAnalyses CoroAnnotationElidePass::run(LazyCallGraph::SCC &C,
108 CGSCCAnalysisManager &AM,
109 LazyCallGraph &CG,
110 CGSCCUpdateResult &UR) {
111 bool Changed = false;
112 CallGraphUpdater CGUpdater;
113 CGUpdater.initialize(LCG&: CG, SCC&: C, AM, UR);
114
115 auto &FAM =
116 AM.getResult<FunctionAnalysisManagerCGSCCProxy>(IR&: C, ExtraArgs&: CG).getManager();
117
118 for (LazyCallGraph::Node &N : C) {
119 Function *Callee = &N.getFunction();
120 Function *NewCallee = Callee->getParent()->getFunction(
121 Name: (Callee->getName() + ".noalloc").str());
122 if (!NewCallee)
123 continue;
124
125 SmallVector<CallBase *, 4> Users;
126 for (auto *U : Callee->users()) {
127 if (auto *CB = dyn_cast<CallBase>(Val: U)) {
128 if (CB->getCalledFunction() == Callee)
129 Users.push_back(Elt: CB);
130 }
131 }
132 auto FramePtrArgPosition = NewCallee->arg_size() - 1;
133 auto FrameSize =
134 NewCallee->getParamDereferenceableBytes(ArgNo: FramePtrArgPosition);
135 auto FrameAlign =
136 NewCallee->getParamAlign(ArgNo: FramePtrArgPosition).valueOrOne();
137
138 auto &ORE = FAM.getResult<OptimizationRemarkEmitterAnalysis>(IR&: *Callee);
139
140 for (auto *CB : Users) {
141 auto *Caller = CB->getFunction();
142 if (!Caller)
143 continue;
144
145 bool IsCallerPresplitCoroutine = Caller->isPresplitCoroutine();
146 bool HasAttr = CB->hasFnAttr(Kind: llvm::Attribute::CoroElideSafe);
147 if (IsCallerPresplitCoroutine && HasAttr) {
148 auto *CallerN = CG.lookup(F: *Caller);
149 auto *CallerC = CallerN ? CG.lookupSCC(N&: *CallerN) : nullptr;
150 // If CallerC is nullptr, it means LazyCallGraph hasn't visited Caller
151 // yet. Skip the call graph update.
152 auto ShouldUpdateCallGraph = !!CallerC;
153 processCall(CB, Caller, NewCallee, FrameSize, FrameAlign);
154
155 ORE.emit(RemarkBuilder: [&]() {
156 return OptimizationRemark(DEBUG_TYPE, "CoroAnnotationElide", Caller)
157 << "'" << ore::NV("callee", Callee->getName())
158 << "' elided in '" << ore::NV("caller", Caller->getName())
159 << "'";
160 });
161
162 FAM.invalidate(IR&: *Caller, PA: PreservedAnalyses::none());
163 Changed = true;
164 if (ShouldUpdateCallGraph)
165 updateCGAndAnalysisManagerForCGSCCPass(G&: CG, C&: *CallerC, N&: *CallerN, AM, UR,
166 FAM);
167
168 } else {
169 ORE.emit(RemarkBuilder: [&]() {
170 return OptimizationRemarkMissed(DEBUG_TYPE, "CoroAnnotationElide",
171 Caller)
172 << "'" << ore::NV("callee", Callee->getName())
173 << "' not elided in '" << ore::NV("caller", Caller->getName())
174 << "' (caller_presplit="
175 << ore::NV("caller_presplit", IsCallerPresplitCoroutine)
176 << ", elide_safe_attr=" << ore::NV("elide_safe_attr", HasAttr)
177 << ")";
178 });
179 }
180 }
181 }
182
183 return Changed ? PreservedAnalyses::none() : PreservedAnalyses::all();
184}
185