1//===- UniformityAnalysis.cpp ---------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "llvm/Analysis/UniformityAnalysis.h"
10#include "llvm/ADT/GenericUniformityImpl.h"
11#include "llvm/Analysis/CycleAnalysis.h"
12#include "llvm/Analysis/TargetTransformInfo.h"
13#include "llvm/IR/Dominators.h"
14#include "llvm/IR/InstIterator.h"
15#include "llvm/IR/Instructions.h"
16#include "llvm/InitializePasses.h"
17
18using namespace llvm;
19
20template <>
21bool llvm::GenericUniformityAnalysisImpl<SSAContext>::hasDivergentDefs(
22 const Instruction &I) const {
23 return isDivergent(V: (const Value *)&I);
24}
25
26template <>
27bool llvm::GenericUniformityAnalysisImpl<SSAContext>::markDefsDivergent(
28 const Instruction &Instr) {
29 return markDivergent(Val: cast<Value>(Val: &Instr));
30}
31
32template <> void llvm::GenericUniformityAnalysisImpl<SSAContext>::initialize() {
33 for (auto &I : instructions(F)) {
34 if (TTI->isSourceOfDivergence(V: &I))
35 markDivergent(I);
36 else if (TTI->isAlwaysUniform(V: &I))
37 addUniformOverride(Instr: I);
38 }
39 for (auto &Arg : F.args()) {
40 if (TTI->isSourceOfDivergence(V: &Arg)) {
41 markDivergent(Val: &Arg);
42 }
43 }
44}
45
46template <>
47void llvm::GenericUniformityAnalysisImpl<SSAContext>::pushUsers(
48 const Value *V) {
49 for (const auto *User : V->users()) {
50 if (const auto *UserInstr = dyn_cast<const Instruction>(Val: User)) {
51 markDivergent(I: *UserInstr);
52 }
53 }
54}
55
56template <>
57void llvm::GenericUniformityAnalysisImpl<SSAContext>::pushUsers(
58 const Instruction &Instr) {
59 assert(!isAlwaysUniform(Instr));
60 if (Instr.isTerminator())
61 return;
62 pushUsers(V: cast<Value>(Val: &Instr));
63}
64
65template <>
66bool llvm::GenericUniformityAnalysisImpl<SSAContext>::usesValueFromCycle(
67 const Instruction &I, const Cycle &DefCycle) const {
68 assert(!isAlwaysUniform(I));
69 for (const Use &U : I.operands()) {
70 if (auto *I = dyn_cast<Instruction>(Val: &U)) {
71 if (DefCycle.contains(Block: I->getParent()))
72 return true;
73 }
74 }
75 return false;
76}
77
78template <>
79void llvm::GenericUniformityAnalysisImpl<
80 SSAContext>::propagateTemporalDivergence(const Instruction &I,
81 const Cycle &DefCycle) {
82 for (auto *User : I.users()) {
83 auto *UserInstr = cast<Instruction>(Val: User);
84 if (DefCycle.contains(Block: UserInstr->getParent()))
85 continue;
86 markDivergent(I: *UserInstr);
87 recordTemporalDivergence(Val: &I, User: UserInstr, Cycle: &DefCycle);
88 }
89}
90
91template <>
92bool llvm::GenericUniformityAnalysisImpl<SSAContext>::isDivergentUse(
93 const Use &U) const {
94 const auto *V = U.get();
95 if (isDivergent(V))
96 return true;
97 if (const auto *DefInstr = dyn_cast<Instruction>(Val: V)) {
98 const auto *UseInstr = cast<Instruction>(Val: U.getUser());
99 return isTemporalDivergent(ObservingBlock: *UseInstr->getParent(), Def: *DefInstr);
100 }
101 return false;
102}
103
104// This ensures explicit instantiation of
105// GenericUniformityAnalysisImpl::ImplDeleter::operator()
106template class llvm::GenericUniformityInfo<SSAContext>;
107template struct llvm::GenericUniformityAnalysisImplDeleter<
108 llvm::GenericUniformityAnalysisImpl<SSAContext>>;
109
110//===----------------------------------------------------------------------===//
111// UniformityInfoAnalysis and related pass implementations
112//===----------------------------------------------------------------------===//
113
114llvm::UniformityInfo UniformityInfoAnalysis::run(Function &F,
115 FunctionAnalysisManager &FAM) {
116 auto &DT = FAM.getResult<DominatorTreeAnalysis>(IR&: F);
117 auto &TTI = FAM.getResult<TargetIRAnalysis>(IR&: F);
118 auto &CI = FAM.getResult<CycleAnalysis>(IR&: F);
119 UniformityInfo UI{DT, CI, &TTI};
120 // Skip computation if we can assume everything is uniform.
121 if (TTI.hasBranchDivergence(F: &F))
122 UI.compute();
123
124 return UI;
125}
126
127AnalysisKey UniformityInfoAnalysis::Key;
128
129UniformityInfoPrinterPass::UniformityInfoPrinterPass(raw_ostream &OS)
130 : OS(OS) {}
131
132PreservedAnalyses UniformityInfoPrinterPass::run(Function &F,
133 FunctionAnalysisManager &AM) {
134 OS << "UniformityInfo for function '" << F.getName() << "':\n";
135 AM.getResult<UniformityInfoAnalysis>(IR&: F).print(out&: OS);
136
137 return PreservedAnalyses::all();
138}
139
140//===----------------------------------------------------------------------===//
141// UniformityInfoWrapperPass Implementation
142//===----------------------------------------------------------------------===//
143
144char UniformityInfoWrapperPass::ID = 0;
145
146UniformityInfoWrapperPass::UniformityInfoWrapperPass() : FunctionPass(ID) {}
147
148INITIALIZE_PASS_BEGIN(UniformityInfoWrapperPass, "uniformity",
149 "Uniformity Analysis", true, true)
150INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
151INITIALIZE_PASS_DEPENDENCY(CycleInfoWrapperPass)
152INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass)
153INITIALIZE_PASS_END(UniformityInfoWrapperPass, "uniformity",
154 "Uniformity Analysis", true, true)
155
156void UniformityInfoWrapperPass::getAnalysisUsage(AnalysisUsage &AU) const {
157 AU.setPreservesAll();
158 AU.addRequired<DominatorTreeWrapperPass>();
159 AU.addRequiredTransitive<CycleInfoWrapperPass>();
160 AU.addRequired<TargetTransformInfoWrapperPass>();
161}
162
163bool UniformityInfoWrapperPass::runOnFunction(Function &F) {
164 auto &cycleInfo = getAnalysis<CycleInfoWrapperPass>().getResult();
165 auto &domTree = getAnalysis<DominatorTreeWrapperPass>().getDomTree();
166 auto &targetTransformInfo =
167 getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
168
169 m_function = &F;
170 m_uniformityInfo = UniformityInfo{domTree, cycleInfo, &targetTransformInfo};
171
172 // Skip computation if we can assume everything is uniform.
173 if (targetTransformInfo.hasBranchDivergence(F: m_function))
174 m_uniformityInfo.compute();
175
176 return false;
177}
178
179void UniformityInfoWrapperPass::print(raw_ostream &OS, const Module *) const {
180 OS << "UniformityInfo for function '" << m_function->getName() << "':\n";
181}
182
183void UniformityInfoWrapperPass::releaseMemory() {
184 m_uniformityInfo = UniformityInfo{};
185 m_function = nullptr;
186}
187