1 | //===---------------- BPFAdjustOpt.cpp - Adjust Optimization --------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // Adjust optimization to make the code more kernel verifier friendly. |
10 | // |
11 | //===----------------------------------------------------------------------===// |
12 | |
13 | #include "BPF.h" |
14 | #include "BPFCORE.h" |
15 | #include "llvm/IR/Instruction.h" |
16 | #include "llvm/IR/Instructions.h" |
17 | #include "llvm/IR/IntrinsicsBPF.h" |
18 | #include "llvm/IR/Module.h" |
19 | #include "llvm/IR/PatternMatch.h" |
20 | #include "llvm/IR/Type.h" |
21 | #include "llvm/IR/User.h" |
22 | #include "llvm/IR/Value.h" |
23 | #include "llvm/Pass.h" |
24 | #include "llvm/Transforms/Utils/BasicBlockUtils.h" |
25 | |
26 | #define DEBUG_TYPE "bpf-adjust-opt" |
27 | |
28 | using namespace llvm; |
29 | using namespace llvm::PatternMatch; |
30 | |
31 | static cl::opt<bool> |
32 | DisableBPFserializeICMP("bpf-disable-serialize-icmp" , cl::Hidden, |
33 | cl::desc("BPF: Disable Serializing ICMP insns." ), |
34 | cl::init(Val: false)); |
35 | |
36 | static cl::opt<bool> DisableBPFavoidSpeculation( |
37 | "bpf-disable-avoid-speculation" , cl::Hidden, |
38 | cl::desc("BPF: Disable Avoiding Speculative Code Motion." ), |
39 | cl::init(Val: false)); |
40 | |
41 | namespace { |
42 | class BPFAdjustOptImpl { |
43 | struct PassThroughInfo { |
44 | Instruction *Input; |
45 | Instruction *UsedInst; |
46 | uint32_t OpIdx; |
47 | PassThroughInfo(Instruction *I, Instruction *U, uint32_t Idx) |
48 | : Input(I), UsedInst(U), OpIdx(Idx) {} |
49 | }; |
50 | |
51 | public: |
52 | BPFAdjustOptImpl(Module *M) : M(M) {} |
53 | |
54 | bool run(); |
55 | |
56 | private: |
57 | Module *M; |
58 | SmallVector<PassThroughInfo, 16> PassThroughs; |
59 | |
60 | bool adjustICmpToBuiltin(); |
61 | void adjustBasicBlock(BasicBlock &BB); |
62 | bool serializeICMPCrossBB(BasicBlock &BB); |
63 | void adjustInst(Instruction &I); |
64 | bool serializeICMPInBB(Instruction &I); |
65 | bool avoidSpeculation(Instruction &I); |
66 | bool insertPassThrough(); |
67 | }; |
68 | |
69 | } // End anonymous namespace |
70 | |
71 | bool BPFAdjustOptImpl::run() { |
72 | bool Changed = adjustICmpToBuiltin(); |
73 | |
74 | for (Function &F : *M) |
75 | for (auto &BB : F) { |
76 | adjustBasicBlock(BB); |
77 | for (auto &I : BB) |
78 | adjustInst(I); |
79 | } |
80 | return insertPassThrough() || Changed; |
81 | } |
82 | |
83 | // Commit acabad9ff6bf ("[InstCombine] try to canonicalize icmp with |
84 | // trunc op into mask and cmp") added a transformation to |
85 | // convert "(conv)a < power_2_const" to "a & <const>" in certain |
86 | // cases and bpf kernel verifier has to handle the resulted code |
87 | // conservatively and this may reject otherwise legitimate program. |
88 | // Here, we change related icmp code to a builtin which will |
89 | // be restored to original icmp code later to prevent that |
90 | // InstCombine transformatin. |
91 | bool BPFAdjustOptImpl::adjustICmpToBuiltin() { |
92 | bool Changed = false; |
93 | ICmpInst *ToBeDeleted = nullptr; |
94 | for (Function &F : *M) |
95 | for (auto &BB : F) |
96 | for (auto &I : BB) { |
97 | if (ToBeDeleted) { |
98 | ToBeDeleted->eraseFromParent(); |
99 | ToBeDeleted = nullptr; |
100 | } |
101 | |
102 | auto *Icmp = dyn_cast<ICmpInst>(Val: &I); |
103 | if (!Icmp) |
104 | continue; |
105 | |
106 | Value *Op0 = Icmp->getOperand(i_nocapture: 0); |
107 | if (!isa<TruncInst>(Val: Op0)) |
108 | continue; |
109 | |
110 | auto ConstOp1 = dyn_cast<ConstantInt>(Val: Icmp->getOperand(i_nocapture: 1)); |
111 | if (!ConstOp1) |
112 | continue; |
113 | |
114 | auto ConstOp1Val = ConstOp1->getValue().getZExtValue(); |
115 | auto Op = Icmp->getPredicate(); |
116 | if (Op == ICmpInst::ICMP_ULT || Op == ICmpInst::ICMP_UGE) { |
117 | if ((ConstOp1Val - 1) & ConstOp1Val) |
118 | continue; |
119 | } else if (Op == ICmpInst::ICMP_ULE || Op == ICmpInst::ICMP_UGT) { |
120 | if (ConstOp1Val & (ConstOp1Val + 1)) |
121 | continue; |
122 | } else { |
123 | continue; |
124 | } |
125 | |
126 | Constant *Opcode = |
127 | ConstantInt::get(Ty: Type::getInt32Ty(C&: BB.getContext()), V: Op); |
128 | Function *Fn = Intrinsic::getOrInsertDeclaration( |
129 | M, id: Intrinsic::bpf_compare, Tys: {Op0->getType(), ConstOp1->getType()}); |
130 | auto *NewInst = CallInst::Create(Func: Fn, Args: {Opcode, Op0, ConstOp1}); |
131 | NewInst->insertBefore(InsertPos: I.getIterator()); |
132 | Icmp->replaceAllUsesWith(V: NewInst); |
133 | Changed = true; |
134 | ToBeDeleted = Icmp; |
135 | } |
136 | |
137 | return Changed; |
138 | } |
139 | |
140 | bool BPFAdjustOptImpl::insertPassThrough() { |
141 | for (auto &Info : PassThroughs) { |
142 | auto *CI = BPFCoreSharedInfo::insertPassThrough( |
143 | M, BB: Info.UsedInst->getParent(), Input: Info.Input, Before: Info.UsedInst); |
144 | Info.UsedInst->setOperand(i: Info.OpIdx, Val: CI); |
145 | } |
146 | |
147 | return !PassThroughs.empty(); |
148 | } |
149 | |
150 | // To avoid combining conditionals in the same basic block by |
151 | // instrcombine optimization. |
152 | bool BPFAdjustOptImpl::serializeICMPInBB(Instruction &I) { |
153 | // For: |
154 | // comp1 = icmp <opcode> ...; |
155 | // comp2 = icmp <opcode> ...; |
156 | // ... or comp1 comp2 ... |
157 | // changed to: |
158 | // comp1 = icmp <opcode> ...; |
159 | // comp2 = icmp <opcode> ...; |
160 | // new_comp1 = __builtin_bpf_passthrough(seq_num, comp1) |
161 | // ... or new_comp1 comp2 ... |
162 | Value *Op0, *Op1; |
163 | // Use LogicalOr (accept `or i1` as well as `select i1 Op0, true, Op1`) |
164 | if (!match(V: &I, P: m_LogicalOr(L: m_Value(V&: Op0), R: m_Value(V&: Op1)))) |
165 | return false; |
166 | auto *Icmp1 = dyn_cast<ICmpInst>(Val: Op0); |
167 | if (!Icmp1) |
168 | return false; |
169 | auto *Icmp2 = dyn_cast<ICmpInst>(Val: Op1); |
170 | if (!Icmp2) |
171 | return false; |
172 | |
173 | Value *Icmp1Op0 = Icmp1->getOperand(i_nocapture: 0); |
174 | Value *Icmp2Op0 = Icmp2->getOperand(i_nocapture: 0); |
175 | if (Icmp1Op0 != Icmp2Op0) |
176 | return false; |
177 | |
178 | // Now we got two icmp instructions which feed into |
179 | // an "or" instruction. |
180 | PassThroughInfo Info(Icmp1, &I, 0); |
181 | PassThroughs.push_back(Elt: Info); |
182 | return true; |
183 | } |
184 | |
185 | // To avoid combining conditionals in the same basic block by |
186 | // instrcombine optimization. |
187 | bool BPFAdjustOptImpl::serializeICMPCrossBB(BasicBlock &BB) { |
188 | // For: |
189 | // B1: |
190 | // comp1 = icmp <opcode> ...; |
191 | // if (comp1) goto B2 else B3; |
192 | // B2: |
193 | // comp2 = icmp <opcode> ...; |
194 | // if (comp2) goto B4 else B5; |
195 | // B4: |
196 | // ... |
197 | // changed to: |
198 | // B1: |
199 | // comp1 = icmp <opcode> ...; |
200 | // comp1 = __builtin_bpf_passthrough(seq_num, comp1); |
201 | // if (comp1) goto B2 else B3; |
202 | // B2: |
203 | // comp2 = icmp <opcode> ...; |
204 | // if (comp2) goto B4 else B5; |
205 | // B4: |
206 | // ... |
207 | |
208 | // Check basic predecessors, if two of them (say B1, B2) are using |
209 | // icmp instructions to generate conditions and one is the predesessor |
210 | // of another (e.g., B1 is the predecessor of B2). Add a passthrough |
211 | // barrier after icmp inst of block B1. |
212 | BasicBlock *B2 = BB.getSinglePredecessor(); |
213 | if (!B2) |
214 | return false; |
215 | |
216 | BasicBlock *B1 = B2->getSinglePredecessor(); |
217 | if (!B1) |
218 | return false; |
219 | |
220 | Instruction *TI = B2->getTerminator(); |
221 | auto *BI = dyn_cast<BranchInst>(Val: TI); |
222 | if (!BI || !BI->isConditional()) |
223 | return false; |
224 | auto *Cond = dyn_cast<ICmpInst>(Val: BI->getCondition()); |
225 | if (!Cond || &*B2->getFirstNonPHIIt() != Cond) |
226 | return false; |
227 | Value *B2Op0 = Cond->getOperand(i_nocapture: 0); |
228 | auto Cond2Op = Cond->getPredicate(); |
229 | |
230 | TI = B1->getTerminator(); |
231 | BI = dyn_cast<BranchInst>(Val: TI); |
232 | if (!BI || !BI->isConditional()) |
233 | return false; |
234 | Cond = dyn_cast<ICmpInst>(Val: BI->getCondition()); |
235 | if (!Cond) |
236 | return false; |
237 | Value *B1Op0 = Cond->getOperand(i_nocapture: 0); |
238 | auto Cond1Op = Cond->getPredicate(); |
239 | |
240 | if (B1Op0 != B2Op0) |
241 | return false; |
242 | |
243 | if (Cond1Op == ICmpInst::ICMP_SGT || Cond1Op == ICmpInst::ICMP_SGE) { |
244 | if (Cond2Op != ICmpInst::ICMP_SLT && Cond2Op != ICmpInst::ICMP_SLE) |
245 | return false; |
246 | } else if (Cond1Op == ICmpInst::ICMP_SLT || Cond1Op == ICmpInst::ICMP_SLE) { |
247 | if (Cond2Op != ICmpInst::ICMP_SGT && Cond2Op != ICmpInst::ICMP_SGE) |
248 | return false; |
249 | } else if (Cond1Op == ICmpInst::ICMP_ULT || Cond1Op == ICmpInst::ICMP_ULE) { |
250 | if (Cond2Op != ICmpInst::ICMP_UGT && Cond2Op != ICmpInst::ICMP_UGE) |
251 | return false; |
252 | } else if (Cond1Op == ICmpInst::ICMP_UGT || Cond1Op == ICmpInst::ICMP_UGE) { |
253 | if (Cond2Op != ICmpInst::ICMP_ULT && Cond2Op != ICmpInst::ICMP_ULE) |
254 | return false; |
255 | } else { |
256 | return false; |
257 | } |
258 | |
259 | PassThroughInfo Info(Cond, BI, 0); |
260 | PassThroughs.push_back(Elt: Info); |
261 | |
262 | return true; |
263 | } |
264 | |
265 | // To avoid speculative hoisting certain computations out of |
266 | // a basic block. |
267 | bool BPFAdjustOptImpl::avoidSpeculation(Instruction &I) { |
268 | if (auto *LdInst = dyn_cast<LoadInst>(Val: &I)) { |
269 | if (auto *GV = dyn_cast<GlobalVariable>(Val: LdInst->getOperand(i_nocapture: 0))) { |
270 | if (GV->hasAttribute(Kind: BPFCoreSharedInfo::AmaAttr) || |
271 | GV->hasAttribute(Kind: BPFCoreSharedInfo::TypeIdAttr)) |
272 | return false; |
273 | } |
274 | } |
275 | |
276 | if (!isa<LoadInst>(Val: &I) && !isa<CallInst>(Val: &I)) |
277 | return false; |
278 | |
279 | // For: |
280 | // B1: |
281 | // var = ... |
282 | // ... |
283 | // /* icmp may not be in the same block as var = ... */ |
284 | // comp1 = icmp <opcode> var, <const>; |
285 | // if (comp1) goto B2 else B3; |
286 | // B2: |
287 | // ... var ... |
288 | // change to: |
289 | // B1: |
290 | // var = ... |
291 | // ... |
292 | // /* icmp may not be in the same block as var = ... */ |
293 | // comp1 = icmp <opcode> var, <const>; |
294 | // if (comp1) goto B2 else B3; |
295 | // B2: |
296 | // var = __builtin_bpf_passthrough(seq_num, var); |
297 | // ... var ... |
298 | bool isCandidate = false; |
299 | SmallVector<PassThroughInfo, 4> Candidates; |
300 | for (User *U : I.users()) { |
301 | Instruction *Inst = dyn_cast<Instruction>(Val: U); |
302 | if (!Inst) |
303 | continue; |
304 | |
305 | // May cover a little bit more than the |
306 | // above pattern. |
307 | if (auto *Icmp1 = dyn_cast<ICmpInst>(Val: Inst)) { |
308 | Value *Icmp1Op1 = Icmp1->getOperand(i_nocapture: 1); |
309 | if (!isa<Constant>(Val: Icmp1Op1)) |
310 | return false; |
311 | isCandidate = true; |
312 | continue; |
313 | } |
314 | |
315 | // Ignore the use in the same basic block as the definition. |
316 | if (Inst->getParent() == I.getParent()) |
317 | continue; |
318 | |
319 | // use in a different basic block, If there is a call or |
320 | // load/store insn before this instruction in this basic |
321 | // block. Most likely it cannot be hoisted out. Skip it. |
322 | for (auto &I2 : *Inst->getParent()) { |
323 | if (isa<CallInst>(Val: &I2)) |
324 | return false; |
325 | if (isa<LoadInst>(Val: &I2) || isa<StoreInst>(Val: &I2)) |
326 | return false; |
327 | if (&I2 == Inst) |
328 | break; |
329 | } |
330 | |
331 | // It should be used in a GEP or a simple arithmetic like |
332 | // ZEXT/SEXT which is used for GEP. |
333 | if (Inst->getOpcode() == Instruction::ZExt || |
334 | Inst->getOpcode() == Instruction::SExt) { |
335 | PassThroughInfo Info(&I, Inst, 0); |
336 | Candidates.push_back(Elt: Info); |
337 | } else if (auto *GI = dyn_cast<GetElementPtrInst>(Val: Inst)) { |
338 | // traverse GEP inst to find Use operand index |
339 | unsigned i, e; |
340 | for (i = 1, e = GI->getNumOperands(); i != e; ++i) { |
341 | Value *V = GI->getOperand(i_nocapture: i); |
342 | if (V == &I) |
343 | break; |
344 | } |
345 | if (i == e) |
346 | continue; |
347 | |
348 | PassThroughInfo Info(&I, GI, i); |
349 | Candidates.push_back(Elt: Info); |
350 | } |
351 | } |
352 | |
353 | if (!isCandidate || Candidates.empty()) |
354 | return false; |
355 | |
356 | llvm::append_range(C&: PassThroughs, R&: Candidates); |
357 | return true; |
358 | } |
359 | |
360 | void BPFAdjustOptImpl::adjustBasicBlock(BasicBlock &BB) { |
361 | if (!DisableBPFserializeICMP && serializeICMPCrossBB(BB)) |
362 | return; |
363 | } |
364 | |
365 | void BPFAdjustOptImpl::adjustInst(Instruction &I) { |
366 | if (!DisableBPFserializeICMP && serializeICMPInBB(I)) |
367 | return; |
368 | if (!DisableBPFavoidSpeculation && avoidSpeculation(I)) |
369 | return; |
370 | } |
371 | |
372 | PreservedAnalyses BPFAdjustOptPass::run(Module &M, ModuleAnalysisManager &AM) { |
373 | return BPFAdjustOptImpl(&M).run() ? PreservedAnalyses::none() |
374 | : PreservedAnalyses::all(); |
375 | } |
376 | |