1//===------ RISCVIndirectBranchTracking.cpp - Enables lpad mechanism ------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// The pass adds LPAD (AUIPC with rd = X0) machine instructions at the
10// beginning of each basic block or function that is referenced by an indirect
11// jump/call instruction.
12//
13//===----------------------------------------------------------------------===//
14
15#include "RISCV.h"
16#include "RISCVInstrInfo.h"
17#include "RISCVSubtarget.h"
18#include "llvm/ADT/Statistic.h"
19#include "llvm/CodeGen/MachineBasicBlock.h"
20#include "llvm/CodeGen/MachineFunctionPass.h"
21#include "llvm/CodeGen/MachineInstrBuilder.h"
22#include "llvm/CodeGen/MachineModuleInfo.h"
23
24#define DEBUG_TYPE "riscv-indirect-branch-tracking"
25#define PASS_NAME "RISC-V Indirect Branch Tracking"
26
27using namespace llvm;
28
29cl::opt<uint32_t> PreferredLandingPadLabel(
30 "riscv-landing-pad-label", cl::ReallyHidden,
31 cl::desc("Use preferred fixed label for all labels"));
32
33namespace {
34class RISCVIndirectBranchTracking : public MachineFunctionPass {
35public:
36 static char ID;
37 RISCVIndirectBranchTracking() : MachineFunctionPass(ID) {}
38
39 StringRef getPassName() const override { return PASS_NAME; }
40
41 bool runOnMachineFunction(MachineFunction &MF) override;
42
43private:
44 const Align LpadAlign = Align(4);
45};
46
47} // end anonymous namespace
48
49INITIALIZE_PASS(RISCVIndirectBranchTracking, DEBUG_TYPE, PASS_NAME, false,
50 false)
51
52char RISCVIndirectBranchTracking::ID = 0;
53
54FunctionPass *llvm::createRISCVIndirectBranchTrackingPass() {
55 return new RISCVIndirectBranchTracking();
56}
57
58static void
59emitLpad(MachineBasicBlock &MBB, const RISCVInstrInfo *TII, uint32_t Label,
60 MachineBasicBlock::iterator I = MachineBasicBlock::iterator{}) {
61 if (!I.isValid())
62 I = MBB.begin();
63 BuildMI(BB&: MBB, I, MIMD: MBB.findDebugLoc(MBBI: I), MCID: TII->get(Opcode: RISCV::AUIPC), DestReg: RISCV::X0)
64 .addImm(Val: Label);
65}
66
67static bool isCallReturnTwice(const MachineOperand &MOp) {
68 if (!MOp.isGlobal())
69 return false;
70 auto *CalleeFn = dyn_cast<Function>(Val: MOp.getGlobal());
71 if (!CalleeFn)
72 return false;
73 AttributeList Attrs = CalleeFn->getAttributes();
74 return Attrs.hasFnAttr(Kind: Attribute::ReturnsTwice);
75}
76
77bool RISCVIndirectBranchTracking::runOnMachineFunction(MachineFunction &MF) {
78 const auto &Subtarget = MF.getSubtarget<RISCVSubtarget>();
79 const RISCVInstrInfo *TII = Subtarget.getInstrInfo();
80 if (!Subtarget.hasStdExtZicfilp())
81 return false;
82
83 uint32_t FixedLabel = 0;
84 if (PreferredLandingPadLabel.getNumOccurrences() > 0) {
85 if (!isUInt<20>(x: PreferredLandingPadLabel))
86 report_fatal_error(reason: "riscv-landing-pad-label=<val>, <val> needs to fit in "
87 "unsigned 20-bits");
88 FixedLabel = PreferredLandingPadLabel;
89 }
90
91 bool Changed = false;
92 for (MachineBasicBlock &MBB : MF) {
93 if (&MBB == &MF.front()) {
94 Function &F = MF.getFunction();
95 // When trap is taken, landing pad is not needed.
96 if (F.hasFnAttribute(Kind: "interrupt"))
97 continue;
98
99 if (F.hasAddressTaken() || !F.hasLocalLinkage()) {
100 emitLpad(MBB, TII, Label: FixedLabel);
101 if (MF.getAlignment() < LpadAlign)
102 MF.setAlignment(LpadAlign);
103 Changed = true;
104 }
105 continue;
106 }
107
108 if (MBB.hasAddressTaken()) {
109 emitLpad(MBB, TII, Label: FixedLabel);
110 if (MBB.getAlignment() < LpadAlign)
111 MBB.setAlignment(LpadAlign);
112 Changed = true;
113 }
114 }
115
116 // Check for calls to functions with ReturnsTwice attribute and insert
117 // LPAD after such calls
118 for (MachineBasicBlock &MBB : MF) {
119 for (MachineBasicBlock::iterator I = MBB.begin(); I != MBB.end(); ++I) {
120 if (I->isCall() && I->getNumOperands() > 0 &&
121 isCallReturnTwice(MOp: I->getOperand(i: 0))) {
122 auto NextI = std::next(x: I);
123 emitLpad(MBB, TII, Label: FixedLabel, I: NextI);
124 Changed = true;
125 }
126 }
127 }
128
129 return Changed;
130}
131