1 | //===- MachineUniformityAnalysis.cpp --------------------------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "llvm/CodeGen/MachineUniformityAnalysis.h" |
10 | #include "llvm/ADT/GenericUniformityImpl.h" |
11 | #include "llvm/Analysis/TargetTransformInfo.h" |
12 | #include "llvm/CodeGen/MachineCycleAnalysis.h" |
13 | #include "llvm/CodeGen/MachineDominators.h" |
14 | #include "llvm/CodeGen/MachineRegisterInfo.h" |
15 | #include "llvm/CodeGen/MachineSSAContext.h" |
16 | #include "llvm/CodeGen/TargetInstrInfo.h" |
17 | #include "llvm/InitializePasses.h" |
18 | |
19 | using namespace llvm; |
20 | |
21 | template <> |
22 | bool llvm::GenericUniformityAnalysisImpl<MachineSSAContext>::hasDivergentDefs( |
23 | const MachineInstr &I) const { |
24 | for (auto &op : I.all_defs()) { |
25 | if (isDivergent(V: op.getReg())) |
26 | return true; |
27 | } |
28 | return false; |
29 | } |
30 | |
31 | template <> |
32 | bool llvm::GenericUniformityAnalysisImpl<MachineSSAContext>::markDefsDivergent( |
33 | const MachineInstr &Instr) { |
34 | bool insertedDivergent = false; |
35 | const auto &MRI = F.getRegInfo(); |
36 | const auto &RBI = *F.getSubtarget().getRegBankInfo(); |
37 | const auto &TRI = *MRI.getTargetRegisterInfo(); |
38 | for (auto &op : Instr.all_defs()) { |
39 | if (!op.getReg().isVirtual()) |
40 | continue; |
41 | assert(!op.getSubReg()); |
42 | if (TRI.isUniformReg(MRI, RBI, Reg: op.getReg())) |
43 | continue; |
44 | insertedDivergent |= markDivergent(Val: op.getReg()); |
45 | } |
46 | return insertedDivergent; |
47 | } |
48 | |
49 | template <> |
50 | void llvm::GenericUniformityAnalysisImpl<MachineSSAContext>::initialize() { |
51 | const auto &InstrInfo = *F.getSubtarget().getInstrInfo(); |
52 | |
53 | for (const MachineBasicBlock &block : F) { |
54 | for (const MachineInstr &instr : block) { |
55 | auto uniformity = InstrInfo.getInstructionUniformity(MI: instr); |
56 | if (uniformity == InstructionUniformity::AlwaysUniform) { |
57 | addUniformOverride(Instr: instr); |
58 | continue; |
59 | } |
60 | |
61 | if (uniformity == InstructionUniformity::NeverUniform) { |
62 | markDivergent(I: instr); |
63 | } |
64 | } |
65 | } |
66 | } |
67 | |
68 | template <> |
69 | void llvm::GenericUniformityAnalysisImpl<MachineSSAContext>::pushUsers( |
70 | Register Reg) { |
71 | assert(isDivergent(Reg)); |
72 | const auto &RegInfo = F.getRegInfo(); |
73 | for (MachineInstr &UserInstr : RegInfo.use_instructions(Reg)) { |
74 | markDivergent(I: UserInstr); |
75 | } |
76 | } |
77 | |
78 | template <> |
79 | void llvm::GenericUniformityAnalysisImpl<MachineSSAContext>::pushUsers( |
80 | const MachineInstr &Instr) { |
81 | assert(!isAlwaysUniform(Instr)); |
82 | if (Instr.isTerminator()) |
83 | return; |
84 | for (const MachineOperand &op : Instr.all_defs()) { |
85 | auto Reg = op.getReg(); |
86 | if (isDivergent(V: Reg)) |
87 | pushUsers(Reg); |
88 | } |
89 | } |
90 | |
91 | template <> |
92 | bool llvm::GenericUniformityAnalysisImpl<MachineSSAContext>::usesValueFromCycle( |
93 | const MachineInstr &I, const MachineCycle &DefCycle) const { |
94 | assert(!isAlwaysUniform(I)); |
95 | for (auto &Op : I.operands()) { |
96 | if (!Op.isReg() || !Op.readsReg()) |
97 | continue; |
98 | auto Reg = Op.getReg(); |
99 | |
100 | // FIXME: Physical registers need to be properly checked instead of always |
101 | // returning true |
102 | if (Reg.isPhysical()) |
103 | return true; |
104 | |
105 | auto *Def = F.getRegInfo().getVRegDef(Reg); |
106 | if (DefCycle.contains(Block: Def->getParent())) |
107 | return true; |
108 | } |
109 | return false; |
110 | } |
111 | |
112 | template <> |
113 | void llvm::GenericUniformityAnalysisImpl<MachineSSAContext>:: |
114 | propagateTemporalDivergence(const MachineInstr &I, |
115 | const MachineCycle &DefCycle) { |
116 | const auto &RegInfo = F.getRegInfo(); |
117 | for (auto &Op : I.all_defs()) { |
118 | if (!Op.getReg().isVirtual()) |
119 | continue; |
120 | auto Reg = Op.getReg(); |
121 | for (MachineInstr &UserInstr : RegInfo.use_instructions(Reg)) { |
122 | if (DefCycle.contains(Block: UserInstr.getParent())) |
123 | continue; |
124 | markDivergent(I: UserInstr); |
125 | |
126 | recordTemporalDivergence(Val: Reg, User: &UserInstr, Cycle: &DefCycle); |
127 | } |
128 | } |
129 | } |
130 | |
131 | template <> |
132 | bool llvm::GenericUniformityAnalysisImpl<MachineSSAContext>::isDivergentUse( |
133 | const MachineOperand &U) const { |
134 | if (!U.isReg()) |
135 | return false; |
136 | |
137 | auto Reg = U.getReg(); |
138 | if (isDivergent(V: Reg)) |
139 | return true; |
140 | |
141 | const auto &RegInfo = F.getRegInfo(); |
142 | auto *Def = RegInfo.getOneDef(Reg); |
143 | if (!Def) |
144 | return true; |
145 | |
146 | auto *DefInstr = Def->getParent(); |
147 | auto *UseInstr = U.getParent(); |
148 | return isTemporalDivergent(ObservingBlock: *UseInstr->getParent(), Def: *DefInstr); |
149 | } |
150 | |
151 | // This ensures explicit instantiation of |
152 | // GenericUniformityAnalysisImpl::ImplDeleter::operator() |
153 | template class llvm::GenericUniformityInfo<MachineSSAContext>; |
154 | template struct llvm::GenericUniformityAnalysisImplDeleter< |
155 | llvm::GenericUniformityAnalysisImpl<MachineSSAContext>>; |
156 | |
157 | MachineUniformityInfo llvm::computeMachineUniformityInfo( |
158 | MachineFunction &F, const MachineCycleInfo &cycleInfo, |
159 | const MachineDominatorTree &domTree, bool HasBranchDivergence) { |
160 | assert(F.getRegInfo().isSSA() && "Expected to be run on SSA form!" ); |
161 | MachineUniformityInfo UI(domTree, cycleInfo); |
162 | if (HasBranchDivergence) |
163 | UI.compute(); |
164 | return UI; |
165 | } |
166 | |
167 | namespace { |
168 | |
169 | class MachineUniformityInfoPrinterPass : public MachineFunctionPass { |
170 | public: |
171 | static char ID; |
172 | |
173 | MachineUniformityInfoPrinterPass(); |
174 | |
175 | bool runOnMachineFunction(MachineFunction &F) override; |
176 | void getAnalysisUsage(AnalysisUsage &AU) const override; |
177 | }; |
178 | |
179 | } // namespace |
180 | |
181 | AnalysisKey MachineUniformityAnalysis::Key; |
182 | |
183 | MachineUniformityAnalysis::Result |
184 | MachineUniformityAnalysis::run(MachineFunction &MF, |
185 | MachineFunctionAnalysisManager &MFAM) { |
186 | auto &DomTree = MFAM.getResult<MachineDominatorTreeAnalysis>(IR&: MF); |
187 | auto &CI = MFAM.getResult<MachineCycleAnalysis>(IR&: MF); |
188 | auto &FAM = MFAM.getResult<FunctionAnalysisManagerMachineFunctionProxy>(IR&: MF) |
189 | .getManager(); |
190 | auto &F = MF.getFunction(); |
191 | auto &TTI = FAM.getResult<TargetIRAnalysis>(IR&: F); |
192 | return computeMachineUniformityInfo(F&: MF, cycleInfo: CI, domTree: DomTree, |
193 | HasBranchDivergence: TTI.hasBranchDivergence(F: &F)); |
194 | } |
195 | |
196 | PreservedAnalyses |
197 | MachineUniformityPrinterPass::run(MachineFunction &MF, |
198 | MachineFunctionAnalysisManager &MFAM) { |
199 | auto &MUI = MFAM.getResult<MachineUniformityAnalysis>(IR&: MF); |
200 | OS << "MachineUniformityInfo for function: " ; |
201 | MF.getFunction().printAsOperand(O&: OS, /*PrintType=*/false); |
202 | OS << '\n'; |
203 | MUI.print(out&: OS); |
204 | return PreservedAnalyses::all(); |
205 | } |
206 | |
207 | char MachineUniformityAnalysisPass::ID = 0; |
208 | |
209 | MachineUniformityAnalysisPass::MachineUniformityAnalysisPass() |
210 | : MachineFunctionPass(ID) { |
211 | initializeMachineUniformityAnalysisPassPass(*PassRegistry::getPassRegistry()); |
212 | } |
213 | |
214 | INITIALIZE_PASS_BEGIN(MachineUniformityAnalysisPass, "machine-uniformity" , |
215 | "Machine Uniformity Info Analysis" , false, true) |
216 | INITIALIZE_PASS_DEPENDENCY(MachineCycleInfoWrapperPass) |
217 | INITIALIZE_PASS_DEPENDENCY(MachineDominatorTreeWrapperPass) |
218 | INITIALIZE_PASS_END(MachineUniformityAnalysisPass, "machine-uniformity" , |
219 | "Machine Uniformity Info Analysis" , false, true) |
220 | |
221 | void MachineUniformityAnalysisPass::getAnalysisUsage(AnalysisUsage &AU) const { |
222 | AU.setPreservesAll(); |
223 | AU.addRequiredTransitive<MachineCycleInfoWrapperPass>(); |
224 | AU.addRequired<MachineDominatorTreeWrapperPass>(); |
225 | MachineFunctionPass::getAnalysisUsage(AU); |
226 | } |
227 | |
228 | bool MachineUniformityAnalysisPass::runOnMachineFunction(MachineFunction &MF) { |
229 | auto &DomTree = getAnalysis<MachineDominatorTreeWrapperPass>().getDomTree(); |
230 | auto &CI = getAnalysis<MachineCycleInfoWrapperPass>().getCycleInfo(); |
231 | // FIXME: Query TTI::hasBranchDivergence. -run-pass seems to end up with a |
232 | // default NoTTI |
233 | UI = computeMachineUniformityInfo(F&: MF, cycleInfo: CI, domTree: DomTree, HasBranchDivergence: true); |
234 | return false; |
235 | } |
236 | |
237 | void MachineUniformityAnalysisPass::print(raw_ostream &OS, |
238 | const Module *) const { |
239 | OS << "MachineUniformityInfo for function: " ; |
240 | UI.getFunction().getFunction().printAsOperand(O&: OS, /*PrintType=*/false); |
241 | OS << '\n'; |
242 | UI.print(out&: OS); |
243 | } |
244 | |
245 | char MachineUniformityInfoPrinterPass::ID = 0; |
246 | |
247 | MachineUniformityInfoPrinterPass::MachineUniformityInfoPrinterPass() |
248 | : MachineFunctionPass(ID) { |
249 | initializeMachineUniformityInfoPrinterPassPass( |
250 | *PassRegistry::getPassRegistry()); |
251 | } |
252 | |
253 | INITIALIZE_PASS_BEGIN(MachineUniformityInfoPrinterPass, |
254 | "print-machine-uniformity" , |
255 | "Print Machine Uniformity Info Analysis" , true, true) |
256 | INITIALIZE_PASS_DEPENDENCY(MachineUniformityAnalysisPass) |
257 | INITIALIZE_PASS_END(MachineUniformityInfoPrinterPass, |
258 | "print-machine-uniformity" , |
259 | "Print Machine Uniformity Info Analysis" , true, true) |
260 | |
261 | void MachineUniformityInfoPrinterPass::getAnalysisUsage( |
262 | AnalysisUsage &AU) const { |
263 | AU.setPreservesAll(); |
264 | AU.addRequired<MachineUniformityAnalysisPass>(); |
265 | MachineFunctionPass::getAnalysisUsage(AU); |
266 | } |
267 | |
268 | bool MachineUniformityInfoPrinterPass::runOnMachineFunction( |
269 | MachineFunction &F) { |
270 | auto &UI = getAnalysis<MachineUniformityAnalysisPass>(); |
271 | UI.print(OS&: errs()); |
272 | return false; |
273 | } |
274 | |