1//===- UniformityAnalysis.cpp ---------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "llvm/Analysis/UniformityAnalysis.h"
10#include "llvm/ADT/GenericUniformityImpl.h"
11#include "llvm/ADT/SmallBitVector.h"
12#include "llvm/Analysis/CycleAnalysis.h"
13#include "llvm/Analysis/TargetTransformInfo.h"
14#include "llvm/IR/Dominators.h"
15#include "llvm/IR/InstIterator.h"
16#include "llvm/IR/Instructions.h"
17#include "llvm/InitializePasses.h"
18
19using namespace llvm;
20
21template <>
22bool llvm::GenericUniformityAnalysisImpl<SSAContext>::hasDivergentDefs(
23 const Instruction &I) const {
24 return isDivergent(V: (const Value *)&I);
25}
26
27template <>
28bool llvm::GenericUniformityAnalysisImpl<SSAContext>::markDefsDivergent(
29 const Instruction &Instr) {
30 return markDivergent(Val: cast<Value>(Val: &Instr));
31}
32
33template <>
34void llvm::GenericUniformityAnalysisImpl<SSAContext>::pushUsers(
35 const Value *V) {
36 for (const auto *User : V->users()) {
37 if (const auto *UserInstr = dyn_cast<const Instruction>(Val: User)) {
38 markDivergent(I: *UserInstr);
39 }
40 }
41}
42
43template <>
44void llvm::GenericUniformityAnalysisImpl<SSAContext>::pushUsers(
45 const Instruction &Instr) {
46 assert(!isAlwaysUniform(Instr));
47 if (Instr.isTerminator())
48 return;
49 pushUsers(V: cast<Value>(Val: &Instr));
50}
51
52template <>
53bool llvm::GenericUniformityAnalysisImpl<SSAContext>::printDivergentArgs(
54 raw_ostream &OS) const {
55 bool haveDivergentArgs = false;
56 for (const auto &Arg : F.args()) {
57 if (isDivergent(V: &Arg)) {
58 if (!haveDivergentArgs) {
59 OS << "DIVERGENT ARGUMENTS:\n";
60 haveDivergentArgs = true;
61 }
62 OS << " DIVERGENT: " << Context.print(value: &Arg) << '\n';
63 }
64 }
65 return haveDivergentArgs;
66}
67
68template <> void llvm::GenericUniformityAnalysisImpl<SSAContext>::initialize() {
69 // Pre-populate UniformValues with uniform values, then seed divergence.
70 // NeverUniform values are not inserted -- they are divergent by definition
71 // and will be reported as such by isDivergent() (not in UniformValues).
72 SmallVector<const Value *, 4> DivergentArgs;
73 for (auto &Arg : F.args()) {
74 if (TTI->getValueUniformity(V: &Arg) == ValueUniformity::NeverUniform)
75 DivergentArgs.push_back(Elt: &Arg);
76 else
77 UniformValues.insert(V: &Arg);
78 }
79 for (auto &I : instructions(F)) {
80 ValueUniformity IU = TTI->getValueUniformity(V: &I);
81 switch (IU) {
82 case ValueUniformity::AlwaysUniform:
83 UniformValues.insert(V: &I);
84 addUniformOverride(Instr: I);
85 continue;
86 case ValueUniformity::NeverUniform:
87 // Skip inserting -- divergent by definition. Add to Worklist directly
88 // so compute() propagates divergence to users.
89 if (I.isTerminator())
90 DivergentTermBlocks.insert(Ptr: I.getParent());
91 Worklist.push_back(x: &I);
92 continue;
93 case ValueUniformity::Custom:
94 UniformValues.insert(V: &I);
95 addCustomUniformityCandidate(I: &I);
96 continue;
97 case ValueUniformity::Default:
98 UniformValues.insert(V: &I);
99 break;
100 }
101 }
102 // Arguments are not instructions and cannot go on the Worklist, so we
103 // propagate their divergence to users explicitly here. This must happen
104 // after all instructions are in UniformValues so markDivergent (called
105 // inside pushUsers) can successfully erase user instructions from the set.
106 for (const Value *Arg : DivergentArgs)
107 pushUsers(V: Arg);
108}
109
110template <>
111bool llvm::GenericUniformityAnalysisImpl<SSAContext>::usesValueFromCycle(
112 const Instruction &I, const Cycle &DefCycle) const {
113 assert(!isAlwaysUniform(I));
114 for (const Use &U : I.operands()) {
115 if (auto *I = dyn_cast<Instruction>(Val: &U)) {
116 if (DefCycle.contains(Block: I->getParent()))
117 return true;
118 }
119 }
120 return false;
121}
122
123template <>
124void llvm::GenericUniformityAnalysisImpl<
125 SSAContext>::propagateTemporalDivergence(const Instruction &I,
126 const Cycle &DefCycle) {
127 for (auto *User : I.users()) {
128 auto *UserInstr = cast<Instruction>(Val: User);
129 if (DefCycle.contains(Block: UserInstr->getParent()))
130 continue;
131 markDivergent(I: *UserInstr);
132 recordTemporalDivergence(Val: &I, User: UserInstr, Cycle: &DefCycle);
133 }
134}
135
136template <>
137bool llvm::GenericUniformityAnalysisImpl<SSAContext>::isDivergentUse(
138 const Use &U) const {
139 const auto *V = U.get();
140 if (isDivergent(V))
141 return true;
142 if (const auto *DefInstr = dyn_cast<Instruction>(Val: V)) {
143 const auto *UseInstr = cast<Instruction>(Val: U.getUser());
144 return isTemporalDivergent(ObservingBlock: *UseInstr->getParent(), Def: *DefInstr);
145 }
146 return false;
147}
148
149template <>
150bool GenericUniformityAnalysisImpl<SSAContext>::isCustomUniform(
151 const Instruction &I) const {
152 SmallBitVector UniformArgs(I.getNumOperands());
153 for (auto [Idx, Use] : enumerate(First: I.operands()))
154 UniformArgs[Idx] = !isDivergentUse(U: Use);
155 return TTI->isUniform(I: &I, UniformArgs);
156}
157
158// This ensures explicit instantiation of
159// GenericUniformityAnalysisImpl::ImplDeleter::operator()
160template class llvm::GenericUniformityInfo<SSAContext>;
161template struct llvm::GenericUniformityAnalysisImplDeleter<
162 llvm::GenericUniformityAnalysisImpl<SSAContext>>;
163
164//===----------------------------------------------------------------------===//
165// UniformityInfoAnalysis and related pass implementations
166//===----------------------------------------------------------------------===//
167
168llvm::UniformityInfo UniformityInfoAnalysis::run(Function &F,
169 FunctionAnalysisManager &FAM) {
170 auto &DT = FAM.getResult<DominatorTreeAnalysis>(IR&: F);
171 auto &TTI = FAM.getResult<TargetIRAnalysis>(IR&: F);
172 auto &CI = FAM.getResult<CycleAnalysis>(IR&: F);
173 UniformityInfo UI{DT, CI, &TTI};
174 // Skip computation if we can assume everything is uniform.
175 if (TTI.hasBranchDivergence(F: &F))
176 UI.compute();
177
178 return UI;
179}
180
181AnalysisKey UniformityInfoAnalysis::Key;
182
183UniformityInfoPrinterPass::UniformityInfoPrinterPass(raw_ostream &OS)
184 : OS(OS) {}
185
186PreservedAnalyses UniformityInfoPrinterPass::run(Function &F,
187 FunctionAnalysisManager &AM) {
188 OS << "UniformityInfo for function '" << F.getName() << "':\n";
189 AM.getResult<UniformityInfoAnalysis>(IR&: F).print(out&: OS);
190
191 return PreservedAnalyses::all();
192}
193
194//===----------------------------------------------------------------------===//
195// UniformityInfoWrapperPass Implementation
196//===----------------------------------------------------------------------===//
197
198char UniformityInfoWrapperPass::ID = 0;
199
200UniformityInfoWrapperPass::UniformityInfoWrapperPass() : FunctionPass(ID) {}
201
202INITIALIZE_PASS_BEGIN(UniformityInfoWrapperPass, "uniformity",
203 "Uniformity Analysis", false, true)
204INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
205INITIALIZE_PASS_DEPENDENCY(CycleInfoWrapperPass)
206INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass)
207INITIALIZE_PASS_END(UniformityInfoWrapperPass, "uniformity",
208 "Uniformity Analysis", false, true)
209
210void UniformityInfoWrapperPass::getAnalysisUsage(AnalysisUsage &AU) const {
211 AU.setPreservesAll();
212 AU.addRequired<DominatorTreeWrapperPass>();
213 AU.addRequiredTransitive<CycleInfoWrapperPass>();
214 AU.addRequired<TargetTransformInfoWrapperPass>();
215}
216
217bool UniformityInfoWrapperPass::runOnFunction(Function &F) {
218 auto &cycleInfo = getAnalysis<CycleInfoWrapperPass>().getResult();
219 auto &domTree = getAnalysis<DominatorTreeWrapperPass>().getDomTree();
220 auto &targetTransformInfo =
221 getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
222
223 m_function = &F;
224 m_uniformityInfo = UniformityInfo{domTree, cycleInfo, &targetTransformInfo};
225
226 // Skip computation if we can assume everything is uniform.
227 if (targetTransformInfo.hasBranchDivergence(F: m_function))
228 m_uniformityInfo.compute();
229
230 return false;
231}
232
233void UniformityInfoWrapperPass::print(raw_ostream &OS, const Module *) const {
234 OS << "UniformityInfo for function '" << m_function->getName() << "':\n";
235 m_uniformityInfo.print(out&: OS);
236}
237
238void UniformityInfoWrapperPass::releaseMemory() {
239 m_uniformityInfo = UniformityInfo{};
240 m_function = nullptr;
241}
242