1//===-- AMDGPULowerKernelArguments.cpp ------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9/// \file This pass replaces accesses to kernel arguments with loads from
10/// offsets from the kernarg base pointer.
11//
12//===----------------------------------------------------------------------===//
13
14#include "AMDGPU.h"
15#include "AMDGPUAsanInstrumentation.h"
16#include "GCNSubtarget.h"
17#include "llvm/Analysis/AliasAnalysis.h"
18#include "llvm/Analysis/CaptureTracking.h"
19#include "llvm/Analysis/ScopedNoAliasAA.h"
20#include "llvm/Analysis/ValueTracking.h"
21#include "llvm/CodeGen/TargetPassConfig.h"
22#include "llvm/IR/Argument.h"
23#include "llvm/IR/Attributes.h"
24#include "llvm/IR/Dominators.h"
25#include "llvm/IR/IRBuilder.h"
26#include "llvm/IR/InstIterator.h"
27#include "llvm/IR/Instruction.h"
28#include "llvm/IR/Instructions.h"
29#include "llvm/IR/IntrinsicsAMDGPU.h"
30#include "llvm/IR/LLVMContext.h"
31#include "llvm/IR/MDBuilder.h"
32#include "llvm/Target/TargetMachine.h"
33#include <optional>
34#include <string>
35
36#define DEBUG_TYPE "amdgpu-lower-kernel-arguments"
37
38using namespace llvm;
39
40namespace {
41
42class AMDGPULowerKernelArguments : public FunctionPass {
43public:
44 static char ID;
45
46 AMDGPULowerKernelArguments() : FunctionPass(ID) {}
47
48 bool runOnFunction(Function &F) override;
49
50 void getAnalysisUsage(AnalysisUsage &AU) const override {
51 AU.addRequired<TargetPassConfig>();
52 AU.addRequired<DominatorTreeWrapperPass>();
53 AU.setPreservesAll();
54 }
55};
56
57} // end anonymous namespace
58
59// skip allocas
60static BasicBlock::iterator getInsertPt(BasicBlock &BB) {
61 BasicBlock::iterator InsPt = BB.getFirstInsertionPt();
62 for (BasicBlock::iterator E = BB.end(); InsPt != E; ++InsPt) {
63 AllocaInst *AI = dyn_cast<AllocaInst>(Val: &*InsPt);
64
65 // If this is a dynamic alloca, the value may depend on the loaded kernargs,
66 // so loads will need to be inserted before it.
67 if (!AI || !AI->isStaticAlloca())
68 break;
69 }
70
71 return InsPt;
72}
73
74static void addAliasScopeMetadata(Function &F, const DataLayout &DL,
75 DominatorTree &DT) {
76 // Collect noalias arguments.
77 SmallVector<const Argument *, 4u> NoAliasArgs;
78
79 for (Argument &Arg : F.args())
80 if (Arg.hasNoAliasAttr() && !Arg.use_empty())
81 NoAliasArgs.push_back(Elt: &Arg);
82
83 if (NoAliasArgs.empty())
84 return;
85
86 // Add alias scopes for each noalias argument.
87 MDBuilder MDB(F.getContext());
88 DenseMap<const Argument *, MDNode *> NewScopes;
89 MDNode *NewDomain = MDB.createAnonymousAliasScopeDomain(Name: F.getName());
90
91 for (unsigned I = 0u; I < NoAliasArgs.size(); ++I) {
92 const Argument *Arg = NoAliasArgs[I];
93 MDNode *NewScope = MDB.createAnonymousAliasScope(Domain: NewDomain, Name: Arg->getName());
94 NewScopes.insert(KV: {Arg, NewScope});
95 }
96
97 // Iterate over all instructions.
98 for (inst_iterator Inst = inst_begin(F), InstEnd = inst_end(F);
99 Inst != InstEnd; ++Inst) {
100 // If instruction accesses memory, collect its pointer arguments.
101 Instruction *I = &(*Inst);
102 SmallVector<const Value *, 2u> PtrArgs;
103
104 if (std::optional<MemoryLocation> MO = MemoryLocation::getOrNone(Inst: I))
105 PtrArgs.push_back(Elt: MO->Ptr);
106 else if (const CallBase *Call = dyn_cast<CallBase>(Val: I)) {
107 if (Call->doesNotAccessMemory())
108 continue;
109
110 for (Value *Arg : Call->args()) {
111 if (!Arg->getType()->isPointerTy())
112 continue;
113
114 PtrArgs.push_back(Elt: Arg);
115 }
116 } else {
117 // Not a memory access and not a call — nothing to annotate.
118 continue;
119 }
120
121 // Collect underlying objects of pointer arguments.
122 SmallVector<Metadata *, 4u> Scopes;
123 SmallPtrSet<const Value *, 4u> ObjSet;
124 SmallVector<Metadata *, 4u> NoAliases;
125
126 if (!PtrArgs.empty()) {
127 // Trace pointer arguments back to underlying objects and decide which
128 // noalias scopes apply based on provenance and capture analysis.
129 for (const Value *Val : PtrArgs) {
130 SmallVector<const Value *, 4u> Objects;
131 getUnderlyingObjects(V: Val, Objects);
132 ObjSet.insert_range(R&: Objects);
133 }
134
135 bool RequiresNoCaptureBefore = false;
136 bool UsesUnknownObject = false;
137 bool UsesAliasingPtr = false;
138
139 for (const Value *Val : ObjSet) {
140 if (isa<ConstantData>(Val))
141 continue;
142
143 if (const Argument *Arg = dyn_cast<Argument>(Val)) {
144 if (!Arg->hasAttribute(Kind: Attribute::NoAlias))
145 UsesAliasingPtr = true;
146 } else
147 UsesAliasingPtr = true;
148
149 if (isEscapeSource(V: Val))
150 RequiresNoCaptureBefore = true;
151 else if (!isa<Argument>(Val) && isIdentifiedObject(V: Val))
152 UsesUnknownObject = true;
153 }
154
155 if (UsesUnknownObject)
156 continue;
157
158 // Collect noalias scopes for instruction.
159 for (const Argument *Arg : NoAliasArgs) {
160 if (ObjSet.contains(Ptr: Arg))
161 continue;
162
163 if (!RequiresNoCaptureBefore ||
164 !capturesAnything(CC: PointerMayBeCapturedBefore(
165 V: Arg, ReturnCaptures: false, I, DT: &DT, IncludeI: false, Mask: CaptureComponents::Provenance)))
166 NoAliases.push_back(Elt: NewScopes[Arg]);
167 }
168
169 // Collect scopes for alias.scope metadata.
170 if (!UsesAliasingPtr)
171 for (const Argument *Arg : NoAliasArgs) {
172 if (ObjSet.count(Ptr: Arg))
173 Scopes.push_back(Elt: NewScopes[Arg]);
174 }
175 } else {
176 // The instruction accesses memory but has no pointer arguments.
177 // Since none of its operands derive from any noalias kernel argument,
178 // it cannot possibly alias them. Mark it as !noalias w.r.t. every
179 // noalias scope so that ScopedNoAliasAA can prove non-aliasing when
180 // other instructions reference those scopes via !alias.scope.
181 for (const Argument *Arg : NoAliasArgs)
182 NoAliases.push_back(Elt: NewScopes[Arg]);
183 }
184
185 // Add noalias metadata to instruction.
186 if (!NoAliases.empty()) {
187 MDNode *NewMD =
188 MDNode::concatenate(A: Inst->getMetadata(KindID: LLVMContext::MD_noalias),
189 B: MDNode::get(Context&: F.getContext(), MDs: NoAliases));
190 Inst->setMetadata(KindID: LLVMContext::MD_noalias, Node: NewMD);
191 }
192
193 // Add alias.scope metadata to instruction.
194 if (!Scopes.empty()) {
195 MDNode *NewMD =
196 MDNode::concatenate(A: Inst->getMetadata(KindID: LLVMContext::MD_alias_scope),
197 B: MDNode::get(Context&: F.getContext(), MDs: Scopes));
198 Inst->setMetadata(KindID: LLVMContext::MD_alias_scope, Node: NewMD);
199 }
200 }
201}
202
203static bool lowerKernelArguments(Function &F, const TargetMachine &TM,
204 DominatorTree &DT) {
205 CallingConv::ID CC = F.getCallingConv();
206 if (CC != CallingConv::AMDGPU_KERNEL || F.arg_empty())
207 return false;
208
209 const GCNSubtarget &ST = TM.getSubtarget<GCNSubtarget>(F);
210 LLVMContext &Ctx = F.getContext();
211 const DataLayout &DL = F.getDataLayout();
212 BasicBlock &EntryBlock = *F.begin();
213 IRBuilder<> Builder(&EntryBlock, getInsertPt(BB&: EntryBlock));
214
215 const Align KernArgBaseAlign(16); // FIXME: Increase if necessary
216 const uint64_t BaseOffset = ST.getExplicitKernelArgOffset();
217
218 Align MaxAlign;
219 // FIXME: Alignment is broken with explicit arg offset.;
220 const uint64_t TotalKernArgSize = ST.getKernArgSegmentSize(F, MaxAlign);
221 if (TotalKernArgSize == 0)
222 return false;
223
224 CallInst *KernArgSegment =
225 Builder.CreateIntrinsic(ID: Intrinsic::amdgcn_kernarg_segment_ptr, Args: {},
226 FMFSource: nullptr, Name: F.getName() + ".kernarg.segment");
227 KernArgSegment->addRetAttr(Kind: Attribute::NonNull);
228 KernArgSegment->addRetAttr(
229 Attr: Attribute::getWithDereferenceableBytes(Context&: Ctx, Bytes: TotalKernArgSize));
230
231 uint64_t ExplicitArgOffset = 0;
232
233 addAliasScopeMetadata(F, DL: F.getParent()->getDataLayout(), DT);
234
235 for (Argument &Arg : F.args()) {
236 const bool IsByRef = Arg.hasByRefAttr();
237 Type *ArgTy = IsByRef ? Arg.getParamByRefType() : Arg.getType();
238 MaybeAlign ParamAlign = IsByRef ? Arg.getParamAlign() : std::nullopt;
239 Align ABITypeAlign = DL.getValueOrABITypeAlignment(Alignment: ParamAlign, Ty: ArgTy);
240
241 uint64_t Size = DL.getTypeSizeInBits(Ty: ArgTy);
242 uint64_t AllocSize = DL.getTypeAllocSize(Ty: ArgTy);
243
244 uint64_t EltOffset = alignTo(Size: ExplicitArgOffset, A: ABITypeAlign) + BaseOffset;
245 ExplicitArgOffset = alignTo(Size: ExplicitArgOffset, A: ABITypeAlign) + AllocSize;
246
247 // Skip inreg arguments which should be preloaded.
248 if (Arg.use_empty() || Arg.hasInRegAttr())
249 continue;
250
251 // If this is byval, the loads are already explicit in the function. We just
252 // need to rewrite the pointer values.
253 if (IsByRef) {
254 Value *ArgOffsetPtr = Builder.CreateConstInBoundsGEP1_64(
255 Ty: Builder.getInt8Ty(), Ptr: KernArgSegment, Idx0: EltOffset,
256 Name: Arg.getName() + ".byval.kernarg.offset");
257
258 Value *CastOffsetPtr =
259 Builder.CreateAddrSpaceCast(V: ArgOffsetPtr, DestTy: Arg.getType());
260 Arg.replaceAllUsesWith(V: CastOffsetPtr);
261 continue;
262 }
263
264 if (PointerType *PT = dyn_cast<PointerType>(Val: ArgTy)) {
265 // FIXME: Hack. We rely on AssertZext to be able to fold DS addressing
266 // modes on SI to know the high bits are 0 so pointer adds don't wrap. We
267 // can't represent this with range metadata because it's only allowed for
268 // integer types.
269 if ((PT->getAddressSpace() == AMDGPUAS::LOCAL_ADDRESS ||
270 PT->getAddressSpace() == AMDGPUAS::REGION_ADDRESS) &&
271 !ST.hasUsableDSOffset())
272 continue;
273 }
274
275 auto *VT = dyn_cast<FixedVectorType>(Val: ArgTy);
276 bool IsV3 = VT && VT->getNumElements() == 3;
277 bool DoShiftOpt = Size < 32 && !ArgTy->isAggregateType();
278
279 VectorType *V4Ty = nullptr;
280
281 int64_t AlignDownOffset = alignDown(Value: EltOffset, Align: 4);
282 int64_t OffsetDiff = EltOffset - AlignDownOffset;
283 Align AdjustedAlign = commonAlignment(
284 A: KernArgBaseAlign, Offset: DoShiftOpt ? AlignDownOffset : EltOffset);
285
286 Value *ArgPtr;
287 Type *AdjustedArgTy;
288 if (DoShiftOpt) { // FIXME: Handle aggregate types
289 // Since we don't have sub-dword scalar loads, avoid doing an extload by
290 // loading earlier than the argument address, and extracting the relevant
291 // bits.
292 // TODO: Update this for GFX12 which does have scalar sub-dword loads.
293 //
294 // Additionally widen any sub-dword load to i32 even if suitably aligned,
295 // so that CSE between different argument loads works easily.
296 ArgPtr = Builder.CreateConstInBoundsGEP1_64(
297 Ty: Builder.getInt8Ty(), Ptr: KernArgSegment, Idx0: AlignDownOffset,
298 Name: Arg.getName() + ".kernarg.offset.align.down");
299 AdjustedArgTy = Builder.getInt32Ty();
300 } else {
301 ArgPtr = Builder.CreateConstInBoundsGEP1_64(
302 Ty: Builder.getInt8Ty(), Ptr: KernArgSegment, Idx0: EltOffset,
303 Name: Arg.getName() + ".kernarg.offset");
304 AdjustedArgTy = ArgTy;
305 }
306
307 if (IsV3 && Size >= 32) {
308 V4Ty = FixedVectorType::get(ElementType: VT->getElementType(), NumElts: 4);
309 // Use the hack that clang uses to avoid SelectionDAG ruining v3 loads
310 AdjustedArgTy = V4Ty;
311 }
312
313 LoadInst *Load =
314 Builder.CreateAlignedLoad(Ty: AdjustedArgTy, Ptr: ArgPtr, Align: AdjustedAlign);
315 Load->setMetadata(KindID: LLVMContext::MD_invariant_load, Node: MDNode::get(Context&: Ctx, MDs: {}));
316
317 MDBuilder MDB(Ctx);
318
319 if (Arg.hasAttribute(Kind: Attribute::NoUndef))
320 Load->setMetadata(KindID: LLVMContext::MD_noundef, Node: MDNode::get(Context&: Ctx, MDs: {}));
321
322 if (Arg.hasAttribute(Kind: Attribute::Range)) {
323 const ConstantRange &Range =
324 Arg.getAttribute(Kind: Attribute::Range).getValueAsConstantRange();
325 Load->setMetadata(KindID: LLVMContext::MD_range,
326 Node: MDB.createRange(Lo: Range.getLower(), Hi: Range.getUpper()));
327 }
328
329 if (isa<PointerType>(Val: ArgTy)) {
330 if (Arg.hasNonNullAttr())
331 Load->setMetadata(KindID: LLVMContext::MD_nonnull, Node: MDNode::get(Context&: Ctx, MDs: {}));
332
333 uint64_t DerefBytes = Arg.getDereferenceableBytes();
334 if (DerefBytes != 0) {
335 Load->setMetadata(
336 KindID: LLVMContext::MD_dereferenceable,
337 Node: MDNode::get(Context&: Ctx,
338 MDs: MDB.createConstant(
339 C: ConstantInt::get(Ty: Builder.getInt64Ty(), V: DerefBytes))));
340 }
341
342 uint64_t DerefOrNullBytes = Arg.getDereferenceableOrNullBytes();
343 if (DerefOrNullBytes != 0) {
344 Load->setMetadata(
345 KindID: LLVMContext::MD_dereferenceable_or_null,
346 Node: MDNode::get(Context&: Ctx,
347 MDs: MDB.createConstant(C: ConstantInt::get(Ty: Builder.getInt64Ty(),
348 V: DerefOrNullBytes))));
349 }
350
351 if (MaybeAlign ParamAlign = Arg.getParamAlign()) {
352 Load->setMetadata(
353 KindID: LLVMContext::MD_align,
354 Node: MDNode::get(Context&: Ctx, MDs: MDB.createConstant(C: ConstantInt::get(
355 Ty: Builder.getInt64Ty(), V: ParamAlign->value()))));
356 }
357 }
358
359 if (DoShiftOpt) {
360 Value *ExtractBits = OffsetDiff == 0 ?
361 Load : Builder.CreateLShr(LHS: Load, RHS: OffsetDiff * 8);
362
363 IntegerType *ArgIntTy = Builder.getIntNTy(N: Size);
364 Value *Trunc = Builder.CreateTrunc(V: ExtractBits, DestTy: ArgIntTy);
365 Value *NewVal = Builder.CreateBitCast(V: Trunc, DestTy: ArgTy,
366 Name: Arg.getName() + ".load");
367 Arg.replaceAllUsesWith(V: NewVal);
368 } else if (IsV3) {
369 Value *Shuf = Builder.CreateShuffleVector(V: Load, Mask: ArrayRef<int>{0, 1, 2},
370 Name: Arg.getName() + ".load");
371 Arg.replaceAllUsesWith(V: Shuf);
372 } else {
373 Load->setName(Arg.getName() + ".load");
374 Arg.replaceAllUsesWith(V: Load);
375 }
376 }
377
378 KernArgSegment->addRetAttr(
379 Attr: Attribute::getWithAlignment(Context&: Ctx, Alignment: std::max(a: KernArgBaseAlign, b: MaxAlign)));
380
381 return true;
382}
383
384bool AMDGPULowerKernelArguments::runOnFunction(Function &F) {
385 auto &TPC = getAnalysis<TargetPassConfig>();
386 const TargetMachine &TM = TPC.getTM<TargetMachine>();
387 DominatorTree &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree();
388 return lowerKernelArguments(F, TM, DT);
389}
390
391INITIALIZE_PASS_BEGIN(AMDGPULowerKernelArguments, DEBUG_TYPE,
392 "AMDGPU Lower Kernel Arguments", false, false)
393INITIALIZE_PASS_END(AMDGPULowerKernelArguments, DEBUG_TYPE, "AMDGPU Lower Kernel Arguments",
394 false, false)
395
396char AMDGPULowerKernelArguments::ID = 0;
397
398FunctionPass *llvm::createAMDGPULowerKernelArgumentsPass() {
399 return new AMDGPULowerKernelArguments();
400}
401
402PreservedAnalyses
403AMDGPULowerKernelArgumentsPass::run(Function &F, FunctionAnalysisManager &AM) {
404 DominatorTree &DT = AM.getResult<DominatorTreeAnalysis>(IR&: F);
405 bool Changed = lowerKernelArguments(F, TM, DT);
406 if (Changed) {
407 // TODO: Preserves a lot more.
408 PreservedAnalyses PA;
409 PA.preserveSet<CFGAnalyses>();
410 return PA;
411 }
412
413 return PreservedAnalyses::all();
414}
415