1//===-- NVPTXMarkKernelPtrsGlobal.cpp - Mark kernel pointers as global ----===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// For CUDA kernels, pointers loaded from byval parameters are known to be in
10// global address space. This pass inserts addrspacecast pairs to make that
11// explicit, enabling later address-space inference to propagate the global AS.
12// It also handles the pattern where a pointer is loaded as an integer and then
13// converted via inttoptr.
14//
15//===----------------------------------------------------------------------===//
16
17#include "NVPTX.h"
18#include "NVPTXUtilities.h"
19#include "llvm/Analysis/ValueTracking.h"
20#include "llvm/IR/InstIterator.h"
21#include "llvm/IR/Instructions.h"
22#include "llvm/InitializePasses.h"
23#include "llvm/Pass.h"
24#include "llvm/Support/NVPTXAddrSpace.h"
25
26using namespace llvm;
27using namespace NVPTXAS;
28
29static void markPointerAsAS(Value *Ptr, unsigned AS) {
30 if (Ptr->getType()->getPointerAddressSpace() != ADDRESS_SPACE_GENERIC)
31 return;
32
33 BasicBlock::iterator InsertPt;
34 if (auto *Arg = dyn_cast<Argument>(Val: Ptr)) {
35 InsertPt = Arg->getParent()->getEntryBlock().begin();
36 } else {
37 InsertPt = ++cast<Instruction>(Val: Ptr)->getIterator();
38 assert(InsertPt != InsertPt->getParent()->end() &&
39 "We don't call this function with Ptr being a terminator.");
40 }
41
42 Instruction *PtrInGlobal = new AddrSpaceCastInst(
43 Ptr, PointerType::get(C&: Ptr->getContext(), AddressSpace: AS), Ptr->getName(), InsertPt);
44 Value *PtrInGeneric = new AddrSpaceCastInst(PtrInGlobal, Ptr->getType(),
45 Ptr->getName(), InsertPt);
46 Ptr->replaceAllUsesWith(V: PtrInGeneric);
47 PtrInGlobal->setOperand(i: 0, Val: Ptr);
48}
49
50static void markPointerAsGlobal(Value *Ptr) {
51 markPointerAsAS(Ptr, AS: ADDRESS_SPACE_GLOBAL);
52}
53
54static void handleIntToPtr(Value &V) {
55 if (!all_of(Range: V.users(), P: [](User *U) { return isa<IntToPtrInst>(Val: U); }))
56 return;
57
58 SmallVector<User *, 16> UsersToUpdate(V.users());
59 for (User *U : UsersToUpdate)
60 markPointerAsGlobal(Ptr: U);
61}
62
63static bool markKernelPtrsGlobal(Function &F) {
64 if (!isKernelFunction(F))
65 return false;
66
67 // Copying of byval aggregates + SROA may result in pointers being loaded as
68 // integers, followed by inttoptr. We mark those as global too, but only if
69 // the loaded integer is used exclusively for conversion to a pointer.
70 for (auto &I : instructions(F)) {
71 auto *LI = dyn_cast<LoadInst>(Val: &I);
72 if (!LI)
73 continue;
74
75 if (LI->getType()->isPointerTy() || LI->getType()->isIntegerTy()) {
76 Value *UO = getUnderlyingObject(V: LI->getPointerOperand());
77 if (auto *Arg = dyn_cast<Argument>(Val: UO)) {
78 if (Arg->hasByValAttr()) {
79 if (LI->getType()->isPointerTy())
80 markPointerAsGlobal(Ptr: LI);
81 else
82 handleIntToPtr(V&: *LI);
83 }
84 }
85 }
86 }
87
88 for (Argument &Arg : F.args())
89 if (Arg.getType()->isIntegerTy())
90 handleIntToPtr(V&: Arg);
91
92 return true;
93}
94
95namespace {
96
97class NVPTXMarkKernelPtrsGlobalLegacyPass : public FunctionPass {
98public:
99 static char ID;
100 NVPTXMarkKernelPtrsGlobalLegacyPass() : FunctionPass(ID) {}
101 bool runOnFunction(Function &F) override;
102};
103
104} // namespace
105
106INITIALIZE_PASS(NVPTXMarkKernelPtrsGlobalLegacyPass,
107 "nvptx-mark-kernel-ptrs-global",
108 "NVPTX Mark Kernel Pointers Global", false, false)
109
110bool NVPTXMarkKernelPtrsGlobalLegacyPass::runOnFunction(Function &F) {
111 return markKernelPtrsGlobal(F);
112}
113
114char NVPTXMarkKernelPtrsGlobalLegacyPass::ID = 0;
115
116FunctionPass *llvm::createNVPTXMarkKernelPtrsGlobalPass() {
117 return new NVPTXMarkKernelPtrsGlobalLegacyPass();
118}
119
120PreservedAnalyses
121NVPTXMarkKernelPtrsGlobalPass::run(Function &F, FunctionAnalysisManager &) {
122 return markKernelPtrsGlobal(F) ? PreservedAnalyses::none()
123 : PreservedAnalyses::all();
124}
125