1//===- UniformityAnalysis.cpp ---------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "llvm/Analysis/UniformityAnalysis.h"
10#include "llvm/ADT/GenericUniformityImpl.h"
11#include "llvm/Analysis/CycleAnalysis.h"
12#include "llvm/Analysis/TargetTransformInfo.h"
13#include "llvm/IR/Constants.h"
14#include "llvm/IR/Dominators.h"
15#include "llvm/IR/InstIterator.h"
16#include "llvm/IR/Instructions.h"
17#include "llvm/InitializePasses.h"
18
19using namespace llvm;
20
21template <>
22bool llvm::GenericUniformityAnalysisImpl<SSAContext>::hasDivergentDefs(
23 const Instruction &I) const {
24 return isDivergent(V: (const Value *)&I);
25}
26
27template <>
28bool llvm::GenericUniformityAnalysisImpl<SSAContext>::markDefsDivergent(
29 const Instruction &Instr) {
30 return markDivergent(DivVal: cast<Value>(Val: &Instr));
31}
32
33template <> void llvm::GenericUniformityAnalysisImpl<SSAContext>::initialize() {
34 for (auto &I : instructions(F)) {
35 if (TTI->isSourceOfDivergence(V: &I))
36 markDivergent(I);
37 else if (TTI->isAlwaysUniform(V: &I))
38 addUniformOverride(Instr: I);
39 }
40 for (auto &Arg : F.args()) {
41 if (TTI->isSourceOfDivergence(V: &Arg)) {
42 markDivergent(DivVal: &Arg);
43 }
44 }
45}
46
47template <>
48void llvm::GenericUniformityAnalysisImpl<SSAContext>::pushUsers(
49 const Value *V) {
50 for (const auto *User : V->users()) {
51 if (const auto *UserInstr = dyn_cast<const Instruction>(Val: User)) {
52 markDivergent(I: *UserInstr);
53 }
54 }
55}
56
57template <>
58void llvm::GenericUniformityAnalysisImpl<SSAContext>::pushUsers(
59 const Instruction &Instr) {
60 assert(!isAlwaysUniform(Instr));
61 if (Instr.isTerminator())
62 return;
63 pushUsers(V: cast<Value>(Val: &Instr));
64}
65
66template <>
67bool llvm::GenericUniformityAnalysisImpl<SSAContext>::usesValueFromCycle(
68 const Instruction &I, const Cycle &DefCycle) const {
69 assert(!isAlwaysUniform(I));
70 for (const Use &U : I.operands()) {
71 if (auto *I = dyn_cast<Instruction>(Val: &U)) {
72 if (DefCycle.contains(Block: I->getParent()))
73 return true;
74 }
75 }
76 return false;
77}
78
79template <>
80void llvm::GenericUniformityAnalysisImpl<
81 SSAContext>::propagateTemporalDivergence(const Instruction &I,
82 const Cycle &DefCycle) {
83 if (isDivergent(I))
84 return;
85 for (auto *User : I.users()) {
86 auto *UserInstr = cast<Instruction>(Val: User);
87 if (DefCycle.contains(Block: UserInstr->getParent()))
88 continue;
89 markDivergent(I: *UserInstr);
90 }
91}
92
93template <>
94bool llvm::GenericUniformityAnalysisImpl<SSAContext>::isDivergentUse(
95 const Use &U) const {
96 const auto *V = U.get();
97 if (isDivergent(V))
98 return true;
99 if (const auto *DefInstr = dyn_cast<Instruction>(Val: V)) {
100 const auto *UseInstr = cast<Instruction>(Val: U.getUser());
101 return isTemporalDivergent(ObservingBlock: *UseInstr->getParent(), Def: *DefInstr);
102 }
103 return false;
104}
105
106// This ensures explicit instantiation of
107// GenericUniformityAnalysisImpl::ImplDeleter::operator()
108template class llvm::GenericUniformityInfo<SSAContext>;
109template struct llvm::GenericUniformityAnalysisImplDeleter<
110 llvm::GenericUniformityAnalysisImpl<SSAContext>>;
111
112//===----------------------------------------------------------------------===//
113// UniformityInfoAnalysis and related pass implementations
114//===----------------------------------------------------------------------===//
115
116llvm::UniformityInfo UniformityInfoAnalysis::run(Function &F,
117 FunctionAnalysisManager &FAM) {
118 auto &DT = FAM.getResult<DominatorTreeAnalysis>(IR&: F);
119 auto &TTI = FAM.getResult<TargetIRAnalysis>(IR&: F);
120 auto &CI = FAM.getResult<CycleAnalysis>(IR&: F);
121 UniformityInfo UI{DT, CI, &TTI};
122 // Skip computation if we can assume everything is uniform.
123 if (TTI.hasBranchDivergence(F: &F))
124 UI.compute();
125
126 return UI;
127}
128
129AnalysisKey UniformityInfoAnalysis::Key;
130
131UniformityInfoPrinterPass::UniformityInfoPrinterPass(raw_ostream &OS)
132 : OS(OS) {}
133
134PreservedAnalyses UniformityInfoPrinterPass::run(Function &F,
135 FunctionAnalysisManager &AM) {
136 OS << "UniformityInfo for function '" << F.getName() << "':\n";
137 AM.getResult<UniformityInfoAnalysis>(IR&: F).print(out&: OS);
138
139 return PreservedAnalyses::all();
140}
141
142//===----------------------------------------------------------------------===//
143// UniformityInfoWrapperPass Implementation
144//===----------------------------------------------------------------------===//
145
146char UniformityInfoWrapperPass::ID = 0;
147
148UniformityInfoWrapperPass::UniformityInfoWrapperPass() : FunctionPass(ID) {
149 initializeUniformityInfoWrapperPassPass(*PassRegistry::getPassRegistry());
150}
151
152INITIALIZE_PASS_BEGIN(UniformityInfoWrapperPass, "uniformity",
153 "Uniformity Analysis", true, true)
154INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
155INITIALIZE_PASS_DEPENDENCY(CycleInfoWrapperPass)
156INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass)
157INITIALIZE_PASS_END(UniformityInfoWrapperPass, "uniformity",
158 "Uniformity Analysis", true, true)
159
160void UniformityInfoWrapperPass::getAnalysisUsage(AnalysisUsage &AU) const {
161 AU.setPreservesAll();
162 AU.addRequired<DominatorTreeWrapperPass>();
163 AU.addRequiredTransitive<CycleInfoWrapperPass>();
164 AU.addRequired<TargetTransformInfoWrapperPass>();
165}
166
167bool UniformityInfoWrapperPass::runOnFunction(Function &F) {
168 auto &cycleInfo = getAnalysis<CycleInfoWrapperPass>().getResult();
169 auto &domTree = getAnalysis<DominatorTreeWrapperPass>().getDomTree();
170 auto &targetTransformInfo =
171 getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
172
173 m_function = &F;
174 m_uniformityInfo = UniformityInfo{domTree, cycleInfo, &targetTransformInfo};
175
176 // Skip computation if we can assume everything is uniform.
177 if (targetTransformInfo.hasBranchDivergence(F: m_function))
178 m_uniformityInfo.compute();
179
180 return false;
181}
182
183void UniformityInfoWrapperPass::print(raw_ostream &OS, const Module *) const {
184 OS << "UniformityInfo for function '" << m_function->getName() << "':\n";
185}
186
187void UniformityInfoWrapperPass::releaseMemory() {
188 m_uniformityInfo = UniformityInfo{};
189 m_function = nullptr;
190}
191