1//===---------------- BPFAdjustOpt.cpp - Adjust Optimization --------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Adjust optimization to make the code more kernel verifier friendly.
10//
11//===----------------------------------------------------------------------===//
12
13#include "BPF.h"
14#include "BPFCORE.h"
15#include "BPFTargetMachine.h"
16#include "llvm/IR/Instruction.h"
17#include "llvm/IR/Instructions.h"
18#include "llvm/IR/IntrinsicsBPF.h"
19#include "llvm/IR/Module.h"
20#include "llvm/IR/PatternMatch.h"
21#include "llvm/IR/Type.h"
22#include "llvm/IR/User.h"
23#include "llvm/IR/Value.h"
24#include "llvm/Pass.h"
25#include "llvm/Transforms/Utils/BasicBlockUtils.h"
26
27#define DEBUG_TYPE "bpf-adjust-opt"
28
29using namespace llvm;
30using namespace llvm::PatternMatch;
31
32static cl::opt<bool>
33 DisableBPFserializeICMP("bpf-disable-serialize-icmp", cl::Hidden,
34 cl::desc("BPF: Disable Serializing ICMP insns."),
35 cl::init(Val: false));
36
37static cl::opt<bool> DisableBPFavoidSpeculation(
38 "bpf-disable-avoid-speculation", cl::Hidden,
39 cl::desc("BPF: Disable Avoiding Speculative Code Motion."),
40 cl::init(Val: false));
41
42namespace {
43class BPFAdjustOptImpl {
44 struct PassThroughInfo {
45 Instruction *Input;
46 Instruction *UsedInst;
47 uint32_t OpIdx;
48 PassThroughInfo(Instruction *I, Instruction *U, uint32_t Idx)
49 : Input(I), UsedInst(U), OpIdx(Idx) {}
50 };
51
52public:
53 BPFAdjustOptImpl(Module *M) : M(M) {}
54
55 bool run();
56
57private:
58 Module *M;
59 SmallVector<PassThroughInfo, 16> PassThroughs;
60
61 bool adjustICmpToBuiltin();
62 void adjustBasicBlock(BasicBlock &BB);
63 bool serializeICMPCrossBB(BasicBlock &BB);
64 void adjustInst(Instruction &I);
65 bool serializeICMPInBB(Instruction &I);
66 bool avoidSpeculation(Instruction &I);
67 bool insertPassThrough();
68};
69
70} // End anonymous namespace
71
72bool BPFAdjustOptImpl::run() {
73 bool Changed = adjustICmpToBuiltin();
74
75 for (Function &F : *M)
76 for (auto &BB : F) {
77 adjustBasicBlock(BB);
78 for (auto &I : BB)
79 adjustInst(I);
80 }
81 return insertPassThrough() || Changed;
82}
83
84// Commit acabad9ff6bf ("[InstCombine] try to canonicalize icmp with
85// trunc op into mask and cmp") added a transformation to
86// convert "(conv)a < power_2_const" to "a & <const>" in certain
87// cases and bpf kernel verifier has to handle the resulted code
88// conservatively and this may reject otherwise legitimate program.
89// Here, we change related icmp code to a builtin which will
90// be restored to original icmp code later to prevent that
91// InstCombine transformatin.
92bool BPFAdjustOptImpl::adjustICmpToBuiltin() {
93 bool Changed = false;
94 ICmpInst *ToBeDeleted = nullptr;
95 for (Function &F : *M)
96 for (auto &BB : F)
97 for (auto &I : BB) {
98 if (ToBeDeleted) {
99 ToBeDeleted->eraseFromParent();
100 ToBeDeleted = nullptr;
101 }
102
103 auto *Icmp = dyn_cast<ICmpInst>(Val: &I);
104 if (!Icmp)
105 continue;
106
107 Value *Op0 = Icmp->getOperand(i_nocapture: 0);
108 if (!isa<TruncInst>(Val: Op0))
109 continue;
110
111 auto ConstOp1 = dyn_cast<ConstantInt>(Val: Icmp->getOperand(i_nocapture: 1));
112 if (!ConstOp1)
113 continue;
114
115 auto ConstOp1Val = ConstOp1->getValue().getZExtValue();
116 auto Op = Icmp->getPredicate();
117 if (Op == ICmpInst::ICMP_ULT || Op == ICmpInst::ICMP_UGE) {
118 if ((ConstOp1Val - 1) & ConstOp1Val)
119 continue;
120 } else if (Op == ICmpInst::ICMP_ULE || Op == ICmpInst::ICMP_UGT) {
121 if (ConstOp1Val & (ConstOp1Val + 1))
122 continue;
123 } else {
124 continue;
125 }
126
127 Constant *Opcode =
128 ConstantInt::get(Ty: Type::getInt32Ty(C&: BB.getContext()), V: Op);
129 Function *Fn = Intrinsic::getDeclaration(
130 M, id: Intrinsic::bpf_compare, Tys: {Op0->getType(), ConstOp1->getType()});
131 auto *NewInst = CallInst::Create(Func: Fn, Args: {Opcode, Op0, ConstOp1});
132 NewInst->insertBefore(InsertPos: &I);
133 Icmp->replaceAllUsesWith(V: NewInst);
134 Changed = true;
135 ToBeDeleted = Icmp;
136 }
137
138 return Changed;
139}
140
141bool BPFAdjustOptImpl::insertPassThrough() {
142 for (auto &Info : PassThroughs) {
143 auto *CI = BPFCoreSharedInfo::insertPassThrough(
144 M, BB: Info.UsedInst->getParent(), Input: Info.Input, Before: Info.UsedInst);
145 Info.UsedInst->setOperand(i: Info.OpIdx, Val: CI);
146 }
147
148 return !PassThroughs.empty();
149}
150
151// To avoid combining conditionals in the same basic block by
152// instrcombine optimization.
153bool BPFAdjustOptImpl::serializeICMPInBB(Instruction &I) {
154 // For:
155 // comp1 = icmp <opcode> ...;
156 // comp2 = icmp <opcode> ...;
157 // ... or comp1 comp2 ...
158 // changed to:
159 // comp1 = icmp <opcode> ...;
160 // comp2 = icmp <opcode> ...;
161 // new_comp1 = __builtin_bpf_passthrough(seq_num, comp1)
162 // ... or new_comp1 comp2 ...
163 Value *Op0, *Op1;
164 // Use LogicalOr (accept `or i1` as well as `select i1 Op0, true, Op1`)
165 if (!match(V: &I, P: m_LogicalOr(L: m_Value(V&: Op0), R: m_Value(V&: Op1))))
166 return false;
167 auto *Icmp1 = dyn_cast<ICmpInst>(Val: Op0);
168 if (!Icmp1)
169 return false;
170 auto *Icmp2 = dyn_cast<ICmpInst>(Val: Op1);
171 if (!Icmp2)
172 return false;
173
174 Value *Icmp1Op0 = Icmp1->getOperand(i_nocapture: 0);
175 Value *Icmp2Op0 = Icmp2->getOperand(i_nocapture: 0);
176 if (Icmp1Op0 != Icmp2Op0)
177 return false;
178
179 // Now we got two icmp instructions which feed into
180 // an "or" instruction.
181 PassThroughInfo Info(Icmp1, &I, 0);
182 PassThroughs.push_back(Elt: Info);
183 return true;
184}
185
186// To avoid combining conditionals in the same basic block by
187// instrcombine optimization.
188bool BPFAdjustOptImpl::serializeICMPCrossBB(BasicBlock &BB) {
189 // For:
190 // B1:
191 // comp1 = icmp <opcode> ...;
192 // if (comp1) goto B2 else B3;
193 // B2:
194 // comp2 = icmp <opcode> ...;
195 // if (comp2) goto B4 else B5;
196 // B4:
197 // ...
198 // changed to:
199 // B1:
200 // comp1 = icmp <opcode> ...;
201 // comp1 = __builtin_bpf_passthrough(seq_num, comp1);
202 // if (comp1) goto B2 else B3;
203 // B2:
204 // comp2 = icmp <opcode> ...;
205 // if (comp2) goto B4 else B5;
206 // B4:
207 // ...
208
209 // Check basic predecessors, if two of them (say B1, B2) are using
210 // icmp instructions to generate conditions and one is the predesessor
211 // of another (e.g., B1 is the predecessor of B2). Add a passthrough
212 // barrier after icmp inst of block B1.
213 BasicBlock *B2 = BB.getSinglePredecessor();
214 if (!B2)
215 return false;
216
217 BasicBlock *B1 = B2->getSinglePredecessor();
218 if (!B1)
219 return false;
220
221 Instruction *TI = B2->getTerminator();
222 auto *BI = dyn_cast<BranchInst>(Val: TI);
223 if (!BI || !BI->isConditional())
224 return false;
225 auto *Cond = dyn_cast<ICmpInst>(Val: BI->getCondition());
226 if (!Cond || B2->getFirstNonPHI() != Cond)
227 return false;
228 Value *B2Op0 = Cond->getOperand(i_nocapture: 0);
229 auto Cond2Op = Cond->getPredicate();
230
231 TI = B1->getTerminator();
232 BI = dyn_cast<BranchInst>(Val: TI);
233 if (!BI || !BI->isConditional())
234 return false;
235 Cond = dyn_cast<ICmpInst>(Val: BI->getCondition());
236 if (!Cond)
237 return false;
238 Value *B1Op0 = Cond->getOperand(i_nocapture: 0);
239 auto Cond1Op = Cond->getPredicate();
240
241 if (B1Op0 != B2Op0)
242 return false;
243
244 if (Cond1Op == ICmpInst::ICMP_SGT || Cond1Op == ICmpInst::ICMP_SGE) {
245 if (Cond2Op != ICmpInst::ICMP_SLT && Cond2Op != ICmpInst::ICMP_SLE)
246 return false;
247 } else if (Cond1Op == ICmpInst::ICMP_SLT || Cond1Op == ICmpInst::ICMP_SLE) {
248 if (Cond2Op != ICmpInst::ICMP_SGT && Cond2Op != ICmpInst::ICMP_SGE)
249 return false;
250 } else if (Cond1Op == ICmpInst::ICMP_ULT || Cond1Op == ICmpInst::ICMP_ULE) {
251 if (Cond2Op != ICmpInst::ICMP_UGT && Cond2Op != ICmpInst::ICMP_UGE)
252 return false;
253 } else if (Cond1Op == ICmpInst::ICMP_UGT || Cond1Op == ICmpInst::ICMP_UGE) {
254 if (Cond2Op != ICmpInst::ICMP_ULT && Cond2Op != ICmpInst::ICMP_ULE)
255 return false;
256 } else {
257 return false;
258 }
259
260 PassThroughInfo Info(Cond, BI, 0);
261 PassThroughs.push_back(Elt: Info);
262
263 return true;
264}
265
266// To avoid speculative hoisting certain computations out of
267// a basic block.
268bool BPFAdjustOptImpl::avoidSpeculation(Instruction &I) {
269 if (auto *LdInst = dyn_cast<LoadInst>(Val: &I)) {
270 if (auto *GV = dyn_cast<GlobalVariable>(Val: LdInst->getOperand(i_nocapture: 0))) {
271 if (GV->hasAttribute(Kind: BPFCoreSharedInfo::AmaAttr) ||
272 GV->hasAttribute(Kind: BPFCoreSharedInfo::TypeIdAttr))
273 return false;
274 }
275 }
276
277 if (!isa<LoadInst>(Val: &I) && !isa<CallInst>(Val: &I))
278 return false;
279
280 // For:
281 // B1:
282 // var = ...
283 // ...
284 // /* icmp may not be in the same block as var = ... */
285 // comp1 = icmp <opcode> var, <const>;
286 // if (comp1) goto B2 else B3;
287 // B2:
288 // ... var ...
289 // change to:
290 // B1:
291 // var = ...
292 // ...
293 // /* icmp may not be in the same block as var = ... */
294 // comp1 = icmp <opcode> var, <const>;
295 // if (comp1) goto B2 else B3;
296 // B2:
297 // var = __builtin_bpf_passthrough(seq_num, var);
298 // ... var ...
299 bool isCandidate = false;
300 SmallVector<PassThroughInfo, 4> Candidates;
301 for (User *U : I.users()) {
302 Instruction *Inst = dyn_cast<Instruction>(Val: U);
303 if (!Inst)
304 continue;
305
306 // May cover a little bit more than the
307 // above pattern.
308 if (auto *Icmp1 = dyn_cast<ICmpInst>(Val: Inst)) {
309 Value *Icmp1Op1 = Icmp1->getOperand(i_nocapture: 1);
310 if (!isa<Constant>(Val: Icmp1Op1))
311 return false;
312 isCandidate = true;
313 continue;
314 }
315
316 // Ignore the use in the same basic block as the definition.
317 if (Inst->getParent() == I.getParent())
318 continue;
319
320 // use in a different basic block, If there is a call or
321 // load/store insn before this instruction in this basic
322 // block. Most likely it cannot be hoisted out. Skip it.
323 for (auto &I2 : *Inst->getParent()) {
324 if (isa<CallInst>(Val: &I2))
325 return false;
326 if (isa<LoadInst>(Val: &I2) || isa<StoreInst>(Val: &I2))
327 return false;
328 if (&I2 == Inst)
329 break;
330 }
331
332 // It should be used in a GEP or a simple arithmetic like
333 // ZEXT/SEXT which is used for GEP.
334 if (Inst->getOpcode() == Instruction::ZExt ||
335 Inst->getOpcode() == Instruction::SExt) {
336 PassThroughInfo Info(&I, Inst, 0);
337 Candidates.push_back(Elt: Info);
338 } else if (auto *GI = dyn_cast<GetElementPtrInst>(Val: Inst)) {
339 // traverse GEP inst to find Use operand index
340 unsigned i, e;
341 for (i = 1, e = GI->getNumOperands(); i != e; ++i) {
342 Value *V = GI->getOperand(i_nocapture: i);
343 if (V == &I)
344 break;
345 }
346 if (i == e)
347 continue;
348
349 PassThroughInfo Info(&I, GI, i);
350 Candidates.push_back(Elt: Info);
351 }
352 }
353
354 if (!isCandidate || Candidates.empty())
355 return false;
356
357 llvm::append_range(C&: PassThroughs, R&: Candidates);
358 return true;
359}
360
361void BPFAdjustOptImpl::adjustBasicBlock(BasicBlock &BB) {
362 if (!DisableBPFserializeICMP && serializeICMPCrossBB(BB))
363 return;
364}
365
366void BPFAdjustOptImpl::adjustInst(Instruction &I) {
367 if (!DisableBPFserializeICMP && serializeICMPInBB(I))
368 return;
369 if (!DisableBPFavoidSpeculation && avoidSpeculation(I))
370 return;
371}
372
373PreservedAnalyses BPFAdjustOptPass::run(Module &M, ModuleAnalysisManager &AM) {
374 return BPFAdjustOptImpl(&M).run() ? PreservedAnalyses::none()
375 : PreservedAnalyses::all();
376}
377