1 | //===- NVPTXLowerAggrCopies.cpp - ------------------------------*- C++ -*--===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // \file |
10 | // Lower aggregate copies, memset, memcpy, memmov intrinsics into loops when |
11 | // the size is large or is not a compile-time constant. |
12 | // |
13 | //===----------------------------------------------------------------------===// |
14 | |
15 | #include "NVPTXLowerAggrCopies.h" |
16 | #include "llvm/Analysis/TargetTransformInfo.h" |
17 | #include "llvm/CodeGen/StackProtector.h" |
18 | #include "llvm/IR/Constants.h" |
19 | #include "llvm/IR/DataLayout.h" |
20 | #include "llvm/IR/Function.h" |
21 | #include "llvm/IR/IRBuilder.h" |
22 | #include "llvm/IR/Instructions.h" |
23 | #include "llvm/IR/IntrinsicInst.h" |
24 | #include "llvm/IR/Intrinsics.h" |
25 | #include "llvm/IR/LLVMContext.h" |
26 | #include "llvm/IR/Module.h" |
27 | #include "llvm/Support/Debug.h" |
28 | #include "llvm/Transforms/Utils/BasicBlockUtils.h" |
29 | #include "llvm/Transforms/Utils/LowerMemIntrinsics.h" |
30 | |
31 | #define DEBUG_TYPE "nvptx" |
32 | |
33 | using namespace llvm; |
34 | |
35 | namespace { |
36 | |
37 | // actual analysis class, which is a functionpass |
38 | struct NVPTXLowerAggrCopies : public FunctionPass { |
39 | static char ID; |
40 | |
41 | NVPTXLowerAggrCopies() : FunctionPass(ID) {} |
42 | |
43 | void getAnalysisUsage(AnalysisUsage &AU) const override { |
44 | AU.addPreserved<StackProtector>(); |
45 | AU.addRequired<TargetTransformInfoWrapperPass>(); |
46 | } |
47 | |
48 | bool runOnFunction(Function &F) override; |
49 | |
50 | static const unsigned MaxAggrCopySize = 128; |
51 | |
52 | StringRef getPassName() const override { |
53 | return "Lower aggregate copies/intrinsics into loops" ; |
54 | } |
55 | }; |
56 | |
57 | char NVPTXLowerAggrCopies::ID = 0; |
58 | |
59 | bool NVPTXLowerAggrCopies::runOnFunction(Function &F) { |
60 | SmallVector<LoadInst *, 4> AggrLoads; |
61 | SmallVector<MemIntrinsic *, 4> MemCalls; |
62 | |
63 | const DataLayout &DL = F.getDataLayout(); |
64 | LLVMContext &Context = F.getParent()->getContext(); |
65 | const TargetTransformInfo &TTI = |
66 | getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F); |
67 | |
68 | // Collect all aggregate loads and mem* calls. |
69 | for (BasicBlock &BB : F) { |
70 | for (Instruction &I : BB) { |
71 | if (LoadInst *LI = dyn_cast<LoadInst>(Val: &I)) { |
72 | if (!LI->hasOneUse()) |
73 | continue; |
74 | |
75 | if (DL.getTypeStoreSize(Ty: LI->getType()) < MaxAggrCopySize) |
76 | continue; |
77 | |
78 | if (StoreInst *SI = dyn_cast<StoreInst>(Val: LI->user_back())) { |
79 | if (SI->getOperand(i_nocapture: 0) != LI) |
80 | continue; |
81 | AggrLoads.push_back(Elt: LI); |
82 | } |
83 | } else if (MemIntrinsic *IntrCall = dyn_cast<MemIntrinsic>(Val: &I)) { |
84 | // Convert intrinsic calls with variable size or with constant size |
85 | // larger than the MaxAggrCopySize threshold. |
86 | if (ConstantInt *LenCI = dyn_cast<ConstantInt>(Val: IntrCall->getLength())) { |
87 | if (LenCI->getZExtValue() >= MaxAggrCopySize) { |
88 | MemCalls.push_back(Elt: IntrCall); |
89 | } |
90 | } else { |
91 | MemCalls.push_back(Elt: IntrCall); |
92 | } |
93 | } |
94 | } |
95 | } |
96 | |
97 | if (AggrLoads.size() == 0 && MemCalls.size() == 0) { |
98 | return false; |
99 | } |
100 | |
101 | // |
102 | // Do the transformation of an aggr load/copy/set to a loop |
103 | // |
104 | for (LoadInst *LI : AggrLoads) { |
105 | auto *SI = cast<StoreInst>(Val: *LI->user_begin()); |
106 | Value *SrcAddr = LI->getOperand(i_nocapture: 0); |
107 | Value *DstAddr = SI->getOperand(i_nocapture: 1); |
108 | unsigned NumLoads = DL.getTypeStoreSize(Ty: LI->getType()); |
109 | ConstantInt *CopyLen = |
110 | ConstantInt::get(Ty: Type::getInt32Ty(C&: Context), V: NumLoads); |
111 | |
112 | createMemCpyLoopKnownSize(/* ConvertedInst */ InsertBefore: SI, |
113 | /* SrcAddr */ SrcAddr, /* DstAddr */ DstAddr, |
114 | /* CopyLen */ CopyLen, |
115 | /* SrcAlign */ LI->getAlign(), |
116 | /* DestAlign */ SI->getAlign(), |
117 | /* SrcIsVolatile */ LI->isVolatile(), |
118 | /* DstIsVolatile */ SI->isVolatile(), |
119 | /* CanOverlap */ true, TTI); |
120 | |
121 | SI->eraseFromParent(); |
122 | LI->eraseFromParent(); |
123 | } |
124 | |
125 | // Transform mem* intrinsic calls. |
126 | for (MemIntrinsic *MemCall : MemCalls) { |
127 | if (MemCpyInst *Memcpy = dyn_cast<MemCpyInst>(Val: MemCall)) { |
128 | expandMemCpyAsLoop(MemCpy: Memcpy, TTI); |
129 | } else if (MemMoveInst *Memmove = dyn_cast<MemMoveInst>(Val: MemCall)) { |
130 | expandMemMoveAsLoop(MemMove: Memmove, TTI); |
131 | } else if (MemSetInst *Memset = dyn_cast<MemSetInst>(Val: MemCall)) { |
132 | expandMemSetAsLoop(MemSet: Memset); |
133 | } |
134 | MemCall->eraseFromParent(); |
135 | } |
136 | |
137 | return true; |
138 | } |
139 | |
140 | } // namespace |
141 | |
142 | namespace llvm { |
143 | void initializeNVPTXLowerAggrCopiesPass(PassRegistry &); |
144 | } |
145 | |
146 | INITIALIZE_PASS(NVPTXLowerAggrCopies, "nvptx-lower-aggr-copies" , |
147 | "Lower aggregate copies, and llvm.mem* intrinsics into loops" , |
148 | false, false) |
149 | |
150 | FunctionPass *llvm::createLowerAggrCopies() { |
151 | return new NVPTXLowerAggrCopies(); |
152 | } |
153 | |