| 1 | //===-- AArch64SMEAttributes.h - Helper for interpreting SME attributes -*-===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | |
| 9 | #ifndef LLVM_LIB_TARGET_AARCH64_UTILS_AARCH64SMEATTRIBUTES_H |
| 10 | #define LLVM_LIB_TARGET_AARCH64_UTILS_AARCH64SMEATTRIBUTES_H |
| 11 | |
| 12 | #include "llvm/IR/Function.h" |
| 13 | |
| 14 | namespace llvm { |
| 15 | namespace RTLIB { |
| 16 | struct RuntimeLibcallsInfo; |
| 17 | } |
| 18 | |
| 19 | class Function; |
| 20 | class CallBase; |
| 21 | class AttributeList; |
| 22 | |
| 23 | /// SMEAttrs is a utility class to parse the SME ACLE attributes on functions. |
| 24 | /// It helps determine a function's requirements for PSTATE.ZA and PSTATE.SM. |
| 25 | class SMEAttrs { |
| 26 | unsigned Bitmask = Normal; |
| 27 | |
| 28 | public: |
| 29 | enum class StateValue { |
| 30 | None = 0, |
| 31 | In = 1, // aarch64_in_zt0 |
| 32 | Out = 2, // aarch64_out_zt0 |
| 33 | InOut = 3, // aarch64_inout_zt0 |
| 34 | Preserved = 4, // aarch64_preserves_zt0 |
| 35 | New = 5 // aarch64_new_zt0 |
| 36 | }; |
| 37 | |
| 38 | // Enum with bitmasks for each individual SME feature. |
| 39 | enum Mask { |
| 40 | Normal = 0, |
| 41 | SM_Enabled = 1 << 0, // aarch64_pstate_sm_enabled |
| 42 | SM_Compatible = 1 << 1, // aarch64_pstate_sm_compatible |
| 43 | SM_Body = 1 << 2, // aarch64_pstate_sm_body |
| 44 | SME_ABI_Routine = 1 << 3, // Used for SME ABI routines to avoid lazy saves |
| 45 | ZA_State_Agnostic = 1 << 4, |
| 46 | ZT0_Undef = 1 << 5, // Use to mark ZT0 as undef to avoid spills |
| 47 | ZA_Shift = 6, |
| 48 | ZA_Mask = 0b111 << ZA_Shift, |
| 49 | ZT0_Shift = 9, |
| 50 | ZT0_Mask = 0b111 << ZT0_Shift, |
| 51 | CallSiteFlags_Mask = ZT0_Undef |
| 52 | }; |
| 53 | |
| 54 | SMEAttrs() = default; |
| 55 | SMEAttrs(unsigned Mask) { set(M: Mask); } |
| 56 | SMEAttrs(const Function &F, const RTLIB::RuntimeLibcallsInfo *RTLCI = nullptr) |
| 57 | : SMEAttrs(F.getAttributes()) { |
| 58 | if (RTLCI) |
| 59 | addKnownFunctionAttrs(FuncName: F.getName(), RTLCI: *RTLCI); |
| 60 | } |
| 61 | SMEAttrs(const AttributeList &L); |
| 62 | SMEAttrs(StringRef FuncName, const RTLIB::RuntimeLibcallsInfo &RTLCI) { |
| 63 | addKnownFunctionAttrs(FuncName, RTLCI); |
| 64 | }; |
| 65 | |
| 66 | void set(unsigned M, bool Enable = true) { |
| 67 | if (Enable) |
| 68 | Bitmask |= M; |
| 69 | else |
| 70 | Bitmask &= ~M; |
| 71 | #ifndef NDEBUG |
| 72 | validate(); |
| 73 | #endif |
| 74 | } |
| 75 | |
| 76 | // Interfaces to query PSTATE.SM |
| 77 | bool hasStreamingBody() const { return Bitmask & SM_Body; } |
| 78 | bool hasStreamingInterface() const { return Bitmask & SM_Enabled; } |
| 79 | bool hasStreamingInterfaceOrBody() const { |
| 80 | return hasStreamingBody() || hasStreamingInterface(); |
| 81 | } |
| 82 | bool hasStreamingCompatibleInterface() const { |
| 83 | return Bitmask & SM_Compatible; |
| 84 | } |
| 85 | bool hasNonStreamingInterface() const { |
| 86 | return !hasStreamingInterface() && !hasStreamingCompatibleInterface(); |
| 87 | } |
| 88 | bool hasNonStreamingInterfaceAndBody() const { |
| 89 | return hasNonStreamingInterface() && !hasStreamingBody(); |
| 90 | } |
| 91 | |
| 92 | // Interfaces to query ZA |
| 93 | static StateValue decodeZAState(unsigned Bitmask) { |
| 94 | return static_cast<StateValue>((Bitmask & ZA_Mask) >> ZA_Shift); |
| 95 | } |
| 96 | static unsigned encodeZAState(StateValue S) { |
| 97 | return static_cast<unsigned>(S) << ZA_Shift; |
| 98 | } |
| 99 | |
| 100 | bool isNewZA() const { return decodeZAState(Bitmask) == StateValue::New; } |
| 101 | bool isInZA() const { return decodeZAState(Bitmask) == StateValue::In; } |
| 102 | bool isOutZA() const { return decodeZAState(Bitmask) == StateValue::Out; } |
| 103 | bool isInOutZA() const { return decodeZAState(Bitmask) == StateValue::InOut; } |
| 104 | bool isPreservesZA() const { |
| 105 | return decodeZAState(Bitmask) == StateValue::Preserved; |
| 106 | } |
| 107 | bool sharesZA() const { |
| 108 | StateValue State = decodeZAState(Bitmask); |
| 109 | return State == StateValue::In || State == StateValue::Out || |
| 110 | State == StateValue::InOut || State == StateValue::Preserved; |
| 111 | } |
| 112 | bool hasAgnosticZAInterface() const { return Bitmask & ZA_State_Agnostic; } |
| 113 | bool hasSharedZAInterface() const { return sharesZA() || sharesZT0(); } |
| 114 | bool hasPrivateZAInterface() const { |
| 115 | return !hasSharedZAInterface() && !hasAgnosticZAInterface(); |
| 116 | } |
| 117 | bool hasZAState() const { return isNewZA() || sharesZA(); } |
| 118 | bool isSMEABIRoutine() const { return Bitmask & SME_ABI_Routine; } |
| 119 | |
| 120 | // Interfaces to query ZT0 State |
| 121 | static StateValue decodeZT0State(unsigned Bitmask) { |
| 122 | return static_cast<StateValue>((Bitmask & ZT0_Mask) >> ZT0_Shift); |
| 123 | } |
| 124 | static unsigned encodeZT0State(StateValue S) { |
| 125 | return static_cast<unsigned>(S) << ZT0_Shift; |
| 126 | } |
| 127 | |
| 128 | bool isNewZT0() const { return decodeZT0State(Bitmask) == StateValue::New; } |
| 129 | bool isInZT0() const { return decodeZT0State(Bitmask) == StateValue::In; } |
| 130 | bool isOutZT0() const { return decodeZT0State(Bitmask) == StateValue::Out; } |
| 131 | bool isInOutZT0() const { |
| 132 | return decodeZT0State(Bitmask) == StateValue::InOut; |
| 133 | } |
| 134 | bool isPreservesZT0() const { |
| 135 | return decodeZT0State(Bitmask) == StateValue::Preserved; |
| 136 | } |
| 137 | bool hasUndefZT0() const { return Bitmask & ZT0_Undef; } |
| 138 | bool sharesZT0() const { |
| 139 | StateValue State = decodeZT0State(Bitmask); |
| 140 | return State == StateValue::In || State == StateValue::Out || |
| 141 | State == StateValue::InOut || State == StateValue::Preserved; |
| 142 | } |
| 143 | bool hasZT0State() const { return isNewZT0() || sharesZT0(); } |
| 144 | |
| 145 | SMEAttrs operator|(SMEAttrs Other) const { |
| 146 | SMEAttrs Merged(*this); |
| 147 | Merged.set(M: Other.Bitmask); |
| 148 | return Merged; |
| 149 | } |
| 150 | |
| 151 | SMEAttrs withoutPerCallsiteFlags() const { |
| 152 | return (Bitmask & ~CallSiteFlags_Mask); |
| 153 | } |
| 154 | |
| 155 | bool operator==(SMEAttrs const &Other) const { |
| 156 | return Bitmask == Other.Bitmask; |
| 157 | } |
| 158 | |
| 159 | private: |
| 160 | void addKnownFunctionAttrs(StringRef FuncName, |
| 161 | const RTLIB::RuntimeLibcallsInfo &RTLCI); |
| 162 | void validate() const; |
| 163 | }; |
| 164 | |
| 165 | /// SMECallAttrs is a utility class to hold the SMEAttrs for a callsite. It has |
| 166 | /// interfaces to query whether a streaming mode change or lazy-save mechanism |
| 167 | /// is required when going from one function to another (e.g. through a call). |
| 168 | class SMECallAttrs { |
| 169 | SMEAttrs CallerFn; |
| 170 | SMEAttrs CalledFn; |
| 171 | SMEAttrs Callsite; |
| 172 | bool IsIndirect = false; |
| 173 | |
| 174 | public: |
| 175 | SMECallAttrs(SMEAttrs Caller, SMEAttrs Callee, |
| 176 | SMEAttrs Callsite = SMEAttrs::Normal) |
| 177 | : CallerFn(Caller), CalledFn(Callee), Callsite(Callsite) {} |
| 178 | |
| 179 | SMECallAttrs(const CallBase &CB, const RTLIB::RuntimeLibcallsInfo *RTLCI); |
| 180 | |
| 181 | SMEAttrs &caller() { return CallerFn; } |
| 182 | SMEAttrs &callee() { return IsIndirect ? Callsite : CalledFn; } |
| 183 | SMEAttrs &callsite() { return Callsite; } |
| 184 | SMEAttrs const &caller() const { return CallerFn; } |
| 185 | SMEAttrs const &callee() const { |
| 186 | return const_cast<SMECallAttrs *>(this)->callee(); |
| 187 | } |
| 188 | SMEAttrs const &callsite() const { return Callsite; } |
| 189 | |
| 190 | /// \return true if a call from Caller -> Callee requires a change in |
| 191 | /// streaming mode. |
| 192 | bool requiresSMChange() const; |
| 193 | |
| 194 | bool requiresLazySave() const { |
| 195 | return caller().hasZAState() && callee().hasPrivateZAInterface() && |
| 196 | !callee().isSMEABIRoutine(); |
| 197 | } |
| 198 | |
| 199 | bool requiresPreservingZT0() const { |
| 200 | return caller().hasZT0State() && !callsite().hasUndefZT0() && |
| 201 | !callee().sharesZT0() && !callee().hasAgnosticZAInterface(); |
| 202 | } |
| 203 | |
| 204 | bool requiresDisablingZABeforeCall() const { |
| 205 | return caller().hasZT0State() && !caller().hasZAState() && |
| 206 | callee().hasPrivateZAInterface() && !callee().isSMEABIRoutine(); |
| 207 | } |
| 208 | |
| 209 | bool requiresEnablingZAAfterCall() const { |
| 210 | return requiresDisablingZABeforeCall(); |
| 211 | } |
| 212 | |
| 213 | bool requiresPreservingAllZAState() const { |
| 214 | return caller().hasAgnosticZAInterface() && |
| 215 | !callee().hasAgnosticZAInterface() && !callee().isSMEABIRoutine(); |
| 216 | } |
| 217 | }; |
| 218 | |
| 219 | } // namespace llvm |
| 220 | |
| 221 | #endif // LLVM_LIB_TARGET_AARCH64_UTILS_AARCH64SMEATTRIBUTES_H |
| 222 | |