1 | //===- NVPTXLowerAggrCopies.cpp - ------------------------------*- C++ -*--===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // \file |
10 | // Lower aggregate copies, memset, memcpy, memmov intrinsics into loops when |
11 | // the size is large or is not a compile-time constant. |
12 | // |
13 | //===----------------------------------------------------------------------===// |
14 | |
15 | #include "NVPTXLowerAggrCopies.h" |
16 | #include "NVPTX.h" |
17 | #include "llvm/Analysis/TargetTransformInfo.h" |
18 | #include "llvm/CodeGen/StackProtector.h" |
19 | #include "llvm/IR/Constants.h" |
20 | #include "llvm/IR/DataLayout.h" |
21 | #include "llvm/IR/Function.h" |
22 | #include "llvm/IR/Instructions.h" |
23 | #include "llvm/IR/IntrinsicInst.h" |
24 | #include "llvm/IR/Intrinsics.h" |
25 | #include "llvm/IR/LLVMContext.h" |
26 | #include "llvm/IR/Module.h" |
27 | #include "llvm/Transforms/Utils/BasicBlockUtils.h" |
28 | #include "llvm/Transforms/Utils/LowerMemIntrinsics.h" |
29 | |
30 | #define DEBUG_TYPE "nvptx" |
31 | |
32 | using namespace llvm; |
33 | |
34 | namespace { |
35 | |
36 | // actual analysis class, which is a functionpass |
37 | struct NVPTXLowerAggrCopies : public FunctionPass { |
38 | static char ID; |
39 | |
40 | NVPTXLowerAggrCopies() : FunctionPass(ID) {} |
41 | |
42 | void getAnalysisUsage(AnalysisUsage &AU) const override { |
43 | AU.addPreserved<StackProtector>(); |
44 | AU.addRequired<TargetTransformInfoWrapperPass>(); |
45 | } |
46 | |
47 | bool runOnFunction(Function &F) override; |
48 | |
49 | static const unsigned MaxAggrCopySize = 128; |
50 | |
51 | StringRef getPassName() const override { |
52 | return "Lower aggregate copies/intrinsics into loops" ; |
53 | } |
54 | }; |
55 | |
56 | char NVPTXLowerAggrCopies::ID = 0; |
57 | |
58 | bool NVPTXLowerAggrCopies::runOnFunction(Function &F) { |
59 | SmallVector<LoadInst *, 4> AggrLoads; |
60 | SmallVector<MemIntrinsic *, 4> MemCalls; |
61 | |
62 | const DataLayout &DL = F.getDataLayout(); |
63 | LLVMContext &Context = F.getParent()->getContext(); |
64 | const TargetTransformInfo &TTI = |
65 | getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F); |
66 | |
67 | // Collect all aggregate loads and mem* calls. |
68 | for (BasicBlock &BB : F) { |
69 | for (Instruction &I : BB) { |
70 | if (LoadInst *LI = dyn_cast<LoadInst>(Val: &I)) { |
71 | if (!LI->hasOneUse()) |
72 | continue; |
73 | |
74 | if (DL.getTypeStoreSize(Ty: LI->getType()) < MaxAggrCopySize) |
75 | continue; |
76 | |
77 | if (StoreInst *SI = dyn_cast<StoreInst>(Val: LI->user_back())) { |
78 | if (SI->getOperand(i_nocapture: 0) != LI) |
79 | continue; |
80 | AggrLoads.push_back(Elt: LI); |
81 | } |
82 | } else if (MemIntrinsic *IntrCall = dyn_cast<MemIntrinsic>(Val: &I)) { |
83 | // Convert intrinsic calls with variable size or with constant size |
84 | // larger than the MaxAggrCopySize threshold. |
85 | if (ConstantInt *LenCI = dyn_cast<ConstantInt>(Val: IntrCall->getLength())) { |
86 | if (LenCI->getZExtValue() >= MaxAggrCopySize) { |
87 | MemCalls.push_back(Elt: IntrCall); |
88 | } |
89 | } else { |
90 | MemCalls.push_back(Elt: IntrCall); |
91 | } |
92 | } |
93 | } |
94 | } |
95 | |
96 | if (AggrLoads.size() == 0 && MemCalls.size() == 0) { |
97 | return false; |
98 | } |
99 | |
100 | // |
101 | // Do the transformation of an aggr load/copy/set to a loop |
102 | // |
103 | for (LoadInst *LI : AggrLoads) { |
104 | auto *SI = cast<StoreInst>(Val: *LI->user_begin()); |
105 | Value *SrcAddr = LI->getOperand(i_nocapture: 0); |
106 | Value *DstAddr = SI->getOperand(i_nocapture: 1); |
107 | unsigned NumLoads = DL.getTypeStoreSize(Ty: LI->getType()); |
108 | ConstantInt *CopyLen = |
109 | ConstantInt::get(Ty: Type::getInt32Ty(C&: Context), V: NumLoads); |
110 | |
111 | createMemCpyLoopKnownSize(/* ConvertedInst */ InsertBefore: SI, |
112 | /* SrcAddr */ SrcAddr, /* DstAddr */ DstAddr, |
113 | /* CopyLen */ CopyLen, |
114 | /* SrcAlign */ LI->getAlign(), |
115 | /* DestAlign */ SI->getAlign(), |
116 | /* SrcIsVolatile */ LI->isVolatile(), |
117 | /* DstIsVolatile */ SI->isVolatile(), |
118 | /* CanOverlap */ true, TTI); |
119 | |
120 | SI->eraseFromParent(); |
121 | LI->eraseFromParent(); |
122 | } |
123 | |
124 | // Transform mem* intrinsic calls. |
125 | for (MemIntrinsic *MemCall : MemCalls) { |
126 | if (MemCpyInst *Memcpy = dyn_cast<MemCpyInst>(Val: MemCall)) { |
127 | expandMemCpyAsLoop(MemCpy: Memcpy, TTI); |
128 | } else if (MemMoveInst *Memmove = dyn_cast<MemMoveInst>(Val: MemCall)) { |
129 | expandMemMoveAsLoop(MemMove: Memmove, TTI); |
130 | } else if (MemSetInst *Memset = dyn_cast<MemSetInst>(Val: MemCall)) { |
131 | expandMemSetAsLoop(MemSet: Memset); |
132 | } |
133 | MemCall->eraseFromParent(); |
134 | } |
135 | |
136 | return true; |
137 | } |
138 | |
139 | } // namespace |
140 | |
141 | INITIALIZE_PASS(NVPTXLowerAggrCopies, "nvptx-lower-aggr-copies" , |
142 | "Lower aggregate copies, and llvm.mem* intrinsics into loops" , |
143 | false, false) |
144 | |
145 | FunctionPass *llvm::createLowerAggrCopies() { |
146 | return new NVPTXLowerAggrCopies(); |
147 | } |
148 | |