1//===-- AMDGPUGlobalISelDivergenceLowering.cpp ----------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9/// \file
10/// GlobalISel pass that selects divergent i1 phis as lane mask phis.
11/// Lane mask merging uses same algorithm as SDAG in SILowerI1Copies.
12/// Handles all cases of temporal divergence.
13/// For divergent non-phi i1 and uniform i1 uses outside of the cycle this pass
14/// currently depends on LCSSA to insert phis with one incoming.
15//
16//===----------------------------------------------------------------------===//
17
18#include "AMDGPU.h"
19#include "SILowerI1Copies.h"
20#include "llvm/CodeGen/GlobalISel/MachineIRBuilder.h"
21#include "llvm/CodeGen/MachineFunctionPass.h"
22#include "llvm/CodeGen/MachineUniformityAnalysis.h"
23#include "llvm/InitializePasses.h"
24
25#define DEBUG_TYPE "amdgpu-global-isel-divergence-lowering"
26
27using namespace llvm;
28
29namespace {
30
31class AMDGPUGlobalISelDivergenceLowering : public MachineFunctionPass {
32public:
33 static char ID;
34
35public:
36 AMDGPUGlobalISelDivergenceLowering() : MachineFunctionPass(ID) {
37 initializeAMDGPUGlobalISelDivergenceLoweringPass(
38 *PassRegistry::getPassRegistry());
39 }
40
41 bool runOnMachineFunction(MachineFunction &MF) override;
42
43 StringRef getPassName() const override {
44 return "AMDGPU GlobalISel divergence lowering";
45 }
46
47 void getAnalysisUsage(AnalysisUsage &AU) const override {
48 AU.setPreservesCFG();
49 AU.addRequired<MachineDominatorTreeWrapperPass>();
50 AU.addRequired<MachinePostDominatorTreeWrapperPass>();
51 AU.addRequired<MachineUniformityAnalysisPass>();
52 MachineFunctionPass::getAnalysisUsage(AU);
53 }
54};
55
56class DivergenceLoweringHelper : public PhiLoweringHelper {
57public:
58 DivergenceLoweringHelper(MachineFunction *MF, MachineDominatorTree *DT,
59 MachinePostDominatorTree *PDT,
60 MachineUniformityInfo *MUI);
61
62private:
63 MachineUniformityInfo *MUI = nullptr;
64 MachineIRBuilder B;
65 Register buildRegCopyToLaneMask(Register Reg);
66
67public:
68 void markAsLaneMask(Register DstReg) const override;
69 void getCandidatesForLowering(
70 SmallVectorImpl<MachineInstr *> &Vreg1Phis) const override;
71 void collectIncomingValuesFromPhi(
72 const MachineInstr *MI,
73 SmallVectorImpl<Incoming> &Incomings) const override;
74 void replaceDstReg(Register NewReg, Register OldReg,
75 MachineBasicBlock *MBB) override;
76 void buildMergeLaneMasks(MachineBasicBlock &MBB,
77 MachineBasicBlock::iterator I, const DebugLoc &DL,
78 Register DstReg, Register PrevReg,
79 Register CurReg) override;
80 void constrainAsLaneMask(Incoming &In) override;
81};
82
83DivergenceLoweringHelper::DivergenceLoweringHelper(
84 MachineFunction *MF, MachineDominatorTree *DT,
85 MachinePostDominatorTree *PDT, MachineUniformityInfo *MUI)
86 : PhiLoweringHelper(MF, DT, PDT), MUI(MUI), B(*MF) {}
87
88// _(s1) -> SReg_32/64(s1)
89void DivergenceLoweringHelper::markAsLaneMask(Register DstReg) const {
90 assert(MRI->getType(DstReg) == LLT::scalar(1));
91
92 if (MRI->getRegClassOrNull(Reg: DstReg)) {
93 if (MRI->constrainRegClass(Reg: DstReg, RC: ST->getBoolRC()))
94 return;
95 llvm_unreachable("Failed to constrain register class");
96 }
97
98 MRI->setRegClass(Reg: DstReg, RC: ST->getBoolRC());
99}
100
101void DivergenceLoweringHelper::getCandidatesForLowering(
102 SmallVectorImpl<MachineInstr *> &Vreg1Phis) const {
103 LLT S1 = LLT::scalar(SizeInBits: 1);
104
105 // Add divergent i1 phis to the list
106 for (MachineBasicBlock &MBB : *MF) {
107 for (MachineInstr &MI : MBB.phis()) {
108 Register Dst = MI.getOperand(i: 0).getReg();
109 if (MRI->getType(Reg: Dst) == S1 && MUI->isDivergent(V: Dst))
110 Vreg1Phis.push_back(Elt: &MI);
111 }
112 }
113}
114
115void DivergenceLoweringHelper::collectIncomingValuesFromPhi(
116 const MachineInstr *MI, SmallVectorImpl<Incoming> &Incomings) const {
117 for (unsigned i = 1; i < MI->getNumOperands(); i += 2) {
118 Incomings.emplace_back(Args: MI->getOperand(i).getReg(),
119 Args: MI->getOperand(i: i + 1).getMBB(), Args: Register());
120 }
121}
122
123void DivergenceLoweringHelper::replaceDstReg(Register NewReg, Register OldReg,
124 MachineBasicBlock *MBB) {
125 BuildMI(BB&: *MBB, I: MBB->getFirstNonPHI(), MIMD: {}, MCID: TII->get(Opcode: AMDGPU::COPY), DestReg: OldReg)
126 .addReg(RegNo: NewReg);
127}
128
129// Copy Reg to new lane mask register, insert a copy after instruction that
130// defines Reg while skipping phis if needed.
131Register DivergenceLoweringHelper::buildRegCopyToLaneMask(Register Reg) {
132 Register LaneMask = createLaneMaskReg(MRI, LaneMaskRegAttrs);
133 MachineInstr *Instr = MRI->getVRegDef(Reg);
134 MachineBasicBlock *MBB = Instr->getParent();
135 B.setInsertPt(MBB&: *MBB, II: MBB->SkipPHIsAndLabels(I: std::next(x: Instr->getIterator())));
136 B.buildCopy(Res: LaneMask, Op: Reg);
137 return LaneMask;
138}
139
140// bb.previous
141// %PrevReg = ...
142//
143// bb.current
144// %CurReg = ...
145//
146// %DstReg - not defined
147//
148// -> (wave32 example, new registers have sreg_32 reg class and S1 LLT)
149//
150// bb.previous
151// %PrevReg = ...
152// %PrevRegCopy:sreg_32(s1) = COPY %PrevReg
153//
154// bb.current
155// %CurReg = ...
156// %CurRegCopy:sreg_32(s1) = COPY %CurReg
157// ...
158// %PrevMaskedReg:sreg_32(s1) = ANDN2 %PrevRegCopy, ExecReg - active lanes 0
159// %CurMaskedReg:sreg_32(s1) = AND %ExecReg, CurRegCopy - inactive lanes to 0
160// %DstReg:sreg_32(s1) = OR %PrevMaskedReg, CurMaskedReg
161//
162// DstReg = for active lanes rewrite bit in PrevReg with bit from CurReg
163void DivergenceLoweringHelper::buildMergeLaneMasks(
164 MachineBasicBlock &MBB, MachineBasicBlock::iterator I, const DebugLoc &DL,
165 Register DstReg, Register PrevReg, Register CurReg) {
166 // DstReg = (PrevReg & !EXEC) | (CurReg & EXEC)
167 // TODO: check if inputs are constants or results of a compare.
168
169 Register PrevRegCopy = buildRegCopyToLaneMask(Reg: PrevReg);
170 Register CurRegCopy = buildRegCopyToLaneMask(Reg: CurReg);
171 Register PrevMaskedReg = createLaneMaskReg(MRI, LaneMaskRegAttrs);
172 Register CurMaskedReg = createLaneMaskReg(MRI, LaneMaskRegAttrs);
173
174 B.setInsertPt(MBB, II: I);
175 B.buildInstr(Opc: AndN2Op, DstOps: {PrevMaskedReg}, SrcOps: {PrevRegCopy, ExecReg});
176 B.buildInstr(Opc: AndOp, DstOps: {CurMaskedReg}, SrcOps: {ExecReg, CurRegCopy});
177 B.buildInstr(Opc: OrOp, DstOps: {DstReg}, SrcOps: {PrevMaskedReg, CurMaskedReg});
178}
179
180// GlobalISel has to constrain S1 incoming taken as-is with lane mask register
181// class. Insert a copy of Incoming.Reg to new lane mask inside Incoming.Block,
182// Incoming.Reg becomes that new lane mask.
183void DivergenceLoweringHelper::constrainAsLaneMask(Incoming &In) {
184 B.setInsertPt(MBB&: *In.Block, II: In.Block->getFirstTerminator());
185
186 auto Copy = B.buildCopy(Res: LLT::scalar(SizeInBits: 1), Op: In.Reg);
187 MRI->setRegClass(Reg: Copy.getReg(Idx: 0), RC: ST->getBoolRC());
188 In.Reg = Copy.getReg(Idx: 0);
189}
190
191} // End anonymous namespace.
192
193INITIALIZE_PASS_BEGIN(AMDGPUGlobalISelDivergenceLowering, DEBUG_TYPE,
194 "AMDGPU GlobalISel divergence lowering", false, false)
195INITIALIZE_PASS_DEPENDENCY(MachineDominatorTreeWrapperPass)
196INITIALIZE_PASS_DEPENDENCY(MachinePostDominatorTreeWrapperPass)
197INITIALIZE_PASS_DEPENDENCY(MachineUniformityAnalysisPass)
198INITIALIZE_PASS_END(AMDGPUGlobalISelDivergenceLowering, DEBUG_TYPE,
199 "AMDGPU GlobalISel divergence lowering", false, false)
200
201char AMDGPUGlobalISelDivergenceLowering::ID = 0;
202
203char &llvm::AMDGPUGlobalISelDivergenceLoweringID =
204 AMDGPUGlobalISelDivergenceLowering::ID;
205
206FunctionPass *llvm::createAMDGPUGlobalISelDivergenceLoweringPass() {
207 return new AMDGPUGlobalISelDivergenceLowering();
208}
209
210bool AMDGPUGlobalISelDivergenceLowering::runOnMachineFunction(
211 MachineFunction &MF) {
212 MachineDominatorTree &DT =
213 getAnalysis<MachineDominatorTreeWrapperPass>().getDomTree();
214 MachinePostDominatorTree &PDT =
215 getAnalysis<MachinePostDominatorTreeWrapperPass>().getPostDomTree();
216 MachineUniformityInfo &MUI =
217 getAnalysis<MachineUniformityAnalysisPass>().getUniformityInfo();
218
219 DivergenceLoweringHelper Helper(&MF, &DT, &PDT, &MUI);
220
221 return Helper.lowerPhis();
222}
223