1//===- MachineUniformityAnalysis.cpp --------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "llvm/CodeGen/MachineUniformityAnalysis.h"
10#include "llvm/ADT/GenericUniformityImpl.h"
11#include "llvm/Analysis/TargetTransformInfo.h"
12#include "llvm/CodeGen/MachineCycleAnalysis.h"
13#include "llvm/CodeGen/MachineDominators.h"
14#include "llvm/CodeGen/MachineRegisterInfo.h"
15#include "llvm/CodeGen/MachineSSAContext.h"
16#include "llvm/CodeGen/TargetInstrInfo.h"
17#include "llvm/InitializePasses.h"
18
19using namespace llvm;
20
21template <>
22bool llvm::GenericUniformityAnalysisImpl<MachineSSAContext>::hasDivergentDefs(
23 const MachineInstr &I) const {
24 for (auto &op : I.all_defs()) {
25 if (isDivergent(V: op.getReg()))
26 return true;
27 }
28 return false;
29}
30
31template <>
32bool llvm::GenericUniformityAnalysisImpl<MachineSSAContext>::markDefsDivergent(
33 const MachineInstr &Instr) {
34 bool insertedDivergent = false;
35 const auto &MRI = F.getRegInfo();
36 const auto &RBI = *F.getSubtarget().getRegBankInfo();
37 const auto &TRI = *MRI.getTargetRegisterInfo();
38 for (auto &op : Instr.all_defs()) {
39 if (!op.getReg().isVirtual())
40 continue;
41 assert(!op.getSubReg());
42 if (TRI.isUniformReg(MRI, RBI, Reg: op.getReg()))
43 continue;
44 insertedDivergent |= markDivergent(Val: op.getReg());
45 }
46 return insertedDivergent;
47}
48
49template <>
50void llvm::GenericUniformityAnalysisImpl<MachineSSAContext>::initialize() {
51 const auto &InstrInfo = *F.getSubtarget().getInstrInfo();
52
53 for (const MachineBasicBlock &block : F) {
54 for (const MachineInstr &instr : block) {
55 auto uniformity = InstrInfo.getInstructionUniformity(MI: instr);
56
57 switch (uniformity) {
58 case InstructionUniformity::AlwaysUniform:
59 addUniformOverride(Instr: instr);
60 break;
61 case InstructionUniformity::NeverUniform:
62 markDivergent(I: instr);
63 break;
64 case InstructionUniformity::Default:
65 break;
66 }
67 }
68 }
69}
70
71template <>
72void llvm::GenericUniformityAnalysisImpl<MachineSSAContext>::pushUsers(
73 Register Reg) {
74 assert(isDivergent(Reg));
75 const auto &RegInfo = F.getRegInfo();
76 for (MachineInstr &UserInstr : RegInfo.use_instructions(Reg)) {
77 markDivergent(I: UserInstr);
78 }
79}
80
81template <>
82void llvm::GenericUniformityAnalysisImpl<MachineSSAContext>::pushUsers(
83 const MachineInstr &Instr) {
84 assert(!isAlwaysUniform(Instr));
85 if (Instr.isTerminator())
86 return;
87 for (const MachineOperand &op : Instr.all_defs()) {
88 auto Reg = op.getReg();
89 if (isDivergent(V: Reg))
90 pushUsers(Reg);
91 }
92}
93
94template <>
95bool llvm::GenericUniformityAnalysisImpl<MachineSSAContext>::usesValueFromCycle(
96 const MachineInstr &I, const MachineCycle &DefCycle) const {
97 assert(!isAlwaysUniform(I));
98 for (auto &Op : I.operands()) {
99 if (!Op.isReg() || !Op.readsReg())
100 continue;
101 auto Reg = Op.getReg();
102
103 // FIXME: Physical registers need to be properly checked instead of always
104 // returning true
105 if (Reg.isPhysical())
106 return true;
107
108 auto *Def = F.getRegInfo().getVRegDef(Reg);
109 if (DefCycle.contains(Block: Def->getParent()))
110 return true;
111 }
112 return false;
113}
114
115template <>
116void llvm::GenericUniformityAnalysisImpl<MachineSSAContext>::
117 propagateTemporalDivergence(const MachineInstr &I,
118 const MachineCycle &DefCycle) {
119 const auto &RegInfo = F.getRegInfo();
120 for (auto &Op : I.all_defs()) {
121 if (!Op.getReg().isVirtual())
122 continue;
123 auto Reg = Op.getReg();
124 for (MachineInstr &UserInstr : RegInfo.use_instructions(Reg)) {
125 if (DefCycle.contains(Block: UserInstr.getParent()))
126 continue;
127 markDivergent(I: UserInstr);
128
129 recordTemporalDivergence(Val: Reg, User: &UserInstr, Cycle: &DefCycle);
130 }
131 }
132}
133
134template <>
135bool llvm::GenericUniformityAnalysisImpl<MachineSSAContext>::isDivergentUse(
136 const MachineOperand &U) const {
137 if (!U.isReg())
138 return false;
139
140 auto Reg = U.getReg();
141 if (isDivergent(V: Reg))
142 return true;
143
144 const auto &RegInfo = F.getRegInfo();
145 auto *Def = RegInfo.getOneDef(Reg);
146 if (!Def)
147 return true;
148
149 auto *DefInstr = Def->getParent();
150 auto *UseInstr = U.getParent();
151 return isTemporalDivergent(ObservingBlock: *UseInstr->getParent(), Def: *DefInstr);
152}
153
154// This ensures explicit instantiation of
155// GenericUniformityAnalysisImpl::ImplDeleter::operator()
156template class llvm::GenericUniformityInfo<MachineSSAContext>;
157template struct llvm::GenericUniformityAnalysisImplDeleter<
158 llvm::GenericUniformityAnalysisImpl<MachineSSAContext>>;
159
160MachineUniformityInfo llvm::computeMachineUniformityInfo(
161 MachineFunction &F, const MachineCycleInfo &cycleInfo,
162 const MachineDominatorTree &domTree, bool HasBranchDivergence) {
163 assert(F.getRegInfo().isSSA() && "Expected to be run on SSA form!");
164 MachineUniformityInfo UI(domTree, cycleInfo);
165 if (HasBranchDivergence)
166 UI.compute();
167 return UI;
168}
169
170namespace {
171
172class MachineUniformityInfoPrinterPass : public MachineFunctionPass {
173public:
174 static char ID;
175
176 MachineUniformityInfoPrinterPass();
177
178 bool runOnMachineFunction(MachineFunction &F) override;
179 void getAnalysisUsage(AnalysisUsage &AU) const override;
180};
181
182} // namespace
183
184AnalysisKey MachineUniformityAnalysis::Key;
185
186MachineUniformityAnalysis::Result
187MachineUniformityAnalysis::run(MachineFunction &MF,
188 MachineFunctionAnalysisManager &MFAM) {
189 auto &DomTree = MFAM.getResult<MachineDominatorTreeAnalysis>(IR&: MF);
190 auto &CI = MFAM.getResult<MachineCycleAnalysis>(IR&: MF);
191 auto &FAM = MFAM.getResult<FunctionAnalysisManagerMachineFunctionProxy>(IR&: MF)
192 .getManager();
193 auto &F = MF.getFunction();
194 auto &TTI = FAM.getResult<TargetIRAnalysis>(IR&: F);
195 return computeMachineUniformityInfo(F&: MF, cycleInfo: CI, domTree: DomTree,
196 HasBranchDivergence: TTI.hasBranchDivergence(F: &F));
197}
198
199PreservedAnalyses
200MachineUniformityPrinterPass::run(MachineFunction &MF,
201 MachineFunctionAnalysisManager &MFAM) {
202 auto &MUI = MFAM.getResult<MachineUniformityAnalysis>(IR&: MF);
203 OS << "MachineUniformityInfo for function: ";
204 MF.getFunction().printAsOperand(O&: OS, /*PrintType=*/false);
205 OS << '\n';
206 MUI.print(out&: OS);
207 return PreservedAnalyses::all();
208}
209
210char MachineUniformityAnalysisPass::ID = 0;
211
212MachineUniformityAnalysisPass::MachineUniformityAnalysisPass()
213 : MachineFunctionPass(ID) {}
214
215INITIALIZE_PASS_BEGIN(MachineUniformityAnalysisPass, "machine-uniformity",
216 "Machine Uniformity Info Analysis", false, true)
217INITIALIZE_PASS_DEPENDENCY(MachineCycleInfoWrapperPass)
218INITIALIZE_PASS_DEPENDENCY(MachineDominatorTreeWrapperPass)
219INITIALIZE_PASS_END(MachineUniformityAnalysisPass, "machine-uniformity",
220 "Machine Uniformity Info Analysis", false, true)
221
222void MachineUniformityAnalysisPass::getAnalysisUsage(AnalysisUsage &AU) const {
223 AU.setPreservesAll();
224 AU.addRequiredTransitive<MachineCycleInfoWrapperPass>();
225 AU.addRequired<MachineDominatorTreeWrapperPass>();
226 MachineFunctionPass::getAnalysisUsage(AU);
227}
228
229bool MachineUniformityAnalysisPass::runOnMachineFunction(MachineFunction &MF) {
230 auto &DomTree = getAnalysis<MachineDominatorTreeWrapperPass>().getDomTree();
231 auto &CI = getAnalysis<MachineCycleInfoWrapperPass>().getCycleInfo();
232 // FIXME: Query TTI::hasBranchDivergence. -run-pass seems to end up with a
233 // default NoTTI
234 UI = computeMachineUniformityInfo(F&: MF, cycleInfo: CI, domTree: DomTree, HasBranchDivergence: true);
235 return false;
236}
237
238void MachineUniformityAnalysisPass::print(raw_ostream &OS,
239 const Module *) const {
240 OS << "MachineUniformityInfo for function: ";
241 UI.getFunction().getFunction().printAsOperand(O&: OS, /*PrintType=*/false);
242 OS << '\n';
243 UI.print(out&: OS);
244}
245
246char MachineUniformityInfoPrinterPass::ID = 0;
247
248MachineUniformityInfoPrinterPass::MachineUniformityInfoPrinterPass()
249 : MachineFunctionPass(ID) {}
250
251INITIALIZE_PASS_BEGIN(MachineUniformityInfoPrinterPass,
252 "print-machine-uniformity",
253 "Print Machine Uniformity Info Analysis", true, true)
254INITIALIZE_PASS_DEPENDENCY(MachineUniformityAnalysisPass)
255INITIALIZE_PASS_END(MachineUniformityInfoPrinterPass,
256 "print-machine-uniformity",
257 "Print Machine Uniformity Info Analysis", true, true)
258
259void MachineUniformityInfoPrinterPass::getAnalysisUsage(
260 AnalysisUsage &AU) const {
261 AU.setPreservesAll();
262 AU.addRequired<MachineUniformityAnalysisPass>();
263 MachineFunctionPass::getAnalysisUsage(AU);
264}
265
266bool MachineUniformityInfoPrinterPass::runOnMachineFunction(
267 MachineFunction &F) {
268 auto &UI = getAnalysis<MachineUniformityAnalysisPass>();
269 UI.print(OS&: errs());
270 return false;
271}
272