1 | //===-- AArch64SMEAttributes.h - Helper for interpreting SME attributes -*-===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #ifndef LLVM_LIB_TARGET_AARCH64_UTILS_AARCH64SMEATTRIBUTES_H |
10 | #define LLVM_LIB_TARGET_AARCH64_UTILS_AARCH64SMEATTRIBUTES_H |
11 | |
12 | #include "llvm/IR/Function.h" |
13 | |
14 | namespace llvm { |
15 | |
16 | class Function; |
17 | class CallBase; |
18 | class AttributeList; |
19 | |
20 | /// SMEAttrs is a utility class to parse the SME ACLE attributes on functions. |
21 | /// It helps determine a function's requirements for PSTATE.ZA and PSTATE.SM. |
22 | class SMEAttrs { |
23 | unsigned Bitmask = Normal; |
24 | |
25 | public: |
26 | enum class StateValue { |
27 | None = 0, |
28 | In = 1, // aarch64_in_zt0 |
29 | Out = 2, // aarch64_out_zt0 |
30 | InOut = 3, // aarch64_inout_zt0 |
31 | Preserved = 4, // aarch64_preserves_zt0 |
32 | New = 5 // aarch64_new_zt0 |
33 | }; |
34 | |
35 | // Enum with bitmasks for each individual SME feature. |
36 | enum Mask { |
37 | Normal = 0, |
38 | SM_Enabled = 1 << 0, // aarch64_pstate_sm_enabled |
39 | SM_Compatible = 1 << 1, // aarch64_pstate_sm_compatible |
40 | SM_Body = 1 << 2, // aarch64_pstate_sm_body |
41 | SME_ABI_Routine = 1 << 3, // Used for SME ABI routines to avoid lazy saves |
42 | ZA_State_Agnostic = 1 << 4, |
43 | ZT0_Undef = 1 << 5, // Use to mark ZT0 as undef to avoid spills |
44 | ZA_Shift = 6, |
45 | ZA_Mask = 0b111 << ZA_Shift, |
46 | ZT0_Shift = 9, |
47 | ZT0_Mask = 0b111 << ZT0_Shift, |
48 | CallSiteFlags_Mask = ZT0_Undef |
49 | }; |
50 | |
51 | enum class InferAttrsFromName { No, Yes }; |
52 | |
53 | SMEAttrs() = default; |
54 | SMEAttrs(unsigned Mask) { set(M: Mask); } |
55 | SMEAttrs(const Function &F, InferAttrsFromName Infer = InferAttrsFromName::No) |
56 | : SMEAttrs(F.getAttributes()) { |
57 | if (Infer == InferAttrsFromName::Yes) |
58 | addKnownFunctionAttrs(FuncName: F.getName()); |
59 | } |
60 | SMEAttrs(const AttributeList &L); |
61 | SMEAttrs(StringRef FuncName) { addKnownFunctionAttrs(FuncName); }; |
62 | |
63 | void set(unsigned M, bool Enable = true); |
64 | |
65 | // Interfaces to query PSTATE.SM |
66 | bool hasStreamingBody() const { return Bitmask & SM_Body; } |
67 | bool hasStreamingInterface() const { return Bitmask & SM_Enabled; } |
68 | bool hasStreamingInterfaceOrBody() const { |
69 | return hasStreamingBody() || hasStreamingInterface(); |
70 | } |
71 | bool hasStreamingCompatibleInterface() const { |
72 | return Bitmask & SM_Compatible; |
73 | } |
74 | bool hasNonStreamingInterface() const { |
75 | return !hasStreamingInterface() && !hasStreamingCompatibleInterface(); |
76 | } |
77 | bool hasNonStreamingInterfaceAndBody() const { |
78 | return hasNonStreamingInterface() && !hasStreamingBody(); |
79 | } |
80 | |
81 | // Interfaces to query ZA |
82 | static StateValue decodeZAState(unsigned Bitmask) { |
83 | return static_cast<StateValue>((Bitmask & ZA_Mask) >> ZA_Shift); |
84 | } |
85 | static unsigned encodeZAState(StateValue S) { |
86 | return static_cast<unsigned>(S) << ZA_Shift; |
87 | } |
88 | |
89 | bool isNewZA() const { return decodeZAState(Bitmask) == StateValue::New; } |
90 | bool isInZA() const { return decodeZAState(Bitmask) == StateValue::In; } |
91 | bool isOutZA() const { return decodeZAState(Bitmask) == StateValue::Out; } |
92 | bool isInOutZA() const { return decodeZAState(Bitmask) == StateValue::InOut; } |
93 | bool isPreservesZA() const { |
94 | return decodeZAState(Bitmask) == StateValue::Preserved; |
95 | } |
96 | bool sharesZA() const { |
97 | StateValue State = decodeZAState(Bitmask); |
98 | return State == StateValue::In || State == StateValue::Out || |
99 | State == StateValue::InOut || State == StateValue::Preserved; |
100 | } |
101 | bool hasAgnosticZAInterface() const { return Bitmask & ZA_State_Agnostic; } |
102 | bool hasSharedZAInterface() const { return sharesZA() || sharesZT0(); } |
103 | bool hasPrivateZAInterface() const { |
104 | return !hasSharedZAInterface() && !hasAgnosticZAInterface(); |
105 | } |
106 | bool hasZAState() const { return isNewZA() || sharesZA(); } |
107 | bool isSMEABIRoutine() const { return Bitmask & SME_ABI_Routine; } |
108 | |
109 | // Interfaces to query ZT0 State |
110 | static StateValue decodeZT0State(unsigned Bitmask) { |
111 | return static_cast<StateValue>((Bitmask & ZT0_Mask) >> ZT0_Shift); |
112 | } |
113 | static unsigned encodeZT0State(StateValue S) { |
114 | return static_cast<unsigned>(S) << ZT0_Shift; |
115 | } |
116 | |
117 | bool isNewZT0() const { return decodeZT0State(Bitmask) == StateValue::New; } |
118 | bool isInZT0() const { return decodeZT0State(Bitmask) == StateValue::In; } |
119 | bool isOutZT0() const { return decodeZT0State(Bitmask) == StateValue::Out; } |
120 | bool isInOutZT0() const { |
121 | return decodeZT0State(Bitmask) == StateValue::InOut; |
122 | } |
123 | bool isPreservesZT0() const { |
124 | return decodeZT0State(Bitmask) == StateValue::Preserved; |
125 | } |
126 | bool hasUndefZT0() const { return Bitmask & ZT0_Undef; } |
127 | bool sharesZT0() const { |
128 | StateValue State = decodeZT0State(Bitmask); |
129 | return State == StateValue::In || State == StateValue::Out || |
130 | State == StateValue::InOut || State == StateValue::Preserved; |
131 | } |
132 | bool hasZT0State() const { return isNewZT0() || sharesZT0(); } |
133 | |
134 | SMEAttrs operator|(SMEAttrs Other) const { |
135 | SMEAttrs Merged(*this); |
136 | Merged.set(M: Other.Bitmask); |
137 | return Merged; |
138 | } |
139 | |
140 | SMEAttrs withoutPerCallsiteFlags() const { |
141 | return (Bitmask & ~CallSiteFlags_Mask); |
142 | } |
143 | |
144 | bool operator==(SMEAttrs const &Other) const { |
145 | return Bitmask == Other.Bitmask; |
146 | } |
147 | |
148 | private: |
149 | void addKnownFunctionAttrs(StringRef FuncName); |
150 | }; |
151 | |
152 | /// SMECallAttrs is a utility class to hold the SMEAttrs for a callsite. It has |
153 | /// interfaces to query whether a streaming mode change or lazy-save mechanism |
154 | /// is required when going from one function to another (e.g. through a call). |
155 | class SMECallAttrs { |
156 | SMEAttrs CallerFn; |
157 | SMEAttrs CalledFn; |
158 | SMEAttrs Callsite; |
159 | bool IsIndirect = false; |
160 | |
161 | public: |
162 | SMECallAttrs(SMEAttrs Caller, SMEAttrs Callee, |
163 | SMEAttrs Callsite = SMEAttrs::Normal) |
164 | : CallerFn(Caller), CalledFn(Callee), Callsite(Callsite) {} |
165 | |
166 | SMECallAttrs(const CallBase &CB); |
167 | |
168 | SMEAttrs &caller() { return CallerFn; } |
169 | SMEAttrs &callee() { return IsIndirect ? Callsite : CalledFn; } |
170 | SMEAttrs &callsite() { return Callsite; } |
171 | SMEAttrs const &caller() const { return CallerFn; } |
172 | SMEAttrs const &callee() const { |
173 | return const_cast<SMECallAttrs *>(this)->callee(); |
174 | } |
175 | SMEAttrs const &callsite() const { return Callsite; } |
176 | |
177 | /// \return true if a call from Caller -> Callee requires a change in |
178 | /// streaming mode. |
179 | bool requiresSMChange() const; |
180 | |
181 | bool requiresLazySave() const { |
182 | return caller().hasZAState() && callee().hasPrivateZAInterface() && |
183 | !callee().isSMEABIRoutine(); |
184 | } |
185 | |
186 | bool requiresPreservingZT0() const { |
187 | return caller().hasZT0State() && !callsite().hasUndefZT0() && |
188 | !callee().sharesZT0() && !callee().hasAgnosticZAInterface(); |
189 | } |
190 | |
191 | bool requiresDisablingZABeforeCall() const { |
192 | return caller().hasZT0State() && !caller().hasZAState() && |
193 | callee().hasPrivateZAInterface() && !callee().isSMEABIRoutine(); |
194 | } |
195 | |
196 | bool requiresEnablingZAAfterCall() const { |
197 | return requiresLazySave() || requiresDisablingZABeforeCall(); |
198 | } |
199 | |
200 | bool requiresPreservingAllZAState() const { |
201 | return caller().hasAgnosticZAInterface() && |
202 | !callee().hasAgnosticZAInterface() && !callee().isSMEABIRoutine(); |
203 | } |
204 | }; |
205 | |
206 | } // namespace llvm |
207 | |
208 | #endif // LLVM_LIB_TARGET_AARCH64_UTILS_AARCH64SMEATTRIBUTES_H |
209 | |