| 1 | //===-- AArch64PBQPRegAlloc.cpp - AArch64 specific PBQP constraints -------===// | 
|---|
| 2 | // | 
|---|
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | 
|---|
| 4 | // See https://llvm.org/LICENSE.txt for license information. | 
|---|
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | 
|---|
| 6 | // | 
|---|
| 7 | //===----------------------------------------------------------------------===// | 
|---|
| 8 | // This file contains the AArch64 / Cortex-A57 specific register allocation | 
|---|
| 9 | // constraints for use by the PBQP register allocator. | 
|---|
| 10 | // | 
|---|
| 11 | // It is essentially a transcription of what is contained in | 
|---|
| 12 | // AArch64A57FPLoadBalancing, which tries to use a balanced | 
|---|
| 13 | // mix of odd and even D-registers when performing a critical sequence of | 
|---|
| 14 | // independent, non-quadword FP/ASIMD floating-point multiply-accumulates. | 
|---|
| 15 | //===----------------------------------------------------------------------===// | 
|---|
| 16 |  | 
|---|
| 17 | #include "AArch64PBQPRegAlloc.h" | 
|---|
| 18 | #include "AArch64InstrInfo.h" | 
|---|
| 19 | #include "AArch64RegisterInfo.h" | 
|---|
| 20 | #include "llvm/CodeGen/LiveIntervals.h" | 
|---|
| 21 | #include "llvm/CodeGen/MachineBasicBlock.h" | 
|---|
| 22 | #include "llvm/CodeGen/MachineFunction.h" | 
|---|
| 23 | #include "llvm/CodeGen/RegAllocPBQP.h" | 
|---|
| 24 | #include "llvm/Support/Debug.h" | 
|---|
| 25 | #include "llvm/Support/ErrorHandling.h" | 
|---|
| 26 | #include "llvm/Support/raw_ostream.h" | 
|---|
| 27 |  | 
|---|
| 28 | #define DEBUG_TYPE "aarch64-pbqp" | 
|---|
| 29 |  | 
|---|
| 30 | using namespace llvm; | 
|---|
| 31 |  | 
|---|
| 32 | namespace { | 
|---|
| 33 |  | 
|---|
| 34 | bool isOdd(unsigned reg) { | 
|---|
| 35 | switch (reg) { | 
|---|
| 36 | default: | 
|---|
| 37 | llvm_unreachable( "Register is not from the expected class !"); | 
|---|
| 38 | case AArch64::S1: | 
|---|
| 39 | case AArch64::S3: | 
|---|
| 40 | case AArch64::S5: | 
|---|
| 41 | case AArch64::S7: | 
|---|
| 42 | case AArch64::S9: | 
|---|
| 43 | case AArch64::S11: | 
|---|
| 44 | case AArch64::S13: | 
|---|
| 45 | case AArch64::S15: | 
|---|
| 46 | case AArch64::S17: | 
|---|
| 47 | case AArch64::S19: | 
|---|
| 48 | case AArch64::S21: | 
|---|
| 49 | case AArch64::S23: | 
|---|
| 50 | case AArch64::S25: | 
|---|
| 51 | case AArch64::S27: | 
|---|
| 52 | case AArch64::S29: | 
|---|
| 53 | case AArch64::S31: | 
|---|
| 54 | case AArch64::D1: | 
|---|
| 55 | case AArch64::D3: | 
|---|
| 56 | case AArch64::D5: | 
|---|
| 57 | case AArch64::D7: | 
|---|
| 58 | case AArch64::D9: | 
|---|
| 59 | case AArch64::D11: | 
|---|
| 60 | case AArch64::D13: | 
|---|
| 61 | case AArch64::D15: | 
|---|
| 62 | case AArch64::D17: | 
|---|
| 63 | case AArch64::D19: | 
|---|
| 64 | case AArch64::D21: | 
|---|
| 65 | case AArch64::D23: | 
|---|
| 66 | case AArch64::D25: | 
|---|
| 67 | case AArch64::D27: | 
|---|
| 68 | case AArch64::D29: | 
|---|
| 69 | case AArch64::D31: | 
|---|
| 70 | case AArch64::Q1: | 
|---|
| 71 | case AArch64::Q3: | 
|---|
| 72 | case AArch64::Q5: | 
|---|
| 73 | case AArch64::Q7: | 
|---|
| 74 | case AArch64::Q9: | 
|---|
| 75 | case AArch64::Q11: | 
|---|
| 76 | case AArch64::Q13: | 
|---|
| 77 | case AArch64::Q15: | 
|---|
| 78 | case AArch64::Q17: | 
|---|
| 79 | case AArch64::Q19: | 
|---|
| 80 | case AArch64::Q21: | 
|---|
| 81 | case AArch64::Q23: | 
|---|
| 82 | case AArch64::Q25: | 
|---|
| 83 | case AArch64::Q27: | 
|---|
| 84 | case AArch64::Q29: | 
|---|
| 85 | case AArch64::Q31: | 
|---|
| 86 | return true; | 
|---|
| 87 | case AArch64::S0: | 
|---|
| 88 | case AArch64::S2: | 
|---|
| 89 | case AArch64::S4: | 
|---|
| 90 | case AArch64::S6: | 
|---|
| 91 | case AArch64::S8: | 
|---|
| 92 | case AArch64::S10: | 
|---|
| 93 | case AArch64::S12: | 
|---|
| 94 | case AArch64::S14: | 
|---|
| 95 | case AArch64::S16: | 
|---|
| 96 | case AArch64::S18: | 
|---|
| 97 | case AArch64::S20: | 
|---|
| 98 | case AArch64::S22: | 
|---|
| 99 | case AArch64::S24: | 
|---|
| 100 | case AArch64::S26: | 
|---|
| 101 | case AArch64::S28: | 
|---|
| 102 | case AArch64::S30: | 
|---|
| 103 | case AArch64::D0: | 
|---|
| 104 | case AArch64::D2: | 
|---|
| 105 | case AArch64::D4: | 
|---|
| 106 | case AArch64::D6: | 
|---|
| 107 | case AArch64::D8: | 
|---|
| 108 | case AArch64::D10: | 
|---|
| 109 | case AArch64::D12: | 
|---|
| 110 | case AArch64::D14: | 
|---|
| 111 | case AArch64::D16: | 
|---|
| 112 | case AArch64::D18: | 
|---|
| 113 | case AArch64::D20: | 
|---|
| 114 | case AArch64::D22: | 
|---|
| 115 | case AArch64::D24: | 
|---|
| 116 | case AArch64::D26: | 
|---|
| 117 | case AArch64::D28: | 
|---|
| 118 | case AArch64::D30: | 
|---|
| 119 | case AArch64::Q0: | 
|---|
| 120 | case AArch64::Q2: | 
|---|
| 121 | case AArch64::Q4: | 
|---|
| 122 | case AArch64::Q6: | 
|---|
| 123 | case AArch64::Q8: | 
|---|
| 124 | case AArch64::Q10: | 
|---|
| 125 | case AArch64::Q12: | 
|---|
| 126 | case AArch64::Q14: | 
|---|
| 127 | case AArch64::Q16: | 
|---|
| 128 | case AArch64::Q18: | 
|---|
| 129 | case AArch64::Q20: | 
|---|
| 130 | case AArch64::Q22: | 
|---|
| 131 | case AArch64::Q24: | 
|---|
| 132 | case AArch64::Q26: | 
|---|
| 133 | case AArch64::Q28: | 
|---|
| 134 | case AArch64::Q30: | 
|---|
| 135 | return false; | 
|---|
| 136 |  | 
|---|
| 137 | } | 
|---|
| 138 | } | 
|---|
| 139 |  | 
|---|
| 140 | bool haveSameParity(unsigned reg1, unsigned reg2) { | 
|---|
| 141 | assert(AArch64InstrInfo::isFpOrNEON(reg1) && | 
|---|
| 142 | "Expecting an FP register for reg1"); | 
|---|
| 143 | assert(AArch64InstrInfo::isFpOrNEON(reg2) && | 
|---|
| 144 | "Expecting an FP register for reg2"); | 
|---|
| 145 |  | 
|---|
| 146 | return isOdd(reg: reg1) == isOdd(reg: reg2); | 
|---|
| 147 | } | 
|---|
| 148 |  | 
|---|
| 149 | } | 
|---|
| 150 |  | 
|---|
| 151 | bool A57ChainingConstraint::addIntraChainConstraint(PBQPRAGraph &G, unsigned Rd, | 
|---|
| 152 | unsigned Ra) { | 
|---|
| 153 | if (Rd == Ra) | 
|---|
| 154 | return false; | 
|---|
| 155 |  | 
|---|
| 156 | LiveIntervals &LIs = G.getMetadata().LIS; | 
|---|
| 157 |  | 
|---|
| 158 | if (Register::isPhysicalRegister(Reg: Rd) || Register::isPhysicalRegister(Reg: Ra)) { | 
|---|
| 159 | LLVM_DEBUG(dbgs() << "Rd is a physical reg:" | 
|---|
| 160 | << Register::isPhysicalRegister(Rd) << '\n'); | 
|---|
| 161 | LLVM_DEBUG(dbgs() << "Ra is a physical reg:" | 
|---|
| 162 | << Register::isPhysicalRegister(Ra) << '\n'); | 
|---|
| 163 | return false; | 
|---|
| 164 | } | 
|---|
| 165 |  | 
|---|
| 166 | PBQPRAGraph::NodeId node1 = G.getMetadata().getNodeIdForVReg(VReg: Rd); | 
|---|
| 167 | PBQPRAGraph::NodeId node2 = G.getMetadata().getNodeIdForVReg(VReg: Ra); | 
|---|
| 168 |  | 
|---|
| 169 | const PBQPRAGraph::NodeMetadata::AllowedRegVector *vRdAllowed = | 
|---|
| 170 | &G.getNodeMetadata(NId: node1).getAllowedRegs(); | 
|---|
| 171 | const PBQPRAGraph::NodeMetadata::AllowedRegVector *vRaAllowed = | 
|---|
| 172 | &G.getNodeMetadata(NId: node2).getAllowedRegs(); | 
|---|
| 173 |  | 
|---|
| 174 | PBQPRAGraph::EdgeId edge = G.findEdge(N1Id: node1, N2Id: node2); | 
|---|
| 175 |  | 
|---|
| 176 | // The edge does not exist. Create one with the appropriate interference | 
|---|
| 177 | // costs. | 
|---|
| 178 | if (edge == G.invalidEdgeId()) { | 
|---|
| 179 | const LiveInterval &ld = LIs.getInterval(Reg: Rd); | 
|---|
| 180 | const LiveInterval &la = LIs.getInterval(Reg: Ra); | 
|---|
| 181 | bool livesOverlap = ld.overlaps(other: la); | 
|---|
| 182 |  | 
|---|
| 183 | PBQPRAGraph::RawMatrix costs(vRdAllowed->size() + 1, | 
|---|
| 184 | vRaAllowed->size() + 1, 0); | 
|---|
| 185 | for (unsigned i = 0, ie = vRdAllowed->size(); i != ie; ++i) { | 
|---|
| 186 | unsigned pRd = (*vRdAllowed)[i]; | 
|---|
| 187 | for (unsigned j = 0, je = vRaAllowed->size(); j != je; ++j) { | 
|---|
| 188 | unsigned pRa = (*vRaAllowed)[j]; | 
|---|
| 189 | if (livesOverlap && TRI->regsOverlap(RegA: pRd, RegB: pRa)) | 
|---|
| 190 | costs[i + 1][j + 1] = std::numeric_limits<PBQP::PBQPNum>::infinity(); | 
|---|
| 191 | else | 
|---|
| 192 | costs[i + 1][j + 1] = haveSameParity(reg1: pRd, reg2: pRa) ? 0.0 : 1.0; | 
|---|
| 193 | } | 
|---|
| 194 | } | 
|---|
| 195 | G.addEdge(N1Id: node1, N2Id: node2, Costs: std::move(costs)); | 
|---|
| 196 | return true; | 
|---|
| 197 | } | 
|---|
| 198 |  | 
|---|
| 199 | if (G.getEdgeNode1Id(EId: edge) == node2) { | 
|---|
| 200 | std::swap(a&: node1, b&: node2); | 
|---|
| 201 | std::swap(a&: vRdAllowed, b&: vRaAllowed); | 
|---|
| 202 | } | 
|---|
| 203 |  | 
|---|
| 204 | // Enforce minCost(sameParity(RaClass)) > maxCost(otherParity(RdClass)) | 
|---|
| 205 | PBQPRAGraph::RawMatrix costs(G.getEdgeCosts(EId: edge)); | 
|---|
| 206 | for (unsigned i = 0, ie = vRdAllowed->size(); i != ie; ++i) { | 
|---|
| 207 | unsigned pRd = (*vRdAllowed)[i]; | 
|---|
| 208 |  | 
|---|
| 209 | // Get the maximum cost (excluding unallocatable reg) for same parity | 
|---|
| 210 | // registers | 
|---|
| 211 | PBQP::PBQPNum sameParityMax = std::numeric_limits<PBQP::PBQPNum>::min(); | 
|---|
| 212 | for (unsigned j = 0, je = vRaAllowed->size(); j != je; ++j) { | 
|---|
| 213 | unsigned pRa = (*vRaAllowed)[j]; | 
|---|
| 214 | if (haveSameParity(reg1: pRd, reg2: pRa)) | 
|---|
| 215 | if (costs[i + 1][j + 1] != | 
|---|
| 216 | std::numeric_limits<PBQP::PBQPNum>::infinity() && | 
|---|
| 217 | costs[i + 1][j + 1] > sameParityMax) | 
|---|
| 218 | sameParityMax = costs[i + 1][j + 1]; | 
|---|
| 219 | } | 
|---|
| 220 |  | 
|---|
| 221 | // Ensure all registers with a different parity have a higher cost | 
|---|
| 222 | // than sameParityMax | 
|---|
| 223 | for (unsigned j = 0, je = vRaAllowed->size(); j != je; ++j) { | 
|---|
| 224 | unsigned pRa = (*vRaAllowed)[j]; | 
|---|
| 225 | if (!haveSameParity(reg1: pRd, reg2: pRa)) | 
|---|
| 226 | if (sameParityMax > costs[i + 1][j + 1]) | 
|---|
| 227 | costs[i + 1][j + 1] = sameParityMax + 1.0; | 
|---|
| 228 | } | 
|---|
| 229 | } | 
|---|
| 230 | G.updateEdgeCosts(EId: edge, Costs: std::move(costs)); | 
|---|
| 231 |  | 
|---|
| 232 | return true; | 
|---|
| 233 | } | 
|---|
| 234 |  | 
|---|
| 235 | void A57ChainingConstraint::addInterChainConstraint(PBQPRAGraph &G, unsigned Rd, | 
|---|
| 236 | unsigned Ra) { | 
|---|
| 237 | LiveIntervals &LIs = G.getMetadata().LIS; | 
|---|
| 238 |  | 
|---|
| 239 | // Do some Chain management | 
|---|
| 240 | if (Chains.count(key: Ra)) { | 
|---|
| 241 | if (Rd != Ra) { | 
|---|
| 242 | LLVM_DEBUG(dbgs() << "Moving acc chain from "<< printReg(Ra, TRI) | 
|---|
| 243 | << " to "<< printReg(Rd, TRI) << '\n'); | 
|---|
| 244 | Chains.remove(X: Ra); | 
|---|
| 245 | Chains.insert(X: Rd); | 
|---|
| 246 | } | 
|---|
| 247 | } else { | 
|---|
| 248 | LLVM_DEBUG(dbgs() << "Creating new acc chain for "<< printReg(Rd, TRI) | 
|---|
| 249 | << '\n'); | 
|---|
| 250 | Chains.insert(X: Rd); | 
|---|
| 251 | } | 
|---|
| 252 |  | 
|---|
| 253 | PBQPRAGraph::NodeId node1 = G.getMetadata().getNodeIdForVReg(VReg: Rd); | 
|---|
| 254 |  | 
|---|
| 255 | const LiveInterval &ld = LIs.getInterval(Reg: Rd); | 
|---|
| 256 | for (auto r : Chains) { | 
|---|
| 257 | // Skip self | 
|---|
| 258 | if (r == Rd) | 
|---|
| 259 | continue; | 
|---|
| 260 |  | 
|---|
| 261 | const LiveInterval &lr = LIs.getInterval(Reg: r); | 
|---|
| 262 | if (ld.overlaps(other: lr)) { | 
|---|
| 263 | const PBQPRAGraph::NodeMetadata::AllowedRegVector *vRdAllowed = | 
|---|
| 264 | &G.getNodeMetadata(NId: node1).getAllowedRegs(); | 
|---|
| 265 |  | 
|---|
| 266 | PBQPRAGraph::NodeId node2 = G.getMetadata().getNodeIdForVReg(VReg: r); | 
|---|
| 267 | const PBQPRAGraph::NodeMetadata::AllowedRegVector *vRrAllowed = | 
|---|
| 268 | &G.getNodeMetadata(NId: node2).getAllowedRegs(); | 
|---|
| 269 |  | 
|---|
| 270 | PBQPRAGraph::EdgeId edge = G.findEdge(N1Id: node1, N2Id: node2); | 
|---|
| 271 | assert(edge != G.invalidEdgeId() && | 
|---|
| 272 | "PBQP error ! The edge should exist !"); | 
|---|
| 273 |  | 
|---|
| 274 | LLVM_DEBUG(dbgs() << "Refining constraint !\n"); | 
|---|
| 275 |  | 
|---|
| 276 | if (G.getEdgeNode1Id(EId: edge) == node2) { | 
|---|
| 277 | std::swap(a&: node1, b&: node2); | 
|---|
| 278 | std::swap(a&: vRdAllowed, b&: vRrAllowed); | 
|---|
| 279 | } | 
|---|
| 280 |  | 
|---|
| 281 | // Enforce that cost is higher with all other Chains of the same parity | 
|---|
| 282 | PBQP::Matrix costs(G.getEdgeCosts(EId: edge)); | 
|---|
| 283 | for (unsigned i = 0, ie = vRdAllowed->size(); i != ie; ++i) { | 
|---|
| 284 | unsigned pRd = (*vRdAllowed)[i]; | 
|---|
| 285 |  | 
|---|
| 286 | // Get the maximum cost (excluding unallocatable reg) for all other | 
|---|
| 287 | // parity registers | 
|---|
| 288 | PBQP::PBQPNum sameParityMax = std::numeric_limits<PBQP::PBQPNum>::min(); | 
|---|
| 289 | for (unsigned j = 0, je = vRrAllowed->size(); j != je; ++j) { | 
|---|
| 290 | unsigned pRa = (*vRrAllowed)[j]; | 
|---|
| 291 | if (!haveSameParity(reg1: pRd, reg2: pRa)) | 
|---|
| 292 | if (costs[i + 1][j + 1] != | 
|---|
| 293 | std::numeric_limits<PBQP::PBQPNum>::infinity() && | 
|---|
| 294 | costs[i + 1][j + 1] > sameParityMax) | 
|---|
| 295 | sameParityMax = costs[i + 1][j + 1]; | 
|---|
| 296 | } | 
|---|
| 297 |  | 
|---|
| 298 | // Ensure all registers with same parity have a higher cost | 
|---|
| 299 | // than sameParityMax | 
|---|
| 300 | for (unsigned j = 0, je = vRrAllowed->size(); j != je; ++j) { | 
|---|
| 301 | unsigned pRa = (*vRrAllowed)[j]; | 
|---|
| 302 | if (haveSameParity(reg1: pRd, reg2: pRa)) | 
|---|
| 303 | if (sameParityMax > costs[i + 1][j + 1]) | 
|---|
| 304 | costs[i + 1][j + 1] = sameParityMax + 1.0; | 
|---|
| 305 | } | 
|---|
| 306 | } | 
|---|
| 307 | G.updateEdgeCosts(EId: edge, Costs: std::move(costs)); | 
|---|
| 308 | } | 
|---|
| 309 | } | 
|---|
| 310 | } | 
|---|
| 311 |  | 
|---|
| 312 | static bool regJustKilledBefore(const LiveIntervals &LIs, unsigned reg, | 
|---|
| 313 | const MachineInstr &MI) { | 
|---|
| 314 | const LiveInterval &LI = LIs.getInterval(Reg: reg); | 
|---|
| 315 | SlotIndex SI = LIs.getInstructionIndex(Instr: MI); | 
|---|
| 316 | return LI.expiredAt(index: SI); | 
|---|
| 317 | } | 
|---|
| 318 |  | 
|---|
| 319 | void A57ChainingConstraint::apply(PBQPRAGraph &G) { | 
|---|
| 320 | const MachineFunction &MF = G.getMetadata().MF; | 
|---|
| 321 | LiveIntervals &LIs = G.getMetadata().LIS; | 
|---|
| 322 |  | 
|---|
| 323 | TRI = MF.getSubtarget().getRegisterInfo(); | 
|---|
| 324 | LLVM_DEBUG(MF.dump()); | 
|---|
| 325 |  | 
|---|
| 326 | for (const auto &MBB: MF) { | 
|---|
| 327 | Chains.clear(); // FIXME: really needed ? Could not work at MF level ? | 
|---|
| 328 |  | 
|---|
| 329 | for (const auto &MI: MBB) { | 
|---|
| 330 |  | 
|---|
| 331 | // Forget Chains which have expired | 
|---|
| 332 | for (auto r : Chains) { | 
|---|
| 333 | SmallVector<unsigned, 8> toDel; | 
|---|
| 334 | if(regJustKilledBefore(LIs, reg: r, MI)) { | 
|---|
| 335 | LLVM_DEBUG(dbgs() << "Killing chain "<< printReg(r, TRI) << " at "; | 
|---|
| 336 | MI.print(dbgs())); | 
|---|
| 337 | toDel.push_back(Elt: r); | 
|---|
| 338 | } | 
|---|
| 339 |  | 
|---|
| 340 | while (!toDel.empty()) { | 
|---|
| 341 | Chains.remove(X: toDel.back()); | 
|---|
| 342 | toDel.pop_back(); | 
|---|
| 343 | } | 
|---|
| 344 | } | 
|---|
| 345 |  | 
|---|
| 346 | switch (MI.getOpcode()) { | 
|---|
| 347 | case AArch64::FMSUBSrrr: | 
|---|
| 348 | case AArch64::FMADDSrrr: | 
|---|
| 349 | case AArch64::FNMSUBSrrr: | 
|---|
| 350 | case AArch64::FNMADDSrrr: | 
|---|
| 351 | case AArch64::FMSUBDrrr: | 
|---|
| 352 | case AArch64::FMADDDrrr: | 
|---|
| 353 | case AArch64::FNMSUBDrrr: | 
|---|
| 354 | case AArch64::FNMADDDrrr: { | 
|---|
| 355 | Register Rd = MI.getOperand(i: 0).getReg(); | 
|---|
| 356 | Register Ra = MI.getOperand(i: 3).getReg(); | 
|---|
| 357 |  | 
|---|
| 358 | if (addIntraChainConstraint(G, Rd, Ra)) | 
|---|
| 359 | addInterChainConstraint(G, Rd, Ra); | 
|---|
| 360 | break; | 
|---|
| 361 | } | 
|---|
| 362 |  | 
|---|
| 363 | case AArch64::FMLAv2f32: | 
|---|
| 364 | case AArch64::FMLSv2f32: { | 
|---|
| 365 | Register Rd = MI.getOperand(i: 0).getReg(); | 
|---|
| 366 | addInterChainConstraint(G, Rd, Ra: Rd); | 
|---|
| 367 | break; | 
|---|
| 368 | } | 
|---|
| 369 |  | 
|---|
| 370 | default: | 
|---|
| 371 | break; | 
|---|
| 372 | } | 
|---|
| 373 | } | 
|---|
| 374 | } | 
|---|
| 375 | } | 
|---|
| 376 |  | 
|---|