1//===- NVPTXLowerAggrCopies.cpp - ------------------------------*- C++ -*--===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// \file
10// Lower aggregate copies, memset, memcpy, memmov intrinsics into loops when
11// the size is large or is not a compile-time constant.
12//
13//===----------------------------------------------------------------------===//
14
15#include "NVPTXLowerAggrCopies.h"
16#include "NVPTX.h"
17#include "llvm/Analysis/TargetTransformInfo.h"
18#include "llvm/CodeGen/StackProtector.h"
19#include "llvm/IR/Constants.h"
20#include "llvm/IR/DataLayout.h"
21#include "llvm/IR/Function.h"
22#include "llvm/IR/Instructions.h"
23#include "llvm/IR/IntrinsicInst.h"
24#include "llvm/IR/Intrinsics.h"
25#include "llvm/IR/LLVMContext.h"
26#include "llvm/IR/Module.h"
27#include "llvm/Transforms/Utils/BasicBlockUtils.h"
28#include "llvm/Transforms/Utils/LowerMemIntrinsics.h"
29
30#define DEBUG_TYPE "nvptx"
31
32using namespace llvm;
33
34namespace {
35
36// actual analysis class, which is a functionpass
37struct NVPTXLowerAggrCopies : public FunctionPass {
38 static char ID;
39
40 NVPTXLowerAggrCopies() : FunctionPass(ID) {}
41
42 void getAnalysisUsage(AnalysisUsage &AU) const override {
43 AU.addPreserved<StackProtector>();
44 AU.addRequired<TargetTransformInfoWrapperPass>();
45 }
46
47 bool runOnFunction(Function &F) override;
48
49 static const unsigned MaxAggrCopySize = 128;
50
51 StringRef getPassName() const override {
52 return "Lower aggregate copies/intrinsics into loops";
53 }
54};
55
56char NVPTXLowerAggrCopies::ID = 0;
57
58bool NVPTXLowerAggrCopies::runOnFunction(Function &F) {
59 SmallVector<LoadInst *, 4> AggrLoads;
60 SmallVector<MemIntrinsic *, 4> MemCalls;
61
62 const DataLayout &DL = F.getDataLayout();
63 LLVMContext &Context = F.getParent()->getContext();
64 const TargetTransformInfo &TTI =
65 getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
66
67 // Collect all aggregate loads and mem* calls.
68 for (BasicBlock &BB : F) {
69 for (Instruction &I : BB) {
70 if (LoadInst *LI = dyn_cast<LoadInst>(Val: &I)) {
71 if (!LI->hasOneUse())
72 continue;
73
74 if (DL.getTypeStoreSize(Ty: LI->getType()) < MaxAggrCopySize)
75 continue;
76
77 if (StoreInst *SI = dyn_cast<StoreInst>(Val: LI->user_back())) {
78 if (SI->getOperand(i_nocapture: 0) != LI)
79 continue;
80 AggrLoads.push_back(Elt: LI);
81 }
82 } else if (MemIntrinsic *IntrCall = dyn_cast<MemIntrinsic>(Val: &I)) {
83 // Convert intrinsic calls with variable size or with constant size
84 // larger than the MaxAggrCopySize threshold.
85 if (ConstantInt *LenCI = dyn_cast<ConstantInt>(Val: IntrCall->getLength())) {
86 if (LenCI->getZExtValue() >= MaxAggrCopySize) {
87 MemCalls.push_back(Elt: IntrCall);
88 }
89 } else {
90 MemCalls.push_back(Elt: IntrCall);
91 }
92 }
93 }
94 }
95
96 if (AggrLoads.size() == 0 && MemCalls.size() == 0) {
97 return false;
98 }
99
100 //
101 // Do the transformation of an aggr load/copy/set to a loop
102 //
103 for (LoadInst *LI : AggrLoads) {
104 auto *SI = cast<StoreInst>(Val: *LI->user_begin());
105 Value *SrcAddr = LI->getOperand(i_nocapture: 0);
106 Value *DstAddr = SI->getOperand(i_nocapture: 1);
107 unsigned NumLoads = DL.getTypeStoreSize(Ty: LI->getType());
108 ConstantInt *CopyLen =
109 ConstantInt::get(Ty: Type::getInt32Ty(C&: Context), V: NumLoads);
110
111 createMemCpyLoopKnownSize(/* ConvertedInst */ InsertBefore: SI,
112 /* SrcAddr */ SrcAddr, /* DstAddr */ DstAddr,
113 /* CopyLen */ CopyLen,
114 /* SrcAlign */ LI->getAlign(),
115 /* DestAlign */ SI->getAlign(),
116 /* SrcIsVolatile */ LI->isVolatile(),
117 /* DstIsVolatile */ SI->isVolatile(),
118 /* CanOverlap */ true, TTI);
119
120 SI->eraseFromParent();
121 LI->eraseFromParent();
122 }
123
124 // Transform mem* intrinsic calls.
125 for (MemIntrinsic *MemCall : MemCalls) {
126 if (MemCpyInst *Memcpy = dyn_cast<MemCpyInst>(Val: MemCall)) {
127 expandMemCpyAsLoop(MemCpy: Memcpy, TTI);
128 } else if (MemMoveInst *Memmove = dyn_cast<MemMoveInst>(Val: MemCall)) {
129 expandMemMoveAsLoop(MemMove: Memmove, TTI);
130 } else if (MemSetInst *Memset = dyn_cast<MemSetInst>(Val: MemCall)) {
131 expandMemSetAsLoop(MemSet: Memset);
132 }
133 MemCall->eraseFromParent();
134 }
135
136 return true;
137}
138
139} // namespace
140
141INITIALIZE_PASS(NVPTXLowerAggrCopies, "nvptx-lower-aggr-copies",
142 "Lower aggregate copies, and llvm.mem* intrinsics into loops",
143 false, false)
144
145FunctionPass *llvm::createLowerAggrCopies() {
146 return new NVPTXLowerAggrCopies();
147}
148