1//===-- X86PreTileConfig.cpp - Tile Register Pre-configure-----------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9/// \file Pass to pre-config the shapes of AMX registers
10/// AMX register needs to be configured before use. The shapes of AMX register
11/// are encoded in the 1st and 2nd machine operand of AMX pseudo instructions.
12///
13/// The instruction ldtilecfg is used to config the shapes. It must be reachable
14/// for all variable shapes. ldtilecfg will be inserted more than once if we
15/// cannot find a dominating point for all AMX instructions.
16///
17/// The configure register is caller saved according to ABI. We need to insert
18/// ldtilecfg again after the call instruction if callee clobbers any AMX
19/// registers.
20///
21/// This pass calculates all points that ldtilecfg need to be inserted to and
22/// insert them. It reports error if the reachability conditions aren't met.
23//
24//===----------------------------------------------------------------------===//
25
26#include "X86.h"
27#include "X86InstrBuilder.h"
28#include "X86MachineFunctionInfo.h"
29#include "X86RegisterInfo.h"
30#include "X86Subtarget.h"
31#include "llvm/ADT/SmallSet.h"
32#include "llvm/CodeGen/MachineFunctionPass.h"
33#include "llvm/CodeGen/MachineInstr.h"
34#include "llvm/CodeGen/MachineLoopInfo.h"
35#include "llvm/CodeGen/MachineModuleInfo.h"
36#include "llvm/CodeGen/MachineRegisterInfo.h"
37#include "llvm/CodeGen/Passes.h"
38#include "llvm/CodeGen/TargetInstrInfo.h"
39#include "llvm/CodeGen/TargetRegisterInfo.h"
40#include "llvm/IR/Module.h"
41#include "llvm/InitializePasses.h"
42
43using namespace llvm;
44
45#define DEBUG_TYPE "tile-pre-config"
46
47static void emitErrorMsg(MachineFunction &MF) {
48 LLVMContext &Context = MF.getFunction().getContext();
49 Context.emitError(
50 ErrorStr: MF.getName() +
51 ": Failed to config tile register, please define the shape earlier");
52}
53
54namespace {
55
56struct MIRef {
57 MachineInstr *MI = nullptr;
58 MachineBasicBlock *MBB = nullptr;
59 // A virtual position for instruction that will be inserted after MI.
60 size_t Pos = 0;
61 MIRef() = default;
62 MIRef(MachineBasicBlock *MBB) : MBB(MBB) {
63 for (auto I = MBB->begin(), E = MBB->end(); I != E && I->isPHI();
64 ++I, ++Pos)
65 MI = &*I;
66 }
67 MIRef(MachineInstr *MI)
68 : MI(MI), MBB(MI->getParent()),
69 Pos(std::distance(first: MBB->instr_begin(), last: ++MI->getIterator())) {}
70 MIRef(MachineInstr *MI, MachineBasicBlock *MBB)
71 : MI(MI), MBB(MBB),
72 Pos(std::distance(first: MBB->instr_begin(), last: ++MI->getIterator())) {}
73 MIRef(MachineInstr *MI, MachineBasicBlock *MBB, size_t Pos)
74 : MI(MI), MBB(MBB), Pos(Pos) {}
75 operator bool() const { return MBB != nullptr; }
76 bool operator==(const MIRef &RHS) const {
77 return MI == RHS.MI && MBB == RHS.MBB;
78 }
79 bool operator!=(const MIRef &RHS) const { return !(*this == RHS); }
80 bool operator<(const MIRef &RHS) const {
81 // Comparison between different BBs happens when inserting a MIRef into set.
82 // So we compare MBB first to make the insertion happy.
83 return std::tie(args: MBB, args: Pos) < std::tie(args: RHS.MBB, args: RHS.Pos);
84 }
85 bool operator>(const MIRef &RHS) const {
86 // Comparison between different BBs happens when inserting a MIRef into set.
87 // So we compare MBB first to make the insertion happy.
88 return std::tie(args: MBB, args: Pos) > std::tie(args: RHS.MBB, args: RHS.Pos);
89 }
90};
91
92struct BBInfo {
93 MIRef FirstAMX;
94 MIRef LastCall;
95 bool HasAMXRegLiveIn = false;
96 bool TileCfgForbidden = false;
97 bool NeedTileCfgLiveIn = false;
98};
99
100class X86PreTileConfig : public MachineFunctionPass {
101 MachineRegisterInfo *MRI = nullptr;
102 const MachineLoopInfo *MLI = nullptr;
103 SmallSet<MachineInstr *, 8> DefVisited;
104 DenseMap<MachineBasicBlock *, BBInfo> BBVisitedInfo;
105 DenseMap<MachineBasicBlock *, SmallVector<MIRef, 8>> ShapeBBs;
106
107 /// Check if the callee will clobber AMX registers.
108 bool isDestructiveCall(MachineInstr &MI, BitVector UsableRegs) {
109 auto Iter = llvm::find_if(
110 Range: MI.operands(), P: [](MachineOperand &MO) { return MO.isRegMask(); });
111 if (Iter == MI.operands_end())
112 return false;
113 UsableRegs.clearBitsInMask(Mask: Iter->getRegMask());
114 return !UsableRegs.none();
115 }
116
117 /// Check if MI is AMX pseudo instruction.
118 bool isAMXInstruction(MachineInstr &MI) {
119 if (MI.isPHI() || MI.isDebugInstr() || MI.getNumOperands() < 3)
120 return false;
121 switch (MI.getOpcode()) {
122 case X86::PTILESTOREDV:
123 case X86::PTCVTROWD2PSrreV:
124 case X86::PTCVTROWD2PSrriV:
125 case X86::PTCVTROWPS2BF16HrreV:
126 case X86::PTCVTROWPS2BF16HrriV:
127 case X86::PTCVTROWPS2BF16LrreV:
128 case X86::PTCVTROWPS2BF16LrriV:
129 case X86::PTCVTROWPS2PHHrreV:
130 case X86::PTCVTROWPS2PHHrriV:
131 case X86::PTCVTROWPS2PHLrreV:
132 case X86::PTCVTROWPS2PHLrriV:
133 case X86::PTILEMOVROWrreV:
134 case X86::PTILEMOVROWrriV:
135 return true;
136 }
137
138 // We can simply check if it is AMX instruction by its def.
139 // But we should exclude old API which uses physical registers.
140 MachineOperand &MO = MI.getOperand(i: 0);
141 if (!MO.isReg() || !MO.getReg().isVirtual())
142 return false;
143
144 unsigned Shapes = 0;
145 if (MRI->getRegClass(Reg: MO.getReg())->getID() == X86::TILERegClassID)
146 Shapes = 1;
147 if (MRI->getRegClass(Reg: MO.getReg())->getID() == X86::TILEPAIRRegClassID)
148 Shapes = 2;
149 if (!Shapes)
150 return false;
151
152 collectShapeInfo(MI, Shapes);
153 return true;
154 }
155
156 /// Check if it is an edge from loop bottom to loop head.
157 bool isLoopBackEdge(MachineBasicBlock *Header, MachineBasicBlock *Bottom) {
158 if (!MLI->isLoopHeader(BB: Header))
159 return false;
160 auto *ML = MLI->getLoopFor(BB: Header);
161 if (ML->contains(BB: Bottom) && ML->isLoopLatch(BB: Bottom))
162 return true;
163
164 return false;
165 }
166
167 /// Collect the shape def information for later use.
168 void collectShapeInfo(MachineInstr &MI, unsigned Shapes);
169
170 /// Try to hoist shapes definded below AMX instructions.
171 bool hoistShapesInBB(MachineBasicBlock *MBB, SmallVectorImpl<MIRef> &Shapes) {
172 MIRef &FirstAMX = BBVisitedInfo[MBB].FirstAMX;
173 auto FirstShapeBelowAMX = llvm::lower_bound(Range&: Shapes, Value&: FirstAMX);
174 auto InsertPoint = FirstAMX.MI->getIterator();
175 for (auto I = FirstShapeBelowAMX, E = Shapes.end(); I != E; ++I) {
176 // Do not hoist instructions that access memory.
177 if (I->MI->mayLoadOrStore())
178 return false;
179 for (auto &MO : I->MI->operands()) {
180 if (MO.isDef())
181 continue;
182 // Do not hoist instructions if the sources' def under AMX instruction.
183 // TODO: We can handle isMoveImmediate MI here.
184 if (MO.isReg() && MIRef(MRI->getVRegDef(Reg: MO.getReg())) > FirstAMX)
185 return false;
186 // TODO: Maybe need more checks here.
187 }
188 MBB->insert(I: InsertPoint, M: I->MI->removeFromParent());
189 }
190 // We only need to mark the last shape in the BB now.
191 Shapes.clear();
192 Shapes.push_back(Elt: MIRef(&*--InsertPoint, MBB));
193 return true;
194 }
195
196public:
197 X86PreTileConfig() : MachineFunctionPass(ID) {}
198
199 /// Return the pass name.
200 StringRef getPassName() const override {
201 return "Tile Register Pre-configure";
202 }
203
204 /// X86PreTileConfig analysis usage.
205 void getAnalysisUsage(AnalysisUsage &AU) const override {
206 AU.setPreservesAll();
207 AU.addRequired<MachineLoopInfoWrapperPass>();
208 MachineFunctionPass::getAnalysisUsage(AU);
209 }
210
211 /// Clear MF related structures.
212 void releaseMemory() override {
213 ShapeBBs.clear();
214 DefVisited.clear();
215 BBVisitedInfo.clear();
216 }
217
218 /// Perform ldtilecfg instructions inserting.
219 bool runOnMachineFunction(MachineFunction &MF) override;
220
221 static char ID;
222};
223
224} // end anonymous namespace
225
226char X86PreTileConfig::ID = 0;
227
228INITIALIZE_PASS_BEGIN(X86PreTileConfig, "tilepreconfig",
229 "Tile Register Pre-configure", false, false)
230INITIALIZE_PASS_DEPENDENCY(MachineLoopInfoWrapperPass)
231INITIALIZE_PASS_END(X86PreTileConfig, "tilepreconfig",
232 "Tile Register Pre-configure", false, false)
233
234void X86PreTileConfig::collectShapeInfo(MachineInstr &MI, unsigned Shapes) {
235 auto RecordShape = [&](MachineInstr *MI, MachineBasicBlock *MBB) {
236 MIRef MIR(MI, MBB);
237 auto &Refs = ShapeBBs[MBB];
238 auto I = llvm::lower_bound(Range&: Refs, Value&: MIR);
239 if (I == Refs.end() || *I != MIR)
240 Refs.insert(I, Elt: MIR);
241 };
242
243 // All shapes have same row in multi-tile operand.
244 SmallVector<Register, 8> WorkList;
245 for (unsigned I = 1; I < Shapes + 2; ++I)
246 WorkList.push_back(Elt: MI.getOperand(i: I).getReg());
247 while (!WorkList.empty()) {
248 Register R = WorkList.pop_back_val();
249 MachineInstr *DefMI = MRI->getVRegDef(Reg: R);
250 assert(DefMI && "R must has one define instruction");
251 MachineBasicBlock *DefMBB = DefMI->getParent();
252 if (DefMI->isMoveImmediate() || !DefVisited.insert(Ptr: DefMI).second)
253 continue;
254
255 // This happens when column = 0 in multi-tile operand.
256 if (DefMI->getOpcode() == X86::COPY) {
257 MachineInstr *MI = MRI->getVRegDef(Reg: DefMI->getOperand(i: 1).getReg());
258 if (MI && MI->isMoveImmediate())
259 continue;
260 }
261
262 if (DefMI->isPHI()) {
263 for (unsigned I = 1; I < DefMI->getNumOperands(); I += 2)
264 if (isLoopBackEdge(Header: DefMBB, Bottom: DefMI->getOperand(i: I + 1).getMBB()))
265 RecordShape(DefMI, DefMBB); // In this case, PHI is also a shape def.
266 else
267 WorkList.push_back(Elt: DefMI->getOperand(i: I).getReg());
268 } else {
269 RecordShape(DefMI, DefMBB);
270 }
271 }
272}
273
274bool X86PreTileConfig::runOnMachineFunction(MachineFunction &MF) {
275 X86MachineFunctionInfo *X86FI = MF.getInfo<X86MachineFunctionInfo>();
276 // Early exit in the common case of non-AMX code.
277 if (X86FI->getAMXProgModel() != AMXProgModelEnum::ManagedRA)
278 return false;
279
280 const X86Subtarget &ST = MF.getSubtarget<X86Subtarget>();
281 const TargetInstrInfo *TII = ST.getInstrInfo();
282 const TargetRegisterInfo *TRI = ST.getRegisterInfo();
283 const TargetRegisterClass *RC = TRI->getRegClass(i: X86::TILERegClassID);
284
285 BitVector AMXRegs(TRI->getNumRegs());
286 for (unsigned I = 0; I < RC->getNumRegs(); I++)
287 AMXRegs.set(X86::TMM0 + I);
288
289 // Iterate MF to collect information.
290 MRI = &MF.getRegInfo();
291 MLI = &getAnalysis<MachineLoopInfoWrapperPass>().getLI();
292 SmallSet<MIRef, 8> CfgNeedInsert;
293 SmallVector<MachineBasicBlock *, 8> CfgLiveInBBs;
294 for (auto &MBB : MF) {
295 size_t Pos = 0;
296 auto &Info = BBVisitedInfo[&MBB];
297 for (auto &MI : MBB) {
298 ++Pos;
299 if (isAMXInstruction(MI)) {
300 // If there's call before the AMX, we need to reload tile config.
301 if (Info.LastCall)
302 CfgNeedInsert.insert(V: Info.LastCall);
303 else // Otherwise, we need tile config to live in this BB.
304 Info.NeedTileCfgLiveIn = true;
305 // Always record the first AMX in case there's shape def after it.
306 if (!Info.FirstAMX)
307 Info.FirstAMX = MIRef(&MI, &MBB, Pos);
308 } else if (MI.isCall() && isDestructiveCall(MI, UsableRegs: AMXRegs)) {
309 // Record the call only if the callee clobbers all AMX registers.
310 Info.LastCall = MIRef(&MI, &MBB, Pos);
311 }
312 }
313 if (Info.NeedTileCfgLiveIn) {
314 if (&MBB == &MF.front())
315 CfgNeedInsert.insert(V: MIRef(&MBB));
316 else
317 CfgLiveInBBs.push_back(Elt: &MBB);
318 }
319 if (Info.FirstAMX || Info.HasAMXRegLiveIn)
320 for (auto *Succ : MBB.successors())
321 if (!isLoopBackEdge(Header: Succ, Bottom: &MBB))
322 BBVisitedInfo[Succ].HasAMXRegLiveIn = true;
323 }
324
325 // Update NeedTileCfgLiveIn for predecessors.
326 while (!CfgLiveInBBs.empty()) {
327 MachineBasicBlock *MBB = CfgLiveInBBs.pop_back_val();
328 for (auto *Pred : MBB->predecessors()) {
329 auto &Info = BBVisitedInfo[Pred];
330 if (Info.LastCall) {
331 CfgNeedInsert.insert(V: Info.LastCall);
332 } else if (!Info.NeedTileCfgLiveIn) {
333 Info.NeedTileCfgLiveIn = true;
334 if (Pred == &MF.front())
335 CfgNeedInsert.insert(V: MIRef(Pred));
336 else
337 CfgLiveInBBs.push_back(Elt: Pred);
338 }
339 }
340 }
341
342 // There's no AMX instruction if we didn't find a tile config live in point.
343 if (CfgNeedInsert.empty())
344 return false;
345
346 // Avoid to insert ldtilecfg before any shape defs.
347 SmallVector<MachineBasicBlock *, 8> WorkList;
348 for (auto &I : ShapeBBs) {
349 auto &Info = BBVisitedInfo[I.first];
350 // TODO: We can hoist shapes across BBs here.
351 if (Info.HasAMXRegLiveIn) {
352 // We are not able to config tile registers since the shape to config
353 // is not defined yet. Emit error message and continue. The function
354 // would not config tile registers.
355 emitErrorMsg(MF);
356 return false;
357 }
358 if (Info.FirstAMX && Info.FirstAMX < I.second.back() &&
359 !hoistShapesInBB(MBB: I.first, Shapes&: I.second)) {
360 emitErrorMsg(MF);
361 return false;
362 }
363 WorkList.push_back(Elt: I.first);
364 }
365 while (!WorkList.empty()) {
366 MachineBasicBlock *MBB = WorkList.pop_back_val();
367 for (auto *Pred : MBB->predecessors()) {
368 auto &Info = BBVisitedInfo[Pred];
369 if (!Info.TileCfgForbidden && !isLoopBackEdge(Header: MBB, Bottom: Pred)) {
370 Info.TileCfgForbidden = true;
371 WorkList.push_back(Elt: Pred);
372 }
373 }
374 }
375
376 DebugLoc DL;
377 SmallSet<MIRef, 8> VisitedOrInserted;
378 int SS = MF.getFrameInfo().CreateStackObject(
379 Size: ST.getTileConfigSize(), Alignment: ST.getTileConfigAlignment(), isSpillSlot: false);
380
381 // Try to insert for the tile config live in points.
382 for (const auto &I : CfgNeedInsert) {
383 SmallSet<MIRef, 8> InsertPoints;
384 SmallVector<MIRef, 8> WorkList({I});
385 while (!WorkList.empty()) {
386 MIRef I = WorkList.pop_back_val();
387 if (!VisitedOrInserted.count(V: I)) {
388 if (!BBVisitedInfo[I.MBB].TileCfgForbidden) {
389 // If the BB is all shapes reachable, stop sink and try to insert.
390 InsertPoints.insert(V: I);
391 } else {
392 // Avoid the BB to be multi visited.
393 VisitedOrInserted.insert(V: I);
394 // Sink the inserting point along the chain with NeedTileCfgLiveIn =
395 // true when MBB isn't all shapes reachable.
396 for (auto *Succ : I.MBB->successors())
397 if (BBVisitedInfo[Succ].NeedTileCfgLiveIn)
398 WorkList.push_back(Elt: MIRef(Succ));
399 }
400 }
401 }
402
403 // A given point might be forked due to shape conditions are not met.
404 for (MIRef I : InsertPoints) {
405 // Make sure we insert ldtilecfg after the last shape def in MBB.
406 auto It = ShapeBBs.find(Val: I.MBB);
407 if (It != ShapeBBs.end() && I < It->second.back())
408 I = It->second.back();
409 // There're chances the MBB is sunk more than once. Record it to avoid
410 // multi insert.
411 if (VisitedOrInserted.insert(V: I).second) {
412 auto II = I.MI ? I.MI->getIterator() : I.MBB->instr_begin();
413 addFrameReference(MIB: BuildMI(BB&: *I.MBB, I: ++II, MIMD: DL, MCID: TII->get(Opcode: X86::PLDTILECFGV)),
414 FI: SS);
415 }
416 }
417 }
418
419 // Zero stack slot.
420 MachineBasicBlock &MBB = MF.front();
421 MachineInstr *MI = &*MBB.begin();
422 if (ST.hasAVX512()) {
423 Register Zmm = MRI->createVirtualRegister(RegClass: &X86::VR512RegClass);
424 BuildMI(BB&: MBB, I: MI, MIMD: DL, MCID: TII->get(Opcode: X86::AVX512_512_SET0), DestReg: Zmm);
425 addFrameReference(MIB: BuildMI(BB&: MBB, I: MI, MIMD: DL, MCID: TII->get(Opcode: X86::VMOVUPSZmr)), FI: SS)
426 .addReg(RegNo: Zmm);
427 } else if (ST.hasAVX2()) {
428 Register Ymm = MRI->createVirtualRegister(RegClass: &X86::VR256RegClass);
429 BuildMI(BB&: MBB, I: MI, MIMD: DL, MCID: TII->get(Opcode: X86::AVX_SET0), DestReg: Ymm);
430 addFrameReference(MIB: BuildMI(BB&: MBB, I: MI, MIMD: DL, MCID: TII->get(Opcode: X86::VMOVUPSYmr)), FI: SS)
431 .addReg(RegNo: Ymm);
432 addFrameReference(MIB: BuildMI(BB&: MBB, I: MI, MIMD: DL, MCID: TII->get(Opcode: X86::VMOVUPSYmr)), FI: SS, Offset: 32)
433 .addReg(RegNo: Ymm);
434 } else {
435 assert(ST.hasSSE2() && "AMX should assume SSE2 enabled");
436 unsigned StoreOpc = ST.hasAVX() ? X86::VMOVUPSmr : X86::MOVUPSmr;
437 Register Xmm = MRI->createVirtualRegister(RegClass: &X86::VR128RegClass);
438 BuildMI(BB&: MBB, I: MI, MIMD: DL, MCID: TII->get(Opcode: X86::V_SET0), DestReg: Xmm);
439 addFrameReference(MIB: BuildMI(BB&: MBB, I: MI, MIMD: DL, MCID: TII->get(Opcode: StoreOpc)), FI: SS).addReg(RegNo: Xmm);
440 addFrameReference(MIB: BuildMI(BB&: MBB, I: MI, MIMD: DL, MCID: TII->get(Opcode: StoreOpc)), FI: SS, Offset: 16)
441 .addReg(RegNo: Xmm);
442 addFrameReference(MIB: BuildMI(BB&: MBB, I: MI, MIMD: DL, MCID: TII->get(Opcode: StoreOpc)), FI: SS, Offset: 32)
443 .addReg(RegNo: Xmm);
444 addFrameReference(MIB: BuildMI(BB&: MBB, I: MI, MIMD: DL, MCID: TII->get(Opcode: StoreOpc)), FI: SS, Offset: 48)
445 .addReg(RegNo: Xmm);
446 }
447 // Fill in the palette first.
448 addFrameReference(MIB: BuildMI(BB&: MBB, I: MI, MIMD: DL, MCID: TII->get(Opcode: X86::MOV8mi)), FI: SS).addImm(Val: 1);
449
450 return true;
451}
452
453FunctionPass *llvm::createX86PreTileConfigPass() {
454 return new X86PreTileConfig();
455}
456