1 | //===- MachineUniformityAnalysis.cpp --------------------------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "llvm/CodeGen/MachineUniformityAnalysis.h" |
10 | #include "llvm/ADT/GenericUniformityImpl.h" |
11 | #include "llvm/CodeGen/MachineCycleAnalysis.h" |
12 | #include "llvm/CodeGen/MachineDominators.h" |
13 | #include "llvm/CodeGen/MachineRegisterInfo.h" |
14 | #include "llvm/CodeGen/MachineSSAContext.h" |
15 | #include "llvm/CodeGen/TargetInstrInfo.h" |
16 | #include "llvm/InitializePasses.h" |
17 | |
18 | using namespace llvm; |
19 | |
20 | template <> |
21 | bool llvm::GenericUniformityAnalysisImpl<MachineSSAContext>::hasDivergentDefs( |
22 | const MachineInstr &I) const { |
23 | for (auto &op : I.all_defs()) { |
24 | if (isDivergent(V: op.getReg())) |
25 | return true; |
26 | } |
27 | return false; |
28 | } |
29 | |
30 | template <> |
31 | bool llvm::GenericUniformityAnalysisImpl<MachineSSAContext>::markDefsDivergent( |
32 | const MachineInstr &Instr) { |
33 | bool insertedDivergent = false; |
34 | const auto &MRI = F.getRegInfo(); |
35 | const auto &RBI = *F.getSubtarget().getRegBankInfo(); |
36 | const auto &TRI = *MRI.getTargetRegisterInfo(); |
37 | for (auto &op : Instr.all_defs()) { |
38 | if (!op.getReg().isVirtual()) |
39 | continue; |
40 | assert(!op.getSubReg()); |
41 | if (TRI.isUniformReg(MRI, RBI, Reg: op.getReg())) |
42 | continue; |
43 | insertedDivergent |= markDivergent(DivVal: op.getReg()); |
44 | } |
45 | return insertedDivergent; |
46 | } |
47 | |
48 | template <> |
49 | void llvm::GenericUniformityAnalysisImpl<MachineSSAContext>::initialize() { |
50 | const auto &InstrInfo = *F.getSubtarget().getInstrInfo(); |
51 | |
52 | for (const MachineBasicBlock &block : F) { |
53 | for (const MachineInstr &instr : block) { |
54 | auto uniformity = InstrInfo.getInstructionUniformity(MI: instr); |
55 | if (uniformity == InstructionUniformity::AlwaysUniform) { |
56 | addUniformOverride(Instr: instr); |
57 | continue; |
58 | } |
59 | |
60 | if (uniformity == InstructionUniformity::NeverUniform) { |
61 | markDivergent(I: instr); |
62 | } |
63 | } |
64 | } |
65 | } |
66 | |
67 | template <> |
68 | void llvm::GenericUniformityAnalysisImpl<MachineSSAContext>::pushUsers( |
69 | Register Reg) { |
70 | assert(isDivergent(Reg)); |
71 | const auto &RegInfo = F.getRegInfo(); |
72 | for (MachineInstr &UserInstr : RegInfo.use_instructions(Reg)) { |
73 | markDivergent(I: UserInstr); |
74 | } |
75 | } |
76 | |
77 | template <> |
78 | void llvm::GenericUniformityAnalysisImpl<MachineSSAContext>::pushUsers( |
79 | const MachineInstr &Instr) { |
80 | assert(!isAlwaysUniform(Instr)); |
81 | if (Instr.isTerminator()) |
82 | return; |
83 | for (const MachineOperand &op : Instr.all_defs()) { |
84 | auto Reg = op.getReg(); |
85 | if (isDivergent(V: Reg)) |
86 | pushUsers(Reg); |
87 | } |
88 | } |
89 | |
90 | template <> |
91 | bool llvm::GenericUniformityAnalysisImpl<MachineSSAContext>::usesValueFromCycle( |
92 | const MachineInstr &I, const MachineCycle &DefCycle) const { |
93 | assert(!isAlwaysUniform(I)); |
94 | for (auto &Op : I.operands()) { |
95 | if (!Op.isReg() || !Op.readsReg()) |
96 | continue; |
97 | auto Reg = Op.getReg(); |
98 | |
99 | // FIXME: Physical registers need to be properly checked instead of always |
100 | // returning true |
101 | if (Reg.isPhysical()) |
102 | return true; |
103 | |
104 | auto *Def = F.getRegInfo().getVRegDef(Reg); |
105 | if (DefCycle.contains(Block: Def->getParent())) |
106 | return true; |
107 | } |
108 | return false; |
109 | } |
110 | |
111 | template <> |
112 | void llvm::GenericUniformityAnalysisImpl<MachineSSAContext>:: |
113 | propagateTemporalDivergence(const MachineInstr &I, |
114 | const MachineCycle &DefCycle) { |
115 | const auto &RegInfo = F.getRegInfo(); |
116 | for (auto &Op : I.all_defs()) { |
117 | if (!Op.getReg().isVirtual()) |
118 | continue; |
119 | auto Reg = Op.getReg(); |
120 | if (isDivergent(V: Reg)) |
121 | continue; |
122 | for (MachineInstr &UserInstr : RegInfo.use_instructions(Reg)) { |
123 | if (DefCycle.contains(Block: UserInstr.getParent())) |
124 | continue; |
125 | markDivergent(I: UserInstr); |
126 | } |
127 | } |
128 | } |
129 | |
130 | template <> |
131 | bool llvm::GenericUniformityAnalysisImpl<MachineSSAContext>::isDivergentUse( |
132 | const MachineOperand &U) const { |
133 | if (!U.isReg()) |
134 | return false; |
135 | |
136 | auto Reg = U.getReg(); |
137 | if (isDivergent(V: Reg)) |
138 | return true; |
139 | |
140 | const auto &RegInfo = F.getRegInfo(); |
141 | auto *Def = RegInfo.getOneDef(Reg); |
142 | if (!Def) |
143 | return true; |
144 | |
145 | auto *DefInstr = Def->getParent(); |
146 | auto *UseInstr = U.getParent(); |
147 | return isTemporalDivergent(ObservingBlock: *UseInstr->getParent(), Def: *DefInstr); |
148 | } |
149 | |
150 | // This ensures explicit instantiation of |
151 | // GenericUniformityAnalysisImpl::ImplDeleter::operator() |
152 | template class llvm::GenericUniformityInfo<MachineSSAContext>; |
153 | template struct llvm::GenericUniformityAnalysisImplDeleter< |
154 | llvm::GenericUniformityAnalysisImpl<MachineSSAContext>>; |
155 | |
156 | MachineUniformityInfo llvm::computeMachineUniformityInfo( |
157 | MachineFunction &F, const MachineCycleInfo &cycleInfo, |
158 | const MachineDominatorTree &domTree, bool HasBranchDivergence) { |
159 | assert(F.getRegInfo().isSSA() && "Expected to be run on SSA form!" ); |
160 | MachineUniformityInfo UI(domTree, cycleInfo); |
161 | if (HasBranchDivergence) |
162 | UI.compute(); |
163 | return UI; |
164 | } |
165 | |
166 | namespace { |
167 | |
168 | class MachineUniformityInfoPrinterPass : public MachineFunctionPass { |
169 | public: |
170 | static char ID; |
171 | |
172 | MachineUniformityInfoPrinterPass(); |
173 | |
174 | bool runOnMachineFunction(MachineFunction &F) override; |
175 | void getAnalysisUsage(AnalysisUsage &AU) const override; |
176 | }; |
177 | |
178 | } // namespace |
179 | |
180 | char MachineUniformityAnalysisPass::ID = 0; |
181 | |
182 | MachineUniformityAnalysisPass::MachineUniformityAnalysisPass() |
183 | : MachineFunctionPass(ID) { |
184 | initializeMachineUniformityAnalysisPassPass(*PassRegistry::getPassRegistry()); |
185 | } |
186 | |
187 | INITIALIZE_PASS_BEGIN(MachineUniformityAnalysisPass, "machine-uniformity" , |
188 | "Machine Uniformity Info Analysis" , true, true) |
189 | INITIALIZE_PASS_DEPENDENCY(MachineCycleInfoWrapperPass) |
190 | INITIALIZE_PASS_DEPENDENCY(MachineDominatorTreeWrapperPass) |
191 | INITIALIZE_PASS_END(MachineUniformityAnalysisPass, "machine-uniformity" , |
192 | "Machine Uniformity Info Analysis" , true, true) |
193 | |
194 | void MachineUniformityAnalysisPass::getAnalysisUsage(AnalysisUsage &AU) const { |
195 | AU.setPreservesAll(); |
196 | AU.addRequired<MachineCycleInfoWrapperPass>(); |
197 | AU.addRequired<MachineDominatorTreeWrapperPass>(); |
198 | MachineFunctionPass::getAnalysisUsage(AU); |
199 | } |
200 | |
201 | bool MachineUniformityAnalysisPass::runOnMachineFunction(MachineFunction &MF) { |
202 | auto &DomTree = |
203 | getAnalysis<MachineDominatorTreeWrapperPass>().getDomTree().getBase(); |
204 | auto &CI = getAnalysis<MachineCycleInfoWrapperPass>().getCycleInfo(); |
205 | // FIXME: Query TTI::hasBranchDivergence. -run-pass seems to end up with a |
206 | // default NoTTI |
207 | UI = computeMachineUniformityInfo(F&: MF, cycleInfo: CI, domTree: DomTree, HasBranchDivergence: true); |
208 | return false; |
209 | } |
210 | |
211 | void MachineUniformityAnalysisPass::print(raw_ostream &OS, |
212 | const Module *) const { |
213 | OS << "MachineUniformityInfo for function: " << UI.getFunction().getName() |
214 | << "\n" ; |
215 | UI.print(out&: OS); |
216 | } |
217 | |
218 | char MachineUniformityInfoPrinterPass::ID = 0; |
219 | |
220 | MachineUniformityInfoPrinterPass::MachineUniformityInfoPrinterPass() |
221 | : MachineFunctionPass(ID) { |
222 | initializeMachineUniformityInfoPrinterPassPass( |
223 | *PassRegistry::getPassRegistry()); |
224 | } |
225 | |
226 | INITIALIZE_PASS_BEGIN(MachineUniformityInfoPrinterPass, |
227 | "print-machine-uniformity" , |
228 | "Print Machine Uniformity Info Analysis" , true, true) |
229 | INITIALIZE_PASS_DEPENDENCY(MachineUniformityAnalysisPass) |
230 | INITIALIZE_PASS_END(MachineUniformityInfoPrinterPass, |
231 | "print-machine-uniformity" , |
232 | "Print Machine Uniformity Info Analysis" , true, true) |
233 | |
234 | void MachineUniformityInfoPrinterPass::getAnalysisUsage( |
235 | AnalysisUsage &AU) const { |
236 | AU.setPreservesAll(); |
237 | AU.addRequired<MachineUniformityAnalysisPass>(); |
238 | MachineFunctionPass::getAnalysisUsage(AU); |
239 | } |
240 | |
241 | bool MachineUniformityInfoPrinterPass::runOnMachineFunction( |
242 | MachineFunction &F) { |
243 | auto &UI = getAnalysis<MachineUniformityAnalysisPass>(); |
244 | UI.print(OS&: errs()); |
245 | return false; |
246 | } |
247 | |