1 | //===---------------- BPFAdjustOpt.cpp - Adjust Optimization --------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // Adjust optimization to make the code more kernel verifier friendly. |
10 | // |
11 | //===----------------------------------------------------------------------===// |
12 | |
13 | #include "BPF.h" |
14 | #include "BPFCORE.h" |
15 | #include "BPFTargetMachine.h" |
16 | #include "llvm/IR/Instruction.h" |
17 | #include "llvm/IR/Instructions.h" |
18 | #include "llvm/IR/IntrinsicsBPF.h" |
19 | #include "llvm/IR/Module.h" |
20 | #include "llvm/IR/PatternMatch.h" |
21 | #include "llvm/IR/Type.h" |
22 | #include "llvm/IR/User.h" |
23 | #include "llvm/IR/Value.h" |
24 | #include "llvm/Pass.h" |
25 | #include "llvm/Transforms/Utils/BasicBlockUtils.h" |
26 | |
27 | #define DEBUG_TYPE "bpf-adjust-opt" |
28 | |
29 | using namespace llvm; |
30 | using namespace llvm::PatternMatch; |
31 | |
32 | static cl::opt<bool> |
33 | DisableBPFserializeICMP("bpf-disable-serialize-icmp" , cl::Hidden, |
34 | cl::desc("BPF: Disable Serializing ICMP insns." ), |
35 | cl::init(Val: false)); |
36 | |
37 | static cl::opt<bool> DisableBPFavoidSpeculation( |
38 | "bpf-disable-avoid-speculation" , cl::Hidden, |
39 | cl::desc("BPF: Disable Avoiding Speculative Code Motion." ), |
40 | cl::init(Val: false)); |
41 | |
42 | namespace { |
43 | class BPFAdjustOptImpl { |
44 | struct PassThroughInfo { |
45 | Instruction *Input; |
46 | Instruction *UsedInst; |
47 | uint32_t OpIdx; |
48 | PassThroughInfo(Instruction *I, Instruction *U, uint32_t Idx) |
49 | : Input(I), UsedInst(U), OpIdx(Idx) {} |
50 | }; |
51 | |
52 | public: |
53 | BPFAdjustOptImpl(Module *M) : M(M) {} |
54 | |
55 | bool run(); |
56 | |
57 | private: |
58 | Module *M; |
59 | SmallVector<PassThroughInfo, 16> PassThroughs; |
60 | |
61 | bool adjustICmpToBuiltin(); |
62 | void adjustBasicBlock(BasicBlock &BB); |
63 | bool serializeICMPCrossBB(BasicBlock &BB); |
64 | void adjustInst(Instruction &I); |
65 | bool serializeICMPInBB(Instruction &I); |
66 | bool avoidSpeculation(Instruction &I); |
67 | bool insertPassThrough(); |
68 | }; |
69 | |
70 | } // End anonymous namespace |
71 | |
72 | bool BPFAdjustOptImpl::run() { |
73 | bool Changed = adjustICmpToBuiltin(); |
74 | |
75 | for (Function &F : *M) |
76 | for (auto &BB : F) { |
77 | adjustBasicBlock(BB); |
78 | for (auto &I : BB) |
79 | adjustInst(I); |
80 | } |
81 | return insertPassThrough() || Changed; |
82 | } |
83 | |
84 | // Commit acabad9ff6bf ("[InstCombine] try to canonicalize icmp with |
85 | // trunc op into mask and cmp") added a transformation to |
86 | // convert "(conv)a < power_2_const" to "a & <const>" in certain |
87 | // cases and bpf kernel verifier has to handle the resulted code |
88 | // conservatively and this may reject otherwise legitimate program. |
89 | // Here, we change related icmp code to a builtin which will |
90 | // be restored to original icmp code later to prevent that |
91 | // InstCombine transformatin. |
92 | bool BPFAdjustOptImpl::adjustICmpToBuiltin() { |
93 | bool Changed = false; |
94 | ICmpInst *ToBeDeleted = nullptr; |
95 | for (Function &F : *M) |
96 | for (auto &BB : F) |
97 | for (auto &I : BB) { |
98 | if (ToBeDeleted) { |
99 | ToBeDeleted->eraseFromParent(); |
100 | ToBeDeleted = nullptr; |
101 | } |
102 | |
103 | auto *Icmp = dyn_cast<ICmpInst>(Val: &I); |
104 | if (!Icmp) |
105 | continue; |
106 | |
107 | Value *Op0 = Icmp->getOperand(i_nocapture: 0); |
108 | if (!isa<TruncInst>(Val: Op0)) |
109 | continue; |
110 | |
111 | auto ConstOp1 = dyn_cast<ConstantInt>(Val: Icmp->getOperand(i_nocapture: 1)); |
112 | if (!ConstOp1) |
113 | continue; |
114 | |
115 | auto ConstOp1Val = ConstOp1->getValue().getZExtValue(); |
116 | auto Op = Icmp->getPredicate(); |
117 | if (Op == ICmpInst::ICMP_ULT || Op == ICmpInst::ICMP_UGE) { |
118 | if ((ConstOp1Val - 1) & ConstOp1Val) |
119 | continue; |
120 | } else if (Op == ICmpInst::ICMP_ULE || Op == ICmpInst::ICMP_UGT) { |
121 | if (ConstOp1Val & (ConstOp1Val + 1)) |
122 | continue; |
123 | } else { |
124 | continue; |
125 | } |
126 | |
127 | Constant *Opcode = |
128 | ConstantInt::get(Ty: Type::getInt32Ty(C&: BB.getContext()), V: Op); |
129 | Function *Fn = Intrinsic::getDeclaration( |
130 | M, id: Intrinsic::bpf_compare, Tys: {Op0->getType(), ConstOp1->getType()}); |
131 | auto *NewInst = CallInst::Create(Func: Fn, Args: {Opcode, Op0, ConstOp1}); |
132 | NewInst->insertBefore(InsertPos: &I); |
133 | Icmp->replaceAllUsesWith(V: NewInst); |
134 | Changed = true; |
135 | ToBeDeleted = Icmp; |
136 | } |
137 | |
138 | return Changed; |
139 | } |
140 | |
141 | bool BPFAdjustOptImpl::insertPassThrough() { |
142 | for (auto &Info : PassThroughs) { |
143 | auto *CI = BPFCoreSharedInfo::insertPassThrough( |
144 | M, BB: Info.UsedInst->getParent(), Input: Info.Input, Before: Info.UsedInst); |
145 | Info.UsedInst->setOperand(i: Info.OpIdx, Val: CI); |
146 | } |
147 | |
148 | return !PassThroughs.empty(); |
149 | } |
150 | |
151 | // To avoid combining conditionals in the same basic block by |
152 | // instrcombine optimization. |
153 | bool BPFAdjustOptImpl::serializeICMPInBB(Instruction &I) { |
154 | // For: |
155 | // comp1 = icmp <opcode> ...; |
156 | // comp2 = icmp <opcode> ...; |
157 | // ... or comp1 comp2 ... |
158 | // changed to: |
159 | // comp1 = icmp <opcode> ...; |
160 | // comp2 = icmp <opcode> ...; |
161 | // new_comp1 = __builtin_bpf_passthrough(seq_num, comp1) |
162 | // ... or new_comp1 comp2 ... |
163 | Value *Op0, *Op1; |
164 | // Use LogicalOr (accept `or i1` as well as `select i1 Op0, true, Op1`) |
165 | if (!match(V: &I, P: m_LogicalOr(L: m_Value(V&: Op0), R: m_Value(V&: Op1)))) |
166 | return false; |
167 | auto *Icmp1 = dyn_cast<ICmpInst>(Val: Op0); |
168 | if (!Icmp1) |
169 | return false; |
170 | auto *Icmp2 = dyn_cast<ICmpInst>(Val: Op1); |
171 | if (!Icmp2) |
172 | return false; |
173 | |
174 | Value *Icmp1Op0 = Icmp1->getOperand(i_nocapture: 0); |
175 | Value *Icmp2Op0 = Icmp2->getOperand(i_nocapture: 0); |
176 | if (Icmp1Op0 != Icmp2Op0) |
177 | return false; |
178 | |
179 | // Now we got two icmp instructions which feed into |
180 | // an "or" instruction. |
181 | PassThroughInfo Info(Icmp1, &I, 0); |
182 | PassThroughs.push_back(Elt: Info); |
183 | return true; |
184 | } |
185 | |
186 | // To avoid combining conditionals in the same basic block by |
187 | // instrcombine optimization. |
188 | bool BPFAdjustOptImpl::serializeICMPCrossBB(BasicBlock &BB) { |
189 | // For: |
190 | // B1: |
191 | // comp1 = icmp <opcode> ...; |
192 | // if (comp1) goto B2 else B3; |
193 | // B2: |
194 | // comp2 = icmp <opcode> ...; |
195 | // if (comp2) goto B4 else B5; |
196 | // B4: |
197 | // ... |
198 | // changed to: |
199 | // B1: |
200 | // comp1 = icmp <opcode> ...; |
201 | // comp1 = __builtin_bpf_passthrough(seq_num, comp1); |
202 | // if (comp1) goto B2 else B3; |
203 | // B2: |
204 | // comp2 = icmp <opcode> ...; |
205 | // if (comp2) goto B4 else B5; |
206 | // B4: |
207 | // ... |
208 | |
209 | // Check basic predecessors, if two of them (say B1, B2) are using |
210 | // icmp instructions to generate conditions and one is the predesessor |
211 | // of another (e.g., B1 is the predecessor of B2). Add a passthrough |
212 | // barrier after icmp inst of block B1. |
213 | BasicBlock *B2 = BB.getSinglePredecessor(); |
214 | if (!B2) |
215 | return false; |
216 | |
217 | BasicBlock *B1 = B2->getSinglePredecessor(); |
218 | if (!B1) |
219 | return false; |
220 | |
221 | Instruction *TI = B2->getTerminator(); |
222 | auto *BI = dyn_cast<BranchInst>(Val: TI); |
223 | if (!BI || !BI->isConditional()) |
224 | return false; |
225 | auto *Cond = dyn_cast<ICmpInst>(Val: BI->getCondition()); |
226 | if (!Cond || B2->getFirstNonPHI() != Cond) |
227 | return false; |
228 | Value *B2Op0 = Cond->getOperand(i_nocapture: 0); |
229 | auto Cond2Op = Cond->getPredicate(); |
230 | |
231 | TI = B1->getTerminator(); |
232 | BI = dyn_cast<BranchInst>(Val: TI); |
233 | if (!BI || !BI->isConditional()) |
234 | return false; |
235 | Cond = dyn_cast<ICmpInst>(Val: BI->getCondition()); |
236 | if (!Cond) |
237 | return false; |
238 | Value *B1Op0 = Cond->getOperand(i_nocapture: 0); |
239 | auto Cond1Op = Cond->getPredicate(); |
240 | |
241 | if (B1Op0 != B2Op0) |
242 | return false; |
243 | |
244 | if (Cond1Op == ICmpInst::ICMP_SGT || Cond1Op == ICmpInst::ICMP_SGE) { |
245 | if (Cond2Op != ICmpInst::ICMP_SLT && Cond2Op != ICmpInst::ICMP_SLE) |
246 | return false; |
247 | } else if (Cond1Op == ICmpInst::ICMP_SLT || Cond1Op == ICmpInst::ICMP_SLE) { |
248 | if (Cond2Op != ICmpInst::ICMP_SGT && Cond2Op != ICmpInst::ICMP_SGE) |
249 | return false; |
250 | } else if (Cond1Op == ICmpInst::ICMP_ULT || Cond1Op == ICmpInst::ICMP_ULE) { |
251 | if (Cond2Op != ICmpInst::ICMP_UGT && Cond2Op != ICmpInst::ICMP_UGE) |
252 | return false; |
253 | } else if (Cond1Op == ICmpInst::ICMP_UGT || Cond1Op == ICmpInst::ICMP_UGE) { |
254 | if (Cond2Op != ICmpInst::ICMP_ULT && Cond2Op != ICmpInst::ICMP_ULE) |
255 | return false; |
256 | } else { |
257 | return false; |
258 | } |
259 | |
260 | PassThroughInfo Info(Cond, BI, 0); |
261 | PassThroughs.push_back(Elt: Info); |
262 | |
263 | return true; |
264 | } |
265 | |
266 | // To avoid speculative hoisting certain computations out of |
267 | // a basic block. |
268 | bool BPFAdjustOptImpl::avoidSpeculation(Instruction &I) { |
269 | if (auto *LdInst = dyn_cast<LoadInst>(Val: &I)) { |
270 | if (auto *GV = dyn_cast<GlobalVariable>(Val: LdInst->getOperand(i_nocapture: 0))) { |
271 | if (GV->hasAttribute(Kind: BPFCoreSharedInfo::AmaAttr) || |
272 | GV->hasAttribute(Kind: BPFCoreSharedInfo::TypeIdAttr)) |
273 | return false; |
274 | } |
275 | } |
276 | |
277 | if (!isa<LoadInst>(Val: &I) && !isa<CallInst>(Val: &I)) |
278 | return false; |
279 | |
280 | // For: |
281 | // B1: |
282 | // var = ... |
283 | // ... |
284 | // /* icmp may not be in the same block as var = ... */ |
285 | // comp1 = icmp <opcode> var, <const>; |
286 | // if (comp1) goto B2 else B3; |
287 | // B2: |
288 | // ... var ... |
289 | // change to: |
290 | // B1: |
291 | // var = ... |
292 | // ... |
293 | // /* icmp may not be in the same block as var = ... */ |
294 | // comp1 = icmp <opcode> var, <const>; |
295 | // if (comp1) goto B2 else B3; |
296 | // B2: |
297 | // var = __builtin_bpf_passthrough(seq_num, var); |
298 | // ... var ... |
299 | bool isCandidate = false; |
300 | SmallVector<PassThroughInfo, 4> Candidates; |
301 | for (User *U : I.users()) { |
302 | Instruction *Inst = dyn_cast<Instruction>(Val: U); |
303 | if (!Inst) |
304 | continue; |
305 | |
306 | // May cover a little bit more than the |
307 | // above pattern. |
308 | if (auto *Icmp1 = dyn_cast<ICmpInst>(Val: Inst)) { |
309 | Value *Icmp1Op1 = Icmp1->getOperand(i_nocapture: 1); |
310 | if (!isa<Constant>(Val: Icmp1Op1)) |
311 | return false; |
312 | isCandidate = true; |
313 | continue; |
314 | } |
315 | |
316 | // Ignore the use in the same basic block as the definition. |
317 | if (Inst->getParent() == I.getParent()) |
318 | continue; |
319 | |
320 | // use in a different basic block, If there is a call or |
321 | // load/store insn before this instruction in this basic |
322 | // block. Most likely it cannot be hoisted out. Skip it. |
323 | for (auto &I2 : *Inst->getParent()) { |
324 | if (isa<CallInst>(Val: &I2)) |
325 | return false; |
326 | if (isa<LoadInst>(Val: &I2) || isa<StoreInst>(Val: &I2)) |
327 | return false; |
328 | if (&I2 == Inst) |
329 | break; |
330 | } |
331 | |
332 | // It should be used in a GEP or a simple arithmetic like |
333 | // ZEXT/SEXT which is used for GEP. |
334 | if (Inst->getOpcode() == Instruction::ZExt || |
335 | Inst->getOpcode() == Instruction::SExt) { |
336 | PassThroughInfo Info(&I, Inst, 0); |
337 | Candidates.push_back(Elt: Info); |
338 | } else if (auto *GI = dyn_cast<GetElementPtrInst>(Val: Inst)) { |
339 | // traverse GEP inst to find Use operand index |
340 | unsigned i, e; |
341 | for (i = 1, e = GI->getNumOperands(); i != e; ++i) { |
342 | Value *V = GI->getOperand(i_nocapture: i); |
343 | if (V == &I) |
344 | break; |
345 | } |
346 | if (i == e) |
347 | continue; |
348 | |
349 | PassThroughInfo Info(&I, GI, i); |
350 | Candidates.push_back(Elt: Info); |
351 | } |
352 | } |
353 | |
354 | if (!isCandidate || Candidates.empty()) |
355 | return false; |
356 | |
357 | llvm::append_range(C&: PassThroughs, R&: Candidates); |
358 | return true; |
359 | } |
360 | |
361 | void BPFAdjustOptImpl::adjustBasicBlock(BasicBlock &BB) { |
362 | if (!DisableBPFserializeICMP && serializeICMPCrossBB(BB)) |
363 | return; |
364 | } |
365 | |
366 | void BPFAdjustOptImpl::adjustInst(Instruction &I) { |
367 | if (!DisableBPFserializeICMP && serializeICMPInBB(I)) |
368 | return; |
369 | if (!DisableBPFavoidSpeculation && avoidSpeculation(I)) |
370 | return; |
371 | } |
372 | |
373 | PreservedAnalyses BPFAdjustOptPass::run(Module &M, ModuleAnalysisManager &AM) { |
374 | return BPFAdjustOptImpl(&M).run() ? PreservedAnalyses::none() |
375 | : PreservedAnalyses::all(); |
376 | } |
377 | |