1//===- Coroutines.cpp -----------------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements the common infrastructure for Coroutine Passes.
10//
11//===----------------------------------------------------------------------===//
12
13#include "CoroInternal.h"
14#include "llvm/ADT/SmallVector.h"
15#include "llvm/ADT/StringRef.h"
16#include "llvm/Analysis/CallGraph.h"
17#include "llvm/IR/Attributes.h"
18#include "llvm/IR/Constants.h"
19#include "llvm/IR/DerivedTypes.h"
20#include "llvm/IR/Function.h"
21#include "llvm/IR/InstIterator.h"
22#include "llvm/IR/Instructions.h"
23#include "llvm/IR/IntrinsicInst.h"
24#include "llvm/IR/Intrinsics.h"
25#include "llvm/IR/Module.h"
26#include "llvm/IR/Type.h"
27#include "llvm/Support/Casting.h"
28#include "llvm/Support/ErrorHandling.h"
29#include "llvm/Transforms/Coroutines/ABI.h"
30#include "llvm/Transforms/Coroutines/CoroInstr.h"
31#include "llvm/Transforms/Coroutines/CoroShape.h"
32#include "llvm/Transforms/Utils/Local.h"
33#include <cassert>
34#include <cstddef>
35#include <utility>
36
37using namespace llvm;
38
39// Construct the lowerer base class and initialize its members.
40coro::LowererBase::LowererBase(Module &M)
41 : TheModule(M), Context(M.getContext()),
42 Int8Ptr(PointerType::get(C&: Context, AddressSpace: 0)),
43 ResumeFnType(FunctionType::get(Result: Type::getVoidTy(C&: Context), Params: Int8Ptr,
44 /*isVarArg=*/false)),
45 NullPtr(ConstantPointerNull::get(T: Int8Ptr)) {}
46
47// Creates a call to llvm.coro.subfn.addr to obtain a resume function address.
48// It generates the following:
49//
50// call ptr @llvm.coro.subfn.addr(ptr %Arg, i8 %index)
51
52CallInst *coro::LowererBase::makeSubFnCall(Value *Arg, int Index,
53 Instruction *InsertPt) {
54 auto *IndexVal = ConstantInt::get(Ty: Type::getInt8Ty(C&: Context), V: Index);
55 auto *Fn =
56 Intrinsic::getOrInsertDeclaration(M: &TheModule, id: Intrinsic::coro_subfn_addr);
57
58 assert(Index >= CoroSubFnInst::IndexFirst &&
59 Index < CoroSubFnInst::IndexLast &&
60 "makeSubFnCall: Index value out of range");
61 return CallInst::Create(Func: Fn, Args: {Arg, IndexVal}, NameStr: "", InsertBefore: InsertPt->getIterator());
62}
63
64// We can only efficiently check for non-overloaded intrinsics.
65// The following intrinsics are absent for that reason:
66// coro_align, coro_size, coro_suspend_async, coro_suspend_retcon
67static Intrinsic::ID NonOverloadedCoroIntrinsics[] = {
68 Intrinsic::coro_alloc,
69 Intrinsic::coro_async_context_alloc,
70 Intrinsic::coro_async_context_dealloc,
71 Intrinsic::coro_async_resume,
72 Intrinsic::coro_async_size_replace,
73 Intrinsic::coro_await_suspend_bool,
74 Intrinsic::coro_await_suspend_handle,
75 Intrinsic::coro_await_suspend_void,
76 Intrinsic::coro_begin,
77 Intrinsic::coro_begin_custom_abi,
78 Intrinsic::coro_destroy,
79 Intrinsic::coro_done,
80 Intrinsic::coro_end,
81 Intrinsic::coro_end_async,
82 Intrinsic::coro_frame,
83 Intrinsic::coro_free,
84 Intrinsic::coro_id,
85 Intrinsic::coro_id_async,
86 Intrinsic::coro_id_retcon,
87 Intrinsic::coro_id_retcon_once,
88 Intrinsic::coro_promise,
89 Intrinsic::coro_resume,
90 Intrinsic::coro_save,
91 Intrinsic::coro_subfn_addr,
92 Intrinsic::coro_suspend,
93};
94
95bool coro::isSuspendBlock(BasicBlock *BB) {
96 return isa<AnyCoroSuspendInst>(Val: BB->front());
97}
98
99bool coro::declaresAnyIntrinsic(const Module &M) {
100 return declaresIntrinsics(M, List: NonOverloadedCoroIntrinsics);
101}
102
103// Checks whether the module declares any of the listed intrinsics.
104bool coro::declaresIntrinsics(const Module &M, ArrayRef<Intrinsic::ID> List) {
105#ifndef NDEBUG
106 for (Intrinsic::ID ID : List)
107 assert(!Intrinsic::isOverloaded(ID) &&
108 "Only non-overloaded intrinsics supported");
109#endif
110
111 for (Intrinsic::ID ID : List)
112 if (Intrinsic::getDeclarationIfExists(M: &M, id: ID))
113 return true;
114 return false;
115}
116
117// Replace all coro.frees associated with the provided CoroId either with 'null'
118// if Elide is true and with its frame parameter otherwise.
119void coro::replaceCoroFree(CoroIdInst *CoroId, bool Elide) {
120 SmallVector<CoroFreeInst *, 4> CoroFrees;
121 for (User *U : CoroId->users())
122 if (auto CF = dyn_cast<CoroFreeInst>(Val: U))
123 CoroFrees.push_back(Elt: CF);
124
125 if (CoroFrees.empty())
126 return;
127
128 Value *Replacement =
129 Elide
130 ? ConstantPointerNull::get(T: PointerType::get(C&: CoroId->getContext(), AddressSpace: 0))
131 : CoroFrees.front()->getFrame();
132
133 for (CoroFreeInst *CF : CoroFrees) {
134 CF->replaceAllUsesWith(V: Replacement);
135 CF->eraseFromParent();
136 }
137}
138
139void coro::suppressCoroAllocs(CoroIdInst *CoroId) {
140 SmallVector<CoroAllocInst *, 4> CoroAllocs;
141 for (User *U : CoroId->users())
142 if (auto *CA = dyn_cast<CoroAllocInst>(Val: U))
143 CoroAllocs.push_back(Elt: CA);
144
145 if (CoroAllocs.empty())
146 return;
147
148 coro::suppressCoroAllocs(Context&: CoroId->getContext(), CoroAllocs);
149}
150
151// Replacing llvm.coro.alloc with false will suppress dynamic
152// allocation as it is expected for the frontend to generate the code that
153// looks like:
154// id = coro.id(...)
155// mem = coro.alloc(id) ? malloc(coro.size()) : 0;
156// coro.begin(id, mem)
157void coro::suppressCoroAllocs(LLVMContext &Context,
158 ArrayRef<CoroAllocInst *> CoroAllocs) {
159 auto *False = ConstantInt::getFalse(Context);
160 for (auto *CA : CoroAllocs) {
161 CA->replaceAllUsesWith(V: False);
162 CA->eraseFromParent();
163 }
164}
165
166static CoroSaveInst *createCoroSave(CoroBeginInst *CoroBegin,
167 CoroSuspendInst *SuspendInst) {
168 Module *M = SuspendInst->getModule();
169 auto *Fn = Intrinsic::getOrInsertDeclaration(M, id: Intrinsic::coro_save);
170 auto *SaveInst = cast<CoroSaveInst>(
171 Val: CallInst::Create(Func: Fn, Args: CoroBegin, NameStr: "", InsertBefore: SuspendInst->getIterator()));
172 assert(!SuspendInst->getCoroSave());
173 SuspendInst->setArgOperand(i: 0, v: SaveInst);
174 return SaveInst;
175}
176
177// Collect "interesting" coroutine intrinsics.
178void coro::Shape::analyze(Function &F,
179 SmallVectorImpl<CoroFrameInst *> &CoroFrames,
180 SmallVectorImpl<CoroSaveInst *> &UnusedCoroSaves,
181 CoroPromiseInst *&CoroPromise) {
182 clear();
183
184 bool HasFinalSuspend = false;
185 bool HasUnwindCoroEnd = false;
186 size_t FinalSuspendIndex = 0;
187
188 for (Instruction &I : instructions(F)) {
189 // FIXME: coro_await_suspend_* are not proper `IntrinisicInst`s
190 // because they might be invoked
191 if (auto AWS = dyn_cast<CoroAwaitSuspendInst>(Val: &I)) {
192 CoroAwaitSuspends.push_back(Elt: AWS);
193 } else if (auto II = dyn_cast<IntrinsicInst>(Val: &I)) {
194 switch (II->getIntrinsicID()) {
195 default:
196 continue;
197 case Intrinsic::coro_size:
198 CoroSizes.push_back(Elt: cast<CoroSizeInst>(Val: II));
199 break;
200 case Intrinsic::coro_align:
201 CoroAligns.push_back(Elt: cast<CoroAlignInst>(Val: II));
202 break;
203 case Intrinsic::coro_frame:
204 CoroFrames.push_back(Elt: cast<CoroFrameInst>(Val: II));
205 break;
206 case Intrinsic::coro_save:
207 // After optimizations, coro_suspends using this coro_save might have
208 // been removed, remember orphaned coro_saves to remove them later.
209 if (II->use_empty())
210 UnusedCoroSaves.push_back(Elt: cast<CoroSaveInst>(Val: II));
211 break;
212 case Intrinsic::coro_suspend_async: {
213 auto *Suspend = cast<CoroSuspendAsyncInst>(Val: II);
214 Suspend->checkWellFormed();
215 CoroSuspends.push_back(Elt: Suspend);
216 break;
217 }
218 case Intrinsic::coro_suspend_retcon: {
219 auto Suspend = cast<CoroSuspendRetconInst>(Val: II);
220 CoroSuspends.push_back(Elt: Suspend);
221 break;
222 }
223 case Intrinsic::coro_suspend: {
224 auto Suspend = cast<CoroSuspendInst>(Val: II);
225 CoroSuspends.push_back(Elt: Suspend);
226 if (Suspend->isFinal()) {
227 if (HasFinalSuspend)
228 report_fatal_error(
229 reason: "Only one suspend point can be marked as final");
230 HasFinalSuspend = true;
231 FinalSuspendIndex = CoroSuspends.size() - 1;
232 }
233 break;
234 }
235 case Intrinsic::coro_begin:
236 case Intrinsic::coro_begin_custom_abi: {
237 auto CB = cast<CoroBeginInst>(Val: II);
238
239 // Ignore coro id's that aren't pre-split.
240 auto Id = dyn_cast<CoroIdInst>(Val: CB->getId());
241 if (Id && !Id->getInfo().isPreSplit())
242 break;
243
244 if (CoroBegin)
245 report_fatal_error(
246 reason: "coroutine should have exactly one defining @llvm.coro.begin");
247 CB->addRetAttr(Kind: Attribute::NonNull);
248 CB->addRetAttr(Kind: Attribute::NoAlias);
249 CB->removeFnAttr(Kind: Attribute::NoDuplicate);
250 CoroBegin = CB;
251 break;
252 }
253 case Intrinsic::coro_end_async:
254 case Intrinsic::coro_end:
255 CoroEnds.push_back(Elt: cast<AnyCoroEndInst>(Val: II));
256 if (auto *AsyncEnd = dyn_cast<CoroAsyncEndInst>(Val: II)) {
257 AsyncEnd->checkWellFormed();
258 }
259
260 if (CoroEnds.back()->isUnwind())
261 HasUnwindCoroEnd = true;
262
263 if (CoroEnds.back()->isFallthrough() && isa<CoroEndInst>(Val: II)) {
264 // Make sure that the fallthrough coro.end is the first element in the
265 // CoroEnds vector.
266 // Note: I don't think this is neccessary anymore.
267 if (CoroEnds.size() > 1) {
268 if (CoroEnds.front()->isFallthrough())
269 report_fatal_error(
270 reason: "Only one coro.end can be marked as fallthrough");
271 std::swap(a&: CoroEnds.front(), b&: CoroEnds.back());
272 }
273 }
274 break;
275 case Intrinsic::coro_promise:
276 assert(CoroPromise == nullptr &&
277 "CoroEarly must ensure coro.promise unique");
278 CoroPromise = cast<CoroPromiseInst>(Val: II);
279 break;
280 }
281 }
282 }
283
284 // If there is no CoroBegin then this is not a coroutine.
285 if (!CoroBegin)
286 return;
287
288 // Determination of ABI and initializing lowering info
289 auto Id = CoroBegin->getId();
290 switch (auto IntrID = Id->getIntrinsicID()) {
291 case Intrinsic::coro_id: {
292 ABI = coro::ABI::Switch;
293 SwitchLowering.HasFinalSuspend = HasFinalSuspend;
294 SwitchLowering.HasUnwindCoroEnd = HasUnwindCoroEnd;
295
296 auto SwitchId = getSwitchCoroId();
297 SwitchLowering.ResumeSwitch = nullptr;
298 SwitchLowering.PromiseAlloca = SwitchId->getPromise();
299 SwitchLowering.ResumeEntryBlock = nullptr;
300
301 // Move final suspend to the last element in the CoroSuspends vector.
302 if (SwitchLowering.HasFinalSuspend &&
303 FinalSuspendIndex != CoroSuspends.size() - 1)
304 std::swap(a&: CoroSuspends[FinalSuspendIndex], b&: CoroSuspends.back());
305 break;
306 }
307 case Intrinsic::coro_id_async: {
308 ABI = coro::ABI::Async;
309 auto *AsyncId = getAsyncCoroId();
310 AsyncId->checkWellFormed();
311 AsyncLowering.Context = AsyncId->getStorage();
312 AsyncLowering.ContextArgNo = AsyncId->getStorageArgumentIndex();
313 AsyncLowering.ContextHeaderSize = AsyncId->getStorageSize();
314 AsyncLowering.ContextAlignment = AsyncId->getStorageAlignment().value();
315 AsyncLowering.AsyncFuncPointer = AsyncId->getAsyncFunctionPointer();
316 AsyncLowering.AsyncCC = F.getCallingConv();
317 break;
318 }
319 case Intrinsic::coro_id_retcon:
320 case Intrinsic::coro_id_retcon_once: {
321 ABI = IntrID == Intrinsic::coro_id_retcon ? coro::ABI::Retcon
322 : coro::ABI::RetconOnce;
323 auto ContinuationId = getRetconCoroId();
324 ContinuationId->checkWellFormed();
325 auto Prototype = ContinuationId->getPrototype();
326 RetconLowering.ResumePrototype = Prototype;
327 RetconLowering.Alloc = ContinuationId->getAllocFunction();
328 RetconLowering.Dealloc = ContinuationId->getDeallocFunction();
329 RetconLowering.ReturnBlock = nullptr;
330 RetconLowering.IsFrameInlineInStorage = false;
331 break;
332 }
333 default:
334 llvm_unreachable("coro.begin is not dependent on a coro.id call");
335 }
336}
337
338// If for some reason, we were not able to find coro.begin, bailout.
339void coro::Shape::invalidateCoroutine(
340 Function &F, SmallVectorImpl<CoroFrameInst *> &CoroFrames) {
341 assert(!CoroBegin);
342 {
343 // Replace coro.frame which are supposed to be lowered to the result of
344 // coro.begin with poison.
345 auto *Poison = PoisonValue::get(T: PointerType::get(C&: F.getContext(), AddressSpace: 0));
346 for (CoroFrameInst *CF : CoroFrames) {
347 CF->replaceAllUsesWith(V: Poison);
348 CF->eraseFromParent();
349 }
350 CoroFrames.clear();
351
352 // Replace all coro.suspend with poison and remove related coro.saves if
353 // present.
354 for (AnyCoroSuspendInst *CS : CoroSuspends) {
355 CS->replaceAllUsesWith(V: PoisonValue::get(T: CS->getType()));
356 CS->eraseFromParent();
357 if (auto *CoroSave = CS->getCoroSave())
358 CoroSave->eraseFromParent();
359 }
360 CoroSuspends.clear();
361
362 // Replace all coro.ends with unreachable instruction.
363 for (AnyCoroEndInst *CE : CoroEnds)
364 changeToUnreachable(I: CE);
365 }
366}
367
368void coro::SwitchABI::init() {
369 assert(Shape.ABI == coro::ABI::Switch);
370 {
371 for (auto *AnySuspend : Shape.CoroSuspends) {
372 auto Suspend = dyn_cast<CoroSuspendInst>(Val: AnySuspend);
373 if (!Suspend) {
374#ifndef NDEBUG
375 AnySuspend->dump();
376#endif
377 report_fatal_error(reason: "coro.id must be paired with coro.suspend");
378 }
379
380 if (!Suspend->getCoroSave())
381 createCoroSave(CoroBegin: Shape.CoroBegin, SuspendInst: Suspend);
382 }
383 }
384}
385
386void coro::AsyncABI::init() { assert(Shape.ABI == coro::ABI::Async); }
387
388void coro::AnyRetconABI::init() {
389 assert(Shape.ABI == coro::ABI::Retcon || Shape.ABI == coro::ABI::RetconOnce);
390 {
391 // Determine the result value types, and make sure they match up with
392 // the values passed to the suspends.
393 auto ResultTys = Shape.getRetconResultTypes();
394 auto ResumeTys = Shape.getRetconResumeTypes();
395
396 for (auto *AnySuspend : Shape.CoroSuspends) {
397 auto Suspend = dyn_cast<CoroSuspendRetconInst>(Val: AnySuspend);
398 if (!Suspend) {
399#ifndef NDEBUG
400 AnySuspend->dump();
401#endif
402 report_fatal_error(reason: "coro.id.retcon.* must be paired with "
403 "coro.suspend.retcon");
404 }
405
406 // Check that the argument types of the suspend match the results.
407 auto SI = Suspend->value_begin(), SE = Suspend->value_end();
408 auto RI = ResultTys.begin(), RE = ResultTys.end();
409 for (; SI != SE && RI != RE; ++SI, ++RI) {
410 auto SrcTy = (*SI)->getType();
411 if (SrcTy != *RI) {
412 // The optimizer likes to eliminate bitcasts leading into variadic
413 // calls, but that messes with our invariants. Re-insert the
414 // bitcast and ignore this type mismatch.
415 if (CastInst::isBitCastable(SrcTy, DestTy: *RI)) {
416 auto BCI = new BitCastInst(*SI, *RI, "", Suspend->getIterator());
417 SI->set(BCI);
418 continue;
419 }
420
421#ifndef NDEBUG
422 Suspend->dump();
423 Shape.RetconLowering.ResumePrototype->getFunctionType()->dump();
424#endif
425 report_fatal_error(reason: "argument to coro.suspend.retcon does not "
426 "match corresponding prototype function result");
427 }
428 }
429 if (SI != SE || RI != RE) {
430#ifndef NDEBUG
431 Suspend->dump();
432 Shape.RetconLowering.ResumePrototype->getFunctionType()->dump();
433#endif
434 report_fatal_error(reason: "wrong number of arguments to coro.suspend.retcon");
435 }
436
437 // Check that the result type of the suspend matches the resume types.
438 Type *SResultTy = Suspend->getType();
439 ArrayRef<Type *> SuspendResultTys;
440 if (SResultTy->isVoidTy()) {
441 // leave as empty array
442 } else if (auto SResultStructTy = dyn_cast<StructType>(Val: SResultTy)) {
443 SuspendResultTys = SResultStructTy->elements();
444 } else {
445 // forms an ArrayRef using SResultTy, be careful
446 SuspendResultTys = SResultTy;
447 }
448 if (SuspendResultTys.size() != ResumeTys.size()) {
449#ifndef NDEBUG
450 Suspend->dump();
451 Shape.RetconLowering.ResumePrototype->getFunctionType()->dump();
452#endif
453 report_fatal_error(reason: "wrong number of results from coro.suspend.retcon");
454 }
455 for (size_t I = 0, E = ResumeTys.size(); I != E; ++I) {
456 if (SuspendResultTys[I] != ResumeTys[I]) {
457#ifndef NDEBUG
458 Suspend->dump();
459 Shape.RetconLowering.ResumePrototype->getFunctionType()->dump();
460#endif
461 report_fatal_error(reason: "result from coro.suspend.retcon does not "
462 "match corresponding prototype function param");
463 }
464 }
465 }
466 }
467}
468
469void coro::Shape::cleanCoroutine(
470 SmallVectorImpl<CoroFrameInst *> &CoroFrames,
471 SmallVectorImpl<CoroSaveInst *> &UnusedCoroSaves, CoroPromiseInst *PI) {
472 // The coro.frame intrinsic is always lowered to the result of coro.begin.
473 for (CoroFrameInst *CF : CoroFrames) {
474 CF->replaceAllUsesWith(V: CoroBegin);
475 CF->eraseFromParent();
476 }
477 CoroFrames.clear();
478
479 // Remove orphaned coro.saves.
480 for (CoroSaveInst *CoroSave : UnusedCoroSaves)
481 CoroSave->eraseFromParent();
482 UnusedCoroSaves.clear();
483
484 if (PI) {
485 PI->replaceAllUsesWith(V: PI->isFromPromise()
486 ? cast<Value>(Val: CoroBegin)
487 : cast<Value>(Val: getPromiseAlloca()));
488 PI->eraseFromParent();
489 }
490}
491
492static void propagateCallAttrsFromCallee(CallInst *Call, Function *Callee) {
493 Call->setCallingConv(Callee->getCallingConv());
494 // TODO: attributes?
495}
496
497static void addCallToCallGraph(CallGraph *CG, CallInst *Call, Function *Callee){
498 if (CG)
499 (*CG)[Call->getFunction()]->addCalledFunction(Call, M: (*CG)[Callee]);
500}
501
502Value *coro::Shape::emitAlloc(IRBuilder<> &Builder, Value *Size,
503 CallGraph *CG) const {
504 switch (ABI) {
505 case coro::ABI::Switch:
506 llvm_unreachable("can't allocate memory in coro switch-lowering");
507
508 case coro::ABI::Retcon:
509 case coro::ABI::RetconOnce: {
510 auto Alloc = RetconLowering.Alloc;
511 Size = Builder.CreateIntCast(V: Size,
512 DestTy: Alloc->getFunctionType()->getParamType(i: 0),
513 /*is signed*/ isSigned: false);
514 auto *Call = Builder.CreateCall(Callee: Alloc, Args: Size);
515 propagateCallAttrsFromCallee(Call, Callee: Alloc);
516 addCallToCallGraph(CG, Call, Callee: Alloc);
517 return Call;
518 }
519 case coro::ABI::Async:
520 llvm_unreachable("can't allocate memory in coro async-lowering");
521 }
522 llvm_unreachable("Unknown coro::ABI enum");
523}
524
525void coro::Shape::emitDealloc(IRBuilder<> &Builder, Value *Ptr,
526 CallGraph *CG) const {
527 switch (ABI) {
528 case coro::ABI::Switch:
529 llvm_unreachable("can't allocate memory in coro switch-lowering");
530
531 case coro::ABI::Retcon:
532 case coro::ABI::RetconOnce: {
533 auto Dealloc = RetconLowering.Dealloc;
534 Ptr = Builder.CreateBitCast(V: Ptr,
535 DestTy: Dealloc->getFunctionType()->getParamType(i: 0));
536 auto *Call = Builder.CreateCall(Callee: Dealloc, Args: Ptr);
537 propagateCallAttrsFromCallee(Call, Callee: Dealloc);
538 addCallToCallGraph(CG, Call, Callee: Dealloc);
539 return;
540 }
541 case coro::ABI::Async:
542 llvm_unreachable("can't allocate memory in coro async-lowering");
543 }
544 llvm_unreachable("Unknown coro::ABI enum");
545}
546
547[[noreturn]] static void fail(const Instruction *I, const char *Reason,
548 Value *V) {
549#ifndef NDEBUG
550 I->dump();
551 if (V) {
552 errs() << " Value: ";
553 V->printAsOperand(llvm::errs());
554 errs() << '\n';
555 }
556#endif
557 report_fatal_error(reason: Reason);
558}
559
560/// Check that the given value is a well-formed prototype for the
561/// llvm.coro.id.retcon.* intrinsics.
562static void checkWFRetconPrototype(const AnyCoroIdRetconInst *I, Value *V) {
563 auto F = dyn_cast<Function>(Val: V->stripPointerCasts());
564 if (!F)
565 fail(I, Reason: "llvm.coro.id.retcon.* prototype not a Function", V);
566
567 auto FT = F->getFunctionType();
568
569 if (isa<CoroIdRetconInst>(Val: I)) {
570 bool ResultOkay;
571 if (FT->getReturnType()->isPointerTy()) {
572 ResultOkay = true;
573 } else if (auto SRetTy = dyn_cast<StructType>(Val: FT->getReturnType())) {
574 ResultOkay = (!SRetTy->isOpaque() &&
575 SRetTy->getNumElements() > 0 &&
576 SRetTy->getElementType(N: 0)->isPointerTy());
577 } else {
578 ResultOkay = false;
579 }
580 if (!ResultOkay)
581 fail(I, Reason: "llvm.coro.id.retcon prototype must return pointer as first "
582 "result", V: F);
583
584 if (FT->getReturnType() !=
585 I->getFunction()->getFunctionType()->getReturnType())
586 fail(I, Reason: "llvm.coro.id.retcon prototype return type must be same as"
587 "current function return type", V: F);
588 } else {
589 // No meaningful validation to do here for llvm.coro.id.unique.once.
590 }
591
592 if (FT->getNumParams() == 0 || !FT->getParamType(i: 0)->isPointerTy())
593 fail(I, Reason: "llvm.coro.id.retcon.* prototype must take pointer as "
594 "its first parameter", V: F);
595}
596
597/// Check that the given value is a well-formed allocator.
598static void checkWFAlloc(const Instruction *I, Value *V) {
599 auto F = dyn_cast<Function>(Val: V->stripPointerCasts());
600 if (!F)
601 fail(I, Reason: "llvm.coro.* allocator not a Function", V);
602
603 auto FT = F->getFunctionType();
604 if (!FT->getReturnType()->isPointerTy())
605 fail(I, Reason: "llvm.coro.* allocator must return a pointer", V: F);
606
607 if (FT->getNumParams() != 1 ||
608 !FT->getParamType(i: 0)->isIntegerTy())
609 fail(I, Reason: "llvm.coro.* allocator must take integer as only param", V: F);
610}
611
612/// Check that the given value is a well-formed deallocator.
613static void checkWFDealloc(const Instruction *I, Value *V) {
614 auto F = dyn_cast<Function>(Val: V->stripPointerCasts());
615 if (!F)
616 fail(I, Reason: "llvm.coro.* deallocator not a Function", V);
617
618 auto FT = F->getFunctionType();
619 if (!FT->getReturnType()->isVoidTy())
620 fail(I, Reason: "llvm.coro.* deallocator must return void", V: F);
621
622 if (FT->getNumParams() != 1 ||
623 !FT->getParamType(i: 0)->isPointerTy())
624 fail(I, Reason: "llvm.coro.* deallocator must take pointer as only param", V: F);
625}
626
627static void checkConstantInt(const Instruction *I, Value *V,
628 const char *Reason) {
629 if (!isa<ConstantInt>(Val: V)) {
630 fail(I, Reason, V);
631 }
632}
633
634void AnyCoroIdRetconInst::checkWellFormed() const {
635 checkConstantInt(I: this, V: getArgOperand(i: SizeArg),
636 Reason: "size argument to coro.id.retcon.* must be constant");
637 checkConstantInt(I: this, V: getArgOperand(i: AlignArg),
638 Reason: "alignment argument to coro.id.retcon.* must be constant");
639 checkWFRetconPrototype(I: this, V: getArgOperand(i: PrototypeArg));
640 checkWFAlloc(I: this, V: getArgOperand(i: AllocArg));
641 checkWFDealloc(I: this, V: getArgOperand(i: DeallocArg));
642}
643
644static void checkAsyncFuncPointer(const Instruction *I, Value *V) {
645 auto *AsyncFuncPtrAddr = dyn_cast<GlobalVariable>(Val: V->stripPointerCasts());
646 if (!AsyncFuncPtrAddr)
647 fail(I, Reason: "llvm.coro.id.async async function pointer not a global", V);
648}
649
650void CoroIdAsyncInst::checkWellFormed() const {
651 checkConstantInt(I: this, V: getArgOperand(i: SizeArg),
652 Reason: "size argument to coro.id.async must be constant");
653 checkConstantInt(I: this, V: getArgOperand(i: AlignArg),
654 Reason: "alignment argument to coro.id.async must be constant");
655 checkConstantInt(I: this, V: getArgOperand(i: StorageArg),
656 Reason: "storage argument offset to coro.id.async must be constant");
657 checkAsyncFuncPointer(I: this, V: getArgOperand(i: AsyncFuncPtrArg));
658}
659
660static void checkAsyncContextProjectFunction(const Instruction *I,
661 Function *F) {
662 auto *FunTy = cast<FunctionType>(Val: F->getValueType());
663 if (!FunTy->getReturnType()->isPointerTy())
664 fail(I,
665 Reason: "llvm.coro.suspend.async resume function projection function must "
666 "return a ptr type",
667 V: F);
668 if (FunTy->getNumParams() != 1 || !FunTy->getParamType(i: 0)->isPointerTy())
669 fail(I,
670 Reason: "llvm.coro.suspend.async resume function projection function must "
671 "take one ptr type as parameter",
672 V: F);
673}
674
675void CoroSuspendAsyncInst::checkWellFormed() const {
676 checkAsyncContextProjectFunction(I: this, F: getAsyncContextProjectionFunction());
677}
678
679void CoroAsyncEndInst::checkWellFormed() const {
680 auto *MustTailCallFunc = getMustTailCallFunction();
681 if (!MustTailCallFunc)
682 return;
683 auto *FnTy = MustTailCallFunc->getFunctionType();
684 if (FnTy->getNumParams() != (arg_size() - 3))
685 fail(I: this,
686 Reason: "llvm.coro.end.async must tail call function argument type must "
687 "match the tail arguments",
688 V: MustTailCallFunc);
689}
690