1//===-- AArch64SMEAttributes.h - Helper for interpreting SME attributes -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#ifndef LLVM_LIB_TARGET_AARCH64_UTILS_AARCH64SMEATTRIBUTES_H
10#define LLVM_LIB_TARGET_AARCH64_UTILS_AARCH64SMEATTRIBUTES_H
11
12#include "llvm/IR/Function.h"
13
14namespace llvm {
15
16class Function;
17class CallBase;
18class AttributeList;
19
20/// SMEAttrs is a utility class to parse the SME ACLE attributes on functions.
21/// It helps determine a function's requirements for PSTATE.ZA and PSTATE.SM.
22class SMEAttrs {
23 unsigned Bitmask = Normal;
24
25public:
26 enum class StateValue {
27 None = 0,
28 In = 1, // aarch64_in_zt0
29 Out = 2, // aarch64_out_zt0
30 InOut = 3, // aarch64_inout_zt0
31 Preserved = 4, // aarch64_preserves_zt0
32 New = 5 // aarch64_new_zt0
33 };
34
35 // Enum with bitmasks for each individual SME feature.
36 enum Mask {
37 Normal = 0,
38 SM_Enabled = 1 << 0, // aarch64_pstate_sm_enabled
39 SM_Compatible = 1 << 1, // aarch64_pstate_sm_compatible
40 SM_Body = 1 << 2, // aarch64_pstate_sm_body
41 SME_ABI_Routine = 1 << 3, // Used for SME ABI routines to avoid lazy saves
42 ZA_State_Agnostic = 1 << 4,
43 ZT0_Undef = 1 << 5, // Use to mark ZT0 as undef to avoid spills
44 ZA_Shift = 6,
45 ZA_Mask = 0b111 << ZA_Shift,
46 ZT0_Shift = 9,
47 ZT0_Mask = 0b111 << ZT0_Shift,
48 CallSiteFlags_Mask = ZT0_Undef
49 };
50
51 enum class InferAttrsFromName { No, Yes };
52
53 SMEAttrs() = default;
54 SMEAttrs(unsigned Mask) { set(M: Mask); }
55 SMEAttrs(const Function &F, InferAttrsFromName Infer = InferAttrsFromName::No)
56 : SMEAttrs(F.getAttributes()) {
57 if (Infer == InferAttrsFromName::Yes)
58 addKnownFunctionAttrs(FuncName: F.getName());
59 }
60 SMEAttrs(const AttributeList &L);
61 SMEAttrs(StringRef FuncName) { addKnownFunctionAttrs(FuncName); };
62
63 void set(unsigned M, bool Enable = true);
64
65 // Interfaces to query PSTATE.SM
66 bool hasStreamingBody() const { return Bitmask & SM_Body; }
67 bool hasStreamingInterface() const { return Bitmask & SM_Enabled; }
68 bool hasStreamingInterfaceOrBody() const {
69 return hasStreamingBody() || hasStreamingInterface();
70 }
71 bool hasStreamingCompatibleInterface() const {
72 return Bitmask & SM_Compatible;
73 }
74 bool hasNonStreamingInterface() const {
75 return !hasStreamingInterface() && !hasStreamingCompatibleInterface();
76 }
77 bool hasNonStreamingInterfaceAndBody() const {
78 return hasNonStreamingInterface() && !hasStreamingBody();
79 }
80
81 // Interfaces to query ZA
82 static StateValue decodeZAState(unsigned Bitmask) {
83 return static_cast<StateValue>((Bitmask & ZA_Mask) >> ZA_Shift);
84 }
85 static unsigned encodeZAState(StateValue S) {
86 return static_cast<unsigned>(S) << ZA_Shift;
87 }
88
89 bool isNewZA() const { return decodeZAState(Bitmask) == StateValue::New; }
90 bool isInZA() const { return decodeZAState(Bitmask) == StateValue::In; }
91 bool isOutZA() const { return decodeZAState(Bitmask) == StateValue::Out; }
92 bool isInOutZA() const { return decodeZAState(Bitmask) == StateValue::InOut; }
93 bool isPreservesZA() const {
94 return decodeZAState(Bitmask) == StateValue::Preserved;
95 }
96 bool sharesZA() const {
97 StateValue State = decodeZAState(Bitmask);
98 return State == StateValue::In || State == StateValue::Out ||
99 State == StateValue::InOut || State == StateValue::Preserved;
100 }
101 bool hasAgnosticZAInterface() const { return Bitmask & ZA_State_Agnostic; }
102 bool hasSharedZAInterface() const { return sharesZA() || sharesZT0(); }
103 bool hasPrivateZAInterface() const {
104 return !hasSharedZAInterface() && !hasAgnosticZAInterface();
105 }
106 bool hasZAState() const { return isNewZA() || sharesZA(); }
107 bool isSMEABIRoutine() const { return Bitmask & SME_ABI_Routine; }
108
109 // Interfaces to query ZT0 State
110 static StateValue decodeZT0State(unsigned Bitmask) {
111 return static_cast<StateValue>((Bitmask & ZT0_Mask) >> ZT0_Shift);
112 }
113 static unsigned encodeZT0State(StateValue S) {
114 return static_cast<unsigned>(S) << ZT0_Shift;
115 }
116
117 bool isNewZT0() const { return decodeZT0State(Bitmask) == StateValue::New; }
118 bool isInZT0() const { return decodeZT0State(Bitmask) == StateValue::In; }
119 bool isOutZT0() const { return decodeZT0State(Bitmask) == StateValue::Out; }
120 bool isInOutZT0() const {
121 return decodeZT0State(Bitmask) == StateValue::InOut;
122 }
123 bool isPreservesZT0() const {
124 return decodeZT0State(Bitmask) == StateValue::Preserved;
125 }
126 bool hasUndefZT0() const { return Bitmask & ZT0_Undef; }
127 bool sharesZT0() const {
128 StateValue State = decodeZT0State(Bitmask);
129 return State == StateValue::In || State == StateValue::Out ||
130 State == StateValue::InOut || State == StateValue::Preserved;
131 }
132 bool hasZT0State() const { return isNewZT0() || sharesZT0(); }
133
134 SMEAttrs operator|(SMEAttrs Other) const {
135 SMEAttrs Merged(*this);
136 Merged.set(M: Other.Bitmask);
137 return Merged;
138 }
139
140 SMEAttrs withoutPerCallsiteFlags() const {
141 return (Bitmask & ~CallSiteFlags_Mask);
142 }
143
144 bool operator==(SMEAttrs const &Other) const {
145 return Bitmask == Other.Bitmask;
146 }
147
148private:
149 void addKnownFunctionAttrs(StringRef FuncName);
150};
151
152/// SMECallAttrs is a utility class to hold the SMEAttrs for a callsite. It has
153/// interfaces to query whether a streaming mode change or lazy-save mechanism
154/// is required when going from one function to another (e.g. through a call).
155class SMECallAttrs {
156 SMEAttrs CallerFn;
157 SMEAttrs CalledFn;
158 SMEAttrs Callsite;
159 bool IsIndirect = false;
160
161public:
162 SMECallAttrs(SMEAttrs Caller, SMEAttrs Callee,
163 SMEAttrs Callsite = SMEAttrs::Normal)
164 : CallerFn(Caller), CalledFn(Callee), Callsite(Callsite) {}
165
166 SMECallAttrs(const CallBase &CB);
167
168 SMEAttrs &caller() { return CallerFn; }
169 SMEAttrs &callee() { return IsIndirect ? Callsite : CalledFn; }
170 SMEAttrs &callsite() { return Callsite; }
171 SMEAttrs const &caller() const { return CallerFn; }
172 SMEAttrs const &callee() const {
173 return const_cast<SMECallAttrs *>(this)->callee();
174 }
175 SMEAttrs const &callsite() const { return Callsite; }
176
177 /// \return true if a call from Caller -> Callee requires a change in
178 /// streaming mode.
179 bool requiresSMChange() const;
180
181 bool requiresLazySave() const {
182 return caller().hasZAState() && callee().hasPrivateZAInterface() &&
183 !callee().isSMEABIRoutine();
184 }
185
186 bool requiresPreservingZT0() const {
187 return caller().hasZT0State() && !callsite().hasUndefZT0() &&
188 !callee().sharesZT0() && !callee().hasAgnosticZAInterface();
189 }
190
191 bool requiresDisablingZABeforeCall() const {
192 return caller().hasZT0State() && !caller().hasZAState() &&
193 callee().hasPrivateZAInterface() && !callee().isSMEABIRoutine();
194 }
195
196 bool requiresEnablingZAAfterCall() const {
197 return requiresLazySave() || requiresDisablingZABeforeCall();
198 }
199
200 bool requiresPreservingAllZAState() const {
201 return caller().hasAgnosticZAInterface() &&
202 !callee().hasAgnosticZAInterface() && !callee().isSMEABIRoutine();
203 }
204};
205
206} // namespace llvm
207
208#endif // LLVM_LIB_TARGET_AARCH64_UTILS_AARCH64SMEATTRIBUTES_H
209