1//===-- X86PreTileConfig.cpp - Tile Register Pre-configure-----------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9/// \file Pass to pre-config the shapes of AMX registers
10/// AMX register needs to be configured before use. The shapes of AMX register
11/// are encoded in the 1st and 2nd machine operand of AMX pseudo instructions.
12///
13/// The instruction ldtilecfg is used to config the shapes. It must be reachable
14/// for all variable shapes. ldtilecfg will be inserted more than once if we
15/// cannot find a dominating point for all AMX instructions.
16///
17/// The configure register is caller saved according to ABI. We need to insert
18/// ldtilecfg again after the call instruction if callee clobbers any AMX
19/// registers.
20///
21/// This pass calculates all points that ldtilecfg need to be inserted to and
22/// insert them. It reports error if the reachability conditions aren't met.
23//
24//===----------------------------------------------------------------------===//
25
26#include "X86.h"
27#include "X86InstrBuilder.h"
28#include "X86MachineFunctionInfo.h"
29#include "X86RegisterInfo.h"
30#include "X86Subtarget.h"
31#include "llvm/ADT/ScopeExit.h"
32#include "llvm/ADT/SmallSet.h"
33#include "llvm/CodeGen/MachineFunctionPass.h"
34#include "llvm/CodeGen/MachineInstr.h"
35#include "llvm/CodeGen/MachineLoopInfo.h"
36#include "llvm/CodeGen/MachineModuleInfo.h"
37#include "llvm/CodeGen/MachineRegisterInfo.h"
38#include "llvm/CodeGen/Passes.h"
39#include "llvm/CodeGen/TargetInstrInfo.h"
40#include "llvm/CodeGen/TargetRegisterInfo.h"
41#include "llvm/IR/Module.h"
42#include "llvm/InitializePasses.h"
43
44using namespace llvm;
45
46#define DEBUG_TYPE "x86-pre-tile-config"
47
48static void emitErrorMsg(MachineFunction &MF) {
49 LLVMContext &Context = MF.getFunction().getContext();
50 Context.emitError(
51 ErrorStr: MF.getName() +
52 ": Failed to config tile register, please define the shape earlier");
53}
54
55namespace {
56
57struct MIRef {
58 MachineInstr *MI = nullptr;
59 MachineBasicBlock *MBB = nullptr;
60 // A virtual position for instruction that will be inserted after MI.
61 size_t Pos = 0;
62 MIRef() = default;
63 MIRef(MachineBasicBlock *MBB) : MBB(MBB) {
64 for (auto I = MBB->begin(), E = MBB->end(); I != E && I->isPHI();
65 ++I, ++Pos)
66 MI = &*I;
67 }
68 MIRef(MachineInstr *MI)
69 : MI(MI), MBB(MI->getParent()),
70 Pos(std::distance(first: MBB->instr_begin(), last: ++MI->getIterator())) {}
71 MIRef(MachineInstr *MI, MachineBasicBlock *MBB)
72 : MI(MI), MBB(MBB),
73 Pos(std::distance(first: MBB->instr_begin(), last: ++MI->getIterator())) {}
74 MIRef(MachineInstr *MI, MachineBasicBlock *MBB, size_t Pos)
75 : MI(MI), MBB(MBB), Pos(Pos) {}
76 operator bool() const { return MBB != nullptr; }
77 bool operator==(const MIRef &RHS) const {
78 return MI == RHS.MI && MBB == RHS.MBB;
79 }
80 bool operator!=(const MIRef &RHS) const { return !(*this == RHS); }
81 bool operator<(const MIRef &RHS) const {
82 // Comparison between different BBs happens when inserting a MIRef into set.
83 // So we compare MBB first to make the insertion happy.
84 return std::tie(args: MBB, args: Pos) < std::tie(args: RHS.MBB, args: RHS.Pos);
85 }
86 bool operator>(const MIRef &RHS) const {
87 // Comparison between different BBs happens when inserting a MIRef into set.
88 // So we compare MBB first to make the insertion happy.
89 return std::tie(args: MBB, args: Pos) > std::tie(args: RHS.MBB, args: RHS.Pos);
90 }
91};
92
93struct BBInfo {
94 MIRef FirstAMX;
95 MIRef LastCall;
96 bool HasAMXRegLiveIn = false;
97 bool TileCfgForbidden = false;
98 bool NeedTileCfgLiveIn = false;
99};
100
101class X86PreTileConfigImpl {
102 std::function<MachineLoopInfo *()> GetMLI;
103 MachineRegisterInfo *MRI = nullptr;
104 const MachineLoopInfo *MLI = nullptr;
105 SmallPtrSet<MachineInstr *, 8> DefVisited;
106 DenseMap<MachineBasicBlock *, BBInfo> BBVisitedInfo;
107 DenseMap<MachineBasicBlock *, SmallVector<MIRef, 8>> ShapeBBs;
108
109 /// Check if the callee will clobber AMX registers.
110 bool isDestructiveCall(MachineInstr &MI, BitVector UsableRegs) {
111 auto Iter = llvm::find_if(
112 Range: MI.operands(), P: [](MachineOperand &MO) { return MO.isRegMask(); });
113 if (Iter == MI.operands_end())
114 return false;
115 UsableRegs.clearBitsInMask(Mask: Iter->getRegMask());
116 return !UsableRegs.none();
117 }
118
119 /// Check if MI is AMX pseudo instruction.
120 bool isAMXInstruction(MachineInstr &MI) {
121 if (MI.isPHI() || MI.isDebugInstr() || MI.getNumOperands() < 3)
122 return false;
123 switch (MI.getOpcode()) {
124 case X86::PTILESTOREDV:
125 case X86::PTCVTROWD2PSrreV:
126 case X86::PTCVTROWD2PSrriV:
127 case X86::PTCVTROWPS2BF16HrreV:
128 case X86::PTCVTROWPS2BF16HrriV:
129 case X86::PTCVTROWPS2BF16LrreV:
130 case X86::PTCVTROWPS2BF16LrriV:
131 case X86::PTCVTROWPS2PHHrreV:
132 case X86::PTCVTROWPS2PHHrriV:
133 case X86::PTCVTROWPS2PHLrreV:
134 case X86::PTCVTROWPS2PHLrriV:
135 case X86::PTILEMOVROWrreV:
136 case X86::PTILEMOVROWrriV:
137 return true;
138 }
139
140 // We can simply check if it is AMX instruction by its def.
141 // But we should exclude old API which uses physical registers.
142 MachineOperand &MO = MI.getOperand(i: 0);
143 if (!MO.isReg() || !MO.getReg().isVirtual())
144 return false;
145
146 if (MRI->getRegClass(Reg: MO.getReg())->getID() != X86::TILERegClassID)
147 return false;
148
149 collectShapeInfo(MI);
150 return true;
151 }
152
153 /// Check if it is an edge from loop bottom to loop head.
154 bool isLoopBackEdge(MachineBasicBlock *Header, MachineBasicBlock *Bottom) {
155 if (!MLI->isLoopHeader(BB: Header))
156 return false;
157 auto *ML = MLI->getLoopFor(BB: Header);
158 if (ML->contains(BB: Bottom) && ML->isLoopLatch(BB: Bottom))
159 return true;
160
161 return false;
162 }
163
164 /// Collect the shape def information for later use.
165 void collectShapeInfo(MachineInstr &MI);
166
167 /// Try to hoist shapes definded below AMX instructions.
168 bool hoistShapesInBB(MachineBasicBlock *MBB, SmallVectorImpl<MIRef> &Shapes) {
169 MIRef &FirstAMX = BBVisitedInfo[MBB].FirstAMX;
170 auto FirstShapeBelowAMX = llvm::lower_bound(Range&: Shapes, Value&: FirstAMX);
171 auto InsertPoint = FirstAMX.MI->getIterator();
172 for (auto I = FirstShapeBelowAMX, E = Shapes.end(); I != E; ++I) {
173 // Do not hoist instructions that access memory.
174 if (I->MI->mayLoadOrStore())
175 return false;
176 for (auto &MO : I->MI->operands()) {
177 if (MO.isDef())
178 continue;
179 // Do not hoist instructions if the sources' def under AMX instruction.
180 // TODO: We can handle isMoveImmediate MI here.
181 if (MO.isReg() && MIRef(MRI->getVRegDef(Reg: MO.getReg())) > FirstAMX)
182 return false;
183 // TODO: Maybe need more checks here.
184 }
185 MBB->insert(I: InsertPoint, M: I->MI->removeFromParent());
186 }
187 // We only need to mark the last shape in the BB now.
188 Shapes.clear();
189 Shapes.push_back(Elt: MIRef(&*--InsertPoint, MBB));
190 return true;
191 }
192
193 /// Clear MF related structures.
194 void releaseMemory() {
195 ShapeBBs.clear();
196 DefVisited.clear();
197 BBVisitedInfo.clear();
198 }
199
200public:
201 X86PreTileConfigImpl(std::function<MachineLoopInfo *()> GetMLI)
202 : GetMLI(GetMLI) {}
203 bool runOnMachineFunction(MachineFunction &MF);
204};
205
206class X86PreTileConfigLegacy : public MachineFunctionPass {
207public:
208 X86PreTileConfigLegacy() : MachineFunctionPass(ID) {}
209
210 /// Return the pass name.
211 StringRef getPassName() const override {
212 return "Tile Register Pre-configure";
213 }
214
215 /// X86PreTileConfig analysis usage.
216 void getAnalysisUsage(AnalysisUsage &AU) const override {
217 AU.setPreservesAll();
218 AU.addRequired<MachineLoopInfoWrapperPass>();
219 MachineFunctionPass::getAnalysisUsage(AU);
220 }
221
222 /// Perform ldtilecfg instructions inserting.
223 bool runOnMachineFunction(MachineFunction &MF) override;
224
225 static char ID;
226};
227
228} // end anonymous namespace
229
230char X86PreTileConfigLegacy::ID = 0;
231
232INITIALIZE_PASS_BEGIN(X86PreTileConfigLegacy, "tilepreconfig",
233 "Tile Register Pre-configure", false, false)
234INITIALIZE_PASS_DEPENDENCY(MachineLoopInfoWrapperPass)
235INITIALIZE_PASS_END(X86PreTileConfigLegacy, "tilepreconfig",
236 "Tile Register Pre-configure", false, false)
237
238void X86PreTileConfigImpl::collectShapeInfo(MachineInstr &MI) {
239 auto RecordShape = [&](MachineInstr *MI, MachineBasicBlock *MBB) {
240 MIRef MIR(MI, MBB);
241 auto &Refs = ShapeBBs[MBB];
242 auto I = llvm::lower_bound(Range&: Refs, Value&: MIR);
243 if (I == Refs.end() || *I != MIR)
244 Refs.insert(I, Elt: MIR);
245 };
246
247 SmallVector<Register, 8> WorkList(
248 {MI.getOperand(i: 1).getReg(), MI.getOperand(i: 2).getReg()});
249 while (!WorkList.empty()) {
250 Register R = WorkList.pop_back_val();
251 MachineInstr *DefMI = MRI->getVRegDef(Reg: R);
252 assert(DefMI && "R must has one define instruction");
253 MachineBasicBlock *DefMBB = DefMI->getParent();
254 if (DefMI->isMoveImmediate() || !DefVisited.insert(Ptr: DefMI).second)
255 continue;
256
257 if (DefMI->isPHI()) {
258 for (unsigned I = 1; I < DefMI->getNumOperands(); I += 2)
259 if (isLoopBackEdge(Header: DefMBB, Bottom: DefMI->getOperand(i: I + 1).getMBB()))
260 RecordShape(DefMI, DefMBB); // In this case, PHI is also a shape def.
261 else
262 WorkList.push_back(Elt: DefMI->getOperand(i: I).getReg());
263 } else {
264 RecordShape(DefMI, DefMBB);
265 }
266 }
267}
268
269bool X86PreTileConfigImpl::runOnMachineFunction(MachineFunction &MF) {
270 scope_exit ClearStateOnExit([this] { releaseMemory(); });
271
272 X86MachineFunctionInfo *X86FI = MF.getInfo<X86MachineFunctionInfo>();
273 // Early exit in the common case of non-AMX code.
274 if (X86FI->getAMXProgModel() != AMXProgModelEnum::ManagedRA)
275 return false;
276
277 const X86Subtarget &ST = MF.getSubtarget<X86Subtarget>();
278 const TargetInstrInfo *TII = ST.getInstrInfo();
279 const TargetRegisterInfo *TRI = ST.getRegisterInfo();
280 const TargetRegisterClass *RC = TRI->getRegClass(i: X86::TILERegClassID);
281
282 BitVector AMXRegs(TRI->getNumRegs());
283 for (unsigned I = 0; I < RC->getNumRegs(); I++)
284 AMXRegs.set(X86::TMM0 + I);
285
286 // Iterate MF to collect information.
287 MRI = &MF.getRegInfo();
288 MLI = GetMLI();
289 SmallSet<MIRef, 8> CfgNeedInsert;
290 SmallVector<MachineBasicBlock *, 8> CfgLiveInBBs;
291 for (auto &MBB : MF) {
292 size_t Pos = 0;
293 auto &Info = BBVisitedInfo[&MBB];
294 for (auto &MI : MBB) {
295 ++Pos;
296 if (isAMXInstruction(MI)) {
297 // If there's call before the AMX, we need to reload tile config.
298 if (Info.LastCall)
299 CfgNeedInsert.insert(V: Info.LastCall);
300 else // Otherwise, we need tile config to live in this BB.
301 Info.NeedTileCfgLiveIn = true;
302 // Always record the first AMX in case there's shape def after it.
303 if (!Info.FirstAMX)
304 Info.FirstAMX = MIRef(&MI, &MBB, Pos);
305 } else if (MI.isCall() && isDestructiveCall(MI, UsableRegs: AMXRegs)) {
306 // Record the call only if the callee clobbers all AMX registers.
307 Info.LastCall = MIRef(&MI, &MBB, Pos);
308 }
309 }
310 if (Info.NeedTileCfgLiveIn) {
311 if (&MBB == &MF.front())
312 CfgNeedInsert.insert(V: MIRef(&MBB));
313 else
314 CfgLiveInBBs.push_back(Elt: &MBB);
315 }
316 if (Info.FirstAMX || Info.HasAMXRegLiveIn)
317 for (auto *Succ : MBB.successors())
318 if (!isLoopBackEdge(Header: Succ, Bottom: &MBB))
319 BBVisitedInfo[Succ].HasAMXRegLiveIn = true;
320 }
321
322 // Update NeedTileCfgLiveIn for predecessors.
323 while (!CfgLiveInBBs.empty()) {
324 MachineBasicBlock *MBB = CfgLiveInBBs.pop_back_val();
325 for (auto *Pred : MBB->predecessors()) {
326 auto &Info = BBVisitedInfo[Pred];
327 if (Info.LastCall) {
328 CfgNeedInsert.insert(V: Info.LastCall);
329 } else if (!Info.NeedTileCfgLiveIn) {
330 Info.NeedTileCfgLiveIn = true;
331 if (Pred == &MF.front())
332 CfgNeedInsert.insert(V: MIRef(Pred));
333 else
334 CfgLiveInBBs.push_back(Elt: Pred);
335 }
336 }
337 }
338
339 // There's no AMX instruction if we didn't find a tile config live in point.
340 if (CfgNeedInsert.empty())
341 return false;
342
343 // Avoid to insert ldtilecfg before any shape defs.
344 SmallVector<MachineBasicBlock *, 8> WorkList;
345 for (auto &I : ShapeBBs) {
346 auto &Info = BBVisitedInfo[I.first];
347 // TODO: We can hoist shapes across BBs here.
348 if (Info.HasAMXRegLiveIn) {
349 // We are not able to config tile registers since the shape to config
350 // is not defined yet. Emit error message and continue. The function
351 // would not config tile registers.
352 emitErrorMsg(MF);
353 return false;
354 }
355 if (Info.FirstAMX && Info.FirstAMX < I.second.back() &&
356 !hoistShapesInBB(MBB: I.first, Shapes&: I.second)) {
357 emitErrorMsg(MF);
358 return false;
359 }
360 WorkList.push_back(Elt: I.first);
361 }
362 while (!WorkList.empty()) {
363 MachineBasicBlock *MBB = WorkList.pop_back_val();
364 for (auto *Pred : MBB->predecessors()) {
365 auto &Info = BBVisitedInfo[Pred];
366 if (!Info.TileCfgForbidden && !isLoopBackEdge(Header: MBB, Bottom: Pred)) {
367 Info.TileCfgForbidden = true;
368 WorkList.push_back(Elt: Pred);
369 }
370 }
371 }
372
373 DebugLoc DL;
374 SmallSet<MIRef, 8> VisitedOrInserted;
375 int SS = MF.getFrameInfo().CreateStackObject(
376 Size: ST.getTileConfigSize(), Alignment: ST.getTileConfigAlignment(), isSpillSlot: false);
377
378 // Try to insert for the tile config live in points.
379 for (const auto &I : CfgNeedInsert) {
380 SmallSet<MIRef, 8> InsertPoints;
381 SmallVector<MIRef, 8> WorkList({I});
382 while (!WorkList.empty()) {
383 MIRef I = WorkList.pop_back_val();
384 if (!VisitedOrInserted.count(V: I)) {
385 if (!BBVisitedInfo[I.MBB].TileCfgForbidden) {
386 // If the BB is all shapes reachable, stop sink and try to insert.
387 InsertPoints.insert(V: I);
388 } else {
389 // Avoid the BB to be multi visited.
390 VisitedOrInserted.insert(V: I);
391 // Sink the inserting point along the chain with NeedTileCfgLiveIn =
392 // true when MBB isn't all shapes reachable.
393 for (auto *Succ : I.MBB->successors())
394 if (BBVisitedInfo[Succ].NeedTileCfgLiveIn)
395 WorkList.push_back(Elt: MIRef(Succ));
396 }
397 }
398 }
399
400 // A given point might be forked due to shape conditions are not met.
401 for (MIRef I : InsertPoints) {
402 // Make sure we insert ldtilecfg after the last shape def in MBB.
403 auto It = ShapeBBs.find(Val: I.MBB);
404 if (It != ShapeBBs.end() && I < It->second.back())
405 I = It->second.back();
406 // There're chances the MBB is sunk more than once. Record it to avoid
407 // multi insert.
408 if (VisitedOrInserted.insert(V: I).second) {
409 auto II = I.MI ? I.MI->getIterator() : I.MBB->instr_begin();
410 addFrameReference(MIB: BuildMI(BB&: *I.MBB, I: ++II, MIMD: DL, MCID: TII->get(Opcode: X86::PLDTILECFGV)),
411 FI: SS);
412 }
413 }
414 }
415
416 // Zero stack slot.
417 MachineBasicBlock &MBB = MF.front();
418 MachineInstr *MI = &*MBB.begin();
419 if (ST.hasAVX512()) {
420 Register Zmm = MRI->createVirtualRegister(RegClass: &X86::VR512RegClass);
421 BuildMI(BB&: MBB, I: MI, MIMD: DL, MCID: TII->get(Opcode: X86::AVX512_512_SET0), DestReg: Zmm);
422 addFrameReference(MIB: BuildMI(BB&: MBB, I: MI, MIMD: DL, MCID: TII->get(Opcode: X86::VMOVUPSZmr)), FI: SS)
423 .addReg(RegNo: Zmm);
424 } else if (ST.hasAVX2()) {
425 Register Ymm = MRI->createVirtualRegister(RegClass: &X86::VR256RegClass);
426 BuildMI(BB&: MBB, I: MI, MIMD: DL, MCID: TII->get(Opcode: X86::AVX_SET0), DestReg: Ymm);
427 addFrameReference(MIB: BuildMI(BB&: MBB, I: MI, MIMD: DL, MCID: TII->get(Opcode: X86::VMOVUPSYmr)), FI: SS)
428 .addReg(RegNo: Ymm);
429 addFrameReference(MIB: BuildMI(BB&: MBB, I: MI, MIMD: DL, MCID: TII->get(Opcode: X86::VMOVUPSYmr)), FI: SS, Offset: 32)
430 .addReg(RegNo: Ymm);
431 } else {
432 assert(ST.hasSSE2() && "AMX should assume SSE2 enabled");
433 unsigned StoreOpc = ST.hasAVX() ? X86::VMOVUPSmr : X86::MOVUPSmr;
434 Register Xmm = MRI->createVirtualRegister(RegClass: &X86::VR128RegClass);
435 BuildMI(BB&: MBB, I: MI, MIMD: DL, MCID: TII->get(Opcode: X86::V_SET0), DestReg: Xmm);
436 addFrameReference(MIB: BuildMI(BB&: MBB, I: MI, MIMD: DL, MCID: TII->get(Opcode: StoreOpc)), FI: SS).addReg(RegNo: Xmm);
437 addFrameReference(MIB: BuildMI(BB&: MBB, I: MI, MIMD: DL, MCID: TII->get(Opcode: StoreOpc)), FI: SS, Offset: 16)
438 .addReg(RegNo: Xmm);
439 addFrameReference(MIB: BuildMI(BB&: MBB, I: MI, MIMD: DL, MCID: TII->get(Opcode: StoreOpc)), FI: SS, Offset: 32)
440 .addReg(RegNo: Xmm);
441 addFrameReference(MIB: BuildMI(BB&: MBB, I: MI, MIMD: DL, MCID: TII->get(Opcode: StoreOpc)), FI: SS, Offset: 48)
442 .addReg(RegNo: Xmm);
443 }
444 // Fill in the palette first.
445 addFrameReference(MIB: BuildMI(BB&: MBB, I: MI, MIMD: DL, MCID: TII->get(Opcode: X86::MOV8mi)), FI: SS).addImm(Val: 1);
446
447 return true;
448}
449
450FunctionPass *llvm::createX86PreTileConfigLegacyPass() {
451 return new X86PreTileConfigLegacy();
452}
453
454bool X86PreTileConfigLegacy::runOnMachineFunction(MachineFunction &MF) {
455 X86PreTileConfigImpl Impl(
456 [this]() { return &getAnalysis<MachineLoopInfoWrapperPass>().getLI(); });
457 return Impl.runOnMachineFunction(MF);
458}
459
460PreservedAnalyses
461X86PreTileConfigPass::run(MachineFunction &MF,
462 MachineFunctionAnalysisManager &MFAM) {
463 X86PreTileConfigImpl Impl(
464 [&MFAM, &MF]() { return &MFAM.getResult<MachineLoopAnalysis>(IR&: MF); });
465 return Impl.runOnMachineFunction(MF)
466 ? getMachineFunctionPassPreservedAnalyses()
467 .preserveSet<CFGAnalyses>()
468 : PreservedAnalyses::all();
469}
470