1//===-- WebAssemblyAddMissingPrototypes.cpp - Fix prototypeless functions -===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8///
9/// \file
10/// Add prototypes to prototypes-less functions.
11///
12/// WebAssembly has strict function prototype checking so we need functions
13/// declarations to match the call sites. Clang treats prototype-less functions
14/// as varargs (foo(...)) which happens to work on existing platforms but
15/// doesn't under WebAssembly. This pass will find all the call sites of each
16/// prototype-less function, ensure they agree, and then set the signature
17/// on the function declaration accordingly.
18///
19//===----------------------------------------------------------------------===//
20
21#include "WebAssembly.h"
22#include "llvm/IR/Constants.h"
23#include "llvm/IR/Module.h"
24#include "llvm/IR/Operator.h"
25#include "llvm/Pass.h"
26#include "llvm/Support/Debug.h"
27#include "llvm/Transforms/Utils/Local.h"
28#include "llvm/Transforms/Utils/ModuleUtils.h"
29using namespace llvm;
30
31#define DEBUG_TYPE "wasm-add-missing-prototypes"
32
33namespace {
34class WebAssemblyAddMissingPrototypes final : public ModulePass {
35 StringRef getPassName() const override {
36 return "Add prototypes to prototypes-less functions";
37 }
38
39 void getAnalysisUsage(AnalysisUsage &AU) const override {
40 AU.setPreservesCFG();
41 ModulePass::getAnalysisUsage(AU);
42 }
43
44 bool runOnModule(Module &M) override;
45
46public:
47 static char ID;
48 WebAssemblyAddMissingPrototypes() : ModulePass(ID) {}
49};
50} // End anonymous namespace
51
52char WebAssemblyAddMissingPrototypes::ID = 0;
53INITIALIZE_PASS(WebAssemblyAddMissingPrototypes, DEBUG_TYPE,
54 "Add prototypes to prototypes-less functions", false, false)
55
56ModulePass *llvm::createWebAssemblyAddMissingPrototypes() {
57 return new WebAssemblyAddMissingPrototypes();
58}
59
60bool WebAssemblyAddMissingPrototypes::runOnModule(Module &M) {
61 LLVM_DEBUG(dbgs() << "********** Add Missing Prototypes **********\n");
62
63 std::vector<std::pair<Function *, Function *>> Replacements;
64
65 // Find all the prototype-less function declarations
66 for (Function &F : M) {
67 if (!F.isDeclaration() || !F.hasFnAttribute(Kind: "no-prototype"))
68 continue;
69
70 LLVM_DEBUG(dbgs() << "Found no-prototype function: " << F.getName()
71 << "\n");
72
73 // When clang emits prototype-less C functions it uses (...), i.e. varargs
74 // function that take no arguments (have no sentinel). When we see a
75 // no-prototype attribute we expect the function have these properties.
76 if (!F.isVarArg())
77 report_fatal_error(
78 reason: "Functions with 'no-prototype' attribute must take varargs: " +
79 F.getName());
80 unsigned NumParams = F.getFunctionType()->getNumParams();
81 if (NumParams != 0) {
82 if (!(NumParams == 1 && F.arg_begin()->hasStructRetAttr()))
83 report_fatal_error(reason: "Functions with 'no-prototype' attribute should "
84 "not have params: " +
85 F.getName());
86 }
87
88 // Find calls of this function, looking through bitcasts.
89 SmallVector<CallBase *> Calls;
90 SmallVector<Value *> Worklist;
91 Worklist.push_back(Elt: &F);
92 while (!Worklist.empty()) {
93 Value *V = Worklist.pop_back_val();
94 for (User *U : V->users()) {
95 if (auto *BC = dyn_cast<BitCastOperator>(Val: U))
96 Worklist.push_back(Elt: BC);
97 else if (auto *CB = dyn_cast<CallBase>(Val: U))
98 if (CB->getCalledOperand() == V)
99 Calls.push_back(Elt: CB);
100 }
101 }
102
103 // Create a function prototype based on the first call site that we find.
104 FunctionType *NewType = nullptr;
105 for (CallBase *CB : Calls) {
106 LLVM_DEBUG(dbgs() << "prototype-less call of " << F.getName() << ":\n");
107 LLVM_DEBUG(dbgs() << *CB << "\n");
108 FunctionType *DestType = CB->getFunctionType();
109 if (!NewType) {
110 // Create a new function with the correct type
111 NewType = DestType;
112 LLVM_DEBUG(dbgs() << "found function type: " << *NewType << "\n");
113 } else if (NewType != DestType) {
114 errs() << "warning: prototype-less function used with "
115 "conflicting signatures: "
116 << F.getName() << "\n";
117 LLVM_DEBUG(dbgs() << " " << *DestType << "\n");
118 LLVM_DEBUG(dbgs() << " " << *NewType << "\n");
119 }
120 }
121
122 if (!NewType) {
123 LLVM_DEBUG(
124 dbgs() << "could not derive a function prototype from usage: " +
125 F.getName() + "\n");
126 // We could not derive a type for this function. In this case strip
127 // the isVarArg and make it a simple zero-arg function. This has more
128 // chance of being correct. The current signature of (...) is illegal in
129 // C since it doesn't have any arguments before the "...", we this at
130 // least makes it possible for this symbol to be resolved by the linker.
131 NewType = FunctionType::get(Result: F.getFunctionType()->getReturnType(), isVarArg: false);
132 }
133
134 Function *NewF =
135 Function::Create(Ty: NewType, Linkage: F.getLinkage(), N: F.getName() + ".fixed_sig");
136 NewF->setAttributes(F.getAttributes());
137 NewF->removeFnAttr(Kind: "no-prototype");
138 Replacements.emplace_back(args: &F, args&: NewF);
139 }
140
141 for (auto &Pair : Replacements) {
142 Function *OldF = Pair.first;
143 Function *NewF = Pair.second;
144 std::string Name = std::string(OldF->getName());
145 M.getFunctionList().push_back(val: NewF);
146 OldF->replaceAllUsesWith(
147 V: ConstantExpr::getPointerBitCastOrAddrSpaceCast(C: NewF, Ty: OldF->getType()));
148 OldF->eraseFromParent();
149 NewF->setName(Name);
150 }
151
152 return !Replacements.empty();
153}
154