1//===- AArch64FalkorHWPFFix.cpp - Avoid HW prefetcher pitfalls on Falkor --===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8/// \file For Falkor, we want to avoid HW prefetcher instruction tag collisions
9/// that may inhibit the HW prefetching. This is done in two steps. Before
10/// ISel, we mark strided loads (i.e. those that will likely benefit from
11/// prefetching) with metadata. Then, after opcodes have been finalized, we
12/// insert MOVs and re-write loads to prevent unintentional tag collisions.
13// ===---------------------------------------------------------------------===//
14
15#include "AArch64.h"
16#include "AArch64InstrInfo.h"
17#include "AArch64Subtarget.h"
18#include "AArch64TargetMachine.h"
19#include "llvm/ADT/DenseMap.h"
20#include "llvm/ADT/DepthFirstIterator.h"
21#include "llvm/ADT/SmallVector.h"
22#include "llvm/ADT/Statistic.h"
23#include "llvm/Analysis/LoopInfo.h"
24#include "llvm/Analysis/ScalarEvolution.h"
25#include "llvm/Analysis/ScalarEvolutionExpressions.h"
26#include "llvm/CodeGen/LiveRegUnits.h"
27#include "llvm/CodeGen/MachineBasicBlock.h"
28#include "llvm/CodeGen/MachineFunction.h"
29#include "llvm/CodeGen/MachineFunctionPass.h"
30#include "llvm/CodeGen/MachineInstr.h"
31#include "llvm/CodeGen/MachineInstrBuilder.h"
32#include "llvm/CodeGen/MachineLoopInfo.h"
33#include "llvm/CodeGen/MachineOperand.h"
34#include "llvm/CodeGen/MachineRegisterInfo.h"
35#include "llvm/CodeGen/TargetPassConfig.h"
36#include "llvm/CodeGen/TargetRegisterInfo.h"
37#include "llvm/IR/DebugLoc.h"
38#include "llvm/IR/Dominators.h"
39#include "llvm/IR/Function.h"
40#include "llvm/IR/Instruction.h"
41#include "llvm/IR/Instructions.h"
42#include "llvm/IR/Metadata.h"
43#include "llvm/InitializePasses.h"
44#include "llvm/Pass.h"
45#include "llvm/Support/Casting.h"
46#include "llvm/Support/Debug.h"
47#include "llvm/Support/DebugCounter.h"
48#include "llvm/Support/raw_ostream.h"
49#include <iterator>
50#include <utility>
51
52using namespace llvm;
53
54#define DEBUG_TYPE "aarch64-falkor-hwpf-fix"
55
56STATISTIC(NumStridedLoadsMarked, "Number of strided loads marked");
57STATISTIC(NumCollisionsAvoided,
58 "Number of HW prefetch tag collisions avoided");
59STATISTIC(NumCollisionsNotAvoided,
60 "Number of HW prefetch tag collisions not avoided due to lack of registers");
61DEBUG_COUNTER(FixCounter, "falkor-hwpf",
62 "Controls which tag collisions are avoided");
63
64namespace {
65
66class FalkorMarkStridedAccesses {
67public:
68 FalkorMarkStridedAccesses(LoopInfo &LI, ScalarEvolution &SE)
69 : LI(LI), SE(SE) {}
70
71 bool run();
72
73private:
74 bool runOnLoop(Loop &L);
75
76 LoopInfo &LI;
77 ScalarEvolution &SE;
78};
79
80class FalkorMarkStridedAccessesLegacy : public FunctionPass {
81public:
82 static char ID; // Pass ID, replacement for typeid
83
84 FalkorMarkStridedAccessesLegacy() : FunctionPass(ID) {}
85
86 void getAnalysisUsage(AnalysisUsage &AU) const override {
87 AU.addRequired<TargetPassConfig>();
88 AU.addPreserved<DominatorTreeWrapperPass>();
89 AU.addRequired<LoopInfoWrapperPass>();
90 AU.addPreserved<LoopInfoWrapperPass>();
91 AU.addRequired<ScalarEvolutionWrapperPass>();
92 AU.addPreserved<ScalarEvolutionWrapperPass>();
93 }
94
95 bool runOnFunction(Function &F) override;
96};
97
98} // end anonymous namespace
99
100char FalkorMarkStridedAccessesLegacy::ID = 0;
101
102INITIALIZE_PASS_BEGIN(FalkorMarkStridedAccessesLegacy, DEBUG_TYPE,
103 "Falkor HW Prefetch Fix", false, false)
104INITIALIZE_PASS_DEPENDENCY(TargetPassConfig)
105INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass)
106INITIALIZE_PASS_DEPENDENCY(ScalarEvolutionWrapperPass)
107INITIALIZE_PASS_END(FalkorMarkStridedAccessesLegacy, DEBUG_TYPE,
108 "Falkor HW Prefetch Fix", false, false)
109
110FunctionPass *llvm::createFalkorMarkStridedAccessesPass() {
111 return new FalkorMarkStridedAccessesLegacy();
112}
113
114bool FalkorMarkStridedAccessesLegacy::runOnFunction(Function &F) {
115 TargetPassConfig &TPC = getAnalysis<TargetPassConfig>();
116 const AArch64Subtarget *ST =
117 TPC.getTM<AArch64TargetMachine>().getSubtargetImpl(F);
118 if (ST->getProcFamily() != AArch64Subtarget::Falkor)
119 return false;
120
121 if (skipFunction(F))
122 return false;
123
124 LoopInfo &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
125 ScalarEvolution &SE = getAnalysis<ScalarEvolutionWrapperPass>().getSE();
126
127 FalkorMarkStridedAccesses LDP(LI, SE);
128 return LDP.run();
129}
130
131bool FalkorMarkStridedAccesses::run() {
132 bool MadeChange = false;
133
134 for (Loop *L : LI)
135 for (Loop *LIt : depth_first(G: L))
136 MadeChange |= runOnLoop(L&: *LIt);
137
138 return MadeChange;
139}
140
141bool FalkorMarkStridedAccesses::runOnLoop(Loop &L) {
142 // Only mark strided loads in the inner-most loop
143 if (!L.isInnermost())
144 return false;
145
146 bool MadeChange = false;
147
148 for (BasicBlock *BB : L.blocks()) {
149 for (Instruction &I : *BB) {
150 LoadInst *LoadI = dyn_cast<LoadInst>(Val: &I);
151 if (!LoadI)
152 continue;
153
154 Value *PtrValue = LoadI->getPointerOperand();
155 if (L.isLoopInvariant(V: PtrValue))
156 continue;
157
158 const SCEV *LSCEV = SE.getSCEV(V: PtrValue);
159 const SCEVAddRecExpr *LSCEVAddRec = dyn_cast<SCEVAddRecExpr>(Val: LSCEV);
160 if (!LSCEVAddRec || !LSCEVAddRec->isAffine())
161 continue;
162
163 LoadI->setMetadata(FALKOR_STRIDED_ACCESS_MD,
164 Node: MDNode::get(Context&: LoadI->getContext(), MDs: {}));
165 ++NumStridedLoadsMarked;
166 LLVM_DEBUG(dbgs() << "Load: " << I << " marked as strided\n");
167 MadeChange = true;
168 }
169 }
170
171 return MadeChange;
172}
173
174namespace {
175
176class FalkorHWPFFix : public MachineFunctionPass {
177public:
178 static char ID;
179
180 FalkorHWPFFix() : MachineFunctionPass(ID) {}
181
182 bool runOnMachineFunction(MachineFunction &Fn) override;
183
184 void getAnalysisUsage(AnalysisUsage &AU) const override {
185 AU.setPreservesCFG();
186 AU.addRequired<MachineLoopInfoWrapperPass>();
187 MachineFunctionPass::getAnalysisUsage(AU);
188 }
189
190 MachineFunctionProperties getRequiredProperties() const override {
191 return MachineFunctionProperties().setNoVRegs();
192 }
193
194private:
195 void runOnLoop(MachineLoop &L, MachineFunction &Fn);
196
197 const AArch64InstrInfo *TII;
198 const TargetRegisterInfo *TRI;
199 DenseMap<unsigned, SmallVector<MachineInstr *, 4>> TagMap;
200 bool Modified;
201};
202
203/// Bits from load opcodes used to compute HW prefetcher instruction tags.
204struct LoadInfo {
205 LoadInfo() = default;
206
207 Register DestReg;
208 Register BaseReg;
209 int BaseRegIdx = -1;
210 const MachineOperand *OffsetOpnd = nullptr;
211 bool IsPrePost = false;
212};
213
214} // end anonymous namespace
215
216char FalkorHWPFFix::ID = 0;
217
218INITIALIZE_PASS_BEGIN(FalkorHWPFFix, "aarch64-falkor-hwpf-fix-late",
219 "Falkor HW Prefetch Fix Late Phase", false, false)
220INITIALIZE_PASS_DEPENDENCY(MachineLoopInfoWrapperPass)
221INITIALIZE_PASS_END(FalkorHWPFFix, "aarch64-falkor-hwpf-fix-late",
222 "Falkor HW Prefetch Fix Late Phase", false, false)
223
224static unsigned makeTag(unsigned Dest, unsigned Base, unsigned Offset) {
225 return (Dest & 0xf) | ((Base & 0xf) << 4) | ((Offset & 0x3f) << 8);
226}
227
228static std::optional<LoadInfo> getLoadInfo(const MachineInstr &MI) {
229 int DestRegIdx;
230 int BaseRegIdx;
231 int OffsetIdx;
232 bool IsPrePost;
233
234 switch (MI.getOpcode()) {
235 default:
236 return std::nullopt;
237
238 case AArch64::LD1i64:
239 case AArch64::LD2i64:
240 DestRegIdx = 0;
241 BaseRegIdx = 3;
242 OffsetIdx = -1;
243 IsPrePost = false;
244 break;
245
246 case AArch64::LD1i8:
247 case AArch64::LD1i16:
248 case AArch64::LD1i32:
249 case AArch64::LD2i8:
250 case AArch64::LD2i16:
251 case AArch64::LD2i32:
252 case AArch64::LD3i8:
253 case AArch64::LD3i16:
254 case AArch64::LD3i32:
255 case AArch64::LD3i64:
256 case AArch64::LD4i8:
257 case AArch64::LD4i16:
258 case AArch64::LD4i32:
259 case AArch64::LD4i64:
260 DestRegIdx = -1;
261 BaseRegIdx = 3;
262 OffsetIdx = -1;
263 IsPrePost = false;
264 break;
265
266 case AArch64::LD1Onev1d:
267 case AArch64::LD1Onev2s:
268 case AArch64::LD1Onev4h:
269 case AArch64::LD1Onev8b:
270 case AArch64::LD1Onev2d:
271 case AArch64::LD1Onev4s:
272 case AArch64::LD1Onev8h:
273 case AArch64::LD1Onev16b:
274 case AArch64::LD1Rv1d:
275 case AArch64::LD1Rv2s:
276 case AArch64::LD1Rv4h:
277 case AArch64::LD1Rv8b:
278 case AArch64::LD1Rv2d:
279 case AArch64::LD1Rv4s:
280 case AArch64::LD1Rv8h:
281 case AArch64::LD1Rv16b:
282 DestRegIdx = 0;
283 BaseRegIdx = 1;
284 OffsetIdx = -1;
285 IsPrePost = false;
286 break;
287
288 case AArch64::LD1Twov1d:
289 case AArch64::LD1Twov2s:
290 case AArch64::LD1Twov4h:
291 case AArch64::LD1Twov8b:
292 case AArch64::LD1Twov2d:
293 case AArch64::LD1Twov4s:
294 case AArch64::LD1Twov8h:
295 case AArch64::LD1Twov16b:
296 case AArch64::LD1Threev1d:
297 case AArch64::LD1Threev2s:
298 case AArch64::LD1Threev4h:
299 case AArch64::LD1Threev8b:
300 case AArch64::LD1Threev2d:
301 case AArch64::LD1Threev4s:
302 case AArch64::LD1Threev8h:
303 case AArch64::LD1Threev16b:
304 case AArch64::LD1Fourv1d:
305 case AArch64::LD1Fourv2s:
306 case AArch64::LD1Fourv4h:
307 case AArch64::LD1Fourv8b:
308 case AArch64::LD1Fourv2d:
309 case AArch64::LD1Fourv4s:
310 case AArch64::LD1Fourv8h:
311 case AArch64::LD1Fourv16b:
312 case AArch64::LD2Twov2s:
313 case AArch64::LD2Twov4s:
314 case AArch64::LD2Twov8b:
315 case AArch64::LD2Twov2d:
316 case AArch64::LD2Twov4h:
317 case AArch64::LD2Twov8h:
318 case AArch64::LD2Twov16b:
319 case AArch64::LD2Rv1d:
320 case AArch64::LD2Rv2s:
321 case AArch64::LD2Rv4s:
322 case AArch64::LD2Rv8b:
323 case AArch64::LD2Rv2d:
324 case AArch64::LD2Rv4h:
325 case AArch64::LD2Rv8h:
326 case AArch64::LD2Rv16b:
327 case AArch64::LD3Threev2s:
328 case AArch64::LD3Threev4h:
329 case AArch64::LD3Threev8b:
330 case AArch64::LD3Threev2d:
331 case AArch64::LD3Threev4s:
332 case AArch64::LD3Threev8h:
333 case AArch64::LD3Threev16b:
334 case AArch64::LD3Rv1d:
335 case AArch64::LD3Rv2s:
336 case AArch64::LD3Rv4h:
337 case AArch64::LD3Rv8b:
338 case AArch64::LD3Rv2d:
339 case AArch64::LD3Rv4s:
340 case AArch64::LD3Rv8h:
341 case AArch64::LD3Rv16b:
342 case AArch64::LD4Fourv2s:
343 case AArch64::LD4Fourv4h:
344 case AArch64::LD4Fourv8b:
345 case AArch64::LD4Fourv2d:
346 case AArch64::LD4Fourv4s:
347 case AArch64::LD4Fourv8h:
348 case AArch64::LD4Fourv16b:
349 case AArch64::LD4Rv1d:
350 case AArch64::LD4Rv2s:
351 case AArch64::LD4Rv4h:
352 case AArch64::LD4Rv8b:
353 case AArch64::LD4Rv2d:
354 case AArch64::LD4Rv4s:
355 case AArch64::LD4Rv8h:
356 case AArch64::LD4Rv16b:
357 DestRegIdx = -1;
358 BaseRegIdx = 1;
359 OffsetIdx = -1;
360 IsPrePost = false;
361 break;
362
363 case AArch64::LD1i64_POST:
364 case AArch64::LD2i64_POST:
365 DestRegIdx = 1;
366 BaseRegIdx = 4;
367 OffsetIdx = 5;
368 IsPrePost = true;
369 break;
370
371 case AArch64::LD1i8_POST:
372 case AArch64::LD1i16_POST:
373 case AArch64::LD1i32_POST:
374 case AArch64::LD2i8_POST:
375 case AArch64::LD2i16_POST:
376 case AArch64::LD2i32_POST:
377 case AArch64::LD3i8_POST:
378 case AArch64::LD3i16_POST:
379 case AArch64::LD3i32_POST:
380 case AArch64::LD3i64_POST:
381 case AArch64::LD4i8_POST:
382 case AArch64::LD4i16_POST:
383 case AArch64::LD4i32_POST:
384 case AArch64::LD4i64_POST:
385 DestRegIdx = -1;
386 BaseRegIdx = 4;
387 OffsetIdx = 5;
388 IsPrePost = true;
389 break;
390
391 case AArch64::LD1Onev1d_POST:
392 case AArch64::LD1Onev2s_POST:
393 case AArch64::LD1Onev4h_POST:
394 case AArch64::LD1Onev8b_POST:
395 case AArch64::LD1Onev2d_POST:
396 case AArch64::LD1Onev4s_POST:
397 case AArch64::LD1Onev8h_POST:
398 case AArch64::LD1Onev16b_POST:
399 case AArch64::LD1Rv1d_POST:
400 case AArch64::LD1Rv2s_POST:
401 case AArch64::LD1Rv4h_POST:
402 case AArch64::LD1Rv8b_POST:
403 case AArch64::LD1Rv2d_POST:
404 case AArch64::LD1Rv4s_POST:
405 case AArch64::LD1Rv8h_POST:
406 case AArch64::LD1Rv16b_POST:
407 DestRegIdx = 1;
408 BaseRegIdx = 2;
409 OffsetIdx = 3;
410 IsPrePost = true;
411 break;
412
413 case AArch64::LD1Twov1d_POST:
414 case AArch64::LD1Twov2s_POST:
415 case AArch64::LD1Twov4h_POST:
416 case AArch64::LD1Twov8b_POST:
417 case AArch64::LD1Twov2d_POST:
418 case AArch64::LD1Twov4s_POST:
419 case AArch64::LD1Twov8h_POST:
420 case AArch64::LD1Twov16b_POST:
421 case AArch64::LD1Threev1d_POST:
422 case AArch64::LD1Threev2s_POST:
423 case AArch64::LD1Threev4h_POST:
424 case AArch64::LD1Threev8b_POST:
425 case AArch64::LD1Threev2d_POST:
426 case AArch64::LD1Threev4s_POST:
427 case AArch64::LD1Threev8h_POST:
428 case AArch64::LD1Threev16b_POST:
429 case AArch64::LD1Fourv1d_POST:
430 case AArch64::LD1Fourv2s_POST:
431 case AArch64::LD1Fourv4h_POST:
432 case AArch64::LD1Fourv8b_POST:
433 case AArch64::LD1Fourv2d_POST:
434 case AArch64::LD1Fourv4s_POST:
435 case AArch64::LD1Fourv8h_POST:
436 case AArch64::LD1Fourv16b_POST:
437 case AArch64::LD2Twov2s_POST:
438 case AArch64::LD2Twov4s_POST:
439 case AArch64::LD2Twov8b_POST:
440 case AArch64::LD2Twov2d_POST:
441 case AArch64::LD2Twov4h_POST:
442 case AArch64::LD2Twov8h_POST:
443 case AArch64::LD2Twov16b_POST:
444 case AArch64::LD2Rv1d_POST:
445 case AArch64::LD2Rv2s_POST:
446 case AArch64::LD2Rv4s_POST:
447 case AArch64::LD2Rv8b_POST:
448 case AArch64::LD2Rv2d_POST:
449 case AArch64::LD2Rv4h_POST:
450 case AArch64::LD2Rv8h_POST:
451 case AArch64::LD2Rv16b_POST:
452 case AArch64::LD3Threev2s_POST:
453 case AArch64::LD3Threev4h_POST:
454 case AArch64::LD3Threev8b_POST:
455 case AArch64::LD3Threev2d_POST:
456 case AArch64::LD3Threev4s_POST:
457 case AArch64::LD3Threev8h_POST:
458 case AArch64::LD3Threev16b_POST:
459 case AArch64::LD3Rv1d_POST:
460 case AArch64::LD3Rv2s_POST:
461 case AArch64::LD3Rv4h_POST:
462 case AArch64::LD3Rv8b_POST:
463 case AArch64::LD3Rv2d_POST:
464 case AArch64::LD3Rv4s_POST:
465 case AArch64::LD3Rv8h_POST:
466 case AArch64::LD3Rv16b_POST:
467 case AArch64::LD4Fourv2s_POST:
468 case AArch64::LD4Fourv4h_POST:
469 case AArch64::LD4Fourv8b_POST:
470 case AArch64::LD4Fourv2d_POST:
471 case AArch64::LD4Fourv4s_POST:
472 case AArch64::LD4Fourv8h_POST:
473 case AArch64::LD4Fourv16b_POST:
474 case AArch64::LD4Rv1d_POST:
475 case AArch64::LD4Rv2s_POST:
476 case AArch64::LD4Rv4h_POST:
477 case AArch64::LD4Rv8b_POST:
478 case AArch64::LD4Rv2d_POST:
479 case AArch64::LD4Rv4s_POST:
480 case AArch64::LD4Rv8h_POST:
481 case AArch64::LD4Rv16b_POST:
482 DestRegIdx = -1;
483 BaseRegIdx = 2;
484 OffsetIdx = 3;
485 IsPrePost = true;
486 break;
487
488 case AArch64::LDRBBroW:
489 case AArch64::LDRBBroX:
490 case AArch64::LDRBBui:
491 case AArch64::LDRBroW:
492 case AArch64::LDRBroX:
493 case AArch64::LDRBui:
494 case AArch64::LDRDl:
495 case AArch64::LDRDroW:
496 case AArch64::LDRDroX:
497 case AArch64::LDRDui:
498 case AArch64::LDRHHroW:
499 case AArch64::LDRHHroX:
500 case AArch64::LDRHHui:
501 case AArch64::LDRHroW:
502 case AArch64::LDRHroX:
503 case AArch64::LDRHui:
504 case AArch64::LDRQl:
505 case AArch64::LDRQroW:
506 case AArch64::LDRQroX:
507 case AArch64::LDRQui:
508 case AArch64::LDRSBWroW:
509 case AArch64::LDRSBWroX:
510 case AArch64::LDRSBWui:
511 case AArch64::LDRSBXroW:
512 case AArch64::LDRSBXroX:
513 case AArch64::LDRSBXui:
514 case AArch64::LDRSHWroW:
515 case AArch64::LDRSHWroX:
516 case AArch64::LDRSHWui:
517 case AArch64::LDRSHXroW:
518 case AArch64::LDRSHXroX:
519 case AArch64::LDRSHXui:
520 case AArch64::LDRSWl:
521 case AArch64::LDRSWroW:
522 case AArch64::LDRSWroX:
523 case AArch64::LDRSWui:
524 case AArch64::LDRSl:
525 case AArch64::LDRSroW:
526 case AArch64::LDRSroX:
527 case AArch64::LDRSui:
528 case AArch64::LDRWl:
529 case AArch64::LDRWroW:
530 case AArch64::LDRWroX:
531 case AArch64::LDRWui:
532 case AArch64::LDRXl:
533 case AArch64::LDRXroW:
534 case AArch64::LDRXroX:
535 case AArch64::LDRXui:
536 case AArch64::LDURBBi:
537 case AArch64::LDURBi:
538 case AArch64::LDURDi:
539 case AArch64::LDURHHi:
540 case AArch64::LDURHi:
541 case AArch64::LDURQi:
542 case AArch64::LDURSBWi:
543 case AArch64::LDURSBXi:
544 case AArch64::LDURSHWi:
545 case AArch64::LDURSHXi:
546 case AArch64::LDURSWi:
547 case AArch64::LDURSi:
548 case AArch64::LDURWi:
549 case AArch64::LDURXi:
550 DestRegIdx = 0;
551 BaseRegIdx = 1;
552 OffsetIdx = 2;
553 IsPrePost = false;
554 break;
555
556 case AArch64::LDRBBpost:
557 case AArch64::LDRBBpre:
558 case AArch64::LDRBpost:
559 case AArch64::LDRBpre:
560 case AArch64::LDRDpost:
561 case AArch64::LDRDpre:
562 case AArch64::LDRHHpost:
563 case AArch64::LDRHHpre:
564 case AArch64::LDRHpost:
565 case AArch64::LDRHpre:
566 case AArch64::LDRQpost:
567 case AArch64::LDRQpre:
568 case AArch64::LDRSBWpost:
569 case AArch64::LDRSBWpre:
570 case AArch64::LDRSBXpost:
571 case AArch64::LDRSBXpre:
572 case AArch64::LDRSHWpost:
573 case AArch64::LDRSHWpre:
574 case AArch64::LDRSHXpost:
575 case AArch64::LDRSHXpre:
576 case AArch64::LDRSWpost:
577 case AArch64::LDRSWpre:
578 case AArch64::LDRSpost:
579 case AArch64::LDRSpre:
580 case AArch64::LDRWpost:
581 case AArch64::LDRWpre:
582 case AArch64::LDRXpost:
583 case AArch64::LDRXpre:
584 DestRegIdx = 1;
585 BaseRegIdx = 2;
586 OffsetIdx = 3;
587 IsPrePost = true;
588 break;
589
590 case AArch64::LDNPDi:
591 case AArch64::LDNPQi:
592 case AArch64::LDNPSi:
593 case AArch64::LDPQi:
594 case AArch64::LDPDi:
595 case AArch64::LDPSi:
596 DestRegIdx = -1;
597 BaseRegIdx = 2;
598 OffsetIdx = 3;
599 IsPrePost = false;
600 break;
601
602 case AArch64::LDPSWi:
603 case AArch64::LDPWi:
604 case AArch64::LDPXi:
605 DestRegIdx = 0;
606 BaseRegIdx = 2;
607 OffsetIdx = 3;
608 IsPrePost = false;
609 break;
610
611 case AArch64::LDPQpost:
612 case AArch64::LDPQpre:
613 case AArch64::LDPDpost:
614 case AArch64::LDPDpre:
615 case AArch64::LDPSpost:
616 case AArch64::LDPSpre:
617 DestRegIdx = -1;
618 BaseRegIdx = 3;
619 OffsetIdx = 4;
620 IsPrePost = true;
621 break;
622
623 case AArch64::LDPSWpost:
624 case AArch64::LDPSWpre:
625 case AArch64::LDPWpost:
626 case AArch64::LDPWpre:
627 case AArch64::LDPXpost:
628 case AArch64::LDPXpre:
629 DestRegIdx = 1;
630 BaseRegIdx = 3;
631 OffsetIdx = 4;
632 IsPrePost = true;
633 break;
634 }
635
636 // Loads from the stack pointer don't get prefetched.
637 Register BaseReg = MI.getOperand(i: BaseRegIdx).getReg();
638 if (BaseReg == AArch64::SP || BaseReg == AArch64::WSP)
639 return std::nullopt;
640
641 LoadInfo LI;
642 LI.DestReg = DestRegIdx == -1 ? Register() : MI.getOperand(i: DestRegIdx).getReg();
643 LI.BaseReg = BaseReg;
644 LI.BaseRegIdx = BaseRegIdx;
645 LI.OffsetOpnd = OffsetIdx == -1 ? nullptr : &MI.getOperand(i: OffsetIdx);
646 LI.IsPrePost = IsPrePost;
647 return LI;
648}
649
650static std::optional<unsigned> getTag(const TargetRegisterInfo *TRI,
651 const MachineInstr &MI,
652 const LoadInfo &LI) {
653 unsigned Dest = LI.DestReg ? TRI->getEncodingValue(Reg: LI.DestReg) : 0;
654 unsigned Base = TRI->getEncodingValue(Reg: LI.BaseReg);
655 unsigned Off;
656 if (LI.OffsetOpnd == nullptr)
657 Off = 0;
658 else if (LI.OffsetOpnd->isGlobal() || LI.OffsetOpnd->isSymbol() ||
659 LI.OffsetOpnd->isCPI())
660 return std::nullopt;
661 else if (LI.OffsetOpnd->isReg())
662 Off = (1 << 5) | TRI->getEncodingValue(Reg: LI.OffsetOpnd->getReg());
663 else
664 Off = LI.OffsetOpnd->getImm() >> 2;
665
666 return makeTag(Dest, Base, Offset: Off);
667}
668
669void FalkorHWPFFix::runOnLoop(MachineLoop &L, MachineFunction &Fn) {
670 // Build the initial tag map for the whole loop.
671 TagMap.clear();
672 for (MachineBasicBlock *MBB : L.getBlocks())
673 for (MachineInstr &MI : *MBB) {
674 std::optional<LoadInfo> LInfo = getLoadInfo(MI);
675 if (!LInfo)
676 continue;
677 std::optional<unsigned> Tag = getTag(TRI, MI, LI: *LInfo);
678 if (!Tag)
679 continue;
680 TagMap[*Tag].push_back(Elt: &MI);
681 }
682
683 bool AnyCollisions = false;
684 for (auto &P : TagMap) {
685 auto Size = P.second.size();
686 if (Size > 1) {
687 for (auto *MI : P.second) {
688 if (TII->isStridedAccess(MI: *MI)) {
689 AnyCollisions = true;
690 break;
691 }
692 }
693 }
694 if (AnyCollisions)
695 break;
696 }
697 // Nothing to fix.
698 if (!AnyCollisions)
699 return;
700
701 MachineRegisterInfo &MRI = Fn.getRegInfo();
702
703 // Go through all the basic blocks in the current loop and fix any streaming
704 // loads to avoid collisions with any other loads.
705 LiveRegUnits LR(*TRI);
706 for (MachineBasicBlock *MBB : L.getBlocks()) {
707 LR.clear();
708 LR.addLiveOuts(MBB: *MBB);
709 for (auto I = MBB->rbegin(); I != MBB->rend(); LR.stepBackward(MI: *I), ++I) {
710 MachineInstr &MI = *I;
711 if (!TII->isStridedAccess(MI))
712 continue;
713
714 std::optional<LoadInfo> OptLdI = getLoadInfo(MI);
715 if (!OptLdI)
716 continue;
717 LoadInfo LdI = *OptLdI;
718 std::optional<unsigned> OptOldTag = getTag(TRI, MI, LI: LdI);
719 if (!OptOldTag)
720 continue;
721 auto &OldCollisions = TagMap[*OptOldTag];
722 if (OldCollisions.size() <= 1)
723 continue;
724
725 bool Fixed = false;
726 LLVM_DEBUG(dbgs() << "Attempting to fix tag collision: " << MI);
727
728 if (!DebugCounter::shouldExecute(CounterName: FixCounter)) {
729 LLVM_DEBUG(dbgs() << "Skipping fix due to debug counter:\n " << MI);
730 continue;
731 }
732
733 // Add the non-base registers of MI as live so we don't use them as
734 // scratch registers.
735 for (unsigned OpI = 0, OpE = MI.getNumOperands(); OpI < OpE; ++OpI) {
736 if (OpI == static_cast<unsigned>(LdI.BaseRegIdx))
737 continue;
738 MachineOperand &MO = MI.getOperand(i: OpI);
739 if (MO.isReg() && MO.readsReg())
740 LR.addReg(Reg: MO.getReg());
741 }
742
743 for (unsigned ScratchReg : AArch64::GPR64RegClass) {
744 if (!LR.available(Reg: ScratchReg) || MRI.isReserved(PhysReg: ScratchReg))
745 continue;
746
747 LoadInfo NewLdI(LdI);
748 NewLdI.BaseReg = ScratchReg;
749 unsigned NewTag = *getTag(TRI, MI, LI: NewLdI);
750 // Scratch reg tag would collide too, so don't use it.
751 if (TagMap.count(Val: NewTag))
752 continue;
753
754 LLVM_DEBUG(dbgs() << "Changing base reg to: "
755 << printReg(ScratchReg, TRI) << '\n');
756
757 // Rewrite:
758 // Xd = LOAD Xb, off
759 // to:
760 // Xc = MOV Xb
761 // Xd = LOAD Xc, off
762 DebugLoc DL = MI.getDebugLoc();
763 BuildMI(BB&: *MBB, I: &MI, MIMD: DL, MCID: TII->get(Opcode: AArch64::ORRXrs), DestReg: ScratchReg)
764 .addReg(RegNo: AArch64::XZR)
765 .addReg(RegNo: LdI.BaseReg)
766 .addImm(Val: 0);
767 MachineOperand &BaseOpnd = MI.getOperand(i: LdI.BaseRegIdx);
768 BaseOpnd.setReg(ScratchReg);
769
770 // If the load does a pre/post increment, then insert a MOV after as
771 // well to update the real base register.
772 if (LdI.IsPrePost) {
773 LLVM_DEBUG(dbgs() << "Doing post MOV of incremented reg: "
774 << printReg(ScratchReg, TRI) << '\n');
775 MI.getOperand(i: 0).setReg(
776 ScratchReg); // Change tied operand pre/post update dest.
777 BuildMI(BB&: *MBB, I: std::next(x: MachineBasicBlock::iterator(MI)), MIMD: DL,
778 MCID: TII->get(Opcode: AArch64::ORRXrs), DestReg: LdI.BaseReg)
779 .addReg(RegNo: AArch64::XZR)
780 .addReg(RegNo: ScratchReg)
781 .addImm(Val: 0);
782 }
783
784 for (int I = 0, E = OldCollisions.size(); I != E; ++I)
785 if (OldCollisions[I] == &MI) {
786 std::swap(a&: OldCollisions[I], b&: OldCollisions[E - 1]);
787 OldCollisions.pop_back();
788 break;
789 }
790
791 // Update TagMap to reflect instruction changes to reduce the number
792 // of later MOVs to be inserted. This needs to be done after
793 // OldCollisions is updated since it may be relocated by this
794 // insertion.
795 TagMap[NewTag].push_back(Elt: &MI);
796 ++NumCollisionsAvoided;
797 Fixed = true;
798 Modified = true;
799 break;
800 }
801 if (!Fixed)
802 ++NumCollisionsNotAvoided;
803 }
804 }
805}
806
807bool FalkorHWPFFix::runOnMachineFunction(MachineFunction &Fn) {
808 auto &ST = Fn.getSubtarget<AArch64Subtarget>();
809 if (ST.getProcFamily() != AArch64Subtarget::Falkor)
810 return false;
811
812 if (skipFunction(F: Fn.getFunction()))
813 return false;
814
815 TII = static_cast<const AArch64InstrInfo *>(ST.getInstrInfo());
816 TRI = ST.getRegisterInfo();
817
818 MachineLoopInfo &LI = getAnalysis<MachineLoopInfoWrapperPass>().getLI();
819
820 Modified = false;
821
822 for (MachineLoop *I : LI)
823 for (MachineLoop *L : depth_first(G: I))
824 // Only process inner-loops
825 if (L->isInnermost())
826 runOnLoop(L&: *L, Fn);
827
828 return Modified;
829}
830
831FunctionPass *llvm::createFalkorHWPFFixPass() { return new FalkorHWPFFix(); }
832