1//===-- NVPTXLowerArgs.cpp - Lower arguments ------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9//
10// Arguments to kernel and device functions are passed via param space,
11// which imposes certain restrictions:
12// http://docs.nvidia.com/cuda/parallel-thread-execution/#state-spaces
13//
14// Kernel parameters are read-only and accessible only via ld.param
15// instruction, directly or via a pointer.
16//
17// Device function parameters are directly accessible via
18// ld.param/st.param, but taking the address of one returns a pointer
19// to a copy created in local space which *can't* be used with
20// ld.param/st.param.
21//
22// Copying a byval struct into local memory in IR allows us to enforce
23// the param space restrictions, gives the rest of IR a pointer w/o
24// param space restrictions, and gives us an opportunity to eliminate
25// the copy.
26//
27// Pointer arguments to kernel functions need more work to be lowered:
28//
29// 1. Convert non-byval pointer arguments of CUDA kernels to pointers in the
30// global address space. This allows later optimizations to emit
31// ld.global.*/st.global.* for accessing these pointer arguments. For
32// example,
33//
34// define void @foo(float* %input) {
35// %v = load float, float* %input, align 4
36// ...
37// }
38//
39// becomes
40//
41// define void @foo(float* %input) {
42// %input2 = addrspacecast float* %input to float addrspace(1)*
43// %input3 = addrspacecast float addrspace(1)* %input2 to float*
44// %v = load float, float* %input3, align 4
45// ...
46// }
47//
48// Later, NVPTXInferAddressSpaces will optimize it to
49//
50// define void @foo(float* %input) {
51// %input2 = addrspacecast float* %input to float addrspace(1)*
52// %v = load float, float addrspace(1)* %input2, align 4
53// ...
54// }
55//
56// 2. Convert byval kernel parameters to pointers in the param address space
57// (so that NVPTX emits ld/st.param). Convert pointers *within* a byval
58// kernel parameter to pointers in the global address space. This allows
59// NVPTX to emit ld/st.global.
60//
61// struct S {
62// int *x;
63// int *y;
64// };
65// __global__ void foo(S s) {
66// int *b = s.y;
67// // use b
68// }
69//
70// "b" points to the global address space. In the IR level,
71//
72// define void @foo(ptr byval %input) {
73// %b_ptr = getelementptr {ptr, ptr}, ptr %input, i64 0, i32 1
74// %b = load ptr, ptr %b_ptr
75// ; use %b
76// }
77//
78// becomes
79//
80// define void @foo({i32*, i32*}* byval %input) {
81// %b_param = addrspacecat ptr %input to ptr addrspace(101)
82// %b_ptr = getelementptr {ptr, ptr}, ptr addrspace(101) %b_param, i64 0, i32 1
83// %b = load ptr, ptr addrspace(101) %b_ptr
84// %b_global = addrspacecast ptr %b to ptr addrspace(1)
85// ; use %b_generic
86// }
87//
88// Create a local copy of kernel byval parameters used in a way that *might* mutate
89// the parameter, by storing it in an alloca. Mutations to "grid_constant" parameters
90// are undefined behaviour, and don't require local copies.
91//
92// define void @foo(ptr byval(%struct.s) align 4 %input) {
93// store i32 42, ptr %input
94// ret void
95// }
96//
97// becomes
98//
99// define void @foo(ptr byval(%struct.s) align 4 %input) #1 {
100// %input1 = alloca %struct.s, align 4
101// %input2 = addrspacecast ptr %input to ptr addrspace(101)
102// %input3 = load %struct.s, ptr addrspace(101) %input2, align 4
103// store %struct.s %input3, ptr %input1, align 4
104// store i32 42, ptr %input1, align 4
105// ret void
106// }
107//
108// If %input were passed to a device function, or written to memory,
109// conservatively assume that %input gets mutated, and create a local copy.
110//
111// Convert param pointers to grid_constant byval kernel parameters that are
112// passed into calls (device functions, intrinsics, inline asm), or otherwise
113// "escape" (into stores/ptrtoints) to the generic address space, using the
114// `nvvm.ptr.param.to.gen` intrinsic, so that NVPTX emits cvta.param
115// (available for sm70+)
116//
117// define void @foo(ptr byval(%struct.s) %input) {
118// ; %input is a grid_constant
119// %call = call i32 @escape(ptr %input)
120// ret void
121// }
122//
123// becomes
124//
125// define void @foo(ptr byval(%struct.s) %input) {
126// %input1 = addrspacecast ptr %input to ptr addrspace(101)
127// ; the following intrinsic converts pointer to generic. We don't use an addrspacecast
128// ; to prevent generic -> param -> generic from getting cancelled out
129// %input1.gen = call ptr @llvm.nvvm.ptr.param.to.gen.p0.p101(ptr addrspace(101) %input1)
130// %call = call i32 @escape(ptr %input1.gen)
131// ret void
132// }
133//
134// TODO: merge this pass with NVPTXInferAddressSpaces so that other passes don't
135// cancel the addrspacecast pair this pass emits.
136//===----------------------------------------------------------------------===//
137
138#include "NVPTX.h"
139#include "NVPTXTargetMachine.h"
140#include "NVPTXUtilities.h"
141#include "llvm/ADT/STLExtras.h"
142#include "llvm/ADT/SmallVectorExtras.h"
143#include "llvm/Analysis/PtrUseVisitor.h"
144#include "llvm/CodeGen/TargetPassConfig.h"
145#include "llvm/IR/Attributes.h"
146#include "llvm/IR/Function.h"
147#include "llvm/IR/IRBuilder.h"
148#include "llvm/IR/Instructions.h"
149#include "llvm/IR/IntrinsicInst.h"
150#include "llvm/IR/IntrinsicsNVPTX.h"
151#include "llvm/IR/Type.h"
152#include "llvm/InitializePasses.h"
153#include "llvm/Pass.h"
154#include "llvm/Support/Debug.h"
155#include "llvm/Support/ErrorHandling.h"
156#include "llvm/Support/NVPTXAddrSpace.h"
157#include <queue>
158
159#define DEBUG_TYPE "nvptx-lower-args"
160
161using namespace llvm;
162using namespace NVPTXAS;
163
164namespace {
165class NVPTXLowerArgsLegacyPass : public FunctionPass {
166 bool runOnFunction(Function &F) override;
167
168public:
169 static char ID; // Pass identification, replacement for typeid
170 NVPTXLowerArgsLegacyPass() : FunctionPass(ID) {}
171 StringRef getPassName() const override {
172 return "Lower pointer arguments of CUDA kernels";
173 }
174 void getAnalysisUsage(AnalysisUsage &AU) const override {
175 AU.addRequired<TargetPassConfig>();
176 }
177};
178} // namespace
179
180char NVPTXLowerArgsLegacyPass::ID = 1;
181
182INITIALIZE_PASS_BEGIN(NVPTXLowerArgsLegacyPass, "nvptx-lower-args",
183 "Lower arguments (NVPTX)", false, false)
184INITIALIZE_PASS_DEPENDENCY(TargetPassConfig)
185INITIALIZE_PASS_END(NVPTXLowerArgsLegacyPass, "nvptx-lower-args",
186 "Lower arguments (NVPTX)", false, false)
187
188// =============================================================================
189// If the function had a byval struct ptr arg, say foo(ptr byval(%struct.x) %d),
190// and we can't guarantee that the only accesses are loads,
191// then add the following instructions to the first basic block:
192//
193// %temp = alloca %struct.x, align 8
194// %tempd = addrspacecast ptr %d to ptr addrspace(101)
195// %tv = load %struct.x, ptr addrspace(101) %tempd
196// store %struct.x %tv, ptr %temp, align 8
197//
198// The above code allocates some space in the stack and copies the incoming
199// struct from param space to local space.
200// Then replace all occurrences of %d by %temp.
201//
202// In case we know that all users are GEPs or Loads, replace them with the same
203// ones in parameter AS, so we can access them using ld.param.
204// =============================================================================
205
206/// Recursively convert the users of a param to the param address space.
207static void convertToParamAS(ArrayRef<Use *> OldUses, Value *Param) {
208 struct IP {
209 Use *OldUse;
210 Value *NewParam;
211 };
212
213 const auto CloneInstInParamAS = [](const IP &I) -> Value * {
214 auto *OldInst = cast<Instruction>(Val: I.OldUse->getUser());
215 if (auto *LI = dyn_cast<LoadInst>(Val: OldInst)) {
216 LI->setOperand(i_nocapture: 0, Val_nocapture: I.NewParam);
217 return LI;
218 }
219 if (auto *GEP = dyn_cast<GetElementPtrInst>(Val: OldInst)) {
220 SmallVector<Value *, 4> Indices(GEP->indices());
221 auto *NewGEP = GetElementPtrInst::Create(
222 PointeeType: GEP->getSourceElementType(), Ptr: I.NewParam, IdxList: Indices, NameStr: GEP->getName(),
223 InsertBefore: GEP->getIterator());
224 NewGEP->setNoWrapFlags(GEP->getNoWrapFlags());
225 return NewGEP;
226 }
227 if (auto *BC = dyn_cast<BitCastInst>(Val: OldInst)) {
228 auto *NewBCType = PointerType::get(C&: BC->getContext(), AddressSpace: ADDRESS_SPACE_PARAM);
229 return BitCastInst::Create(BC->getOpcode(), S: I.NewParam, Ty: NewBCType,
230 Name: BC->getName(), InsertBefore: BC->getIterator());
231 }
232 if (auto *ASC = dyn_cast<AddrSpaceCastInst>(Val: OldInst)) {
233 assert(ASC->getDestAddressSpace() == ADDRESS_SPACE_PARAM);
234 (void)ASC;
235 // Just pass through the argument, the old ASC is no longer needed.
236 return I.NewParam;
237 }
238 if (auto *MI = dyn_cast<MemTransferInst>(Val: OldInst)) {
239 if (MI->getRawSource() == I.OldUse->get()) {
240 // convert to memcpy/memmove from param space.
241 IRBuilder<> Builder(OldInst);
242 Intrinsic::ID ID = MI->getIntrinsicID();
243
244 CallInst *B = Builder.CreateMemTransferInst(
245 IntrID: ID, Dst: MI->getRawDest(), DstAlign: MI->getDestAlign(), Src: I.NewParam,
246 SrcAlign: MI->getSourceAlign(), Size: MI->getLength(), isVolatile: MI->isVolatile());
247 for (unsigned I : {0, 1})
248 if (uint64_t Bytes = MI->getParamDereferenceableBytes(i: I))
249 B->addDereferenceableParamAttr(i: I, Bytes);
250 return B;
251 }
252 }
253
254 llvm_unreachable("Unsupported instruction");
255 };
256
257 auto ItemsToConvert =
258 map_to_vector(C&: OldUses, F: [=](Use *U) -> IP { return {.OldUse: U, .NewParam: Param}; });
259 SmallVector<Instruction *> InstructionsToDelete;
260
261 while (!ItemsToConvert.empty()) {
262 IP I = ItemsToConvert.pop_back_val();
263 Value *NewInst = CloneInstInParamAS(I);
264 Instruction *OldInst = cast<Instruction>(Val: I.OldUse->getUser());
265
266 if (NewInst && NewInst != OldInst) {
267 // We've created a new instruction. Queue users of the old instruction to
268 // be converted and the instruction itself to be deleted. We can't delete
269 // the old instruction yet, because it's still in use by a load somewhere.
270 for (Use &U : OldInst->uses())
271 ItemsToConvert.push_back(Elt: {.OldUse: &U, .NewParam: NewInst});
272
273 InstructionsToDelete.push_back(Elt: OldInst);
274 }
275 }
276
277 // Now we know that all argument loads are using addresses in parameter space
278 // and we can finally remove the old instructions in generic AS. Instructions
279 // scheduled for removal should be processed in reverse order so the ones
280 // closest to the load are deleted first. Otherwise they may still be in use.
281 // E.g if we have Value = Load(BitCast(GEP(arg))), InstructionsToDelete will
282 // have {GEP,BitCast}. GEP can't be deleted first, because it's still used by
283 // the BitCast.
284 for (Instruction *I : llvm::reverse(C&: InstructionsToDelete))
285 I->eraseFromParent();
286}
287
288static Align setByValParamAlign(Argument *Arg) {
289 Function *F = Arg->getParent();
290 Type *ByValType = Arg->getParamByValType();
291 const DataLayout &DL = F->getDataLayout();
292
293 const Align OptimizedAlign = getFunctionParamOptimizedAlign(F, ArgTy: ByValType, DL);
294 const Align CurrentAlign = Arg->getParamAlign().valueOrOne();
295
296 if (CurrentAlign >= OptimizedAlign)
297 return CurrentAlign;
298
299 LLVM_DEBUG(dbgs() << "Try to use alignment " << OptimizedAlign.value()
300 << " instead of " << CurrentAlign.value() << " for " << *Arg
301 << '\n');
302
303 Arg->removeAttr(Kind: Attribute::Alignment);
304 Arg->addAttr(Attr: Attribute::getWithAlignment(Context&: F->getContext(), Alignment: OptimizedAlign));
305
306 return OptimizedAlign;
307}
308
309// Adjust alignment of arguments passed byval in .param address space. We can
310// increase alignment of such arguments in a way that ensures that we can
311// effectively vectorize their loads. We should also traverse all loads from
312// byval pointer and adjust their alignment, if those were using known offset.
313// Such alignment changes must be conformed with parameter store and load in
314// NVPTXTargetLowering::LowerCall.
315static void propagateAlignmentToLoads(Value *Val, Align NewAlign,
316 const DataLayout &DL) {
317 struct Load {
318 LoadInst *Inst;
319 uint64_t Offset;
320 };
321
322 struct LoadContext {
323 Value *InitialVal;
324 uint64_t Offset;
325 };
326
327 SmallVector<Load> Loads;
328 std::queue<LoadContext> Worklist;
329 Worklist.push(x: {.InitialVal: Val, .Offset: 0});
330
331 while (!Worklist.empty()) {
332 LoadContext Ctx = Worklist.front();
333 Worklist.pop();
334
335 for (User *CurUser : Ctx.InitialVal->users()) {
336 if (auto *I = dyn_cast<LoadInst>(Val: CurUser))
337 Loads.push_back(Elt: {.Inst: I, .Offset: Ctx.Offset});
338 else if (isa<BitCastInst>(Val: CurUser) || isa<AddrSpaceCastInst>(Val: CurUser))
339 Worklist.push(x: {.InitialVal: cast<Instruction>(Val: CurUser), .Offset: Ctx.Offset});
340 else if (auto *I = dyn_cast<GetElementPtrInst>(Val: CurUser)) {
341 APInt OffsetAccumulated =
342 APInt::getZero(numBits: DL.getIndexSizeInBits(AS: ADDRESS_SPACE_PARAM));
343
344 if (!I->accumulateConstantOffset(DL, Offset&: OffsetAccumulated))
345 continue;
346
347 uint64_t OffsetLimit = -1;
348 uint64_t Offset = OffsetAccumulated.getLimitedValue(Limit: OffsetLimit);
349 assert(Offset != OffsetLimit && "Expect Offset less than UINT64_MAX");
350
351 Worklist.push(x: {.InitialVal: I, .Offset: Ctx.Offset + Offset});
352 }
353 }
354 }
355
356 for (Load &CurLoad : Loads) {
357 Align NewLoadAlign = commonAlignment(A: NewAlign, Offset: CurLoad.Offset);
358 Align CurLoadAlign = CurLoad.Inst->getAlign();
359 CurLoad.Inst->setAlignment(std::max(a: NewLoadAlign, b: CurLoadAlign));
360 }
361}
362
363// Create a call to the nvvm_internal_addrspace_wrap intrinsic and set the
364// alignment of the return value based on the alignment of the argument.
365static CallInst *createNVVMInternalAddrspaceWrap(IRBuilder<> &IRB,
366 Argument &Arg) {
367 CallInst *ArgInParam =
368 IRB.CreateIntrinsic(ID: Intrinsic::nvvm_internal_addrspace_wrap,
369 Types: {IRB.getPtrTy(AddrSpace: ADDRESS_SPACE_PARAM), Arg.getType()},
370 Args: &Arg, FMFSource: {}, Name: Arg.getName() + ".param");
371
372 if (MaybeAlign ParamAlign = Arg.getParamAlign())
373 ArgInParam->addRetAttr(
374 Attr: Attribute::getWithAlignment(Context&: ArgInParam->getContext(), Alignment: *ParamAlign));
375
376 Arg.addAttr(Attr: Attribute::get(Context&: Arg.getContext(), Kind: "nvvm.grid_constant"));
377 Arg.addAttr(Kind: Attribute::ReadOnly);
378
379 return ArgInParam;
380}
381
382namespace {
383struct ArgUseChecker : PtrUseVisitor<ArgUseChecker> {
384 using Base = PtrUseVisitor<ArgUseChecker>;
385 // Set of phi/select instructions using the Arg
386 SmallPtrSet<Instruction *, 4> Conditionals;
387
388 ArgUseChecker(const DataLayout &DL) : PtrUseVisitor(DL) {}
389
390 PtrInfo visitArgPtr(Argument &A) {
391 assert(A.getType()->isPointerTy());
392 IntegerType *IntIdxTy = cast<IntegerType>(Val: DL.getIndexType(PtrTy: A.getType()));
393 IsOffsetKnown = false;
394 Offset = APInt(IntIdxTy->getBitWidth(), 0);
395 PI.reset();
396
397 LLVM_DEBUG(dbgs() << "Checking Argument " << A << "\n");
398 // Enqueue the uses of this pointer.
399 enqueueUsers(I&: A);
400
401 // Visit all the uses off the worklist until it is empty.
402 // Note that unlike PtrUseVisitor we intentionally do not track offsets.
403 // We're only interested in how we use the pointer.
404 while (!(Worklist.empty() || PI.isAborted())) {
405 UseToVisit ToVisit = Worklist.pop_back_val();
406 U = ToVisit.UseAndIsOffsetKnown.getPointer();
407 Instruction *I = cast<Instruction>(Val: U->getUser());
408 LLVM_DEBUG(dbgs() << "Processing " << *I << "\n");
409 Base::visit(I);
410 }
411 if (PI.isEscaped())
412 LLVM_DEBUG(dbgs() << "Argument pointer escaped: " << *PI.getEscapingInst()
413 << "\n");
414 else if (PI.isAborted())
415 LLVM_DEBUG(dbgs() << "Pointer use needs a copy: " << *PI.getAbortingInst()
416 << "\n");
417 LLVM_DEBUG(dbgs() << "Traversed " << Conditionals.size()
418 << " conditionals\n");
419 return PI;
420 }
421
422 void visitStoreInst(StoreInst &SI) {
423 // Storing the pointer escapes it.
424 if (U->get() == SI.getValueOperand())
425 return PI.setEscapedAndAborted(&SI);
426
427 PI.setAborted(&SI);
428 }
429
430 void visitAddrSpaceCastInst(AddrSpaceCastInst &ASC) {
431 // ASC to param space are no-ops and do not need a copy
432 if (ASC.getDestAddressSpace() != ADDRESS_SPACE_PARAM)
433 return PI.setEscapedAndAborted(&ASC);
434 Base::visitAddrSpaceCastInst(ASC);
435 }
436
437 void visitPtrToIntInst(PtrToIntInst &I) { Base::visitPtrToIntInst(I); }
438
439 void visitPHINodeOrSelectInst(Instruction &I) {
440 assert(isa<PHINode>(I) || isa<SelectInst>(I));
441 enqueueUsers(I);
442 Conditionals.insert(Ptr: &I);
443 }
444 // PHI and select just pass through the pointers.
445 void visitPHINode(PHINode &PN) { visitPHINodeOrSelectInst(I&: PN); }
446 void visitSelectInst(SelectInst &SI) { visitPHINodeOrSelectInst(I&: SI); }
447
448 // memcpy/memmove are OK when the pointer is source. We can convert them to
449 // AS-specific memcpy.
450 void visitMemTransferInst(MemTransferInst &II) {
451 if (*U == II.getRawDest())
452 PI.setAborted(&II);
453 }
454
455 void visitMemSetInst(MemSetInst &II) { PI.setAborted(&II); }
456}; // struct ArgUseChecker
457
458void copyByValParam(Function &F, Argument &Arg) {
459 LLVM_DEBUG(dbgs() << "Creating a local copy of " << Arg << "\n");
460 Type *ByValType = Arg.getParamByValType();
461 const DataLayout &DL = F.getDataLayout();
462 IRBuilder<> IRB(&F.getEntryBlock().front());
463 AllocaInst *AllocA = IRB.CreateAlloca(Ty: ByValType, ArraySize: nullptr, Name: Arg.getName());
464 // Set the alignment to alignment of the byval parameter. This is because,
465 // later load/stores assume that alignment, and we are going to replace
466 // the use of the byval parameter with this alloca instruction.
467 AllocA->setAlignment(
468 Arg.getParamAlign().value_or(u: DL.getPrefTypeAlign(Ty: ByValType)));
469 Arg.replaceAllUsesWith(V: AllocA);
470
471 Value *ArgInParamAS = createNVVMInternalAddrspaceWrap(IRB, Arg);
472
473 // Be sure to propagate alignment to this load; LLVM doesn't know that NVPTX
474 // addrspacecast preserves alignment. Since params are constant, this load
475 // is definitely not volatile.
476 const auto ArgSize = *AllocA->getAllocationSize(DL);
477 IRB.CreateMemCpy(Dst: AllocA, DstAlign: AllocA->getAlign(), Src: ArgInParamAS, SrcAlign: AllocA->getAlign(),
478 Size: ArgSize);
479}
480} // namespace
481
482static bool argIsProcessed(Argument *Arg) {
483 if (Arg->use_empty())
484 return true;
485
486 // If the argument is already wrapped, it was processed by this pass before.
487 if (Arg->hasOneUse())
488 if (const auto *II = dyn_cast<IntrinsicInst>(Val: *Arg->user_begin()))
489 if (II->getIntrinsicID() == Intrinsic::nvvm_internal_addrspace_wrap)
490 return true;
491
492 return false;
493}
494
495static void lowerKernelByValParam(Argument *Arg, Function &F,
496 const bool HasCvtaParam) {
497 assert(isKernelFunction(F));
498
499 const DataLayout &DL = F.getDataLayout();
500 IRBuilder<> IRB(&F.getEntryBlock().front());
501
502 if (argIsProcessed(Arg))
503 return;
504
505 const Align NewArgAlign = setByValParamAlign(Arg);
506
507 // (1) First check the easy case, if were able to trace through all the uses
508 // and we can convert them all to param AS, then we'll do this.
509 ArgUseChecker AUC(DL);
510 ArgUseChecker::PtrInfo PI = AUC.visitArgPtr(A&: *Arg);
511 const bool ArgUseIsReadOnly = !(PI.isEscaped() || PI.isAborted());
512 if (ArgUseIsReadOnly && AUC.Conditionals.empty()) {
513 // Convert all loads and intermediate operations to use parameter AS and
514 // skip creation of a local copy of the argument.
515 SmallVector<Use *, 16> UsesToUpdate(llvm::make_pointer_range(Range: Arg->uses()));
516 Value *ArgInParamAS = createNVVMInternalAddrspaceWrap(IRB, Arg&: *Arg);
517 for (Use *U : UsesToUpdate)
518 convertToParamAS(OldUses: U, Param: ArgInParamAS);
519
520 propagateAlignmentToLoads(Val: ArgInParamAS, NewAlign: NewArgAlign, DL);
521 return;
522 }
523
524 // (2) If the argument is grid constant, we get to use the pointer directly.
525 if (HasCvtaParam && (ArgUseIsReadOnly || isParamGridConstant(*Arg))) {
526 LLVM_DEBUG(dbgs() << "Using non-copy pointer to " << *Arg << "\n");
527
528 // Cast argument to param address space. Because the backend will emit the
529 // argument already in the param address space, we need to use the noop
530 // intrinsic, this had the added benefit of preventing other optimizations
531 // from folding away this pair of addrspacecasts.
532 Instruction *ArgInParamAS = createNVVMInternalAddrspaceWrap(IRB, Arg&: *Arg);
533
534 // Cast param address to generic address space.
535 Value *GenericArg = IRB.CreateAddrSpaceCast(
536 V: ArgInParamAS, DestTy: IRB.getPtrTy(AddrSpace: ADDRESS_SPACE_GENERIC),
537 Name: Arg->getName() + ".gen");
538
539 Arg->replaceAllUsesWith(V: GenericArg);
540
541 // Do not replace Arg in the cast to param space
542 ArgInParamAS->setOperand(i: 0, Val: Arg);
543 return;
544 }
545
546 // (3) Otherwise we have to create a copy of the argument in local memory.
547 copyByValParam(F, Arg&: *Arg);
548}
549
550// =============================================================================
551// Main function for this pass.
552// =============================================================================
553static bool runOnKernelFunction(const NVPTXTargetMachine &TM, Function &F) {
554 const NVPTXSubtarget *ST = TM.getSubtargetImpl(F);
555 const bool HasCvtaParam = ST->hasCvtaParam();
556
557 LLVM_DEBUG(dbgs() << "Lowering kernel args of " << F.getName() << "\n");
558 bool Changed = false;
559 for (Argument &Arg : F.args())
560 if (Arg.hasByValAttr()) {
561 lowerKernelByValParam(Arg: &Arg, F, HasCvtaParam);
562 Changed = true;
563 }
564
565 return Changed;
566}
567
568// Device functions only need to copy byval args into local memory.
569static bool runOnDeviceFunction(Function &F) {
570 LLVM_DEBUG(dbgs() << "Lowering function args of " << F.getName() << "\n");
571
572 const DataLayout &DL = F.getDataLayout();
573
574 bool Changed = false;
575 for (Argument &Arg : F.args())
576 if (Arg.hasByValAttr()) {
577 const Align NewArgAlign = setByValParamAlign(&Arg);
578 propagateAlignmentToLoads(Val: &Arg, NewAlign: NewArgAlign, DL);
579 Changed = true;
580 }
581
582 return Changed;
583}
584
585static bool processFunction(Function &F, NVPTXTargetMachine &TM) {
586 return isKernelFunction(F) ? runOnKernelFunction(TM, F)
587 : runOnDeviceFunction(F);
588}
589
590bool NVPTXLowerArgsLegacyPass::runOnFunction(Function &F) {
591 auto &TM = getAnalysis<TargetPassConfig>().getTM<NVPTXTargetMachine>();
592 return processFunction(F, TM);
593}
594
595FunctionPass *llvm::createNVPTXLowerArgsPass() {
596 return new NVPTXLowerArgsLegacyPass();
597}
598
599static bool copyFunctionByValArgs(Function &F) {
600 LLVM_DEBUG(dbgs() << "Creating a copy of byval args of " << F.getName()
601 << "\n");
602 bool Changed = false;
603 if (isKernelFunction(F)) {
604 for (Argument &Arg : F.args())
605 if (Arg.hasByValAttr() && !isParamGridConstant(Arg)) {
606 copyByValParam(F, Arg);
607 Changed = true;
608 }
609 }
610 return Changed;
611}
612
613PreservedAnalyses NVPTXCopyByValArgsPass::run(Function &F,
614 FunctionAnalysisManager &AM) {
615 return copyFunctionByValArgs(F) ? PreservedAnalyses::none()
616 : PreservedAnalyses::all();
617}
618
619PreservedAnalyses NVPTXLowerArgsPass::run(Function &F,
620 FunctionAnalysisManager &AM) {
621 auto &NTM = static_cast<NVPTXTargetMachine &>(TM);
622 bool Changed = processFunction(F, TM&: NTM);
623 return Changed ? PreservedAnalyses::none() : PreservedAnalyses::all();
624}
625