1 | //===-- CrossDSOCFI.cpp - Externalize this module's CFI checks ------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This pass exports all llvm.bitset's found in the module in the form of a |
10 | // __cfi_check function, which can be used to verify cross-DSO call targets. |
11 | // |
12 | //===----------------------------------------------------------------------===// |
13 | |
14 | #include "llvm/Transforms/IPO/CrossDSOCFI.h" |
15 | #include "llvm/ADT/SetVector.h" |
16 | #include "llvm/ADT/Statistic.h" |
17 | #include "llvm/IR/Constants.h" |
18 | #include "llvm/IR/Function.h" |
19 | #include "llvm/IR/GlobalObject.h" |
20 | #include "llvm/IR/IRBuilder.h" |
21 | #include "llvm/IR/Instructions.h" |
22 | #include "llvm/IR/Intrinsics.h" |
23 | #include "llvm/IR/MDBuilder.h" |
24 | #include "llvm/IR/Module.h" |
25 | #include "llvm/TargetParser/Triple.h" |
26 | |
27 | using namespace llvm; |
28 | |
29 | #define DEBUG_TYPE "cross-dso-cfi" |
30 | |
31 | STATISTIC(NumTypeIds, "Number of unique type identifiers" ); |
32 | |
33 | namespace { |
34 | |
35 | struct CrossDSOCFI { |
36 | MDNode *VeryLikelyWeights; |
37 | |
38 | ConstantInt *extractNumericTypeId(MDNode *MD); |
39 | void buildCFICheck(Module &M); |
40 | bool runOnModule(Module &M); |
41 | }; |
42 | |
43 | } // anonymous namespace |
44 | |
45 | /// Extracts a numeric type identifier from an MDNode containing type metadata. |
46 | ConstantInt *CrossDSOCFI::(MDNode *MD) { |
47 | // This check excludes vtables for classes inside anonymous namespaces. |
48 | auto TM = dyn_cast<ValueAsMetadata>(Val: MD->getOperand(I: 1)); |
49 | if (!TM) |
50 | return nullptr; |
51 | auto C = dyn_cast_or_null<ConstantInt>(Val: TM->getValue()); |
52 | if (!C) return nullptr; |
53 | // We are looking for i64 constants. |
54 | if (C->getBitWidth() != 64) return nullptr; |
55 | |
56 | return C; |
57 | } |
58 | |
59 | /// buildCFICheck - emits __cfi_check for the current module. |
60 | void CrossDSOCFI::buildCFICheck(Module &M) { |
61 | // FIXME: verify that __cfi_check ends up near the end of the code section, |
62 | // but before the jump slots created in LowerTypeTests. |
63 | SetVector<uint64_t> TypeIds; |
64 | SmallVector<MDNode *, 2> Types; |
65 | for (GlobalObject &GO : M.global_objects()) { |
66 | Types.clear(); |
67 | GO.getMetadata(KindID: LLVMContext::MD_type, MDs&: Types); |
68 | for (MDNode *Type : Types) |
69 | if (ConstantInt *TypeId = extractNumericTypeId(MD: Type)) |
70 | TypeIds.insert(X: TypeId->getZExtValue()); |
71 | } |
72 | |
73 | NamedMDNode *CfiFunctionsMD = M.getNamedMetadata(Name: "cfi.functions" ); |
74 | if (CfiFunctionsMD) { |
75 | for (auto *Func : CfiFunctionsMD->operands()) { |
76 | assert(Func->getNumOperands() >= 2); |
77 | for (unsigned I = 2; I < Func->getNumOperands(); ++I) |
78 | if (ConstantInt *TypeId = |
79 | extractNumericTypeId(MD: cast<MDNode>(Val: Func->getOperand(I).get()))) |
80 | TypeIds.insert(X: TypeId->getZExtValue()); |
81 | } |
82 | } |
83 | |
84 | LLVMContext &Ctx = M.getContext(); |
85 | FunctionCallee C = M.getOrInsertFunction( |
86 | Name: "__cfi_check" , RetTy: Type::getVoidTy(C&: Ctx), Args: Type::getInt64Ty(C&: Ctx), |
87 | Args: PointerType::getUnqual(C&: Ctx), Args: PointerType::getUnqual(C&: Ctx)); |
88 | Function *F = cast<Function>(Val: C.getCallee()); |
89 | // Take over the existing function. The frontend emits a weak stub so that the |
90 | // linker knows about the symbol; this pass replaces the function body. |
91 | F->deleteBody(); |
92 | F->setAlignment(Align(4096)); |
93 | |
94 | Triple T(M.getTargetTriple()); |
95 | if (T.isARM() || T.isThumb()) |
96 | F->addFnAttr(Kind: "target-features" , Val: "+thumb-mode" ); |
97 | |
98 | auto args = F->arg_begin(); |
99 | Value &CallSiteTypeId = *(args++); |
100 | CallSiteTypeId.setName("CallSiteTypeId" ); |
101 | Value &Addr = *(args++); |
102 | Addr.setName("Addr" ); |
103 | Value &CFICheckFailData = *(args++); |
104 | CFICheckFailData.setName("CFICheckFailData" ); |
105 | assert(args == F->arg_end()); |
106 | |
107 | BasicBlock *BB = BasicBlock::Create(Context&: Ctx, Name: "entry" , Parent: F); |
108 | BasicBlock *ExitBB = BasicBlock::Create(Context&: Ctx, Name: "exit" , Parent: F); |
109 | |
110 | BasicBlock *TrapBB = BasicBlock::Create(Context&: Ctx, Name: "fail" , Parent: F); |
111 | IRBuilder<> IRBFail(TrapBB); |
112 | FunctionCallee CFICheckFailFn = M.getOrInsertFunction( |
113 | Name: "__cfi_check_fail" , RetTy: Type::getVoidTy(C&: Ctx), Args: PointerType::getUnqual(C&: Ctx), |
114 | Args: PointerType::getUnqual(C&: Ctx)); |
115 | IRBFail.CreateCall(Callee: CFICheckFailFn, Args: {&CFICheckFailData, &Addr}); |
116 | IRBFail.CreateBr(Dest: ExitBB); |
117 | |
118 | IRBuilder<> IRBExit(ExitBB); |
119 | IRBExit.CreateRetVoid(); |
120 | |
121 | IRBuilder<> IRB(BB); |
122 | SwitchInst *SI = IRB.CreateSwitch(V: &CallSiteTypeId, Dest: TrapBB, NumCases: TypeIds.size()); |
123 | for (uint64_t TypeId : TypeIds) { |
124 | ConstantInt *CaseTypeId = ConstantInt::get(Ty: Type::getInt64Ty(C&: Ctx), V: TypeId); |
125 | BasicBlock *TestBB = BasicBlock::Create(Context&: Ctx, Name: "test" , Parent: F); |
126 | IRBuilder<> IRBTest(TestBB); |
127 | |
128 | Value *Test = IRBTest.CreateIntrinsic( |
129 | ID: Intrinsic::type_test, |
130 | Args: {&Addr, |
131 | MetadataAsValue::get(Context&: Ctx, MD: ConstantAsMetadata::get(C: CaseTypeId))}); |
132 | BranchInst *BI = IRBTest.CreateCondBr(Cond: Test, True: ExitBB, False: TrapBB); |
133 | BI->setMetadata(KindID: LLVMContext::MD_prof, Node: VeryLikelyWeights); |
134 | |
135 | SI->addCase(OnVal: CaseTypeId, Dest: TestBB); |
136 | ++NumTypeIds; |
137 | } |
138 | } |
139 | |
140 | bool CrossDSOCFI::runOnModule(Module &M) { |
141 | VeryLikelyWeights = MDBuilder(M.getContext()).createLikelyBranchWeights(); |
142 | if (M.getModuleFlag(Key: "Cross-DSO CFI" ) == nullptr) |
143 | return false; |
144 | buildCFICheck(M); |
145 | return true; |
146 | } |
147 | |
148 | PreservedAnalyses CrossDSOCFIPass::run(Module &M, ModuleAnalysisManager &AM) { |
149 | CrossDSOCFI Impl; |
150 | bool Changed = Impl.runOnModule(M); |
151 | if (!Changed) |
152 | return PreservedAnalyses::all(); |
153 | return PreservedAnalyses::none(); |
154 | } |
155 | |