| 1 | //===-- AbstractCallSite.cpp - Implementation of abstract call sites ------===// | 
|---|
| 2 | // | 
|---|
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | 
|---|
| 4 | // See https://llvm.org/LICENSE.txt for license information. | 
|---|
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | 
|---|
| 6 | // | 
|---|
| 7 | //===----------------------------------------------------------------------===// | 
|---|
| 8 | // | 
|---|
| 9 | // This file implements abstract call sites which unify the interface for | 
|---|
| 10 | // direct, indirect, and callback call sites. | 
|---|
| 11 | // | 
|---|
| 12 | // For more information see: | 
|---|
| 13 | // https://llvm.org/devmtg/2018-10/talk-abstracts.html#talk20 | 
|---|
| 14 | // | 
|---|
| 15 | //===----------------------------------------------------------------------===// | 
|---|
| 16 |  | 
|---|
| 17 | #include "llvm/IR/AbstractCallSite.h" | 
|---|
| 18 | #include "llvm/ADT/Statistic.h" | 
|---|
| 19 |  | 
|---|
| 20 | using namespace llvm; | 
|---|
| 21 |  | 
|---|
| 22 | #define DEBUG_TYPE "abstract-call-sites" | 
|---|
| 23 |  | 
|---|
| 24 | STATISTIC(NumCallbackCallSites, "Number of callback call sites created"); | 
|---|
| 25 | STATISTIC(NumDirectAbstractCallSites, | 
|---|
| 26 | "Number of direct abstract call sites created"); | 
|---|
| 27 | STATISTIC(NumInvalidAbstractCallSitesUnknownUse, | 
|---|
| 28 | "Number of invalid abstract call sites created (unknown use)"); | 
|---|
| 29 | STATISTIC(NumInvalidAbstractCallSitesUnknownCallee, | 
|---|
| 30 | "Number of invalid abstract call sites created (unknown callee)"); | 
|---|
| 31 | STATISTIC(NumInvalidAbstractCallSitesNoCallback, | 
|---|
| 32 | "Number of invalid abstract call sites created (no callback)"); | 
|---|
| 33 |  | 
|---|
| 34 | void AbstractCallSite::getCallbackUses( | 
|---|
| 35 | const CallBase &CB, SmallVectorImpl<const Use *> &CallbackUses) { | 
|---|
| 36 | const Function *Callee = CB.getCalledFunction(); | 
|---|
| 37 | if (!Callee) | 
|---|
| 38 | return; | 
|---|
| 39 |  | 
|---|
| 40 | MDNode *CallbackMD = Callee->getMetadata(KindID: LLVMContext::MD_callback); | 
|---|
| 41 | if (!CallbackMD) | 
|---|
| 42 | return; | 
|---|
| 43 |  | 
|---|
| 44 | for (const MDOperand &Op : CallbackMD->operands()) { | 
|---|
| 45 | MDNode *OpMD = cast<MDNode>(Val: Op.get()); | 
|---|
| 46 | auto *CBCalleeIdxAsCM = cast<ConstantAsMetadata>(Val: OpMD->getOperand(I: 0)); | 
|---|
| 47 | uint64_t CBCalleeIdx = | 
|---|
| 48 | cast<ConstantInt>(Val: CBCalleeIdxAsCM->getValue())->getZExtValue(); | 
|---|
| 49 | if (CBCalleeIdx < CB.arg_size()) | 
|---|
| 50 | CallbackUses.push_back(Elt: CB.arg_begin() + CBCalleeIdx); | 
|---|
| 51 | } | 
|---|
| 52 | } | 
|---|
| 53 |  | 
|---|
| 54 | /// Create an abstract call site from a use. | 
|---|
| 55 | AbstractCallSite::AbstractCallSite(const Use *U) | 
|---|
| 56 | : CB(dyn_cast<CallBase>(Val: U->getUser())) { | 
|---|
| 57 |  | 
|---|
| 58 | // First handle unknown users. | 
|---|
| 59 | if (!CB) { | 
|---|
| 60 |  | 
|---|
| 61 | // If the use is actually in a constant cast expression which itself | 
|---|
| 62 | // has only one use, we look through the constant cast expression. | 
|---|
| 63 | // This happens by updating the use @p U to the use of the constant | 
|---|
| 64 | // cast expression and afterwards re-initializing CB accordingly. | 
|---|
| 65 | if (ConstantExpr *CE = dyn_cast<ConstantExpr>(Val: U->getUser())) | 
|---|
| 66 | if (CE->hasOneUse() && CE->isCast()) { | 
|---|
| 67 | U = &*CE->use_begin(); | 
|---|
| 68 | CB = dyn_cast<CallBase>(Val: U->getUser()); | 
|---|
| 69 | } | 
|---|
| 70 |  | 
|---|
| 71 | if (!CB) { | 
|---|
| 72 | NumInvalidAbstractCallSitesUnknownUse++; | 
|---|
| 73 | return; | 
|---|
| 74 | } | 
|---|
| 75 | } | 
|---|
| 76 |  | 
|---|
| 77 | // Then handle direct or indirect calls. Thus, if U is the callee of the | 
|---|
| 78 | // call site CB it is not a callback and we are done. | 
|---|
| 79 | if (CB->isCallee(U)) { | 
|---|
| 80 | NumDirectAbstractCallSites++; | 
|---|
| 81 | return; | 
|---|
| 82 | } | 
|---|
| 83 |  | 
|---|
| 84 | // If we cannot identify the broker function we cannot create a callback and | 
|---|
| 85 | // invalidate the abstract call site. | 
|---|
| 86 | Function *Callee = CB->getCalledFunction(); | 
|---|
| 87 | if (!Callee) { | 
|---|
| 88 | NumInvalidAbstractCallSitesUnknownCallee++; | 
|---|
| 89 | CB = nullptr; | 
|---|
| 90 | return; | 
|---|
| 91 | } | 
|---|
| 92 |  | 
|---|
| 93 | MDNode *CallbackMD = Callee->getMetadata(KindID: LLVMContext::MD_callback); | 
|---|
| 94 | if (!CallbackMD) { | 
|---|
| 95 | NumInvalidAbstractCallSitesNoCallback++; | 
|---|
| 96 | CB = nullptr; | 
|---|
| 97 | return; | 
|---|
| 98 | } | 
|---|
| 99 |  | 
|---|
| 100 | unsigned UseIdx = CB->getArgOperandNo(U); | 
|---|
| 101 | MDNode *CallbackEncMD = nullptr; | 
|---|
| 102 | for (const MDOperand &Op : CallbackMD->operands()) { | 
|---|
| 103 | MDNode *OpMD = cast<MDNode>(Val: Op.get()); | 
|---|
| 104 | auto *CBCalleeIdxAsCM = cast<ConstantAsMetadata>(Val: OpMD->getOperand(I: 0)); | 
|---|
| 105 | uint64_t CBCalleeIdx = | 
|---|
| 106 | cast<ConstantInt>(Val: CBCalleeIdxAsCM->getValue())->getZExtValue(); | 
|---|
| 107 | if (CBCalleeIdx != UseIdx) | 
|---|
| 108 | continue; | 
|---|
| 109 | CallbackEncMD = OpMD; | 
|---|
| 110 | break; | 
|---|
| 111 | } | 
|---|
| 112 |  | 
|---|
| 113 | if (!CallbackEncMD) { | 
|---|
| 114 | NumInvalidAbstractCallSitesNoCallback++; | 
|---|
| 115 | CB = nullptr; | 
|---|
| 116 | return; | 
|---|
| 117 | } | 
|---|
| 118 |  | 
|---|
| 119 | NumCallbackCallSites++; | 
|---|
| 120 |  | 
|---|
| 121 | assert(CallbackEncMD->getNumOperands() >= 2 && "Incomplete !callback metadata"); | 
|---|
| 122 |  | 
|---|
| 123 | unsigned NumCallOperands = CB->arg_size(); | 
|---|
| 124 | // Skip the var-arg flag at the end when reading the metadata. | 
|---|
| 125 | for (unsigned u = 0, e = CallbackEncMD->getNumOperands() - 1; u < e; u++) { | 
|---|
| 126 | Metadata *OpAsM = CallbackEncMD->getOperand(I: u).get(); | 
|---|
| 127 | auto *OpAsCM = cast<ConstantAsMetadata>(Val: OpAsM); | 
|---|
| 128 | assert(OpAsCM->getType()->isIntegerTy(64) && | 
|---|
| 129 | "Malformed !callback metadata"); | 
|---|
| 130 |  | 
|---|
| 131 | int64_t Idx = cast<ConstantInt>(Val: OpAsCM->getValue())->getSExtValue(); | 
|---|
| 132 | assert(-1 <= Idx && Idx <= NumCallOperands && | 
|---|
| 133 | "Out-of-bounds !callback metadata index"); | 
|---|
| 134 |  | 
|---|
| 135 | CI.ParameterEncoding.push_back(Elt: Idx); | 
|---|
| 136 | } | 
|---|
| 137 |  | 
|---|
| 138 | if (!Callee->isVarArg()) | 
|---|
| 139 | return; | 
|---|
| 140 |  | 
|---|
| 141 | Metadata *VarArgFlagAsM = | 
|---|
| 142 | CallbackEncMD->getOperand(I: CallbackEncMD->getNumOperands() - 1).get(); | 
|---|
| 143 | auto *VarArgFlagAsCM = cast<ConstantAsMetadata>(Val: VarArgFlagAsM); | 
|---|
| 144 | assert(VarArgFlagAsCM->getType()->isIntegerTy(1) && | 
|---|
| 145 | "Malformed !callback metadata var-arg flag"); | 
|---|
| 146 |  | 
|---|
| 147 | if (VarArgFlagAsCM->getValue()->isNullValue()) | 
|---|
| 148 | return; | 
|---|
| 149 |  | 
|---|
| 150 | // Add all variadic arguments at the end. | 
|---|
| 151 | for (unsigned u = Callee->arg_size(); u < NumCallOperands; u++) | 
|---|
| 152 | CI.ParameterEncoding.push_back(Elt: u); | 
|---|
| 153 | } | 
|---|
| 154 |  | 
|---|