1 | //===-- X86PreTileConfig.cpp - Tile Register Pre-configure-----------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | /// \file Pass to pre-config the shapes of AMX registers |
10 | /// AMX register needs to be configured before use. The shapes of AMX register |
11 | /// are encoded in the 1st and 2nd machine operand of AMX pseudo instructions. |
12 | /// |
13 | /// The instruction ldtilecfg is used to config the shapes. It must be reachable |
14 | /// for all variable shapes. ldtilecfg will be inserted more than once if we |
15 | /// cannot find a dominating point for all AMX instructions. |
16 | /// |
17 | /// The configure register is caller saved according to ABI. We need to insert |
18 | /// ldtilecfg again after the call instruction if callee clobbers any AMX |
19 | /// registers. |
20 | /// |
21 | /// This pass calculates all points that ldtilecfg need to be inserted to and |
22 | /// insert them. It reports error if the reachability conditions aren't met. |
23 | // |
24 | //===----------------------------------------------------------------------===// |
25 | |
26 | #include "X86.h" |
27 | #include "X86InstrBuilder.h" |
28 | #include "X86MachineFunctionInfo.h" |
29 | #include "X86RegisterInfo.h" |
30 | #include "X86Subtarget.h" |
31 | #include "llvm/ADT/SmallSet.h" |
32 | #include "llvm/CodeGen/MachineFunctionPass.h" |
33 | #include "llvm/CodeGen/MachineInstr.h" |
34 | #include "llvm/CodeGen/MachineLoopInfo.h" |
35 | #include "llvm/CodeGen/MachineModuleInfo.h" |
36 | #include "llvm/CodeGen/MachineRegisterInfo.h" |
37 | #include "llvm/CodeGen/Passes.h" |
38 | #include "llvm/CodeGen/TargetInstrInfo.h" |
39 | #include "llvm/CodeGen/TargetRegisterInfo.h" |
40 | #include "llvm/IR/Module.h" |
41 | #include "llvm/InitializePasses.h" |
42 | |
43 | using namespace llvm; |
44 | |
45 | #define DEBUG_TYPE "tile-pre-config" |
46 | |
47 | static void emitErrorMsg(MachineFunction &MF) { |
48 | LLVMContext &Context = MF.getFunction().getContext(); |
49 | Context.emitError( |
50 | ErrorStr: MF.getName() + |
51 | ": Failed to config tile register, please define the shape earlier" ); |
52 | } |
53 | |
54 | namespace { |
55 | |
56 | struct MIRef { |
57 | MachineInstr *MI = nullptr; |
58 | MachineBasicBlock *MBB = nullptr; |
59 | // A virtual position for instruction that will be inserted after MI. |
60 | size_t Pos = 0; |
61 | MIRef() = default; |
62 | MIRef(MachineBasicBlock *MBB) : MBB(MBB) { |
63 | for (auto I = MBB->begin(), E = MBB->end(); I != E && I->isPHI(); |
64 | ++I, ++Pos) |
65 | MI = &*I; |
66 | } |
67 | MIRef(MachineInstr *MI) |
68 | : MI(MI), MBB(MI->getParent()), |
69 | Pos(std::distance(first: MBB->instr_begin(), last: ++MI->getIterator())) {} |
70 | MIRef(MachineInstr *MI, MachineBasicBlock *MBB) |
71 | : MI(MI), MBB(MBB), |
72 | Pos(std::distance(first: MBB->instr_begin(), last: ++MI->getIterator())) {} |
73 | MIRef(MachineInstr *MI, MachineBasicBlock *MBB, size_t Pos) |
74 | : MI(MI), MBB(MBB), Pos(Pos) {} |
75 | operator bool() const { return MBB != nullptr; } |
76 | bool operator==(const MIRef &RHS) const { |
77 | return MI == RHS.MI && MBB == RHS.MBB; |
78 | } |
79 | bool operator!=(const MIRef &RHS) const { return !(*this == RHS); } |
80 | bool operator<(const MIRef &RHS) const { |
81 | // Comparison between different BBs happens when inserting a MIRef into set. |
82 | // So we compare MBB first to make the insertion happy. |
83 | return MBB < RHS.MBB || (MBB == RHS.MBB && Pos < RHS.Pos); |
84 | } |
85 | bool operator>(const MIRef &RHS) const { |
86 | // Comparison between different BBs happens when inserting a MIRef into set. |
87 | // So we compare MBB first to make the insertion happy. |
88 | return MBB > RHS.MBB || (MBB == RHS.MBB && Pos > RHS.Pos); |
89 | } |
90 | }; |
91 | |
92 | struct BBInfo { |
93 | MIRef FirstAMX; |
94 | MIRef LastCall; |
95 | bool HasAMXRegLiveIn = false; |
96 | bool TileCfgForbidden = false; |
97 | bool NeedTileCfgLiveIn = false; |
98 | }; |
99 | |
100 | class X86PreTileConfig : public MachineFunctionPass { |
101 | MachineRegisterInfo *MRI = nullptr; |
102 | const MachineLoopInfo *MLI = nullptr; |
103 | SmallSet<MachineInstr *, 8> DefVisited; |
104 | DenseMap<MachineBasicBlock *, BBInfo> BBVisitedInfo; |
105 | DenseMap<MachineBasicBlock *, SmallVector<MIRef, 8>> ShapeBBs; |
106 | |
107 | /// Check if the callee will clobber AMX registers. |
108 | bool isDestructiveCall(MachineInstr &MI, BitVector UsableRegs) { |
109 | auto Iter = llvm::find_if( |
110 | Range: MI.operands(), P: [](MachineOperand &MO) { return MO.isRegMask(); }); |
111 | if (Iter == MI.operands_end()) |
112 | return false; |
113 | UsableRegs.clearBitsInMask(Mask: Iter->getRegMask()); |
114 | return !UsableRegs.none(); |
115 | } |
116 | |
117 | /// Check if MI is AMX pseudo instruction. |
118 | bool isAMXInstruction(MachineInstr &MI) { |
119 | if (MI.isPHI() || MI.isDebugInstr() || MI.getNumOperands() < 3) |
120 | return false; |
121 | MachineOperand &MO = MI.getOperand(i: 0); |
122 | // We can simply check if it is AMX instruction by its def. |
123 | // But we should exclude old API which uses physical registers. |
124 | if (MO.isReg() && MO.getReg().isVirtual() && |
125 | MRI->getRegClass(Reg: MO.getReg())->getID() == X86::TILERegClassID) { |
126 | collectShapeInfo(MI); |
127 | return true; |
128 | } |
129 | // PTILESTOREDV is the only exception that doesn't def a AMX register. |
130 | return MI.getOpcode() == X86::PTILESTOREDV; |
131 | } |
132 | |
133 | /// Check if it is an edge from loop bottom to loop head. |
134 | bool isLoopBackEdge(MachineBasicBlock *, MachineBasicBlock *Bottom) { |
135 | if (!MLI->isLoopHeader(BB: Header)) |
136 | return false; |
137 | auto *ML = MLI->getLoopFor(BB: Header); |
138 | if (ML->contains(BB: Bottom) && ML->isLoopLatch(BB: Bottom)) |
139 | return true; |
140 | |
141 | return false; |
142 | } |
143 | |
144 | /// Collect the shape def information for later use. |
145 | void collectShapeInfo(MachineInstr &MI); |
146 | |
147 | /// Try to hoist shapes definded below AMX instructions. |
148 | bool hoistShapesInBB(MachineBasicBlock *MBB, SmallVectorImpl<MIRef> &Shapes) { |
149 | MIRef &FirstAMX = BBVisitedInfo[MBB].FirstAMX; |
150 | auto FirstShapeBelowAMX = llvm::lower_bound(Range&: Shapes, Value&: FirstAMX); |
151 | auto InsertPoint = FirstAMX.MI->getIterator(); |
152 | for (auto I = FirstShapeBelowAMX, E = Shapes.end(); I != E; ++I) { |
153 | // Do not hoist instructions that access memory. |
154 | if (I->MI->mayLoadOrStore()) |
155 | return false; |
156 | for (auto &MO : I->MI->operands()) { |
157 | if (MO.isDef()) |
158 | continue; |
159 | // Do not hoist instructions if the sources' def under AMX instruction. |
160 | // TODO: We can handle isMoveImmediate MI here. |
161 | if (MO.isReg() && MIRef(MRI->getVRegDef(Reg: MO.getReg())) > FirstAMX) |
162 | return false; |
163 | // TODO: Maybe need more checks here. |
164 | } |
165 | MBB->insert(I: InsertPoint, M: I->MI->removeFromParent()); |
166 | } |
167 | // We only need to mark the last shape in the BB now. |
168 | Shapes.clear(); |
169 | Shapes.push_back(Elt: MIRef(&*--InsertPoint, MBB)); |
170 | return true; |
171 | } |
172 | |
173 | public: |
174 | X86PreTileConfig() : MachineFunctionPass(ID) {} |
175 | |
176 | /// Return the pass name. |
177 | StringRef getPassName() const override { |
178 | return "Tile Register Pre-configure" ; |
179 | } |
180 | |
181 | /// X86PreTileConfig analysis usage. |
182 | void getAnalysisUsage(AnalysisUsage &AU) const override { |
183 | AU.setPreservesAll(); |
184 | AU.addRequired<MachineLoopInfoWrapperPass>(); |
185 | MachineFunctionPass::getAnalysisUsage(AU); |
186 | } |
187 | |
188 | /// Clear MF related structures. |
189 | void releaseMemory() override { |
190 | ShapeBBs.clear(); |
191 | DefVisited.clear(); |
192 | BBVisitedInfo.clear(); |
193 | } |
194 | |
195 | /// Perform ldtilecfg instructions inserting. |
196 | bool runOnMachineFunction(MachineFunction &MF) override; |
197 | |
198 | static char ID; |
199 | }; |
200 | |
201 | } // end anonymous namespace |
202 | |
203 | char X86PreTileConfig::ID = 0; |
204 | |
205 | INITIALIZE_PASS_BEGIN(X86PreTileConfig, "tilepreconfig" , |
206 | "Tile Register Pre-configure" , false, false) |
207 | INITIALIZE_PASS_DEPENDENCY(MachineLoopInfoWrapperPass) |
208 | INITIALIZE_PASS_END(X86PreTileConfig, "tilepreconfig" , |
209 | "Tile Register Pre-configure" , false, false) |
210 | |
211 | void X86PreTileConfig::collectShapeInfo(MachineInstr &MI) { |
212 | auto RecordShape = [&](MachineInstr *MI, MachineBasicBlock *MBB) { |
213 | MIRef MIR(MI, MBB); |
214 | auto I = llvm::lower_bound(Range&: ShapeBBs[MBB], Value&: MIR); |
215 | if (I == ShapeBBs[MBB].end() || *I != MIR) |
216 | ShapeBBs[MBB].insert(I, Elt: MIR); |
217 | }; |
218 | |
219 | SmallVector<Register, 8> WorkList( |
220 | {MI.getOperand(i: 1).getReg(), MI.getOperand(i: 2).getReg()}); |
221 | while (!WorkList.empty()) { |
222 | Register R = WorkList.pop_back_val(); |
223 | MachineInstr *DefMI = MRI->getVRegDef(Reg: R); |
224 | assert(DefMI && "R must has one define instruction" ); |
225 | MachineBasicBlock *DefMBB = DefMI->getParent(); |
226 | if (DefMI->isMoveImmediate() || !DefVisited.insert(Ptr: DefMI).second) |
227 | continue; |
228 | if (DefMI->isPHI()) { |
229 | for (unsigned I = 1; I < DefMI->getNumOperands(); I += 2) |
230 | if (isLoopBackEdge(Header: DefMBB, Bottom: DefMI->getOperand(i: I + 1).getMBB())) |
231 | RecordShape(DefMI, DefMBB); // In this case, PHI is also a shape def. |
232 | else |
233 | WorkList.push_back(Elt: DefMI->getOperand(i: I).getReg()); |
234 | } else { |
235 | RecordShape(DefMI, DefMBB); |
236 | } |
237 | } |
238 | } |
239 | |
240 | bool X86PreTileConfig::runOnMachineFunction(MachineFunction &MF) { |
241 | X86MachineFunctionInfo *X86FI = MF.getInfo<X86MachineFunctionInfo>(); |
242 | // Early exit in the common case of non-AMX code. |
243 | if (X86FI->getAMXProgModel() != AMXProgModelEnum::ManagedRA) |
244 | return false; |
245 | |
246 | const X86Subtarget &ST = MF.getSubtarget<X86Subtarget>(); |
247 | const TargetInstrInfo *TII = ST.getInstrInfo(); |
248 | const TargetRegisterInfo *TRI = ST.getRegisterInfo(); |
249 | const TargetRegisterClass *RC = TRI->getRegClass(i: X86::TILERegClassID); |
250 | |
251 | BitVector AMXRegs(TRI->getNumRegs()); |
252 | for (unsigned I = 0; I < RC->getNumRegs(); I++) |
253 | AMXRegs.set(X86::TMM0 + I); |
254 | |
255 | // Iterate MF to collect information. |
256 | MRI = &MF.getRegInfo(); |
257 | MLI = &getAnalysis<MachineLoopInfoWrapperPass>().getLI(); |
258 | SmallSet<MIRef, 8> CfgNeedInsert; |
259 | SmallVector<MachineBasicBlock *, 8> CfgLiveInBBs; |
260 | for (auto &MBB : MF) { |
261 | size_t Pos = 0; |
262 | for (auto &MI : MBB) { |
263 | ++Pos; |
264 | if (isAMXInstruction(MI)) { |
265 | // If there's call before the AMX, we need to reload tile config. |
266 | if (BBVisitedInfo[&MBB].LastCall) |
267 | CfgNeedInsert.insert(V: BBVisitedInfo[&MBB].LastCall); |
268 | else // Otherwise, we need tile config to live in this BB. |
269 | BBVisitedInfo[&MBB].NeedTileCfgLiveIn = true; |
270 | // Always record the first AMX in case there's shape def after it. |
271 | if (!BBVisitedInfo[&MBB].FirstAMX) |
272 | BBVisitedInfo[&MBB].FirstAMX = MIRef(&MI, &MBB, Pos); |
273 | } else if (MI.isCall() && isDestructiveCall(MI, UsableRegs: AMXRegs)) { |
274 | // Record the call only if the callee clobbers all AMX registers. |
275 | BBVisitedInfo[&MBB].LastCall = MIRef(&MI, &MBB, Pos); |
276 | } |
277 | } |
278 | if (BBVisitedInfo[&MBB].NeedTileCfgLiveIn) { |
279 | if (&MBB == &MF.front()) |
280 | CfgNeedInsert.insert(V: MIRef(&MBB)); |
281 | else |
282 | CfgLiveInBBs.push_back(Elt: &MBB); |
283 | } |
284 | if (BBVisitedInfo[&MBB].FirstAMX || BBVisitedInfo[&MBB].HasAMXRegLiveIn) |
285 | for (auto *Succ : MBB.successors()) |
286 | if (!isLoopBackEdge(Header: Succ, Bottom: &MBB)) |
287 | BBVisitedInfo[Succ].HasAMXRegLiveIn = true; |
288 | } |
289 | |
290 | // Update NeedTileCfgLiveIn for predecessors. |
291 | while (!CfgLiveInBBs.empty()) { |
292 | MachineBasicBlock *MBB = CfgLiveInBBs.pop_back_val(); |
293 | for (auto *Pred : MBB->predecessors()) { |
294 | if (BBVisitedInfo[Pred].LastCall) { |
295 | CfgNeedInsert.insert(V: BBVisitedInfo[Pred].LastCall); |
296 | } else if (!BBVisitedInfo[Pred].NeedTileCfgLiveIn) { |
297 | BBVisitedInfo[Pred].NeedTileCfgLiveIn = true; |
298 | if (Pred == &MF.front()) |
299 | CfgNeedInsert.insert(V: MIRef(Pred)); |
300 | else |
301 | CfgLiveInBBs.push_back(Elt: Pred); |
302 | } |
303 | } |
304 | } |
305 | |
306 | // There's no AMX instruction if we didn't find a tile config live in point. |
307 | if (CfgNeedInsert.empty()) |
308 | return false; |
309 | |
310 | // Avoid to insert ldtilecfg before any shape defs. |
311 | SmallVector<MachineBasicBlock *, 8> WorkList; |
312 | for (auto &I : ShapeBBs) { |
313 | // TODO: We can hoist shapes across BBs here. |
314 | if (BBVisitedInfo[I.first].HasAMXRegLiveIn) { |
315 | // We are not able to config tile registers since the shape to config |
316 | // is not defined yet. Emit error message and continue. The function |
317 | // would not config tile registers. |
318 | emitErrorMsg(MF); |
319 | return false; |
320 | } |
321 | if (BBVisitedInfo[I.first].FirstAMX && |
322 | BBVisitedInfo[I.first].FirstAMX < I.second.back() && |
323 | !hoistShapesInBB(MBB: I.first, Shapes&: I.second)) { |
324 | emitErrorMsg(MF); |
325 | return false; |
326 | } |
327 | WorkList.push_back(Elt: I.first); |
328 | } |
329 | while (!WorkList.empty()) { |
330 | MachineBasicBlock *MBB = WorkList.pop_back_val(); |
331 | for (auto *Pred : MBB->predecessors()) { |
332 | if (!BBVisitedInfo[Pred].TileCfgForbidden && !isLoopBackEdge(Header: MBB, Bottom: Pred)) { |
333 | BBVisitedInfo[Pred].TileCfgForbidden = true; |
334 | WorkList.push_back(Elt: Pred); |
335 | } |
336 | } |
337 | } |
338 | |
339 | DebugLoc DL; |
340 | SmallSet<MIRef, 8> VisitedOrInserted; |
341 | int SS = MF.getFrameInfo().CreateStackObject( |
342 | Size: ST.getTileConfigSize(), Alignment: ST.getTileConfigAlignment(), isSpillSlot: false); |
343 | |
344 | // Try to insert for the tile config live in points. |
345 | for (const auto &I : CfgNeedInsert) { |
346 | SmallSet<MIRef, 8> InsertPoints; |
347 | SmallVector<MIRef, 8> WorkList({I}); |
348 | while (!WorkList.empty()) { |
349 | MIRef I = WorkList.pop_back_val(); |
350 | if (!VisitedOrInserted.count(V: I)) { |
351 | if (!BBVisitedInfo[I.MBB].TileCfgForbidden) { |
352 | // If the BB is all shapes reachable, stop sink and try to insert. |
353 | InsertPoints.insert(V: I); |
354 | } else { |
355 | // Avoid the BB to be multi visited. |
356 | VisitedOrInserted.insert(V: I); |
357 | // Sink the inserting point along the chain with NeedTileCfgLiveIn = |
358 | // true when MBB isn't all shapes reachable. |
359 | for (auto *Succ : I.MBB->successors()) |
360 | if (BBVisitedInfo[Succ].NeedTileCfgLiveIn) |
361 | WorkList.push_back(Elt: MIRef(Succ)); |
362 | } |
363 | } |
364 | } |
365 | |
366 | // A given point might be forked due to shape conditions are not met. |
367 | for (MIRef I : InsertPoints) { |
368 | // Make sure we insert ldtilecfg after the last shape def in MBB. |
369 | if (ShapeBBs.count(Val: I.MBB) && I < ShapeBBs[I.MBB].back()) |
370 | I = ShapeBBs[I.MBB].back(); |
371 | // There're chances the MBB is sunk more than once. Record it to avoid |
372 | // multi insert. |
373 | if (VisitedOrInserted.insert(V: I).second) { |
374 | auto II = I.MI ? I.MI->getIterator() : I.MBB->instr_begin(); |
375 | addFrameReference(MIB: BuildMI(BB&: *I.MBB, I: ++II, MIMD: DL, MCID: TII->get(Opcode: X86::PLDTILECFGV)), |
376 | FI: SS); |
377 | } |
378 | } |
379 | } |
380 | |
381 | // Zero stack slot. |
382 | MachineBasicBlock &MBB = MF.front(); |
383 | MachineInstr *MI = &*MBB.begin(); |
384 | if (ST.hasAVX512()) { |
385 | Register Zmm = MRI->createVirtualRegister(RegClass: &X86::VR512RegClass); |
386 | BuildMI(BB&: MBB, I: MI, MIMD: DL, MCID: TII->get(Opcode: X86::AVX512_512_SET0), DestReg: Zmm); |
387 | addFrameReference(MIB: BuildMI(BB&: MBB, I: MI, MIMD: DL, MCID: TII->get(Opcode: X86::VMOVUPSZmr)), FI: SS) |
388 | .addReg(RegNo: Zmm); |
389 | } else if (ST.hasAVX2()) { |
390 | Register Ymm = MRI->createVirtualRegister(RegClass: &X86::VR256RegClass); |
391 | BuildMI(BB&: MBB, I: MI, MIMD: DL, MCID: TII->get(Opcode: X86::AVX_SET0), DestReg: Ymm); |
392 | addFrameReference(MIB: BuildMI(BB&: MBB, I: MI, MIMD: DL, MCID: TII->get(Opcode: X86::VMOVUPSYmr)), FI: SS) |
393 | .addReg(RegNo: Ymm); |
394 | addFrameReference(MIB: BuildMI(BB&: MBB, I: MI, MIMD: DL, MCID: TII->get(Opcode: X86::VMOVUPSYmr)), FI: SS, Offset: 32) |
395 | .addReg(RegNo: Ymm); |
396 | } else { |
397 | assert(ST.hasSSE2() && "AMX should assume SSE2 enabled" ); |
398 | unsigned StoreOpc = ST.hasAVX() ? X86::VMOVUPSmr : X86::MOVUPSmr; |
399 | Register Xmm = MRI->createVirtualRegister(RegClass: &X86::VR128RegClass); |
400 | BuildMI(BB&: MBB, I: MI, MIMD: DL, MCID: TII->get(Opcode: X86::V_SET0), DestReg: Xmm); |
401 | addFrameReference(MIB: BuildMI(BB&: MBB, I: MI, MIMD: DL, MCID: TII->get(Opcode: StoreOpc)), FI: SS).addReg(RegNo: Xmm); |
402 | addFrameReference(MIB: BuildMI(BB&: MBB, I: MI, MIMD: DL, MCID: TII->get(Opcode: StoreOpc)), FI: SS, Offset: 16) |
403 | .addReg(RegNo: Xmm); |
404 | addFrameReference(MIB: BuildMI(BB&: MBB, I: MI, MIMD: DL, MCID: TII->get(Opcode: StoreOpc)), FI: SS, Offset: 32) |
405 | .addReg(RegNo: Xmm); |
406 | addFrameReference(MIB: BuildMI(BB&: MBB, I: MI, MIMD: DL, MCID: TII->get(Opcode: StoreOpc)), FI: SS, Offset: 48) |
407 | .addReg(RegNo: Xmm); |
408 | } |
409 | // Fill in the palette first. |
410 | addFrameReference(MIB: BuildMI(BB&: MBB, I: MI, MIMD: DL, MCID: TII->get(Opcode: X86::MOV8mi)), FI: SS).addImm(Val: 1); |
411 | |
412 | return true; |
413 | } |
414 | |
415 | FunctionPass *llvm::createX86PreTileConfigPass() { |
416 | return new X86PreTileConfig(); |
417 | } |
418 | |