1//===- Coroutines.cpp -----------------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements the common infrastructure for Coroutine Passes.
10//
11//===----------------------------------------------------------------------===//
12
13#include "CoroInternal.h"
14#include "llvm/ADT/SmallVector.h"
15#include "llvm/ADT/StringRef.h"
16#include "llvm/Analysis/CallGraph.h"
17#include "llvm/IR/Attributes.h"
18#include "llvm/IR/Constants.h"
19#include "llvm/IR/DerivedTypes.h"
20#include "llvm/IR/Function.h"
21#include "llvm/IR/InstIterator.h"
22#include "llvm/IR/Instructions.h"
23#include "llvm/IR/IntrinsicInst.h"
24#include "llvm/IR/Intrinsics.h"
25#include "llvm/IR/Module.h"
26#include "llvm/IR/Type.h"
27#include "llvm/Support/Casting.h"
28#include "llvm/Support/ErrorHandling.h"
29#include "llvm/Transforms/Coroutines/ABI.h"
30#include "llvm/Transforms/Coroutines/CoroInstr.h"
31#include "llvm/Transforms/Coroutines/CoroShape.h"
32#include "llvm/Transforms/Utils/Local.h"
33#include <cassert>
34#include <cstddef>
35#include <utility>
36
37using namespace llvm;
38
39// Construct the lowerer base class and initialize its members.
40coro::LowererBase::LowererBase(Module &M)
41 : TheModule(M), Context(M.getContext()),
42 Int8Ptr(PointerType::get(C&: Context, AddressSpace: 0)),
43 ResumeFnType(FunctionType::get(Result: Type::getVoidTy(C&: Context), Params: Int8Ptr,
44 /*isVarArg=*/false)),
45 NullPtr(ConstantPointerNull::get(T: Int8Ptr)) {}
46
47// Creates a call to llvm.coro.subfn.addr to obtain a resume function address.
48// It generates the following:
49//
50// call ptr @llvm.coro.subfn.addr(ptr %Arg, i8 %index)
51
52CallInst *coro::LowererBase::makeSubFnCall(Value *Arg, int Index,
53 Instruction *InsertPt) {
54 auto *IndexVal = ConstantInt::get(Ty: Type::getInt8Ty(C&: Context), V: Index);
55 auto *Fn =
56 Intrinsic::getOrInsertDeclaration(M: &TheModule, id: Intrinsic::coro_subfn_addr);
57
58 assert(Index >= CoroSubFnInst::IndexFirst &&
59 Index < CoroSubFnInst::IndexLast &&
60 "makeSubFnCall: Index value out of range");
61 return CallInst::Create(Func: Fn, Args: {Arg, IndexVal}, NameStr: "", InsertBefore: InsertPt->getIterator());
62}
63
64// We can only efficiently check for non-overloaded intrinsics.
65// The following intrinsics are absent for that reason:
66// coro_align, coro_size, coro_suspend_async, coro_suspend_retcon
67static Intrinsic::ID NonOverloadedCoroIntrinsics[] = {
68 Intrinsic::coro_alloc,
69 Intrinsic::coro_async_context_alloc,
70 Intrinsic::coro_async_context_dealloc,
71 Intrinsic::coro_async_resume,
72 Intrinsic::coro_async_size_replace,
73 Intrinsic::coro_await_suspend_bool,
74 Intrinsic::coro_await_suspend_handle,
75 Intrinsic::coro_await_suspend_void,
76 Intrinsic::coro_begin,
77 Intrinsic::coro_begin_custom_abi,
78 Intrinsic::coro_destroy,
79 Intrinsic::coro_done,
80 Intrinsic::coro_end,
81 Intrinsic::coro_end_async,
82 Intrinsic::coro_frame,
83 Intrinsic::coro_free,
84 Intrinsic::coro_id,
85 Intrinsic::coro_id_async,
86 Intrinsic::coro_id_retcon,
87 Intrinsic::coro_id_retcon_once,
88 Intrinsic::coro_noop,
89 Intrinsic::coro_prepare_async,
90 Intrinsic::coro_prepare_retcon,
91 Intrinsic::coro_promise,
92 Intrinsic::coro_resume,
93 Intrinsic::coro_save,
94 Intrinsic::coro_subfn_addr,
95 Intrinsic::coro_suspend,
96 Intrinsic::coro_is_in_ramp,
97};
98
99bool coro::isSuspendBlock(BasicBlock *BB) {
100 return isa<AnyCoroSuspendInst>(Val: BB->front());
101}
102
103bool coro::declaresAnyIntrinsic(const Module &M) {
104 return declaresIntrinsics(M, List: NonOverloadedCoroIntrinsics);
105}
106
107// Checks whether the module declares any of the listed intrinsics.
108bool coro::declaresIntrinsics(const Module &M, ArrayRef<Intrinsic::ID> List) {
109#ifndef NDEBUG
110 for (Intrinsic::ID ID : List)
111 assert(!Intrinsic::isOverloaded(ID) &&
112 "Only non-overloaded intrinsics supported");
113#endif
114
115 for (Intrinsic::ID ID : List)
116 if (Intrinsic::getDeclarationIfExists(M: &M, id: ID))
117 return true;
118 return false;
119}
120
121// Replace all coro.frees associated with the provided frame with 'null' and
122// erase all associated coro.deads
123void coro::elideCoroFree(Value *FramePtr) {
124 SmallVector<CoroFreeInst *, 4> CoroFrees;
125 SmallVector<CoroDeadInst *, 4> CoroDeads;
126 for (User *U : FramePtr->users()) {
127 if (auto *CF = dyn_cast<CoroFreeInst>(Val: U))
128 CoroFrees.push_back(Elt: CF);
129 else if (auto *CD = dyn_cast<CoroDeadInst>(Val: U))
130 CoroDeads.push_back(Elt: CD);
131 }
132
133 Value *Replacement =
134 ConstantPointerNull::get(T: PointerType::get(C&: FramePtr->getContext(), AddressSpace: 0));
135 for (CoroFreeInst *CF : CoroFrees) {
136 CF->replaceAllUsesWith(V: Replacement);
137 CF->eraseFromParent();
138 }
139
140 for (auto *CD : CoroDeads)
141 CD->eraseFromParent();
142}
143
144void coro::suppressCoroAllocs(CoroIdInst *CoroId) {
145 SmallVector<CoroAllocInst *, 4> CoroAllocs;
146 for (User *U : CoroId->users())
147 if (auto *CA = dyn_cast<CoroAllocInst>(Val: U))
148 CoroAllocs.push_back(Elt: CA);
149
150 if (CoroAllocs.empty())
151 return;
152
153 coro::suppressCoroAllocs(Context&: CoroId->getContext(), CoroAllocs);
154}
155
156// Replacing llvm.coro.alloc with false will suppress dynamic
157// allocation as it is expected for the frontend to generate the code that
158// looks like:
159// id = coro.id(...)
160// mem = coro.alloc(id) ? malloc(coro.size()) : 0;
161// coro.begin(id, mem)
162void coro::suppressCoroAllocs(LLVMContext &Context,
163 ArrayRef<CoroAllocInst *> CoroAllocs) {
164 auto *False = ConstantInt::getFalse(Context);
165 for (auto *CA : CoroAllocs) {
166 CA->replaceAllUsesWith(V: False);
167 CA->eraseFromParent();
168 }
169}
170
171static CoroSaveInst *createCoroSave(CoroBeginInst *CoroBegin,
172 CoroSuspendInst *SuspendInst) {
173 Module *M = SuspendInst->getModule();
174 auto *Fn = Intrinsic::getOrInsertDeclaration(M, id: Intrinsic::coro_save);
175 auto *SaveInst = cast<CoroSaveInst>(
176 Val: CallInst::Create(Func: Fn, Args: CoroBegin, NameStr: "", InsertBefore: SuspendInst->getIterator()));
177 assert(!SuspendInst->getCoroSave());
178 SuspendInst->setArgOperand(i: 0, v: SaveInst);
179 return SaveInst;
180}
181
182// Collect "interesting" coroutine intrinsics.
183void coro::Shape::analyze(Function &F,
184 SmallVectorImpl<CoroFrameInst *> &CoroFrames,
185 SmallVectorImpl<CoroSaveInst *> &UnusedCoroSaves,
186 CoroPromiseInst *&CoroPromise) {
187 clear();
188
189 bool HasFinalSuspend = false;
190 bool HasUnwindCoroEnd = false;
191 size_t FinalSuspendIndex = 0;
192
193 for (Instruction &I : instructions(F)) {
194 // FIXME: coro_await_suspend_* are not proper `IntrinisicInst`s
195 // because they might be invoked
196 if (auto AWS = dyn_cast<CoroAwaitSuspendInst>(Val: &I)) {
197 CoroAwaitSuspends.push_back(Elt: AWS);
198 } else if (auto II = dyn_cast<IntrinsicInst>(Val: &I)) {
199 switch (II->getIntrinsicID()) {
200 default:
201 continue;
202 case Intrinsic::coro_size:
203 CoroSizes.push_back(Elt: cast<CoroSizeInst>(Val: II));
204 break;
205 case Intrinsic::coro_align:
206 CoroAligns.push_back(Elt: cast<CoroAlignInst>(Val: II));
207 break;
208 case Intrinsic::coro_frame:
209 CoroFrames.push_back(Elt: cast<CoroFrameInst>(Val: II));
210 break;
211 case Intrinsic::coro_save:
212 // After optimizations, coro_suspends using this coro_save might have
213 // been removed, remember orphaned coro_saves to remove them later.
214 if (II->use_empty())
215 UnusedCoroSaves.push_back(Elt: cast<CoroSaveInst>(Val: II));
216 break;
217 case Intrinsic::coro_suspend_async: {
218 auto *Suspend = cast<CoroSuspendAsyncInst>(Val: II);
219 Suspend->checkWellFormed();
220 CoroSuspends.push_back(Elt: Suspend);
221 break;
222 }
223 case Intrinsic::coro_suspend_retcon: {
224 auto Suspend = cast<CoroSuspendRetconInst>(Val: II);
225 CoroSuspends.push_back(Elt: Suspend);
226 break;
227 }
228 case Intrinsic::coro_suspend: {
229 auto Suspend = cast<CoroSuspendInst>(Val: II);
230 CoroSuspends.push_back(Elt: Suspend);
231 if (Suspend->isFinal()) {
232 if (HasFinalSuspend)
233 report_fatal_error(
234 reason: "Only one suspend point can be marked as final");
235 HasFinalSuspend = true;
236 FinalSuspendIndex = CoroSuspends.size() - 1;
237 }
238 break;
239 }
240 case Intrinsic::coro_begin:
241 case Intrinsic::coro_begin_custom_abi: {
242 auto CB = cast<CoroBeginInst>(Val: II);
243
244 // Ignore coro id's that aren't pre-split.
245 auto Id = dyn_cast<CoroIdInst>(Val: CB->getId());
246 if (Id && !Id->getInfo().isPreSplit())
247 break;
248
249 if (CoroBegin)
250 report_fatal_error(
251 reason: "coroutine should have exactly one defining @llvm.coro.begin");
252 CB->addRetAttr(Kind: Attribute::NonNull);
253 CB->addRetAttr(Kind: Attribute::NoAlias);
254 CB->removeFnAttr(Kind: Attribute::NoDuplicate);
255 CoroBegin = CB;
256 break;
257 }
258 case Intrinsic::coro_end_async:
259 case Intrinsic::coro_end:
260 CoroEnds.push_back(Elt: cast<AnyCoroEndInst>(Val: II));
261 if (auto *AsyncEnd = dyn_cast<CoroAsyncEndInst>(Val: II)) {
262 AsyncEnd->checkWellFormed();
263 }
264
265 if (CoroEnds.back()->isUnwind())
266 HasUnwindCoroEnd = true;
267
268 if (CoroEnds.back()->isFallthrough() && isa<CoroEndInst>(Val: II)) {
269 // Make sure that the fallthrough coro.end is the first element in the
270 // CoroEnds vector.
271 // Note: I don't think this is neccessary anymore.
272 if (CoroEnds.size() > 1) {
273 if (CoroEnds.front()->isFallthrough())
274 report_fatal_error(
275 reason: "Only one coro.end can be marked as fallthrough");
276 std::swap(a&: CoroEnds.front(), b&: CoroEnds.back());
277 }
278 }
279 break;
280 case Intrinsic::coro_is_in_ramp:
281 CoroIsInRampInsts.push_back(Elt: cast<CoroIsInRampInst>(Val: II));
282 break;
283 case Intrinsic::coro_promise:
284 assert(CoroPromise == nullptr &&
285 "CoroEarly must ensure coro.promise unique");
286 CoroPromise = cast<CoroPromiseInst>(Val: II);
287 break;
288 }
289 }
290 }
291
292 // If there is no CoroBegin then this is not a coroutine.
293 if (!CoroBegin)
294 return;
295
296 // Determination of ABI and initializing lowering info
297 auto Id = CoroBegin->getId();
298 switch (auto IntrID = Id->getIntrinsicID()) {
299 case Intrinsic::coro_id: {
300 ABI = coro::ABI::Switch;
301 SwitchLowering.HasFinalSuspend = HasFinalSuspend;
302 SwitchLowering.HasUnwindCoroEnd = HasUnwindCoroEnd;
303
304 auto SwitchId = getSwitchCoroId();
305 SwitchLowering.ResumeSwitch = nullptr;
306 SwitchLowering.PromiseAlloca = SwitchId->getPromise();
307 SwitchLowering.ResumeEntryBlock = nullptr;
308
309 // Move final suspend to the last element in the CoroSuspends vector.
310 if (SwitchLowering.HasFinalSuspend &&
311 FinalSuspendIndex != CoroSuspends.size() - 1)
312 std::swap(a&: CoroSuspends[FinalSuspendIndex], b&: CoroSuspends.back());
313 break;
314 }
315 case Intrinsic::coro_id_async: {
316 ABI = coro::ABI::Async;
317 auto *AsyncId = getAsyncCoroId();
318 AsyncId->checkWellFormed();
319 AsyncLowering.Context = AsyncId->getStorage();
320 AsyncLowering.ContextArgNo = AsyncId->getStorageArgumentIndex();
321 AsyncLowering.ContextHeaderSize = AsyncId->getStorageSize();
322 AsyncLowering.ContextAlignment = AsyncId->getStorageAlignment().value();
323 AsyncLowering.AsyncFuncPointer = AsyncId->getAsyncFunctionPointer();
324 AsyncLowering.AsyncCC = F.getCallingConv();
325 break;
326 }
327 case Intrinsic::coro_id_retcon:
328 case Intrinsic::coro_id_retcon_once: {
329 ABI = IntrID == Intrinsic::coro_id_retcon ? coro::ABI::Retcon
330 : coro::ABI::RetconOnce;
331 auto ContinuationId = getRetconCoroId();
332 ContinuationId->checkWellFormed();
333 auto Prototype = ContinuationId->getPrototype();
334 RetconLowering.ResumePrototype = Prototype;
335 RetconLowering.Alloc = ContinuationId->getAllocFunction();
336 RetconLowering.Dealloc = ContinuationId->getDeallocFunction();
337 RetconLowering.ReturnBlock = nullptr;
338 RetconLowering.IsFrameInlineInStorage = false;
339 break;
340 }
341 default:
342 llvm_unreachable("coro.begin is not dependent on a coro.id call");
343 }
344}
345
346// If for some reason, we were not able to find coro.begin, bailout.
347void coro::Shape::invalidateCoroutine(
348 Function &F, SmallVectorImpl<CoroFrameInst *> &CoroFrames) {
349 assert(!CoroBegin);
350 {
351 // Replace coro.frame which are supposed to be lowered to the result of
352 // coro.begin with poison.
353 auto *Poison = PoisonValue::get(T: PointerType::get(C&: F.getContext(), AddressSpace: 0));
354 for (CoroFrameInst *CF : CoroFrames) {
355 CF->replaceAllUsesWith(V: Poison);
356 CF->eraseFromParent();
357 }
358 CoroFrames.clear();
359
360 // Replace all coro.suspend with poison and remove related coro.saves if
361 // present.
362 for (AnyCoroSuspendInst *CS : CoroSuspends) {
363 CS->replaceAllUsesWith(V: PoisonValue::get(T: CS->getType()));
364 if (auto *CoroSave = CS->getCoroSave())
365 CoroSave->eraseFromParent();
366 CS->eraseFromParent();
367 }
368 CoroSuspends.clear();
369
370 // Replace all coro.ends with unreachable instruction.
371 for (AnyCoroEndInst *CE : CoroEnds)
372 changeToUnreachable(I: CE);
373 }
374}
375
376void coro::SwitchABI::init() {
377 assert(Shape.ABI == coro::ABI::Switch);
378 {
379 for (auto *AnySuspend : Shape.CoroSuspends) {
380 auto Suspend = dyn_cast<CoroSuspendInst>(Val: AnySuspend);
381 if (!Suspend) {
382#ifndef NDEBUG
383 AnySuspend->dump();
384#endif
385 report_fatal_error(reason: "coro.id must be paired with coro.suspend");
386 }
387
388 if (!Suspend->getCoroSave())
389 createCoroSave(CoroBegin: Shape.CoroBegin, SuspendInst: Suspend);
390 }
391 }
392}
393
394void coro::AsyncABI::init() { assert(Shape.ABI == coro::ABI::Async); }
395
396void coro::AnyRetconABI::init() {
397 assert(Shape.ABI == coro::ABI::Retcon || Shape.ABI == coro::ABI::RetconOnce);
398 {
399 // Determine the result value types, and make sure they match up with
400 // the values passed to the suspends.
401 auto ResultTys = Shape.getRetconResultTypes();
402 auto ResumeTys = Shape.getRetconResumeTypes();
403
404 for (auto *AnySuspend : Shape.CoroSuspends) {
405 auto Suspend = dyn_cast<CoroSuspendRetconInst>(Val: AnySuspend);
406 if (!Suspend) {
407#ifndef NDEBUG
408 AnySuspend->dump();
409#endif
410 report_fatal_error(reason: "coro.id.retcon.* must be paired with "
411 "coro.suspend.retcon");
412 }
413
414 // Check that the argument types of the suspend match the results.
415 auto SI = Suspend->value_begin(), SE = Suspend->value_end();
416 auto RI = ResultTys.begin(), RE = ResultTys.end();
417 for (; SI != SE && RI != RE; ++SI, ++RI) {
418 auto SrcTy = (*SI)->getType();
419 if (SrcTy != *RI) {
420 // The optimizer likes to eliminate bitcasts leading into variadic
421 // calls, but that messes with our invariants. Re-insert the
422 // bitcast and ignore this type mismatch.
423 if (CastInst::isBitCastable(SrcTy, DestTy: *RI)) {
424 auto BCI = new BitCastInst(*SI, *RI, "", Suspend->getIterator());
425 SI->set(BCI);
426 continue;
427 }
428
429#ifndef NDEBUG
430 Suspend->dump();
431 Shape.RetconLowering.ResumePrototype->getFunctionType()->dump();
432#endif
433 report_fatal_error(reason: "argument to coro.suspend.retcon does not "
434 "match corresponding prototype function result");
435 }
436 }
437 if (SI != SE || RI != RE) {
438#ifndef NDEBUG
439 Suspend->dump();
440 Shape.RetconLowering.ResumePrototype->getFunctionType()->dump();
441#endif
442 report_fatal_error(reason: "wrong number of arguments to coro.suspend.retcon");
443 }
444
445 // Check that the result type of the suspend matches the resume types.
446 Type *SResultTy = Suspend->getType();
447 ArrayRef<Type *> SuspendResultTys;
448 if (SResultTy->isVoidTy()) {
449 // leave as empty array
450 } else if (auto SResultStructTy = dyn_cast<StructType>(Val: SResultTy)) {
451 SuspendResultTys = SResultStructTy->elements();
452 } else {
453 // forms an ArrayRef using SResultTy, be careful
454 SuspendResultTys = SResultTy;
455 }
456 if (SuspendResultTys.size() != ResumeTys.size()) {
457#ifndef NDEBUG
458 Suspend->dump();
459 Shape.RetconLowering.ResumePrototype->getFunctionType()->dump();
460#endif
461 report_fatal_error(reason: "wrong number of results from coro.suspend.retcon");
462 }
463 for (size_t I = 0, E = ResumeTys.size(); I != E; ++I) {
464 if (SuspendResultTys[I] != ResumeTys[I]) {
465#ifndef NDEBUG
466 Suspend->dump();
467 Shape.RetconLowering.ResumePrototype->getFunctionType()->dump();
468#endif
469 report_fatal_error(reason: "result from coro.suspend.retcon does not "
470 "match corresponding prototype function param");
471 }
472 }
473 }
474 }
475}
476
477void coro::Shape::cleanCoroutine(
478 SmallVectorImpl<CoroFrameInst *> &CoroFrames,
479 SmallVectorImpl<CoroSaveInst *> &UnusedCoroSaves, CoroPromiseInst *PI) {
480 // The coro.frame intrinsic is always lowered to the result of coro.begin.
481 for (CoroFrameInst *CF : CoroFrames) {
482 CF->replaceAllUsesWith(V: CoroBegin);
483 CF->eraseFromParent();
484 }
485 CoroFrames.clear();
486
487 // Remove orphaned coro.saves.
488 for (CoroSaveInst *CoroSave : UnusedCoroSaves)
489 CoroSave->eraseFromParent();
490 UnusedCoroSaves.clear();
491
492 if (PI) {
493 PI->replaceAllUsesWith(V: PI->isFromPromise()
494 ? cast<Value>(Val: CoroBegin)
495 : cast<Value>(Val: getPromiseAlloca()));
496 PI->eraseFromParent();
497 }
498}
499
500static void propagateCallAttrsFromCallee(CallInst *Call, Function *Callee) {
501 Call->setCallingConv(Callee->getCallingConv());
502 // TODO: attributes?
503}
504
505static void addCallToCallGraph(CallGraph *CG, CallInst *Call, Function *Callee){
506 if (CG)
507 (*CG)[Call->getFunction()]->addCalledFunction(Call, M: (*CG)[Callee]);
508}
509
510Value *coro::Shape::emitAlloc(IRBuilder<> &Builder, Value *Size,
511 CallGraph *CG) const {
512 switch (ABI) {
513 case coro::ABI::Switch:
514 llvm_unreachable("can't allocate memory in coro switch-lowering");
515
516 case coro::ABI::Retcon:
517 case coro::ABI::RetconOnce: {
518 auto Alloc = RetconLowering.Alloc;
519 Size = Builder.CreateIntCast(V: Size,
520 DestTy: Alloc->getFunctionType()->getParamType(i: 0),
521 /*is signed*/ isSigned: false);
522 auto *Call = Builder.CreateCall(Callee: Alloc, Args: Size);
523 propagateCallAttrsFromCallee(Call, Callee: Alloc);
524 addCallToCallGraph(CG, Call, Callee: Alloc);
525 return Call;
526 }
527 case coro::ABI::Async:
528 llvm_unreachable("can't allocate memory in coro async-lowering");
529 }
530 llvm_unreachable("Unknown coro::ABI enum");
531}
532
533void coro::Shape::emitDealloc(IRBuilder<> &Builder, Value *Ptr,
534 CallGraph *CG) const {
535 switch (ABI) {
536 case coro::ABI::Switch:
537 llvm_unreachable("can't allocate memory in coro switch-lowering");
538
539 case coro::ABI::Retcon:
540 case coro::ABI::RetconOnce: {
541 auto Dealloc = RetconLowering.Dealloc;
542 Ptr = Builder.CreateBitCast(V: Ptr,
543 DestTy: Dealloc->getFunctionType()->getParamType(i: 0));
544 auto *Call = Builder.CreateCall(Callee: Dealloc, Args: Ptr);
545 propagateCallAttrsFromCallee(Call, Callee: Dealloc);
546 addCallToCallGraph(CG, Call, Callee: Dealloc);
547 return;
548 }
549 case coro::ABI::Async:
550 llvm_unreachable("can't allocate memory in coro async-lowering");
551 }
552 llvm_unreachable("Unknown coro::ABI enum");
553}
554
555[[noreturn]] static void fail(const Instruction *I, const char *Reason,
556 Value *V) {
557#ifndef NDEBUG
558 I->dump();
559 if (V) {
560 errs() << " Value: ";
561 V->printAsOperand(llvm::errs());
562 errs() << '\n';
563 }
564#endif
565 report_fatal_error(reason: Reason);
566}
567
568/// Check that the given value is a well-formed prototype for the
569/// llvm.coro.id.retcon.* intrinsics.
570static void checkWFRetconPrototype(const AnyCoroIdRetconInst *I, Value *V) {
571 auto F = dyn_cast<Function>(Val: V->stripPointerCasts());
572 if (!F)
573 fail(I, Reason: "llvm.coro.id.retcon.* prototype not a Function", V);
574
575 auto FT = F->getFunctionType();
576
577 if (isa<CoroIdRetconInst>(Val: I)) {
578 bool ResultOkay;
579 if (FT->getReturnType()->isPointerTy()) {
580 ResultOkay = true;
581 } else if (auto SRetTy = dyn_cast<StructType>(Val: FT->getReturnType())) {
582 ResultOkay = (!SRetTy->isOpaque() &&
583 SRetTy->getNumElements() > 0 &&
584 SRetTy->getElementType(N: 0)->isPointerTy());
585 } else {
586 ResultOkay = false;
587 }
588 if (!ResultOkay)
589 fail(I, Reason: "llvm.coro.id.retcon prototype must return pointer as first "
590 "result", V: F);
591
592 if (FT->getReturnType() !=
593 I->getFunction()->getFunctionType()->getReturnType())
594 fail(I, Reason: "llvm.coro.id.retcon prototype return type must be same as"
595 "current function return type", V: F);
596 } else {
597 // No meaningful validation to do here for llvm.coro.id.unique.once.
598 }
599
600 if (FT->getNumParams() == 0 || !FT->getParamType(i: 0)->isPointerTy())
601 fail(I, Reason: "llvm.coro.id.retcon.* prototype must take pointer as "
602 "its first parameter", V: F);
603}
604
605/// Check that the given value is a well-formed allocator.
606static void checkWFAlloc(const Instruction *I, Value *V) {
607 auto F = dyn_cast<Function>(Val: V->stripPointerCasts());
608 if (!F)
609 fail(I, Reason: "llvm.coro.* allocator not a Function", V);
610
611 auto FT = F->getFunctionType();
612 if (!FT->getReturnType()->isPointerTy())
613 fail(I, Reason: "llvm.coro.* allocator must return a pointer", V: F);
614
615 if (FT->getNumParams() != 1 ||
616 !FT->getParamType(i: 0)->isIntegerTy())
617 fail(I, Reason: "llvm.coro.* allocator must take integer as only param", V: F);
618}
619
620/// Check that the given value is a well-formed deallocator.
621static void checkWFDealloc(const Instruction *I, Value *V) {
622 auto F = dyn_cast<Function>(Val: V->stripPointerCasts());
623 if (!F)
624 fail(I, Reason: "llvm.coro.* deallocator not a Function", V);
625
626 auto FT = F->getFunctionType();
627 if (!FT->getReturnType()->isVoidTy())
628 fail(I, Reason: "llvm.coro.* deallocator must return void", V: F);
629
630 if (FT->getNumParams() != 1 ||
631 !FT->getParamType(i: 0)->isPointerTy())
632 fail(I, Reason: "llvm.coro.* deallocator must take pointer as only param", V: F);
633}
634
635static void checkConstantInt(const Instruction *I, Value *V,
636 const char *Reason) {
637 if (!isa<ConstantInt>(Val: V)) {
638 fail(I, Reason, V);
639 }
640}
641
642void AnyCoroIdRetconInst::checkWellFormed() const {
643 checkConstantInt(I: this, V: getArgOperand(i: SizeArg),
644 Reason: "size argument to coro.id.retcon.* must be constant");
645 checkConstantInt(I: this, V: getArgOperand(i: AlignArg),
646 Reason: "alignment argument to coro.id.retcon.* must be constant");
647 checkWFRetconPrototype(I: this, V: getArgOperand(i: PrototypeArg));
648 checkWFAlloc(I: this, V: getArgOperand(i: AllocArg));
649 checkWFDealloc(I: this, V: getArgOperand(i: DeallocArg));
650}
651
652static void checkAsyncFuncPointer(const Instruction *I, Value *V) {
653 auto *AsyncFuncPtrAddr = dyn_cast<GlobalVariable>(Val: V->stripPointerCasts());
654 if (!AsyncFuncPtrAddr)
655 fail(I, Reason: "llvm.coro.id.async async function pointer not a global", V);
656}
657
658void CoroIdAsyncInst::checkWellFormed() const {
659 checkConstantInt(I: this, V: getArgOperand(i: SizeArg),
660 Reason: "size argument to coro.id.async must be constant");
661 checkConstantInt(I: this, V: getArgOperand(i: AlignArg),
662 Reason: "alignment argument to coro.id.async must be constant");
663 checkConstantInt(I: this, V: getArgOperand(i: StorageArg),
664 Reason: "storage argument offset to coro.id.async must be constant");
665 checkAsyncFuncPointer(I: this, V: getArgOperand(i: AsyncFuncPtrArg));
666}
667
668static void checkAsyncContextProjectFunction(const Instruction *I,
669 Function *F) {
670 auto *FunTy = F->getFunctionType();
671 if (!FunTy->getReturnType()->isPointerTy())
672 fail(I,
673 Reason: "llvm.coro.suspend.async resume function projection function must "
674 "return a ptr type",
675 V: F);
676 if (FunTy->getNumParams() != 1 || !FunTy->getParamType(i: 0)->isPointerTy())
677 fail(I,
678 Reason: "llvm.coro.suspend.async resume function projection function must "
679 "take one ptr type as parameter",
680 V: F);
681}
682
683void CoroSuspendAsyncInst::checkWellFormed() const {
684 checkAsyncContextProjectFunction(I: this, F: getAsyncContextProjectionFunction());
685}
686
687void CoroAsyncEndInst::checkWellFormed() const {
688 auto *MustTailCallFunc = getMustTailCallFunction();
689 if (!MustTailCallFunc)
690 return;
691 auto *FnTy = MustTailCallFunc->getFunctionType();
692 if (FnTy->getNumParams() != (arg_size() - 3))
693 fail(I: this,
694 Reason: "llvm.coro.end.async must tail call function argument type must "
695 "match the tail arguments",
696 V: MustTailCallFunc);
697}
698