1//===- UniformityAnalysis.cpp ---------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "llvm/Analysis/UniformityAnalysis.h"
10#include "llvm/ADT/GenericUniformityImpl.h"
11#include "llvm/Analysis/CycleAnalysis.h"
12#include "llvm/Analysis/TargetTransformInfo.h"
13#include "llvm/IR/Dominators.h"
14#include "llvm/IR/InstIterator.h"
15#include "llvm/IR/Instructions.h"
16#include "llvm/InitializePasses.h"
17
18using namespace llvm;
19
20template <>
21bool llvm::GenericUniformityAnalysisImpl<SSAContext>::hasDivergentDefs(
22 const Instruction &I) const {
23 return isDivergent(V: (const Value *)&I);
24}
25
26template <>
27bool llvm::GenericUniformityAnalysisImpl<SSAContext>::markDefsDivergent(
28 const Instruction &Instr) {
29 return markDivergent(Val: cast<Value>(Val: &Instr));
30}
31
32template <> void llvm::GenericUniformityAnalysisImpl<SSAContext>::initialize() {
33 for (auto &I : instructions(F)) {
34 InstructionUniformity IU = TTI->getInstructionUniformity(V: &I);
35 switch (IU) {
36 case InstructionUniformity::AlwaysUniform:
37 addUniformOverride(Instr: I);
38 continue;
39 case InstructionUniformity::NeverUniform:
40 markDivergent(I);
41 continue;
42 case InstructionUniformity::Default:
43 break;
44 }
45 }
46 for (auto &Arg : F.args()) {
47 if (TTI->getInstructionUniformity(V: &Arg) ==
48 InstructionUniformity::NeverUniform)
49 markDivergent(Val: &Arg);
50 }
51}
52
53template <>
54void llvm::GenericUniformityAnalysisImpl<SSAContext>::pushUsers(
55 const Value *V) {
56 for (const auto *User : V->users()) {
57 if (const auto *UserInstr = dyn_cast<const Instruction>(Val: User)) {
58 markDivergent(I: *UserInstr);
59 }
60 }
61}
62
63template <>
64void llvm::GenericUniformityAnalysisImpl<SSAContext>::pushUsers(
65 const Instruction &Instr) {
66 assert(!isAlwaysUniform(Instr));
67 if (Instr.isTerminator())
68 return;
69 pushUsers(V: cast<Value>(Val: &Instr));
70}
71
72template <>
73bool llvm::GenericUniformityAnalysisImpl<SSAContext>::usesValueFromCycle(
74 const Instruction &I, const Cycle &DefCycle) const {
75 assert(!isAlwaysUniform(I));
76 for (const Use &U : I.operands()) {
77 if (auto *I = dyn_cast<Instruction>(Val: &U)) {
78 if (DefCycle.contains(Block: I->getParent()))
79 return true;
80 }
81 }
82 return false;
83}
84
85template <>
86void llvm::GenericUniformityAnalysisImpl<
87 SSAContext>::propagateTemporalDivergence(const Instruction &I,
88 const Cycle &DefCycle) {
89 for (auto *User : I.users()) {
90 auto *UserInstr = cast<Instruction>(Val: User);
91 if (DefCycle.contains(Block: UserInstr->getParent()))
92 continue;
93 markDivergent(I: *UserInstr);
94 recordTemporalDivergence(Val: &I, User: UserInstr, Cycle: &DefCycle);
95 }
96}
97
98template <>
99bool llvm::GenericUniformityAnalysisImpl<SSAContext>::isDivergentUse(
100 const Use &U) const {
101 const auto *V = U.get();
102 if (isDivergent(V))
103 return true;
104 if (const auto *DefInstr = dyn_cast<Instruction>(Val: V)) {
105 const auto *UseInstr = cast<Instruction>(Val: U.getUser());
106 return isTemporalDivergent(ObservingBlock: *UseInstr->getParent(), Def: *DefInstr);
107 }
108 return false;
109}
110
111// This ensures explicit instantiation of
112// GenericUniformityAnalysisImpl::ImplDeleter::operator()
113template class llvm::GenericUniformityInfo<SSAContext>;
114template struct llvm::GenericUniformityAnalysisImplDeleter<
115 llvm::GenericUniformityAnalysisImpl<SSAContext>>;
116
117//===----------------------------------------------------------------------===//
118// UniformityInfoAnalysis and related pass implementations
119//===----------------------------------------------------------------------===//
120
121llvm::UniformityInfo UniformityInfoAnalysis::run(Function &F,
122 FunctionAnalysisManager &FAM) {
123 auto &DT = FAM.getResult<DominatorTreeAnalysis>(IR&: F);
124 auto &TTI = FAM.getResult<TargetIRAnalysis>(IR&: F);
125 auto &CI = FAM.getResult<CycleAnalysis>(IR&: F);
126 UniformityInfo UI{DT, CI, &TTI};
127 // Skip computation if we can assume everything is uniform.
128 if (TTI.hasBranchDivergence(F: &F))
129 UI.compute();
130
131 return UI;
132}
133
134AnalysisKey UniformityInfoAnalysis::Key;
135
136UniformityInfoPrinterPass::UniformityInfoPrinterPass(raw_ostream &OS)
137 : OS(OS) {}
138
139PreservedAnalyses UniformityInfoPrinterPass::run(Function &F,
140 FunctionAnalysisManager &AM) {
141 OS << "UniformityInfo for function '" << F.getName() << "':\n";
142 AM.getResult<UniformityInfoAnalysis>(IR&: F).print(out&: OS);
143
144 return PreservedAnalyses::all();
145}
146
147//===----------------------------------------------------------------------===//
148// UniformityInfoWrapperPass Implementation
149//===----------------------------------------------------------------------===//
150
151char UniformityInfoWrapperPass::ID = 0;
152
153UniformityInfoWrapperPass::UniformityInfoWrapperPass() : FunctionPass(ID) {}
154
155INITIALIZE_PASS_BEGIN(UniformityInfoWrapperPass, "uniformity",
156 "Uniformity Analysis", false, true)
157INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
158INITIALIZE_PASS_DEPENDENCY(CycleInfoWrapperPass)
159INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass)
160INITIALIZE_PASS_END(UniformityInfoWrapperPass, "uniformity",
161 "Uniformity Analysis", false, true)
162
163void UniformityInfoWrapperPass::getAnalysisUsage(AnalysisUsage &AU) const {
164 AU.setPreservesAll();
165 AU.addRequired<DominatorTreeWrapperPass>();
166 AU.addRequiredTransitive<CycleInfoWrapperPass>();
167 AU.addRequired<TargetTransformInfoWrapperPass>();
168}
169
170bool UniformityInfoWrapperPass::runOnFunction(Function &F) {
171 auto &cycleInfo = getAnalysis<CycleInfoWrapperPass>().getResult();
172 auto &domTree = getAnalysis<DominatorTreeWrapperPass>().getDomTree();
173 auto &targetTransformInfo =
174 getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
175
176 m_function = &F;
177 m_uniformityInfo = UniformityInfo{domTree, cycleInfo, &targetTransformInfo};
178
179 // Skip computation if we can assume everything is uniform.
180 if (targetTransformInfo.hasBranchDivergence(F: m_function))
181 m_uniformityInfo.compute();
182
183 return false;
184}
185
186void UniformityInfoWrapperPass::print(raw_ostream &OS, const Module *) const {
187 OS << "UniformityInfo for function '" << m_function->getName() << "':\n";
188 m_uniformityInfo.print(out&: OS);
189}
190
191void UniformityInfoWrapperPass::releaseMemory() {
192 m_uniformityInfo = UniformityInfo{};
193 m_function = nullptr;
194}
195