1//===- MachineSMEABIPass.cpp ----------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This pass implements the SME ABI requirements for ZA state. This includes
10// implementing the lazy (and agnostic) ZA state save schemes around calls.
11//
12//===----------------------------------------------------------------------===//
13//
14// This pass works by collecting instructions that require ZA to be in a
15// specific state (e.g., "ACTIVE" or "SAVED") and inserting the necessary state
16// transitions to ensure ZA is in the required state before instructions. State
17// transitions represent actions such as setting up or restoring a lazy save.
18// Certain points within a function may also have predefined states independent
19// of any instructions, for example, a "shared_za" function is always entered
20// and exited in the "ACTIVE" state.
21//
22// To handle ZA state across control flow, we make use of edge bundling. This
23// assigns each block an "incoming" and "outgoing" edge bundle (representing
24// incoming and outgoing edges). Initially, these are unique to each block;
25// then, in the process of forming bundles, the outgoing bundle of a block is
26// joined with the incoming bundle of all successors. The result is that each
27// bundle can be assigned a single ZA state, which ensures the state required by
28// all a blocks' successors is the same, and that each basic block will always
29// be entered with the same ZA state. This eliminates the need for splitting
30// edges to insert state transitions or "phi" nodes for ZA states.
31//
32// See below for a simple example of edge bundling.
33//
34// The following shows a conditionally executed basic block (BB1):
35//
36// if (cond)
37// BB1
38// BB2
39//
40// Initial Bundles Joined Bundles
41//
42// ┌──0──┐ ┌──0──┐
43// │ BB0 │ │ BB0 │
44// └──1──┘ └──1──┘
45// ├───────┐ ├───────┐
46// ▼ │ ▼ │
47// ┌──2──┐ │ ─────► ┌──1──┐ │
48// │ BB1 │ ▼ │ BB1 │ ▼
49// └──3──┘ ┌──4──┐ └──1──┘ ┌──1──┐
50// └───►4 BB2 │ └───►1 BB2 │
51// └──5──┘ └──2──┘
52//
53// On the left are the initial per-block bundles, and on the right are the
54// joined bundles (which are the result of the EdgeBundles analysis).
55
56#include "AArch64InstrInfo.h"
57#include "AArch64MachineFunctionInfo.h"
58#include "AArch64Subtarget.h"
59#include "MCTargetDesc/AArch64AddressingModes.h"
60#include "llvm/ADT/BitmaskEnum.h"
61#include "llvm/ADT/SmallVector.h"
62#include "llvm/CodeGen/EdgeBundles.h"
63#include "llvm/CodeGen/LivePhysRegs.h"
64#include "llvm/CodeGen/MachineBasicBlock.h"
65#include "llvm/CodeGen/MachineFunctionPass.h"
66#include "llvm/CodeGen/MachineOptimizationRemarkEmitter.h"
67#include "llvm/CodeGen/MachineRegisterInfo.h"
68#include "llvm/CodeGen/TargetRegisterInfo.h"
69
70using namespace llvm;
71
72#define DEBUG_TYPE "aarch64-machine-sme-abi"
73
74namespace {
75
76// Note: For agnostic ZA, we assume the function is always entered/exited in the
77// "ACTIVE" state -- this _may_ not be the case (since OFF is also a
78// possibility, but for the purpose of placing ZA saves/restores, that does not
79// matter).
80enum ZAState : uint8_t {
81 // Any/unknown state (not valid)
82 ANY = 0,
83
84 // ZA is in use and active (i.e. within the accumulator)
85 ACTIVE,
86
87 // ZA is active, but ZT0 has been saved.
88 // This handles the edge case of sharedZA && !sharesZT0.
89 ACTIVE_ZT0_SAVED,
90
91 // A ZA save has been set up or committed (i.e. ZA is dormant or off)
92 // If the function uses ZT0 it must also be saved.
93 LOCAL_SAVED,
94
95 // ZA has been committed to the lazy save buffer of the current function.
96 // If the function uses ZT0 it must also be saved.
97 // ZA is off.
98 LOCAL_COMMITTED,
99
100 // The ZA/ZT0 state on entry to the function.
101 ENTRY,
102
103 // ZA is off.
104 OFF,
105
106 // The number of ZA states (not a valid state)
107 NUM_ZA_STATE
108};
109
110/// A bitmask enum to record live physical registers that the "emit*" routines
111/// may need to preserve. Note: This only tracks registers we may clobber.
112enum LiveRegs : uint8_t {
113 None = 0,
114 NZCV = 1 << 0,
115 W0 = 1 << 1,
116 W0_HI = 1 << 2,
117 X0 = W0 | W0_HI,
118 LLVM_MARK_AS_BITMASK_ENUM(/* LargestValue = */ W0_HI)
119};
120
121/// Holds the virtual registers live physical registers have been saved to.
122struct PhysRegSave {
123 LiveRegs PhysLiveRegs;
124 Register StatusFlags = AArch64::NoRegister;
125 Register X0Save = AArch64::NoRegister;
126};
127
128/// Contains the needed ZA state (and live registers) at an instruction. That is
129/// the state ZA must be in _before_ "InsertPt".
130struct InstInfo {
131 ZAState NeededState{ZAState::ANY};
132 MachineBasicBlock::iterator InsertPt;
133 LiveRegs PhysLiveRegs = LiveRegs::None;
134};
135
136/// Contains the needed ZA state for each instruction in a block. Instructions
137/// that do not require a ZA state are not recorded.
138struct BlockInfo {
139 SmallVector<InstInfo> Insts;
140 ZAState FixedEntryState{ZAState::ANY};
141 ZAState DesiredIncomingState{ZAState::ANY};
142 ZAState DesiredOutgoingState{ZAState::ANY};
143 LiveRegs PhysLiveRegsAtEntry = LiveRegs::None;
144 LiveRegs PhysLiveRegsAtExit = LiveRegs::None;
145};
146
147/// Contains the needed ZA state information for all blocks within a function.
148struct FunctionInfo {
149 SmallVector<BlockInfo> Blocks;
150 std::optional<MachineBasicBlock::iterator> AfterSMEProloguePt;
151 LiveRegs PhysLiveRegsAfterSMEPrologue = LiveRegs::None;
152};
153
154/// State/helpers that is only needed when emitting code to handle
155/// saving/restoring ZA.
156class EmitContext {
157public:
158 EmitContext() = default;
159
160 /// Get or create a TPIDR2 block in \p MF.
161 int getTPIDR2Block(MachineFunction &MF) {
162 if (TPIDR2BlockFI)
163 return *TPIDR2BlockFI;
164 MachineFrameInfo &MFI = MF.getFrameInfo();
165 TPIDR2BlockFI = MFI.CreateStackObject(Size: 16, Alignment: Align(16), isSpillSlot: false);
166 return *TPIDR2BlockFI;
167 }
168
169 /// Get or create agnostic ZA buffer pointer in \p MF.
170 Register getAgnosticZABufferPtr(MachineFunction &MF) {
171 if (AgnosticZABufferPtr != AArch64::NoRegister)
172 return AgnosticZABufferPtr;
173 Register BufferPtr =
174 MF.getInfo<AArch64FunctionInfo>()->getEarlyAllocSMESaveBuffer();
175 AgnosticZABufferPtr =
176 BufferPtr != AArch64::NoRegister
177 ? BufferPtr
178 : MF.getRegInfo().createVirtualRegister(RegClass: &AArch64::GPR64RegClass);
179 return AgnosticZABufferPtr;
180 }
181
182 int getZT0SaveSlot(MachineFunction &MF) {
183 if (ZT0SaveFI)
184 return *ZT0SaveFI;
185 MachineFrameInfo &MFI = MF.getFrameInfo();
186 ZT0SaveFI = MFI.CreateSpillStackObject(Size: 64, Alignment: Align(16));
187 return *ZT0SaveFI;
188 }
189
190 /// Returns true if the function must allocate a ZA save buffer on entry. This
191 /// will be the case if, at any point in the function, a ZA save was emitted.
192 bool needsSaveBuffer() const {
193 assert(!(TPIDR2BlockFI && AgnosticZABufferPtr) &&
194 "Cannot have both a TPIDR2 block and agnostic ZA buffer");
195 return TPIDR2BlockFI || AgnosticZABufferPtr != AArch64::NoRegister;
196 }
197
198private:
199 std::optional<int> ZT0SaveFI;
200 std::optional<int> TPIDR2BlockFI;
201 Register AgnosticZABufferPtr = AArch64::NoRegister;
202};
203
204StringRef getZAStateString(ZAState State) {
205#define MAKE_CASE(V) \
206 case V: \
207 return #V;
208 switch (State) {
209 MAKE_CASE(ZAState::ANY)
210 MAKE_CASE(ZAState::ACTIVE)
211 MAKE_CASE(ZAState::ACTIVE_ZT0_SAVED)
212 MAKE_CASE(ZAState::LOCAL_SAVED)
213 MAKE_CASE(ZAState::LOCAL_COMMITTED)
214 MAKE_CASE(ZAState::ENTRY)
215 MAKE_CASE(ZAState::OFF)
216 default:
217 llvm_unreachable("Unexpected ZAState");
218 }
219#undef MAKE_CASE
220}
221
222static bool isZAorZTRegOp(const TargetRegisterInfo &TRI,
223 const MachineOperand &MO) {
224 if (!MO.isReg() || !MO.getReg().isPhysical())
225 return false;
226 return any_of(Range: TRI.subregs_inclusive(Reg: MO.getReg()), P: [](const MCPhysReg &SR) {
227 return AArch64::MPR128RegClass.contains(Reg: SR) ||
228 AArch64::ZTRRegClass.contains(Reg: SR);
229 });
230}
231
232/// Returns the required ZA state needed before \p MI and an iterator pointing
233/// to where any code required to change the ZA state should be inserted.
234static std::pair<ZAState, MachineBasicBlock::iterator>
235getInstNeededZAState(const TargetRegisterInfo &TRI, MachineInstr &MI,
236 SMEAttrs SMEFnAttrs) {
237 MachineBasicBlock::iterator InsertPt(MI);
238
239 // Note: InOutZAUsePseudo, RequiresZASavePseudo, and RequiresZT0SavePseudo are
240 // intended to mark the position immediately before a call. Due to
241 // SelectionDAG constraints, these markers occur after the ADJCALLSTACKDOWN,
242 // so we use std::prev(InsertPt) to get the position before the call.
243
244 if (MI.getOpcode() == AArch64::InOutZAUsePseudo)
245 return {ZAState::ACTIVE, std::prev(x: InsertPt)};
246
247 // Note: If we need to save both ZA and ZT0 we use RequiresZASavePseudo.
248 if (MI.getOpcode() == AArch64::RequiresZASavePseudo)
249 return {ZAState::LOCAL_SAVED, std::prev(x: InsertPt)};
250
251 // If we only need to save ZT0 there's two cases to consider:
252 // 1. The function has ZA state (that we don't need to save).
253 // - In this case we switch to the "ACTIVE_ZT0_SAVED" state.
254 // This only saves ZT0.
255 // 2. The function does not have ZA state
256 // - In this case we switch to "LOCAL_COMMITTED" state.
257 // This saves ZT0 and turns ZA off.
258 if (MI.getOpcode() == AArch64::RequiresZT0SavePseudo) {
259 return {SMEFnAttrs.hasZAState() ? ZAState::ACTIVE_ZT0_SAVED
260 : ZAState::LOCAL_COMMITTED,
261 std::prev(x: InsertPt)};
262 }
263
264 if (MI.isReturn()) {
265 bool ZAOffAtReturn = SMEFnAttrs.hasPrivateZAInterface();
266 return {ZAOffAtReturn ? ZAState::OFF : ZAState::ACTIVE, InsertPt};
267 }
268
269 for (auto &MO : MI.operands()) {
270 if (isZAorZTRegOp(TRI, MO))
271 return {ZAState::ACTIVE, InsertPt};
272 }
273
274 return {ZAState::ANY, InsertPt};
275}
276
277struct MachineSMEABI : public MachineFunctionPass {
278 inline static char ID = 0;
279
280 MachineSMEABI(CodeGenOptLevel OptLevel = CodeGenOptLevel::Default)
281 : MachineFunctionPass(ID), OptLevel(OptLevel) {}
282
283 bool runOnMachineFunction(MachineFunction &MF) override;
284
285 StringRef getPassName() const override { return "Machine SME ABI pass"; }
286
287 void getAnalysisUsage(AnalysisUsage &AU) const override {
288 AU.setPreservesCFG();
289 AU.addRequired<EdgeBundlesWrapperLegacy>();
290 AU.addRequired<MachineOptimizationRemarkEmitterPass>();
291 AU.addRequired<LibcallLoweringInfoWrapper>();
292 AU.addPreservedID(ID&: MachineLoopInfoID);
293 AU.addPreservedID(ID&: MachineDominatorsID);
294 MachineFunctionPass::getAnalysisUsage(AU);
295 }
296
297 /// Collects the needed ZA state (and live registers) before each instruction
298 /// within the machine function.
299 FunctionInfo collectNeededZAStates(SMEAttrs SMEFnAttrs);
300
301 /// Assigns each edge bundle a ZA state based on the needed states of blocks
302 /// that have incoming or outgoing edges in that bundle.
303 SmallVector<ZAState> assignBundleZAStates(const EdgeBundles &Bundles,
304 const FunctionInfo &FnInfo);
305
306 /// Inserts code to handle changes between ZA states within the function.
307 /// E.g., ACTIVE -> LOCAL_SAVED will insert code required to save ZA.
308 void insertStateChanges(EmitContext &, const FunctionInfo &FnInfo,
309 const EdgeBundles &Bundles,
310 ArrayRef<ZAState> BundleStates);
311
312 void addSMELibCall(MachineInstrBuilder &MIB, RTLIB::Libcall LC,
313 CallingConv::ID ExpectedCC);
314
315 void emitZT0SaveRestore(EmitContext &, MachineBasicBlock &MBB,
316 MachineBasicBlock::iterator MBBI, bool IsSave);
317
318 // Emission routines for private and shared ZA functions (using lazy saves).
319 void emitSMEPrologue(MachineBasicBlock &MBB,
320 MachineBasicBlock::iterator MBBI);
321 void emitRestoreLazySave(EmitContext &, MachineBasicBlock &MBB,
322 MachineBasicBlock::iterator MBBI,
323 LiveRegs PhysLiveRegs);
324 void emitSetupLazySave(EmitContext &, MachineBasicBlock &MBB,
325 MachineBasicBlock::iterator MBBI);
326 void emitAllocateLazySaveBuffer(EmitContext &, MachineBasicBlock &MBB,
327 MachineBasicBlock::iterator MBBI);
328 void emitZAMode(MachineBasicBlock &MBB, MachineBasicBlock::iterator MBBI,
329 bool ClearTPIDR2, bool On);
330
331 // Emission routines for agnostic ZA functions.
332 void emitSetupFullZASave(MachineBasicBlock &MBB,
333 MachineBasicBlock::iterator MBBI,
334 LiveRegs PhysLiveRegs);
335 // Emit a "full" ZA save or restore. It is "full" in the sense that this
336 // function will emit a call to __arm_sme_save or __arm_sme_restore, which
337 // handles saving and restoring both ZA and ZT0.
338 void emitFullZASaveRestore(EmitContext &, MachineBasicBlock &MBB,
339 MachineBasicBlock::iterator MBBI,
340 LiveRegs PhysLiveRegs, bool IsSave);
341 void emitAllocateFullZASaveBuffer(EmitContext &, MachineBasicBlock &MBB,
342 MachineBasicBlock::iterator MBBI,
343 LiveRegs PhysLiveRegs);
344
345 /// Attempts to find an insertion point before \p Inst where the status flags
346 /// are not live. If \p Inst is `Block.Insts.end()` a point before the end of
347 /// the block is found.
348 std::pair<MachineBasicBlock::iterator, LiveRegs>
349 findStateChangeInsertionPoint(MachineBasicBlock &MBB, const BlockInfo &Block,
350 SmallVectorImpl<InstInfo>::const_iterator Inst);
351 void emitStateChange(EmitContext &, MachineBasicBlock &MBB,
352 MachineBasicBlock::iterator MBBI, ZAState From,
353 ZAState To, LiveRegs PhysLiveRegs);
354
355 // Helpers for switching between lazy/full ZA save/restore routines.
356 void emitZASave(EmitContext &Context, MachineBasicBlock &MBB,
357 MachineBasicBlock::iterator MBBI, LiveRegs PhysLiveRegs) {
358 if (AFI->getSMEFnAttrs().hasAgnosticZAInterface())
359 return emitFullZASaveRestore(Context, MBB, MBBI, PhysLiveRegs,
360 /*IsSave=*/true);
361 return emitSetupLazySave(Context, MBB, MBBI);
362 }
363 void emitZARestore(EmitContext &Context, MachineBasicBlock &MBB,
364 MachineBasicBlock::iterator MBBI, LiveRegs PhysLiveRegs) {
365 if (AFI->getSMEFnAttrs().hasAgnosticZAInterface())
366 return emitFullZASaveRestore(Context, MBB, MBBI, PhysLiveRegs,
367 /*IsSave=*/false);
368 return emitRestoreLazySave(Context, MBB, MBBI, PhysLiveRegs);
369 }
370 void emitAllocateZASaveBuffer(EmitContext &Context, MachineBasicBlock &MBB,
371 MachineBasicBlock::iterator MBBI,
372 LiveRegs PhysLiveRegs) {
373 if (AFI->getSMEFnAttrs().hasAgnosticZAInterface())
374 return emitAllocateFullZASaveBuffer(Context, MBB, MBBI, PhysLiveRegs);
375 return emitAllocateLazySaveBuffer(Context, MBB, MBBI);
376 }
377
378 /// Collects the reachable calls from \p MBBI marked with \p Marker. This is
379 /// intended to be used to emit lazy save remarks. Note: This stops at the
380 /// first marked call along any path.
381 void collectReachableMarkedCalls(const MachineBasicBlock &MBB,
382 MachineBasicBlock::const_iterator MBBI,
383 SmallVectorImpl<const MachineInstr *> &Calls,
384 unsigned Marker) const;
385
386 void emitCallSaveRemarks(const MachineBasicBlock &MBB,
387 MachineBasicBlock::const_iterator MBBI, DebugLoc DL,
388 unsigned Marker, StringRef RemarkName,
389 StringRef SaveName) const;
390
391 void emitError(const Twine &Message) {
392 LLVMContext &Context = MF->getFunction().getContext();
393 Context.emitError(ErrorStr: MF->getName() + ": " + Message);
394 }
395
396 /// Save live physical registers to virtual registers.
397 PhysRegSave createPhysRegSave(LiveRegs PhysLiveRegs, MachineBasicBlock &MBB,
398 MachineBasicBlock::iterator MBBI, DebugLoc DL);
399 /// Restore physical registers from a save of their previous values.
400 void restorePhyRegSave(const PhysRegSave &RegSave, MachineBasicBlock &MBB,
401 MachineBasicBlock::iterator MBBI, DebugLoc DL);
402
403private:
404 CodeGenOptLevel OptLevel = CodeGenOptLevel::Default;
405
406 MachineFunction *MF = nullptr;
407 const AArch64Subtarget *Subtarget = nullptr;
408 const AArch64RegisterInfo *TRI = nullptr;
409 const AArch64FunctionInfo *AFI = nullptr;
410 const AArch64InstrInfo *TII = nullptr;
411 const LibcallLoweringInfo *LLI = nullptr;
412
413 MachineOptimizationRemarkEmitter *ORE = nullptr;
414 MachineRegisterInfo *MRI = nullptr;
415 MachineLoopInfo *MLI = nullptr;
416};
417
418static LiveRegs getPhysLiveRegs(LiveRegUnits const &LiveUnits) {
419 LiveRegs PhysLiveRegs = LiveRegs::None;
420 if (!LiveUnits.available(Reg: AArch64::NZCV))
421 PhysLiveRegs |= LiveRegs::NZCV;
422 // We have to track W0 and X0 separately as otherwise things can get
423 // confused if we attempt to preserve X0 but only W0 was defined.
424 if (!LiveUnits.available(Reg: AArch64::W0))
425 PhysLiveRegs |= LiveRegs::W0;
426 if (!LiveUnits.available(Reg: AArch64::W0_HI))
427 PhysLiveRegs |= LiveRegs::W0_HI;
428 return PhysLiveRegs;
429}
430
431static void setPhysLiveRegs(LiveRegUnits &LiveUnits, LiveRegs PhysLiveRegs) {
432 if (PhysLiveRegs & LiveRegs::NZCV)
433 LiveUnits.addReg(Reg: AArch64::NZCV);
434 if (PhysLiveRegs & LiveRegs::W0)
435 LiveUnits.addReg(Reg: AArch64::W0);
436 if (PhysLiveRegs & LiveRegs::W0_HI)
437 LiveUnits.addReg(Reg: AArch64::W0_HI);
438}
439
440[[maybe_unused]] bool isCallStartOpcode(unsigned Opc) {
441 switch (Opc) {
442 case AArch64::TLSDESC_CALLSEQ:
443 case AArch64::TLSDESC_AUTH_CALLSEQ:
444 case AArch64::ADJCALLSTACKDOWN:
445 return true;
446 default:
447 return false;
448 }
449}
450
451FunctionInfo MachineSMEABI::collectNeededZAStates(SMEAttrs SMEFnAttrs) {
452 assert((SMEFnAttrs.hasAgnosticZAInterface() || SMEFnAttrs.hasZT0State() ||
453 SMEFnAttrs.hasZAState()) &&
454 "Expected function to have ZA/ZT0 state!");
455
456 SmallVector<BlockInfo> Blocks(MF->getNumBlockIDs());
457 LiveRegs PhysLiveRegsAfterSMEPrologue = LiveRegs::None;
458 std::optional<MachineBasicBlock::iterator> AfterSMEProloguePt;
459
460 for (MachineBasicBlock &MBB : *MF) {
461 BlockInfo &Block = Blocks[MBB.getNumber()];
462
463 if (MBB.isEntryBlock()) {
464 // Entry block:
465 Block.FixedEntryState = ZAState::ENTRY;
466 } else if (MBB.isEHPad()) {
467 // EH entry block:
468 Block.FixedEntryState = ZAState::LOCAL_COMMITTED;
469 }
470
471 LiveRegUnits LiveUnits(*TRI);
472 LiveUnits.addLiveOuts(MBB);
473
474 Block.PhysLiveRegsAtExit = getPhysLiveRegs(LiveUnits);
475 auto FirstTerminatorInsertPt = MBB.getFirstTerminator();
476 auto FirstNonPhiInsertPt = MBB.getFirstNonPHI();
477 for (MachineInstr &MI : reverse(C&: MBB)) {
478 MachineBasicBlock::iterator MBBI(MI);
479 LiveUnits.stepBackward(MI);
480 LiveRegs PhysLiveRegs = getPhysLiveRegs(LiveUnits);
481 // The SMEStateAllocPseudo marker is added to a function if the save
482 // buffer was allocated in SelectionDAG. It marks the end of the
483 // allocation -- which is a safe point for this pass to insert any TPIDR2
484 // block setup.
485 if (MI.getOpcode() == AArch64::SMEStateAllocPseudo) {
486 AfterSMEProloguePt = MBBI;
487 PhysLiveRegsAfterSMEPrologue = PhysLiveRegs;
488 }
489 // Note: We treat Agnostic ZA as inout_za with an alternate save/restore.
490 auto [NeededState, InsertPt] = getInstNeededZAState(TRI: *TRI, MI, SMEFnAttrs);
491 assert((InsertPt == MBBI || isCallStartOpcode(InsertPt->getOpcode())) &&
492 "Unexpected state change insertion point!");
493 // TODO: Do something to avoid state changes where NZCV is live.
494 if (MBBI == FirstTerminatorInsertPt)
495 Block.PhysLiveRegsAtExit = PhysLiveRegs;
496 if (MBBI == FirstNonPhiInsertPt)
497 Block.PhysLiveRegsAtEntry = PhysLiveRegs;
498 if (NeededState != ZAState::ANY)
499 Block.Insts.push_back(Elt: {.NeededState: NeededState, .InsertPt: InsertPt, .PhysLiveRegs: PhysLiveRegs});
500 }
501
502 // Reverse vector (as we had to iterate backwards for liveness).
503 std::reverse(first: Block.Insts.begin(), last: Block.Insts.end());
504
505 // Record the desired states on entry/exit of this block. These are the
506 // states that would not incur a state transition.
507 if (!Block.Insts.empty()) {
508 Block.DesiredIncomingState = Block.Insts.front().NeededState;
509 Block.DesiredOutgoingState = Block.Insts.back().NeededState;
510 }
511 }
512
513 return FunctionInfo{.Blocks: std::move(Blocks), .AfterSMEProloguePt: AfterSMEProloguePt,
514 .PhysLiveRegsAfterSMEPrologue: PhysLiveRegsAfterSMEPrologue};
515}
516
517/// Assigns each edge bundle a ZA state based on the needed states of blocks
518/// that have incoming or outgoing blocks in that bundle.
519SmallVector<ZAState>
520MachineSMEABI::assignBundleZAStates(const EdgeBundles &Bundles,
521 const FunctionInfo &FnInfo) {
522 SmallVector<ZAState> BundleStates(Bundles.getNumBundles());
523 for (unsigned I = 0, E = Bundles.getNumBundles(); I != E; ++I) {
524 std::optional<ZAState> BundleState;
525 for (unsigned BlockID : Bundles.getBlocks(Bundle: I)) {
526 const BlockInfo &Block = FnInfo.Blocks[BlockID];
527 // Check if the block is an incoming block in the bundle. Note: We skip
528 // Block.FixedEntryState != ANY to ignore EH pads (which are only
529 // reachable via exceptions).
530 if (Block.FixedEntryState != ZAState::ANY ||
531 Bundles.getBundle(N: BlockID, /*Out=*/false) != I)
532 continue;
533
534 // Pick a state that matches all incoming blocks. Fallback to "ACTIVE" if
535 // any blocks doesn't match. This will hoist the state from incoming
536 // blocks to outgoing blocks.
537 if (!BundleState)
538 BundleState = Block.DesiredIncomingState;
539 else if (BundleState != Block.DesiredIncomingState)
540 BundleState = ZAState::ACTIVE;
541 }
542
543 if (!BundleState || BundleState == ZAState::ANY)
544 BundleState = ZAState::ACTIVE;
545
546 BundleStates[I] = *BundleState;
547 }
548
549 return BundleStates;
550}
551
552std::pair<MachineBasicBlock::iterator, LiveRegs>
553MachineSMEABI::findStateChangeInsertionPoint(
554 MachineBasicBlock &MBB, const BlockInfo &Block,
555 SmallVectorImpl<InstInfo>::const_iterator Inst) {
556 LiveRegs PhysLiveRegs;
557 MachineBasicBlock::iterator InsertPt;
558 if (Inst != Block.Insts.end()) {
559 InsertPt = Inst->InsertPt;
560 PhysLiveRegs = Inst->PhysLiveRegs;
561 } else {
562 InsertPt = MBB.getFirstTerminator();
563 PhysLiveRegs = Block.PhysLiveRegsAtExit;
564 }
565
566 if (PhysLiveRegs == LiveRegs::None)
567 return {InsertPt, PhysLiveRegs}; // Nothing to do (no live regs).
568
569 // Find the previous state change. We can not move before this point.
570 MachineBasicBlock::iterator PrevStateChangeI;
571 if (Inst == Block.Insts.begin()) {
572 PrevStateChangeI = MBB.begin();
573 } else {
574 // Note: `std::prev(Inst)` is the previous InstInfo. We only create an
575 // InstInfo object for instructions that require a specific ZA state, so the
576 // InstInfo is the site of the previous state change in the block (which can
577 // be several MIs earlier).
578 PrevStateChangeI = std::prev(x: Inst)->InsertPt;
579 }
580
581 // Note: LiveUnits will only accurately track X0 and NZCV.
582 LiveRegUnits LiveUnits(*TRI);
583 setPhysLiveRegs(LiveUnits, PhysLiveRegs);
584 auto BestCandidate = std::make_pair(x&: InsertPt, y&: PhysLiveRegs);
585 for (MachineBasicBlock::iterator I = InsertPt; I != PrevStateChangeI; --I) {
586 // Don't move before/into a call (which may have a state change before it).
587 if (I->getOpcode() == TII->getCallFrameDestroyOpcode() || I->isCall())
588 break;
589 LiveUnits.stepBackward(MI: *I);
590 LiveRegs CurrentPhysLiveRegs = getPhysLiveRegs(LiveUnits);
591 // Find places where NZCV is available, but keep looking for locations where
592 // both NZCV and X0 are available, which can avoid some copies.
593 if (!(CurrentPhysLiveRegs & LiveRegs::NZCV))
594 BestCandidate = {I, CurrentPhysLiveRegs};
595 if (CurrentPhysLiveRegs == LiveRegs::None)
596 break;
597 }
598 return BestCandidate;
599}
600
601void MachineSMEABI::insertStateChanges(EmitContext &Context,
602 const FunctionInfo &FnInfo,
603 const EdgeBundles &Bundles,
604 ArrayRef<ZAState> BundleStates) {
605 for (MachineBasicBlock &MBB : *MF) {
606 const BlockInfo &Block = FnInfo.Blocks[MBB.getNumber()];
607 ZAState InState = BundleStates[Bundles.getBundle(N: MBB.getNumber(),
608 /*Out=*/false)];
609
610 ZAState CurrentState = Block.FixedEntryState;
611 if (CurrentState == ZAState::ANY)
612 CurrentState = InState;
613
614 for (auto &Inst : Block.Insts) {
615 if (CurrentState != Inst.NeededState) {
616 auto [InsertPt, PhysLiveRegs] =
617 findStateChangeInsertionPoint(MBB, Block, Inst: &Inst);
618 emitStateChange(Context, MBB, MBBI: InsertPt, From: CurrentState, To: Inst.NeededState,
619 PhysLiveRegs);
620 CurrentState = Inst.NeededState;
621 }
622 }
623
624 if (MBB.succ_empty())
625 continue;
626
627 ZAState OutState =
628 BundleStates[Bundles.getBundle(N: MBB.getNumber(), /*Out=*/true)];
629 if (CurrentState != OutState) {
630 auto [InsertPt, PhysLiveRegs] =
631 findStateChangeInsertionPoint(MBB, Block, Inst: Block.Insts.end());
632 emitStateChange(Context, MBB, MBBI: InsertPt, From: CurrentState, To: OutState,
633 PhysLiveRegs);
634 }
635 }
636}
637
638static DebugLoc getDebugLoc(MachineBasicBlock &MBB,
639 MachineBasicBlock::iterator MBBI) {
640 if (MBB.empty())
641 return DebugLoc();
642 return MBBI != MBB.end() ? MBBI->getDebugLoc() : MBB.back().getDebugLoc();
643}
644
645/// Finds the first call (as determined by MachineInstr::isCall()) starting from
646/// \p MBBI in \p MBB marked with \p Marker (which is a marker opcode such as
647/// RequiresZASavePseudo). If a marked call is found, it is pushed to \p Calls
648/// and the function returns true.
649static bool findMarkedCall(const MachineBasicBlock &MBB,
650 MachineBasicBlock::const_iterator MBBI,
651 SmallVectorImpl<const MachineInstr *> &Calls,
652 unsigned Marker, unsigned CallDestroyOpcode) {
653 auto IsMarker = [&](auto &MI) { return MI.getOpcode() == Marker; };
654 auto MarkerInst = std::find_if(first: MBBI, last: MBB.end(), pred: IsMarker);
655 if (MarkerInst == MBB.end())
656 return false;
657 MachineBasicBlock::const_iterator I = MarkerInst;
658 while (++I != MBB.end()) {
659 if (I->isCall() || I->getOpcode() == CallDestroyOpcode)
660 break;
661 }
662 if (I != MBB.end() && I->isCall())
663 Calls.push_back(Elt: &*I);
664 // Note: This function always returns true if a "Marker" was found.
665 return true;
666}
667
668void MachineSMEABI::collectReachableMarkedCalls(
669 const MachineBasicBlock &StartMBB,
670 MachineBasicBlock::const_iterator StartInst,
671 SmallVectorImpl<const MachineInstr *> &Calls, unsigned Marker) const {
672 assert(Marker == AArch64::InOutZAUsePseudo ||
673 Marker == AArch64::RequiresZASavePseudo ||
674 Marker == AArch64::RequiresZT0SavePseudo);
675 unsigned CallDestroyOpcode = TII->getCallFrameDestroyOpcode();
676 if (findMarkedCall(MBB: StartMBB, MBBI: StartInst, Calls, Marker, CallDestroyOpcode))
677 return;
678
679 SmallPtrSet<const MachineBasicBlock *, 4> Visited;
680 SmallVector<const MachineBasicBlock *> Worklist(StartMBB.succ_rbegin(),
681 StartMBB.succ_rend());
682 while (!Worklist.empty()) {
683 const MachineBasicBlock *MBB = Worklist.pop_back_val();
684 auto [_, Inserted] = Visited.insert(Ptr: MBB);
685 if (!Inserted)
686 continue;
687
688 if (!findMarkedCall(MBB: *MBB, MBBI: MBB->begin(), Calls, Marker, CallDestroyOpcode))
689 Worklist.append(in_start: MBB->succ_rbegin(), in_end: MBB->succ_rend());
690 }
691}
692
693static StringRef getCalleeName(const MachineInstr &CallInst) {
694 assert(CallInst.isCall() && "expected a call");
695 for (const MachineOperand &MO : CallInst.operands()) {
696 if (MO.isSymbol())
697 return MO.getSymbolName();
698 if (MO.isGlobal())
699 return MO.getGlobal()->getName();
700 }
701 return {};
702}
703
704void MachineSMEABI::emitCallSaveRemarks(const MachineBasicBlock &MBB,
705 MachineBasicBlock::const_iterator MBBI,
706 DebugLoc DL, unsigned Marker,
707 StringRef RemarkName,
708 StringRef SaveName) const {
709 auto SaveRemark = [&](DebugLoc DL, const MachineBasicBlock &MBB) {
710 return MachineOptimizationRemarkAnalysis("sme", RemarkName, DL, &MBB);
711 };
712 StringRef StateName = Marker == AArch64::RequiresZT0SavePseudo ? "ZT0" : "ZA";
713 ORE->emit(RemarkBuilder: [&] {
714 return SaveRemark(DL, MBB) << SaveName << " of " << StateName
715 << " emitted in '" << MF->getName() << "'";
716 });
717 if (!ORE->allowExtraAnalysis(PassName: "sme"))
718 return;
719 SmallVector<const MachineInstr *> CallsRequiringSaves;
720 collectReachableMarkedCalls(StartMBB: MBB, StartInst: MBBI, Calls&: CallsRequiringSaves, Marker);
721 for (const MachineInstr *CallInst : CallsRequiringSaves) {
722 auto R = SaveRemark(CallInst->getDebugLoc(), *CallInst->getParent());
723 R << "call";
724 if (StringRef CalleeName = getCalleeName(CallInst: *CallInst); !CalleeName.empty())
725 R << " to '" << CalleeName << "'";
726 R << " requires " << StateName << " save";
727 ORE->emit(OptDiag&: R);
728 }
729}
730
731void MachineSMEABI::emitSetupLazySave(EmitContext &Context,
732 MachineBasicBlock &MBB,
733 MachineBasicBlock::iterator MBBI) {
734 DebugLoc DL = getDebugLoc(MBB, MBBI);
735
736 emitCallSaveRemarks(MBB, MBBI, DL, Marker: AArch64::RequiresZASavePseudo,
737 RemarkName: "SMELazySaveZA", SaveName: "lazy save");
738
739 // Get pointer to TPIDR2 block.
740 Register TPIDR2 = MRI->createVirtualRegister(RegClass: &AArch64::GPR64spRegClass);
741 Register TPIDR2Ptr = MRI->createVirtualRegister(RegClass: &AArch64::GPR64RegClass);
742 BuildMI(BB&: MBB, I: MBBI, MIMD: DL, MCID: TII->get(Opcode: AArch64::ADDXri), DestReg: TPIDR2)
743 .addFrameIndex(Idx: Context.getTPIDR2Block(MF&: *MF))
744 .addImm(Val: 0)
745 .addImm(Val: 0);
746 BuildMI(BB&: MBB, I: MBBI, MIMD: DL, MCID: TII->get(Opcode: TargetOpcode::COPY), DestReg: TPIDR2Ptr)
747 .addReg(RegNo: TPIDR2);
748 // Set TPIDR2_EL0 to point to TPIDR2 block.
749 BuildMI(BB&: MBB, I: MBBI, MIMD: DL, MCID: TII->get(Opcode: AArch64::MSR))
750 .addImm(Val: AArch64SysReg::TPIDR2_EL0)
751 .addReg(RegNo: TPIDR2Ptr);
752}
753
754PhysRegSave MachineSMEABI::createPhysRegSave(LiveRegs PhysLiveRegs,
755 MachineBasicBlock &MBB,
756 MachineBasicBlock::iterator MBBI,
757 DebugLoc DL) {
758 PhysRegSave RegSave{.PhysLiveRegs: PhysLiveRegs};
759 if (PhysLiveRegs & LiveRegs::NZCV) {
760 RegSave.StatusFlags = MRI->createVirtualRegister(RegClass: &AArch64::GPR64RegClass);
761 BuildMI(BB&: MBB, I: MBBI, MIMD: DL, MCID: TII->get(Opcode: AArch64::MRS), DestReg: RegSave.StatusFlags)
762 .addImm(Val: AArch64SysReg::NZCV)
763 .addReg(RegNo: AArch64::NZCV, Flags: RegState::Implicit);
764 }
765 // Note: Preserving X0 is "free" as this is before register allocation, so
766 // the register allocator is still able to optimize these copies.
767 if (PhysLiveRegs & LiveRegs::W0) {
768 RegSave.X0Save = MRI->createVirtualRegister(RegClass: PhysLiveRegs & LiveRegs::W0_HI
769 ? &AArch64::GPR64RegClass
770 : &AArch64::GPR32RegClass);
771 BuildMI(BB&: MBB, I: MBBI, MIMD: DL, MCID: TII->get(Opcode: TargetOpcode::COPY), DestReg: RegSave.X0Save)
772 .addReg(RegNo: PhysLiveRegs & LiveRegs::W0_HI ? AArch64::X0 : AArch64::W0);
773 }
774 return RegSave;
775}
776
777void MachineSMEABI::restorePhyRegSave(const PhysRegSave &RegSave,
778 MachineBasicBlock &MBB,
779 MachineBasicBlock::iterator MBBI,
780 DebugLoc DL) {
781 if (RegSave.StatusFlags != AArch64::NoRegister)
782 BuildMI(BB&: MBB, I: MBBI, MIMD: DL, MCID: TII->get(Opcode: AArch64::MSR))
783 .addImm(Val: AArch64SysReg::NZCV)
784 .addReg(RegNo: RegSave.StatusFlags)
785 .addReg(RegNo: AArch64::NZCV, Flags: RegState::ImplicitDefine);
786
787 if (RegSave.X0Save != AArch64::NoRegister)
788 BuildMI(BB&: MBB, I: MBBI, MIMD: DL, MCID: TII->get(Opcode: TargetOpcode::COPY),
789 DestReg: RegSave.PhysLiveRegs & LiveRegs::W0_HI ? AArch64::X0 : AArch64::W0)
790 .addReg(RegNo: RegSave.X0Save);
791}
792
793void MachineSMEABI::addSMELibCall(MachineInstrBuilder &MIB, RTLIB::Libcall LC,
794 CallingConv::ID ExpectedCC) {
795 RTLIB::LibcallImpl LCImpl = LLI->getLibcallImpl(Call: LC);
796 if (LCImpl == RTLIB::Unsupported)
797 emitError(Message: "cannot lower SME ABI (SME routines unsupported)");
798 CallingConv::ID CC = LLI->getLibcallImplCallingConv(Call: LCImpl);
799 StringRef ImplName = RTLIB::RuntimeLibcallsInfo::getLibcallImplName(CallImpl: LCImpl);
800 if (CC != ExpectedCC)
801 emitError(Message: "invalid calling convention for SME routine: '" + ImplName + "'");
802 // FIXME: This assumes the ImplName StringRef is null-terminated.
803 MIB.addExternalSymbol(FnName: ImplName.data());
804 MIB.addRegMask(Mask: TRI->getCallPreservedMask(MF: *MF, CC));
805}
806
807void MachineSMEABI::emitRestoreLazySave(EmitContext &Context,
808 MachineBasicBlock &MBB,
809 MachineBasicBlock::iterator MBBI,
810 LiveRegs PhysLiveRegs) {
811 DebugLoc DL = getDebugLoc(MBB, MBBI);
812 Register TPIDR2EL0 = MRI->createVirtualRegister(RegClass: &AArch64::GPR64RegClass);
813 Register TPIDR2 = AArch64::X0;
814
815 // TODO: Emit these within the restore MBB to prevent unnecessary saves.
816 PhysRegSave RegSave = createPhysRegSave(PhysLiveRegs, MBB, MBBI, DL);
817
818 // Enable ZA.
819 BuildMI(BB&: MBB, I: MBBI, MIMD: DL, MCID: TII->get(Opcode: AArch64::MSRpstatesvcrImm1))
820 .addImm(Val: AArch64SVCR::SVCRZA)
821 .addImm(Val: 1);
822 // Get current TPIDR2_EL0.
823 BuildMI(BB&: MBB, I: MBBI, MIMD: DL, MCID: TII->get(Opcode: AArch64::MRS), DestReg: TPIDR2EL0)
824 .addImm(Val: AArch64SysReg::TPIDR2_EL0);
825 // Get pointer to TPIDR2 block.
826 BuildMI(BB&: MBB, I: MBBI, MIMD: DL, MCID: TII->get(Opcode: AArch64::ADDXri), DestReg: TPIDR2)
827 .addFrameIndex(Idx: Context.getTPIDR2Block(MF&: *MF))
828 .addImm(Val: 0)
829 .addImm(Val: 0);
830 // (Conditionally) restore ZA state.
831 auto RestoreZA = BuildMI(BB&: MBB, I: MBBI, MIMD: DL, MCID: TII->get(Opcode: AArch64::RestoreZAPseudo))
832 .addReg(RegNo: TPIDR2EL0)
833 .addReg(RegNo: TPIDR2);
834 addSMELibCall(
835 MIB&: RestoreZA, LC: RTLIB::SMEABI_TPIDR2_RESTORE,
836 ExpectedCC: CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X0);
837 // Zero TPIDR2_EL0.
838 BuildMI(BB&: MBB, I: MBBI, MIMD: DL, MCID: TII->get(Opcode: AArch64::MSR))
839 .addImm(Val: AArch64SysReg::TPIDR2_EL0)
840 .addReg(RegNo: AArch64::XZR);
841
842 restorePhyRegSave(RegSave, MBB, MBBI, DL);
843}
844
845void MachineSMEABI::emitZAMode(MachineBasicBlock &MBB,
846 MachineBasicBlock::iterator MBBI,
847 bool ClearTPIDR2, bool On) {
848 DebugLoc DL = getDebugLoc(MBB, MBBI);
849
850 if (ClearTPIDR2)
851 BuildMI(BB&: MBB, I: MBBI, MIMD: DL, MCID: TII->get(Opcode: AArch64::MSR))
852 .addImm(Val: AArch64SysReg::TPIDR2_EL0)
853 .addReg(RegNo: AArch64::XZR);
854
855 // Disable ZA.
856 BuildMI(BB&: MBB, I: MBBI, MIMD: DL, MCID: TII->get(Opcode: AArch64::MSRpstatesvcrImm1))
857 .addImm(Val: AArch64SVCR::SVCRZA)
858 .addImm(Val: On ? 1 : 0);
859}
860
861void MachineSMEABI::emitAllocateLazySaveBuffer(
862 EmitContext &Context, MachineBasicBlock &MBB,
863 MachineBasicBlock::iterator MBBI) {
864 MachineFrameInfo &MFI = MF->getFrameInfo();
865 DebugLoc DL = getDebugLoc(MBB, MBBI);
866 Register SP = MRI->createVirtualRegister(RegClass: &AArch64::GPR64RegClass);
867 Register SVL = MRI->createVirtualRegister(RegClass: &AArch64::GPR64RegClass);
868 Register Buffer = AFI->getEarlyAllocSMESaveBuffer();
869
870 // Calculate SVL.
871 BuildMI(BB&: MBB, I: MBBI, MIMD: DL, MCID: TII->get(Opcode: AArch64::RDSVLI_XI), DestReg: SVL).addImm(Val: 1);
872
873 // 1. Allocate the lazy save buffer.
874 if (Buffer == AArch64::NoRegister) {
875 // TODO: On Windows, we allocate the lazy save buffer in SelectionDAG (so
876 // Buffer != AArch64::NoRegister). This is done to reuse the existing
877 // expansions (which can insert stack checks). This works, but it means we
878 // will always allocate the lazy save buffer (even if the function contains
879 // no lazy saves). If we want to handle Windows here, we'll need to
880 // implement something similar to LowerWindowsDYNAMIC_STACKALLOC.
881 assert(!Subtarget->isTargetWindows() &&
882 "Lazy ZA save is not yet supported on Windows");
883 Buffer = MRI->createVirtualRegister(RegClass: &AArch64::GPR64RegClass);
884 // Get original stack pointer.
885 BuildMI(BB&: MBB, I: MBBI, MIMD: DL, MCID: TII->get(Opcode: TargetOpcode::COPY), DestReg: SP)
886 .addReg(RegNo: AArch64::SP);
887 // Allocate a lazy-save buffer object of the size given, normally SVL * SVL
888 BuildMI(BB&: MBB, I: MBBI, MIMD: DL, MCID: TII->get(Opcode: AArch64::MSUBXrrr), DestReg: Buffer)
889 .addReg(RegNo: SVL)
890 .addReg(RegNo: SVL)
891 .addReg(RegNo: SP);
892 BuildMI(BB&: MBB, I: MBBI, MIMD: DL, MCID: TII->get(Opcode: TargetOpcode::COPY), DestReg: AArch64::SP)
893 .addReg(RegNo: Buffer);
894 // We have just allocated a variable sized object, tell this to PEI.
895 MFI.CreateVariableSizedObject(Alignment: Align(16), Alloca: nullptr);
896 }
897
898 // 2. Setup the TPIDR2 block.
899 {
900 // Note: This case just needs to do `SVL << 48`. It is not implemented as we
901 // generally don't support big-endian SVE/SME.
902 if (!Subtarget->isLittleEndian())
903 reportFatalInternalError(
904 reason: "TPIDR2 block initialization is not supported on big-endian targets");
905
906 // Store buffer pointer and num_za_save_slices.
907 // Bytes 10-15 are implicitly zeroed.
908 BuildMI(BB&: MBB, I: MBBI, MIMD: DL, MCID: TII->get(Opcode: AArch64::STPXi))
909 .addReg(RegNo: Buffer)
910 .addReg(RegNo: SVL)
911 .addFrameIndex(Idx: Context.getTPIDR2Block(MF&: *MF))
912 .addImm(Val: 0);
913 }
914}
915
916static constexpr unsigned ZERO_ALL_ZA_MASK = 0b11111111;
917
918void MachineSMEABI::emitSMEPrologue(MachineBasicBlock &MBB,
919 MachineBasicBlock::iterator MBBI) {
920 DebugLoc DL = getDebugLoc(MBB, MBBI);
921
922 bool ZeroZA = AFI->getSMEFnAttrs().isNewZA();
923 bool ZeroZT0 = AFI->getSMEFnAttrs().isNewZT0();
924 if (AFI->getSMEFnAttrs().hasPrivateZAInterface()) {
925 // Get current TPIDR2_EL0.
926 Register TPIDR2EL0 = MRI->createVirtualRegister(RegClass: &AArch64::GPR64RegClass);
927 BuildMI(BB&: MBB, I: MBBI, MIMD: DL, MCID: TII->get(Opcode: AArch64::MRS))
928 .addReg(RegNo: TPIDR2EL0, Flags: RegState::Define)
929 .addImm(Val: AArch64SysReg::TPIDR2_EL0);
930 // If TPIDR2_EL0 is non-zero, commit the lazy save.
931 // NOTE: Functions that only use ZT0 don't need to zero ZA.
932 auto CommitZASave =
933 BuildMI(BB&: MBB, I: MBBI, MIMD: DL, MCID: TII->get(Opcode: AArch64::CommitZASavePseudo))
934 .addReg(RegNo: TPIDR2EL0)
935 .addImm(Val: ZeroZA)
936 .addImm(Val: ZeroZT0);
937 addSMELibCall(
938 MIB&: CommitZASave, LC: RTLIB::SMEABI_TPIDR2_SAVE,
939 ExpectedCC: CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X0);
940 if (ZeroZA)
941 CommitZASave.addDef(RegNo: AArch64::ZAB0, Flags: RegState::ImplicitDefine);
942 if (ZeroZT0)
943 CommitZASave.addDef(RegNo: AArch64::ZT0, Flags: RegState::ImplicitDefine);
944 // Enable ZA (as ZA could have previously been in the OFF state).
945 BuildMI(BB&: MBB, I: MBBI, MIMD: DL, MCID: TII->get(Opcode: AArch64::MSRpstatesvcrImm1))
946 .addImm(Val: AArch64SVCR::SVCRZA)
947 .addImm(Val: 1);
948 } else if (AFI->getSMEFnAttrs().hasSharedZAInterface()) {
949 if (ZeroZA)
950 BuildMI(BB&: MBB, I: MBBI, MIMD: DL, MCID: TII->get(Opcode: AArch64::ZERO_M))
951 .addImm(Val: ZERO_ALL_ZA_MASK)
952 .addDef(RegNo: AArch64::ZAB0, Flags: RegState::ImplicitDefine);
953 if (ZeroZT0)
954 BuildMI(BB&: MBB, I: MBBI, MIMD: DL, MCID: TII->get(Opcode: AArch64::ZERO_T)).addDef(RegNo: AArch64::ZT0);
955 }
956}
957
958void MachineSMEABI::emitFullZASaveRestore(EmitContext &Context,
959 MachineBasicBlock &MBB,
960 MachineBasicBlock::iterator MBBI,
961 LiveRegs PhysLiveRegs, bool IsSave) {
962 DebugLoc DL = getDebugLoc(MBB, MBBI);
963
964 if (IsSave)
965 emitCallSaveRemarks(MBB, MBBI, DL, Marker: AArch64::RequiresZASavePseudo,
966 RemarkName: "SMEFullZASave", SaveName: "full save");
967
968 PhysRegSave RegSave = createPhysRegSave(PhysLiveRegs, MBB, MBBI, DL);
969
970 // Copy the buffer pointer into X0.
971 Register BufferPtr = AArch64::X0;
972 BuildMI(BB&: MBB, I: MBBI, MIMD: DL, MCID: TII->get(Opcode: TargetOpcode::COPY), DestReg: BufferPtr)
973 .addReg(RegNo: Context.getAgnosticZABufferPtr(MF&: *MF));
974
975 // Call __arm_sme_save/__arm_sme_restore.
976 auto SaveRestoreZA = BuildMI(BB&: MBB, I: MBBI, MIMD: DL, MCID: TII->get(Opcode: AArch64::BL))
977 .addReg(RegNo: BufferPtr, Flags: RegState::Implicit);
978 addSMELibCall(
979 MIB&: SaveRestoreZA,
980 LC: IsSave ? RTLIB::SMEABI_SME_SAVE : RTLIB::SMEABI_SME_RESTORE,
981 ExpectedCC: CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X1);
982
983 restorePhyRegSave(RegSave, MBB, MBBI, DL);
984}
985
986void MachineSMEABI::emitZT0SaveRestore(EmitContext &Context,
987 MachineBasicBlock &MBB,
988 MachineBasicBlock::iterator MBBI,
989 bool IsSave) {
990 DebugLoc DL = getDebugLoc(MBB, MBBI);
991
992 // Note: This will report calls that _only_ need ZT0 saved. Call that save
993 // both ZA and ZT0 will be under the SMELazySaveZA remark. This prevents
994 // reporting the same calls twice.
995 if (IsSave)
996 emitCallSaveRemarks(MBB, MBBI, DL, Marker: AArch64::RequiresZT0SavePseudo,
997 RemarkName: "SMEZT0Save", SaveName: "spill");
998
999 Register ZT0Save = MRI->createVirtualRegister(RegClass: &AArch64::GPR64spRegClass);
1000
1001 BuildMI(BB&: MBB, I: MBBI, MIMD: DL, MCID: TII->get(Opcode: AArch64::ADDXri), DestReg: ZT0Save)
1002 .addFrameIndex(Idx: Context.getZT0SaveSlot(MF&: *MF))
1003 .addImm(Val: 0)
1004 .addImm(Val: 0);
1005
1006 if (IsSave) {
1007 BuildMI(BB&: MBB, I: MBBI, MIMD: DL, MCID: TII->get(Opcode: AArch64::STR_TX))
1008 .addReg(RegNo: AArch64::ZT0)
1009 .addReg(RegNo: ZT0Save);
1010 } else {
1011 BuildMI(BB&: MBB, I: MBBI, MIMD: DL, MCID: TII->get(Opcode: AArch64::LDR_TX), DestReg: AArch64::ZT0)
1012 .addReg(RegNo: ZT0Save);
1013 }
1014}
1015
1016void MachineSMEABI::emitAllocateFullZASaveBuffer(
1017 EmitContext &Context, MachineBasicBlock &MBB,
1018 MachineBasicBlock::iterator MBBI, LiveRegs PhysLiveRegs) {
1019 // Buffer already allocated in SelectionDAG.
1020 if (AFI->getEarlyAllocSMESaveBuffer())
1021 return;
1022
1023 DebugLoc DL = getDebugLoc(MBB, MBBI);
1024 Register BufferPtr = Context.getAgnosticZABufferPtr(MF&: *MF);
1025 Register BufferSize = MRI->createVirtualRegister(RegClass: &AArch64::GPR64RegClass);
1026
1027 PhysRegSave RegSave = createPhysRegSave(PhysLiveRegs, MBB, MBBI, DL);
1028
1029 // Calculate the SME state size.
1030 {
1031 auto SMEStateSize = BuildMI(BB&: MBB, I: MBBI, MIMD: DL, MCID: TII->get(Opcode: AArch64::BL))
1032 .addReg(RegNo: AArch64::X0, Flags: RegState::ImplicitDefine);
1033 addSMELibCall(
1034 MIB&: SMEStateSize, LC: RTLIB::SMEABI_SME_STATE_SIZE,
1035 ExpectedCC: CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X1);
1036 BuildMI(BB&: MBB, I: MBBI, MIMD: DL, MCID: TII->get(Opcode: TargetOpcode::COPY), DestReg: BufferSize)
1037 .addReg(RegNo: AArch64::X0);
1038 }
1039
1040 // Allocate a buffer object of the size given __arm_sme_state_size.
1041 {
1042 MachineFrameInfo &MFI = MF->getFrameInfo();
1043 BuildMI(BB&: MBB, I: MBBI, MIMD: DL, MCID: TII->get(Opcode: AArch64::SUBXrx64), DestReg: AArch64::SP)
1044 .addReg(RegNo: AArch64::SP)
1045 .addReg(RegNo: BufferSize)
1046 .addImm(Val: AArch64_AM::getArithExtendImm(ET: AArch64_AM::UXTX, Imm: 0));
1047 BuildMI(BB&: MBB, I: MBBI, MIMD: DL, MCID: TII->get(Opcode: TargetOpcode::COPY), DestReg: BufferPtr)
1048 .addReg(RegNo: AArch64::SP);
1049
1050 // We have just allocated a variable sized object, tell this to PEI.
1051 MFI.CreateVariableSizedObject(Alignment: Align(16), Alloca: nullptr);
1052 }
1053
1054 restorePhyRegSave(RegSave, MBB, MBBI, DL);
1055}
1056
1057struct FromState {
1058 ZAState From;
1059
1060 constexpr uint8_t to(ZAState To) const {
1061 static_assert(NUM_ZA_STATE < 16, "expected ZAState to fit in 4-bits");
1062 return uint8_t(From) << 4 | uint8_t(To);
1063 }
1064};
1065
1066constexpr FromState transitionFrom(ZAState From) { return FromState{.From: From}; }
1067
1068void MachineSMEABI::emitStateChange(EmitContext &Context,
1069 MachineBasicBlock &MBB,
1070 MachineBasicBlock::iterator InsertPt,
1071 ZAState From, ZAState To,
1072 LiveRegs PhysLiveRegs) {
1073 // ZA not used.
1074 if (From == ZAState::ANY || To == ZAState::ANY)
1075 return;
1076
1077 // If we're exiting from the ENTRY state that means that the function has not
1078 // used ZA, so in the case of private ZA/ZT0 functions we can omit any set up.
1079 if (From == ZAState::ENTRY && To == ZAState::OFF)
1080 return;
1081
1082 // TODO: Avoid setting up the save buffer if there's no transition to
1083 // LOCAL_SAVED.
1084 if (From == ZAState::ENTRY) {
1085 assert(&MBB == &MBB.getParent()->front() &&
1086 "ENTRY state only valid in entry block");
1087 emitSMEPrologue(MBB, MBBI: MBB.getFirstNonPHI());
1088 if (To == ZAState::ACTIVE)
1089 return; // Nothing more to do (ZA is active after the prologue).
1090
1091 // Note: "emitNewZAPrologue" zeros ZA, so we may need to setup a lazy save
1092 // if "To" is "ZAState::LOCAL_SAVED". It may be possible to improve this
1093 // case by changing the placement of the zero instruction.
1094 From = ZAState::ACTIVE;
1095 }
1096
1097 SMEAttrs SMEFnAttrs = AFI->getSMEFnAttrs();
1098 bool IsAgnosticZA = SMEFnAttrs.hasAgnosticZAInterface();
1099 bool HasZT0State = SMEFnAttrs.hasZT0State();
1100 bool HasZAState = IsAgnosticZA || SMEFnAttrs.hasZAState();
1101
1102 switch (transitionFrom(From).to(To)) {
1103 // This section handles: ACTIVE <-> ACTIVE_ZT0_SAVED
1104 case transitionFrom(From: ZAState::ACTIVE).to(To: ZAState::ACTIVE_ZT0_SAVED):
1105 emitZT0SaveRestore(Context, MBB, MBBI: InsertPt, /*IsSave=*/true);
1106 break;
1107 case transitionFrom(From: ZAState::ACTIVE_ZT0_SAVED).to(To: ZAState::ACTIVE):
1108 emitZT0SaveRestore(Context, MBB, MBBI: InsertPt, /*IsSave=*/false);
1109 break;
1110
1111 // This section handles: ACTIVE[_ZT0_SAVED] -> LOCAL_SAVED
1112 case transitionFrom(From: ZAState::ACTIVE).to(To: ZAState::LOCAL_SAVED):
1113 case transitionFrom(From: ZAState::ACTIVE_ZT0_SAVED).to(To: ZAState::LOCAL_SAVED):
1114 if (HasZT0State && From == ZAState::ACTIVE)
1115 emitZT0SaveRestore(Context, MBB, MBBI: InsertPt, /*IsSave=*/true);
1116 if (HasZAState)
1117 emitZASave(Context, MBB, MBBI: InsertPt, PhysLiveRegs);
1118 break;
1119
1120 // This section handles: ACTIVE -> LOCAL_COMMITTED
1121 case transitionFrom(From: ZAState::ACTIVE).to(To: ZAState::LOCAL_COMMITTED):
1122 // TODO: We could support ZA state here, but this transition is currently
1123 // only possible when we _don't_ have ZA state.
1124 assert(HasZT0State && !HasZAState && "Expect to only have ZT0 state.");
1125 emitZT0SaveRestore(Context, MBB, MBBI: InsertPt, /*IsSave=*/true);
1126 emitZAMode(MBB, MBBI: InsertPt, /*ClearTPIDR2=*/false, /*On=*/false);
1127 break;
1128
1129 // This section handles: LOCAL_COMMITTED -> (OFF|LOCAL_SAVED)
1130 case transitionFrom(From: ZAState::LOCAL_COMMITTED).to(To: ZAState::OFF):
1131 case transitionFrom(From: ZAState::LOCAL_COMMITTED).to(To: ZAState::LOCAL_SAVED):
1132 // These transistions are a no-op.
1133 break;
1134
1135 // This section handles: LOCAL_(SAVED|COMMITTED) -> ACTIVE[_ZT0_SAVED]
1136 case transitionFrom(From: ZAState::LOCAL_COMMITTED).to(To: ZAState::ACTIVE):
1137 case transitionFrom(From: ZAState::LOCAL_COMMITTED).to(To: ZAState::ACTIVE_ZT0_SAVED):
1138 case transitionFrom(From: ZAState::LOCAL_SAVED).to(To: ZAState::ACTIVE):
1139 if (HasZAState)
1140 emitZARestore(Context, MBB, MBBI: InsertPt, PhysLiveRegs);
1141 else
1142 emitZAMode(MBB, MBBI: InsertPt, /*ClearTPIDR2=*/false, /*On=*/true);
1143 if (HasZT0State && To == ZAState::ACTIVE)
1144 emitZT0SaveRestore(Context, MBB, MBBI: InsertPt, /*IsSave=*/false);
1145 break;
1146
1147 // This section handles transistions to OFF (not previously covered)
1148 case transitionFrom(From: ZAState::ACTIVE).to(To: ZAState::OFF):
1149 case transitionFrom(From: ZAState::ACTIVE_ZT0_SAVED).to(To: ZAState::OFF):
1150 case transitionFrom(From: ZAState::LOCAL_SAVED).to(To: ZAState::OFF):
1151 assert(SMEFnAttrs.hasPrivateZAInterface() &&
1152 "Did not expect to turn ZA off in shared/agnostic ZA function");
1153 emitZAMode(MBB, MBBI: InsertPt, /*ClearTPIDR2=*/From == ZAState::LOCAL_SAVED,
1154 /*On=*/false);
1155 break;
1156
1157 default:
1158 dbgs() << "Error: Transition from " << getZAStateString(State: From) << " to "
1159 << getZAStateString(State: To) << '\n';
1160 llvm_unreachable("Unimplemented state transition");
1161 }
1162}
1163
1164} // end anonymous namespace
1165
1166INITIALIZE_PASS(MachineSMEABI, "aarch64-machine-sme-abi", "Machine SME ABI",
1167 false, false)
1168
1169bool MachineSMEABI::runOnMachineFunction(MachineFunction &MF) {
1170 Subtarget = &MF.getSubtarget<AArch64Subtarget>();
1171 if (!Subtarget->hasSME())
1172 return false;
1173
1174 AFI = MF.getInfo<AArch64FunctionInfo>();
1175 SMEAttrs SMEFnAttrs = AFI->getSMEFnAttrs();
1176 if (!SMEFnAttrs.hasZAState() && !SMEFnAttrs.hasZT0State() &&
1177 !SMEFnAttrs.hasAgnosticZAInterface())
1178 return false;
1179
1180 assert(MF.getRegInfo().isSSA() && "Expected to be run on SSA form!");
1181
1182 this->MF = &MF;
1183 ORE = &getAnalysis<MachineOptimizationRemarkEmitterPass>().getORE();
1184 LLI = &getAnalysis<LibcallLoweringInfoWrapper>().getLibcallLowering(
1185 M: *MF.getFunction().getParent(), Subtarget: *Subtarget);
1186 TII = Subtarget->getInstrInfo();
1187 TRI = Subtarget->getRegisterInfo();
1188 MRI = &MF.getRegInfo();
1189
1190 const EdgeBundles &Bundles =
1191 getAnalysis<EdgeBundlesWrapperLegacy>().getEdgeBundles();
1192
1193 FunctionInfo FnInfo = collectNeededZAStates(SMEFnAttrs);
1194
1195 SmallVector<ZAState> BundleStates = assignBundleZAStates(Bundles, FnInfo);
1196
1197 EmitContext Context;
1198 insertStateChanges(Context, FnInfo, Bundles, BundleStates);
1199
1200 if (Context.needsSaveBuffer()) {
1201 if (FnInfo.AfterSMEProloguePt) {
1202 // Note: With inline stack probes the AfterSMEProloguePt may not be in the
1203 // entry block (due to the probing loop).
1204 MachineBasicBlock::iterator MBBI = *FnInfo.AfterSMEProloguePt;
1205 emitAllocateZASaveBuffer(Context, MBB&: *MBBI->getParent(), MBBI,
1206 PhysLiveRegs: FnInfo.PhysLiveRegsAfterSMEPrologue);
1207 } else {
1208 MachineBasicBlock &EntryBlock = MF.front();
1209 emitAllocateZASaveBuffer(
1210 Context, MBB&: EntryBlock, MBBI: EntryBlock.getFirstNonPHI(),
1211 PhysLiveRegs: FnInfo.Blocks[EntryBlock.getNumber()].PhysLiveRegsAtEntry);
1212 }
1213 }
1214
1215 return true;
1216}
1217
1218FunctionPass *llvm::createMachineSMEABIPass(CodeGenOptLevel OptLevel) {
1219 return new MachineSMEABI(OptLevel);
1220}
1221