1//===- MachineUniformityAnalysis.cpp --------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "llvm/CodeGen/MachineUniformityAnalysis.h"
10#include "llvm/ADT/GenericUniformityImpl.h"
11#include "llvm/Analysis/TargetTransformInfo.h"
12#include "llvm/CodeGen/MachineCycleAnalysis.h"
13#include "llvm/CodeGen/MachineDominators.h"
14#include "llvm/CodeGen/MachineRegisterInfo.h"
15#include "llvm/CodeGen/MachineSSAContext.h"
16#include "llvm/CodeGen/TargetInstrInfo.h"
17#include "llvm/InitializePasses.h"
18
19using namespace llvm;
20
21template <>
22bool llvm::GenericUniformityAnalysisImpl<MachineSSAContext>::hasDivergentDefs(
23 const MachineInstr &I) const {
24 for (auto &op : I.all_defs()) {
25 if (isDivergent(V: op.getReg()))
26 return true;
27 }
28 return false;
29}
30
31template <>
32bool llvm::GenericUniformityAnalysisImpl<MachineSSAContext>::markDefsDivergent(
33 const MachineInstr &Instr) {
34 bool insertedDivergent = false;
35 const auto &MRI = F.getRegInfo();
36 const auto &RBI = *F.getSubtarget().getRegBankInfo();
37 const auto &TRI = *MRI.getTargetRegisterInfo();
38 for (auto &op : Instr.all_defs()) {
39 if (!op.getReg().isVirtual())
40 continue;
41 assert(!op.getSubReg());
42 if (TRI.isUniformReg(MRI, RBI, Reg: op.getReg()))
43 continue;
44 insertedDivergent |= markDivergent(Val: op.getReg());
45 }
46 return insertedDivergent;
47}
48
49template <>
50void llvm::GenericUniformityAnalysisImpl<MachineSSAContext>::initialize() {
51 // Pre-populate UniformValues with all register defs. Physical register defs
52 // are included because they are never analyzed for divergence (initialize
53 // and markDefsDivergent skip them), so they must be in UniformValues to
54 // avoid being falsely reported as divergent.
55 for (const MachineBasicBlock &BB : F) {
56 for (const MachineInstr &MI : BB.instrs()) {
57 for (const MachineOperand &Op : MI.all_defs()) {
58 Register Reg = Op.getReg();
59 if (Reg)
60 UniformValues.insert(V: Reg);
61 }
62 }
63 }
64
65 const auto &InstrInfo = *F.getSubtarget().getInstrInfo();
66
67 for (const MachineBasicBlock &block : F) {
68 for (const MachineInstr &instr : block) {
69 auto uniformity = InstrInfo.getValueUniformity(MI: instr);
70
71 switch (uniformity) {
72 case ValueUniformity::AlwaysUniform:
73 addUniformOverride(Instr: instr);
74 break;
75 case ValueUniformity::NeverUniform:
76 markDivergent(I: instr);
77 break;
78 case ValueUniformity::Custom:
79 break;
80 case ValueUniformity::Default:
81 break;
82 }
83 }
84 }
85}
86
87template <>
88void llvm::GenericUniformityAnalysisImpl<MachineSSAContext>::pushUsers(
89 Register Reg) {
90 assert(isDivergent(Reg));
91 const auto &RegInfo = F.getRegInfo();
92 for (MachineInstr &UserInstr : RegInfo.use_instructions(Reg)) {
93 markDivergent(I: UserInstr);
94 }
95}
96
97template <>
98void llvm::GenericUniformityAnalysisImpl<MachineSSAContext>::pushUsers(
99 const MachineInstr &Instr) {
100 assert(!isAlwaysUniform(Instr));
101 if (Instr.isTerminator())
102 return;
103 for (const MachineOperand &op : Instr.all_defs()) {
104 auto Reg = op.getReg();
105 if (isDivergent(V: Reg))
106 pushUsers(Reg);
107 }
108}
109
110template <>
111bool llvm::GenericUniformityAnalysisImpl<MachineSSAContext>::usesValueFromCycle(
112 const MachineInstr &I, const MachineCycle &DefCycle) const {
113 assert(!isAlwaysUniform(I));
114 for (auto &Op : I.operands()) {
115 if (!Op.isReg() || !Op.readsReg())
116 continue;
117 auto Reg = Op.getReg();
118
119 // FIXME: Physical registers need to be properly checked instead of always
120 // returning true
121 if (Reg.isPhysical())
122 return true;
123
124 auto *Def = F.getRegInfo().getVRegDef(Reg);
125 if (DefCycle.contains(Block: Def->getParent()))
126 return true;
127 }
128 return false;
129}
130
131template <>
132void llvm::GenericUniformityAnalysisImpl<MachineSSAContext>::
133 propagateTemporalDivergence(const MachineInstr &I,
134 const MachineCycle &DefCycle) {
135 const auto &RegInfo = F.getRegInfo();
136 for (auto &Op : I.all_defs()) {
137 if (!Op.getReg().isVirtual())
138 continue;
139 auto Reg = Op.getReg();
140 for (MachineInstr &UserInstr : RegInfo.use_instructions(Reg)) {
141 if (DefCycle.contains(Block: UserInstr.getParent()))
142 continue;
143 markDivergent(I: UserInstr);
144
145 recordTemporalDivergence(Val: Reg, User: &UserInstr, Cycle: &DefCycle);
146 }
147 }
148}
149
150template <>
151bool llvm::GenericUniformityAnalysisImpl<MachineSSAContext>::isDivergentUse(
152 const MachineOperand &U) const {
153 if (!U.isReg())
154 return false;
155
156 auto Reg = U.getReg();
157 if (isDivergent(V: Reg))
158 return true;
159
160 const auto &RegInfo = F.getRegInfo();
161 auto *Def = RegInfo.getOneDef(Reg);
162 if (!Def)
163 return true;
164
165 auto *DefInstr = Def->getParent();
166 auto *UseInstr = U.getParent();
167 return isTemporalDivergent(ObservingBlock: *UseInstr->getParent(), Def: *DefInstr);
168}
169
170template <>
171bool GenericUniformityAnalysisImpl<MachineSSAContext>::isCustomUniform(
172 const MachineInstr &MI) const {
173 llvm_unreachable("no MIR instructions use Custom uniformity yet");
174}
175
176// This ensures explicit instantiation of
177// GenericUniformityAnalysisImpl::ImplDeleter::operator()
178template class llvm::GenericUniformityInfo<MachineSSAContext>;
179template struct llvm::GenericUniformityAnalysisImplDeleter<
180 llvm::GenericUniformityAnalysisImpl<MachineSSAContext>>;
181
182MachineUniformityInfo llvm::computeMachineUniformityInfo(
183 MachineFunction &F, const MachineCycleInfo &cycleInfo,
184 const MachineDominatorTree &domTree, bool HasBranchDivergence) {
185 assert(F.getRegInfo().isSSA() && "Expected to be run on SSA form!");
186 MachineUniformityInfo UI(domTree, cycleInfo);
187 if (HasBranchDivergence)
188 UI.compute();
189 return UI;
190}
191
192namespace {
193
194class MachineUniformityInfoPrinterPass : public MachineFunctionPass {
195public:
196 static char ID;
197
198 MachineUniformityInfoPrinterPass();
199
200 bool runOnMachineFunction(MachineFunction &F) override;
201 void getAnalysisUsage(AnalysisUsage &AU) const override;
202};
203
204} // namespace
205
206AnalysisKey MachineUniformityAnalysis::Key;
207
208MachineUniformityAnalysis::Result
209MachineUniformityAnalysis::run(MachineFunction &MF,
210 MachineFunctionAnalysisManager &MFAM) {
211 auto &DomTree = MFAM.getResult<MachineDominatorTreeAnalysis>(IR&: MF);
212 auto &CI = MFAM.getResult<MachineCycleAnalysis>(IR&: MF);
213 auto &FAM = MFAM.getResult<FunctionAnalysisManagerMachineFunctionProxy>(IR&: MF)
214 .getManager();
215 auto &F = MF.getFunction();
216 auto &TTI = FAM.getResult<TargetIRAnalysis>(IR&: F);
217 return computeMachineUniformityInfo(F&: MF, cycleInfo: CI, domTree: DomTree,
218 HasBranchDivergence: TTI.hasBranchDivergence(F: &F));
219}
220
221PreservedAnalyses
222MachineUniformityPrinterPass::run(MachineFunction &MF,
223 MachineFunctionAnalysisManager &MFAM) {
224 auto &MUI = MFAM.getResult<MachineUniformityAnalysis>(IR&: MF);
225 OS << "MachineUniformityInfo for function: ";
226 MF.getFunction().printAsOperand(O&: OS, /*PrintType=*/false);
227 OS << '\n';
228 MUI.print(out&: OS);
229 return PreservedAnalyses::all();
230}
231
232char MachineUniformityAnalysisPass::ID = 0;
233
234MachineUniformityAnalysisPass::MachineUniformityAnalysisPass()
235 : MachineFunctionPass(ID) {}
236
237INITIALIZE_PASS_BEGIN(MachineUniformityAnalysisPass, "machine-uniformity",
238 "Machine Uniformity Info Analysis", false, true)
239INITIALIZE_PASS_DEPENDENCY(MachineCycleInfoWrapperPass)
240INITIALIZE_PASS_DEPENDENCY(MachineDominatorTreeWrapperPass)
241INITIALIZE_PASS_END(MachineUniformityAnalysisPass, "machine-uniformity",
242 "Machine Uniformity Info Analysis", false, true)
243
244void MachineUniformityAnalysisPass::getAnalysisUsage(AnalysisUsage &AU) const {
245 AU.setPreservesAll();
246 AU.addRequiredTransitive<MachineCycleInfoWrapperPass>();
247 AU.addRequired<MachineDominatorTreeWrapperPass>();
248 MachineFunctionPass::getAnalysisUsage(AU);
249}
250
251bool MachineUniformityAnalysisPass::runOnMachineFunction(MachineFunction &MF) {
252 auto &DomTree = getAnalysis<MachineDominatorTreeWrapperPass>().getDomTree();
253 auto &CI = getAnalysis<MachineCycleInfoWrapperPass>().getCycleInfo();
254 // FIXME: Query TTI::hasBranchDivergence. -run-pass seems to end up with a
255 // default NoTTI
256 UI = computeMachineUniformityInfo(F&: MF, cycleInfo: CI, domTree: DomTree, HasBranchDivergence: true);
257 return false;
258}
259
260void MachineUniformityAnalysisPass::print(raw_ostream &OS,
261 const Module *) const {
262 OS << "MachineUniformityInfo for function: ";
263 UI.getFunction().getFunction().printAsOperand(O&: OS, /*PrintType=*/false);
264 OS << '\n';
265 UI.print(out&: OS);
266}
267
268char MachineUniformityInfoPrinterPass::ID = 0;
269
270MachineUniformityInfoPrinterPass::MachineUniformityInfoPrinterPass()
271 : MachineFunctionPass(ID) {}
272
273INITIALIZE_PASS_BEGIN(MachineUniformityInfoPrinterPass,
274 "print-machine-uniformity",
275 "Print Machine Uniformity Info Analysis", true, true)
276INITIALIZE_PASS_DEPENDENCY(MachineUniformityAnalysisPass)
277INITIALIZE_PASS_END(MachineUniformityInfoPrinterPass,
278 "print-machine-uniformity",
279 "Print Machine Uniformity Info Analysis", true, true)
280
281void MachineUniformityInfoPrinterPass::getAnalysisUsage(
282 AnalysisUsage &AU) const {
283 AU.setPreservesAll();
284 AU.addRequired<MachineUniformityAnalysisPass>();
285 MachineFunctionPass::getAnalysisUsage(AU);
286}
287
288bool MachineUniformityInfoPrinterPass::runOnMachineFunction(
289 MachineFunction &F) {
290 auto &UI = getAnalysis<MachineUniformityAnalysisPass>();
291 UI.print(OS&: errs());
292 return false;
293}
294