1//===-- X86TileConfig.cpp - Tile Register Configure----------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9/// \file Pass to config the shape of AMX physical registers
10/// AMX register need to be configured before use. In X86PreTileConfig pass
11/// the pldtilecfg instruction is inserted, however at that time we don't
12/// know the shape of each physical tile registers, because the register
13/// allocation is not done yet. This pass runs after egister allocation
14/// pass. It collects the shape information of each physical tile register
15/// and store the shape in the stack slot that is allocated for load config
16/// to tile config register.
17//
18//===----------------------------------------------------------------------===//
19
20#include "X86.h"
21#include "X86InstrBuilder.h"
22#include "X86MachineFunctionInfo.h"
23#include "X86Subtarget.h"
24#include "llvm/CodeGen/LiveIntervals.h"
25#include "llvm/CodeGen/MachineFrameInfo.h"
26#include "llvm/CodeGen/MachineFunctionPass.h"
27#include "llvm/CodeGen/MachineInstr.h"
28#include "llvm/CodeGen/MachineRegisterInfo.h"
29#include "llvm/CodeGen/Passes.h"
30#include "llvm/CodeGen/TargetInstrInfo.h"
31#include "llvm/CodeGen/TargetRegisterInfo.h"
32#include "llvm/CodeGen/TileShapeInfo.h"
33#include "llvm/CodeGen/VirtRegMap.h"
34#include "llvm/InitializePasses.h"
35
36using namespace llvm;
37
38#define DEBUG_TYPE "tileconfig"
39
40namespace {
41
42struct X86TileConfig : public MachineFunctionPass {
43
44 X86TileConfig() : MachineFunctionPass(ID) {}
45
46 /// Return the pass name.
47 StringRef getPassName() const override { return "Tile Register Configure"; }
48
49 /// X86TileConfig analysis usage.
50 void getAnalysisUsage(AnalysisUsage &AU) const override {
51 AU.setPreservesAll();
52 AU.addRequired<VirtRegMapWrapperLegacy>();
53 AU.addRequired<LiveIntervalsWrapperPass>();
54 MachineFunctionPass::getAnalysisUsage(AU);
55 }
56
57 /// Perform register allocation.
58 bool runOnMachineFunction(MachineFunction &mf) override;
59
60 MachineFunctionProperties getRequiredProperties() const override {
61 return MachineFunctionProperties().setNoPHIs();
62 }
63
64 static char ID;
65};
66
67} // end anonymous namespace
68
69char X86TileConfig::ID = 0;
70
71INITIALIZE_PASS_BEGIN(X86TileConfig, DEBUG_TYPE, "Tile Register Configure",
72 false, false)
73INITIALIZE_PASS_DEPENDENCY(VirtRegMapWrapperLegacy)
74INITIALIZE_PASS_END(X86TileConfig, DEBUG_TYPE, "Tile Register Configure", false,
75 false)
76
77unsigned getAMXRegNum(MachineRegisterInfo *MRI, Register Reg) {
78 if (Reg.isVirtual()) {
79 unsigned RegClassID = MRI->getRegClass(Reg)->getID();
80 if (RegClassID == X86::TILERegClassID)
81 return 1;
82 if (RegClassID == X86::TILEPAIRRegClassID)
83 return 2;
84 } else {
85 if (Reg >= X86::TMM0 && Reg <= X86::TMM7)
86 return 1;
87 if (Reg >= X86::TMM0_TMM1 && Reg <= X86::TMM6_TMM7)
88 return 2;
89 }
90 return 0;
91}
92
93static void collectVirtRegShapes(MachineRegisterInfo *MRI, VirtRegMap &VRM,
94 Register VirtReg,
95 SmallVector<ShapeT, 8> &Phys2Shapes) {
96 unsigned Num = getAMXRegNum(MRI, Reg: VirtReg);
97 MCRegister PhysReg = VRM.getPhys(virtReg: VirtReg);
98 if (!PhysReg)
99 return;
100
101 if (Num == 1) {
102 unsigned Index = PhysReg - X86::TMM0;
103 if (!Phys2Shapes[Index].isValid()) {
104 ShapeT Shape = VRM.getShape(virtReg: VirtReg);
105 Phys2Shapes[Index] = std::move(Shape);
106 return;
107 }
108 }
109 // Split tile pair shape info to 2 single tile shape info. e.g:
110 // Put TMM0_TMM1's Shape to TMM0's shape + TMM1's Shape in Phys2Shapes.
111 if (Num == 2) {
112 unsigned Index0 = (PhysReg - X86::TMM0_TMM1) * 2;
113 unsigned Index1 = (PhysReg - X86::TMM0_TMM1) * 2 + 1;
114
115 ShapeT Shape = VRM.getShape(virtReg: VirtReg);
116 assert(Shape.getShapeNum() == 2 && "Unexpected shape number!");
117
118 if (!Phys2Shapes[Index0].isValid()) {
119 ShapeT Shape0(Shape.getRow(I: 0), Shape.getCol(I: 0), MRI);
120 Phys2Shapes[Index0] = std::move(Shape0);
121 }
122
123 if (!Phys2Shapes[Index1].isValid()) {
124 ShapeT Shape1(Shape.getRow(I: 1), Shape.getCol(I: 1), MRI);
125 Phys2Shapes[Index1] = std::move(Shape1);
126 }
127 }
128}
129
130static bool isAMXRegClass(MachineRegisterInfo *MRI, Register Reg) {
131 return getAMXRegNum(MRI, Reg) > 0;
132}
133
134bool X86TileConfig::runOnMachineFunction(MachineFunction &MF) {
135 X86MachineFunctionInfo *X86FI = MF.getInfo<X86MachineFunctionInfo>();
136 // Early exit in the common case of non-AMX code.
137 if (X86FI->getAMXProgModel() != AMXProgModelEnum::ManagedRA)
138 return false;
139
140 const X86Subtarget &ST = MF.getSubtarget<X86Subtarget>();
141 const TargetRegisterInfo *TRI = ST.getRegisterInfo();
142 const TargetInstrInfo *TII = ST.getInstrInfo();
143 MachineRegisterInfo &MRI = MF.getRegInfo();
144 LiveIntervals &LIS = getAnalysis<LiveIntervalsWrapperPass>().getLIS();
145 VirtRegMap &VRM = getAnalysis<VirtRegMapWrapperLegacy>().getVRM();
146
147 if (VRM.isShapeMapEmpty())
148 return false;
149
150 int SS = INT_MAX;
151 for (MachineBasicBlock &MBB : MF) {
152 for (MachineInstr &MI : MBB) {
153 if (MI.getOpcode() == X86::PLDTILECFGV) {
154 SS = MI.getOperand(i: 0).getIndex();
155 break;
156 }
157 }
158 if (SS != INT_MAX)
159 break;
160 }
161 // Didn't find PLDTILECFGV, just return false;
162 if (SS == INT_MAX)
163 return false;
164
165 // Try to find a point to insert MIs for constant shapes.
166 // Here we are leveraging the palette id inserted in PreRA pass.
167 unsigned ConstPos = 0;
168 MachineInstr *ConstMI = nullptr;
169 for (MachineInstr &MI : MF.front()) {
170 if (MI.getOpcode() == X86::MOV8mi && SS == MI.getOperand(i: 0).getIndex()) {
171 ConstMI = &MI;
172 break;
173 }
174 ++ConstPos;
175 }
176 assert(ConstMI && "Cannot find an insertion point");
177
178 unsigned AMXRegNum = TRI->getRegClass(i: X86::TILERegClassID)->getNumRegs();
179 SmallVector<ShapeT, 8> Phys2Shapes(AMXRegNum, ShapeT());
180 for (unsigned I = 0, E = MRI.getNumVirtRegs(); I != E; ++I) {
181 Register VirtReg = Register::index2VirtReg(Index: I);
182 if (MRI.reg_nodbg_empty(RegNo: VirtReg))
183 continue;
184 if (!isAMXRegClass(MRI: &MRI, Reg: VirtReg))
185 continue;
186 collectVirtRegShapes(MRI: &MRI, VRM, VirtReg, Phys2Shapes);
187 }
188
189 // Fill in the shape of each tile physical register.
190 for (unsigned I = 0; I < AMXRegNum; ++I) {
191 ShapeT Shape = Phys2Shapes[I];
192 if (!Shape.isValid())
193 continue;
194 DebugLoc DL;
195 bool IsRow = true;
196 MachineInstr *NewMI = nullptr;
197 for (auto &R : {Shape.getRow()->getReg(), Shape.getCol()->getReg()}) {
198 // Here is the data format for the tile config.
199 // 0 palette
200 // 1 start_row
201 // 2-15 reserved, must be zero
202 // 16-17 tile0.colsb Tile 0 bytes per row.
203 // 18-19 tile1.colsb Tile 1 bytes per row.
204 // 20-21 tile2.colsb Tile 2 bytes per row.
205 // ... (sequence continues)
206 // 30-31 tile7.colsb Tile 7 bytes per row.
207 // 32-47 reserved, must be zero
208 // 48 tile0.rows Tile 0 rows.
209 // 49 tile1.rows Tile 1 rows.
210 // 50 tile2.rows Tile 2 rows.
211 // ... (sequence continues)
212 // 55 tile7.rows Tile 7 rows.
213 // 56-63 reserved, must be zero
214 int64_t Imm = INT64_MAX;
215 int Offset = IsRow ? 48 + I : 16 + I * 2;
216 for (auto &DefMI : MRI.def_instructions(Reg: R)) {
217 MachineBasicBlock &MBB = *DefMI.getParent();
218 if (DefMI.isMoveImmediate()) {
219 if (Imm != INT64_MAX) {
220 // FIXME: We should handle this case in future.
221 assert(Imm == DefMI.getOperand(1).getImm() &&
222 "Cannot initialize with different shapes");
223 continue;
224 }
225 if (DefMI.getOperand(i: 1).isImm()) {
226 Imm = DefMI.getOperand(i: 1).getImm();
227 } else {
228 assert(DefMI.getOpcode() == X86::MOV32r0 &&
229 "The opcode is assumed to be MOV32r0 if the operand is not "
230 "immediate.");
231 Imm = 0;
232 }
233
234 NewMI = addFrameReference(
235 MIB: BuildMI(BB&: MF.front(), I: ++ConstMI->getIterator(), MIMD: DL,
236 MCID: TII->get(Opcode: IsRow ? X86::MOV8mi : X86::MOV16mi)),
237 FI: SS, Offset)
238 .addImm(Val: Imm);
239 ConstMI = NewMI;
240 LIS.InsertMachineInstrInMaps(MI&: *NewMI);
241 } else {
242 unsigned SubIdx = IsRow ? X86::sub_8bit : X86::sub_16bit;
243 unsigned RegSize = TRI->getRegSizeInBits(RC: *MRI.getRegClass(Reg: R));
244 if ((IsRow && RegSize == 8) || (!IsRow && RegSize == 16))
245 SubIdx = 0;
246 auto Iter = DefMI.getIterator();
247 if (&MBB == &MF.front() &&
248 (unsigned)std::distance(first: MBB.instr_begin(), last: Iter) < ConstPos)
249 Iter = ConstMI->getIterator();
250 NewMI = addFrameReference(
251 MIB: BuildMI(BB&: MBB, I: ++Iter, MIMD: DL,
252 MCID: TII->get(Opcode: IsRow ? X86::MOV8mr : X86::MOV16mr)),
253 FI: SS, Offset)
254 .addReg(RegNo: R, flags: 0, SubReg: SubIdx);
255 SlotIndex SIdx = LIS.InsertMachineInstrInMaps(MI&: *NewMI);
256 LIS.extendToIndices(LR&: LIS.getInterval(Reg: R), Indices: {SIdx.getRegSlot()});
257 }
258 }
259 IsRow = false;
260 }
261 }
262 return true;
263}
264
265FunctionPass *llvm::createX86TileConfigPass() { return new X86TileConfig(); }
266