1 | //===-- CrossDSOCFI.cpp - Externalize this module's CFI checks ------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This pass exports all llvm.bitset's found in the module in the form of a |
10 | // __cfi_check function, which can be used to verify cross-DSO call targets. |
11 | // |
12 | //===----------------------------------------------------------------------===// |
13 | |
14 | #include "llvm/Transforms/IPO/CrossDSOCFI.h" |
15 | #include "llvm/ADT/SetVector.h" |
16 | #include "llvm/ADT/Statistic.h" |
17 | #include "llvm/IR/Constants.h" |
18 | #include "llvm/IR/Function.h" |
19 | #include "llvm/IR/GlobalObject.h" |
20 | #include "llvm/IR/IRBuilder.h" |
21 | #include "llvm/IR/Instructions.h" |
22 | #include "llvm/IR/Intrinsics.h" |
23 | #include "llvm/IR/MDBuilder.h" |
24 | #include "llvm/IR/Module.h" |
25 | #include "llvm/TargetParser/Triple.h" |
26 | #include "llvm/Transforms/IPO.h" |
27 | |
28 | using namespace llvm; |
29 | |
30 | #define DEBUG_TYPE "cross-dso-cfi" |
31 | |
32 | STATISTIC(NumTypeIds, "Number of unique type identifiers" ); |
33 | |
34 | namespace { |
35 | |
36 | struct CrossDSOCFI { |
37 | MDNode *VeryLikelyWeights; |
38 | |
39 | ConstantInt *extractNumericTypeId(MDNode *MD); |
40 | void buildCFICheck(Module &M); |
41 | bool runOnModule(Module &M); |
42 | }; |
43 | |
44 | } // anonymous namespace |
45 | |
46 | /// Extracts a numeric type identifier from an MDNode containing type metadata. |
47 | ConstantInt *CrossDSOCFI::(MDNode *MD) { |
48 | // This check excludes vtables for classes inside anonymous namespaces. |
49 | auto TM = dyn_cast<ValueAsMetadata>(Val: MD->getOperand(I: 1)); |
50 | if (!TM) |
51 | return nullptr; |
52 | auto C = dyn_cast_or_null<ConstantInt>(Val: TM->getValue()); |
53 | if (!C) return nullptr; |
54 | // We are looking for i64 constants. |
55 | if (C->getBitWidth() != 64) return nullptr; |
56 | |
57 | return C; |
58 | } |
59 | |
60 | /// buildCFICheck - emits __cfi_check for the current module. |
61 | void CrossDSOCFI::buildCFICheck(Module &M) { |
62 | // FIXME: verify that __cfi_check ends up near the end of the code section, |
63 | // but before the jump slots created in LowerTypeTests. |
64 | SetVector<uint64_t> TypeIds; |
65 | SmallVector<MDNode *, 2> Types; |
66 | for (GlobalObject &GO : M.global_objects()) { |
67 | Types.clear(); |
68 | GO.getMetadata(KindID: LLVMContext::MD_type, MDs&: Types); |
69 | for (MDNode *Type : Types) |
70 | if (ConstantInt *TypeId = extractNumericTypeId(MD: Type)) |
71 | TypeIds.insert(X: TypeId->getZExtValue()); |
72 | } |
73 | |
74 | NamedMDNode *CfiFunctionsMD = M.getNamedMetadata(Name: "cfi.functions" ); |
75 | if (CfiFunctionsMD) { |
76 | for (auto *Func : CfiFunctionsMD->operands()) { |
77 | assert(Func->getNumOperands() >= 2); |
78 | for (unsigned I = 2; I < Func->getNumOperands(); ++I) |
79 | if (ConstantInt *TypeId = |
80 | extractNumericTypeId(MD: cast<MDNode>(Val: Func->getOperand(I).get()))) |
81 | TypeIds.insert(X: TypeId->getZExtValue()); |
82 | } |
83 | } |
84 | |
85 | LLVMContext &Ctx = M.getContext(); |
86 | FunctionCallee C = M.getOrInsertFunction( |
87 | Name: "__cfi_check" , RetTy: Type::getVoidTy(C&: Ctx), Args: Type::getInt64Ty(C&: Ctx), |
88 | Args: PointerType::getUnqual(C&: Ctx), Args: PointerType::getUnqual(C&: Ctx)); |
89 | Function *F = cast<Function>(Val: C.getCallee()); |
90 | // Take over the existing function. The frontend emits a weak stub so that the |
91 | // linker knows about the symbol; this pass replaces the function body. |
92 | F->deleteBody(); |
93 | F->setAlignment(Align(4096)); |
94 | |
95 | Triple T(M.getTargetTriple()); |
96 | if (T.isARM() || T.isThumb()) |
97 | F->addFnAttr(Kind: "target-features" , Val: "+thumb-mode" ); |
98 | |
99 | auto args = F->arg_begin(); |
100 | Value &CallSiteTypeId = *(args++); |
101 | CallSiteTypeId.setName("CallSiteTypeId" ); |
102 | Value &Addr = *(args++); |
103 | Addr.setName("Addr" ); |
104 | Value &CFICheckFailData = *(args++); |
105 | CFICheckFailData.setName("CFICheckFailData" ); |
106 | assert(args == F->arg_end()); |
107 | |
108 | BasicBlock *BB = BasicBlock::Create(Context&: Ctx, Name: "entry" , Parent: F); |
109 | BasicBlock *ExitBB = BasicBlock::Create(Context&: Ctx, Name: "exit" , Parent: F); |
110 | |
111 | BasicBlock *TrapBB = BasicBlock::Create(Context&: Ctx, Name: "fail" , Parent: F); |
112 | IRBuilder<> IRBFail(TrapBB); |
113 | FunctionCallee CFICheckFailFn = M.getOrInsertFunction( |
114 | Name: "__cfi_check_fail" , RetTy: Type::getVoidTy(C&: Ctx), Args: PointerType::getUnqual(C&: Ctx), |
115 | Args: PointerType::getUnqual(C&: Ctx)); |
116 | IRBFail.CreateCall(Callee: CFICheckFailFn, Args: {&CFICheckFailData, &Addr}); |
117 | IRBFail.CreateBr(Dest: ExitBB); |
118 | |
119 | IRBuilder<> IRBExit(ExitBB); |
120 | IRBExit.CreateRetVoid(); |
121 | |
122 | IRBuilder<> IRB(BB); |
123 | SwitchInst *SI = IRB.CreateSwitch(V: &CallSiteTypeId, Dest: TrapBB, NumCases: TypeIds.size()); |
124 | for (uint64_t TypeId : TypeIds) { |
125 | ConstantInt *CaseTypeId = ConstantInt::get(Ty: Type::getInt64Ty(C&: Ctx), V: TypeId); |
126 | BasicBlock *TestBB = BasicBlock::Create(Context&: Ctx, Name: "test" , Parent: F); |
127 | IRBuilder<> IRBTest(TestBB); |
128 | Function *BitsetTestFn = Intrinsic::getDeclaration(M: &M, id: Intrinsic::type_test); |
129 | |
130 | Value *Test = IRBTest.CreateCall( |
131 | Callee: BitsetTestFn, Args: {&Addr, MetadataAsValue::get( |
132 | Context&: Ctx, MD: ConstantAsMetadata::get(C: CaseTypeId))}); |
133 | BranchInst *BI = IRBTest.CreateCondBr(Cond: Test, True: ExitBB, False: TrapBB); |
134 | BI->setMetadata(KindID: LLVMContext::MD_prof, Node: VeryLikelyWeights); |
135 | |
136 | SI->addCase(OnVal: CaseTypeId, Dest: TestBB); |
137 | ++NumTypeIds; |
138 | } |
139 | } |
140 | |
141 | bool CrossDSOCFI::runOnModule(Module &M) { |
142 | VeryLikelyWeights = MDBuilder(M.getContext()).createLikelyBranchWeights(); |
143 | if (M.getModuleFlag(Key: "Cross-DSO CFI" ) == nullptr) |
144 | return false; |
145 | buildCFICheck(M); |
146 | return true; |
147 | } |
148 | |
149 | PreservedAnalyses CrossDSOCFIPass::run(Module &M, ModuleAnalysisManager &AM) { |
150 | CrossDSOCFI Impl; |
151 | bool Changed = Impl.runOnModule(M); |
152 | if (!Changed) |
153 | return PreservedAnalyses::all(); |
154 | return PreservedAnalyses::none(); |
155 | } |
156 | |