1 | //===- ARMParallelDSP.cpp - Parallel DSP Pass -----------------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | /// \file |
10 | /// Armv6 introduced instructions to perform 32-bit SIMD operations. The |
11 | /// purpose of this pass is do some IR pattern matching to create ACLE |
12 | /// DSP intrinsics, which map on these 32-bit SIMD operations. |
13 | /// This pass runs only when unaligned accesses is supported/enabled. |
14 | // |
15 | //===----------------------------------------------------------------------===// |
16 | |
17 | #include "ARM.h" |
18 | #include "ARMSubtarget.h" |
19 | #include "llvm/ADT/SmallPtrSet.h" |
20 | #include "llvm/ADT/Statistic.h" |
21 | #include "llvm/Analysis/AliasAnalysis.h" |
22 | #include "llvm/Analysis/AssumptionCache.h" |
23 | #include "llvm/Analysis/GlobalsModRef.h" |
24 | #include "llvm/Analysis/LoopAccessAnalysis.h" |
25 | #include "llvm/Analysis/TargetLibraryInfo.h" |
26 | #include "llvm/CodeGen/TargetPassConfig.h" |
27 | #include "llvm/IR/IRBuilder.h" |
28 | #include "llvm/IR/Instructions.h" |
29 | #include "llvm/IR/IntrinsicsARM.h" |
30 | #include "llvm/IR/Module.h" |
31 | #include "llvm/IR/NoFolder.h" |
32 | #include "llvm/IR/PatternMatch.h" |
33 | #include "llvm/Pass.h" |
34 | #include "llvm/PassRegistry.h" |
35 | #include "llvm/Support/Debug.h" |
36 | #include "llvm/Transforms/Scalar.h" |
37 | #include "llvm/Transforms/Utils/BasicBlockUtils.h" |
38 | |
39 | using namespace llvm; |
40 | using namespace PatternMatch; |
41 | |
42 | #define DEBUG_TYPE "arm-parallel-dsp" |
43 | |
44 | STATISTIC(NumSMLAD , "Number of smlad instructions generated" ); |
45 | |
46 | static cl::opt<bool> |
47 | DisableParallelDSP("disable-arm-parallel-dsp" , cl::Hidden, cl::init(Val: false), |
48 | cl::desc("Disable the ARM Parallel DSP pass" )); |
49 | |
50 | static cl::opt<unsigned> |
51 | NumLoadLimit("arm-parallel-dsp-load-limit" , cl::Hidden, cl::init(Val: 16), |
52 | cl::desc("Limit the number of loads analysed" )); |
53 | |
54 | namespace { |
55 | struct MulCandidate; |
56 | class Reduction; |
57 | |
58 | using MulCandList = SmallVector<std::unique_ptr<MulCandidate>, 8>; |
59 | using MemInstList = SmallVectorImpl<LoadInst*>; |
60 | using MulPairList = SmallVector<std::pair<MulCandidate*, MulCandidate*>, 8>; |
61 | |
62 | // 'MulCandidate' holds the multiplication instructions that are candidates |
63 | // for parallel execution. |
64 | struct MulCandidate { |
65 | Instruction *Root; |
66 | Value* LHS; |
67 | Value* RHS; |
68 | bool Exchange = false; |
69 | bool Paired = false; |
70 | SmallVector<LoadInst*, 2> VecLd; // Container for loads to widen. |
71 | |
72 | MulCandidate(Instruction *I, Value *lhs, Value *rhs) : |
73 | Root(I), LHS(lhs), RHS(rhs) { } |
74 | |
75 | bool HasTwoLoadInputs() const { |
76 | return isa<LoadInst>(Val: LHS) && isa<LoadInst>(Val: RHS); |
77 | } |
78 | |
79 | LoadInst *getBaseLoad() const { |
80 | return VecLd.front(); |
81 | } |
82 | }; |
83 | |
84 | /// Represent a sequence of multiply-accumulate operations with the aim to |
85 | /// perform the multiplications in parallel. |
86 | class Reduction { |
87 | Instruction *Root = nullptr; |
88 | Value *Acc = nullptr; |
89 | MulCandList Muls; |
90 | MulPairList MulPairs; |
91 | SetVector<Instruction*> Adds; |
92 | |
93 | public: |
94 | Reduction() = delete; |
95 | |
96 | Reduction (Instruction *Add) : Root(Add) { } |
97 | |
98 | /// Record an Add instruction that is a part of the this reduction. |
99 | void InsertAdd(Instruction *I) { Adds.insert(X: I); } |
100 | |
101 | /// Create MulCandidates, each rooted at a Mul instruction, that is a part |
102 | /// of this reduction. |
103 | void InsertMuls() { |
104 | auto GetMulOperand = [](Value *V) -> Instruction* { |
105 | if (auto *SExt = dyn_cast<SExtInst>(Val: V)) { |
106 | if (auto *I = dyn_cast<Instruction>(Val: SExt->getOperand(i_nocapture: 0))) |
107 | if (I->getOpcode() == Instruction::Mul) |
108 | return I; |
109 | } else if (auto *I = dyn_cast<Instruction>(Val: V)) { |
110 | if (I->getOpcode() == Instruction::Mul) |
111 | return I; |
112 | } |
113 | return nullptr; |
114 | }; |
115 | |
116 | auto InsertMul = [this](Instruction *I) { |
117 | Value *LHS = cast<Instruction>(Val: I->getOperand(i: 0))->getOperand(i: 0); |
118 | Value *RHS = cast<Instruction>(Val: I->getOperand(i: 1))->getOperand(i: 0); |
119 | Muls.push_back(Elt: std::make_unique<MulCandidate>(args&: I, args&: LHS, args&: RHS)); |
120 | }; |
121 | |
122 | for (auto *Add : Adds) { |
123 | if (Add == Acc) |
124 | continue; |
125 | if (auto *Mul = GetMulOperand(Add->getOperand(i: 0))) |
126 | InsertMul(Mul); |
127 | if (auto *Mul = GetMulOperand(Add->getOperand(i: 1))) |
128 | InsertMul(Mul); |
129 | } |
130 | } |
131 | |
132 | /// Add the incoming accumulator value, returns true if a value had not |
133 | /// already been added. Returning false signals to the user that this |
134 | /// reduction already has a value to initialise the accumulator. |
135 | bool InsertAcc(Value *V) { |
136 | if (Acc) |
137 | return false; |
138 | Acc = V; |
139 | return true; |
140 | } |
141 | |
142 | /// Set two MulCandidates, rooted at muls, that can be executed as a single |
143 | /// parallel operation. |
144 | void AddMulPair(MulCandidate *Mul0, MulCandidate *Mul1, |
145 | bool Exchange = false) { |
146 | LLVM_DEBUG(dbgs() << "Pairing:\n" |
147 | << *Mul0->Root << "\n" |
148 | << *Mul1->Root << "\n" ); |
149 | Mul0->Paired = true; |
150 | Mul1->Paired = true; |
151 | if (Exchange) |
152 | Mul1->Exchange = true; |
153 | MulPairs.push_back(Elt: std::make_pair(x&: Mul0, y&: Mul1)); |
154 | } |
155 | |
156 | /// Return the add instruction which is the root of the reduction. |
157 | Instruction *getRoot() { return Root; } |
158 | |
159 | bool is64Bit() const { return Root->getType()->isIntegerTy(Bitwidth: 64); } |
160 | |
161 | Type *getType() const { return Root->getType(); } |
162 | |
163 | /// Return the incoming value to be accumulated. This maybe null. |
164 | Value *getAccumulator() { return Acc; } |
165 | |
166 | /// Return the set of adds that comprise the reduction. |
167 | SetVector<Instruction*> &getAdds() { return Adds; } |
168 | |
169 | /// Return the MulCandidate, rooted at mul instruction, that comprise the |
170 | /// the reduction. |
171 | MulCandList &getMuls() { return Muls; } |
172 | |
173 | /// Return the MulCandidate, rooted at mul instructions, that have been |
174 | /// paired for parallel execution. |
175 | MulPairList &getMulPairs() { return MulPairs; } |
176 | |
177 | /// To finalise, replace the uses of the root with the intrinsic call. |
178 | void UpdateRoot(Instruction *SMLAD) { |
179 | Root->replaceAllUsesWith(V: SMLAD); |
180 | } |
181 | |
182 | void dump() { |
183 | LLVM_DEBUG(dbgs() << "Reduction:\n" ; |
184 | for (auto *Add : Adds) |
185 | LLVM_DEBUG(dbgs() << *Add << "\n" ); |
186 | for (auto &Mul : Muls) |
187 | LLVM_DEBUG(dbgs() << *Mul->Root << "\n" |
188 | << " " << *Mul->LHS << "\n" |
189 | << " " << *Mul->RHS << "\n" ); |
190 | LLVM_DEBUG(if (Acc) dbgs() << "Acc in: " << *Acc << "\n" ) |
191 | ); |
192 | } |
193 | }; |
194 | |
195 | class WidenedLoad { |
196 | LoadInst *NewLd = nullptr; |
197 | SmallVector<LoadInst*, 4> Loads; |
198 | |
199 | public: |
200 | WidenedLoad(SmallVectorImpl<LoadInst*> &Lds, LoadInst *Wide) |
201 | : NewLd(Wide) { |
202 | append_range(C&: Loads, R&: Lds); |
203 | } |
204 | LoadInst *getLoad() { |
205 | return NewLd; |
206 | } |
207 | }; |
208 | |
209 | class ARMParallelDSP : public FunctionPass { |
210 | ScalarEvolution *SE; |
211 | AliasAnalysis *AA; |
212 | TargetLibraryInfo *TLI; |
213 | DominatorTree *DT; |
214 | const DataLayout *DL; |
215 | Module *M; |
216 | std::map<LoadInst*, LoadInst*> LoadPairs; |
217 | SmallPtrSet<LoadInst*, 4> OffsetLoads; |
218 | std::map<LoadInst*, std::unique_ptr<WidenedLoad>> WideLoads; |
219 | |
220 | template<unsigned> |
221 | bool IsNarrowSequence(Value *V); |
222 | bool Search(Value *V, BasicBlock *BB, Reduction &R); |
223 | bool RecordMemoryOps(BasicBlock *BB); |
224 | void InsertParallelMACs(Reduction &Reduction); |
225 | bool AreSequentialLoads(LoadInst *Ld0, LoadInst *Ld1, MemInstList &VecMem); |
226 | LoadInst* CreateWideLoad(MemInstList &Loads, IntegerType *LoadTy); |
227 | bool CreateParallelPairs(Reduction &R); |
228 | |
229 | /// Try to match and generate: SMLAD, SMLADX - Signed Multiply Accumulate |
230 | /// Dual performs two signed 16x16-bit multiplications. It adds the |
231 | /// products to a 32-bit accumulate operand. Optionally, the instruction can |
232 | /// exchange the halfwords of the second operand before performing the |
233 | /// arithmetic. |
234 | bool MatchSMLAD(Function &F); |
235 | |
236 | public: |
237 | static char ID; |
238 | |
239 | ARMParallelDSP() : FunctionPass(ID) { } |
240 | |
241 | void getAnalysisUsage(AnalysisUsage &AU) const override { |
242 | FunctionPass::getAnalysisUsage(AU); |
243 | AU.addRequired<AssumptionCacheTracker>(); |
244 | AU.addRequired<ScalarEvolutionWrapperPass>(); |
245 | AU.addRequired<AAResultsWrapperPass>(); |
246 | AU.addRequired<TargetLibraryInfoWrapperPass>(); |
247 | AU.addRequired<DominatorTreeWrapperPass>(); |
248 | AU.addRequired<TargetPassConfig>(); |
249 | AU.addPreserved<ScalarEvolutionWrapperPass>(); |
250 | AU.addPreserved<GlobalsAAWrapperPass>(); |
251 | AU.setPreservesCFG(); |
252 | } |
253 | |
254 | bool runOnFunction(Function &F) override { |
255 | if (DisableParallelDSP) |
256 | return false; |
257 | if (skipFunction(F)) |
258 | return false; |
259 | |
260 | SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE(); |
261 | AA = &getAnalysis<AAResultsWrapperPass>().getAAResults(); |
262 | TLI = &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F); |
263 | DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree(); |
264 | auto &TPC = getAnalysis<TargetPassConfig>(); |
265 | |
266 | M = F.getParent(); |
267 | DL = &M->getDataLayout(); |
268 | |
269 | auto &TM = TPC.getTM<TargetMachine>(); |
270 | auto *ST = &TM.getSubtarget<ARMSubtarget>(F); |
271 | |
272 | if (!ST->allowsUnalignedMem()) { |
273 | LLVM_DEBUG(dbgs() << "Unaligned memory access not supported: not " |
274 | "running pass ARMParallelDSP\n" ); |
275 | return false; |
276 | } |
277 | |
278 | if (!ST->hasDSP()) { |
279 | LLVM_DEBUG(dbgs() << "DSP extension not enabled: not running pass " |
280 | "ARMParallelDSP\n" ); |
281 | return false; |
282 | } |
283 | |
284 | if (!ST->isLittle()) { |
285 | LLVM_DEBUG(dbgs() << "Only supporting little endian: not running pass " |
286 | << "ARMParallelDSP\n" ); |
287 | return false; |
288 | } |
289 | |
290 | LLVM_DEBUG(dbgs() << "\n== Parallel DSP pass ==\n" ); |
291 | LLVM_DEBUG(dbgs() << " - " << F.getName() << "\n\n" ); |
292 | |
293 | bool Changes = MatchSMLAD(F); |
294 | return Changes; |
295 | } |
296 | }; |
297 | } |
298 | |
299 | bool ARMParallelDSP::AreSequentialLoads(LoadInst *Ld0, LoadInst *Ld1, |
300 | MemInstList &VecMem) { |
301 | if (!Ld0 || !Ld1) |
302 | return false; |
303 | |
304 | if (!LoadPairs.count(x: Ld0) || LoadPairs[Ld0] != Ld1) |
305 | return false; |
306 | |
307 | LLVM_DEBUG(dbgs() << "Loads are sequential and valid:\n" ; |
308 | dbgs() << "Ld0:" ; Ld0->dump(); |
309 | dbgs() << "Ld1:" ; Ld1->dump(); |
310 | ); |
311 | |
312 | VecMem.clear(); |
313 | VecMem.push_back(Elt: Ld0); |
314 | VecMem.push_back(Elt: Ld1); |
315 | return true; |
316 | } |
317 | |
318 | // MaxBitwidth: the maximum supported bitwidth of the elements in the DSP |
319 | // instructions, which is set to 16. So here we should collect all i8 and i16 |
320 | // narrow operations. |
321 | // TODO: we currently only collect i16, and will support i8 later, so that's |
322 | // why we check that types are equal to MaxBitWidth, and not <= MaxBitWidth. |
323 | template<unsigned MaxBitWidth> |
324 | bool ARMParallelDSP::IsNarrowSequence(Value *V) { |
325 | if (auto *SExt = dyn_cast<SExtInst>(Val: V)) { |
326 | if (SExt->getSrcTy()->getIntegerBitWidth() != MaxBitWidth) |
327 | return false; |
328 | |
329 | if (auto *Ld = dyn_cast<LoadInst>(Val: SExt->getOperand(i_nocapture: 0))) { |
330 | // Check that this load could be paired. |
331 | return LoadPairs.count(x: Ld) || OffsetLoads.count(Ptr: Ld); |
332 | } |
333 | } |
334 | return false; |
335 | } |
336 | |
337 | /// Iterate through the block and record base, offset pairs of loads which can |
338 | /// be widened into a single load. |
339 | bool ARMParallelDSP::RecordMemoryOps(BasicBlock *BB) { |
340 | SmallVector<LoadInst*, 8> Loads; |
341 | SmallVector<Instruction*, 8> Writes; |
342 | LoadPairs.clear(); |
343 | WideLoads.clear(); |
344 | |
345 | // Collect loads and instruction that may write to memory. For now we only |
346 | // record loads which are simple, sign-extended and have a single user. |
347 | // TODO: Allow zero-extended loads. |
348 | for (auto &I : *BB) { |
349 | if (I.mayWriteToMemory()) |
350 | Writes.push_back(Elt: &I); |
351 | auto *Ld = dyn_cast<LoadInst>(Val: &I); |
352 | if (!Ld || !Ld->isSimple() || |
353 | !Ld->hasOneUse() || !isa<SExtInst>(Val: Ld->user_back())) |
354 | continue; |
355 | Loads.push_back(Elt: Ld); |
356 | } |
357 | |
358 | if (Loads.empty() || Loads.size() > NumLoadLimit) |
359 | return false; |
360 | |
361 | using InstSet = std::set<Instruction*>; |
362 | using DepMap = std::map<Instruction*, InstSet>; |
363 | DepMap RAWDeps; |
364 | |
365 | // Record any writes that may alias a load. |
366 | const auto Size = LocationSize::beforeOrAfterPointer(); |
367 | for (auto *Write : Writes) { |
368 | for (auto *Read : Loads) { |
369 | MemoryLocation ReadLoc = |
370 | MemoryLocation(Read->getPointerOperand(), Size); |
371 | |
372 | if (!isModOrRefSet(MRI: AA->getModRefInfo(I: Write, OptLoc: ReadLoc))) |
373 | continue; |
374 | if (Write->comesBefore(Other: Read)) |
375 | RAWDeps[Read].insert(x: Write); |
376 | } |
377 | } |
378 | |
379 | // Check whether there's not a write between the two loads which would |
380 | // prevent them from being safely merged. |
381 | auto SafeToPair = [&](LoadInst *Base, LoadInst *Offset) { |
382 | bool BaseFirst = Base->comesBefore(Other: Offset); |
383 | LoadInst *Dominator = BaseFirst ? Base : Offset; |
384 | LoadInst *Dominated = BaseFirst ? Offset : Base; |
385 | |
386 | if (RAWDeps.count(x: Dominated)) { |
387 | InstSet &WritesBefore = RAWDeps[Dominated]; |
388 | |
389 | for (auto *Before : WritesBefore) { |
390 | // We can't move the second load backward, past a write, to merge |
391 | // with the first load. |
392 | if (Dominator->comesBefore(Other: Before)) |
393 | return false; |
394 | } |
395 | } |
396 | return true; |
397 | }; |
398 | |
399 | // Record base, offset load pairs. |
400 | for (auto *Base : Loads) { |
401 | for (auto *Offset : Loads) { |
402 | if (Base == Offset || OffsetLoads.count(Ptr: Offset)) |
403 | continue; |
404 | |
405 | if (isConsecutiveAccess(A: Base, B: Offset, DL: *DL, SE&: *SE) && |
406 | SafeToPair(Base, Offset)) { |
407 | LoadPairs[Base] = Offset; |
408 | OffsetLoads.insert(Ptr: Offset); |
409 | break; |
410 | } |
411 | } |
412 | } |
413 | |
414 | LLVM_DEBUG(if (!LoadPairs.empty()) { |
415 | dbgs() << "Consecutive load pairs:\n" ; |
416 | for (auto &MapIt : LoadPairs) { |
417 | LLVM_DEBUG(dbgs() << *MapIt.first << ", " |
418 | << *MapIt.second << "\n" ); |
419 | } |
420 | }); |
421 | return LoadPairs.size() > 1; |
422 | } |
423 | |
424 | // Search recursively back through the operands to find a tree of values that |
425 | // form a multiply-accumulate chain. The search records the Add and Mul |
426 | // instructions that form the reduction and allows us to find a single value |
427 | // to be used as the initial input to the accumlator. |
428 | bool ARMParallelDSP::Search(Value *V, BasicBlock *BB, Reduction &R) { |
429 | // If we find a non-instruction, try to use it as the initial accumulator |
430 | // value. This may have already been found during the search in which case |
431 | // this function will return false, signaling a search fail. |
432 | auto *I = dyn_cast<Instruction>(Val: V); |
433 | if (!I) |
434 | return R.InsertAcc(V); |
435 | |
436 | if (I->getParent() != BB) |
437 | return false; |
438 | |
439 | switch (I->getOpcode()) { |
440 | default: |
441 | break; |
442 | case Instruction::PHI: |
443 | // Could be the accumulator value. |
444 | return R.InsertAcc(V); |
445 | case Instruction::Add: { |
446 | // Adds should be adding together two muls, or another add and a mul to |
447 | // be within the mac chain. One of the operands may also be the |
448 | // accumulator value at which point we should stop searching. |
449 | R.InsertAdd(I); |
450 | Value *LHS = I->getOperand(i: 0); |
451 | Value *RHS = I->getOperand(i: 1); |
452 | bool ValidLHS = Search(V: LHS, BB, R); |
453 | bool ValidRHS = Search(V: RHS, BB, R); |
454 | |
455 | if (ValidLHS && ValidRHS) |
456 | return true; |
457 | |
458 | // Ensure we don't add the root as the incoming accumulator. |
459 | if (R.getRoot() == I) |
460 | return false; |
461 | |
462 | return R.InsertAcc(V: I); |
463 | } |
464 | case Instruction::Mul: { |
465 | Value *MulOp0 = I->getOperand(i: 0); |
466 | Value *MulOp1 = I->getOperand(i: 1); |
467 | return IsNarrowSequence<16>(V: MulOp0) && IsNarrowSequence<16>(V: MulOp1); |
468 | } |
469 | case Instruction::SExt: |
470 | return Search(V: I->getOperand(i: 0), BB, R); |
471 | } |
472 | return false; |
473 | } |
474 | |
475 | // The pass needs to identify integer add/sub reductions of 16-bit vector |
476 | // multiplications. |
477 | // To use SMLAD: |
478 | // 1) we first need to find integer add then look for this pattern: |
479 | // |
480 | // acc0 = ... |
481 | // ld0 = load i16 |
482 | // sext0 = sext i16 %ld0 to i32 |
483 | // ld1 = load i16 |
484 | // sext1 = sext i16 %ld1 to i32 |
485 | // mul0 = mul %sext0, %sext1 |
486 | // ld2 = load i16 |
487 | // sext2 = sext i16 %ld2 to i32 |
488 | // ld3 = load i16 |
489 | // sext3 = sext i16 %ld3 to i32 |
490 | // mul1 = mul i32 %sext2, %sext3 |
491 | // add0 = add i32 %mul0, %acc0 |
492 | // acc1 = add i32 %add0, %mul1 |
493 | // |
494 | // Which can be selected to: |
495 | // |
496 | // ldr r0 |
497 | // ldr r1 |
498 | // smlad r2, r0, r1, r2 |
499 | // |
500 | // If constants are used instead of loads, these will need to be hoisted |
501 | // out and into a register. |
502 | // |
503 | // If loop invariants are used instead of loads, these need to be packed |
504 | // before the loop begins. |
505 | // |
506 | bool ARMParallelDSP::MatchSMLAD(Function &F) { |
507 | bool Changed = false; |
508 | |
509 | for (auto &BB : F) { |
510 | SmallPtrSet<Instruction*, 4> AllAdds; |
511 | if (!RecordMemoryOps(BB: &BB)) |
512 | continue; |
513 | |
514 | for (Instruction &I : reverse(C&: BB)) { |
515 | if (I.getOpcode() != Instruction::Add) |
516 | continue; |
517 | |
518 | if (AllAdds.count(Ptr: &I)) |
519 | continue; |
520 | |
521 | const auto *Ty = I.getType(); |
522 | if (!Ty->isIntegerTy(Bitwidth: 32) && !Ty->isIntegerTy(Bitwidth: 64)) |
523 | continue; |
524 | |
525 | Reduction R(&I); |
526 | if (!Search(V: &I, BB: &BB, R)) |
527 | continue; |
528 | |
529 | R.InsertMuls(); |
530 | LLVM_DEBUG(dbgs() << "After search, Reduction:\n" ; R.dump()); |
531 | |
532 | if (!CreateParallelPairs(R)) |
533 | continue; |
534 | |
535 | InsertParallelMACs(Reduction&: R); |
536 | Changed = true; |
537 | AllAdds.insert(I: R.getAdds().begin(), E: R.getAdds().end()); |
538 | LLVM_DEBUG(dbgs() << "BB after inserting parallel MACs:\n" << BB); |
539 | } |
540 | } |
541 | |
542 | return Changed; |
543 | } |
544 | |
545 | bool ARMParallelDSP::CreateParallelPairs(Reduction &R) { |
546 | |
547 | // Not enough mul operations to make a pair. |
548 | if (R.getMuls().size() < 2) |
549 | return false; |
550 | |
551 | // Check that the muls operate directly upon sign extended loads. |
552 | for (auto &MulCand : R.getMuls()) { |
553 | if (!MulCand->HasTwoLoadInputs()) |
554 | return false; |
555 | } |
556 | |
557 | auto CanPair = [&](Reduction &R, MulCandidate *PMul0, MulCandidate *PMul1) { |
558 | // The first elements of each vector should be loads with sexts. If we |
559 | // find that its two pairs of consecutive loads, then these can be |
560 | // transformed into two wider loads and the users can be replaced with |
561 | // DSP intrinsics. |
562 | auto Ld0 = static_cast<LoadInst*>(PMul0->LHS); |
563 | auto Ld1 = static_cast<LoadInst*>(PMul1->LHS); |
564 | auto Ld2 = static_cast<LoadInst*>(PMul0->RHS); |
565 | auto Ld3 = static_cast<LoadInst*>(PMul1->RHS); |
566 | |
567 | // Check that each mul is operating on two different loads. |
568 | if (Ld0 == Ld2 || Ld1 == Ld3) |
569 | return false; |
570 | |
571 | if (AreSequentialLoads(Ld0, Ld1, VecMem&: PMul0->VecLd)) { |
572 | if (AreSequentialLoads(Ld0: Ld2, Ld1: Ld3, VecMem&: PMul1->VecLd)) { |
573 | LLVM_DEBUG(dbgs() << "OK: found two pairs of parallel loads!\n" ); |
574 | R.AddMulPair(Mul0: PMul0, Mul1: PMul1); |
575 | return true; |
576 | } else if (AreSequentialLoads(Ld0: Ld3, Ld1: Ld2, VecMem&: PMul1->VecLd)) { |
577 | LLVM_DEBUG(dbgs() << "OK: found two pairs of parallel loads!\n" ); |
578 | LLVM_DEBUG(dbgs() << " exchanging Ld2 and Ld3\n" ); |
579 | R.AddMulPair(Mul0: PMul0, Mul1: PMul1, Exchange: true); |
580 | return true; |
581 | } |
582 | } else if (AreSequentialLoads(Ld0: Ld1, Ld1: Ld0, VecMem&: PMul0->VecLd) && |
583 | AreSequentialLoads(Ld0: Ld2, Ld1: Ld3, VecMem&: PMul1->VecLd)) { |
584 | LLVM_DEBUG(dbgs() << "OK: found two pairs of parallel loads!\n" ); |
585 | LLVM_DEBUG(dbgs() << " exchanging Ld0 and Ld1\n" ); |
586 | LLVM_DEBUG(dbgs() << " and swapping muls\n" ); |
587 | // Only the second operand can be exchanged, so swap the muls. |
588 | R.AddMulPair(Mul0: PMul1, Mul1: PMul0, Exchange: true); |
589 | return true; |
590 | } |
591 | return false; |
592 | }; |
593 | |
594 | MulCandList &Muls = R.getMuls(); |
595 | const unsigned Elems = Muls.size(); |
596 | for (unsigned i = 0; i < Elems; ++i) { |
597 | MulCandidate *PMul0 = static_cast<MulCandidate*>(Muls[i].get()); |
598 | if (PMul0->Paired) |
599 | continue; |
600 | |
601 | for (unsigned j = 0; j < Elems; ++j) { |
602 | if (i == j) |
603 | continue; |
604 | |
605 | MulCandidate *PMul1 = static_cast<MulCandidate*>(Muls[j].get()); |
606 | if (PMul1->Paired) |
607 | continue; |
608 | |
609 | const Instruction *Mul0 = PMul0->Root; |
610 | const Instruction *Mul1 = PMul1->Root; |
611 | if (Mul0 == Mul1) |
612 | continue; |
613 | |
614 | assert(PMul0 != PMul1 && "expected different chains" ); |
615 | |
616 | if (CanPair(R, PMul0, PMul1)) |
617 | break; |
618 | } |
619 | } |
620 | return !R.getMulPairs().empty(); |
621 | } |
622 | |
623 | void ARMParallelDSP::InsertParallelMACs(Reduction &R) { |
624 | |
625 | auto CreateSMLAD = [&](LoadInst* WideLd0, LoadInst *WideLd1, |
626 | Value *Acc, bool Exchange, |
627 | Instruction *InsertAfter) { |
628 | // Replace the reduction chain with an intrinsic call |
629 | |
630 | Value* Args[] = { WideLd0, WideLd1, Acc }; |
631 | Function *SMLAD = nullptr; |
632 | if (Exchange) |
633 | SMLAD = Acc->getType()->isIntegerTy(Bitwidth: 32) ? |
634 | Intrinsic::getDeclaration(M, id: Intrinsic::arm_smladx) : |
635 | Intrinsic::getDeclaration(M, id: Intrinsic::arm_smlaldx); |
636 | else |
637 | SMLAD = Acc->getType()->isIntegerTy(Bitwidth: 32) ? |
638 | Intrinsic::getDeclaration(M, id: Intrinsic::arm_smlad) : |
639 | Intrinsic::getDeclaration(M, id: Intrinsic::arm_smlald); |
640 | |
641 | IRBuilder<NoFolder> Builder(InsertAfter->getParent(), |
642 | BasicBlock::iterator(InsertAfter)); |
643 | Instruction *Call = Builder.CreateCall(Callee: SMLAD, Args); |
644 | NumSMLAD++; |
645 | return Call; |
646 | }; |
647 | |
648 | // Return the instruction after the dominated instruction. |
649 | auto GetInsertPoint = [this](Value *A, Value *B) { |
650 | assert((isa<Instruction>(A) || isa<Instruction>(B)) && |
651 | "expected at least one instruction" ); |
652 | |
653 | Value *V = nullptr; |
654 | if (!isa<Instruction>(Val: A)) |
655 | V = B; |
656 | else if (!isa<Instruction>(Val: B)) |
657 | V = A; |
658 | else |
659 | V = DT->dominates(Def: cast<Instruction>(Val: A), User: cast<Instruction>(Val: B)) ? B : A; |
660 | |
661 | return &*++BasicBlock::iterator(cast<Instruction>(Val: V)); |
662 | }; |
663 | |
664 | Value *Acc = R.getAccumulator(); |
665 | |
666 | // For any muls that were discovered but not paired, accumulate their values |
667 | // as before. |
668 | IRBuilder<NoFolder> Builder(R.getRoot()->getParent()); |
669 | MulCandList &MulCands = R.getMuls(); |
670 | for (auto &MulCand : MulCands) { |
671 | if (MulCand->Paired) |
672 | continue; |
673 | |
674 | Instruction *Mul = cast<Instruction>(Val: MulCand->Root); |
675 | LLVM_DEBUG(dbgs() << "Accumulating unpaired mul: " << *Mul << "\n" ); |
676 | |
677 | if (R.getType() != Mul->getType()) { |
678 | assert(R.is64Bit() && "expected 64-bit result" ); |
679 | Builder.SetInsertPoint(&*++BasicBlock::iterator(Mul)); |
680 | Mul = cast<Instruction>(Val: Builder.CreateSExt(V: Mul, DestTy: R.getRoot()->getType())); |
681 | } |
682 | |
683 | if (!Acc) { |
684 | Acc = Mul; |
685 | continue; |
686 | } |
687 | |
688 | // If Acc is the original incoming value to the reduction, it could be a |
689 | // phi. But the phi will dominate Mul, meaning that Mul will be the |
690 | // insertion point. |
691 | Builder.SetInsertPoint(GetInsertPoint(Mul, Acc)); |
692 | Acc = Builder.CreateAdd(LHS: Mul, RHS: Acc); |
693 | } |
694 | |
695 | if (!Acc) { |
696 | Acc = R.is64Bit() ? |
697 | ConstantInt::get(Ty: IntegerType::get(C&: M->getContext(), NumBits: 64), V: 0) : |
698 | ConstantInt::get(Ty: IntegerType::get(C&: M->getContext(), NumBits: 32), V: 0); |
699 | } else if (Acc->getType() != R.getType()) { |
700 | Builder.SetInsertPoint(R.getRoot()); |
701 | Acc = Builder.CreateSExt(V: Acc, DestTy: R.getType()); |
702 | } |
703 | |
704 | // Roughly sort the mul pairs in their program order. |
705 | llvm::sort(C&: R.getMulPairs(), Comp: [](auto &PairA, auto &PairB) { |
706 | const Instruction *A = PairA.first->Root; |
707 | const Instruction *B = PairB.first->Root; |
708 | return A->comesBefore(Other: B); |
709 | }); |
710 | |
711 | IntegerType *Ty = IntegerType::get(C&: M->getContext(), NumBits: 32); |
712 | for (auto &Pair : R.getMulPairs()) { |
713 | MulCandidate *LHSMul = Pair.first; |
714 | MulCandidate *RHSMul = Pair.second; |
715 | LoadInst *BaseLHS = LHSMul->getBaseLoad(); |
716 | LoadInst *BaseRHS = RHSMul->getBaseLoad(); |
717 | LoadInst *WideLHS = WideLoads.count(x: BaseLHS) ? |
718 | WideLoads[BaseLHS]->getLoad() : CreateWideLoad(Loads&: LHSMul->VecLd, LoadTy: Ty); |
719 | LoadInst *WideRHS = WideLoads.count(x: BaseRHS) ? |
720 | WideLoads[BaseRHS]->getLoad() : CreateWideLoad(Loads&: RHSMul->VecLd, LoadTy: Ty); |
721 | |
722 | Instruction *InsertAfter = GetInsertPoint(WideLHS, WideRHS); |
723 | InsertAfter = GetInsertPoint(InsertAfter, Acc); |
724 | Acc = CreateSMLAD(WideLHS, WideRHS, Acc, RHSMul->Exchange, InsertAfter); |
725 | } |
726 | R.UpdateRoot(SMLAD: cast<Instruction>(Val: Acc)); |
727 | } |
728 | |
729 | LoadInst* ARMParallelDSP::CreateWideLoad(MemInstList &Loads, |
730 | IntegerType *LoadTy) { |
731 | assert(Loads.size() == 2 && "currently only support widening two loads" ); |
732 | |
733 | LoadInst *Base = Loads[0]; |
734 | LoadInst *Offset = Loads[1]; |
735 | |
736 | Instruction *BaseSExt = dyn_cast<SExtInst>(Val: Base->user_back()); |
737 | Instruction *OffsetSExt = dyn_cast<SExtInst>(Val: Offset->user_back()); |
738 | |
739 | assert((BaseSExt && OffsetSExt) |
740 | && "Loads should have a single, extending, user" ); |
741 | |
742 | std::function<void(Value*, Value*)> MoveBefore = |
743 | [&](Value *A, Value *B) -> void { |
744 | if (!isa<Instruction>(Val: A) || !isa<Instruction>(Val: B)) |
745 | return; |
746 | |
747 | auto *Source = cast<Instruction>(Val: A); |
748 | auto *Sink = cast<Instruction>(Val: B); |
749 | |
750 | if (DT->dominates(Def: Source, User: Sink) || |
751 | Source->getParent() != Sink->getParent() || |
752 | isa<PHINode>(Val: Source) || isa<PHINode>(Val: Sink)) |
753 | return; |
754 | |
755 | Source->moveBefore(MovePos: Sink); |
756 | for (auto &Op : Source->operands()) |
757 | MoveBefore(Op, Source); |
758 | }; |
759 | |
760 | // Insert the load at the point of the original dominating load. |
761 | LoadInst *DomLoad = DT->dominates(Def: Base, User: Offset) ? Base : Offset; |
762 | IRBuilder<NoFolder> IRB(DomLoad->getParent(), |
763 | ++BasicBlock::iterator(DomLoad)); |
764 | |
765 | // Create the wide load, while making sure to maintain the original alignment |
766 | // as this prevents ldrd from being generated when it could be illegal due to |
767 | // memory alignment. |
768 | Value *VecPtr = Base->getPointerOperand(); |
769 | LoadInst *WideLoad = IRB.CreateAlignedLoad(Ty: LoadTy, Ptr: VecPtr, Align: Base->getAlign()); |
770 | |
771 | // Make sure everything is in the correct order in the basic block. |
772 | MoveBefore(Base->getPointerOperand(), VecPtr); |
773 | MoveBefore(VecPtr, WideLoad); |
774 | |
775 | // From the wide load, create two values that equal the original two loads. |
776 | // Loads[0] needs trunc while Loads[1] needs a lshr and trunc. |
777 | // TODO: Support big-endian as well. |
778 | Value *Bottom = IRB.CreateTrunc(V: WideLoad, DestTy: Base->getType()); |
779 | Value *NewBaseSExt = IRB.CreateSExt(V: Bottom, DestTy: BaseSExt->getType()); |
780 | BaseSExt->replaceAllUsesWith(V: NewBaseSExt); |
781 | |
782 | IntegerType *OffsetTy = cast<IntegerType>(Val: Offset->getType()); |
783 | Value *ShiftVal = ConstantInt::get(Ty: LoadTy, V: OffsetTy->getBitWidth()); |
784 | Value *Top = IRB.CreateLShr(LHS: WideLoad, RHS: ShiftVal); |
785 | Value *Trunc = IRB.CreateTrunc(V: Top, DestTy: OffsetTy); |
786 | Value *NewOffsetSExt = IRB.CreateSExt(V: Trunc, DestTy: OffsetSExt->getType()); |
787 | OffsetSExt->replaceAllUsesWith(V: NewOffsetSExt); |
788 | |
789 | LLVM_DEBUG(dbgs() << "From Base and Offset:\n" |
790 | << *Base << "\n" << *Offset << "\n" |
791 | << "Created Wide Load:\n" |
792 | << *WideLoad << "\n" |
793 | << *Bottom << "\n" |
794 | << *NewBaseSExt << "\n" |
795 | << *Top << "\n" |
796 | << *Trunc << "\n" |
797 | << *NewOffsetSExt << "\n" ); |
798 | WideLoads.emplace(args: std::make_pair(x&: Base, |
799 | y: std::make_unique<WidenedLoad>(args&: Loads, args&: WideLoad))); |
800 | return WideLoad; |
801 | } |
802 | |
803 | Pass *llvm::createARMParallelDSPPass() { |
804 | return new ARMParallelDSP(); |
805 | } |
806 | |
807 | char ARMParallelDSP::ID = 0; |
808 | |
809 | INITIALIZE_PASS_BEGIN(ARMParallelDSP, "arm-parallel-dsp" , |
810 | "Transform functions to use DSP intrinsics" , false, false) |
811 | INITIALIZE_PASS_END(ARMParallelDSP, "arm-parallel-dsp" , |
812 | "Transform functions to use DSP intrinsics" , false, false) |
813 | |