1//===- MachineSMEABIPass.cpp ----------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This pass implements the SME ABI requirements for ZA state. This includes
10// implementing the lazy (and agnostic) ZA state save schemes around calls.
11//
12//===----------------------------------------------------------------------===//
13//
14// This pass works by collecting instructions that require ZA to be in a
15// specific state (e.g., "ACTIVE" or "SAVED") and inserting the necessary state
16// transitions to ensure ZA is in the required state before instructions. State
17// transitions represent actions such as setting up or restoring a lazy save.
18// Certain points within a function may also have predefined states independent
19// of any instructions, for example, a "shared_za" function is always entered
20// and exited in the "ACTIVE" state.
21//
22// To handle ZA state across control flow, we make use of edge bundling. This
23// assigns each block an "incoming" and "outgoing" edge bundle (representing
24// incoming and outgoing edges). Initially, these are unique to each block;
25// then, in the process of forming bundles, the outgoing bundle of a block is
26// joined with the incoming bundle of all successors. The result is that each
27// bundle can be assigned a single ZA state, which ensures the state required by
28// all a blocks' successors is the same, and that each basic block will always
29// be entered with the same ZA state. This eliminates the need for splitting
30// edges to insert state transitions or "phi" nodes for ZA states.
31//
32// See below for a simple example of edge bundling.
33//
34// The following shows a conditionally executed basic block (BB1):
35//
36// if (cond)
37// BB1
38// BB2
39//
40// Initial Bundles Joined Bundles
41//
42// ┌──0──┐ ┌──0──┐
43// │ BB0 │ │ BB0 │
44// └──1──┘ └──1──┘
45// ├───────┐ ├───────┐
46// ▼ │ ▼ │
47// ┌──2──┐ │ ─────► ┌──1──┐ │
48// │ BB1 │ ▼ │ BB1 │ ▼
49// └──3──┘ ┌──4──┐ └──1──┘ ┌──1──┐
50// └───►4 BB2 │ └───►1 BB2 │
51// └──5──┘ └──2──┘
52//
53// On the left are the initial per-block bundles, and on the right are the
54// joined bundles (which are the result of the EdgeBundles analysis).
55
56#include "AArch64InstrInfo.h"
57#include "AArch64MachineFunctionInfo.h"
58#include "AArch64Subtarget.h"
59#include "MCTargetDesc/AArch64AddressingModes.h"
60#include "llvm/ADT/BitmaskEnum.h"
61#include "llvm/ADT/SmallVector.h"
62#include "llvm/CodeGen/EdgeBundles.h"
63#include "llvm/CodeGen/LivePhysRegs.h"
64#include "llvm/CodeGen/MachineBasicBlock.h"
65#include "llvm/CodeGen/MachineFunctionPass.h"
66#include "llvm/CodeGen/MachineOptimizationRemarkEmitter.h"
67#include "llvm/CodeGen/MachineRegisterInfo.h"
68#include "llvm/CodeGen/TargetRegisterInfo.h"
69
70using namespace llvm;
71
72#define DEBUG_TYPE "aarch64-machine-sme-abi"
73
74namespace {
75
76// Note: For agnostic ZA, we assume the function is always entered/exited in the
77// "ACTIVE" state -- this _may_ not be the case (since OFF is also a
78// possibility, but for the purpose of placing ZA saves/restores, that does not
79// matter).
80enum ZAState : uint8_t {
81 // Any/unknown state (not valid)
82 ANY = 0,
83
84 // ZA is in use and active (i.e. within the accumulator)
85 ACTIVE,
86
87 // ZA is active, but ZT0 has been saved.
88 // This handles the edge case of sharedZA && !sharesZT0.
89 ACTIVE_ZT0_SAVED,
90
91 // A ZA save has been set up or committed (i.e. ZA is dormant or off)
92 // If the function uses ZT0 it must also be saved.
93 LOCAL_SAVED,
94
95 // ZA has been committed to the lazy save buffer of the current function.
96 // If the function uses ZT0 it must also be saved.
97 // ZA is off.
98 LOCAL_COMMITTED,
99
100 // The ZA/ZT0 state on entry to the function.
101 ENTRY,
102
103 // ZA is off.
104 OFF,
105
106 // The number of ZA states (not a valid state)
107 NUM_ZA_STATE
108};
109
110/// A bitmask enum to record live physical registers that the "emit*" routines
111/// may need to preserve. Note: This only tracks registers we may clobber.
112enum LiveRegs : uint8_t {
113 None = 0,
114 NZCV = 1 << 0,
115 W0 = 1 << 1,
116 W0_HI = 1 << 2,
117 X0 = W0 | W0_HI,
118 LLVM_MARK_AS_BITMASK_ENUM(/* LargestValue = */ W0_HI)
119};
120
121/// Holds the virtual registers live physical registers have been saved to.
122struct PhysRegSave {
123 LiveRegs PhysLiveRegs;
124 Register StatusFlags = AArch64::NoRegister;
125 Register X0Save = AArch64::NoRegister;
126};
127
128/// Contains the needed ZA state (and live registers) at an instruction. That is
129/// the state ZA must be in _before_ "InsertPt".
130struct InstInfo {
131 ZAState NeededState{ZAState::ANY};
132 MachineBasicBlock::iterator InsertPt;
133 LiveRegs PhysLiveRegs = LiveRegs::None;
134};
135
136/// Contains the needed ZA state for each instruction in a block. Instructions
137/// that do not require a ZA state are not recorded.
138struct BlockInfo {
139 SmallVector<InstInfo> Insts;
140 ZAState FixedEntryState{ZAState::ANY};
141 ZAState DesiredIncomingState{ZAState::ANY};
142 ZAState DesiredOutgoingState{ZAState::ANY};
143 LiveRegs PhysLiveRegsAtEntry = LiveRegs::None;
144 LiveRegs PhysLiveRegsAtExit = LiveRegs::None;
145};
146
147/// Contains the needed ZA state information for all blocks within a function.
148struct FunctionInfo {
149 SmallVector<BlockInfo> Blocks;
150 std::optional<MachineBasicBlock::iterator> AfterSMEProloguePt;
151 LiveRegs PhysLiveRegsAfterSMEPrologue = LiveRegs::None;
152};
153
154/// State/helpers that is only needed when emitting code to handle
155/// saving/restoring ZA.
156class EmitContext {
157public:
158 EmitContext() = default;
159
160 /// Get or create a TPIDR2 block in \p MF.
161 int getTPIDR2Block(MachineFunction &MF) {
162 if (TPIDR2BlockFI)
163 return *TPIDR2BlockFI;
164 MachineFrameInfo &MFI = MF.getFrameInfo();
165 TPIDR2BlockFI = MFI.CreateStackObject(Size: 16, Alignment: Align(16), isSpillSlot: false);
166 return *TPIDR2BlockFI;
167 }
168
169 /// Get or create agnostic ZA buffer pointer in \p MF.
170 Register getAgnosticZABufferPtr(MachineFunction &MF) {
171 if (AgnosticZABufferPtr != AArch64::NoRegister)
172 return AgnosticZABufferPtr;
173 Register BufferPtr =
174 MF.getInfo<AArch64FunctionInfo>()->getEarlyAllocSMESaveBuffer();
175 AgnosticZABufferPtr =
176 BufferPtr != AArch64::NoRegister
177 ? BufferPtr
178 : MF.getRegInfo().createVirtualRegister(RegClass: &AArch64::GPR64RegClass);
179 return AgnosticZABufferPtr;
180 }
181
182 int getZT0SaveSlot(MachineFunction &MF) {
183 if (ZT0SaveFI)
184 return *ZT0SaveFI;
185 MachineFrameInfo &MFI = MF.getFrameInfo();
186 ZT0SaveFI = MFI.CreateSpillStackObject(Size: 64, Alignment: Align(16));
187 return *ZT0SaveFI;
188 }
189
190 /// Returns true if the function must allocate a ZA save buffer on entry. This
191 /// will be the case if, at any point in the function, a ZA save was emitted.
192 bool needsSaveBuffer() const {
193 assert(!(TPIDR2BlockFI && AgnosticZABufferPtr) &&
194 "Cannot have both a TPIDR2 block and agnostic ZA buffer");
195 return TPIDR2BlockFI || AgnosticZABufferPtr != AArch64::NoRegister;
196 }
197
198private:
199 std::optional<int> ZT0SaveFI;
200 std::optional<int> TPIDR2BlockFI;
201 Register AgnosticZABufferPtr = AArch64::NoRegister;
202};
203
204StringRef getZAStateString(ZAState State) {
205#define MAKE_CASE(V) \
206 case V: \
207 return #V;
208 switch (State) {
209 MAKE_CASE(ZAState::ANY)
210 MAKE_CASE(ZAState::ACTIVE)
211 MAKE_CASE(ZAState::ACTIVE_ZT0_SAVED)
212 MAKE_CASE(ZAState::LOCAL_SAVED)
213 MAKE_CASE(ZAState::LOCAL_COMMITTED)
214 MAKE_CASE(ZAState::ENTRY)
215 MAKE_CASE(ZAState::OFF)
216 default:
217 llvm_unreachable("Unexpected ZAState");
218 }
219#undef MAKE_CASE
220}
221
222static bool isZAorZTRegOp(const TargetRegisterInfo &TRI,
223 const MachineOperand &MO) {
224 if (!MO.isReg() || !MO.getReg().isPhysical())
225 return false;
226 return any_of(Range: TRI.subregs_inclusive(Reg: MO.getReg()), P: [](const MCPhysReg &SR) {
227 return AArch64::MPR128RegClass.contains(Reg: SR) ||
228 AArch64::ZTRRegClass.contains(Reg: SR);
229 });
230}
231
232/// Returns the required ZA state needed before \p MI and an iterator pointing
233/// to where any code required to change the ZA state should be inserted.
234static std::pair<ZAState, MachineBasicBlock::iterator>
235getInstNeededZAState(const TargetRegisterInfo &TRI, MachineInstr &MI,
236 SMEAttrs SMEFnAttrs) {
237 MachineBasicBlock::iterator InsertPt(MI);
238
239 // Note: InOutZAUsePseudo, RequiresZASavePseudo, and RequiresZT0SavePseudo are
240 // intended to mark the position immediately before a call. Due to
241 // SelectionDAG constraints, these markers occur after the ADJCALLSTACKDOWN,
242 // so we use std::prev(InsertPt) to get the position before the call.
243
244 if (MI.getOpcode() == AArch64::InOutZAUsePseudo)
245 return {ZAState::ACTIVE, std::prev(x: InsertPt)};
246
247 // Note: If we need to save both ZA and ZT0 we use RequiresZASavePseudo.
248 if (MI.getOpcode() == AArch64::RequiresZASavePseudo)
249 return {ZAState::LOCAL_SAVED, std::prev(x: InsertPt)};
250
251 // If we only need to save ZT0 there's two cases to consider:
252 // 1. The function has ZA state (that we don't need to save).
253 // - In this case we switch to the "ACTIVE_ZT0_SAVED" state.
254 // This only saves ZT0.
255 // 2. The function does not have ZA state
256 // - In this case we switch to "LOCAL_COMMITTED" state.
257 // This saves ZT0 and turns ZA off.
258 if (MI.getOpcode() == AArch64::RequiresZT0SavePseudo) {
259 return {SMEFnAttrs.hasZAState() ? ZAState::ACTIVE_ZT0_SAVED
260 : ZAState::LOCAL_COMMITTED,
261 std::prev(x: InsertPt)};
262 }
263
264 if (MI.isReturn()) {
265 bool ZAOffAtReturn = SMEFnAttrs.hasPrivateZAInterface();
266 return {ZAOffAtReturn ? ZAState::OFF : ZAState::ACTIVE, InsertPt};
267 }
268
269 for (auto &MO : MI.operands()) {
270 if (isZAorZTRegOp(TRI, MO))
271 return {ZAState::ACTIVE, InsertPt};
272 }
273
274 return {ZAState::ANY, InsertPt};
275}
276
277struct MachineSMEABI : public MachineFunctionPass {
278 inline static char ID = 0;
279
280 MachineSMEABI(CodeGenOptLevel OptLevel = CodeGenOptLevel::Default)
281 : MachineFunctionPass(ID), OptLevel(OptLevel) {}
282
283 bool runOnMachineFunction(MachineFunction &MF) override;
284
285 StringRef getPassName() const override { return "Machine SME ABI pass"; }
286
287 void getAnalysisUsage(AnalysisUsage &AU) const override {
288 AU.setPreservesCFG();
289 AU.addRequired<EdgeBundlesWrapperLegacy>();
290 AU.addRequired<MachineOptimizationRemarkEmitterPass>();
291 AU.addRequired<LibcallLoweringInfoWrapper>();
292 AU.addPreservedID(ID&: MachineLoopInfoID);
293 AU.addPreservedID(ID&: MachineDominatorsID);
294 MachineFunctionPass::getAnalysisUsage(AU);
295 }
296
297 /// Collects the needed ZA state (and live registers) before each instruction
298 /// within the machine function.
299 FunctionInfo collectNeededZAStates(SMEAttrs SMEFnAttrs);
300
301 /// Assigns each edge bundle a ZA state based on the desired states of
302 /// incoming and outgoing blocks in the bundle.
303 SmallVector<ZAState> assignBundleZAStates(const EdgeBundles &Bundles,
304 const FunctionInfo &FnInfo);
305
306 /// Inserts code to handle changes between ZA states within the function.
307 /// E.g., ACTIVE -> LOCAL_SAVED will insert code required to save ZA.
308 void insertStateChanges(EmitContext &, const FunctionInfo &FnInfo,
309 const EdgeBundles &Bundles,
310 ArrayRef<ZAState> BundleStates);
311
312 void addSMELibCall(MachineInstrBuilder &MIB, RTLIB::Libcall LC,
313 CallingConv::ID ExpectedCC);
314
315 void emitZT0SaveRestore(EmitContext &, MachineBasicBlock &MBB,
316 MachineBasicBlock::iterator MBBI, bool IsSave);
317
318 // Emission routines for private and shared ZA functions (using lazy saves).
319 void emitSMEPrologue(MachineBasicBlock &MBB,
320 MachineBasicBlock::iterator MBBI);
321 void emitRestoreLazySave(EmitContext &, MachineBasicBlock &MBB,
322 MachineBasicBlock::iterator MBBI,
323 LiveRegs PhysLiveRegs);
324 void emitSetupLazySave(EmitContext &, MachineBasicBlock &MBB,
325 MachineBasicBlock::iterator MBBI);
326 void emitAllocateLazySaveBuffer(EmitContext &, MachineBasicBlock &MBB,
327 MachineBasicBlock::iterator MBBI);
328 void emitZAMode(MachineBasicBlock &MBB, MachineBasicBlock::iterator MBBI,
329 bool ClearTPIDR2, bool On);
330
331 // Emission routines for agnostic ZA functions.
332 void emitSetupFullZASave(MachineBasicBlock &MBB,
333 MachineBasicBlock::iterator MBBI,
334 LiveRegs PhysLiveRegs);
335 // Emit a "full" ZA save or restore. It is "full" in the sense that this
336 // function will emit a call to __arm_sme_save or __arm_sme_restore, which
337 // handles saving and restoring both ZA and ZT0.
338 void emitFullZASaveRestore(EmitContext &, MachineBasicBlock &MBB,
339 MachineBasicBlock::iterator MBBI,
340 LiveRegs PhysLiveRegs, bool IsSave);
341 void emitAllocateFullZASaveBuffer(EmitContext &, MachineBasicBlock &MBB,
342 MachineBasicBlock::iterator MBBI,
343 LiveRegs PhysLiveRegs);
344
345 /// Attempts to find an insertion point before \p Inst where the status flags
346 /// are not live. If \p Inst is `Block.Insts.end()` a point before the end of
347 /// the block is found.
348 std::pair<MachineBasicBlock::iterator, LiveRegs>
349 findStateChangeInsertionPoint(MachineBasicBlock &MBB, const BlockInfo &Block,
350 SmallVectorImpl<InstInfo>::const_iterator Inst);
351 void emitStateChange(EmitContext &, MachineBasicBlock &MBB,
352 MachineBasicBlock::iterator MBBI, ZAState From,
353 ZAState To, LiveRegs PhysLiveRegs);
354
355 // Helpers for switching between lazy/full ZA save/restore routines.
356 void emitZASave(EmitContext &Context, MachineBasicBlock &MBB,
357 MachineBasicBlock::iterator MBBI, LiveRegs PhysLiveRegs) {
358 if (AFI->getSMEFnAttrs().hasAgnosticZAInterface())
359 return emitFullZASaveRestore(Context, MBB, MBBI, PhysLiveRegs,
360 /*IsSave=*/true);
361 return emitSetupLazySave(Context, MBB, MBBI);
362 }
363 void emitZARestore(EmitContext &Context, MachineBasicBlock &MBB,
364 MachineBasicBlock::iterator MBBI, LiveRegs PhysLiveRegs) {
365 if (AFI->getSMEFnAttrs().hasAgnosticZAInterface())
366 return emitFullZASaveRestore(Context, MBB, MBBI, PhysLiveRegs,
367 /*IsSave=*/false);
368 return emitRestoreLazySave(Context, MBB, MBBI, PhysLiveRegs);
369 }
370 void emitAllocateZASaveBuffer(EmitContext &Context, MachineBasicBlock &MBB,
371 MachineBasicBlock::iterator MBBI,
372 LiveRegs PhysLiveRegs) {
373 if (AFI->getSMEFnAttrs().hasAgnosticZAInterface())
374 return emitAllocateFullZASaveBuffer(Context, MBB, MBBI, PhysLiveRegs);
375 return emitAllocateLazySaveBuffer(Context, MBB, MBBI);
376 }
377
378 /// Collects the reachable calls from \p MBBI marked with \p Marker. This is
379 /// intended to be used to emit lazy save remarks. Note: This stops at the
380 /// first marked call along any path.
381 void collectReachableMarkedCalls(const MachineBasicBlock &MBB,
382 MachineBasicBlock::const_iterator MBBI,
383 SmallVectorImpl<const MachineInstr *> &Calls,
384 unsigned Marker) const;
385
386 void emitCallSaveRemarks(const MachineBasicBlock &MBB,
387 MachineBasicBlock::const_iterator MBBI, DebugLoc DL,
388 unsigned Marker, StringRef RemarkName,
389 StringRef SaveName) const;
390
391 void emitError(const Twine &Message) {
392 LLVMContext &Context = MF->getFunction().getContext();
393 Context.emitError(ErrorStr: MF->getName() + ": " + Message);
394 }
395
396 /// Save live physical registers to virtual registers.
397 PhysRegSave createPhysRegSave(LiveRegs PhysLiveRegs, MachineBasicBlock &MBB,
398 MachineBasicBlock::iterator MBBI, DebugLoc DL);
399 /// Restore physical registers from a save of their previous values.
400 void restorePhyRegSave(const PhysRegSave &RegSave, MachineBasicBlock &MBB,
401 MachineBasicBlock::iterator MBBI, DebugLoc DL);
402
403private:
404 CodeGenOptLevel OptLevel = CodeGenOptLevel::Default;
405
406 MachineFunction *MF = nullptr;
407 const AArch64Subtarget *Subtarget = nullptr;
408 const AArch64RegisterInfo *TRI = nullptr;
409 const AArch64FunctionInfo *AFI = nullptr;
410 const AArch64InstrInfo *TII = nullptr;
411 const LibcallLoweringInfo *LLI = nullptr;
412
413 MachineOptimizationRemarkEmitter *ORE = nullptr;
414 MachineRegisterInfo *MRI = nullptr;
415 MachineLoopInfo *MLI = nullptr;
416};
417
418static LiveRegs getPhysLiveRegs(LiveRegUnits const &LiveUnits) {
419 LiveRegs PhysLiveRegs = LiveRegs::None;
420 if (!LiveUnits.available(Reg: AArch64::NZCV))
421 PhysLiveRegs |= LiveRegs::NZCV;
422 // We have to track W0 and X0 separately as otherwise things can get
423 // confused if we attempt to preserve X0 but only W0 was defined.
424 if (!LiveUnits.available(Reg: AArch64::W0))
425 PhysLiveRegs |= LiveRegs::W0;
426 if (!LiveUnits.available(Reg: AArch64::W0_HI))
427 PhysLiveRegs |= LiveRegs::W0_HI;
428 return PhysLiveRegs;
429}
430
431static void setPhysLiveRegs(LiveRegUnits &LiveUnits, LiveRegs PhysLiveRegs) {
432 if (PhysLiveRegs & LiveRegs::NZCV)
433 LiveUnits.addReg(Reg: AArch64::NZCV);
434 if (PhysLiveRegs & LiveRegs::W0)
435 LiveUnits.addReg(Reg: AArch64::W0);
436 if (PhysLiveRegs & LiveRegs::W0_HI)
437 LiveUnits.addReg(Reg: AArch64::W0_HI);
438}
439
440[[maybe_unused]] bool isCallStartOpcode(unsigned Opc) {
441 switch (Opc) {
442 case AArch64::TLSDESC_CALLSEQ:
443 case AArch64::TLSDESC_AUTH_CALLSEQ:
444 case AArch64::ADJCALLSTACKDOWN:
445 return true;
446 default:
447 return false;
448 }
449}
450
451FunctionInfo MachineSMEABI::collectNeededZAStates(SMEAttrs SMEFnAttrs) {
452 assert((SMEFnAttrs.hasAgnosticZAInterface() || SMEFnAttrs.hasZT0State() ||
453 SMEFnAttrs.hasZAState()) &&
454 "Expected function to have ZA/ZT0 state!");
455
456 SmallVector<BlockInfo> Blocks(MF->getNumBlockIDs());
457 LiveRegs PhysLiveRegsAfterSMEPrologue = LiveRegs::None;
458 std::optional<MachineBasicBlock::iterator> AfterSMEProloguePt;
459
460 for (MachineBasicBlock &MBB : *MF) {
461 BlockInfo &Block = Blocks[MBB.getNumber()];
462
463 if (MBB.isEntryBlock()) {
464 // Entry block:
465 Block.FixedEntryState = ZAState::ENTRY;
466 } else if (MBB.isEHPad()) {
467 // EH entry block:
468 Block.FixedEntryState = ZAState::LOCAL_COMMITTED;
469 }
470
471 LiveRegUnits LiveUnits(*TRI);
472 LiveUnits.addLiveOuts(MBB);
473
474 Block.PhysLiveRegsAtExit = getPhysLiveRegs(LiveUnits);
475 auto FirstTerminatorInsertPt = MBB.getFirstTerminator();
476 auto FirstNonPhiInsertPt = MBB.getFirstNonPHI();
477 for (MachineInstr &MI : reverse(C&: MBB)) {
478 if (MI.isDebugInstr())
479 continue;
480
481 MachineBasicBlock::iterator MBBI(MI);
482 LiveUnits.stepBackward(MI);
483 LiveRegs PhysLiveRegs = getPhysLiveRegs(LiveUnits);
484 // The SMEStateAllocPseudo marker is added to a function if the save
485 // buffer was allocated in SelectionDAG. It marks the end of the
486 // allocation -- which is a safe point for this pass to insert any TPIDR2
487 // block setup.
488 if (MI.getOpcode() == AArch64::SMEStateAllocPseudo) {
489 AfterSMEProloguePt = MBBI;
490 PhysLiveRegsAfterSMEPrologue = PhysLiveRegs;
491 }
492 // Note: We treat Agnostic ZA as inout_za with an alternate save/restore.
493 auto [NeededState, InsertPt] = getInstNeededZAState(TRI: *TRI, MI, SMEFnAttrs);
494 assert((InsertPt == MBBI || isCallStartOpcode(InsertPt->getOpcode())) &&
495 "Unexpected state change insertion point!");
496 if (MBBI == FirstTerminatorInsertPt)
497 Block.PhysLiveRegsAtExit = PhysLiveRegs;
498 if (MBBI == FirstNonPhiInsertPt)
499 Block.PhysLiveRegsAtEntry = PhysLiveRegs;
500 if (NeededState != ZAState::ANY)
501 Block.Insts.push_back(Elt: {.NeededState: NeededState, .InsertPt: InsertPt, .PhysLiveRegs: PhysLiveRegs});
502 }
503
504 // Reverse vector (as we had to iterate backwards for liveness).
505 std::reverse(first: Block.Insts.begin(), last: Block.Insts.end());
506
507 // Record the desired states on entry/exit of this block. These are the
508 // states that would not incur a state transition.
509 if (!Block.Insts.empty()) {
510 Block.DesiredIncomingState = Block.Insts.front().NeededState;
511 Block.DesiredOutgoingState = Block.Insts.back().NeededState;
512 }
513 }
514
515 return FunctionInfo{.Blocks: std::move(Blocks), .AfterSMEProloguePt: AfterSMEProloguePt,
516 .PhysLiveRegsAfterSMEPrologue: PhysLiveRegsAfterSMEPrologue};
517}
518
519/// Assigns each edge bundle a ZA state based on the desired states of incoming
520/// and outgoing blocks in the bundle.
521SmallVector<ZAState>
522MachineSMEABI::assignBundleZAStates(const EdgeBundles &Bundles,
523 const FunctionInfo &FnInfo) {
524 SmallVector<ZAState> BundleStates(Bundles.getNumBundles());
525 for (unsigned I = 0, E = Bundles.getNumBundles(); I != E; ++I) {
526 std::optional<ZAState> BundleState;
527 for (unsigned BlockID : Bundles.getBlocks(Bundle: I)) {
528 const BlockInfo &Block = FnInfo.Blocks[BlockID];
529 // Check if the block is an incoming block in the bundle. Note: We skip
530 // Block.FixedEntryState != ANY to ignore EH pads (which are only
531 // reachable via exceptions).
532 if (Block.FixedEntryState != ZAState::ANY ||
533 Bundles.getBundle(N: BlockID, /*Out=*/false) != I)
534 continue;
535
536 // Pick a state that matches all incoming blocks. Fall back to "ACTIVE" if
537 // any incoming state doesn't match. This will hoist the state from
538 // incoming blocks to outgoing blocks.
539 if (!BundleState)
540 BundleState = Block.DesiredIncomingState;
541 else if (BundleState != Block.DesiredIncomingState)
542 BundleState = ZAState::ACTIVE;
543 }
544
545 if (!BundleState || BundleState == ZAState::ANY)
546 BundleState = ZAState::ACTIVE;
547
548 BundleStates[I] = *BundleState;
549 }
550
551 return BundleStates;
552}
553
554std::pair<MachineBasicBlock::iterator, LiveRegs>
555MachineSMEABI::findStateChangeInsertionPoint(
556 MachineBasicBlock &MBB, const BlockInfo &Block,
557 SmallVectorImpl<InstInfo>::const_iterator Inst) {
558 LiveRegs PhysLiveRegs;
559 MachineBasicBlock::iterator InsertPt;
560 if (Inst != Block.Insts.end()) {
561 InsertPt = Inst->InsertPt;
562 PhysLiveRegs = Inst->PhysLiveRegs;
563 } else {
564 InsertPt = MBB.getFirstTerminator();
565 PhysLiveRegs = Block.PhysLiveRegsAtExit;
566 }
567
568 if (PhysLiveRegs == LiveRegs::None)
569 return {InsertPt, PhysLiveRegs}; // Nothing to do (no live regs).
570
571 // Find the previous state change. We can not move before this point.
572 MachineBasicBlock::iterator PrevStateChangeI;
573 if (Inst == Block.Insts.begin()) {
574 PrevStateChangeI = MBB.begin();
575 } else {
576 // Note: `std::prev(Inst)` is the previous InstInfo. We only create an
577 // InstInfo object for instructions that require a specific ZA state, so the
578 // InstInfo is the site of the previous state change in the block (which can
579 // be several MIs earlier).
580 PrevStateChangeI = std::prev(x: Inst)->InsertPt;
581 }
582
583 // Note: LiveUnits will only accurately track X0 and NZCV.
584 LiveRegUnits LiveUnits(*TRI);
585 setPhysLiveRegs(LiveUnits, PhysLiveRegs);
586 auto BestCandidate = std::make_pair(x&: InsertPt, y&: PhysLiveRegs);
587 for (MachineBasicBlock::iterator I = InsertPt; I != PrevStateChangeI; --I) {
588 if (I->isDebugInstr())
589 continue;
590
591 // Don't move before/into a call (which may have a state change before it).
592 if (I->getOpcode() == TII->getCallFrameDestroyOpcode() || I->isCall())
593 break;
594 LiveUnits.stepBackward(MI: *I);
595 LiveRegs CurrentPhysLiveRegs = getPhysLiveRegs(LiveUnits);
596 // Find places where NZCV is available, but keep looking for locations where
597 // both NZCV and X0 are available, which can avoid some copies.
598 if (!(CurrentPhysLiveRegs & LiveRegs::NZCV))
599 BestCandidate = {I, CurrentPhysLiveRegs};
600 if (CurrentPhysLiveRegs == LiveRegs::None)
601 break;
602 }
603 return BestCandidate;
604}
605
606void MachineSMEABI::insertStateChanges(EmitContext &Context,
607 const FunctionInfo &FnInfo,
608 const EdgeBundles &Bundles,
609 ArrayRef<ZAState> BundleStates) {
610 for (MachineBasicBlock &MBB : *MF) {
611 const BlockInfo &Block = FnInfo.Blocks[MBB.getNumber()];
612 ZAState InState = BundleStates[Bundles.getBundle(N: MBB.getNumber(),
613 /*Out=*/false)];
614
615 ZAState CurrentState = Block.FixedEntryState;
616 if (CurrentState == ZAState::ANY)
617 CurrentState = InState;
618
619 for (auto &Inst : Block.Insts) {
620 if (CurrentState != Inst.NeededState) {
621 auto [InsertPt, PhysLiveRegs] =
622 findStateChangeInsertionPoint(MBB, Block, Inst: &Inst);
623 emitStateChange(Context, MBB, MBBI: InsertPt, From: CurrentState, To: Inst.NeededState,
624 PhysLiveRegs);
625 CurrentState = Inst.NeededState;
626 }
627 }
628
629 if (MBB.succ_empty())
630 continue;
631
632 ZAState OutState =
633 BundleStates[Bundles.getBundle(N: MBB.getNumber(), /*Out=*/true)];
634 if (CurrentState != OutState) {
635 auto [InsertPt, PhysLiveRegs] =
636 findStateChangeInsertionPoint(MBB, Block, Inst: Block.Insts.end());
637 emitStateChange(Context, MBB, MBBI: InsertPt, From: CurrentState, To: OutState,
638 PhysLiveRegs);
639 }
640 }
641}
642
643static DebugLoc getDebugLoc(MachineBasicBlock &MBB,
644 MachineBasicBlock::iterator MBBI) {
645 if (MBB.empty())
646 return DebugLoc();
647 return MBBI != MBB.end() ? MBBI->getDebugLoc() : MBB.back().getDebugLoc();
648}
649
650/// Finds the first call (as determined by MachineInstr::isCall()) starting from
651/// \p MBBI in \p MBB marked with \p Marker (which is a marker opcode such as
652/// RequiresZASavePseudo). If a marked call is found, it is pushed to \p Calls
653/// and the function returns true.
654static bool findMarkedCall(const MachineBasicBlock &MBB,
655 MachineBasicBlock::const_iterator MBBI,
656 SmallVectorImpl<const MachineInstr *> &Calls,
657 unsigned Marker, unsigned CallDestroyOpcode) {
658 auto IsMarker = [&](auto &MI) { return MI.getOpcode() == Marker; };
659 auto MarkerInst = std::find_if(first: MBBI, last: MBB.end(), pred: IsMarker);
660 if (MarkerInst == MBB.end())
661 return false;
662 MachineBasicBlock::const_iterator I = MarkerInst;
663 while (++I != MBB.end()) {
664 if (I->isCall() || I->getOpcode() == CallDestroyOpcode)
665 break;
666 }
667 if (I != MBB.end() && I->isCall())
668 Calls.push_back(Elt: &*I);
669 // Note: This function always returns true if a "Marker" was found.
670 return true;
671}
672
673void MachineSMEABI::collectReachableMarkedCalls(
674 const MachineBasicBlock &StartMBB,
675 MachineBasicBlock::const_iterator StartInst,
676 SmallVectorImpl<const MachineInstr *> &Calls, unsigned Marker) const {
677 assert(Marker == AArch64::InOutZAUsePseudo ||
678 Marker == AArch64::RequiresZASavePseudo ||
679 Marker == AArch64::RequiresZT0SavePseudo);
680 unsigned CallDestroyOpcode = TII->getCallFrameDestroyOpcode();
681 if (findMarkedCall(MBB: StartMBB, MBBI: StartInst, Calls, Marker, CallDestroyOpcode))
682 return;
683
684 SmallPtrSet<const MachineBasicBlock *, 4> Visited;
685 SmallVector<const MachineBasicBlock *> Worklist(StartMBB.succ_rbegin(),
686 StartMBB.succ_rend());
687 while (!Worklist.empty()) {
688 const MachineBasicBlock *MBB = Worklist.pop_back_val();
689 auto [_, Inserted] = Visited.insert(Ptr: MBB);
690 if (!Inserted)
691 continue;
692
693 if (!findMarkedCall(MBB: *MBB, MBBI: MBB->begin(), Calls, Marker, CallDestroyOpcode))
694 Worklist.append(in_start: MBB->succ_rbegin(), in_end: MBB->succ_rend());
695 }
696}
697
698static StringRef getCalleeName(const MachineInstr &CallInst) {
699 assert(CallInst.isCall() && "expected a call");
700 for (const MachineOperand &MO : CallInst.operands()) {
701 if (MO.isSymbol())
702 return MO.getSymbolName();
703 if (MO.isGlobal())
704 return MO.getGlobal()->getName();
705 }
706 return {};
707}
708
709void MachineSMEABI::emitCallSaveRemarks(const MachineBasicBlock &MBB,
710 MachineBasicBlock::const_iterator MBBI,
711 DebugLoc DL, unsigned Marker,
712 StringRef RemarkName,
713 StringRef SaveName) const {
714 auto SaveRemark = [&](DebugLoc DL, const MachineBasicBlock &MBB) {
715 return MachineOptimizationRemarkAnalysis("sme", RemarkName, DL, &MBB);
716 };
717 StringRef StateName = Marker == AArch64::RequiresZT0SavePseudo ? "ZT0" : "ZA";
718 ORE->emit(RemarkBuilder: [&] {
719 return SaveRemark(DL, MBB) << SaveName << " of " << StateName
720 << " emitted in '" << MF->getName() << "'";
721 });
722 if (!ORE->allowExtraAnalysis(PassName: "sme"))
723 return;
724 SmallVector<const MachineInstr *> CallsRequiringSaves;
725 collectReachableMarkedCalls(StartMBB: MBB, StartInst: MBBI, Calls&: CallsRequiringSaves, Marker);
726 for (const MachineInstr *CallInst : CallsRequiringSaves) {
727 auto R = SaveRemark(CallInst->getDebugLoc(), *CallInst->getParent());
728 R << "call";
729 if (StringRef CalleeName = getCalleeName(CallInst: *CallInst); !CalleeName.empty())
730 R << " to '" << CalleeName << "'";
731 R << " requires " << StateName << " save";
732 ORE->emit(OptDiag&: R);
733 }
734}
735
736void MachineSMEABI::emitSetupLazySave(EmitContext &Context,
737 MachineBasicBlock &MBB,
738 MachineBasicBlock::iterator MBBI) {
739 DebugLoc DL = getDebugLoc(MBB, MBBI);
740
741 emitCallSaveRemarks(MBB, MBBI, DL, Marker: AArch64::RequiresZASavePseudo,
742 RemarkName: "SMELazySaveZA", SaveName: "lazy save");
743
744 // Get pointer to TPIDR2 block.
745 Register TPIDR2 = MRI->createVirtualRegister(RegClass: &AArch64::GPR64spRegClass);
746 Register TPIDR2Ptr = MRI->createVirtualRegister(RegClass: &AArch64::GPR64RegClass);
747 BuildMI(BB&: MBB, I: MBBI, MIMD: DL, MCID: TII->get(Opcode: AArch64::ADDXri), DestReg: TPIDR2)
748 .addFrameIndex(Idx: Context.getTPIDR2Block(MF&: *MF))
749 .addImm(Val: 0)
750 .addImm(Val: 0);
751 BuildMI(BB&: MBB, I: MBBI, MIMD: DL, MCID: TII->get(Opcode: TargetOpcode::COPY), DestReg: TPIDR2Ptr)
752 .addReg(RegNo: TPIDR2);
753 // Set TPIDR2_EL0 to point to TPIDR2 block.
754 BuildMI(BB&: MBB, I: MBBI, MIMD: DL, MCID: TII->get(Opcode: AArch64::MSR))
755 .addImm(Val: AArch64SysReg::TPIDR2_EL0)
756 .addReg(RegNo: TPIDR2Ptr);
757}
758
759PhysRegSave MachineSMEABI::createPhysRegSave(LiveRegs PhysLiveRegs,
760 MachineBasicBlock &MBB,
761 MachineBasicBlock::iterator MBBI,
762 DebugLoc DL) {
763 PhysRegSave RegSave{.PhysLiveRegs: PhysLiveRegs};
764 if (PhysLiveRegs & LiveRegs::NZCV) {
765 RegSave.StatusFlags = MRI->createVirtualRegister(RegClass: &AArch64::GPR64RegClass);
766 BuildMI(BB&: MBB, I: MBBI, MIMD: DL, MCID: TII->get(Opcode: AArch64::MRS), DestReg: RegSave.StatusFlags)
767 .addImm(Val: AArch64SysReg::NZCV)
768 .addReg(RegNo: AArch64::NZCV, Flags: RegState::Implicit);
769 }
770 // Note: Preserving X0 is "free" as this is before register allocation, so
771 // the register allocator is still able to optimize these copies.
772 if (PhysLiveRegs & LiveRegs::W0) {
773 RegSave.X0Save = MRI->createVirtualRegister(RegClass: PhysLiveRegs & LiveRegs::W0_HI
774 ? &AArch64::GPR64RegClass
775 : &AArch64::GPR32RegClass);
776 BuildMI(BB&: MBB, I: MBBI, MIMD: DL, MCID: TII->get(Opcode: TargetOpcode::COPY), DestReg: RegSave.X0Save)
777 .addReg(RegNo: PhysLiveRegs & LiveRegs::W0_HI ? AArch64::X0 : AArch64::W0);
778 }
779 return RegSave;
780}
781
782void MachineSMEABI::restorePhyRegSave(const PhysRegSave &RegSave,
783 MachineBasicBlock &MBB,
784 MachineBasicBlock::iterator MBBI,
785 DebugLoc DL) {
786 if (RegSave.StatusFlags != AArch64::NoRegister)
787 BuildMI(BB&: MBB, I: MBBI, MIMD: DL, MCID: TII->get(Opcode: AArch64::MSR))
788 .addImm(Val: AArch64SysReg::NZCV)
789 .addReg(RegNo: RegSave.StatusFlags)
790 .addReg(RegNo: AArch64::NZCV, Flags: RegState::ImplicitDefine);
791
792 if (RegSave.X0Save != AArch64::NoRegister)
793 BuildMI(BB&: MBB, I: MBBI, MIMD: DL, MCID: TII->get(Opcode: TargetOpcode::COPY),
794 DestReg: RegSave.PhysLiveRegs & LiveRegs::W0_HI ? AArch64::X0 : AArch64::W0)
795 .addReg(RegNo: RegSave.X0Save);
796}
797
798void MachineSMEABI::addSMELibCall(MachineInstrBuilder &MIB, RTLIB::Libcall LC,
799 CallingConv::ID ExpectedCC) {
800 RTLIB::LibcallImpl LCImpl = LLI->getLibcallImpl(Call: LC);
801 if (LCImpl == RTLIB::Unsupported)
802 emitError(Message: "cannot lower SME ABI (SME routines unsupported)");
803 CallingConv::ID CC = LLI->getLibcallImplCallingConv(Call: LCImpl);
804 StringRef ImplName = RTLIB::RuntimeLibcallsInfo::getLibcallImplName(CallImpl: LCImpl);
805 if (CC != ExpectedCC)
806 emitError(Message: "invalid calling convention for SME routine: '" + ImplName + "'");
807 // FIXME: This assumes the ImplName StringRef is null-terminated.
808 MIB.addExternalSymbol(FnName: ImplName.data());
809 MIB.addRegMask(Mask: TRI->getCallPreservedMask(MF: *MF, CC));
810}
811
812void MachineSMEABI::emitRestoreLazySave(EmitContext &Context,
813 MachineBasicBlock &MBB,
814 MachineBasicBlock::iterator MBBI,
815 LiveRegs PhysLiveRegs) {
816 DebugLoc DL = getDebugLoc(MBB, MBBI);
817 Register TPIDR2EL0 = MRI->createVirtualRegister(RegClass: &AArch64::GPR64RegClass);
818 Register TPIDR2 = AArch64::X0;
819
820 // TODO: Emit these within the restore MBB to prevent unnecessary saves.
821 PhysRegSave RegSave = createPhysRegSave(PhysLiveRegs, MBB, MBBI, DL);
822
823 // Enable ZA.
824 BuildMI(BB&: MBB, I: MBBI, MIMD: DL, MCID: TII->get(Opcode: AArch64::MSRpstatesvcrImm1))
825 .addImm(Val: AArch64SVCR::SVCRZA)
826 .addImm(Val: 1);
827 // Get current TPIDR2_EL0.
828 BuildMI(BB&: MBB, I: MBBI, MIMD: DL, MCID: TII->get(Opcode: AArch64::MRS), DestReg: TPIDR2EL0)
829 .addImm(Val: AArch64SysReg::TPIDR2_EL0);
830 // Get pointer to TPIDR2 block.
831 BuildMI(BB&: MBB, I: MBBI, MIMD: DL, MCID: TII->get(Opcode: AArch64::ADDXri), DestReg: TPIDR2)
832 .addFrameIndex(Idx: Context.getTPIDR2Block(MF&: *MF))
833 .addImm(Val: 0)
834 .addImm(Val: 0);
835 // (Conditionally) restore ZA state.
836 auto RestoreZA = BuildMI(BB&: MBB, I: MBBI, MIMD: DL, MCID: TII->get(Opcode: AArch64::RestoreZAPseudo))
837 .addReg(RegNo: TPIDR2EL0)
838 .addReg(RegNo: TPIDR2);
839 addSMELibCall(
840 MIB&: RestoreZA, LC: RTLIB::SMEABI_TPIDR2_RESTORE,
841 ExpectedCC: CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X0);
842 // Zero TPIDR2_EL0.
843 BuildMI(BB&: MBB, I: MBBI, MIMD: DL, MCID: TII->get(Opcode: AArch64::MSR))
844 .addImm(Val: AArch64SysReg::TPIDR2_EL0)
845 .addReg(RegNo: AArch64::XZR);
846
847 restorePhyRegSave(RegSave, MBB, MBBI, DL);
848}
849
850void MachineSMEABI::emitZAMode(MachineBasicBlock &MBB,
851 MachineBasicBlock::iterator MBBI,
852 bool ClearTPIDR2, bool On) {
853 DebugLoc DL = getDebugLoc(MBB, MBBI);
854
855 if (ClearTPIDR2)
856 BuildMI(BB&: MBB, I: MBBI, MIMD: DL, MCID: TII->get(Opcode: AArch64::MSR))
857 .addImm(Val: AArch64SysReg::TPIDR2_EL0)
858 .addReg(RegNo: AArch64::XZR);
859
860 // Disable ZA.
861 BuildMI(BB&: MBB, I: MBBI, MIMD: DL, MCID: TII->get(Opcode: AArch64::MSRpstatesvcrImm1))
862 .addImm(Val: AArch64SVCR::SVCRZA)
863 .addImm(Val: On ? 1 : 0);
864}
865
866void MachineSMEABI::emitAllocateLazySaveBuffer(
867 EmitContext &Context, MachineBasicBlock &MBB,
868 MachineBasicBlock::iterator MBBI) {
869 MachineFrameInfo &MFI = MF->getFrameInfo();
870 DebugLoc DL = getDebugLoc(MBB, MBBI);
871 Register SP = MRI->createVirtualRegister(RegClass: &AArch64::GPR64RegClass);
872 Register SVL = MRI->createVirtualRegister(RegClass: &AArch64::GPR64RegClass);
873 Register Buffer = AFI->getEarlyAllocSMESaveBuffer();
874
875 // Calculate SVL.
876 BuildMI(BB&: MBB, I: MBBI, MIMD: DL, MCID: TII->get(Opcode: AArch64::RDSVLI_XI), DestReg: SVL).addImm(Val: 1);
877
878 // 1. Allocate the lazy save buffer.
879 if (Buffer == AArch64::NoRegister) {
880 // TODO: On Windows, we allocate the lazy save buffer in SelectionDAG (so
881 // Buffer != AArch64::NoRegister). This is done to reuse the existing
882 // expansions (which can insert stack checks). This works, but it means we
883 // will always allocate the lazy save buffer (even if the function contains
884 // no lazy saves). If we want to handle Windows here, we'll need to
885 // implement something similar to LowerWindowsDYNAMIC_STACKALLOC.
886 assert(!Subtarget->isTargetWindows() &&
887 "Lazy ZA save is not yet supported on Windows");
888 Buffer = MRI->createVirtualRegister(RegClass: &AArch64::GPR64RegClass);
889 // Get original stack pointer.
890 BuildMI(BB&: MBB, I: MBBI, MIMD: DL, MCID: TII->get(Opcode: TargetOpcode::COPY), DestReg: SP)
891 .addReg(RegNo: AArch64::SP);
892 // Allocate a lazy-save buffer object of the size given, normally SVL * SVL
893 BuildMI(BB&: MBB, I: MBBI, MIMD: DL, MCID: TII->get(Opcode: AArch64::MSUBXrrr), DestReg: Buffer)
894 .addReg(RegNo: SVL)
895 .addReg(RegNo: SVL)
896 .addReg(RegNo: SP);
897 BuildMI(BB&: MBB, I: MBBI, MIMD: DL, MCID: TII->get(Opcode: TargetOpcode::COPY), DestReg: AArch64::SP)
898 .addReg(RegNo: Buffer);
899 // We have just allocated a variable sized object, tell this to PEI.
900 MFI.CreateVariableSizedObject(Alignment: Align(16), Alloca: nullptr);
901 }
902
903 // 2. Setup the TPIDR2 block.
904 {
905 // Note: This case just needs to do `SVL << 48`. It is not implemented as we
906 // generally don't support big-endian SVE/SME.
907 if (!Subtarget->isLittleEndian())
908 reportFatalInternalError(
909 reason: "TPIDR2 block initialization is not supported on big-endian targets");
910
911 // Store buffer pointer and num_za_save_slices.
912 // Bytes 10-15 are implicitly zeroed.
913 BuildMI(BB&: MBB, I: MBBI, MIMD: DL, MCID: TII->get(Opcode: AArch64::STPXi))
914 .addReg(RegNo: Buffer)
915 .addReg(RegNo: SVL)
916 .addFrameIndex(Idx: Context.getTPIDR2Block(MF&: *MF))
917 .addImm(Val: 0);
918 }
919}
920
921static constexpr unsigned ZERO_ALL_ZA_MASK = 0b11111111;
922
923void MachineSMEABI::emitSMEPrologue(MachineBasicBlock &MBB,
924 MachineBasicBlock::iterator MBBI) {
925 DebugLoc DL = getDebugLoc(MBB, MBBI);
926
927 bool ZeroZA = AFI->getSMEFnAttrs().isNewZA();
928 bool ZeroZT0 = AFI->getSMEFnAttrs().isNewZT0();
929 if (AFI->getSMEFnAttrs().hasPrivateZAInterface()) {
930 // Get current TPIDR2_EL0.
931 Register TPIDR2EL0 = MRI->createVirtualRegister(RegClass: &AArch64::GPR64RegClass);
932 BuildMI(BB&: MBB, I: MBBI, MIMD: DL, MCID: TII->get(Opcode: AArch64::MRS))
933 .addReg(RegNo: TPIDR2EL0, Flags: RegState::Define)
934 .addImm(Val: AArch64SysReg::TPIDR2_EL0);
935 // If TPIDR2_EL0 is non-zero, commit the lazy save.
936 // NOTE: Functions that only use ZT0 don't need to zero ZA.
937 auto CommitZASave =
938 BuildMI(BB&: MBB, I: MBBI, MIMD: DL, MCID: TII->get(Opcode: AArch64::CommitZASavePseudo))
939 .addReg(RegNo: TPIDR2EL0)
940 .addImm(Val: ZeroZA)
941 .addImm(Val: ZeroZT0);
942 addSMELibCall(
943 MIB&: CommitZASave, LC: RTLIB::SMEABI_TPIDR2_SAVE,
944 ExpectedCC: CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X0);
945 if (ZeroZA)
946 CommitZASave.addDef(RegNo: AArch64::ZAB0, Flags: RegState::ImplicitDefine);
947 if (ZeroZT0)
948 CommitZASave.addDef(RegNo: AArch64::ZT0, Flags: RegState::ImplicitDefine);
949 // Enable ZA (as ZA could have previously been in the OFF state).
950 BuildMI(BB&: MBB, I: MBBI, MIMD: DL, MCID: TII->get(Opcode: AArch64::MSRpstatesvcrImm1))
951 .addImm(Val: AArch64SVCR::SVCRZA)
952 .addImm(Val: 1);
953 } else if (AFI->getSMEFnAttrs().hasSharedZAInterface()) {
954 if (ZeroZA)
955 BuildMI(BB&: MBB, I: MBBI, MIMD: DL, MCID: TII->get(Opcode: AArch64::ZERO_M))
956 .addImm(Val: ZERO_ALL_ZA_MASK)
957 .addDef(RegNo: AArch64::ZAB0, Flags: RegState::ImplicitDefine);
958 if (ZeroZT0)
959 BuildMI(BB&: MBB, I: MBBI, MIMD: DL, MCID: TII->get(Opcode: AArch64::ZERO_T)).addDef(RegNo: AArch64::ZT0);
960 }
961}
962
963void MachineSMEABI::emitFullZASaveRestore(EmitContext &Context,
964 MachineBasicBlock &MBB,
965 MachineBasicBlock::iterator MBBI,
966 LiveRegs PhysLiveRegs, bool IsSave) {
967 DebugLoc DL = getDebugLoc(MBB, MBBI);
968
969 if (IsSave)
970 emitCallSaveRemarks(MBB, MBBI, DL, Marker: AArch64::RequiresZASavePseudo,
971 RemarkName: "SMEFullZASave", SaveName: "full save");
972
973 PhysRegSave RegSave = createPhysRegSave(PhysLiveRegs, MBB, MBBI, DL);
974
975 // Copy the buffer pointer into X0.
976 Register BufferPtr = AArch64::X0;
977 BuildMI(BB&: MBB, I: MBBI, MIMD: DL, MCID: TII->get(Opcode: TargetOpcode::COPY), DestReg: BufferPtr)
978 .addReg(RegNo: Context.getAgnosticZABufferPtr(MF&: *MF));
979
980 // Call __arm_sme_save/__arm_sme_restore.
981 auto SaveRestoreZA = BuildMI(BB&: MBB, I: MBBI, MIMD: DL, MCID: TII->get(Opcode: AArch64::BL))
982 .addReg(RegNo: BufferPtr, Flags: RegState::Implicit);
983 addSMELibCall(
984 MIB&: SaveRestoreZA,
985 LC: IsSave ? RTLIB::SMEABI_SME_SAVE : RTLIB::SMEABI_SME_RESTORE,
986 ExpectedCC: CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X1);
987
988 restorePhyRegSave(RegSave, MBB, MBBI, DL);
989}
990
991void MachineSMEABI::emitZT0SaveRestore(EmitContext &Context,
992 MachineBasicBlock &MBB,
993 MachineBasicBlock::iterator MBBI,
994 bool IsSave) {
995 DebugLoc DL = getDebugLoc(MBB, MBBI);
996
997 // Note: This will report calls that _only_ need ZT0 saved. Call that save
998 // both ZA and ZT0 will be under the SMELazySaveZA remark. This prevents
999 // reporting the same calls twice.
1000 if (IsSave)
1001 emitCallSaveRemarks(MBB, MBBI, DL, Marker: AArch64::RequiresZT0SavePseudo,
1002 RemarkName: "SMEZT0Save", SaveName: "spill");
1003
1004 Register ZT0Save = MRI->createVirtualRegister(RegClass: &AArch64::GPR64spRegClass);
1005
1006 BuildMI(BB&: MBB, I: MBBI, MIMD: DL, MCID: TII->get(Opcode: AArch64::ADDXri), DestReg: ZT0Save)
1007 .addFrameIndex(Idx: Context.getZT0SaveSlot(MF&: *MF))
1008 .addImm(Val: 0)
1009 .addImm(Val: 0);
1010
1011 if (IsSave) {
1012 BuildMI(BB&: MBB, I: MBBI, MIMD: DL, MCID: TII->get(Opcode: AArch64::STR_TX))
1013 .addReg(RegNo: AArch64::ZT0)
1014 .addReg(RegNo: ZT0Save);
1015 } else {
1016 BuildMI(BB&: MBB, I: MBBI, MIMD: DL, MCID: TII->get(Opcode: AArch64::LDR_TX), DestReg: AArch64::ZT0)
1017 .addReg(RegNo: ZT0Save);
1018 }
1019}
1020
1021void MachineSMEABI::emitAllocateFullZASaveBuffer(
1022 EmitContext &Context, MachineBasicBlock &MBB,
1023 MachineBasicBlock::iterator MBBI, LiveRegs PhysLiveRegs) {
1024 // Buffer already allocated in SelectionDAG.
1025 if (AFI->getEarlyAllocSMESaveBuffer())
1026 return;
1027
1028 DebugLoc DL = getDebugLoc(MBB, MBBI);
1029 Register BufferPtr = Context.getAgnosticZABufferPtr(MF&: *MF);
1030 Register BufferSize = MRI->createVirtualRegister(RegClass: &AArch64::GPR64RegClass);
1031
1032 PhysRegSave RegSave = createPhysRegSave(PhysLiveRegs, MBB, MBBI, DL);
1033
1034 // Calculate the SME state size.
1035 {
1036 auto SMEStateSize = BuildMI(BB&: MBB, I: MBBI, MIMD: DL, MCID: TII->get(Opcode: AArch64::BL))
1037 .addReg(RegNo: AArch64::X0, Flags: RegState::ImplicitDefine);
1038 addSMELibCall(
1039 MIB&: SMEStateSize, LC: RTLIB::SMEABI_SME_STATE_SIZE,
1040 ExpectedCC: CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X1);
1041 BuildMI(BB&: MBB, I: MBBI, MIMD: DL, MCID: TII->get(Opcode: TargetOpcode::COPY), DestReg: BufferSize)
1042 .addReg(RegNo: AArch64::X0);
1043 }
1044
1045 // Allocate a buffer object of the size given __arm_sme_state_size.
1046 {
1047 MachineFrameInfo &MFI = MF->getFrameInfo();
1048 BuildMI(BB&: MBB, I: MBBI, MIMD: DL, MCID: TII->get(Opcode: AArch64::SUBXrx64), DestReg: AArch64::SP)
1049 .addReg(RegNo: AArch64::SP)
1050 .addReg(RegNo: BufferSize)
1051 .addImm(Val: AArch64_AM::getArithExtendImm(ET: AArch64_AM::UXTX, Imm: 0));
1052 BuildMI(BB&: MBB, I: MBBI, MIMD: DL, MCID: TII->get(Opcode: TargetOpcode::COPY), DestReg: BufferPtr)
1053 .addReg(RegNo: AArch64::SP);
1054
1055 // We have just allocated a variable sized object, tell this to PEI.
1056 MFI.CreateVariableSizedObject(Alignment: Align(16), Alloca: nullptr);
1057 }
1058
1059 restorePhyRegSave(RegSave, MBB, MBBI, DL);
1060}
1061
1062struct FromState {
1063 ZAState From;
1064
1065 constexpr uint8_t to(ZAState To) const {
1066 static_assert(NUM_ZA_STATE < 16, "expected ZAState to fit in 4-bits");
1067 return uint8_t(From) << 4 | uint8_t(To);
1068 }
1069};
1070
1071constexpr FromState transitionFrom(ZAState From) { return FromState{.From: From}; }
1072
1073void MachineSMEABI::emitStateChange(EmitContext &Context,
1074 MachineBasicBlock &MBB,
1075 MachineBasicBlock::iterator InsertPt,
1076 ZAState From, ZAState To,
1077 LiveRegs PhysLiveRegs) {
1078 // ZA not used.
1079 if (From == ZAState::ANY || To == ZAState::ANY)
1080 return;
1081
1082 // If we're exiting from the ENTRY state that means that the function has not
1083 // used ZA, so in the case of private ZA/ZT0 functions we can omit any set up.
1084 if (From == ZAState::ENTRY && To == ZAState::OFF)
1085 return;
1086
1087 // TODO: Avoid setting up the save buffer if there's no transition to
1088 // LOCAL_SAVED.
1089 if (From == ZAState::ENTRY) {
1090 assert(&MBB == &MBB.getParent()->front() &&
1091 "ENTRY state only valid in entry block");
1092 emitSMEPrologue(MBB, MBBI: MBB.getFirstNonPHI());
1093 if (To == ZAState::ACTIVE)
1094 return; // Nothing more to do (ZA is active after the prologue).
1095
1096 // Note: "emitNewZAPrologue" zeros ZA, so we may need to setup a lazy save
1097 // if "To" is "ZAState::LOCAL_SAVED". It may be possible to improve this
1098 // case by changing the placement of the zero instruction.
1099 From = ZAState::ACTIVE;
1100 }
1101
1102 SMEAttrs SMEFnAttrs = AFI->getSMEFnAttrs();
1103 bool IsAgnosticZA = SMEFnAttrs.hasAgnosticZAInterface();
1104 bool HasZT0State = SMEFnAttrs.hasZT0State();
1105 bool HasZAState = IsAgnosticZA || SMEFnAttrs.hasZAState();
1106
1107 switch (transitionFrom(From).to(To)) {
1108 // This section handles: ACTIVE <-> ACTIVE_ZT0_SAVED
1109 case transitionFrom(From: ZAState::ACTIVE).to(To: ZAState::ACTIVE_ZT0_SAVED):
1110 emitZT0SaveRestore(Context, MBB, MBBI: InsertPt, /*IsSave=*/true);
1111 break;
1112 case transitionFrom(From: ZAState::ACTIVE_ZT0_SAVED).to(To: ZAState::ACTIVE):
1113 emitZT0SaveRestore(Context, MBB, MBBI: InsertPt, /*IsSave=*/false);
1114 break;
1115
1116 // This section handles: ACTIVE[_ZT0_SAVED] -> LOCAL_SAVED
1117 case transitionFrom(From: ZAState::ACTIVE).to(To: ZAState::LOCAL_SAVED):
1118 case transitionFrom(From: ZAState::ACTIVE_ZT0_SAVED).to(To: ZAState::LOCAL_SAVED):
1119 if (HasZT0State && From == ZAState::ACTIVE)
1120 emitZT0SaveRestore(Context, MBB, MBBI: InsertPt, /*IsSave=*/true);
1121 if (HasZAState)
1122 emitZASave(Context, MBB, MBBI: InsertPt, PhysLiveRegs);
1123 break;
1124
1125 // This section handles: ACTIVE -> LOCAL_COMMITTED
1126 case transitionFrom(From: ZAState::ACTIVE).to(To: ZAState::LOCAL_COMMITTED):
1127 // TODO: We could support ZA state here, but this transition is currently
1128 // only possible when we _don't_ have ZA state.
1129 assert(HasZT0State && !HasZAState && "Expect to only have ZT0 state.");
1130 emitZT0SaveRestore(Context, MBB, MBBI: InsertPt, /*IsSave=*/true);
1131 emitZAMode(MBB, MBBI: InsertPt, /*ClearTPIDR2=*/false, /*On=*/false);
1132 break;
1133
1134 // This section handles: LOCAL_COMMITTED -> (OFF|LOCAL_SAVED)
1135 case transitionFrom(From: ZAState::LOCAL_COMMITTED).to(To: ZAState::OFF):
1136 case transitionFrom(From: ZAState::LOCAL_COMMITTED).to(To: ZAState::LOCAL_SAVED):
1137 // These transitions are a no-op.
1138 break;
1139
1140 // This section handles: LOCAL_(SAVED|COMMITTED) -> ACTIVE[_ZT0_SAVED]
1141 case transitionFrom(From: ZAState::LOCAL_COMMITTED).to(To: ZAState::ACTIVE):
1142 case transitionFrom(From: ZAState::LOCAL_COMMITTED).to(To: ZAState::ACTIVE_ZT0_SAVED):
1143 case transitionFrom(From: ZAState::LOCAL_SAVED).to(To: ZAState::ACTIVE):
1144 case transitionFrom(From: ZAState::LOCAL_SAVED).to(To: ZAState::ACTIVE_ZT0_SAVED):
1145 if (HasZAState)
1146 emitZARestore(Context, MBB, MBBI: InsertPt, PhysLiveRegs);
1147 else
1148 emitZAMode(MBB, MBBI: InsertPt, /*ClearTPIDR2=*/false, /*On=*/true);
1149 if (HasZT0State && To == ZAState::ACTIVE)
1150 emitZT0SaveRestore(Context, MBB, MBBI: InsertPt, /*IsSave=*/false);
1151 break;
1152
1153 // This section handles transitions to OFF (not previously covered)
1154 case transitionFrom(From: ZAState::ACTIVE).to(To: ZAState::OFF):
1155 case transitionFrom(From: ZAState::ACTIVE_ZT0_SAVED).to(To: ZAState::OFF):
1156 case transitionFrom(From: ZAState::LOCAL_SAVED).to(To: ZAState::OFF):
1157 assert(SMEFnAttrs.hasPrivateZAInterface() &&
1158 "Did not expect to turn ZA off in shared/agnostic ZA function");
1159 emitZAMode(MBB, MBBI: InsertPt, /*ClearTPIDR2=*/From == ZAState::LOCAL_SAVED,
1160 /*On=*/false);
1161 break;
1162
1163 default:
1164 dbgs() << "Error: Transition from " << getZAStateString(State: From) << " to "
1165 << getZAStateString(State: To) << '\n';
1166 llvm_unreachable("Unimplemented state transition");
1167 }
1168}
1169
1170/// Returns true if private ZA setup can be elided. This occurs when there is
1171/// no instruction within the function that requires ZA to be active.
1172static bool canElidePrivateZASetup(const FunctionInfo &FnInfo) {
1173 for (const BlockInfo &BlockInfo : FnInfo.Blocks) {
1174 for (const InstInfo &InstInfo : BlockInfo.Insts) {
1175 if (InstInfo.NeededState == ZAState::ACTIVE ||
1176 InstInfo.NeededState == ZAState::ACTIVE_ZT0_SAVED)
1177 return false;
1178 }
1179 }
1180 return true;
1181}
1182
1183} // end anonymous namespace
1184
1185INITIALIZE_PASS(MachineSMEABI, "aarch64-machine-sme-abi", "Machine SME ABI",
1186 false, false)
1187
1188bool MachineSMEABI::runOnMachineFunction(MachineFunction &MF) {
1189 AFI = MF.getInfo<AArch64FunctionInfo>();
1190 SMEAttrs SMEFnAttrs = AFI->getSMEFnAttrs();
1191 if (!SMEFnAttrs.hasZAState() && !SMEFnAttrs.hasZT0State() &&
1192 !SMEFnAttrs.hasAgnosticZAInterface())
1193 return false;
1194
1195 Subtarget = &MF.getSubtarget<AArch64Subtarget>();
1196 if (!Subtarget->hasSME() && !SMEFnAttrs.hasAgnosticZAInterface())
1197 return false;
1198
1199 assert(MF.getRegInfo().isSSA() && "Expected to be run on SSA form!");
1200
1201 this->MF = &MF;
1202 ORE = &getAnalysis<MachineOptimizationRemarkEmitterPass>().getORE();
1203 LLI = &getAnalysis<LibcallLoweringInfoWrapper>().getLibcallLowering(
1204 M: *MF.getFunction().getParent(), Subtarget: *Subtarget);
1205 TII = Subtarget->getInstrInfo();
1206 TRI = Subtarget->getRegisterInfo();
1207 MRI = &MF.getRegInfo();
1208
1209 const EdgeBundles &Bundles =
1210 getAnalysis<EdgeBundlesWrapperLegacy>().getEdgeBundles();
1211
1212 FunctionInfo FnInfo = collectNeededZAStates(SMEFnAttrs);
1213
1214 if (SMEFnAttrs.hasPrivateZAInterface() && canElidePrivateZASetup(FnInfo))
1215 return false;
1216
1217 SmallVector<ZAState> BundleStates = assignBundleZAStates(Bundles, FnInfo);
1218
1219 EmitContext Context;
1220 insertStateChanges(Context, FnInfo, Bundles, BundleStates);
1221
1222 if (Context.needsSaveBuffer()) {
1223 if (FnInfo.AfterSMEProloguePt) {
1224 // Note: With inline stack probes the AfterSMEProloguePt may not be in the
1225 // entry block (due to the probing loop).
1226 MachineBasicBlock::iterator MBBI = *FnInfo.AfterSMEProloguePt;
1227 emitAllocateZASaveBuffer(Context, MBB&: *MBBI->getParent(), MBBI,
1228 PhysLiveRegs: FnInfo.PhysLiveRegsAfterSMEPrologue);
1229 } else {
1230 MachineBasicBlock &EntryBlock = MF.front();
1231 emitAllocateZASaveBuffer(
1232 Context, MBB&: EntryBlock, MBBI: EntryBlock.getFirstNonPHI(),
1233 PhysLiveRegs: FnInfo.Blocks[EntryBlock.getNumber()].PhysLiveRegsAtEntry);
1234 }
1235 }
1236
1237 return true;
1238}
1239
1240FunctionPass *llvm::createMachineSMEABIPass(CodeGenOptLevel OptLevel) {
1241 return new MachineSMEABI(OptLevel);
1242}
1243