1//===-- NVPTXLowerArgs.cpp - Lower arguments ------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Arguments to kernel functions are passed via param space, which imposes
10// certain restrictions:
11// http://docs.nvidia.com/cuda/parallel-thread-execution/#state-spaces
12//
13// Kernel parameters are read-only and accessible only via ld.param
14// instruction, directly or via a pointer.
15//
16// Copying a byval struct into local memory in IR allows us to enforce
17// the param space restrictions, gives the rest of IR a pointer w/o
18// param space restrictions, and gives us an opportunity to eliminate
19// the copy.
20//
21// This pass lowers byval parameters of kernel functions. It rewrites the
22// kernel's signature so that each byval argument is declared directly as a
23// pointer in the param address space (`ptr addrspace(101)`), then adjusts the
24// body to match. The parameter symbols occupy this space when lowered during
25// ISel, so making the IR type honest avoids the need for a cast or intrinsic to
26// reinterpret a generic pointer as a param-space pointer.
27//
28// This pass uses 1 of 3 possible strategies to lower byval parameters:
29//
30// 1. Direct readonly nocapture uses: If we can trace through all the uses and
31// we can convert them all to param AS, then we'll do this. This is useful
32// for pre-SM70 targets where cvta.param is not available.
33//
34// 2. Grid constant: If the argument is a grid constant (and the target supports
35// cvta.param), we can cast back to generic address space to use the pointer
36// directly.
37//
38// 3. Local copy: If we can't trace through all the uses and we can't convert
39// them all to param AS, then we'll create a local copy of the argument in
40// local memory. This is useful for arguments that are mutated.
41//
42//===----------------------------------------------------------------------===//
43
44#include "NVPTX.h"
45#include "NVPTXTargetMachine.h"
46#include "NVPTXUtilities.h"
47#include "NVVMProperties.h"
48#include "llvm/ADT/STLExtras.h"
49#include "llvm/ADT/SmallVectorExtras.h"
50#include "llvm/Analysis/PtrUseVisitor.h"
51#include "llvm/CodeGen/TargetPassConfig.h"
52#include "llvm/IR/Attributes.h"
53#include "llvm/IR/DebugInfo.h"
54#include "llvm/IR/Function.h"
55#include "llvm/IR/IRBuilder.h"
56#include "llvm/IR/Instructions.h"
57#include "llvm/IR/IntrinsicInst.h"
58#include "llvm/IR/Type.h"
59#include "llvm/InitializePasses.h"
60#include "llvm/Pass.h"
61#include "llvm/Support/Debug.h"
62#include "llvm/Support/ErrorHandling.h"
63#include "llvm/Support/NVPTXAddrSpace.h"
64
65#define DEBUG_TYPE "nvptx-lower-args"
66
67using namespace llvm;
68using namespace NVPTXAS;
69
70namespace {
71class NVPTXLowerArgsLegacyPass : public ModulePass {
72 bool runOnModule(Module &M) override;
73
74public:
75 static char ID; // Pass identification, replacement for typeid
76 NVPTXLowerArgsLegacyPass() : ModulePass(ID) {}
77 StringRef getPassName() const override {
78 return "Lower pointer arguments of CUDA kernels";
79 }
80 void getAnalysisUsage(AnalysisUsage &AU) const override {
81 AU.addRequired<TargetPassConfig>();
82 }
83};
84} // namespace
85
86char NVPTXLowerArgsLegacyPass::ID = 1;
87
88INITIALIZE_PASS_BEGIN(NVPTXLowerArgsLegacyPass, "nvptx-lower-args",
89 "Lower arguments (NVPTX)", false, false)
90INITIALIZE_PASS_DEPENDENCY(TargetPassConfig)
91INITIALIZE_PASS_END(NVPTXLowerArgsLegacyPass, "nvptx-lower-args",
92 "Lower arguments (NVPTX)", false, false)
93
94/// Recursively convert the users of a param to the param address space.
95static void convertToParamAS(ArrayRef<Use *> OldUses, Value *Param) {
96 struct IP {
97 Use *OldUse;
98 Value *NewParam;
99 };
100
101 const auto CloneInstInParamAS = [](const IP &I) -> Value * {
102 auto *OldInst = cast<Instruction>(Val: I.OldUse->getUser());
103 if (auto *LI = dyn_cast<LoadInst>(Val: OldInst)) {
104 LI->setOperand(i_nocapture: 0, Val_nocapture: I.NewParam);
105 return LI;
106 }
107 if (auto *GEP = dyn_cast<GetElementPtrInst>(Val: OldInst)) {
108 SmallVector<Value *, 4> Indices(GEP->indices());
109 auto *NewGEP = GetElementPtrInst::Create(
110 PointeeType: GEP->getSourceElementType(), Ptr: I.NewParam, IdxList: Indices, NameStr: GEP->getName(),
111 InsertBefore: GEP->getIterator());
112 NewGEP->setNoWrapFlags(GEP->getNoWrapFlags());
113 return NewGEP;
114 }
115 if (auto *BC = dyn_cast<BitCastInst>(Val: OldInst)) {
116 auto *NewBCType =
117 PointerType::get(C&: BC->getContext(), AddressSpace: ADDRESS_SPACE_ENTRY_PARAM);
118 return BitCastInst::Create(BC->getOpcode(), S: I.NewParam, Ty: NewBCType,
119 Name: BC->getName(), InsertBefore: BC->getIterator());
120 }
121 if (auto *ASC = dyn_cast<AddrSpaceCastInst>(Val: OldInst)) {
122 assert(ASC->getDestAddressSpace() == ADDRESS_SPACE_ENTRY_PARAM);
123 (void)ASC;
124 // Just pass through the argument, the old ASC is no longer needed.
125 return I.NewParam;
126 }
127 if (auto *MI = dyn_cast<MemTransferInst>(Val: OldInst)) {
128 if (MI->getRawSource() == I.OldUse->get()) {
129 // convert to memcpy/memmove from param space.
130 IRBuilder<> Builder(OldInst);
131 Intrinsic::ID ID = MI->getIntrinsicID();
132
133 CallInst *B = Builder.CreateMemTransferInst(
134 IntrID: ID, Dst: MI->getRawDest(), DstAlign: MI->getDestAlign(), Src: I.NewParam,
135 SrcAlign: MI->getSourceAlign(), Size: MI->getLength(), isVolatile: MI->isVolatile());
136 for (unsigned I : {0, 1})
137 if (uint64_t Bytes = MI->getParamDereferenceableBytes(i: I))
138 B->addDereferenceableParamAttr(i: I, Bytes);
139 return B;
140 }
141 }
142
143 llvm_unreachable("Unsupported instruction");
144 };
145
146 auto ItemsToConvert =
147 map_to_vector(C&: OldUses, F: [=](Use *U) -> IP { return {.OldUse: U, .NewParam: Param}; });
148 SmallVector<Instruction *> InstructionsToDelete;
149
150 while (!ItemsToConvert.empty()) {
151 IP I = ItemsToConvert.pop_back_val();
152 Value *NewInst = CloneInstInParamAS(I);
153 Instruction *OldInst = cast<Instruction>(Val: I.OldUse->getUser());
154
155 if (NewInst && NewInst != OldInst) {
156 // We've created a new instruction. Queue users of the old instruction to
157 // be converted and the instruction itself to be deleted. We can't delete
158 // the old instruction yet, because it's still in use by a load somewhere.
159 for (Use &U : OldInst->uses())
160 ItemsToConvert.push_back(Elt: {.OldUse: &U, .NewParam: NewInst});
161
162 InstructionsToDelete.push_back(Elt: OldInst);
163 }
164 }
165
166 // Now we know that all argument loads are using addresses in parameter space
167 // and we can finally remove the old instructions in generic AS. Instructions
168 // scheduled for removal should be processed in reverse order so the ones
169 // closest to the load are deleted first. Otherwise they may still be in use.
170 // E.g if we have Value = Load(BitCast(GEP(arg))), InstructionsToDelete will
171 // have {GEP,BitCast}. GEP can't be deleted first, because it's still used by
172 // the BitCast.
173 for (Instruction *I : llvm::reverse(C&: InstructionsToDelete))
174 I->eraseFromParent();
175}
176
177namespace {
178struct ArgUseChecker : PtrUseVisitor<ArgUseChecker> {
179 using Base = PtrUseVisitor<ArgUseChecker>;
180 // Set of phi/select instructions using the Arg
181 SmallPtrSet<Instruction *, 4> Conditionals;
182
183 ArgUseChecker(const DataLayout &DL) : PtrUseVisitor(DL) {}
184
185 PtrInfo visitArgPtr(Argument &A) {
186 assert(A.getType()->isPointerTy());
187 IntegerType *IntIdxTy = cast<IntegerType>(Val: DL.getIndexType(PtrTy: A.getType()));
188 IsOffsetKnown = false;
189 Offset = APInt(IntIdxTy->getBitWidth(), 0);
190 PI.reset();
191
192 LLVM_DEBUG(dbgs() << "Checking Argument " << A << "\n");
193 // Enqueue the uses of this pointer.
194 enqueueUsers(I&: A);
195
196 // Visit all the uses off the worklist until it is empty.
197 // Note that unlike PtrUseVisitor we intentionally do not track offsets.
198 // We're only interested in how we use the pointer.
199 while (!(Worklist.empty() || PI.isAborted())) {
200 UseToVisit ToVisit = Worklist.pop_back_val();
201 U = ToVisit.UseAndIsOffsetKnown.getPointer();
202 Instruction *I = cast<Instruction>(Val: U->getUser());
203 LLVM_DEBUG(dbgs() << "Processing " << *I << "\n");
204 Base::visit(I);
205 }
206 if (PI.isEscaped())
207 LLVM_DEBUG(dbgs() << "Argument pointer escaped: " << *PI.getEscapingInst()
208 << "\n");
209 else if (PI.isAborted())
210 LLVM_DEBUG(dbgs() << "Pointer use needs a copy: " << *PI.getAbortingInst()
211 << "\n");
212 LLVM_DEBUG(dbgs() << "Traversed " << Conditionals.size()
213 << " conditionals\n");
214 return PI;
215 }
216
217 void visitStoreInst(StoreInst &SI) {
218 // Storing the pointer escapes it.
219 if (U->get() == SI.getValueOperand())
220 return PI.setEscapedAndAborted(&SI);
221
222 PI.setAborted(&SI);
223 }
224
225 void visitAddrSpaceCastInst(AddrSpaceCastInst &ASC) {
226 // ASC to param space are no-ops and do not need a copy
227 if (ASC.getDestAddressSpace() != ADDRESS_SPACE_ENTRY_PARAM)
228 return PI.setEscapedAndAborted(&ASC);
229 Base::visitAddrSpaceCastInst(ASC);
230 }
231
232 void visitPtrToIntInst(PtrToIntInst &I) { Base::visitPtrToIntInst(I); }
233
234 void visitPHINodeOrSelectInst(Instruction &I) {
235 assert(isa<PHINode>(I) || isa<SelectInst>(I));
236 enqueueUsers(I);
237 Conditionals.insert(Ptr: &I);
238 }
239 // PHI and select just pass through the pointers.
240 void visitPHINode(PHINode &PN) { visitPHINodeOrSelectInst(I&: PN); }
241 void visitSelectInst(SelectInst &SI) { visitPHINodeOrSelectInst(I&: SI); }
242
243 // memcpy/memmove are OK when the pointer is source. We can convert them to
244 // AS-specific memcpy.
245 void visitMemTransferInst(MemTransferInst &II) {
246 if (*U == II.getRawDest())
247 PI.setAborted(&II);
248 }
249
250 void visitMemSetInst(MemSetInst &II) { PI.setAborted(&II); }
251}; // struct ArgUseChecker
252
253// Create a local copy of the byval parameter \p Arg in an alloca, filled by a
254// copy from \p ParamPtr (a pointer to the parameter), and replace all uses of
255// \p Arg with the alloca. \p ParamPtr is either the natively param-space
256// argument (when called from the signature rewrite) or the generic byval
257// argument itself (when called early, before the signature has been rewritten).
258void copyByValParam(Function &F, Argument &Arg, Value &ParamPtr) {
259 LLVM_DEBUG(dbgs() << "Creating a local copy of " << Arg << "\n");
260 Type *ByValType = Arg.getParamByValType();
261 const DataLayout &DL = F.getDataLayout();
262 IRBuilder<> IRB(&F.getEntryBlock().front());
263 AllocaInst *AllocA = IRB.CreateAlloca(Ty: ByValType, ArraySize: nullptr, Name: Arg.getName());
264 // Set the alignment to alignment of the byval parameter. This is because,
265 // later load/stores assume that alignment, and we are going to replace
266 // the use of the byval parameter with this alloca instruction.
267 AllocA->setAlignment(
268 Arg.getParamAlign().value_or(u: DL.getPrefTypeAlign(Ty: ByValType)));
269 Arg.replaceAllUsesWith(V: AllocA);
270
271 // Be sure to propagate alignment to this copy; LLVM doesn't know that NVPTX
272 // addrspacecast preserves alignment. Since params are constant, this copy
273 // is definitely not volatile.
274 const auto ArgSize = *AllocA->getAllocationSize(DL);
275 IRB.CreateMemCpy(Dst: AllocA, DstAlign: AllocA->getAlign(), Src: &ParamPtr, SrcAlign: AllocA->getAlign(),
276 Size: ArgSize);
277}
278} // namespace
279
280// Returns true if F has a byval argument not yet in the param address space.
281// Such arguments are lowered exactly once, so one already in param space means
282// the kernel has already been processed.
283static bool kernelNeedsByValLowering(const Function &F) {
284 return any_of(Range: F.args(), P: [](const Argument &A) {
285 return A.hasByValAttr() &&
286 A.getType()->getPointerAddressSpace() != ADDRESS_SPACE_ENTRY_PARAM;
287 });
288}
289
290// Lower the uses of a single kernel byval argument. \p OldArg is the original
291// (generic) argument whose uses are being rewritten; \p NewParamArg is its
292// replacement, natively in the param address space.
293static void lowerKernelByValParam(Argument &OldArg, Argument &NewParamArg,
294 Function &F, const bool HasCvtaParam) {
295 assert(isKernelFunction(F));
296
297 const DataLayout &DL = F.getDataLayout();
298 IRBuilder<> IRB(&F.getEntryBlock().front());
299
300 if (OldArg.use_empty())
301 return;
302
303 // (1) First check the easy case, if were able to trace through all the uses
304 // and we can convert them all to param AS, then we'll do this.
305 ArgUseChecker AUC(DL);
306 ArgUseChecker::PtrInfo PI = AUC.visitArgPtr(A&: OldArg);
307 const bool ArgUseIsReadOnly = !(PI.isEscaped() || PI.isAborted());
308 if (ArgUseIsReadOnly && AUC.Conditionals.empty()) {
309 // Convert all loads and intermediate operations to use parameter AS and
310 // skip creation of a local copy of the argument.
311 SmallVector<Use *, 16> UsesToUpdate(make_pointer_range(Range: OldArg.uses()));
312 for (Use *U : UsesToUpdate)
313 convertToParamAS(OldUses: U, Param: &NewParamArg);
314 // This path does not replaceAllUsesWith the old argument, so any debug-info
315 // uses would be left dangling and reset to poison when the old function is
316 // erased. Point them at the new param-space argument instead.
317 if (OldArg.isUsedByMetadata()) {
318 SmallVector<DbgVariableRecord *, 4> DbgUsers;
319 findDbgUsers(V: &OldArg, DbgVariableRecords&: DbgUsers);
320 for (DbgVariableRecord *DVR : DbgUsers)
321 DVR->replaceVariableLocationOp(OldValue: &OldArg, NewValue: &NewParamArg);
322 }
323 return;
324 }
325
326 // (2) If the argument is grid constant, we get to use the pointer directly.
327 if (HasCvtaParam && (ArgUseIsReadOnly || isParamGridConstant(OldArg))) {
328 LLVM_DEBUG(dbgs() << "Using non-copy pointer to " << OldArg << "\n");
329
330 // Cast the param-space argument to the generic address space. Because the
331 // argument is natively in param space, this cast only ever goes
332 // param -> generic and lowers to cvta.param; there is no inverse cast for
333 // InferAddressSpaces to fold it away with.
334 Value *GenericArg = IRB.CreateAddrSpaceCast(
335 V: &NewParamArg, DestTy: IRB.getPtrTy(AddrSpace: ADDRESS_SPACE_GENERIC),
336 Name: OldArg.getName() + ".gen");
337
338 OldArg.replaceAllUsesWith(V: GenericArg);
339 return;
340 }
341
342 // (3) Otherwise we have to create a copy of the argument in local memory.
343 copyByValParam(F, Arg&: OldArg, ParamPtr&: NewParamArg);
344}
345
346// Rewrite a kernel's signature so that each byval argument is declared directly
347// as a pointer in the param address space, then lower the body to match. This
348// creates a new function, moves the body across, and erases \p F.
349static void rewriteKernelByValSignature(Function &F, const bool HasCvtaParam) {
350 LLVMContext &Ctx = F.getContext();
351 FunctionType *FTy = F.getFunctionType();
352
353 // Build the new signature: byval pointer arguments move to the param address
354 // space; all other arguments are unchanged.
355 SmallVector<Type *> Params(FTy->params());
356 for (const Argument &Arg : F.args())
357 if (Arg.hasByValAttr())
358 Params[Arg.getArgNo()] = PointerType::get(C&: Ctx, AddressSpace: ADDRESS_SPACE_ENTRY_PARAM);
359
360 Function *NF = Function::Create(
361 Ty: FunctionType::get(Result: FTy->getReturnType(), Params, isVarArg: FTy->isVarArg()),
362 Linkage: F.getLinkage(), AddrSpace: F.getAddressSpace());
363 NF->copyAttributesFrom(Src: &F);
364 NF->setComdat(F.getComdat());
365 F.getParent()->getFunctionList().insert(where: F.getIterator(), New: NF);
366
367 // ISel reads the param symbol directly for kernel byval arguments; this is
368 // valid because the signature rewrite above puts them in the param address
369 // space. Mark them readonly: any mutation is redirected to a local copy
370 // below, so the param itself is never written.
371 for (Argument &NewArg : NF->args())
372 if (NewArg.hasByValAttr())
373 NewArg.addAttr(Kind: Attribute::ReadOnly);
374
375 // Take over F's name and uses (e.g. @llvm.used, nvvm.annotations metadata),
376 // then move the body across.
377 F.replaceAllUsesWith(V: NF);
378 NF->takeName(V: &F);
379 NF->splice(ToIt: NF->begin(), FromF: &F);
380
381 // Remap arguments. Non-byval arguments keep their type and are replaced
382 // directly; byval arguments change address space, so their uses are lowered
383 // to operate on the new param-space argument.
384 for (auto [OldArg, NewArg] : zip_equal(t: F.args(), u: NF->args())) {
385 if (OldArg.hasByValAttr())
386 lowerKernelByValParam(OldArg, NewParamArg&: NewArg, F&: *NF, HasCvtaParam);
387 else
388 OldArg.replaceAllUsesWith(V: &NewArg);
389 NewArg.takeName(V: &OldArg);
390 }
391
392 // Move function-level metadata (debug info, etc.) to the new function.
393 NF->copyMetadata(Src: &F, /*Offset=*/0);
394 F.clearMetadata();
395
396 F.eraseFromParent();
397}
398
399// =============================================================================
400// Main function for this pass.
401// =============================================================================
402static bool processFunction(Function &F, NVPTXTargetMachine &TM) {
403 if (!isKernelFunction(F) || F.isDeclaration())
404 return false;
405
406 // Skip kernels with no byval arguments, and those already lowered (byval
407 // arguments sitting in the param address space).
408 if (!kernelNeedsByValLowering(F))
409 return false;
410
411 LLVM_DEBUG(dbgs() << "Lowering kernel args of " << F.getName() << "\n");
412 const NVPTXSubtarget *ST = TM.getSubtargetImpl(F);
413 rewriteKernelByValSignature(F, HasCvtaParam: ST->hasCvtaParam());
414 return true;
415}
416
417static bool processModule(Module &M, NVPTXTargetMachine &TM) {
418 bool Changed = false;
419 for (Function &F : make_early_inc_range(Range&: M))
420 Changed |= processFunction(F, TM);
421 return Changed;
422}
423
424bool NVPTXLowerArgsLegacyPass::runOnModule(Module &M) {
425 auto &TM = getAnalysis<TargetPassConfig>().getTM<NVPTXTargetMachine>();
426 return processModule(M, TM);
427}
428
429ModulePass *llvm::createNVPTXLowerArgsPass() {
430 return new NVPTXLowerArgsLegacyPass();
431}
432
433static bool copyFunctionByValArgs(Function &F) {
434 LLVM_DEBUG(dbgs() << "Creating a copy of byval args of " << F.getName()
435 << "\n");
436 bool Changed = false;
437 if (isKernelFunction(F)) {
438 for (Argument &Arg : F.args())
439 if (Arg.hasByValAttr() && !isParamGridConstant(Arg)) {
440 copyByValParam(F, Arg, ParamPtr&: Arg);
441 Changed = true;
442 }
443 }
444 return Changed;
445}
446
447PreservedAnalyses NVPTXCopyByValArgsPass::run(Function &F,
448 FunctionAnalysisManager &AM) {
449 return copyFunctionByValArgs(F) ? PreservedAnalyses::none()
450 : PreservedAnalyses::all();
451}
452
453PreservedAnalyses NVPTXLowerArgsPass::run(Module &M,
454 ModuleAnalysisManager &AM) {
455 auto &NTM = static_cast<NVPTXTargetMachine &>(TM);
456 bool Changed = processModule(M, TM&: NTM);
457 return Changed ? PreservedAnalyses::none() : PreservedAnalyses::all();
458}
459