1//===-- X86TileConfig.cpp - Tile Register Configure----------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9/// \file Pass to config the shape of AMX physical registers
10/// AMX register need to be configured before use. In X86PreTileConfig pass
11/// the pldtilecfg instruction is inserted, however at that time we don't
12/// know the shape of each physical tile registers, because the register
13/// allocation is not done yet. This pass runs after egister allocation
14/// pass. It collects the shape information of each physical tile register
15/// and store the shape in the stack slot that is allocated for load config
16/// to tile config register.
17//
18//===----------------------------------------------------------------------===//
19
20#include "X86.h"
21#include "X86InstrBuilder.h"
22#include "X86MachineFunctionInfo.h"
23#include "X86Subtarget.h"
24#include "llvm/CodeGen/LiveIntervals.h"
25#include "llvm/CodeGen/MachineFrameInfo.h"
26#include "llvm/CodeGen/MachineFunctionPass.h"
27#include "llvm/CodeGen/MachineInstr.h"
28#include "llvm/CodeGen/MachineRegisterInfo.h"
29#include "llvm/CodeGen/Passes.h"
30#include "llvm/CodeGen/TargetInstrInfo.h"
31#include "llvm/CodeGen/TargetRegisterInfo.h"
32#include "llvm/CodeGen/TileShapeInfo.h"
33#include "llvm/CodeGen/VirtRegMap.h"
34#include "llvm/InitializePasses.h"
35
36using namespace llvm;
37
38#define DEBUG_TYPE "x86-tile-config"
39
40namespace {
41
42struct X86TileConfigLegacy : public MachineFunctionPass {
43
44 X86TileConfigLegacy() : MachineFunctionPass(ID) {}
45
46 /// Return the pass name.
47 StringRef getPassName() const override { return "Tile Register Configure"; }
48
49 /// X86TileConfig analysis usage.
50 void getAnalysisUsage(AnalysisUsage &AU) const override {
51 AU.setPreservesAll();
52 AU.addRequired<VirtRegMapWrapperLegacy>();
53 AU.addRequired<LiveIntervalsWrapperPass>();
54 MachineFunctionPass::getAnalysisUsage(AU);
55 }
56
57 /// Perform register allocation.
58 bool runOnMachineFunction(MachineFunction &MF) override;
59
60 MachineFunctionProperties getRequiredProperties() const override {
61 return MachineFunctionProperties().setNoPHIs();
62 }
63
64 static char ID;
65};
66
67} // end anonymous namespace
68
69char X86TileConfigLegacy::ID = 0;
70
71INITIALIZE_PASS_BEGIN(X86TileConfigLegacy, DEBUG_TYPE,
72 "Tile Register Configure", false, false)
73INITIALIZE_PASS_DEPENDENCY(VirtRegMapWrapperLegacy)
74INITIALIZE_PASS_END(X86TileConfigLegacy, DEBUG_TYPE, "Tile Register Configure",
75 false, false)
76
77static bool tileConfig(MachineFunction &MF,
78 llvm::function_ref<LiveIntervals *()> GetLIs,
79 llvm::function_ref<VirtRegMap *()> GetVRM) {
80 X86MachineFunctionInfo *X86FI = MF.getInfo<X86MachineFunctionInfo>();
81 // Early exit in the common case of non-AMX code.
82 if (X86FI->getAMXProgModel() != AMXProgModelEnum::ManagedRA)
83 return false;
84
85 const X86Subtarget &ST = MF.getSubtarget<X86Subtarget>();
86 const X86RegisterInfo *TRI = ST.getRegisterInfo();
87 const TargetInstrInfo *TII = ST.getInstrInfo();
88 MachineRegisterInfo &MRI = MF.getRegInfo();
89 LiveIntervals &LIS = *GetLIs();
90 VirtRegMap &VRM = *GetVRM();
91
92 if (VRM.isShapeMapEmpty())
93 return false;
94
95 int SS = INT_MAX;
96 for (MachineBasicBlock &MBB : MF) {
97 for (MachineInstr &MI : MBB) {
98 if (MI.getOpcode() == X86::PLDTILECFGV) {
99 SS = MI.getOperand(i: 0).getIndex();
100 break;
101 }
102 }
103 if (SS != INT_MAX)
104 break;
105 }
106 // Didn't find PLDTILECFGV, just return false;
107 if (SS == INT_MAX)
108 return false;
109
110 // Try to find a point to insert MIs for constant shapes.
111 // Here we are leveraging the palette id inserted in PreRA pass.
112 unsigned ConstPos = 0;
113 MachineInstr *ConstMI = nullptr;
114 for (MachineInstr &MI : MF.front()) {
115 if (MI.getOpcode() == X86::MOV8mi && SS == MI.getOperand(i: 0).getIndex()) {
116 ConstMI = &MI;
117 break;
118 }
119 ++ConstPos;
120 }
121 assert(ConstMI && "Cannot find an insertion point");
122
123 unsigned AMXRegNum = TRI->getRegClass(i: X86::TILERegClassID)->getNumRegs();
124 SmallVector<Register, 8> Phys2Virt(AMXRegNum, 0);
125 for (unsigned I = 0, E = MRI.getNumVirtRegs(); I != E; ++I) {
126 Register VirtReg = Register::index2VirtReg(Index: I);
127 if (MRI.reg_nodbg_empty(RegNo: VirtReg))
128 continue;
129 if (!TRI->isTileRegisterClass(RC: MRI.getRegClass(Reg: VirtReg)))
130 continue;
131 MCRegister PhysReg = VRM.getPhys(virtReg: VirtReg);
132 if (!PhysReg)
133 continue;
134 unsigned Index = PhysReg - X86::TMM0;
135 if (!Phys2Virt[Index])
136 Phys2Virt[Index] = VirtReg;
137 }
138
139 // Fill in the shape of each tile physical register.
140 for (unsigned I = 0; I < AMXRegNum; ++I) {
141 if (!Phys2Virt[I])
142 continue;
143 DebugLoc DL;
144 bool IsRow = true;
145 MachineInstr *NewMI = nullptr;
146 ShapeT Shape = VRM.getShape(virtReg: Phys2Virt[I]);
147 for (auto &R : {Shape.getRow()->getReg(), Shape.getCol()->getReg()}) {
148 // Here is the data format for the tile config.
149 // 0 palette
150 // 1 start_row
151 // 2-15 reserved, must be zero
152 // 16-17 tile0.colsb Tile 0 bytes per row.
153 // 18-19 tile1.colsb Tile 1 bytes per row.
154 // 20-21 tile2.colsb Tile 2 bytes per row.
155 // ... (sequence continues)
156 // 30-31 tile7.colsb Tile 7 bytes per row.
157 // 32-47 reserved, must be zero
158 // 48 tile0.rows Tile 0 rows.
159 // 49 tile1.rows Tile 1 rows.
160 // 50 tile2.rows Tile 2 rows.
161 // ... (sequence continues)
162 // 55 tile7.rows Tile 7 rows.
163 // 56-63 reserved, must be zero
164 int64_t Imm = INT64_MAX;
165 int Offset = IsRow ? 48 + I : 16 + I * 2;
166 for (auto &DefMI : MRI.def_instructions(Reg: R)) {
167 MachineBasicBlock &MBB = *DefMI.getParent();
168 if (DefMI.isMoveImmediate()) {
169 if (Imm != INT64_MAX) {
170 // FIXME: We should handle this case in future.
171 assert(Imm == DefMI.getOperand(1).getImm() &&
172 "Cannot initialize with different shapes");
173 continue;
174 }
175 if (DefMI.getOperand(i: 1).isImm()) {
176 Imm = DefMI.getOperand(i: 1).getImm();
177 } else {
178 assert(DefMI.getOpcode() == X86::MOV32r0 &&
179 "The opcode is assumed to be MOV32r0 if the operand is not "
180 "immediate.");
181 Imm = 0;
182 }
183
184 NewMI = addFrameReference(
185 MIB: BuildMI(BB&: MF.front(), I: ++ConstMI->getIterator(), MIMD: DL,
186 MCID: TII->get(Opcode: IsRow ? X86::MOV8mi : X86::MOV16mi)),
187 FI: SS, Offset)
188 .addImm(Val: Imm);
189 ConstMI = NewMI;
190 LIS.InsertMachineInstrInMaps(MI&: *NewMI);
191 } else {
192 unsigned SubIdx = IsRow ? X86::sub_8bit : X86::sub_16bit;
193 unsigned RegSize = TRI->getRegSizeInBits(RC: *MRI.getRegClass(Reg: R));
194 if ((IsRow && RegSize == 8) || (!IsRow && RegSize == 16))
195 SubIdx = 0;
196 auto Iter = DefMI.getIterator();
197 if (&MBB == &MF.front() &&
198 (unsigned)std::distance(first: MBB.instr_begin(), last: Iter) < ConstPos)
199 Iter = ConstMI->getIterator();
200 NewMI = addFrameReference(
201 MIB: BuildMI(BB&: MBB, I: ++Iter, MIMD: DL,
202 MCID: TII->get(Opcode: IsRow ? X86::MOV8mr : X86::MOV16mr)),
203 FI: SS, Offset)
204 .addReg(RegNo: R, Flags: {}, SubReg: SubIdx);
205 SlotIndex SIdx = LIS.InsertMachineInstrInMaps(MI&: *NewMI);
206 LIS.extendToIndices(LR&: LIS.getInterval(Reg: R), Indices: {SIdx.getRegSlot()});
207 }
208 }
209 IsRow = false;
210 }
211 }
212 return true;
213}
214
215FunctionPass *llvm::createX86TileConfigLegacyPass() {
216 return new X86TileConfigLegacy();
217}
218
219bool X86TileConfigLegacy::runOnMachineFunction(MachineFunction &MF) {
220 return tileConfig(
221 MF,
222 GetLIs: [this]() { return &getAnalysis<LiveIntervalsWrapperPass>().getLIS(); },
223 GetVRM: [this]() { return &getAnalysis<VirtRegMapWrapperLegacy>().getVRM(); });
224}
225
226PreservedAnalyses X86TileConfigPass::run(MachineFunction &MF,
227 MachineFunctionAnalysisManager &MFAM) {
228 bool Changed = tileConfig(
229 MF, GetLIs: [&MFAM, &MF]() { return &MFAM.getResult<LiveIntervalsAnalysis>(IR&: MF); },
230 GetVRM: [&MFAM, &MF]() { return &MFAM.getResult<VirtRegMapAnalysis>(IR&: MF); });
231 return Changed ? getMachineFunctionPassPreservedAnalyses()
232 .preserveSet<CFGAnalyses>()
233 : PreservedAnalyses::all();
234}
235