1//===---------------- BPFAdjustOpt.cpp - Adjust Optimization --------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Adjust optimization to make the code more kernel verifier friendly.
10//
11//===----------------------------------------------------------------------===//
12
13#include "BPF.h"
14#include "BPFCORE.h"
15#include "llvm/IR/Instruction.h"
16#include "llvm/IR/Instructions.h"
17#include "llvm/IR/IntrinsicsBPF.h"
18#include "llvm/IR/Module.h"
19#include "llvm/IR/PatternMatch.h"
20#include "llvm/IR/Type.h"
21#include "llvm/IR/User.h"
22#include "llvm/IR/Value.h"
23#include "llvm/Pass.h"
24#include "llvm/Transforms/Utils/BasicBlockUtils.h"
25
26#define DEBUG_TYPE "bpf-adjust-opt"
27
28using namespace llvm;
29using namespace llvm::PatternMatch;
30
31static cl::opt<bool>
32 DisableBPFserializeICMP("bpf-disable-serialize-icmp", cl::Hidden,
33 cl::desc("BPF: Disable Serializing ICMP insns."),
34 cl::init(Val: false));
35
36static cl::opt<bool> DisableBPFavoidSpeculation(
37 "bpf-disable-avoid-speculation", cl::Hidden,
38 cl::desc("BPF: Disable Avoiding Speculative Code Motion."),
39 cl::init(Val: false));
40
41namespace {
42class BPFAdjustOptImpl {
43 struct PassThroughInfo {
44 Instruction *Input;
45 Instruction *UsedInst;
46 uint32_t OpIdx;
47 PassThroughInfo(Instruction *I, Instruction *U, uint32_t Idx)
48 : Input(I), UsedInst(U), OpIdx(Idx) {}
49 };
50
51public:
52 BPFAdjustOptImpl(Module *M) : M(M) {}
53
54 bool run();
55
56private:
57 Module *M;
58 SmallVector<PassThroughInfo, 16> PassThroughs;
59
60 bool adjustICmpToBuiltin();
61 void adjustBasicBlock(BasicBlock &BB);
62 bool serializeICMPCrossBB(BasicBlock &BB);
63 void adjustInst(Instruction &I);
64 bool serializeICMPInBB(Instruction &I);
65 bool avoidSpeculation(Instruction &I);
66 bool insertPassThrough();
67};
68
69} // End anonymous namespace
70
71bool BPFAdjustOptImpl::run() {
72 bool Changed = adjustICmpToBuiltin();
73
74 for (Function &F : *M)
75 for (auto &BB : F) {
76 adjustBasicBlock(BB);
77 for (auto &I : BB)
78 adjustInst(I);
79 }
80 return insertPassThrough() || Changed;
81}
82
83// Commit acabad9ff6bf ("[InstCombine] try to canonicalize icmp with
84// trunc op into mask and cmp") added a transformation to
85// convert "(conv)a < power_2_const" to "a & <const>" in certain
86// cases and bpf kernel verifier has to handle the resulted code
87// conservatively and this may reject otherwise legitimate program.
88// Here, we change related icmp code to a builtin which will
89// be restored to original icmp code later to prevent that
90// InstCombine transformatin.
91bool BPFAdjustOptImpl::adjustICmpToBuiltin() {
92 bool Changed = false;
93 ICmpInst *ToBeDeleted = nullptr;
94 for (Function &F : *M)
95 for (auto &BB : F)
96 for (auto &I : BB) {
97 if (ToBeDeleted) {
98 ToBeDeleted->eraseFromParent();
99 ToBeDeleted = nullptr;
100 }
101
102 auto *Icmp = dyn_cast<ICmpInst>(Val: &I);
103 if (!Icmp)
104 continue;
105
106 Value *Op0 = Icmp->getOperand(i_nocapture: 0);
107 if (!isa<TruncInst>(Val: Op0))
108 continue;
109
110 auto ConstOp1 = dyn_cast<ConstantInt>(Val: Icmp->getOperand(i_nocapture: 1));
111 if (!ConstOp1)
112 continue;
113
114 auto ConstOp1Val = ConstOp1->getValue().getZExtValue();
115 auto Op = Icmp->getPredicate();
116 if (Op == ICmpInst::ICMP_ULT || Op == ICmpInst::ICMP_UGE) {
117 if ((ConstOp1Val - 1) & ConstOp1Val)
118 continue;
119 } else if (Op == ICmpInst::ICMP_ULE || Op == ICmpInst::ICMP_UGT) {
120 if (ConstOp1Val & (ConstOp1Val + 1))
121 continue;
122 } else {
123 continue;
124 }
125
126 Constant *Opcode =
127 ConstantInt::get(Ty: Type::getInt32Ty(C&: BB.getContext()), V: Op);
128 Function *Fn = Intrinsic::getOrInsertDeclaration(
129 M, id: Intrinsic::bpf_compare, Tys: {Op0->getType(), ConstOp1->getType()});
130 auto *NewInst = CallInst::Create(Func: Fn, Args: {Opcode, Op0, ConstOp1});
131 NewInst->insertBefore(InsertPos: I.getIterator());
132 Icmp->replaceAllUsesWith(V: NewInst);
133 Changed = true;
134 ToBeDeleted = Icmp;
135 }
136
137 return Changed;
138}
139
140bool BPFAdjustOptImpl::insertPassThrough() {
141 for (auto &Info : PassThroughs) {
142 auto *CI = BPFCoreSharedInfo::insertPassThrough(
143 M, BB: Info.UsedInst->getParent(), Input: Info.Input, Before: Info.UsedInst);
144 Info.UsedInst->setOperand(i: Info.OpIdx, Val: CI);
145 }
146
147 return !PassThroughs.empty();
148}
149
150// To avoid combining conditionals in the same basic block by
151// instrcombine optimization.
152bool BPFAdjustOptImpl::serializeICMPInBB(Instruction &I) {
153 // For:
154 // comp1 = icmp <opcode> ...;
155 // comp2 = icmp <opcode> ...;
156 // ... or comp1 comp2 ...
157 // changed to:
158 // comp1 = icmp <opcode> ...;
159 // comp2 = icmp <opcode> ...;
160 // new_comp1 = __builtin_bpf_passthrough(seq_num, comp1)
161 // ... or new_comp1 comp2 ...
162 Value *Op0, *Op1;
163 // Use LogicalOr (accept `or i1` as well as `select i1 Op0, true, Op1`)
164 if (!match(V: &I, P: m_LogicalOr(L: m_Value(V&: Op0), R: m_Value(V&: Op1))))
165 return false;
166 auto *Icmp1 = dyn_cast<ICmpInst>(Val: Op0);
167 if (!Icmp1)
168 return false;
169 auto *Icmp2 = dyn_cast<ICmpInst>(Val: Op1);
170 if (!Icmp2)
171 return false;
172
173 Value *Icmp1Op0 = Icmp1->getOperand(i_nocapture: 0);
174 Value *Icmp2Op0 = Icmp2->getOperand(i_nocapture: 0);
175 if (Icmp1Op0 != Icmp2Op0)
176 return false;
177
178 // Now we got two icmp instructions which feed into
179 // an "or" instruction.
180 PassThroughInfo Info(Icmp1, &I, 0);
181 PassThroughs.push_back(Elt: Info);
182 return true;
183}
184
185// To avoid combining conditionals in the same basic block by
186// instrcombine optimization.
187bool BPFAdjustOptImpl::serializeICMPCrossBB(BasicBlock &BB) {
188 // For:
189 // B1:
190 // comp1 = icmp <opcode> ...;
191 // if (comp1) goto B2 else B3;
192 // B2:
193 // comp2 = icmp <opcode> ...;
194 // if (comp2) goto B4 else B5;
195 // B4:
196 // ...
197 // changed to:
198 // B1:
199 // comp1 = icmp <opcode> ...;
200 // comp1 = __builtin_bpf_passthrough(seq_num, comp1);
201 // if (comp1) goto B2 else B3;
202 // B2:
203 // comp2 = icmp <opcode> ...;
204 // if (comp2) goto B4 else B5;
205 // B4:
206 // ...
207
208 // Check basic predecessors, if two of them (say B1, B2) are using
209 // icmp instructions to generate conditions and one is the predesessor
210 // of another (e.g., B1 is the predecessor of B2). Add a passthrough
211 // barrier after icmp inst of block B1.
212 BasicBlock *B2 = BB.getSinglePredecessor();
213 if (!B2)
214 return false;
215
216 BasicBlock *B1 = B2->getSinglePredecessor();
217 if (!B1)
218 return false;
219
220 Instruction *TI = B2->getTerminator();
221 auto *BI = dyn_cast<BranchInst>(Val: TI);
222 if (!BI || !BI->isConditional())
223 return false;
224 auto *Cond = dyn_cast<ICmpInst>(Val: BI->getCondition());
225 if (!Cond || &*B2->getFirstNonPHIIt() != Cond)
226 return false;
227 Value *B2Op0 = Cond->getOperand(i_nocapture: 0);
228 auto Cond2Op = Cond->getPredicate();
229
230 TI = B1->getTerminator();
231 BI = dyn_cast<BranchInst>(Val: TI);
232 if (!BI || !BI->isConditional())
233 return false;
234 Cond = dyn_cast<ICmpInst>(Val: BI->getCondition());
235 if (!Cond)
236 return false;
237 Value *B1Op0 = Cond->getOperand(i_nocapture: 0);
238 auto Cond1Op = Cond->getPredicate();
239
240 if (B1Op0 != B2Op0)
241 return false;
242
243 if (Cond1Op == ICmpInst::ICMP_SGT || Cond1Op == ICmpInst::ICMP_SGE) {
244 if (Cond2Op != ICmpInst::ICMP_SLT && Cond2Op != ICmpInst::ICMP_SLE)
245 return false;
246 } else if (Cond1Op == ICmpInst::ICMP_SLT || Cond1Op == ICmpInst::ICMP_SLE) {
247 if (Cond2Op != ICmpInst::ICMP_SGT && Cond2Op != ICmpInst::ICMP_SGE)
248 return false;
249 } else if (Cond1Op == ICmpInst::ICMP_ULT || Cond1Op == ICmpInst::ICMP_ULE) {
250 if (Cond2Op != ICmpInst::ICMP_UGT && Cond2Op != ICmpInst::ICMP_UGE)
251 return false;
252 } else if (Cond1Op == ICmpInst::ICMP_UGT || Cond1Op == ICmpInst::ICMP_UGE) {
253 if (Cond2Op != ICmpInst::ICMP_ULT && Cond2Op != ICmpInst::ICMP_ULE)
254 return false;
255 } else {
256 return false;
257 }
258
259 PassThroughInfo Info(Cond, BI, 0);
260 PassThroughs.push_back(Elt: Info);
261
262 return true;
263}
264
265// To avoid speculative hoisting certain computations out of
266// a basic block.
267bool BPFAdjustOptImpl::avoidSpeculation(Instruction &I) {
268 if (auto *LdInst = dyn_cast<LoadInst>(Val: &I)) {
269 if (auto *GV = dyn_cast<GlobalVariable>(Val: LdInst->getOperand(i_nocapture: 0))) {
270 if (GV->hasAttribute(Kind: BPFCoreSharedInfo::AmaAttr) ||
271 GV->hasAttribute(Kind: BPFCoreSharedInfo::TypeIdAttr))
272 return false;
273 }
274 }
275
276 if (!isa<LoadInst>(Val: &I) && !isa<CallInst>(Val: &I))
277 return false;
278
279 // For:
280 // B1:
281 // var = ...
282 // ...
283 // /* icmp may not be in the same block as var = ... */
284 // comp1 = icmp <opcode> var, <const>;
285 // if (comp1) goto B2 else B3;
286 // B2:
287 // ... var ...
288 // change to:
289 // B1:
290 // var = ...
291 // ...
292 // /* icmp may not be in the same block as var = ... */
293 // comp1 = icmp <opcode> var, <const>;
294 // if (comp1) goto B2 else B3;
295 // B2:
296 // var = __builtin_bpf_passthrough(seq_num, var);
297 // ... var ...
298 bool isCandidate = false;
299 SmallVector<PassThroughInfo, 4> Candidates;
300 for (User *U : I.users()) {
301 Instruction *Inst = dyn_cast<Instruction>(Val: U);
302 if (!Inst)
303 continue;
304
305 // May cover a little bit more than the
306 // above pattern.
307 if (auto *Icmp1 = dyn_cast<ICmpInst>(Val: Inst)) {
308 Value *Icmp1Op1 = Icmp1->getOperand(i_nocapture: 1);
309 if (!isa<Constant>(Val: Icmp1Op1))
310 return false;
311 isCandidate = true;
312 continue;
313 }
314
315 // Ignore the use in the same basic block as the definition.
316 if (Inst->getParent() == I.getParent())
317 continue;
318
319 // use in a different basic block, If there is a call or
320 // load/store insn before this instruction in this basic
321 // block. Most likely it cannot be hoisted out. Skip it.
322 for (auto &I2 : *Inst->getParent()) {
323 if (isa<CallInst>(Val: &I2))
324 return false;
325 if (isa<LoadInst>(Val: &I2) || isa<StoreInst>(Val: &I2))
326 return false;
327 if (&I2 == Inst)
328 break;
329 }
330
331 // It should be used in a GEP or a simple arithmetic like
332 // ZEXT/SEXT which is used for GEP.
333 if (Inst->getOpcode() == Instruction::ZExt ||
334 Inst->getOpcode() == Instruction::SExt) {
335 PassThroughInfo Info(&I, Inst, 0);
336 Candidates.push_back(Elt: Info);
337 } else if (auto *GI = dyn_cast<GetElementPtrInst>(Val: Inst)) {
338 // traverse GEP inst to find Use operand index
339 unsigned i, e;
340 for (i = 1, e = GI->getNumOperands(); i != e; ++i) {
341 Value *V = GI->getOperand(i_nocapture: i);
342 if (V == &I)
343 break;
344 }
345 if (i == e)
346 continue;
347
348 PassThroughInfo Info(&I, GI, i);
349 Candidates.push_back(Elt: Info);
350 }
351 }
352
353 if (!isCandidate || Candidates.empty())
354 return false;
355
356 llvm::append_range(C&: PassThroughs, R&: Candidates);
357 return true;
358}
359
360void BPFAdjustOptImpl::adjustBasicBlock(BasicBlock &BB) {
361 if (!DisableBPFserializeICMP && serializeICMPCrossBB(BB))
362 return;
363}
364
365void BPFAdjustOptImpl::adjustInst(Instruction &I) {
366 if (!DisableBPFserializeICMP && serializeICMPInBB(I))
367 return;
368 if (!DisableBPFavoidSpeculation && avoidSpeculation(I))
369 return;
370}
371
372PreservedAnalyses BPFAdjustOptPass::run(Module &M, ModuleAnalysisManager &AM) {
373 return BPFAdjustOptImpl(&M).run() ? PreservedAnalyses::none()
374 : PreservedAnalyses::all();
375}
376