1//===------------ BPFCheckAndAdjustIR.cpp - Check and Adjust IR -----------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Check IR and adjust IR for verifier friendly codes.
10// The following are done for IR checking:
11// - no relocation globals in PHI node.
12// The following are done for IR adjustment:
13// - remove __builtin_bpf_passthrough builtins. Target independent IR
14// optimizations are done and those builtins can be removed.
15// - remove llvm.bpf.getelementptr.and.load builtins.
16// - remove llvm.bpf.getelementptr.and.store builtins.
17// - for loads and stores with base addresses from non-zero address space
18// cast base address to zero address space (support for BPF address spaces).
19//
20//===----------------------------------------------------------------------===//
21
22#include "BPF.h"
23#include "BPFCORE.h"
24#include "llvm/Analysis/LoopInfo.h"
25#include "llvm/IR/GlobalVariable.h"
26#include "llvm/IR/IRBuilder.h"
27#include "llvm/IR/Instruction.h"
28#include "llvm/IR/Instructions.h"
29#include "llvm/IR/IntrinsicsBPF.h"
30#include "llvm/IR/Module.h"
31#include "llvm/IR/Type.h"
32#include "llvm/IR/Value.h"
33#include "llvm/Pass.h"
34#include "llvm/Transforms/Utils/BasicBlockUtils.h"
35
36#define DEBUG_TYPE "bpf-check-and-opt-ir"
37
38using namespace llvm;
39
40namespace {
41
42class BPFCheckAndAdjustIR final : public ModulePass {
43 bool runOnModule(Module &F) override;
44
45public:
46 static char ID;
47 BPFCheckAndAdjustIR() : ModulePass(ID) {}
48 virtual void getAnalysisUsage(AnalysisUsage &AU) const override;
49
50private:
51 void checkIR(Module &M);
52 bool adjustIR(Module &M);
53 bool removePassThroughBuiltin(Module &M);
54 bool removeCompareBuiltin(Module &M);
55 bool sinkMinMax(Module &M);
56 bool removeGEPBuiltins(Module &M);
57 bool insertASpaceCasts(Module &M);
58};
59} // End anonymous namespace
60
61char BPFCheckAndAdjustIR::ID = 0;
62INITIALIZE_PASS(BPFCheckAndAdjustIR, DEBUG_TYPE, "BPF Check And Adjust IR",
63 false, false)
64
65ModulePass *llvm::createBPFCheckAndAdjustIR() {
66 return new BPFCheckAndAdjustIR();
67}
68
69void BPFCheckAndAdjustIR::checkIR(Module &M) {
70 // Ensure relocation global won't appear in PHI node
71 // This may happen if the compiler generated the following code:
72 // B1:
73 // g1 = @llvm.skb_buff:0:1...
74 // ...
75 // goto B_COMMON
76 // B2:
77 // g2 = @llvm.skb_buff:0:2...
78 // ...
79 // goto B_COMMON
80 // B_COMMON:
81 // g = PHI(g1, g2)
82 // x = load g
83 // ...
84 // If anything likes the above "g = PHI(g1, g2)", issue a fatal error.
85 for (Function &F : M)
86 for (auto &BB : F)
87 for (auto &I : BB) {
88 PHINode *PN = dyn_cast<PHINode>(Val: &I);
89 if (!PN || PN->use_empty())
90 continue;
91 for (int i = 0, e = PN->getNumIncomingValues(); i < e; ++i) {
92 auto *GV = dyn_cast<GlobalVariable>(Val: PN->getIncomingValue(i));
93 if (!GV)
94 continue;
95 if (GV->hasAttribute(Kind: BPFCoreSharedInfo::AmaAttr) ||
96 GV->hasAttribute(Kind: BPFCoreSharedInfo::TypeIdAttr))
97 report_fatal_error(reason: "relocation global in PHI node");
98 }
99 }
100}
101
102bool BPFCheckAndAdjustIR::removePassThroughBuiltin(Module &M) {
103 // Remove __builtin_bpf_passthrough()'s which are used to prevent
104 // certain IR optimizations. Now major IR optimizations are done,
105 // remove them.
106 bool Changed = false;
107 CallInst *ToBeDeleted = nullptr;
108 for (Function &F : M)
109 for (auto &BB : F)
110 for (auto &I : BB) {
111 if (ToBeDeleted) {
112 ToBeDeleted->eraseFromParent();
113 ToBeDeleted = nullptr;
114 }
115
116 auto *Call = dyn_cast<CallInst>(Val: &I);
117 if (!Call)
118 continue;
119 auto *GV = dyn_cast<GlobalValue>(Val: Call->getCalledOperand());
120 if (!GV)
121 continue;
122 if (!GV->getName().starts_with(Prefix: "llvm.bpf.passthrough"))
123 continue;
124 Changed = true;
125 Value *Arg = Call->getArgOperand(i: 1);
126 Call->replaceAllUsesWith(V: Arg);
127 ToBeDeleted = Call;
128 }
129 return Changed;
130}
131
132bool BPFCheckAndAdjustIR::removeCompareBuiltin(Module &M) {
133 // Remove __builtin_bpf_compare()'s which are used to prevent
134 // certain IR optimizations. Now major IR optimizations are done,
135 // remove them.
136 bool Changed = false;
137 CallInst *ToBeDeleted = nullptr;
138 for (Function &F : M)
139 for (auto &BB : F)
140 for (auto &I : BB) {
141 if (ToBeDeleted) {
142 ToBeDeleted->eraseFromParent();
143 ToBeDeleted = nullptr;
144 }
145
146 auto *Call = dyn_cast<CallInst>(Val: &I);
147 if (!Call)
148 continue;
149 auto *GV = dyn_cast<GlobalValue>(Val: Call->getCalledOperand());
150 if (!GV)
151 continue;
152 if (!GV->getName().starts_with(Prefix: "llvm.bpf.compare"))
153 continue;
154
155 Changed = true;
156 Value *Arg0 = Call->getArgOperand(i: 0);
157 Value *Arg1 = Call->getArgOperand(i: 1);
158 Value *Arg2 = Call->getArgOperand(i: 2);
159
160 auto OpVal = cast<ConstantInt>(Val: Arg0)->getValue().getZExtValue();
161 CmpInst::Predicate Opcode = (CmpInst::Predicate)OpVal;
162
163 auto *ICmp = new ICmpInst(Opcode, Arg1, Arg2);
164 ICmp->insertBefore(InsertPos: Call->getIterator());
165
166 Call->replaceAllUsesWith(V: ICmp);
167 ToBeDeleted = Call;
168 }
169 return Changed;
170}
171
172struct MinMaxSinkInfo {
173 ICmpInst *ICmp;
174 Value *Other;
175 ICmpInst::Predicate Predicate;
176 CallInst *MinMax;
177 ZExtInst *ZExt;
178 SExtInst *SExt;
179
180 MinMaxSinkInfo(ICmpInst *ICmp, Value *Other, ICmpInst::Predicate Predicate)
181 : ICmp(ICmp), Other(Other), Predicate(Predicate), MinMax(nullptr),
182 ZExt(nullptr), SExt(nullptr) {}
183};
184
185static bool sinkMinMaxInBB(BasicBlock &BB,
186 const std::function<bool(Instruction *)> &Filter) {
187 // Check if V is:
188 // (fn %a %b) or (ext (fn %a %b))
189 // Where:
190 // ext := sext | zext
191 // fn := smin | umin | smax | umax
192 auto IsMinMaxCall = [=](Value *V, MinMaxSinkInfo &Info) {
193 if (auto *ZExt = dyn_cast<ZExtInst>(Val: V)) {
194 V = ZExt->getOperand(i_nocapture: 0);
195 Info.ZExt = ZExt;
196 } else if (auto *SExt = dyn_cast<SExtInst>(Val: V)) {
197 V = SExt->getOperand(i_nocapture: 0);
198 Info.SExt = SExt;
199 }
200
201 auto *Call = dyn_cast<CallInst>(Val: V);
202 if (!Call)
203 return false;
204
205 auto *Called = dyn_cast<Function>(Val: Call->getCalledOperand());
206 if (!Called)
207 return false;
208
209 switch (Called->getIntrinsicID()) {
210 case Intrinsic::smin:
211 case Intrinsic::umin:
212 case Intrinsic::smax:
213 case Intrinsic::umax:
214 break;
215 default:
216 return false;
217 }
218
219 if (!Filter(Call))
220 return false;
221
222 Info.MinMax = Call;
223
224 return true;
225 };
226
227 auto ZeroOrSignExtend = [](IRBuilder<> &Builder, Value *V,
228 MinMaxSinkInfo &Info) {
229 if (Info.SExt) {
230 if (Info.SExt->getType() == V->getType())
231 return V;
232 return Builder.CreateSExt(V, DestTy: Info.SExt->getType());
233 }
234 if (Info.ZExt) {
235 if (Info.ZExt->getType() == V->getType())
236 return V;
237 return Builder.CreateZExt(V, DestTy: Info.ZExt->getType());
238 }
239 return V;
240 };
241
242 bool Changed = false;
243 SmallVector<MinMaxSinkInfo, 2> SinkList;
244
245 // Check BB for instructions like:
246 // insn := (icmp %a (fn ...)) | (icmp (fn ...) %a)
247 //
248 // Where:
249 // fn := min | max | (sext (min ...)) | (sext (max ...))
250 //
251 // Put such instructions to SinkList.
252 for (Instruction &I : BB) {
253 ICmpInst *ICmp = dyn_cast<ICmpInst>(Val: &I);
254 if (!ICmp)
255 continue;
256 if (!ICmp->isRelational())
257 continue;
258 MinMaxSinkInfo First(ICmp, ICmp->getOperand(i_nocapture: 1),
259 ICmpInst::getSwappedPredicate(pred: ICmp->getPredicate()));
260 MinMaxSinkInfo Second(ICmp, ICmp->getOperand(i_nocapture: 0), ICmp->getPredicate());
261 bool FirstMinMax = IsMinMaxCall(ICmp->getOperand(i_nocapture: 0), First);
262 bool SecondMinMax = IsMinMaxCall(ICmp->getOperand(i_nocapture: 1), Second);
263 if (!(FirstMinMax ^ SecondMinMax))
264 continue;
265 SinkList.push_back(Elt: FirstMinMax ? First : Second);
266 }
267
268 // Iterate SinkList and replace each (icmp ...) with corresponding
269 // `x < a && x < b` or similar expression.
270 for (auto &Info : SinkList) {
271 ICmpInst *ICmp = Info.ICmp;
272 CallInst *MinMax = Info.MinMax;
273 Intrinsic::ID IID = MinMax->getCalledFunction()->getIntrinsicID();
274 ICmpInst::Predicate P = Info.Predicate;
275 if (ICmpInst::isSigned(predicate: P) && IID != Intrinsic::smin &&
276 IID != Intrinsic::smax)
277 continue;
278
279 IRBuilder<> Builder(ICmp);
280 Value *X = Info.Other;
281 Value *A = ZeroOrSignExtend(Builder, MinMax->getArgOperand(i: 0), Info);
282 Value *B = ZeroOrSignExtend(Builder, MinMax->getArgOperand(i: 1), Info);
283 bool IsMin = IID == Intrinsic::smin || IID == Intrinsic::umin;
284 bool IsMax = IID == Intrinsic::smax || IID == Intrinsic::umax;
285 bool IsLess = ICmpInst::isLE(P) || ICmpInst::isLT(P);
286 bool IsGreater = ICmpInst::isGE(P) || ICmpInst::isGT(P);
287 assert(IsMin ^ IsMax);
288 assert(IsLess ^ IsGreater);
289
290 Value *Replacement;
291 Value *LHS = Builder.CreateICmp(P, LHS: X, RHS: A);
292 Value *RHS = Builder.CreateICmp(P, LHS: X, RHS: B);
293 if ((IsLess && IsMin) || (IsGreater && IsMax))
294 // x < min(a, b) -> x < a && x < b
295 // x > max(a, b) -> x > a && x > b
296 Replacement = Builder.CreateLogicalAnd(Cond1: LHS, Cond2: RHS);
297 else
298 // x > min(a, b) -> x > a || x > b
299 // x < max(a, b) -> x < a || x < b
300 Replacement = Builder.CreateLogicalOr(Cond1: LHS, Cond2: RHS);
301
302 ICmp->replaceAllUsesWith(V: Replacement);
303
304 Instruction *ToRemove[] = {ICmp, Info.ZExt, Info.SExt, MinMax};
305 for (Instruction *I : ToRemove)
306 if (I && I->use_empty())
307 I->eraseFromParent();
308
309 Changed = true;
310 }
311
312 return Changed;
313}
314
315// Do the following transformation:
316//
317// x < min(a, b) -> x < a && x < b
318// x > min(a, b) -> x > a || x > b
319// x < max(a, b) -> x < a || x < b
320// x > max(a, b) -> x > a && x > b
321//
322// Such patterns are introduced by LICM.cpp:hoistMinMax()
323// transformation and might lead to BPF verification failures for
324// older kernels.
325//
326// To minimize "collateral" changes only do it for icmp + min/max
327// calls when icmp is inside a loop and min/max is outside of that
328// loop.
329//
330// Verification failure happens when:
331// - RHS operand of some `icmp LHS, RHS` is replaced by some RHS1;
332// - verifier can recognize RHS as a constant scalar in some context;
333// - verifier can't recognize RHS1 as a constant scalar in the same
334// context;
335//
336// The "constant scalar" is not a compile time constant, but a register
337// that holds a scalar value known to verifier at some point in time
338// during abstract interpretation.
339//
340// See also:
341// https://lore.kernel.org/bpf/20230406164505.1046801-1-yhs@fb.com/
342bool BPFCheckAndAdjustIR::sinkMinMax(Module &M) {
343 bool Changed = false;
344
345 for (Function &F : M) {
346 if (F.isDeclaration())
347 continue;
348
349 LoopInfo &LI = getAnalysis<LoopInfoWrapperPass>(F).getLoopInfo();
350 for (Loop *L : LI)
351 for (BasicBlock *BB : L->blocks()) {
352 // Filter out instructions coming from the same loop
353 Loop *BBLoop = LI.getLoopFor(BB);
354 auto OtherLoopFilter = [&](Instruction *I) {
355 return LI.getLoopFor(BB: I->getParent()) != BBLoop;
356 };
357 Changed |= sinkMinMaxInBB(BB&: *BB, Filter: OtherLoopFilter);
358 }
359 }
360
361 return Changed;
362}
363
364void BPFCheckAndAdjustIR::getAnalysisUsage(AnalysisUsage &AU) const {
365 AU.addRequired<LoopInfoWrapperPass>();
366}
367
368static void unrollGEPLoad(CallInst *Call) {
369 auto [GEP, Load] = BPFPreserveStaticOffsetPass::reconstructLoad(Call);
370 GEP->insertBefore(InsertPos: Call->getIterator());
371 Load->insertBefore(InsertPos: Call->getIterator());
372 Call->replaceAllUsesWith(V: Load);
373 Call->eraseFromParent();
374}
375
376static void unrollGEPStore(CallInst *Call) {
377 auto [GEP, Store] = BPFPreserveStaticOffsetPass::reconstructStore(Call);
378 GEP->insertBefore(InsertPos: Call->getIterator());
379 Store->insertBefore(InsertPos: Call->getIterator());
380 Call->eraseFromParent();
381}
382
383static bool removeGEPBuiltinsInFunc(Function &F) {
384 SmallVector<CallInst *> GEPLoads;
385 SmallVector<CallInst *> GEPStores;
386 for (auto &BB : F)
387 for (auto &Insn : BB)
388 if (auto *Call = dyn_cast<CallInst>(Val: &Insn))
389 if (auto *Called = Call->getCalledFunction())
390 switch (Called->getIntrinsicID()) {
391 case Intrinsic::bpf_getelementptr_and_load:
392 GEPLoads.push_back(Elt: Call);
393 break;
394 case Intrinsic::bpf_getelementptr_and_store:
395 GEPStores.push_back(Elt: Call);
396 break;
397 }
398
399 if (GEPLoads.empty() && GEPStores.empty())
400 return false;
401
402 for_each(Range&: GEPLoads, F: unrollGEPLoad);
403 for_each(Range&: GEPStores, F: unrollGEPStore);
404
405 return true;
406}
407
408// Rewrites the following builtins:
409// - llvm.bpf.getelementptr.and.load
410// - llvm.bpf.getelementptr.and.store
411// As (load (getelementptr ...)) or (store (getelementptr ...)).
412bool BPFCheckAndAdjustIR::removeGEPBuiltins(Module &M) {
413 bool Changed = false;
414 for (auto &F : M)
415 Changed = removeGEPBuiltinsInFunc(F) || Changed;
416 return Changed;
417}
418
419// Wrap ToWrap with cast to address space zero:
420// - if ToWrap is a getelementptr,
421// wrap it's base pointer instead and return a copy;
422// - if ToWrap is Instruction, insert address space cast
423// immediately after ToWrap;
424// - if ToWrap is not an Instruction (function parameter
425// or a global value), insert address space cast at the
426// beginning of the Function F;
427// - use Cache to avoid inserting too many casts;
428static Value *aspaceWrapValue(DenseMap<Value *, Value *> &Cache, Function *F,
429 Value *ToWrap) {
430 auto It = Cache.find(Val: ToWrap);
431 if (It != Cache.end())
432 return It->getSecond();
433
434 if (auto *GEP = dyn_cast<GetElementPtrInst>(Val: ToWrap)) {
435 Value *Ptr = GEP->getPointerOperand();
436 Value *WrappedPtr = aspaceWrapValue(Cache, F, ToWrap: Ptr);
437 auto *GEPTy = cast<PointerType>(Val: GEP->getType());
438 auto *NewGEP = GEP->clone();
439 NewGEP->insertAfter(InsertPos: GEP->getIterator());
440 NewGEP->mutateType(Ty: PointerType::getUnqual(C&: GEPTy->getContext()));
441 NewGEP->setOperand(i: GEP->getPointerOperandIndex(), Val: WrappedPtr);
442 NewGEP->setName(GEP->getName());
443 Cache[ToWrap] = NewGEP;
444 return NewGEP;
445 }
446
447 IRBuilder IB(F->getContext());
448 if (Instruction *InsnPtr = dyn_cast<Instruction>(Val: ToWrap))
449 IB.SetInsertPoint(*InsnPtr->getInsertionPointAfterDef());
450 else
451 IB.SetInsertPoint(F->getEntryBlock().getFirstInsertionPt());
452 auto *ASZeroPtrTy = IB.getPtrTy(AddrSpace: 0);
453 auto *ACast = IB.CreateAddrSpaceCast(V: ToWrap, DestTy: ASZeroPtrTy, Name: ToWrap->getName());
454 Cache[ToWrap] = ACast;
455 return ACast;
456}
457
458// Wrap a pointer operand OpNum of instruction I
459// with cast to address space zero
460static void aspaceWrapOperand(DenseMap<Value *, Value *> &Cache, Instruction *I,
461 unsigned OpNum) {
462 Value *OldOp = I->getOperand(i: OpNum);
463 if (OldOp->getType()->getPointerAddressSpace() == 0)
464 return;
465
466 Value *NewOp = aspaceWrapValue(Cache, F: I->getFunction(), ToWrap: OldOp);
467 I->setOperand(i: OpNum, Val: NewOp);
468 // Check if there are any remaining users of old GEP,
469 // delete those w/o users
470 for (;;) {
471 auto *OldGEP = dyn_cast<GetElementPtrInst>(Val: OldOp);
472 if (!OldGEP)
473 break;
474 if (!OldGEP->use_empty())
475 break;
476 OldOp = OldGEP->getPointerOperand();
477 OldGEP->eraseFromParent();
478 }
479}
480
481// Support for BPF address spaces:
482// - for each function in the module M, update pointer operand of
483// each memory access instruction (load/store/cmpxchg/atomicrmw)
484// by casting it from non-zero address space to zero address space, e.g:
485//
486// (load (ptr addrspace (N) %p) ...)
487// -> (load (addrspacecast ptr addrspace (N) %p to ptr))
488//
489// - assign section with name .addr_space.N for globals defined in
490// non-zero address space N
491bool BPFCheckAndAdjustIR::insertASpaceCasts(Module &M) {
492 bool Changed = false;
493 for (Function &F : M) {
494 DenseMap<Value *, Value *> CastsCache;
495 for (BasicBlock &BB : F) {
496 for (Instruction &I : BB) {
497 unsigned PtrOpNum;
498
499 if (auto *LD = dyn_cast<LoadInst>(Val: &I))
500 PtrOpNum = LD->getPointerOperandIndex();
501 else if (auto *ST = dyn_cast<StoreInst>(Val: &I))
502 PtrOpNum = ST->getPointerOperandIndex();
503 else if (auto *CmpXchg = dyn_cast<AtomicCmpXchgInst>(Val: &I))
504 PtrOpNum = CmpXchg->getPointerOperandIndex();
505 else if (auto *RMW = dyn_cast<AtomicRMWInst>(Val: &I))
506 PtrOpNum = RMW->getPointerOperandIndex();
507 else
508 continue;
509
510 aspaceWrapOperand(Cache&: CastsCache, I: &I, OpNum: PtrOpNum);
511 }
512 }
513 Changed |= !CastsCache.empty();
514 }
515 // Merge all globals within same address space into single
516 // .addr_space.<addr space no> section
517 for (GlobalVariable &G : M.globals()) {
518 if (G.getAddressSpace() == 0 || G.hasSection())
519 continue;
520 SmallString<16> SecName;
521 raw_svector_ostream OS(SecName);
522 OS << ".addr_space." << G.getAddressSpace();
523 G.setSection(SecName);
524 // Prevent having separate section for constants
525 G.setConstant(false);
526 }
527 return Changed;
528}
529
530bool BPFCheckAndAdjustIR::adjustIR(Module &M) {
531 bool Changed = removePassThroughBuiltin(M);
532 Changed = removeCompareBuiltin(M) || Changed;
533 Changed = sinkMinMax(M) || Changed;
534 Changed = removeGEPBuiltins(M) || Changed;
535 Changed = insertASpaceCasts(M) || Changed;
536 return Changed;
537}
538
539bool BPFCheckAndAdjustIR::runOnModule(Module &M) {
540 checkIR(M);
541 return adjustIR(M);
542}
543