1 | //===-- RISCVInsertReadWriteCSR.cpp - Insert Read/Write of RISC-V CSR -----===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // This file implements the machine function pass to insert read/write of CSR-s |
9 | // of the RISC-V instructions. |
10 | // |
11 | // Currently the pass implements: |
12 | // -Writing and saving frm before an RVV floating-point instruction with a |
13 | // static rounding mode and restores the value after. |
14 | // |
15 | //===----------------------------------------------------------------------===// |
16 | |
17 | #include "MCTargetDesc/RISCVBaseInfo.h" |
18 | #include "RISCV.h" |
19 | #include "RISCVSubtarget.h" |
20 | #include "llvm/CodeGen/MachineFunctionPass.h" |
21 | using namespace llvm; |
22 | |
23 | #define DEBUG_TYPE "riscv-insert-read-write-csr" |
24 | #define RISCV_INSERT_READ_WRITE_CSR_NAME "RISC-V Insert Read/Write CSR Pass" |
25 | |
26 | static cl::opt<bool> |
27 | DisableFRMInsertOpt("riscv-disable-frm-insert-opt" , cl::init(Val: false), |
28 | cl::Hidden, |
29 | cl::desc("Disable optimized frm insertion." )); |
30 | |
31 | namespace { |
32 | |
33 | class RISCVInsertReadWriteCSR : public MachineFunctionPass { |
34 | const TargetInstrInfo *TII; |
35 | |
36 | public: |
37 | static char ID; |
38 | |
39 | RISCVInsertReadWriteCSR() : MachineFunctionPass(ID) {} |
40 | |
41 | bool runOnMachineFunction(MachineFunction &MF) override; |
42 | |
43 | void getAnalysisUsage(AnalysisUsage &AU) const override { |
44 | AU.setPreservesCFG(); |
45 | MachineFunctionPass::getAnalysisUsage(AU); |
46 | } |
47 | |
48 | StringRef getPassName() const override { |
49 | return RISCV_INSERT_READ_WRITE_CSR_NAME; |
50 | } |
51 | |
52 | private: |
53 | bool emitWriteRoundingMode(MachineBasicBlock &MBB); |
54 | bool emitWriteRoundingModeOpt(MachineBasicBlock &MBB); |
55 | }; |
56 | |
57 | } // end anonymous namespace |
58 | |
59 | char RISCVInsertReadWriteCSR::ID = 0; |
60 | |
61 | INITIALIZE_PASS(RISCVInsertReadWriteCSR, DEBUG_TYPE, |
62 | RISCV_INSERT_READ_WRITE_CSR_NAME, false, false) |
63 | |
64 | // TODO: Use more accurate rounding mode at the start of MBB. |
65 | bool RISCVInsertReadWriteCSR::emitWriteRoundingModeOpt(MachineBasicBlock &MBB) { |
66 | bool Changed = false; |
67 | MachineInstr *LastFRMChanger = nullptr; |
68 | unsigned CurrentRM = RISCVFPRndMode::DYN; |
69 | Register SavedFRM; |
70 | |
71 | for (MachineInstr &MI : MBB) { |
72 | if (MI.getOpcode() == RISCV::SwapFRMImm || |
73 | MI.getOpcode() == RISCV::WriteFRMImm) { |
74 | CurrentRM = MI.getOperand(i: 0).getImm(); |
75 | SavedFRM = Register(); |
76 | continue; |
77 | } |
78 | |
79 | if (MI.getOpcode() == RISCV::WriteFRM) { |
80 | CurrentRM = RISCVFPRndMode::DYN; |
81 | SavedFRM = Register(); |
82 | continue; |
83 | } |
84 | |
85 | if (MI.isCall() || MI.isInlineAsm() || |
86 | MI.readsRegister(Reg: RISCV::FRM, /*TRI=*/nullptr)) { |
87 | // Restore FRM before unknown operations. |
88 | if (SavedFRM.isValid()) |
89 | BuildMI(BB&: MBB, I&: MI, MIMD: MI.getDebugLoc(), MCID: TII->get(Opcode: RISCV::WriteFRM)) |
90 | .addReg(RegNo: SavedFRM); |
91 | CurrentRM = RISCVFPRndMode::DYN; |
92 | SavedFRM = Register(); |
93 | continue; |
94 | } |
95 | |
96 | assert(!MI.modifiesRegister(RISCV::FRM, /*TRI=*/nullptr) && |
97 | "Expected that MI could not modify FRM." ); |
98 | |
99 | int FRMIdx = RISCVII::getFRMOpNum(Desc: MI.getDesc()); |
100 | if (FRMIdx < 0) |
101 | continue; |
102 | unsigned InstrRM = MI.getOperand(i: FRMIdx).getImm(); |
103 | |
104 | LastFRMChanger = &MI; |
105 | |
106 | // Make MI implicit use FRM. |
107 | MI.addOperand(Op: MachineOperand::CreateReg(Reg: RISCV::FRM, /*IsDef*/ isDef: false, |
108 | /*IsImp*/ isImp: true)); |
109 | Changed = true; |
110 | |
111 | // Skip if MI uses same rounding mode as FRM. |
112 | if (InstrRM == CurrentRM) |
113 | continue; |
114 | |
115 | if (!SavedFRM.isValid()) { |
116 | // Save current FRM value to SavedFRM. |
117 | MachineRegisterInfo *MRI = &MBB.getParent()->getRegInfo(); |
118 | SavedFRM = MRI->createVirtualRegister(RegClass: &RISCV::GPRRegClass); |
119 | BuildMI(BB&: MBB, I&: MI, MIMD: MI.getDebugLoc(), MCID: TII->get(Opcode: RISCV::SwapFRMImm), DestReg: SavedFRM) |
120 | .addImm(Val: InstrRM); |
121 | } else { |
122 | // Don't need to save current FRM when SavedFRM having value. |
123 | BuildMI(BB&: MBB, I&: MI, MIMD: MI.getDebugLoc(), MCID: TII->get(Opcode: RISCV::WriteFRMImm)) |
124 | .addImm(Val: InstrRM); |
125 | } |
126 | CurrentRM = InstrRM; |
127 | } |
128 | |
129 | // Restore FRM if needed. |
130 | if (SavedFRM.isValid()) { |
131 | assert(LastFRMChanger && "Expected valid pointer." ); |
132 | MachineInstrBuilder MIB = |
133 | BuildMI(MF&: *MBB.getParent(), MIMD: {}, MCID: TII->get(Opcode: RISCV::WriteFRM)) |
134 | .addReg(RegNo: SavedFRM); |
135 | MBB.insertAfter(I: LastFRMChanger, MI: MIB); |
136 | } |
137 | |
138 | return Changed; |
139 | } |
140 | |
141 | // This function also swaps frm and restores it when encountering an RVV |
142 | // floating point instruction with a static rounding mode. |
143 | bool RISCVInsertReadWriteCSR::emitWriteRoundingMode(MachineBasicBlock &MBB) { |
144 | bool Changed = false; |
145 | for (MachineInstr &MI : MBB) { |
146 | int FRMIdx = RISCVII::getFRMOpNum(Desc: MI.getDesc()); |
147 | if (FRMIdx < 0) |
148 | continue; |
149 | |
150 | unsigned FRMImm = MI.getOperand(i: FRMIdx).getImm(); |
151 | |
152 | // The value is a hint to this pass to not alter the frm value. |
153 | if (FRMImm == RISCVFPRndMode::DYN) |
154 | continue; |
155 | |
156 | Changed = true; |
157 | |
158 | // Save |
159 | MachineRegisterInfo *MRI = &MBB.getParent()->getRegInfo(); |
160 | Register SavedFRM = MRI->createVirtualRegister(RegClass: &RISCV::GPRRegClass); |
161 | BuildMI(BB&: MBB, I&: MI, MIMD: MI.getDebugLoc(), MCID: TII->get(Opcode: RISCV::SwapFRMImm), |
162 | DestReg: SavedFRM) |
163 | .addImm(Val: FRMImm); |
164 | MI.addOperand(Op: MachineOperand::CreateReg(Reg: RISCV::FRM, /*IsDef*/ isDef: false, |
165 | /*IsImp*/ isImp: true)); |
166 | // Restore |
167 | MachineInstrBuilder MIB = |
168 | BuildMI(MF&: *MBB.getParent(), MIMD: {}, MCID: TII->get(Opcode: RISCV::WriteFRM)) |
169 | .addReg(RegNo: SavedFRM); |
170 | MBB.insertAfter(I: MI, MI: MIB); |
171 | } |
172 | return Changed; |
173 | } |
174 | |
175 | bool RISCVInsertReadWriteCSR::runOnMachineFunction(MachineFunction &MF) { |
176 | // Skip if the vector extension is not enabled. |
177 | const RISCVSubtarget &ST = MF.getSubtarget<RISCVSubtarget>(); |
178 | if (!ST.hasVInstructions()) |
179 | return false; |
180 | |
181 | TII = ST.getInstrInfo(); |
182 | |
183 | bool Changed = false; |
184 | |
185 | for (MachineBasicBlock &MBB : MF) { |
186 | if (DisableFRMInsertOpt) |
187 | Changed |= emitWriteRoundingMode(MBB); |
188 | else |
189 | Changed |= emitWriteRoundingModeOpt(MBB); |
190 | } |
191 | |
192 | return Changed; |
193 | } |
194 | |
195 | FunctionPass *llvm::createRISCVInsertReadWriteCSRPass() { |
196 | return new RISCVInsertReadWriteCSR(); |
197 | } |
198 | |