1//===-- AMDGPUMemoryUtils.cpp - -------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "AMDGPUMemoryUtils.h"
10#include "AMDGPU.h"
11#include "Utils/AMDGPUBaseInfo.h"
12#include "llvm/ADT/SetOperations.h"
13#include "llvm/Analysis/AliasAnalysis.h"
14#include "llvm/Analysis/CallGraph.h"
15#include "llvm/Analysis/MemorySSA.h"
16#include "llvm/IR/DataLayout.h"
17#include "llvm/IR/Instructions.h"
18#include "llvm/IR/IntrinsicInst.h"
19#include "llvm/IR/IntrinsicsAMDGPU.h"
20#include "llvm/IR/LLVMContext.h"
21#include "llvm/IR/ReplaceConstant.h"
22
23#define DEBUG_TYPE "amdgpu-memory-utils"
24
25using namespace llvm;
26
27namespace llvm::AMDGPU {
28
29Align getAlign(const DataLayout &DL, const GlobalVariable *GV) {
30 return DL.getValueOrABITypeAlignment(Alignment: GV->getPointerAlignment(DL),
31 Ty: GV->getValueType());
32}
33
34void copyMetadataForWidenedLoad(LoadInst &Dest, const LoadInst &Source) {
35 SmallVector<std::pair<unsigned, MDNode *>, 8> MD;
36 Source.getAllMetadata(MDs&: MD);
37 for (const auto [ID, N] : MD) {
38 switch (ID) {
39 case LLVMContext::MD_dbg:
40 case LLVMContext::MD_invariant_load:
41 case LLVMContext::MD_nontemporal:
42 Dest.setMetadata(KindID: ID, Node: N);
43 break;
44 default:
45 break;
46 }
47 }
48}
49
50// Returns the target extension type of a global variable,
51// which can only be a TargetExtType, an array or single-element struct of it,
52// or their nesting combination.
53// TODO: allow struct of multiple TargetExtType elements of the same type.
54// TODO: Disallow other uses of target("amdgcn.named.barrier") including:
55// - Structs containing barriers in different scope/rank
56// - Structs containing a mixture of barriers and other data.
57// - Globals in other address spaces.
58// - Allocas.
59static TargetExtType *getTargetExtType(const GlobalVariable &GV) {
60 Type *Ty = GV.getValueType();
61 while (true) {
62 if (auto *TTy = dyn_cast<TargetExtType>(Val: Ty))
63 return TTy;
64 if (auto *STy = dyn_cast<StructType>(Val: Ty)) {
65 if (STy->getNumElements() != 1)
66 return nullptr;
67 Ty = STy->getElementType(N: 0);
68 continue;
69 }
70 if (auto *ATy = dyn_cast<ArrayType>(Val: Ty)) {
71 Ty = ATy->getElementType();
72 continue;
73 }
74 return nullptr;
75 }
76}
77
78TargetExtType *isNamedBarrier(const GlobalVariable &GV) {
79 if (TargetExtType *Ty = getTargetExtType(GV))
80 return Ty->getName() == "amdgcn.named.barrier" ? Ty : nullptr;
81 return nullptr;
82}
83
84bool isDynamicLDS(const GlobalVariable &GV) {
85 // external zero size addrspace(3) without initializer is dynlds.
86 const Module *M = GV.getParent();
87 const DataLayout &DL = M->getDataLayout();
88 if (GV.getType()->getPointerAddressSpace() != AMDGPUAS::LOCAL_ADDRESS)
89 return false;
90 return GV.getGlobalSize(DL) == 0;
91}
92
93bool isLDSVariableToLower(const GlobalVariable &GV) {
94 if (GV.getType()->getPointerAddressSpace() != AMDGPUAS::LOCAL_ADDRESS) {
95 return false;
96 }
97 if (isDynamicLDS(GV)) {
98 return true;
99 }
100 if (GV.isConstant()) {
101 // A constant undef variable can't be written to, and any load is
102 // undef, so it should be eliminated by the optimizer. It could be
103 // dropped by the back end if not. This pass skips over it.
104 return false;
105 }
106 if (GV.hasInitializer() && !isa<UndefValue>(Val: GV.getInitializer())) {
107 // Initializers are unimplemented for LDS address space.
108 // Leave such variables in place for consistent error reporting.
109 return false;
110 }
111 return true;
112}
113
114bool eliminateGVConstantExprUsesFromAllInstructions(
115 Module &M, function_ref<bool(const GlobalVariable &)> Filter) {
116 SmallVector<Constant *> Worklist;
117 for (auto &GV : M.globals())
118 if (Filter(GV))
119 Worklist.push_back(Elt: &GV);
120 return convertUsersOfConstantsToInstructions(Consts: Worklist);
121}
122
123void getUsesOfGVByFunction(const CallGraph &CG, Module &M,
124 function_ref<bool(const GlobalVariable &)> Filter,
125 FunctionVariableMap &Kernels,
126 FunctionVariableMap &Functions) {
127 // Get uses from the current function, excluding uses by called Functions
128 // Two output variables to avoid walking the globals list twice
129 for (auto &GV : M.globals()) {
130 if (!Filter(GV))
131 continue;
132 for (User *V : GV.users()) {
133 if (auto *I = dyn_cast<Instruction>(Val: V)) {
134 Function *F = I->getFunction();
135 if (isKernel(F: *F))
136 Kernels[F].insert(V: &GV);
137 else
138 Functions[F].insert(V: &GV);
139 }
140 }
141 }
142}
143
144GVUsesInfoTy
145getTransitiveUsesOfGV(const CallGraph &CG, Module &M,
146 function_ref<bool(const GlobalVariable &)> Filter) {
147
148 FunctionVariableMap DirectMapKernel;
149 FunctionVariableMap DirectMapFunction;
150 getUsesOfGVByFunction(CG, M, Filter, Kernels&: DirectMapKernel, Functions&: DirectMapFunction);
151
152 // Collect functions whose address has escaped
153 DenseSet<Function *> AddressTakenFuncs;
154 for (Function &F : M.functions()) {
155 if (!isKernel(F))
156 if (F.hasAddressTaken(nullptr,
157 /* IgnoreCallbackUses */ false,
158 /* IgnoreAssumeLikeCalls */ false,
159 /* IgnoreLLVMUsed */ IngoreLLVMUsed: true,
160 /* IgnoreArcAttachedCall */ IgnoreARCAttachedCall: false)) {
161 AddressTakenFuncs.insert(V: &F);
162 }
163 }
164
165 // Collect variables that are used by functions whose address has escaped
166 DenseSet<GlobalVariable *> VariablesReachableThroughFunctionPointer;
167 for (Function *F : AddressTakenFuncs) {
168 set_union(S1&: VariablesReachableThroughFunctionPointer, S2: DirectMapFunction[F]);
169 }
170
171 auto FunctionMakesUnknownCall = [&](const Function *F) -> bool {
172 assert(!F->isDeclaration());
173 for (const CallGraphNode::CallRecord &R : *CG[F]) {
174 if (!R.second->getFunction())
175 return true;
176 }
177 return false;
178 };
179
180 // Work out which variables are reachable through function calls
181 FunctionVariableMap TransitiveMapFunction = DirectMapFunction;
182
183 // If the function makes any unknown call, assume the worst case that it can
184 // access all variables accessed by functions whose address escaped
185 for (Function &F : M.functions()) {
186 if (!F.isDeclaration() && FunctionMakesUnknownCall(&F)) {
187 if (!isKernel(F)) {
188 set_union(S1&: TransitiveMapFunction[&F],
189 S2: VariablesReachableThroughFunctionPointer);
190 }
191 }
192 }
193
194 // Direct implementation of collecting all variables reachable from each
195 // function
196 for (Function &Func : M.functions()) {
197 if (Func.isDeclaration() || isKernel(F: Func))
198 continue;
199
200 DenseSet<Function *> seen; // catches cycles
201 SmallVector<Function *, 4> wip = {&Func};
202
203 while (!wip.empty()) {
204 Function *F = wip.pop_back_val();
205
206 // Can accelerate this by referring to transitive map for functions that
207 // have already been computed, with more care than this
208 set_union(S1&: TransitiveMapFunction[&Func], S2: DirectMapFunction[F]);
209
210 for (const CallGraphNode::CallRecord &R : *CG[F]) {
211 Function *Ith = R.second->getFunction();
212 if (Ith) {
213 if (!seen.contains(V: Ith)) {
214 seen.insert(V: Ith);
215 wip.push_back(Elt: Ith);
216 }
217 }
218 }
219 }
220 }
221
222 // Collect variables that are transitively used by functions whose address has
223 // escaped
224 for (Function *F : AddressTakenFuncs) {
225 set_union(S1&: VariablesReachableThroughFunctionPointer,
226 S2: TransitiveMapFunction[F]);
227 }
228
229 // DirectMapKernel lists which variables are used by the kernel
230 // find the variables which are used through a function call
231 FunctionVariableMap IndirectMapKernel;
232
233 for (Function &Func : M.functions()) {
234 if (Func.isDeclaration() || !isKernel(F: Func))
235 continue;
236
237 for (const CallGraphNode::CallRecord &R : *CG[&Func]) {
238 Function *Ith = R.second->getFunction();
239 if (Ith) {
240 set_union(S1&: IndirectMapKernel[&Func], S2: TransitiveMapFunction[Ith]);
241 }
242 }
243
244 // Check if the kernel encounters unknows calls, wheher directly or
245 // indirectly.
246 bool SeesUnknownCalls = [&]() {
247 SmallVector<Function *> WorkList = {CG[&Func]->getFunction()};
248 SmallPtrSet<Function *, 8> Visited;
249
250 while (!WorkList.empty()) {
251 Function *F = WorkList.pop_back_val();
252
253 for (const CallGraphNode::CallRecord &CallRecord : *CG[F]) {
254 if (!CallRecord.second)
255 continue;
256
257 Function *Callee = CallRecord.second->getFunction();
258 if (!Callee)
259 return true;
260
261 if (Visited.insert(Ptr: Callee).second)
262 WorkList.push_back(Elt: Callee);
263 }
264 }
265 return false;
266 }();
267
268 if (SeesUnknownCalls) {
269 set_union(S1&: IndirectMapKernel[&Func],
270 S2: VariablesReachableThroughFunctionPointer);
271 }
272 }
273
274 return {.DirectAccess: std::move(DirectMapKernel), .IndirectAccess: std::move(IndirectMapKernel)};
275}
276
277GVUsesInfoTy getTransitiveUsesOfLDSForLowering(const CallGraph &CG, Module &M) {
278 GVUsesInfoTy UsesInfo = getTransitiveUsesOfGV(CG, M, Filter: isLDSVariableToLower);
279 // Verify that we fall into one of 2 cases:
280 // - All variables are either absolute
281 // or direct mapped dynamic LDS that is not lowered.
282 // - No variables are absolute.
283 // Named-barriers which are absolute symbols are removed
284 // from the maps.
285 std::optional<bool> HasAbsoluteGVs;
286 for (auto &Map : {UsesInfo.DirectAccess, UsesInfo.IndirectAccess}) {
287 for (auto &[Fn, GVs] : Map) {
288 for (auto *GV : GVs) {
289 bool IsAbsolute = GV->isAbsoluteSymbolRef();
290 bool IsDirectMapDynLDSGV =
291 AMDGPU::isDynamicLDS(GV: *GV) && UsesInfo.DirectAccess.contains(Val: Fn);
292 if (IsDirectMapDynLDSGV)
293 continue;
294
295 // TODO: Remove once barriers are no longer in the LDS AS.
296 if (isNamedBarrier(GV: *GV)) {
297 if (IsAbsolute) {
298 UsesInfo.DirectAccess[Fn].erase(V: GV);
299 UsesInfo.IndirectAccess[Fn].erase(V: GV);
300 }
301 continue;
302 }
303
304 if (HasAbsoluteGVs.has_value()) {
305 if (*HasAbsoluteGVs != IsAbsolute) {
306 reportFatalUsageError(
307 reason: "module cannot mix absolute and non-absolute LDS GVs");
308 }
309 } else
310 HasAbsoluteGVs = IsAbsolute;
311 }
312 }
313 }
314
315 // If we only had absolute GVs, we have nothing to do, return an empty
316 // result.
317 if (HasAbsoluteGVs && *HasAbsoluteGVs)
318 return GVUsesInfoTy();
319
320 return UsesInfo;
321}
322
323void removeFnAttrFromReachable(CallGraph &CG, Function *KernelRoot,
324 ArrayRef<StringRef> FnAttrs) {
325 for (StringRef Attr : FnAttrs)
326 KernelRoot->removeFnAttr(Kind: Attr);
327
328 SmallVector<Function *> WorkList = {CG[KernelRoot]->getFunction()};
329 SmallPtrSet<Function *, 8> Visited;
330 bool SeenUnknownCall = false;
331
332 while (!WorkList.empty()) {
333 Function *F = WorkList.pop_back_val();
334
335 for (auto &CallRecord : *CG[F]) {
336 if (!CallRecord.second)
337 continue;
338
339 Function *Callee = CallRecord.second->getFunction();
340 if (!Callee) {
341 if (!SeenUnknownCall) {
342 SeenUnknownCall = true;
343
344 // If we see any indirect calls, assume nothing about potential
345 // targets.
346 // TODO: This could be refined to possible LDS global users.
347 for (auto &ExternalCallRecord : *CG.getExternalCallingNode()) {
348 Function *PotentialCallee =
349 ExternalCallRecord.second->getFunction();
350 assert(PotentialCallee);
351 if (!isKernel(F: *PotentialCallee)) {
352 for (StringRef Attr : FnAttrs)
353 PotentialCallee->removeFnAttr(Kind: Attr);
354 }
355 }
356 }
357 } else {
358 for (StringRef Attr : FnAttrs)
359 Callee->removeFnAttr(Kind: Attr);
360 if (Visited.insert(Ptr: Callee).second)
361 WorkList.push_back(Elt: Callee);
362 }
363 }
364 }
365}
366
367bool isReallyAClobber(const Value *Ptr, MemoryDef *Def, AAResults *AA) {
368 Instruction *DefInst = Def->getMemoryInst();
369
370 if (isa<FenceInst>(Val: DefInst))
371 return false;
372
373 if (const IntrinsicInst *II = dyn_cast<IntrinsicInst>(Val: DefInst)) {
374 switch (II->getIntrinsicID()) {
375 case Intrinsic::amdgcn_s_barrier:
376 case Intrinsic::amdgcn_s_cluster_barrier:
377 case Intrinsic::amdgcn_s_barrier_signal:
378 case Intrinsic::amdgcn_s_barrier_signal_var:
379 case Intrinsic::amdgcn_s_barrier_signal_isfirst:
380 case Intrinsic::amdgcn_s_barrier_init:
381 case Intrinsic::amdgcn_s_barrier_join:
382 case Intrinsic::amdgcn_s_barrier_wait:
383 case Intrinsic::amdgcn_s_barrier_leave:
384 case Intrinsic::amdgcn_s_get_barrier_state:
385 case Intrinsic::amdgcn_s_wakeup_barrier:
386 case Intrinsic::amdgcn_wave_barrier:
387 case Intrinsic::amdgcn_sched_barrier:
388 case Intrinsic::amdgcn_sched_group_barrier:
389 case Intrinsic::amdgcn_iglp_opt:
390 return false;
391 default:
392 break;
393 }
394 }
395
396 // Ignore atomics not aliasing with the original load, any atomic is a
397 // universal MemoryDef from MSSA's point of view too, just like a fence.
398 const auto checkNoAlias = [AA, Ptr](auto I) -> bool {
399 return I && AA->isNoAlias(I->getPointerOperand(), Ptr);
400 };
401
402 if (checkNoAlias(dyn_cast<AtomicCmpXchgInst>(Val: DefInst)) ||
403 checkNoAlias(dyn_cast<AtomicRMWInst>(Val: DefInst)))
404 return false;
405
406 return true;
407}
408
409bool isClobberedInFunction(const LoadInst *Load, MemorySSA *MSSA,
410 AAResults *AA) {
411 MemorySSAWalker *Walker = MSSA->getWalker();
412 SmallVector<MemoryAccess *> WorkList{Walker->getClobberingMemoryAccess(I: Load)};
413 SmallPtrSet<MemoryAccess *, 8> Visited;
414 MemoryLocation Loc(MemoryLocation::get(LI: Load));
415
416 LLVM_DEBUG(dbgs() << "Checking clobbering of: " << *Load << '\n');
417
418 // Start with a nearest dominating clobbering access, it will be either
419 // live on entry (nothing to do, load is not clobbered), MemoryDef, or
420 // MemoryPhi if several MemoryDefs can define this memory state. In that
421 // case add all Defs to WorkList and continue going up and checking all
422 // the definitions of this memory location until the root. When all the
423 // defs are exhausted and came to the entry state we have no clobber.
424 // Along the scan ignore barriers and fences which are considered clobbers
425 // by the MemorySSA, but not really writing anything into the memory.
426 while (!WorkList.empty()) {
427 MemoryAccess *MA = WorkList.pop_back_val();
428 if (!Visited.insert(Ptr: MA).second)
429 continue;
430
431 if (MSSA->isLiveOnEntryDef(MA))
432 continue;
433
434 if (MemoryDef *Def = dyn_cast<MemoryDef>(Val: MA)) {
435 LLVM_DEBUG(dbgs() << " Def: " << *Def->getMemoryInst() << '\n');
436
437 if (isReallyAClobber(Ptr: Load->getPointerOperand(), Def, AA)) {
438 LLVM_DEBUG(dbgs() << " -> load is clobbered\n");
439 return true;
440 }
441
442 WorkList.push_back(
443 Elt: Walker->getClobberingMemoryAccess(MA: Def->getDefiningAccess(), Loc));
444 continue;
445 }
446
447 const MemoryPhi *Phi = cast<MemoryPhi>(Val: MA);
448 for (const auto &Use : Phi->incoming_values())
449 WorkList.push_back(Elt: cast<MemoryAccess>(Val: &Use));
450 }
451
452 LLVM_DEBUG(dbgs() << " -> no clobber\n");
453 return false;
454}
455
456} // end namespace llvm::AMDGPU
457