1//===-- X86FastPreTileConfig.cpp - Fast Tile Register Configure------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9/// \file Pass to preconfig the shape of physical tile registers
10/// It inserts ldtilecfg ahead of each group of tile registers. The algorithm
11/// walk each instruction of basic block in reverse order. All the tile
12/// registers that live out the basic block would be spilled and reloaded
13/// before its user. It also check the depenedency of the shape to ensure
14/// the shape is defined before ldtilecfg.
15//
16//===----------------------------------------------------------------------===//
17
18#include "X86.h"
19#include "X86InstrBuilder.h"
20#include "X86MachineFunctionInfo.h"
21#include "X86RegisterInfo.h"
22#include "X86Subtarget.h"
23#include "llvm/ADT/PostOrderIterator.h"
24#include "llvm/ADT/Statistic.h"
25#include "llvm/CodeGen/MachineFrameInfo.h"
26#include "llvm/CodeGen/MachineFunctionPass.h"
27#include "llvm/CodeGen/MachineInstr.h"
28#include "llvm/CodeGen/MachineRegisterInfo.h"
29#include "llvm/CodeGen/Passes.h"
30#include "llvm/CodeGen/TargetInstrInfo.h"
31#include "llvm/CodeGen/TargetRegisterInfo.h"
32#include "llvm/Support/Debug.h"
33
34using namespace llvm;
35
36#define DEBUG_TYPE "fastpretileconfig"
37
38STATISTIC(NumStores, "Number of stores added");
39STATISTIC(NumLoads, "Number of loads added");
40
41namespace {
42
43class X86FastPreTileConfig : public MachineFunctionPass {
44 MachineFunction *MF = nullptr;
45 const X86Subtarget *ST = nullptr;
46 const TargetInstrInfo *TII = nullptr;
47 MachineRegisterInfo *MRI = nullptr;
48 X86MachineFunctionInfo *X86FI = nullptr;
49 MachineFrameInfo *MFI = nullptr;
50 const TargetRegisterInfo *TRI = nullptr;
51 MachineBasicBlock *MBB = nullptr;
52 int CfgSS = -1;
53 struct PHIInfo {
54 Register Row;
55 Register Col;
56 Register StackAddr;
57 };
58 DenseMap<MachineInstr *, struct PHIInfo> VisitedPHIs;
59
60 /// Maps virtual regs to the frame index where these values are spilled.
61 IndexedMap<int, VirtReg2IndexFunctor> StackSlotForVirtReg;
62
63 /// Has a bit set for tile virtual register for which it was determined
64 /// that it is alive across blocks.
65 BitVector MayLiveAcrossBlocks;
66
67 int getStackSpaceFor(Register VirtReg);
68 void InitializeTileConfigStackSpace();
69 bool mayLiveOut(Register VirtReg, MachineInstr *CfgMI);
70 void spill(MachineBasicBlock::iterator Before, Register VirtReg, bool Kill);
71 void reload(MachineBasicBlock::iterator UseMI, Register VirtReg,
72 MachineOperand *RowMO, MachineOperand *ColMO);
73 void canonicalizePHIs(MachineBasicBlock &MBB);
74 void convertPHI(MachineBasicBlock *MBB, MachineInstr &PHI);
75 void convertPHIs(MachineBasicBlock &MBB);
76 bool configBasicBlock(MachineBasicBlock &MBB);
77
78public:
79 X86FastPreTileConfig() : MachineFunctionPass(ID), StackSlotForVirtReg(-1) {}
80
81 /// Return the pass name.
82 StringRef getPassName() const override {
83 return "Fast Tile Register Preconfigure";
84 }
85
86 /// Perform tile register configure.
87 bool runOnMachineFunction(MachineFunction &MFunc) override;
88
89 static char ID;
90};
91
92} // end anonymous namespace
93
94char X86FastPreTileConfig::ID = 0;
95
96INITIALIZE_PASS_BEGIN(X86FastPreTileConfig, DEBUG_TYPE,
97 "Fast Tile Register Preconfigure", false, false)
98INITIALIZE_PASS_END(X86FastPreTileConfig, DEBUG_TYPE,
99 "Fast Tile Register Preconfigure", false, false)
100
101static bool dominates(MachineBasicBlock &MBB,
102 MachineBasicBlock::const_iterator A,
103 MachineBasicBlock::const_iterator B) {
104 auto MBBEnd = MBB.end();
105 if (B == MBBEnd)
106 return true;
107
108 MachineBasicBlock::const_iterator I = MBB.begin();
109 for (; &*I != A && &*I != B; ++I)
110 ;
111
112 return &*I == A;
113}
114
115/// This allocates space for the specified virtual register to be held on the
116/// stack.
117int X86FastPreTileConfig::getStackSpaceFor(Register VirtReg) {
118 // Find the location Reg would belong...
119 int SS = StackSlotForVirtReg[VirtReg];
120 // Already has space allocated?
121 if (SS != -1)
122 return SS;
123
124 // Allocate a new stack object for this spill location...
125 const TargetRegisterClass &RC = *MRI->getRegClass(Reg: VirtReg);
126 unsigned Size = TRI->getSpillSize(RC);
127 Align Alignment = TRI->getSpillAlign(RC);
128 int FrameIdx = MFI->CreateSpillStackObject(Size, Alignment);
129
130 // Assign the slot.
131 StackSlotForVirtReg[VirtReg] = FrameIdx;
132 return FrameIdx;
133}
134
135/// Returns false if \p VirtReg is known to not live out of the current config.
136/// If \p VirtReg live out of the current MBB, it must live out of the current
137/// config
138bool X86FastPreTileConfig::mayLiveOut(Register VirtReg, MachineInstr *CfgMI) {
139 if (MayLiveAcrossBlocks.test(Idx: VirtReg.virtRegIndex()))
140 return true;
141
142 for (const MachineInstr &UseInst : MRI->use_nodbg_instructions(Reg: VirtReg)) {
143 if (UseInst.getParent() != MBB) {
144 MayLiveAcrossBlocks.set(VirtReg.virtRegIndex());
145 return true;
146 }
147
148 // The use and def are in the same MBB. If the tile register is
149 // reconfigured, it is crobbered and we need to spill and reload
150 // tile register.
151 if (CfgMI) {
152 if (dominates(MBB&: *MBB, A: *CfgMI, B: UseInst)) {
153 MayLiveAcrossBlocks.set(VirtReg.virtRegIndex());
154 return true;
155 }
156 }
157 }
158
159 return false;
160}
161
162void X86FastPreTileConfig::InitializeTileConfigStackSpace() {
163 MachineBasicBlock &MBB = MF->front();
164 MachineInstr *MI = &*MBB.getFirstNonPHI();
165 DebugLoc DL;
166 if (ST->hasAVX512()) {
167 Register Zmm = MRI->createVirtualRegister(RegClass: &X86::VR512RegClass);
168 BuildMI(BB&: MBB, I: MI, MIMD: DL, MCID: TII->get(Opcode: X86::AVX512_512_SET0), DestReg: Zmm);
169 addFrameReference(MIB: BuildMI(BB&: MBB, I: MI, MIMD: DL, MCID: TII->get(Opcode: X86::VMOVUPSZmr)), FI: CfgSS)
170 .addReg(RegNo: Zmm);
171 } else if (ST->hasAVX2()) {
172 Register Ymm = MRI->createVirtualRegister(RegClass: &X86::VR256RegClass);
173 BuildMI(BB&: MBB, I: MI, MIMD: DL, MCID: TII->get(Opcode: X86::AVX_SET0), DestReg: Ymm);
174 addFrameReference(MIB: BuildMI(BB&: MBB, I: MI, MIMD: DL, MCID: TII->get(Opcode: X86::VMOVUPSYmr)), FI: CfgSS)
175 .addReg(RegNo: Ymm);
176 addFrameReference(MIB: BuildMI(BB&: MBB, I: MI, MIMD: DL, MCID: TII->get(Opcode: X86::VMOVUPSYmr)), FI: CfgSS,
177 Offset: 32)
178 .addReg(RegNo: Ymm);
179 } else {
180 assert(ST->hasSSE2() && "AMX should assume SSE2 enabled");
181 unsigned StoreOpc = ST->hasAVX() ? X86::VMOVUPSmr : X86::MOVUPSmr;
182 Register Xmm = MRI->createVirtualRegister(RegClass: &X86::VR128RegClass);
183 BuildMI(BB&: MBB, I: MI, MIMD: DL, MCID: TII->get(Opcode: X86::V_SET0), DestReg: Xmm);
184 addFrameReference(MIB: BuildMI(BB&: MBB, I: MI, MIMD: DL, MCID: TII->get(Opcode: StoreOpc)), FI: CfgSS)
185 .addReg(RegNo: Xmm);
186 addFrameReference(MIB: BuildMI(BB&: MBB, I: MI, MIMD: DL, MCID: TII->get(Opcode: StoreOpc)), FI: CfgSS, Offset: 16)
187 .addReg(RegNo: Xmm);
188 addFrameReference(MIB: BuildMI(BB&: MBB, I: MI, MIMD: DL, MCID: TII->get(Opcode: StoreOpc)), FI: CfgSS, Offset: 32)
189 .addReg(RegNo: Xmm);
190 addFrameReference(MIB: BuildMI(BB&: MBB, I: MI, MIMD: DL, MCID: TII->get(Opcode: StoreOpc)), FI: CfgSS, Offset: 48)
191 .addReg(RegNo: Xmm);
192 }
193 // Fill in the palette first.
194 addFrameReference(MIB: BuildMI(BB&: MBB, I: MI, MIMD: DL, MCID: TII->get(Opcode: X86::MOV8mi)), FI: CfgSS)
195 .addImm(Val: 1);
196}
197
198/// Insert spill instruction for \p AssignedReg before \p Before.
199/// TODO: Update DBG_VALUEs with \p VirtReg operands with the stack slot.
200void X86FastPreTileConfig::spill(MachineBasicBlock::iterator Before,
201 Register VirtReg, bool Kill) {
202 LLVM_DEBUG(dbgs() << "Spilling " << printReg(VirtReg, TRI) << " \n");
203 int FI = getStackSpaceFor(VirtReg);
204 LLVM_DEBUG(dbgs() << " to stack slot #" << FI << '\n');
205
206 const TargetRegisterClass &RC = *MRI->getRegClass(Reg: VirtReg);
207 // Don't need shape information for tile store, becasue it is adjacent to
208 // the tile def instruction.
209 TII->storeRegToStackSlot(MBB&: *MBB, MI: Before, SrcReg: VirtReg, isKill: Kill, FrameIndex: FI, RC: &RC, TRI,
210 VReg: Register());
211 ++NumStores;
212
213 // TODO: update DBG_VALUEs
214}
215
216/// Insert reload instruction for \p PhysReg before \p Before.
217void X86FastPreTileConfig::reload(MachineBasicBlock::iterator UseMI,
218 Register OrigReg, MachineOperand *RowMO,
219 MachineOperand *ColMO) {
220 int FI = getStackSpaceFor(VirtReg: OrigReg);
221 const TargetRegisterClass &RC = *MRI->getRegClass(Reg: OrigReg);
222 Register TileReg;
223 // Fold copy to tileload
224 // BB1:
225 // spill src to s
226 //
227 // BB2:
228 // t = copy src
229 // -->
230 // t = tileload (s)
231 if (UseMI->isCopy())
232 TileReg = UseMI->getOperand(i: 0).getReg();
233 else
234 TileReg = MRI->createVirtualRegister(RegClass: &RC);
235 // Can't use TII->loadRegFromStackSlot(), because we need the shape
236 // information for reload.
237 // tileloadd (%sp, %idx), %tmm
238 unsigned Opc = X86::PTILELOADDV;
239 Register StrideReg = MRI->createVirtualRegister(RegClass: &X86::GR64_NOSPRegClass);
240 // FIXME: MBB is not the parent of UseMI.
241 MachineInstr *NewMI = BuildMI(BB&: *UseMI->getParent(), I: UseMI, MIMD: DebugLoc(),
242 MCID: TII->get(Opcode: X86::MOV64ri), DestReg: StrideReg)
243 .addImm(Val: 64);
244 NewMI = addFrameReference(
245 MIB: BuildMI(BB&: *UseMI->getParent(), I: UseMI, MIMD: DebugLoc(), MCID: TII->get(Opcode: Opc), DestReg: TileReg)
246 .addReg(RegNo: RowMO->getReg())
247 .addReg(RegNo: ColMO->getReg()),
248 FI);
249 MachineOperand &MO = NewMI->getOperand(i: 5);
250 MO.setReg(StrideReg);
251 MO.setIsKill(true);
252 RowMO->setIsKill(false);
253 ColMO->setIsKill(false);
254 // Erase copy instruction after it is folded.
255 if (UseMI->isCopy()) {
256 UseMI->eraseFromParent();
257 } else {
258 // Replace the register in the user MI.
259 for (auto &MO : UseMI->operands()) {
260 if (MO.isReg() && MO.getReg() == OrigReg)
261 MO.setReg(TileReg);
262 }
263 }
264
265 ++NumLoads;
266 LLVM_DEBUG(dbgs() << "Reloading " << printReg(OrigReg, TRI) << " into "
267 << printReg(TileReg, TRI) << '\n');
268}
269
270static unsigned getTileDefNum(MachineRegisterInfo *MRI, Register Reg) {
271 if (Reg.isVirtual()) {
272 unsigned RegClassID = MRI->getRegClass(Reg)->getID();
273 if (RegClassID == X86::TILERegClassID)
274 return 1;
275 if (RegClassID == X86::TILEPAIRRegClassID)
276 return 2;
277 } else {
278 if (Reg >= X86::TMM0 && Reg <= X86::TMM7)
279 return 1;
280 if (Reg >= X86::TMM0_TMM1 && Reg <= X86::TMM6_TMM7)
281 return 2;
282 }
283 return 0;
284}
285
286static bool isTileRegister(MachineRegisterInfo *MRI, Register VirtReg) {
287 return getTileDefNum(MRI, Reg: VirtReg) > 0;
288}
289
290static bool isTileDef(MachineRegisterInfo *MRI, MachineInstr &MI) {
291 // The instruction must have 3 operands: tile def, row, col.
292 if (MI.isDebugInstr() || MI.getNumOperands() < 3 || !MI.isPseudo())
293 return false;
294 MachineOperand &MO = MI.getOperand(i: 0);
295
296 if (!MO.isReg())
297 return false;
298
299 return getTileDefNum(MRI, Reg: MO.getReg()) > 0;
300}
301
302static ShapeT getShape(MachineRegisterInfo *MRI, Register TileReg) {
303 MachineInstr *MI = MRI->getVRegDef(Reg: TileReg);
304 if (isTileDef(MRI, MI&: *MI)) {
305 MachineOperand *RowMO = &MI->getOperand(i: 1);
306 MachineOperand *ColMO = &MI->getOperand(i: 2);
307 return ShapeT(RowMO, ColMO, MRI);
308 } else if (MI->isCopy()) {
309 TileReg = MI->getOperand(i: 1).getReg();
310 return getShape(MRI, TileReg);
311 }
312
313 // The def should not be PHI node, because we walk the MBB in reverse post
314 // order.
315 assert(MI->isPHI() && "Unexpected PHI when get shape.");
316 llvm_unreachable("Unexpected MI when get shape.");
317}
318
319// BB0:
320// spill t0 to s0
321// BB1:
322// spill t1 to s1
323//
324// BB2:
325// t = phi [t0, bb0] [t1, bb1]
326// -->
327// row = phi [r0, bb0] [r1, bb1]
328// col = phi [c0, bb0] [c1, bb1]
329// s = phi [s0, bb0] [s1, bb1]
330// t = tileload row, col, s
331// The new instruction is inserted at the end of the phi node. The order
332// of the original phi node is not ensured.
333void X86FastPreTileConfig::convertPHI(MachineBasicBlock *MBB,
334 MachineInstr &PHI) {
335 // 1. Create instruction to get stack slot address of each incoming block.
336 // 2. Create PHI node for the stack address.
337 // 3. Create PHI node for shape. If one of the incoming shape is immediate
338 // use the immediate and delete the PHI node.
339 // 4. Create tileload instruction from the stack address.
340 Register StackAddrReg = MRI->createVirtualRegister(RegClass: &X86::GR64_NOSPRegClass);
341 MachineInstrBuilder AddrPHI = BuildMI(BB&: *MBB, I: ++PHI.getIterator(), MIMD: DebugLoc(),
342 MCID: TII->get(Opcode: X86::PHI), DestReg: StackAddrReg);
343 Register RowReg = MRI->createVirtualRegister(RegClass: &X86::GR16RegClass);
344 MachineInstrBuilder RowPHI = BuildMI(BB&: *MBB, I: ++PHI.getIterator(), MIMD: DebugLoc(),
345 MCID: TII->get(Opcode: X86::PHI), DestReg: RowReg);
346 Register ColReg = MRI->createVirtualRegister(RegClass: &X86::GR16RegClass);
347 MachineInstrBuilder ColPHI = BuildMI(BB&: *MBB, I: ++PHI.getIterator(), MIMD: DebugLoc(),
348 MCID: TII->get(Opcode: X86::PHI), DestReg: ColReg);
349 // Record the mapping of phi node and its row/column information.
350 VisitedPHIs[&PHI] = {.Row: RowReg, .Col: ColReg, .StackAddr: StackAddrReg};
351
352 for (unsigned I = 1, E = PHI.getNumOperands(); I != E; I += 2) {
353 // Get the 2 incoming value of tile register and MBB.
354 Register InTileReg = PHI.getOperand(i: I).getReg();
355 // Mark it as liveout, so that it will be spilled when visit
356 // the incoming MBB. Otherwise since phi will be deleted, it
357 // would miss spill when visit incoming MBB.
358 MayLiveAcrossBlocks.set(InTileReg.virtRegIndex());
359 MachineBasicBlock *InMBB = PHI.getOperand(i: I + 1).getMBB();
360
361 MachineInstr *TileDefMI = MRI->getVRegDef(Reg: InTileReg);
362 MachineBasicBlock::iterator InsertPos;
363 if (TileDefMI->isPHI()) {
364 InsertPos = TileDefMI->getParent()->getFirstNonPHI();
365 if (auto It = VisitedPHIs.find(Val: TileDefMI);
366 It != VisitedPHIs.end()) { // circular phi reference
367 // def t1
368 // / \
369 // def t2 t3 = phi(t1, t4) <--
370 // \ / |
371 // t4 = phi(t2, t3)-------------
372 //
373 // For each (row, column and stack address) append phi incoming value.
374 // Create r3 = phi(r1, r4)
375 // Create r4 = phi(r2, r3)
376 Register InRowReg = It->second.Row;
377 Register InColReg = It->second.Col;
378 Register InStackAddrReg = It->second.StackAddr;
379 RowPHI.addReg(RegNo: InRowReg).addMBB(MBB: InMBB);
380 ColPHI.addReg(RegNo: InColReg).addMBB(MBB: InMBB);
381 AddrPHI.addReg(RegNo: InStackAddrReg).addMBB(MBB: InMBB);
382 continue;
383 } else {
384 // Recursively convert PHI to tileload
385 convertPHI(MBB: TileDefMI->getParent(), PHI&: *TileDefMI);
386 // The PHI node is coverted to tileload instruction. Get the stack
387 // address from tileload operands.
388 MachineInstr *TileLoad = MRI->getVRegDef(Reg: InTileReg);
389 assert(TileLoad && TileLoad->getOpcode() == X86::PTILELOADDV);
390 Register InRowReg = TileLoad->getOperand(i: 1).getReg();
391 Register InColReg = TileLoad->getOperand(i: 2).getReg();
392 Register InStackAddrReg = TileLoad->getOperand(i: 3).getReg();
393 RowPHI.addReg(RegNo: InRowReg).addMBB(MBB: InMBB);
394 ColPHI.addReg(RegNo: InColReg).addMBB(MBB: InMBB);
395 AddrPHI.addReg(RegNo: InStackAddrReg).addMBB(MBB: InMBB);
396 }
397 } else {
398 InsertPos = TileDefMI->getIterator();
399
400 // Fill the incoming operand of row/column phi instruction.
401 ShapeT Shape = getShape(MRI, TileReg: InTileReg);
402 Shape.getRow()->setIsKill(false);
403 Shape.getCol()->setIsKill(false);
404 RowPHI.addReg(RegNo: Shape.getRow()->getReg()).addMBB(MBB: InMBB);
405 ColPHI.addReg(RegNo: Shape.getCol()->getReg()).addMBB(MBB: InMBB);
406
407 // The incoming tile register live out of its def BB, it would be spilled.
408 // Create MI to get the spill stack slot address for the tile register
409 int FI = getStackSpaceFor(VirtReg: InTileReg);
410 Register InStackAddrReg =
411 MRI->createVirtualRegister(RegClass: &X86::GR64_NOSPRegClass);
412 addOffset(MIB: BuildMI(BB&: *TileDefMI->getParent(), I: InsertPos, MIMD: DebugLoc(),
413 MCID: TII->get(Opcode: X86::LEA64r), DestReg: InStackAddrReg)
414 .addFrameIndex(Idx: FI),
415 Offset: 0);
416 AddrPHI.addReg(RegNo: InStackAddrReg).addMBB(MBB: InMBB);
417 }
418 }
419
420 MachineBasicBlock::iterator InsertPos = MBB->getFirstNonPHI();
421 Register StrideReg = MRI->createVirtualRegister(RegClass: &X86::GR64_NOSPRegClass);
422 BuildMI(BB&: *MBB, I: InsertPos, MIMD: DebugLoc(), MCID: TII->get(Opcode: X86::MOV64ri), DestReg: StrideReg)
423 .addImm(Val: 64);
424 Register TileReg = PHI.getOperand(i: 0).getReg();
425 MachineInstr *NewMI = addDirectMem(
426 MIB: BuildMI(BB&: *MBB, I: InsertPos, MIMD: DebugLoc(), MCID: TII->get(Opcode: X86::PTILELOADDV), DestReg: TileReg)
427 .addReg(RegNo: RowReg)
428 .addReg(RegNo: ColReg),
429 Reg: StackAddrReg);
430 MachineOperand &MO = NewMI->getOperand(i: 5);
431 MO.setReg(StrideReg);
432 MO.setIsKill(true);
433 PHI.eraseFromParent();
434 VisitedPHIs.erase(Val: &PHI);
435}
436
437static bool isTileRegDef(MachineRegisterInfo *MRI, MachineInstr &MI) {
438 MachineOperand &MO = MI.getOperand(i: 0);
439 if (MO.isReg() && MO.getReg().isVirtual() && isTileRegister(MRI, VirtReg: MO.getReg()))
440 return true;
441 return false;
442}
443
444void X86FastPreTileConfig::canonicalizePHIs(MachineBasicBlock &MBB) {
445 SmallVector<MachineInstr *, 8> PHIs;
446
447 for (MachineInstr &MI : MBB) {
448 if (!MI.isPHI())
449 break;
450 if (!isTileRegDef(MRI, MI))
451 continue;
452 PHIs.push_back(Elt: &MI);
453 }
454 // Canonicalize the phi node first. One tile phi may depeneds previous
455 // phi node. For below case, we need convert %t4.
456 //
457 // BB0:
458 // %t3 = phi (t1 BB1, t2 BB0)
459 // %t4 = phi (t5 BB1, t3 BB0)
460 // -->
461 // %t3 = phi (t1 BB1, t2 BB0)
462 // %t4 = phi (t5 BB1, t2 BB0)
463 //
464 while (!PHIs.empty()) {
465 MachineInstr *PHI = PHIs.pop_back_val();
466
467 // Find the operand that is incoming from the same MBB and the def
468 // is also phi node.
469 MachineOperand *InMO = nullptr;
470 MachineInstr *DefMI = nullptr;
471 for (unsigned I = 1, E = PHI->getNumOperands(); I != E; I += 2) {
472 Register InTileReg = PHI->getOperand(i: I).getReg();
473 MachineBasicBlock *InMBB = PHI->getOperand(i: I + 1).getMBB();
474 DefMI = MRI->getVRegDef(Reg: InTileReg);
475 if (InMBB != &MBB || !DefMI->isPHI())
476 continue;
477
478 InMO = &PHI->getOperand(i: I);
479 break;
480 }
481 // If can't find such operand, do nothing.
482 if (!InMO)
483 continue;
484
485 // Current phi node depends on previous phi node. Break the
486 // dependency.
487 Register DefTileReg;
488 for (unsigned I = 1, E = DefMI->getNumOperands(); I != E; I += 2) {
489 MachineBasicBlock *InMBB = PHI->getOperand(i: I + 1).getMBB();
490 if (InMBB != &MBB)
491 continue;
492 DefTileReg = DefMI->getOperand(i: I).getReg();
493 InMO->setReg(DefTileReg);
494 break;
495 }
496 }
497}
498
499void X86FastPreTileConfig::convertPHIs(MachineBasicBlock &MBB) {
500 SmallVector<MachineInstr *, 8> PHIs;
501 for (MachineInstr &MI : MBB) {
502 if (!MI.isPHI())
503 break;
504 if (!isTileRegDef(MRI, MI))
505 continue;
506 PHIs.push_back(Elt: &MI);
507 }
508 while (!PHIs.empty()) {
509 MachineInstr *MI = PHIs.pop_back_val();
510 VisitedPHIs.clear();
511 convertPHI(MBB: &MBB, PHI&: *MI);
512 }
513}
514
515// PreTileConfig should configure the tile registers based on basic
516// block.
517bool X86FastPreTileConfig::configBasicBlock(MachineBasicBlock &MBB) {
518 this->MBB = &MBB;
519 bool Change = false;
520 MachineInstr *LastShapeMI = nullptr;
521 MachineInstr *LastTileCfg = nullptr;
522 bool HasUnconfigTile = false;
523
524 auto Config = [&](MachineInstr &Before) {
525 if (CfgSS == -1)
526 CfgSS = MFI->CreateStackObject(Size: ST->getTileConfigSize(),
527 Alignment: ST->getTileConfigAlignment(), isSpillSlot: false);
528 LastTileCfg = addFrameReference(
529 MIB: BuildMI(BB&: MBB, I&: Before, MIMD: DebugLoc(), MCID: TII->get(Opcode: X86::PLDTILECFGV)), FI: CfgSS);
530 LastShapeMI = nullptr;
531 Change = true;
532 };
533 auto HasTileOperand = [](MachineRegisterInfo *MRI, MachineInstr &MI) {
534 for (const MachineOperand &MO : MI.operands()) {
535 if (!MO.isReg())
536 continue;
537 Register Reg = MO.getReg();
538 if (Reg.isVirtual() && isTileRegister(MRI, VirtReg: Reg))
539 return true;
540 }
541 return false;
542 };
543 for (MachineInstr &MI : reverse(C&: MBB)) {
544 // We have transformed phi node before configuring BB.
545 if (MI.isPHI())
546 break;
547 // Don't collect the shape of used tile, the tile should be defined
548 // before the tile use. Spill and reload would happen if there is only
549 // tile use after ldtilecfg, so the shape can be collected from reload.
550 // Take below code for example. %t would be reloaded before tilestore
551 // call
552 // ....
553 // tilestore %r, %c, %t
554 // -->
555 // call
556 // ldtilecfg
557 // %t = tileload %r, %c
558 // tilestore %r, %c, %t
559 if (HasTileOperand(MRI, MI))
560 HasUnconfigTile = true;
561 // According to AMX ABI, all the tile registers including config register
562 // are volatile. Caller need to save/restore config register.
563 if (MI.isCall() && HasUnconfigTile) {
564 MachineBasicBlock::iterator I;
565 if (LastShapeMI && dominates(MBB, A: MI, B: LastShapeMI))
566 I = ++LastShapeMI->getIterator();
567 else
568 I = ++MI.getIterator();
569 Config(*I);
570 HasUnconfigTile = false;
571 continue;
572 }
573 if (!isTileDef(MRI, MI))
574 continue;
575 //
576 //---------------------------------------------------------------------
577 // Don't handle COPY instruction. If the src and dst of the COPY can be
578 // in the same config in below case, we just check the shape of t0.
579 // def row0
580 // def col0
581 // ldtilecfg
582 // t0 = tielzero(row0, col0)
583 // t1 = copy t0
584 // ...
585 // If the src and dst of the COPY can NOT be in the same config in below
586 // case. Reload would be generated befor the copy instruction.
587 // def row0
588 // def col0
589 // t0 = tielzero(row0, col0)
590 // spill t0
591 // ...
592 // def row1
593 // def col1
594 // ldtilecfg
595 // t1 = tilezero(row1, col1)
596 // reload t0
597 // t1 = copy t0
598 //---------------------------------------------------------------------
599 //
600 // If MI dominate the last shape def instruction, we need insert
601 // ldtilecfg after LastShapeMI now. The config doesn't include
602 // current MI.
603 // def row0
604 // def col0
605 // tilezero(row0, col0) <- MI
606 // def row1
607 // def col1
608 // ldtilecfg <- insert
609 // tilezero(row1, col1)
610 if (LastShapeMI && dominates(MBB, A: MI, B: LastShapeMI))
611 Config(*(++LastShapeMI->getIterator()));
612 MachineOperand *RowMO = &MI.getOperand(i: 1);
613 MachineOperand *ColMO = &MI.getOperand(i: 2);
614 MachineInstr *RowMI = MRI->getVRegDef(Reg: RowMO->getReg());
615 MachineInstr *ColMI = MRI->getVRegDef(Reg: ColMO->getReg());
616 // If the shape is defined in current MBB, check the domination.
617 // FIXME how about loop?
618 if (RowMI->getParent() == &MBB) {
619 if (!LastShapeMI)
620 LastShapeMI = RowMI;
621 else if (dominates(MBB, A: LastShapeMI, B: RowMI))
622 LastShapeMI = RowMI;
623 }
624 if (ColMI->getParent() == &MBB) {
625 if (!LastShapeMI)
626 LastShapeMI = ColMI;
627 else if (dominates(MBB, A: LastShapeMI, B: ColMI))
628 LastShapeMI = ColMI;
629 }
630 unsigned TileDefNum = getTileDefNum(MRI, Reg: MI.getOperand(i: 0).getReg());
631 if (TileDefNum > 1) {
632 for (unsigned I = 1; I < TileDefNum; I++) {
633 MachineOperand *ColxMO = &MI.getOperand(i: 2 + I);
634 MachineInstr *ColxMI = MRI->getVRegDef(Reg: ColxMO->getReg());
635 if (ColxMI->getParent() == &MBB) {
636 if (!LastShapeMI)
637 LastShapeMI = ColxMI;
638 else if (dominates(MBB, A: LastShapeMI, B: ColxMI))
639 LastShapeMI = ColxMI;
640 }
641 }
642 }
643 // If there is user live out of the tilecfg, spill it and reload in
644 // before the user.
645 Register TileReg = MI.getOperand(i: 0).getReg();
646 if (mayLiveOut(VirtReg: TileReg, CfgMI: LastTileCfg))
647 spill(Before: ++MI.getIterator(), VirtReg: TileReg, Kill: false);
648 for (MachineInstr &UseMI : MRI->use_instructions(Reg: TileReg)) {
649 if (UseMI.getParent() == &MBB) {
650 // check user should not across ldtilecfg
651 if (!LastTileCfg || !dominates(MBB, A: LastTileCfg, B: UseMI))
652 continue;
653 // reload befor UseMI
654 reload(UseMI: UseMI.getIterator(), OrigReg: TileReg, RowMO, ColMO);
655 } else {
656 // Don't reload for phi instruction, we handle phi reload separately.
657 // TODO: merge the reload for the same user MBB.
658 if (!UseMI.isPHI())
659 reload(UseMI: UseMI.getIterator(), OrigReg: TileReg, RowMO, ColMO);
660 }
661 }
662 }
663
664 // Configure tile registers at the head of the MBB
665 if (HasUnconfigTile) {
666 MachineInstr *Before;
667 if (LastShapeMI == nullptr || LastShapeMI->isPHI())
668 Before = &*MBB.getFirstNonPHI();
669 else
670 Before = &*(++LastShapeMI->getIterator());
671
672 Config(*Before);
673 }
674
675 return Change;
676}
677
678bool X86FastPreTileConfig::runOnMachineFunction(MachineFunction &MFunc) {
679 X86FI = MFunc.getInfo<X86MachineFunctionInfo>();
680 // Early exit in the common case of non-AMX code.
681 if (X86FI->getAMXProgModel() != AMXProgModelEnum::ManagedRA)
682 return false;
683
684 MF = &MFunc;
685 MRI = &MFunc.getRegInfo();
686 ST = &MFunc.getSubtarget<X86Subtarget>();
687 TII = ST->getInstrInfo();
688 MFI = &MFunc.getFrameInfo();
689 TRI = ST->getRegisterInfo();
690 CfgSS = -1;
691
692 unsigned NumVirtRegs = MRI->getNumVirtRegs();
693
694 StackSlotForVirtReg.resize(s: NumVirtRegs);
695 MayLiveAcrossBlocks.clear();
696 // We will create register during config. *3 is to make sure
697 // the virtual register number doesn't exceed the size of
698 // the bit vector.
699 MayLiveAcrossBlocks.resize(N: NumVirtRegs * 3);
700 bool Change = false;
701 assert(MRI->isSSA());
702
703 // Canonicalize the phi node first.
704 for (MachineBasicBlock &MBB : MFunc)
705 canonicalizePHIs(MBB);
706
707 // Loop over all of the basic blocks in reverse post order and insert
708 // ldtilecfg for tile registers. The reserse post order is to facilitate
709 // PHI node convert.
710 ReversePostOrderTraversal<MachineFunction *> RPOT(MF);
711 for (MachineBasicBlock *MBB : RPOT) {
712 convertPHIs(MBB&: *MBB);
713 Change |= configBasicBlock(MBB&: *MBB);
714 }
715
716 if (Change)
717 InitializeTileConfigStackSpace();
718
719 StackSlotForVirtReg.clear();
720 return Change;
721}
722
723FunctionPass *llvm::createX86FastPreTileConfigPass() {
724 return new X86FastPreTileConfig();
725}
726