1//===-- AMDGPULowerBufferFatPointers.cpp ---------------------------=//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This pass lowers operations on buffer fat pointers (addrspace 7) to
10// operations on buffer resources (addrspace 8) and is needed for correct
11// codegen.
12//
13// # Background
14//
15// Address space 7 (the buffer fat pointer) is a 160-bit pointer that consists
16// of a 128-bit buffer descriptor and a 32-bit offset into that descriptor.
17// The buffer resource part needs to be it needs to be a "raw" buffer resource
18// (it must have a stride of 0 and bounds checks must be in raw buffer mode
19// or disabled).
20//
21// When these requirements are met, a buffer resource can be treated as a
22// typical (though quite wide) pointer that follows typical LLVM pointer
23// semantics. This allows the frontend to reason about such buffers (which are
24// often encountered in the context of SPIR-V kernels).
25//
26// However, because of their non-power-of-2 size, these fat pointers cannot be
27// present during translation to MIR (though this restriction may be lifted
28// during the transition to GlobalISel). Therefore, this pass is needed in order
29// to correctly implement these fat pointers.
30//
31// The resource intrinsics take the resource part (the address space 8 pointer)
32// and the offset part (the 32-bit integer) as separate arguments. In addition,
33// many users of these buffers manipulate the offset while leaving the resource
34// part alone. For these reasons, we want to typically separate the resource
35// and offset parts into separate variables, but combine them together when
36// encountering cases where this is required, such as by inserting these values
37// into aggretates or moving them to memory.
38//
39// Therefore, at a high level, `ptr addrspace(7) %x` becomes `ptr addrspace(8)
40// %x.rsrc` and `i32 %x.off`, which will be combined into `{ptr addrspace(8),
41// i32} %x = {%x.rsrc, %x.off}` if needed. Similarly, `vector<Nxp7>` becomes
42// `{vector<Nxp8>, vector<Nxi32 >}` and its component parts.
43//
44// # Implementation
45//
46// This pass proceeds in three main phases:
47//
48// ## Rewriting loads and stores of p7 and memcpy()-like handling
49//
50// The first phase is to rewrite away all loads and stors of `ptr addrspace(7)`,
51// including aggregates containing such pointers, to ones that use `i160`. This
52// is handled by `StoreFatPtrsAsIntsAndExpandMemcpyVisitor` , which visits
53// loads, stores, and allocas and, if the loaded or stored type contains `ptr
54// addrspace(7)`, rewrites that type to one where the p7s are replaced by i160s,
55// copying other parts of aggregates as needed. In the case of a store, each
56// pointer is `ptrtoint`d to i160 before storing, and load integers are
57// `inttoptr`d back. This same transformation is applied to vectors of pointers.
58//
59// Such a transformation allows the later phases of the pass to not need
60// to handle buffer fat pointers moving to and from memory, where we load
61// have to handle the incompatibility between a `{Nxp8, Nxi32}` representation
62// and `Nxi60` directly. Instead, that transposing action (where the vectors
63// of resources and vectors of offsets are concatentated before being stored to
64// memory) are handled through implementing `inttoptr` and `ptrtoint` only.
65//
66// Atomics operations on `ptr addrspace(7)` values are not suppported, as the
67// hardware does not include a 160-bit atomic.
68//
69// In order to save on O(N) work and to ensure that the contents type
70// legalizer correctly splits up wide loads, also unconditionally lower
71// memcpy-like intrinsics into loops here.
72//
73// ## Buffer contents type legalization
74//
75// The underlying buffer intrinsics only support types up to 128 bits long,
76// and don't support complex types. If buffer operations were
77// standard pointer operations that could be represented as MIR-level loads,
78// this would be handled by the various legalization schemes in instruction
79// selection. However, because we have to do the conversion from `load` and
80// `store` to intrinsics at LLVM IR level, we must perform that legalization
81// ourselves.
82//
83// This involves a combination of
84// - Converting arrays to vectors where possible
85// - Otherwise, splitting loads and stores of aggregates into loads/stores of
86// each component.
87// - Zero-extending things to fill a whole number of bytes
88// - Casting values of types that don't neatly correspond to supported machine
89// value
90// (for example, an i96 or i256) into ones that would work (
91// like <3 x i32> and <8 x i32>, respectively)
92// - Splitting values that are too long (such as aforementioned <8 x i32>) into
93// multiple operations.
94//
95// ## Type remapping
96//
97// We use a `ValueMapper` to mangle uses of [vectors of] buffer fat pointers
98// to the corresponding struct type, which has a resource part and an offset
99// part.
100//
101// This uses a `BufferFatPtrToStructTypeMap` and a `FatPtrConstMaterializer`
102// to, usually by way of `setType`ing values. Constants are handled here
103// because there isn't a good way to fix them up later.
104//
105// This has the downside of leaving the IR in an invalid state (for example,
106// the instruction `getelementptr {ptr addrspace(8), i32} %p, ...` will exist),
107// but all such invalid states will be resolved by the third phase.
108//
109// Functions that don't take buffer fat pointers are modified in place. Those
110// that do take such pointers have their basic blocks moved to a new function
111// with arguments that are {ptr addrspace(8), i32} arguments and return values.
112// This phase also records intrinsics so that they can be remangled or deleted
113// later.
114//
115// ## Splitting pointer structs
116//
117// The meat of this pass consists of defining semantics for operations that
118// produce or consume [vectors of] buffer fat pointers in terms of their
119// resource and offset parts. This is accomplished throgh the `SplitPtrStructs`
120// visitor.
121//
122// In the first pass through each function that is being lowered, the splitter
123// inserts new instructions to implement the split-structures behavior, which is
124// needed for correctness and performance. It records a list of "split users",
125// instructions that are being replaced by operations on the resource and offset
126// parts.
127//
128// Split users do not necessarily need to produce parts themselves (
129// a `load float, ptr addrspace(7)` does not, for example), but, if they do not
130// generate fat buffer pointers, they must RAUW in their replacement
131// instructions during the initial visit.
132//
133// When these new instructions are created, they use the split parts recorded
134// for their initial arguments in order to generate their replacements, creating
135// a parallel set of instructions that does not refer to the original fat
136// pointer values but instead to their resource and offset components.
137//
138// Instructions, such as `extractvalue`, that produce buffer fat pointers from
139// sources that do not have split parts, have such parts generated using
140// `extractvalue`. This is also the initial handling of PHI nodes, which
141// are then cleaned up.
142//
143// ### Conditionals
144//
145// PHI nodes are initially given resource parts via `extractvalue`. However,
146// this is not an efficient rewrite of such nodes, as, in most cases, the
147// resource part in a conditional or loop remains constant throughout the loop
148// and only the offset varies. Failing to optimize away these constant resources
149// would cause additional registers to be sent around loops and might lead to
150// waterfall loops being generated for buffer operations due to the
151// "non-uniform" resource argument.
152//
153// Therefore, after all instructions have been visited, the pointer splitter
154// post-processes all encountered conditionals. Given a PHI node or select,
155// getPossibleRsrcRoots() collects all values that the resource parts of that
156// conditional's input could come from as well as collecting all conditional
157// instructions encountered during the search. If, after filtering out the
158// initial node itself, the set of encountered conditionals is a subset of the
159// potential roots and there is a single potential resource that isn't in the
160// conditional set, that value is the only possible value the resource argument
161// could have throughout the control flow.
162//
163// If that condition is met, then a PHI node can have its resource part changed
164// to the singleton value and then be replaced by a PHI on the offsets.
165// Otherwise, each PHI node is split into two, one for the resource part and one
166// for the offset part, which replace the temporary `extractvalue` instructions
167// that were added during the first pass.
168//
169// Similar logic applies to `select`, where
170// `%z = select i1 %cond, %cond, ptr addrspace(7) %x, ptr addrspace(7) %y`
171// can be split into `%z.rsrc = %x.rsrc` and
172// `%z.off = select i1 %cond, ptr i32 %x.off, i32 %y.off`
173// if both `%x` and `%y` have the same resource part, but two `select`
174// operations will be needed if they do not.
175//
176// ### Final processing
177//
178// After conditionals have been cleaned up, the IR for each function is
179// rewritten to remove all the old instructions that have been split up.
180//
181// Any instruction that used to produce a buffer fat pointer (and therefore now
182// produces a resource-and-offset struct after type remapping) is
183// replaced as follows:
184// 1. All debug value annotations are cloned to reflect that the resource part
185// and offset parts are computed separately and constitute different
186// fragments of the underlying source language variable.
187// 2. All uses that were themselves split are replaced by a `poison` of the
188// struct type, as they will themselves be erased soon. This rule, combined
189// with debug handling, should leave the use lists of split instructions
190// empty in almost all cases.
191// 3. If a user of the original struct-valued result remains, the structure
192// needed for the new types to work is constructed out of the newly-defined
193// parts, and the original instruction is replaced by this structure
194// before being erased. Instructions requiring this construction include
195// `ret` and `insertvalue`.
196//
197// # Consequences
198//
199// This pass does not alter the CFG.
200//
201// Alias analysis information will become coarser, as the LLVM alias analyzer
202// cannot handle the buffer intrinsics. Specifically, while we can determine
203// that the following two loads do not alias:
204// ```
205// %y = getelementptr i32, ptr addrspace(7) %x, i32 1
206// %a = load i32, ptr addrspace(7) %x
207// %b = load i32, ptr addrspace(7) %y
208// ```
209// we cannot (except through some code that runs during scheduling) determine
210// that the rewritten loads below do not alias.
211// ```
212// %y.off = add i32 %x.off, 1
213// %a = call @llvm.amdgcn.raw.ptr.buffer.load(ptr addrspace(8) %x.rsrc, i32
214// %x.off, ...)
215// %b = call @llvm.amdgcn.raw.ptr.buffer.load(ptr addrspace(8)
216// %x.rsrc, i32 %y.off, ...)
217// ```
218// However, existing alias information is preserved.
219//===----------------------------------------------------------------------===//
220
221#include "AMDGPU.h"
222#include "AMDGPUTargetMachine.h"
223#include "GCNSubtarget.h"
224#include "SIDefines.h"
225#include "llvm/ADT/SetOperations.h"
226#include "llvm/ADT/SmallVector.h"
227#include "llvm/Analysis/InstSimplifyFolder.h"
228#include "llvm/Analysis/TargetTransformInfo.h"
229#include "llvm/Analysis/Utils/Local.h"
230#include "llvm/CodeGen/TargetPassConfig.h"
231#include "llvm/IR/AttributeMask.h"
232#include "llvm/IR/Constants.h"
233#include "llvm/IR/DebugInfo.h"
234#include "llvm/IR/DerivedTypes.h"
235#include "llvm/IR/IRBuilder.h"
236#include "llvm/IR/InstIterator.h"
237#include "llvm/IR/InstVisitor.h"
238#include "llvm/IR/Instructions.h"
239#include "llvm/IR/IntrinsicInst.h"
240#include "llvm/IR/Intrinsics.h"
241#include "llvm/IR/IntrinsicsAMDGPU.h"
242#include "llvm/IR/Metadata.h"
243#include "llvm/IR/Operator.h"
244#include "llvm/IR/PatternMatch.h"
245#include "llvm/IR/ReplaceConstant.h"
246#include "llvm/IR/ValueHandle.h"
247#include "llvm/InitializePasses.h"
248#include "llvm/Pass.h"
249#include "llvm/Support/AMDGPUAddrSpace.h"
250#include "llvm/Support/Alignment.h"
251#include "llvm/Support/AtomicOrdering.h"
252#include "llvm/Support/Debug.h"
253#include "llvm/Support/ErrorHandling.h"
254#include "llvm/Transforms/Utils/Cloning.h"
255#include "llvm/Transforms/Utils/Local.h"
256#include "llvm/Transforms/Utils/LowerMemIntrinsics.h"
257#include "llvm/Transforms/Utils/ValueMapper.h"
258
259#define DEBUG_TYPE "amdgpu-lower-buffer-fat-pointers"
260
261using namespace llvm;
262
263static constexpr unsigned BufferOffsetWidth = 32;
264
265namespace {
266/// Recursively replace instances of ptr addrspace(7) and vector<Nxptr
267/// addrspace(7)> with some other type as defined by the relevant subclass.
268class BufferFatPtrTypeLoweringBase : public ValueMapTypeRemapper {
269 DenseMap<Type *, Type *> Map;
270
271 Type *remapTypeImpl(Type *Ty);
272
273protected:
274 virtual Type *remapScalar(PointerType *PT) = 0;
275 virtual Type *remapVector(VectorType *VT) = 0;
276
277 const DataLayout &DL;
278
279public:
280 BufferFatPtrTypeLoweringBase(const DataLayout &DL) : DL(DL) {}
281 Type *remapType(Type *SrcTy) override;
282 void clear() { Map.clear(); }
283};
284
285/// Remap ptr addrspace(7) to i160 and vector<Nxptr addrspace(7)> to
286/// vector<Nxi60> in order to correctly handling loading/storing these values
287/// from memory.
288class BufferFatPtrToIntTypeMap : public BufferFatPtrTypeLoweringBase {
289 using BufferFatPtrTypeLoweringBase::BufferFatPtrTypeLoweringBase;
290
291protected:
292 Type *remapScalar(PointerType *PT) override { return DL.getIntPtrType(PT); }
293 Type *remapVector(VectorType *VT) override { return DL.getIntPtrType(VT); }
294};
295
296/// Remap ptr addrspace(7) to {ptr addrspace(8), i32} (the resource and offset
297/// parts of the pointer) so that we can easily rewrite operations on these
298/// values that aren't loading them from or storing them to memory.
299class BufferFatPtrToStructTypeMap : public BufferFatPtrTypeLoweringBase {
300 using BufferFatPtrTypeLoweringBase::BufferFatPtrTypeLoweringBase;
301
302protected:
303 Type *remapScalar(PointerType *PT) override;
304 Type *remapVector(VectorType *VT) override;
305};
306} // namespace
307
308// This code is adapted from the type remapper in lib/Linker/IRMover.cpp
309Type *BufferFatPtrTypeLoweringBase::remapTypeImpl(Type *Ty) {
310 Type **Entry = &Map[Ty];
311 if (*Entry)
312 return *Entry;
313 if (auto *PT = dyn_cast<PointerType>(Val: Ty)) {
314 if (PT->getAddressSpace() == AMDGPUAS::BUFFER_FAT_POINTER) {
315 return *Entry = remapScalar(PT);
316 }
317 }
318 if (auto *VT = dyn_cast<VectorType>(Val: Ty)) {
319 auto *PT = dyn_cast<PointerType>(Val: VT->getElementType());
320 if (PT && PT->getAddressSpace() == AMDGPUAS::BUFFER_FAT_POINTER) {
321 return *Entry = remapVector(VT);
322 }
323 return *Entry = Ty;
324 }
325 // Whether the type is one that is structurally uniqued - that is, if it is
326 // not a named struct (the only kind of type where multiple structurally
327 // identical types that have a distinct `Type*`)
328 StructType *TyAsStruct = dyn_cast<StructType>(Val: Ty);
329 bool IsUniqued = !TyAsStruct || TyAsStruct->isLiteral();
330 // Base case for ints, floats, opaque pointers, and so on, which don't
331 // require recursion.
332 if (Ty->getNumContainedTypes() == 0 && IsUniqued)
333 return *Entry = Ty;
334 bool Changed = false;
335 SmallVector<Type *> ElementTypes(Ty->getNumContainedTypes(), nullptr);
336 for (unsigned int I = 0, E = Ty->getNumContainedTypes(); I < E; ++I) {
337 Type *OldElem = Ty->getContainedType(i: I);
338 Type *NewElem = remapTypeImpl(Ty: OldElem);
339 ElementTypes[I] = NewElem;
340 Changed |= (OldElem != NewElem);
341 }
342 // Recursive calls to remapTypeImpl() may have invalidated pointer.
343 Entry = &Map[Ty];
344 if (!Changed) {
345 return *Entry = Ty;
346 }
347 if (auto *ArrTy = dyn_cast<ArrayType>(Val: Ty))
348 return *Entry = ArrayType::get(ElementType: ElementTypes[0], NumElements: ArrTy->getNumElements());
349 if (auto *FnTy = dyn_cast<FunctionType>(Val: Ty))
350 return *Entry = FunctionType::get(Result: ElementTypes[0],
351 Params: ArrayRef(ElementTypes).slice(N: 1),
352 isVarArg: FnTy->isVarArg());
353 if (auto *STy = dyn_cast<StructType>(Val: Ty)) {
354 // Genuine opaque types don't have a remapping.
355 if (STy->isOpaque())
356 return *Entry = Ty;
357 bool IsPacked = STy->isPacked();
358 if (IsUniqued)
359 return *Entry = StructType::get(Context&: Ty->getContext(), Elements: ElementTypes, isPacked: IsPacked);
360 SmallString<16> Name(STy->getName());
361 STy->setName("");
362 return *Entry = StructType::create(Context&: Ty->getContext(), Elements: ElementTypes, Name,
363 isPacked: IsPacked);
364 }
365 llvm_unreachable("Unknown type of type that contains elements");
366}
367
368Type *BufferFatPtrTypeLoweringBase::remapType(Type *SrcTy) {
369 return remapTypeImpl(Ty: SrcTy);
370}
371
372Type *BufferFatPtrToStructTypeMap::remapScalar(PointerType *PT) {
373 LLVMContext &Ctx = PT->getContext();
374 return StructType::get(elt1: PointerType::get(C&: Ctx, AddressSpace: AMDGPUAS::BUFFER_RESOURCE),
375 elts: IntegerType::get(C&: Ctx, NumBits: BufferOffsetWidth));
376}
377
378Type *BufferFatPtrToStructTypeMap::remapVector(VectorType *VT) {
379 ElementCount EC = VT->getElementCount();
380 LLVMContext &Ctx = VT->getContext();
381 Type *RsrcVec =
382 VectorType::get(ElementType: PointerType::get(C&: Ctx, AddressSpace: AMDGPUAS::BUFFER_RESOURCE), EC);
383 Type *OffVec = VectorType::get(ElementType: IntegerType::get(C&: Ctx, NumBits: BufferOffsetWidth), EC);
384 return StructType::get(elt1: RsrcVec, elts: OffVec);
385}
386
387static bool isBufferFatPtrOrVector(Type *Ty) {
388 if (auto *PT = dyn_cast<PointerType>(Val: Ty->getScalarType()))
389 return PT->getAddressSpace() == AMDGPUAS::BUFFER_FAT_POINTER;
390 return false;
391}
392
393// True if the type is {ptr addrspace(8), i32} or a struct containing vectors of
394// those types. Used to quickly skip instructions we don't need to process.
395static bool isSplitFatPtr(Type *Ty) {
396 auto *ST = dyn_cast<StructType>(Val: Ty);
397 if (!ST)
398 return false;
399 if (!ST->isLiteral() || ST->getNumElements() != 2)
400 return false;
401 auto *MaybeRsrc =
402 dyn_cast<PointerType>(Val: ST->getElementType(N: 0)->getScalarType());
403 auto *MaybeOff =
404 dyn_cast<IntegerType>(Val: ST->getElementType(N: 1)->getScalarType());
405 return MaybeRsrc && MaybeOff &&
406 MaybeRsrc->getAddressSpace() == AMDGPUAS::BUFFER_RESOURCE &&
407 MaybeOff->getBitWidth() == BufferOffsetWidth;
408}
409
410// True if the result type or any argument types are buffer fat pointers.
411static bool isBufferFatPtrConst(Constant *C) {
412 Type *T = C->getType();
413 return isBufferFatPtrOrVector(Ty: T) || any_of(Range: C->operands(), P: [](const Use &U) {
414 return isBufferFatPtrOrVector(Ty: U.get()->getType());
415 });
416}
417
418namespace {
419/// Convert [vectors of] buffer fat pointers to integers when they are read from
420/// or stored to memory. This ensures that these pointers will have the same
421/// memory layout as before they are lowered, even though they will no longer
422/// have their previous layout in registers/in the program (they'll be broken
423/// down into resource and offset parts). This has the downside of imposing
424/// marshalling costs when reading or storing these values, but since placing
425/// such pointers into memory is an uncommon operation at best, we feel that
426/// this cost is acceptable for better performance in the common case.
427class StoreFatPtrsAsIntsAndExpandMemcpyVisitor
428 : public InstVisitor<StoreFatPtrsAsIntsAndExpandMemcpyVisitor, bool> {
429 BufferFatPtrToIntTypeMap *TypeMap;
430
431 ValueToValueMapTy ConvertedForStore;
432
433 IRBuilder<InstSimplifyFolder> IRB;
434
435 const TargetMachine *TM;
436
437 // Convert all the buffer fat pointers within the input value to inttegers
438 // so that it can be stored in memory.
439 Value *fatPtrsToInts(Value *V, Type *From, Type *To, const Twine &Name);
440 // Convert all the i160s that need to be buffer fat pointers (as specified)
441 // by the To type) into those pointers to preserve the semantics of the rest
442 // of the program.
443 Value *intsToFatPtrs(Value *V, Type *From, Type *To, const Twine &Name);
444
445public:
446 StoreFatPtrsAsIntsAndExpandMemcpyVisitor(BufferFatPtrToIntTypeMap *TypeMap,
447 const DataLayout &DL,
448 LLVMContext &Ctx,
449 const TargetMachine *TM)
450 : TypeMap(TypeMap), IRB(Ctx, InstSimplifyFolder(DL)), TM(TM) {}
451 bool processFunction(Function &F);
452
453 bool visitInstruction(Instruction &I) { return false; }
454 bool visitAllocaInst(AllocaInst &I);
455 bool visitLoadInst(LoadInst &LI);
456 bool visitStoreInst(StoreInst &SI);
457 bool visitGetElementPtrInst(GetElementPtrInst &I);
458
459 bool visitMemCpyInst(MemCpyInst &MCI);
460 bool visitMemMoveInst(MemMoveInst &MMI);
461 bool visitMemSetInst(MemSetInst &MSI);
462 bool visitMemSetPatternInst(MemSetPatternInst &MSPI);
463};
464} // namespace
465
466Value *StoreFatPtrsAsIntsAndExpandMemcpyVisitor::fatPtrsToInts(
467 Value *V, Type *From, Type *To, const Twine &Name) {
468 if (From == To)
469 return V;
470 ValueToValueMapTy::iterator Find = ConvertedForStore.find(Val: V);
471 if (Find != ConvertedForStore.end())
472 return Find->second;
473 if (isBufferFatPtrOrVector(Ty: From)) {
474 Value *Cast = IRB.CreatePtrToInt(V, DestTy: To, Name: Name + ".int");
475 ConvertedForStore[V] = Cast;
476 return Cast;
477 }
478 if (From->getNumContainedTypes() == 0)
479 return V;
480 // Structs, arrays, and other compound types.
481 Value *Ret = PoisonValue::get(T: To);
482 if (auto *AT = dyn_cast<ArrayType>(Val: From)) {
483 Type *FromPart = AT->getArrayElementType();
484 Type *ToPart = cast<ArrayType>(Val: To)->getElementType();
485 for (uint64_t I = 0, E = AT->getArrayNumElements(); I < E; ++I) {
486 Value *Field = IRB.CreateExtractValue(Agg: V, Idxs: I);
487 Value *NewField =
488 fatPtrsToInts(V: Field, From: FromPart, To: ToPart, Name: Name + "." + Twine(I));
489 Ret = IRB.CreateInsertValue(Agg: Ret, Val: NewField, Idxs: I);
490 }
491 } else {
492 for (auto [Idx, FromPart, ToPart] :
493 enumerate(First: From->subtypes(), Rest: To->subtypes())) {
494 Value *Field = IRB.CreateExtractValue(Agg: V, Idxs: Idx);
495 Value *NewField =
496 fatPtrsToInts(V: Field, From: FromPart, To: ToPart, Name: Name + "." + Twine(Idx));
497 Ret = IRB.CreateInsertValue(Agg: Ret, Val: NewField, Idxs: Idx);
498 }
499 }
500 ConvertedForStore[V] = Ret;
501 return Ret;
502}
503
504Value *StoreFatPtrsAsIntsAndExpandMemcpyVisitor::intsToFatPtrs(
505 Value *V, Type *From, Type *To, const Twine &Name) {
506 if (From == To)
507 return V;
508 if (isBufferFatPtrOrVector(Ty: To)) {
509 Value *Cast = IRB.CreateIntToPtr(V, DestTy: To, Name: Name + ".ptr");
510 return Cast;
511 }
512 if (From->getNumContainedTypes() == 0)
513 return V;
514 // Structs, arrays, and other compound types.
515 Value *Ret = PoisonValue::get(T: To);
516 if (auto *AT = dyn_cast<ArrayType>(Val: From)) {
517 Type *FromPart = AT->getArrayElementType();
518 Type *ToPart = cast<ArrayType>(Val: To)->getElementType();
519 for (uint64_t I = 0, E = AT->getArrayNumElements(); I < E; ++I) {
520 Value *Field = IRB.CreateExtractValue(Agg: V, Idxs: I);
521 Value *NewField =
522 intsToFatPtrs(V: Field, From: FromPart, To: ToPart, Name: Name + "." + Twine(I));
523 Ret = IRB.CreateInsertValue(Agg: Ret, Val: NewField, Idxs: I);
524 }
525 } else {
526 for (auto [Idx, FromPart, ToPart] :
527 enumerate(First: From->subtypes(), Rest: To->subtypes())) {
528 Value *Field = IRB.CreateExtractValue(Agg: V, Idxs: Idx);
529 Value *NewField =
530 intsToFatPtrs(V: Field, From: FromPart, To: ToPart, Name: Name + "." + Twine(Idx));
531 Ret = IRB.CreateInsertValue(Agg: Ret, Val: NewField, Idxs: Idx);
532 }
533 }
534 return Ret;
535}
536
537bool StoreFatPtrsAsIntsAndExpandMemcpyVisitor::processFunction(Function &F) {
538 bool Changed = false;
539 // Process memcpy-like instructions after the main iteration because they can
540 // invalidate iterators.
541 SmallVector<WeakTrackingVH> CanBecomeLoops;
542 for (Instruction &I : make_early_inc_range(Range: instructions(F))) {
543 if (isa<MemTransferInst, MemSetInst, MemSetPatternInst>(Val: I))
544 CanBecomeLoops.push_back(Elt: &I);
545 else
546 Changed |= visit(I);
547 }
548 for (WeakTrackingVH VH : make_early_inc_range(Range&: CanBecomeLoops)) {
549 Changed |= visit(I: cast<Instruction>(Val&: VH));
550 }
551 ConvertedForStore.clear();
552 return Changed;
553}
554
555bool StoreFatPtrsAsIntsAndExpandMemcpyVisitor::visitAllocaInst(AllocaInst &I) {
556 Type *Ty = I.getAllocatedType();
557 Type *NewTy = TypeMap->remapType(SrcTy: Ty);
558 if (Ty == NewTy)
559 return false;
560 I.setAllocatedType(NewTy);
561 return true;
562}
563
564bool StoreFatPtrsAsIntsAndExpandMemcpyVisitor::visitGetElementPtrInst(
565 GetElementPtrInst &I) {
566 Type *Ty = I.getSourceElementType();
567 Type *NewTy = TypeMap->remapType(SrcTy: Ty);
568 if (Ty == NewTy)
569 return false;
570 // We'll be rewriting the type `ptr addrspace(7)` out of existence soon, so
571 // make sure GEPs don't have different semantics with the new type.
572 I.setSourceElementType(NewTy);
573 I.setResultElementType(TypeMap->remapType(SrcTy: I.getResultElementType()));
574 return true;
575}
576
577bool StoreFatPtrsAsIntsAndExpandMemcpyVisitor::visitLoadInst(LoadInst &LI) {
578 Type *Ty = LI.getType();
579 Type *IntTy = TypeMap->remapType(SrcTy: Ty);
580 if (Ty == IntTy)
581 return false;
582
583 IRB.SetInsertPoint(&LI);
584 auto *NLI = cast<LoadInst>(Val: LI.clone());
585 NLI->mutateType(Ty: IntTy);
586 NLI = IRB.Insert(I: NLI);
587 NLI->takeName(V: &LI);
588
589 Value *CastBack = intsToFatPtrs(V: NLI, From: IntTy, To: Ty, Name: NLI->getName());
590 LI.replaceAllUsesWith(V: CastBack);
591 LI.eraseFromParent();
592 return true;
593}
594
595bool StoreFatPtrsAsIntsAndExpandMemcpyVisitor::visitStoreInst(StoreInst &SI) {
596 Value *V = SI.getValueOperand();
597 Type *Ty = V->getType();
598 Type *IntTy = TypeMap->remapType(SrcTy: Ty);
599 if (Ty == IntTy)
600 return false;
601
602 IRB.SetInsertPoint(&SI);
603 Value *IntV = fatPtrsToInts(V, From: Ty, To: IntTy, Name: V->getName());
604 for (auto *Dbg : at::getDVRAssignmentMarkers(Inst: &SI))
605 Dbg->setRawLocation(ValueAsMetadata::get(V: IntV));
606
607 SI.setOperand(i_nocapture: 0, Val_nocapture: IntV);
608 return true;
609}
610
611bool StoreFatPtrsAsIntsAndExpandMemcpyVisitor::visitMemCpyInst(
612 MemCpyInst &MCI) {
613 // TODO: Allow memcpy.p7.p3 as a synonym for the direct-to-LDS copy, which'll
614 // need loop expansion here.
615 if (MCI.getSourceAddressSpace() != AMDGPUAS::BUFFER_FAT_POINTER &&
616 MCI.getDestAddressSpace() != AMDGPUAS::BUFFER_FAT_POINTER)
617 return false;
618 llvm::expandMemCpyAsLoop(MemCpy: &MCI,
619 TTI: TM->getTargetTransformInfo(F: *MCI.getFunction()));
620 MCI.eraseFromParent();
621 return true;
622}
623
624bool StoreFatPtrsAsIntsAndExpandMemcpyVisitor::visitMemMoveInst(
625 MemMoveInst &MMI) {
626 if (MMI.getSourceAddressSpace() != AMDGPUAS::BUFFER_FAT_POINTER &&
627 MMI.getDestAddressSpace() != AMDGPUAS::BUFFER_FAT_POINTER)
628 return false;
629 reportFatalUsageError(
630 reason: "memmove() on buffer descriptors is not implemented because pointer "
631 "comparison on buffer descriptors isn't implemented\n");
632}
633
634bool StoreFatPtrsAsIntsAndExpandMemcpyVisitor::visitMemSetInst(
635 MemSetInst &MSI) {
636 if (MSI.getDestAddressSpace() != AMDGPUAS::BUFFER_FAT_POINTER)
637 return false;
638 llvm::expandMemSetAsLoop(MemSet: &MSI,
639 TTI: TM->getTargetTransformInfo(F: *MSI.getFunction()));
640 MSI.eraseFromParent();
641 return true;
642}
643
644bool StoreFatPtrsAsIntsAndExpandMemcpyVisitor::visitMemSetPatternInst(
645 MemSetPatternInst &MSPI) {
646 if (MSPI.getDestAddressSpace() != AMDGPUAS::BUFFER_FAT_POINTER)
647 return false;
648 llvm::expandMemSetPatternAsLoop(
649 MemSet: &MSPI, TTI: TM->getTargetTransformInfo(F: *MSPI.getFunction()));
650 MSPI.eraseFromParent();
651 return true;
652}
653
654namespace {
655/// Convert loads/stores of types that the buffer intrinsics can't handle into
656/// one ore more such loads/stores that consist of legal types.
657///
658/// Do this by
659/// 1. Recursing into structs (and arrays that don't share a memory layout with
660/// vectors) since the intrinsics can't handle complex types.
661/// 2. Converting arrays of non-aggregate, byte-sized types into their
662/// corresponding vectors
663/// 3. Bitcasting unsupported types, namely overly-long scalars and byte
664/// vectors, into vectors of supported types.
665/// 4. Splitting up excessively long reads/writes into multiple operations.
666///
667/// Note that this doesn't handle complex data strucures, but, in the future,
668/// the aggregate load splitter from SROA could be refactored to allow for that
669/// case.
670class LegalizeBufferContentTypesVisitor
671 : public InstVisitor<LegalizeBufferContentTypesVisitor, bool> {
672 friend class InstVisitor<LegalizeBufferContentTypesVisitor, bool>;
673
674 IRBuilder<InstSimplifyFolder> IRB;
675
676 const DataLayout &DL;
677
678 /// If T is [N x U], where U is a scalar type, return the vector type
679 /// <N x U>, otherwise, return T.
680 Type *scalarArrayTypeAsVector(Type *MaybeArrayType);
681 Value *arrayToVector(Value *V, Type *TargetType, const Twine &Name);
682 Value *vectorToArray(Value *V, Type *OrigType, const Twine &Name);
683
684 /// Break up the loads of a struct into the loads of its components
685
686 /// Convert a vector or scalar type that can't be operated on by buffer
687 /// intrinsics to one that would be legal through bitcasts and/or truncation.
688 /// Uses the wider of i32, i16, or i8 where possible.
689 Type *legalNonAggregateFor(Type *T);
690 Value *makeLegalNonAggregate(Value *V, Type *TargetType, const Twine &Name);
691 Value *makeIllegalNonAggregate(Value *V, Type *OrigType, const Twine &Name);
692
693 struct VecSlice {
694 uint64_t Index = 0;
695 uint64_t Length = 0;
696 VecSlice() = delete;
697 // Needed for some Clangs
698 VecSlice(uint64_t Index, uint64_t Length) : Index(Index), Length(Length) {}
699 };
700 /// Return the [index, length] pairs into which `T` needs to be cut to form
701 /// legal buffer load or store operations. Clears `Slices`. Creates an empty
702 /// `Slices` for non-vector inputs and creates one slice if no slicing will be
703 /// needed.
704 void getVecSlices(Type *T, SmallVectorImpl<VecSlice> &Slices);
705
706 Value *extractSlice(Value *Vec, VecSlice S, const Twine &Name);
707 Value *insertSlice(Value *Whole, Value *Part, VecSlice S, const Twine &Name);
708
709 /// In most cases, return `LegalType`. However, when given an input that would
710 /// normally be a legal type for the buffer intrinsics to return but that
711 /// isn't hooked up through SelectionDAG, return a type of the same width that
712 /// can be used with the relevant intrinsics. Specifically, handle the cases:
713 /// - <1 x T> => T for all T
714 /// - <N x i8> <=> i16, i32, 2xi32, 4xi32 (as needed)
715 /// - <N x T> where T is under 32 bits and the total size is 96 bits <=> <3 x
716 /// i32>
717 Type *intrinsicTypeFor(Type *LegalType);
718
719 bool visitLoadImpl(LoadInst &OrigLI, Type *PartType,
720 SmallVectorImpl<uint32_t> &AggIdxs, uint64_t AggByteOffset,
721 Value *&Result, const Twine &Name);
722 /// Return value is (Changed, ModifiedInPlace)
723 std::pair<bool, bool> visitStoreImpl(StoreInst &OrigSI, Type *PartType,
724 SmallVectorImpl<uint32_t> &AggIdxs,
725 uint64_t AggByteOffset,
726 const Twine &Name);
727
728 bool visitInstruction(Instruction &I) { return false; }
729 bool visitLoadInst(LoadInst &LI);
730 bool visitStoreInst(StoreInst &SI);
731
732public:
733 LegalizeBufferContentTypesVisitor(const DataLayout &DL, LLVMContext &Ctx)
734 : IRB(Ctx, InstSimplifyFolder(DL)), DL(DL) {}
735 bool processFunction(Function &F);
736};
737} // namespace
738
739Type *LegalizeBufferContentTypesVisitor::scalarArrayTypeAsVector(Type *T) {
740 ArrayType *AT = dyn_cast<ArrayType>(Val: T);
741 if (!AT)
742 return T;
743 Type *ET = AT->getElementType();
744 if (!ET->isSingleValueType() || isa<VectorType>(Val: ET))
745 reportFatalUsageError(reason: "loading non-scalar arrays from buffer fat pointers "
746 "should have recursed");
747 if (!DL.typeSizeEqualsStoreSize(Ty: AT))
748 reportFatalUsageError(
749 reason: "loading padded arrays from buffer fat pinters should have recursed");
750 return FixedVectorType::get(ElementType: ET, NumElts: AT->getNumElements());
751}
752
753Value *LegalizeBufferContentTypesVisitor::arrayToVector(Value *V,
754 Type *TargetType,
755 const Twine &Name) {
756 Value *VectorRes = PoisonValue::get(T: TargetType);
757 auto *VT = cast<FixedVectorType>(Val: TargetType);
758 unsigned EC = VT->getNumElements();
759 for (auto I : iota_range<unsigned>(0, EC, /*Inclusive=*/false)) {
760 Value *Elem = IRB.CreateExtractValue(Agg: V, Idxs: I, Name: Name + ".elem." + Twine(I));
761 VectorRes = IRB.CreateInsertElement(Vec: VectorRes, NewElt: Elem, Idx: I,
762 Name: Name + ".as.vec." + Twine(I));
763 }
764 return VectorRes;
765}
766
767Value *LegalizeBufferContentTypesVisitor::vectorToArray(Value *V,
768 Type *OrigType,
769 const Twine &Name) {
770 Value *ArrayRes = PoisonValue::get(T: OrigType);
771 ArrayType *AT = cast<ArrayType>(Val: OrigType);
772 unsigned EC = AT->getNumElements();
773 for (auto I : iota_range<unsigned>(0, EC, /*Inclusive=*/false)) {
774 Value *Elem = IRB.CreateExtractElement(Vec: V, Idx: I, Name: Name + ".elem." + Twine(I));
775 ArrayRes = IRB.CreateInsertValue(Agg: ArrayRes, Val: Elem, Idxs: I,
776 Name: Name + ".as.array." + Twine(I));
777 }
778 return ArrayRes;
779}
780
781Type *LegalizeBufferContentTypesVisitor::legalNonAggregateFor(Type *T) {
782 TypeSize Size = DL.getTypeStoreSizeInBits(Ty: T);
783 // Implicitly zero-extend to the next byte if needed
784 if (!DL.typeSizeEqualsStoreSize(Ty: T))
785 T = IRB.getIntNTy(N: Size.getFixedValue());
786 Type *ElemTy = T->getScalarType();
787 if (isa<PointerType, ScalableVectorType>(Val: ElemTy)) {
788 // Pointers are always big enough, and we'll let scalable vectors through to
789 // fail in codegen.
790 return T;
791 }
792 unsigned ElemSize = DL.getTypeSizeInBits(Ty: ElemTy).getFixedValue();
793 if (isPowerOf2_32(Value: ElemSize) && ElemSize >= 16 && ElemSize <= 128) {
794 // [vectors of] anything that's 16/32/64/128 bits can be cast and split into
795 // legal buffer operations.
796 return T;
797 }
798 Type *BestVectorElemType = nullptr;
799 if (Size.isKnownMultipleOf(RHS: 32))
800 BestVectorElemType = IRB.getInt32Ty();
801 else if (Size.isKnownMultipleOf(RHS: 16))
802 BestVectorElemType = IRB.getInt16Ty();
803 else
804 BestVectorElemType = IRB.getInt8Ty();
805 unsigned NumCastElems =
806 Size.getFixedValue() / BestVectorElemType->getIntegerBitWidth();
807 if (NumCastElems == 1)
808 return BestVectorElemType;
809 return FixedVectorType::get(ElementType: BestVectorElemType, NumElts: NumCastElems);
810}
811
812Value *LegalizeBufferContentTypesVisitor::makeLegalNonAggregate(
813 Value *V, Type *TargetType, const Twine &Name) {
814 Type *SourceType = V->getType();
815 TypeSize SourceSize = DL.getTypeSizeInBits(Ty: SourceType);
816 TypeSize TargetSize = DL.getTypeSizeInBits(Ty: TargetType);
817 if (SourceSize != TargetSize) {
818 Type *ShortScalarTy = IRB.getIntNTy(N: SourceSize.getFixedValue());
819 Type *ByteScalarTy = IRB.getIntNTy(N: TargetSize.getFixedValue());
820 Value *AsScalar = IRB.CreateBitCast(V, DestTy: ShortScalarTy, Name: Name + ".as.scalar");
821 Value *Zext = IRB.CreateZExt(V: AsScalar, DestTy: ByteScalarTy, Name: Name + ".zext");
822 V = Zext;
823 SourceType = ByteScalarTy;
824 }
825 return IRB.CreateBitCast(V, DestTy: TargetType, Name: Name + ".legal");
826}
827
828Value *LegalizeBufferContentTypesVisitor::makeIllegalNonAggregate(
829 Value *V, Type *OrigType, const Twine &Name) {
830 Type *LegalType = V->getType();
831 TypeSize LegalSize = DL.getTypeSizeInBits(Ty: LegalType);
832 TypeSize OrigSize = DL.getTypeSizeInBits(Ty: OrigType);
833 if (LegalSize != OrigSize) {
834 Type *ShortScalarTy = IRB.getIntNTy(N: OrigSize.getFixedValue());
835 Type *ByteScalarTy = IRB.getIntNTy(N: LegalSize.getFixedValue());
836 Value *AsScalar = IRB.CreateBitCast(V, DestTy: ByteScalarTy, Name: Name + ".bytes.cast");
837 Value *Trunc = IRB.CreateTrunc(V: AsScalar, DestTy: ShortScalarTy, Name: Name + ".trunc");
838 return IRB.CreateBitCast(V: Trunc, DestTy: OrigType, Name: Name + ".orig");
839 }
840 return IRB.CreateBitCast(V, DestTy: OrigType, Name: Name + ".real.ty");
841}
842
843Type *LegalizeBufferContentTypesVisitor::intrinsicTypeFor(Type *LegalType) {
844 auto *VT = dyn_cast<FixedVectorType>(Val: LegalType);
845 if (!VT)
846 return LegalType;
847 Type *ET = VT->getElementType();
848 // Explicitly return the element type of 1-element vectors because the
849 // underlying intrinsics don't like <1 x T> even though it's a synonym for T.
850 if (VT->getNumElements() == 1)
851 return ET;
852 if (DL.getTypeSizeInBits(Ty: LegalType) == 96 && DL.getTypeSizeInBits(Ty: ET) < 32)
853 return FixedVectorType::get(ElementType: IRB.getInt32Ty(), NumElts: 3);
854 if (ET->isIntegerTy(Bitwidth: 8)) {
855 switch (VT->getNumElements()) {
856 default:
857 return LegalType; // Let it crash later
858 case 1:
859 return IRB.getInt8Ty();
860 case 2:
861 return IRB.getInt16Ty();
862 case 4:
863 return IRB.getInt32Ty();
864 case 8:
865 return FixedVectorType::get(ElementType: IRB.getInt32Ty(), NumElts: 2);
866 case 16:
867 return FixedVectorType::get(ElementType: IRB.getInt32Ty(), NumElts: 4);
868 }
869 }
870 return LegalType;
871}
872
873void LegalizeBufferContentTypesVisitor::getVecSlices(
874 Type *T, SmallVectorImpl<VecSlice> &Slices) {
875 Slices.clear();
876 auto *VT = dyn_cast<FixedVectorType>(Val: T);
877 if (!VT)
878 return;
879
880 uint64_t ElemBitWidth =
881 DL.getTypeSizeInBits(Ty: VT->getElementType()).getFixedValue();
882
883 uint64_t ElemsPer4Words = 128 / ElemBitWidth;
884 uint64_t ElemsPer2Words = ElemsPer4Words / 2;
885 uint64_t ElemsPerWord = ElemsPer2Words / 2;
886 uint64_t ElemsPerShort = ElemsPerWord / 2;
887 uint64_t ElemsPerByte = ElemsPerShort / 2;
888 // If the elements evenly pack into 32-bit words, we can use 3-word stores,
889 // such as for <6 x bfloat> or <3 x i32>, but we can't dot his for, for
890 // example, <3 x i64>, since that's not slicing.
891 uint64_t ElemsPer3Words = ElemsPerWord * 3;
892
893 uint64_t TotalElems = VT->getNumElements();
894 uint64_t Index = 0;
895 auto TrySlice = [&](unsigned MaybeLen) {
896 if (MaybeLen > 0 && Index + MaybeLen <= TotalElems) {
897 VecSlice Slice{/*Index=*/Index, /*Length=*/MaybeLen};
898 Slices.push_back(Elt: Slice);
899 Index += MaybeLen;
900 return true;
901 }
902 return false;
903 };
904 while (Index < TotalElems) {
905 TrySlice(ElemsPer4Words) || TrySlice(ElemsPer3Words) ||
906 TrySlice(ElemsPer2Words) || TrySlice(ElemsPerWord) ||
907 TrySlice(ElemsPerShort) || TrySlice(ElemsPerByte);
908 }
909}
910
911Value *LegalizeBufferContentTypesVisitor::extractSlice(Value *Vec, VecSlice S,
912 const Twine &Name) {
913 auto *VecVT = dyn_cast<FixedVectorType>(Val: Vec->getType());
914 if (!VecVT)
915 return Vec;
916 if (S.Length == VecVT->getNumElements() && S.Index == 0)
917 return Vec;
918 if (S.Length == 1)
919 return IRB.CreateExtractElement(Vec, Idx: S.Index,
920 Name: Name + ".slice." + Twine(S.Index));
921 SmallVector<int> Mask = llvm::to_vector(
922 Range: llvm::iota_range<int>(S.Index, S.Index + S.Length, /*Inclusive=*/false));
923 return IRB.CreateShuffleVector(V: Vec, Mask, Name: Name + ".slice." + Twine(S.Index));
924}
925
926Value *LegalizeBufferContentTypesVisitor::insertSlice(Value *Whole, Value *Part,
927 VecSlice S,
928 const Twine &Name) {
929 auto *WholeVT = dyn_cast<FixedVectorType>(Val: Whole->getType());
930 if (!WholeVT)
931 return Part;
932 if (S.Length == WholeVT->getNumElements() && S.Index == 0)
933 return Part;
934 if (S.Length == 1) {
935 return IRB.CreateInsertElement(Vec: Whole, NewElt: Part, Idx: S.Index,
936 Name: Name + ".slice." + Twine(S.Index));
937 }
938 int NumElems = cast<FixedVectorType>(Val: Whole->getType())->getNumElements();
939
940 // Extend the slice with poisons to make the main shufflevector happy.
941 SmallVector<int> ExtPartMask(NumElems, -1);
942 for (auto [I, E] : llvm::enumerate(
943 First: MutableArrayRef<int>(ExtPartMask).take_front(N: S.Length))) {
944 E = I;
945 }
946 Value *ExtPart = IRB.CreateShuffleVector(V: Part, Mask: ExtPartMask,
947 Name: Name + ".ext." + Twine(S.Index));
948
949 SmallVector<int> Mask =
950 llvm::to_vector(Range: llvm::iota_range<int>(0, NumElems, /*Inclusive=*/false));
951 for (auto [I, E] :
952 llvm::enumerate(First: MutableArrayRef<int>(Mask).slice(N: S.Index, M: S.Length)))
953 E = I + NumElems;
954 return IRB.CreateShuffleVector(V1: Whole, V2: ExtPart, Mask,
955 Name: Name + ".parts." + Twine(S.Index));
956}
957
958bool LegalizeBufferContentTypesVisitor::visitLoadImpl(
959 LoadInst &OrigLI, Type *PartType, SmallVectorImpl<uint32_t> &AggIdxs,
960 uint64_t AggByteOff, Value *&Result, const Twine &Name) {
961 if (auto *ST = dyn_cast<StructType>(Val: PartType)) {
962 const StructLayout *Layout = DL.getStructLayout(Ty: ST);
963 bool Changed = false;
964 for (auto [I, ElemTy, Offset] :
965 llvm::enumerate(First: ST->elements(), Rest: Layout->getMemberOffsets())) {
966 AggIdxs.push_back(Elt: I);
967 Changed |= visitLoadImpl(OrigLI, PartType: ElemTy, AggIdxs,
968 AggByteOff: AggByteOff + Offset.getFixedValue(), Result,
969 Name: Name + "." + Twine(I));
970 AggIdxs.pop_back();
971 }
972 return Changed;
973 }
974 if (auto *AT = dyn_cast<ArrayType>(Val: PartType)) {
975 Type *ElemTy = AT->getElementType();
976 if (!ElemTy->isSingleValueType() || !DL.typeSizeEqualsStoreSize(Ty: ElemTy) ||
977 ElemTy->isVectorTy()) {
978 TypeSize ElemStoreSize = DL.getTypeStoreSize(Ty: ElemTy);
979 bool Changed = false;
980 for (auto I : llvm::iota_range<uint32_t>(0, AT->getNumElements(),
981 /*Inclusive=*/false)) {
982 AggIdxs.push_back(Elt: I);
983 Changed |= visitLoadImpl(OrigLI, PartType: ElemTy, AggIdxs,
984 AggByteOff: AggByteOff + I * ElemStoreSize.getFixedValue(),
985 Result, Name: Name + Twine(I));
986 AggIdxs.pop_back();
987 }
988 return Changed;
989 }
990 }
991
992 // Typical case
993
994 Type *ArrayAsVecType = scalarArrayTypeAsVector(T: PartType);
995 Type *LegalType = legalNonAggregateFor(T: ArrayAsVecType);
996
997 SmallVector<VecSlice> Slices;
998 getVecSlices(T: LegalType, Slices);
999 bool HasSlices = Slices.size() > 1;
1000 bool IsAggPart = !AggIdxs.empty();
1001 Value *LoadsRes;
1002 if (!HasSlices && !IsAggPart) {
1003 Type *LoadableType = intrinsicTypeFor(LegalType);
1004 if (LoadableType == PartType)
1005 return false;
1006
1007 IRB.SetInsertPoint(&OrigLI);
1008 auto *NLI = cast<LoadInst>(Val: OrigLI.clone());
1009 NLI->mutateType(Ty: LoadableType);
1010 NLI = IRB.Insert(I: NLI);
1011 NLI->setName(Name + ".loadable");
1012
1013 LoadsRes = IRB.CreateBitCast(V: NLI, DestTy: LegalType, Name: Name + ".from.loadable");
1014 } else {
1015 IRB.SetInsertPoint(&OrigLI);
1016 LoadsRes = PoisonValue::get(T: LegalType);
1017 Value *OrigPtr = OrigLI.getPointerOperand();
1018 // If we're needing to spill something into more than one load, its legal
1019 // type will be a vector (ex. an i256 load will have LegalType = <8 x i32>).
1020 // But if we're already a scalar (which can happen if we're splitting up a
1021 // struct), the element type will be the legal type itself.
1022 Type *ElemType = LegalType->getScalarType();
1023 unsigned ElemBytes = DL.getTypeStoreSize(Ty: ElemType);
1024 AAMDNodes AANodes = OrigLI.getAAMetadata();
1025 if (IsAggPart && Slices.empty())
1026 Slices.push_back(Elt: VecSlice{/*Index=*/0, /*Length=*/1});
1027 for (VecSlice S : Slices) {
1028 Type *SliceType =
1029 S.Length != 1 ? FixedVectorType::get(ElementType: ElemType, NumElts: S.Length) : ElemType;
1030 int64_t ByteOffset = AggByteOff + S.Index * ElemBytes;
1031 // You can't reasonably expect loads to wrap around the edge of memory.
1032 Value *NewPtr = IRB.CreateGEP(
1033 Ty: IRB.getInt8Ty(), Ptr: OrigLI.getPointerOperand(), IdxList: IRB.getInt32(C: ByteOffset),
1034 Name: OrigPtr->getName() + ".off.ptr." + Twine(ByteOffset),
1035 NW: GEPNoWrapFlags::noUnsignedWrap());
1036 Type *LoadableType = intrinsicTypeFor(LegalType: SliceType);
1037 LoadInst *NewLI = IRB.CreateAlignedLoad(
1038 Ty: LoadableType, Ptr: NewPtr, Align: commonAlignment(A: OrigLI.getAlign(), Offset: ByteOffset),
1039 Name: Name + ".off." + Twine(ByteOffset));
1040 copyMetadataForLoad(Dest&: *NewLI, Source: OrigLI);
1041 NewLI->setAAMetadata(
1042 AANodes.adjustForAccess(Offset: ByteOffset, AccessTy: LoadableType, DL));
1043 NewLI->setAtomic(Ordering: OrigLI.getOrdering(), SSID: OrigLI.getSyncScopeID());
1044 NewLI->setVolatile(OrigLI.isVolatile());
1045 Value *Loaded = IRB.CreateBitCast(V: NewLI, DestTy: SliceType,
1046 Name: NewLI->getName() + ".from.loadable");
1047 LoadsRes = insertSlice(Whole: LoadsRes, Part: Loaded, S, Name);
1048 }
1049 }
1050 if (LegalType != ArrayAsVecType)
1051 LoadsRes = makeIllegalNonAggregate(V: LoadsRes, OrigType: ArrayAsVecType, Name);
1052 if (ArrayAsVecType != PartType)
1053 LoadsRes = vectorToArray(V: LoadsRes, OrigType: PartType, Name);
1054
1055 if (IsAggPart)
1056 Result = IRB.CreateInsertValue(Agg: Result, Val: LoadsRes, Idxs: AggIdxs, Name);
1057 else
1058 Result = LoadsRes;
1059 return true;
1060}
1061
1062bool LegalizeBufferContentTypesVisitor::visitLoadInst(LoadInst &LI) {
1063 if (LI.getPointerAddressSpace() != AMDGPUAS::BUFFER_FAT_POINTER)
1064 return false;
1065
1066 SmallVector<uint32_t> AggIdxs;
1067 Type *OrigType = LI.getType();
1068 Value *Result = PoisonValue::get(T: OrigType);
1069 bool Changed = visitLoadImpl(OrigLI&: LI, PartType: OrigType, AggIdxs, AggByteOff: 0, Result, Name: LI.getName());
1070 if (!Changed)
1071 return false;
1072 Result->takeName(V: &LI);
1073 LI.replaceAllUsesWith(V: Result);
1074 LI.eraseFromParent();
1075 return Changed;
1076}
1077
1078std::pair<bool, bool> LegalizeBufferContentTypesVisitor::visitStoreImpl(
1079 StoreInst &OrigSI, Type *PartType, SmallVectorImpl<uint32_t> &AggIdxs,
1080 uint64_t AggByteOff, const Twine &Name) {
1081 if (auto *ST = dyn_cast<StructType>(Val: PartType)) {
1082 const StructLayout *Layout = DL.getStructLayout(Ty: ST);
1083 bool Changed = false;
1084 for (auto [I, ElemTy, Offset] :
1085 llvm::enumerate(First: ST->elements(), Rest: Layout->getMemberOffsets())) {
1086 AggIdxs.push_back(Elt: I);
1087 Changed |= std::get<0>(in: visitStoreImpl(OrigSI, PartType: ElemTy, AggIdxs,
1088 AggByteOff: AggByteOff + Offset.getFixedValue(),
1089 Name: Name + "." + Twine(I)));
1090 AggIdxs.pop_back();
1091 }
1092 return std::make_pair(x&: Changed, /*ModifiedInPlace=*/y: false);
1093 }
1094 if (auto *AT = dyn_cast<ArrayType>(Val: PartType)) {
1095 Type *ElemTy = AT->getElementType();
1096 if (!ElemTy->isSingleValueType() || !DL.typeSizeEqualsStoreSize(Ty: ElemTy) ||
1097 ElemTy->isVectorTy()) {
1098 TypeSize ElemStoreSize = DL.getTypeStoreSize(Ty: ElemTy);
1099 bool Changed = false;
1100 for (auto I : llvm::iota_range<uint32_t>(0, AT->getNumElements(),
1101 /*Inclusive=*/false)) {
1102 AggIdxs.push_back(Elt: I);
1103 Changed |= std::get<0>(in: visitStoreImpl(
1104 OrigSI, PartType: ElemTy, AggIdxs,
1105 AggByteOff: AggByteOff + I * ElemStoreSize.getFixedValue(), Name: Name + Twine(I)));
1106 AggIdxs.pop_back();
1107 }
1108 return std::make_pair(x&: Changed, /*ModifiedInPlace=*/y: false);
1109 }
1110 }
1111
1112 Value *OrigData = OrigSI.getValueOperand();
1113 Value *NewData = OrigData;
1114
1115 bool IsAggPart = !AggIdxs.empty();
1116 if (IsAggPart)
1117 NewData = IRB.CreateExtractValue(Agg: NewData, Idxs: AggIdxs, Name);
1118
1119 Type *ArrayAsVecType = scalarArrayTypeAsVector(T: PartType);
1120 if (ArrayAsVecType != PartType) {
1121 NewData = arrayToVector(V: NewData, TargetType: ArrayAsVecType, Name);
1122 }
1123
1124 Type *LegalType = legalNonAggregateFor(T: ArrayAsVecType);
1125 if (LegalType != ArrayAsVecType) {
1126 NewData = makeLegalNonAggregate(V: NewData, TargetType: LegalType, Name);
1127 }
1128
1129 SmallVector<VecSlice> Slices;
1130 getVecSlices(T: LegalType, Slices);
1131 bool NeedToSplit = Slices.size() > 1 || IsAggPart;
1132 if (!NeedToSplit) {
1133 Type *StorableType = intrinsicTypeFor(LegalType);
1134 if (StorableType == PartType)
1135 return std::make_pair(/*Changed=*/x: false, /*ModifiedInPlace=*/y: false);
1136 NewData = IRB.CreateBitCast(V: NewData, DestTy: StorableType, Name: Name + ".storable");
1137 OrigSI.setOperand(i_nocapture: 0, Val_nocapture: NewData);
1138 return std::make_pair(/*Changed=*/x: true, /*ModifiedInPlace=*/y: true);
1139 }
1140
1141 Value *OrigPtr = OrigSI.getPointerOperand();
1142 Type *ElemType = LegalType->getScalarType();
1143 if (IsAggPart && Slices.empty())
1144 Slices.push_back(Elt: VecSlice{/*Index=*/0, /*Length=*/1});
1145 unsigned ElemBytes = DL.getTypeStoreSize(Ty: ElemType);
1146 AAMDNodes AANodes = OrigSI.getAAMetadata();
1147 for (VecSlice S : Slices) {
1148 Type *SliceType =
1149 S.Length != 1 ? FixedVectorType::get(ElementType: ElemType, NumElts: S.Length) : ElemType;
1150 int64_t ByteOffset = AggByteOff + S.Index * ElemBytes;
1151 Value *NewPtr =
1152 IRB.CreateGEP(Ty: IRB.getInt8Ty(), Ptr: OrigPtr, IdxList: IRB.getInt32(C: ByteOffset),
1153 Name: OrigPtr->getName() + ".part." + Twine(S.Index),
1154 NW: GEPNoWrapFlags::noUnsignedWrap());
1155 Value *DataSlice = extractSlice(Vec: NewData, S, Name);
1156 Type *StorableType = intrinsicTypeFor(LegalType: SliceType);
1157 DataSlice = IRB.CreateBitCast(V: DataSlice, DestTy: StorableType,
1158 Name: DataSlice->getName() + ".storable");
1159 auto *NewSI = cast<StoreInst>(Val: OrigSI.clone());
1160 NewSI->setAlignment(commonAlignment(A: OrigSI.getAlign(), Offset: ByteOffset));
1161 IRB.Insert(I: NewSI);
1162 NewSI->setOperand(i_nocapture: 0, Val_nocapture: DataSlice);
1163 NewSI->setOperand(i_nocapture: 1, Val_nocapture: NewPtr);
1164 NewSI->setAAMetadata(AANodes.adjustForAccess(Offset: ByteOffset, AccessTy: StorableType, DL));
1165 }
1166 return std::make_pair(/*Changed=*/x: true, /*ModifiedInPlace=*/y: false);
1167}
1168
1169bool LegalizeBufferContentTypesVisitor::visitStoreInst(StoreInst &SI) {
1170 if (SI.getPointerAddressSpace() != AMDGPUAS::BUFFER_FAT_POINTER)
1171 return false;
1172 IRB.SetInsertPoint(&SI);
1173 SmallVector<uint32_t> AggIdxs;
1174 Value *OrigData = SI.getValueOperand();
1175 auto [Changed, ModifiedInPlace] =
1176 visitStoreImpl(OrigSI&: SI, PartType: OrigData->getType(), AggIdxs, AggByteOff: 0, Name: OrigData->getName());
1177 if (Changed && !ModifiedInPlace)
1178 SI.eraseFromParent();
1179 return Changed;
1180}
1181
1182bool LegalizeBufferContentTypesVisitor::processFunction(Function &F) {
1183 bool Changed = false;
1184 // Note, memory transfer intrinsics won't
1185 for (Instruction &I : make_early_inc_range(Range: instructions(F))) {
1186 Changed |= visit(I);
1187 }
1188 return Changed;
1189}
1190
1191/// Return the ptr addrspace(8) and i32 (resource and offset parts) in a lowered
1192/// buffer fat pointer constant.
1193static std::pair<Constant *, Constant *>
1194splitLoweredFatBufferConst(Constant *C) {
1195 assert(isSplitFatPtr(C->getType()) && "Not a split fat buffer pointer");
1196 return std::make_pair(x: C->getAggregateElement(Elt: 0u), y: C->getAggregateElement(Elt: 1u));
1197}
1198
1199namespace {
1200/// Handle the remapping of ptr addrspace(7) constants.
1201class FatPtrConstMaterializer final : public ValueMaterializer {
1202 BufferFatPtrToStructTypeMap *TypeMap;
1203 // An internal mapper that is used to recurse into the arguments of constants.
1204 // While the documentation for `ValueMapper` specifies not to use it
1205 // recursively, examination of the logic in mapValue() shows that it can
1206 // safely be used recursively when handling constants, like it does in its own
1207 // logic.
1208 ValueMapper InternalMapper;
1209
1210 Constant *materializeBufferFatPtrConst(Constant *C);
1211
1212public:
1213 // UnderlyingMap is the value map this materializer will be filling.
1214 FatPtrConstMaterializer(BufferFatPtrToStructTypeMap *TypeMap,
1215 ValueToValueMapTy &UnderlyingMap)
1216 : TypeMap(TypeMap),
1217 InternalMapper(UnderlyingMap, RF_None, TypeMap, this) {}
1218 ~FatPtrConstMaterializer() = default;
1219
1220 Value *materialize(Value *V) override;
1221};
1222} // namespace
1223
1224Constant *FatPtrConstMaterializer::materializeBufferFatPtrConst(Constant *C) {
1225 Type *SrcTy = C->getType();
1226 auto *NewTy = dyn_cast<StructType>(Val: TypeMap->remapType(SrcTy));
1227 if (C->isNullValue())
1228 return ConstantAggregateZero::getNullValue(Ty: NewTy);
1229 if (isa<PoisonValue>(Val: C)) {
1230 return ConstantStruct::get(T: NewTy,
1231 V: {PoisonValue::get(T: NewTy->getElementType(N: 0)),
1232 PoisonValue::get(T: NewTy->getElementType(N: 1))});
1233 }
1234 if (isa<UndefValue>(Val: C)) {
1235 return ConstantStruct::get(T: NewTy,
1236 V: {UndefValue::get(T: NewTy->getElementType(N: 0)),
1237 UndefValue::get(T: NewTy->getElementType(N: 1))});
1238 }
1239
1240 if (auto *VC = dyn_cast<ConstantVector>(Val: C)) {
1241 if (Constant *S = VC->getSplatValue()) {
1242 Constant *NewS = InternalMapper.mapConstant(C: *S);
1243 if (!NewS)
1244 return nullptr;
1245 auto [Rsrc, Off] = splitLoweredFatBufferConst(C: NewS);
1246 auto EC = VC->getType()->getElementCount();
1247 return ConstantStruct::get(T: NewTy, V: {ConstantVector::getSplat(EC, Elt: Rsrc),
1248 ConstantVector::getSplat(EC, Elt: Off)});
1249 }
1250 SmallVector<Constant *> Rsrcs;
1251 SmallVector<Constant *> Offs;
1252 for (Value *Op : VC->operand_values()) {
1253 auto *NewOp = dyn_cast_or_null<Constant>(Val: InternalMapper.mapValue(V: *Op));
1254 if (!NewOp)
1255 return nullptr;
1256 auto [Rsrc, Off] = splitLoweredFatBufferConst(C: NewOp);
1257 Rsrcs.push_back(Elt: Rsrc);
1258 Offs.push_back(Elt: Off);
1259 }
1260 Constant *RsrcVec = ConstantVector::get(V: Rsrcs);
1261 Constant *OffVec = ConstantVector::get(V: Offs);
1262 return ConstantStruct::get(T: NewTy, V: {RsrcVec, OffVec});
1263 }
1264
1265 if (isa<GlobalValue>(Val: C))
1266 reportFatalUsageError(reason: "global values containing ptr addrspace(7) (buffer "
1267 "fat pointer) values are not supported");
1268
1269 if (isa<ConstantExpr>(Val: C))
1270 reportFatalUsageError(
1271 reason: "constant exprs containing ptr addrspace(7) (buffer "
1272 "fat pointer) values should have been expanded earlier");
1273
1274 return nullptr;
1275}
1276
1277Value *FatPtrConstMaterializer::materialize(Value *V) {
1278 Constant *C = dyn_cast<Constant>(Val: V);
1279 if (!C)
1280 return nullptr;
1281 // Structs and other types that happen to contain fat pointers get remapped
1282 // by the mapValue() logic.
1283 if (!isBufferFatPtrConst(C))
1284 return nullptr;
1285 return materializeBufferFatPtrConst(C);
1286}
1287
1288using PtrParts = std::pair<Value *, Value *>;
1289namespace {
1290// The visitor returns the resource and offset parts for an instruction if they
1291// can be computed, or (nullptr, nullptr) for cases that don't have a meaningful
1292// value mapping.
1293class SplitPtrStructs : public InstVisitor<SplitPtrStructs, PtrParts> {
1294 ValueToValueMapTy RsrcParts;
1295 ValueToValueMapTy OffParts;
1296
1297 // Track instructions that have been rewritten into a user of the component
1298 // parts of their ptr addrspace(7) input. Instructions that produced
1299 // ptr addrspace(7) parts should **not** be RAUW'd before being added to this
1300 // set, as that replacement will be handled in a post-visit step. However,
1301 // instructions that yield values that aren't fat pointers (ex. ptrtoint)
1302 // should RAUW themselves with new instructions that use the split parts
1303 // of their arguments during processing.
1304 DenseSet<Instruction *> SplitUsers;
1305
1306 // Nodes that need a second look once we've computed the parts for all other
1307 // instructions to see if, for example, we really need to phi on the resource
1308 // part.
1309 SmallVector<Instruction *> Conditionals;
1310 // Temporary instructions produced while lowering conditionals that should be
1311 // killed.
1312 SmallVector<Instruction *> ConditionalTemps;
1313
1314 // Subtarget info, needed for determining what cache control bits to set.
1315 const TargetMachine *TM;
1316 const GCNSubtarget *ST = nullptr;
1317
1318 IRBuilder<InstSimplifyFolder> IRB;
1319
1320 // Copy metadata between instructions if applicable.
1321 void copyMetadata(Value *Dest, Value *Src);
1322
1323 // Get the resource and offset parts of the value V, inserting appropriate
1324 // extractvalue calls if needed.
1325 PtrParts getPtrParts(Value *V);
1326
1327 // Given an instruction that could produce multiple resource parts (a PHI or
1328 // select), collect the set of possible instructions that could have provided
1329 // its resource parts that it could have (the `Roots`) and the set of
1330 // conditional instructions visited during the search (`Seen`). If, after
1331 // removing the root of the search from `Seen` and `Roots`, `Seen` is a subset
1332 // of `Roots` and `Roots - Seen` contains one element, the resource part of
1333 // that element can replace the resource part of all other elements in `Seen`.
1334 void getPossibleRsrcRoots(Instruction *I, SmallPtrSetImpl<Value *> &Roots,
1335 SmallPtrSetImpl<Value *> &Seen);
1336 void processConditionals();
1337
1338 // If an instruction hav been split into resource and offset parts,
1339 // delete that instruction. If any of its uses have not themselves been split
1340 // into parts (for example, an insertvalue), construct the structure
1341 // that the type rewrites declared should be produced by the dying instruction
1342 // and use that.
1343 // Also, kill the temporary extractvalue operations produced by the two-stage
1344 // lowering of PHIs and conditionals.
1345 void killAndReplaceSplitInstructions(SmallVectorImpl<Instruction *> &Origs);
1346
1347 void setAlign(CallInst *Intr, Align A, unsigned RsrcArgIdx);
1348 void insertPreMemOpFence(AtomicOrdering Order, SyncScope::ID SSID);
1349 void insertPostMemOpFence(AtomicOrdering Order, SyncScope::ID SSID);
1350 Value *handleMemoryInst(Instruction *I, Value *Arg, Value *Ptr, Type *Ty,
1351 Align Alignment, AtomicOrdering Order,
1352 bool IsVolatile, SyncScope::ID SSID);
1353
1354public:
1355 SplitPtrStructs(const DataLayout &DL, LLVMContext &Ctx,
1356 const TargetMachine *TM)
1357 : TM(TM), IRB(Ctx, InstSimplifyFolder(DL)) {}
1358
1359 void processFunction(Function &F);
1360
1361 PtrParts visitInstruction(Instruction &I);
1362 PtrParts visitLoadInst(LoadInst &LI);
1363 PtrParts visitStoreInst(StoreInst &SI);
1364 PtrParts visitAtomicRMWInst(AtomicRMWInst &AI);
1365 PtrParts visitAtomicCmpXchgInst(AtomicCmpXchgInst &AI);
1366 PtrParts visitGetElementPtrInst(GetElementPtrInst &GEP);
1367
1368 PtrParts visitPtrToAddrInst(PtrToAddrInst &PA);
1369 PtrParts visitPtrToIntInst(PtrToIntInst &PI);
1370 PtrParts visitIntToPtrInst(IntToPtrInst &IP);
1371 PtrParts visitAddrSpaceCastInst(AddrSpaceCastInst &I);
1372 PtrParts visitICmpInst(ICmpInst &Cmp);
1373 PtrParts visitFreezeInst(FreezeInst &I);
1374
1375 PtrParts visitExtractElementInst(ExtractElementInst &I);
1376 PtrParts visitInsertElementInst(InsertElementInst &I);
1377 PtrParts visitShuffleVectorInst(ShuffleVectorInst &I);
1378
1379 PtrParts visitPHINode(PHINode &PHI);
1380 PtrParts visitSelectInst(SelectInst &SI);
1381
1382 PtrParts visitIntrinsicInst(IntrinsicInst &II);
1383};
1384} // namespace
1385
1386void SplitPtrStructs::copyMetadata(Value *Dest, Value *Src) {
1387 auto *DestI = dyn_cast<Instruction>(Val: Dest);
1388 auto *SrcI = dyn_cast<Instruction>(Val: Src);
1389
1390 if (!DestI || !SrcI)
1391 return;
1392
1393 DestI->copyMetadata(SrcInst: *SrcI);
1394}
1395
1396PtrParts SplitPtrStructs::getPtrParts(Value *V) {
1397 assert(isSplitFatPtr(V->getType()) && "it's not meaningful to get the parts "
1398 "of something that wasn't rewritten");
1399 auto *RsrcEntry = &RsrcParts[V];
1400 auto *OffEntry = &OffParts[V];
1401 if (*RsrcEntry && *OffEntry)
1402 return {*RsrcEntry, *OffEntry};
1403
1404 if (auto *C = dyn_cast<Constant>(Val: V)) {
1405 auto [Rsrc, Off] = splitLoweredFatBufferConst(C);
1406 return {*RsrcEntry = Rsrc, *OffEntry = Off};
1407 }
1408
1409 IRBuilder<InstSimplifyFolder>::InsertPointGuard Guard(IRB);
1410 if (auto *I = dyn_cast<Instruction>(Val: V)) {
1411 LLVM_DEBUG(dbgs() << "Recursing to split parts of " << *I << "\n");
1412 auto [Rsrc, Off] = visit(I&: *I);
1413 if (Rsrc && Off)
1414 return {*RsrcEntry = Rsrc, *OffEntry = Off};
1415 // We'll be creating the new values after the relevant instruction.
1416 // This instruction generates a value and so isn't a terminator.
1417 IRB.SetInsertPoint(*I->getInsertionPointAfterDef());
1418 IRB.SetCurrentDebugLocation(I->getDebugLoc());
1419 } else if (auto *A = dyn_cast<Argument>(Val: V)) {
1420 IRB.SetInsertPointPastAllocas(A->getParent());
1421 IRB.SetCurrentDebugLocation(DebugLoc());
1422 }
1423 Value *Rsrc = IRB.CreateExtractValue(Agg: V, Idxs: 0, Name: V->getName() + ".rsrc");
1424 Value *Off = IRB.CreateExtractValue(Agg: V, Idxs: 1, Name: V->getName() + ".off");
1425 return {*RsrcEntry = Rsrc, *OffEntry = Off};
1426}
1427
1428/// Returns the instruction that defines the resource part of the value V.
1429/// Note that this is not getUnderlyingObject(), since that looks through
1430/// operations like ptrmask which might modify the resource part.
1431///
1432/// We can limit ourselves to just looking through GEPs followed by looking
1433/// through addrspacecasts because only those two operations preserve the
1434/// resource part, and because operations on an `addrspace(8)` (which is the
1435/// legal input to this addrspacecast) would produce a different resource part.
1436static Value *rsrcPartRoot(Value *V) {
1437 while (auto *GEP = dyn_cast<GEPOperator>(Val: V))
1438 V = GEP->getPointerOperand();
1439 while (auto *ASC = dyn_cast<AddrSpaceCastOperator>(Val: V))
1440 V = ASC->getPointerOperand();
1441 return V;
1442}
1443
1444void SplitPtrStructs::getPossibleRsrcRoots(Instruction *I,
1445 SmallPtrSetImpl<Value *> &Roots,
1446 SmallPtrSetImpl<Value *> &Seen) {
1447 if (auto *PHI = dyn_cast<PHINode>(Val: I)) {
1448 if (!Seen.insert(Ptr: I).second)
1449 return;
1450 for (Value *In : PHI->incoming_values()) {
1451 In = rsrcPartRoot(V: In);
1452 Roots.insert(Ptr: In);
1453 if (isa<PHINode, SelectInst>(Val: In))
1454 getPossibleRsrcRoots(I: cast<Instruction>(Val: In), Roots, Seen);
1455 }
1456 } else if (auto *SI = dyn_cast<SelectInst>(Val: I)) {
1457 if (!Seen.insert(Ptr: SI).second)
1458 return;
1459 Value *TrueVal = rsrcPartRoot(V: SI->getTrueValue());
1460 Value *FalseVal = rsrcPartRoot(V: SI->getFalseValue());
1461 Roots.insert(Ptr: TrueVal);
1462 Roots.insert(Ptr: FalseVal);
1463 if (isa<PHINode, SelectInst>(Val: TrueVal))
1464 getPossibleRsrcRoots(I: cast<Instruction>(Val: TrueVal), Roots, Seen);
1465 if (isa<PHINode, SelectInst>(Val: FalseVal))
1466 getPossibleRsrcRoots(I: cast<Instruction>(Val: FalseVal), Roots, Seen);
1467 } else {
1468 llvm_unreachable("getPossibleRsrcParts() only works on phi and select");
1469 }
1470}
1471
1472void SplitPtrStructs::processConditionals() {
1473 SmallDenseMap<Value *, Value *> FoundRsrcs;
1474 SmallPtrSet<Value *, 4> Roots;
1475 SmallPtrSet<Value *, 4> Seen;
1476 for (Instruction *I : Conditionals) {
1477 // These have to exist by now because we've visited these nodes.
1478 Value *Rsrc = RsrcParts[I];
1479 Value *Off = OffParts[I];
1480 assert(Rsrc && Off && "must have visited conditionals by now");
1481
1482 std::optional<Value *> MaybeRsrc;
1483 auto MaybeFoundRsrc = FoundRsrcs.find(Val: I);
1484 if (MaybeFoundRsrc != FoundRsrcs.end()) {
1485 MaybeRsrc = MaybeFoundRsrc->second;
1486 } else {
1487 IRBuilder<InstSimplifyFolder>::InsertPointGuard Guard(IRB);
1488 Roots.clear();
1489 Seen.clear();
1490 getPossibleRsrcRoots(I, Roots, Seen);
1491 LLVM_DEBUG(dbgs() << "Processing conditional: " << *I << "\n");
1492#ifndef NDEBUG
1493 for (Value *V : Roots)
1494 LLVM_DEBUG(dbgs() << "Root: " << *V << "\n");
1495 for (Value *V : Seen)
1496 LLVM_DEBUG(dbgs() << "Seen: " << *V << "\n");
1497#endif
1498 // If we are our own possible root, then we shouldn't block our
1499 // replacement with a valid incoming value.
1500 Roots.erase(Ptr: I);
1501 // We don't want to block the optimization for conditionals that don't
1502 // refer to themselves but did see themselves during the traversal.
1503 Seen.erase(Ptr: I);
1504
1505 if (set_is_subset(S1: Seen, S2: Roots)) {
1506 auto Diff = set_difference(S1: Roots, S2: Seen);
1507 if (Diff.size() == 1) {
1508 Value *RootVal = *Diff.begin();
1509 // Handle the case where previous loops already looked through
1510 // an addrspacecast.
1511 if (isSplitFatPtr(Ty: RootVal->getType()))
1512 MaybeRsrc = std::get<0>(in: getPtrParts(V: RootVal));
1513 else
1514 MaybeRsrc = RootVal;
1515 }
1516 }
1517 }
1518
1519 if (auto *PHI = dyn_cast<PHINode>(Val: I)) {
1520 Value *NewRsrc;
1521 StructType *PHITy = cast<StructType>(Val: PHI->getType());
1522 IRB.SetInsertPoint(*PHI->getInsertionPointAfterDef());
1523 IRB.SetCurrentDebugLocation(PHI->getDebugLoc());
1524 if (MaybeRsrc) {
1525 NewRsrc = *MaybeRsrc;
1526 } else {
1527 Type *RsrcTy = PHITy->getElementType(N: 0);
1528 auto *RsrcPHI = IRB.CreatePHI(Ty: RsrcTy, NumReservedValues: PHI->getNumIncomingValues());
1529 RsrcPHI->takeName(V: Rsrc);
1530 for (auto [V, BB] : llvm::zip(t: PHI->incoming_values(), u: PHI->blocks())) {
1531 Value *VRsrc = std::get<0>(in: getPtrParts(V));
1532 RsrcPHI->addIncoming(V: VRsrc, BB);
1533 }
1534 copyMetadata(Dest: RsrcPHI, Src: PHI);
1535 NewRsrc = RsrcPHI;
1536 }
1537
1538 Type *OffTy = PHITy->getElementType(N: 1);
1539 auto *NewOff = IRB.CreatePHI(Ty: OffTy, NumReservedValues: PHI->getNumIncomingValues());
1540 NewOff->takeName(V: Off);
1541 for (auto [V, BB] : llvm::zip(t: PHI->incoming_values(), u: PHI->blocks())) {
1542 assert(OffParts.count(V) && "An offset part had to be created by now");
1543 Value *VOff = std::get<1>(in: getPtrParts(V));
1544 NewOff->addIncoming(V: VOff, BB);
1545 }
1546 copyMetadata(Dest: NewOff, Src: PHI);
1547
1548 // Note: We don't eraseFromParent() the temporaries because we don't want
1549 // to put the corrections maps in an inconstent state. That'll be handed
1550 // during the rest of the killing. Also, `ValueToValueMapTy` guarantees
1551 // that references in that map will be updated as well.
1552 // Note that if the temporary instruction got `InstSimplify`'d away, it
1553 // might be something like a block argument.
1554 if (auto *RsrcInst = dyn_cast<Instruction>(Val: Rsrc)) {
1555 ConditionalTemps.push_back(Elt: RsrcInst);
1556 RsrcInst->replaceAllUsesWith(V: NewRsrc);
1557 }
1558 if (auto *OffInst = dyn_cast<Instruction>(Val: Off)) {
1559 ConditionalTemps.push_back(Elt: OffInst);
1560 OffInst->replaceAllUsesWith(V: NewOff);
1561 }
1562
1563 // Save on recomputing the cycle traversals in known-root cases.
1564 if (MaybeRsrc)
1565 for (Value *V : Seen)
1566 FoundRsrcs[V] = NewRsrc;
1567 } else if (isa<SelectInst>(Val: I)) {
1568 if (MaybeRsrc) {
1569 if (auto *RsrcInst = dyn_cast<Instruction>(Val: Rsrc)) {
1570 // Guard against conditionals that were already folded away.
1571 if (RsrcInst != *MaybeRsrc) {
1572 ConditionalTemps.push_back(Elt: RsrcInst);
1573 RsrcInst->replaceAllUsesWith(V: *MaybeRsrc);
1574 }
1575 }
1576 for (Value *V : Seen)
1577 FoundRsrcs[V] = *MaybeRsrc;
1578 }
1579 } else {
1580 llvm_unreachable("Only PHIs and selects go in the conditionals list");
1581 }
1582 }
1583}
1584
1585void SplitPtrStructs::killAndReplaceSplitInstructions(
1586 SmallVectorImpl<Instruction *> &Origs) {
1587 for (Instruction *I : ConditionalTemps)
1588 I->eraseFromParent();
1589
1590 for (Instruction *I : Origs) {
1591 if (!SplitUsers.contains(V: I))
1592 continue;
1593
1594 SmallVector<DbgVariableRecord *> Dbgs;
1595 findDbgValues(V: I, DbgVariableRecords&: Dbgs);
1596 for (DbgVariableRecord *Dbg : Dbgs) {
1597 auto &DL = I->getDataLayout();
1598 assert(isSplitFatPtr(I->getType()) &&
1599 "We should've RAUW'd away loads, stores, etc. at this point");
1600 DbgVariableRecord *OffDbg = Dbg->clone();
1601 auto [Rsrc, Off] = getPtrParts(V: I);
1602
1603 int64_t RsrcSz = DL.getTypeSizeInBits(Ty: Rsrc->getType());
1604 int64_t OffSz = DL.getTypeSizeInBits(Ty: Off->getType());
1605
1606 std::optional<DIExpression *> RsrcExpr =
1607 DIExpression::createFragmentExpression(Expr: Dbg->getExpression(), OffsetInBits: 0,
1608 SizeInBits: RsrcSz);
1609 std::optional<DIExpression *> OffExpr =
1610 DIExpression::createFragmentExpression(Expr: Dbg->getExpression(), OffsetInBits: RsrcSz,
1611 SizeInBits: OffSz);
1612 if (OffExpr) {
1613 OffDbg->setExpression(*OffExpr);
1614 OffDbg->replaceVariableLocationOp(OldValue: I, NewValue: Off);
1615 OffDbg->insertBefore(InsertBefore: Dbg);
1616 } else {
1617 OffDbg->eraseFromParent();
1618 }
1619 if (RsrcExpr) {
1620 Dbg->setExpression(*RsrcExpr);
1621 Dbg->replaceVariableLocationOp(OldValue: I, NewValue: Rsrc);
1622 } else {
1623 Dbg->replaceVariableLocationOp(OldValue: I, NewValue: PoisonValue::get(T: I->getType()));
1624 }
1625 }
1626
1627 Value *Poison = PoisonValue::get(T: I->getType());
1628 I->replaceUsesWithIf(New: Poison, ShouldReplace: [&](const Use &U) -> bool {
1629 if (const auto *UI = dyn_cast<Instruction>(Val: U.getUser()))
1630 return SplitUsers.contains(V: UI);
1631 return false;
1632 });
1633
1634 if (I->use_empty()) {
1635 I->eraseFromParent();
1636 continue;
1637 }
1638 IRB.SetInsertPoint(*I->getInsertionPointAfterDef());
1639 IRB.SetCurrentDebugLocation(I->getDebugLoc());
1640 auto [Rsrc, Off] = getPtrParts(V: I);
1641 Value *Struct = PoisonValue::get(T: I->getType());
1642 Struct = IRB.CreateInsertValue(Agg: Struct, Val: Rsrc, Idxs: 0);
1643 Struct = IRB.CreateInsertValue(Agg: Struct, Val: Off, Idxs: 1);
1644 copyMetadata(Dest: Struct, Src: I);
1645 Struct->takeName(V: I);
1646 I->replaceAllUsesWith(V: Struct);
1647 I->eraseFromParent();
1648 }
1649}
1650
1651void SplitPtrStructs::setAlign(CallInst *Intr, Align A, unsigned RsrcArgIdx) {
1652 LLVMContext &Ctx = Intr->getContext();
1653 Intr->addParamAttr(ArgNo: RsrcArgIdx, Attr: Attribute::getWithAlignment(Context&: Ctx, Alignment: A));
1654}
1655
1656void SplitPtrStructs::insertPreMemOpFence(AtomicOrdering Order,
1657 SyncScope::ID SSID) {
1658 switch (Order) {
1659 case AtomicOrdering::Release:
1660 case AtomicOrdering::AcquireRelease:
1661 case AtomicOrdering::SequentiallyConsistent:
1662 IRB.CreateFence(Ordering: AtomicOrdering::Release, SSID);
1663 break;
1664 default:
1665 break;
1666 }
1667}
1668
1669void SplitPtrStructs::insertPostMemOpFence(AtomicOrdering Order,
1670 SyncScope::ID SSID) {
1671 switch (Order) {
1672 case AtomicOrdering::Acquire:
1673 case AtomicOrdering::AcquireRelease:
1674 case AtomicOrdering::SequentiallyConsistent:
1675 IRB.CreateFence(Ordering: AtomicOrdering::Acquire, SSID);
1676 break;
1677 default:
1678 break;
1679 }
1680}
1681
1682Value *SplitPtrStructs::handleMemoryInst(Instruction *I, Value *Arg, Value *Ptr,
1683 Type *Ty, Align Alignment,
1684 AtomicOrdering Order, bool IsVolatile,
1685 SyncScope::ID SSID) {
1686 IRB.SetInsertPoint(I);
1687
1688 auto [Rsrc, Off] = getPtrParts(V: Ptr);
1689 SmallVector<Value *, 5> Args;
1690 if (Arg)
1691 Args.push_back(Elt: Arg);
1692 Args.push_back(Elt: Rsrc);
1693 Args.push_back(Elt: Off);
1694 insertPreMemOpFence(Order, SSID);
1695 // soffset is always 0 for these cases, where we always want any offset to be
1696 // part of bounds checking and we don't know which parts of the GEPs is
1697 // uniform.
1698 Args.push_back(Elt: IRB.getInt32(C: 0));
1699
1700 uint32_t Aux = 0;
1701 if (IsVolatile)
1702 Aux |= AMDGPU::CPol::VOLATILE;
1703 Args.push_back(Elt: IRB.getInt32(C: Aux));
1704
1705 Intrinsic::ID IID = Intrinsic::not_intrinsic;
1706 if (isa<LoadInst>(Val: I))
1707 IID = Order == AtomicOrdering::NotAtomic
1708 ? Intrinsic::amdgcn_raw_ptr_buffer_load
1709 : Intrinsic::amdgcn_raw_ptr_atomic_buffer_load;
1710 else if (isa<StoreInst>(Val: I))
1711 IID = Intrinsic::amdgcn_raw_ptr_buffer_store;
1712 else if (auto *RMW = dyn_cast<AtomicRMWInst>(Val: I)) {
1713 switch (RMW->getOperation()) {
1714 case AtomicRMWInst::Xchg:
1715 IID = Intrinsic::amdgcn_raw_ptr_buffer_atomic_swap;
1716 break;
1717 case AtomicRMWInst::Add:
1718 IID = Intrinsic::amdgcn_raw_ptr_buffer_atomic_add;
1719 break;
1720 case AtomicRMWInst::Sub:
1721 IID = Intrinsic::amdgcn_raw_ptr_buffer_atomic_sub;
1722 break;
1723 case AtomicRMWInst::And:
1724 IID = Intrinsic::amdgcn_raw_ptr_buffer_atomic_and;
1725 break;
1726 case AtomicRMWInst::Or:
1727 IID = Intrinsic::amdgcn_raw_ptr_buffer_atomic_or;
1728 break;
1729 case AtomicRMWInst::Xor:
1730 IID = Intrinsic::amdgcn_raw_ptr_buffer_atomic_xor;
1731 break;
1732 case AtomicRMWInst::Max:
1733 IID = Intrinsic::amdgcn_raw_ptr_buffer_atomic_smax;
1734 break;
1735 case AtomicRMWInst::Min:
1736 IID = Intrinsic::amdgcn_raw_ptr_buffer_atomic_smin;
1737 break;
1738 case AtomicRMWInst::UMax:
1739 IID = Intrinsic::amdgcn_raw_ptr_buffer_atomic_umax;
1740 break;
1741 case AtomicRMWInst::UMin:
1742 IID = Intrinsic::amdgcn_raw_ptr_buffer_atomic_umin;
1743 break;
1744 case AtomicRMWInst::FAdd:
1745 IID = Intrinsic::amdgcn_raw_ptr_buffer_atomic_fadd;
1746 break;
1747 case AtomicRMWInst::FMax:
1748 IID = Intrinsic::amdgcn_raw_ptr_buffer_atomic_fmax;
1749 break;
1750 case AtomicRMWInst::FMin:
1751 IID = Intrinsic::amdgcn_raw_ptr_buffer_atomic_fmin;
1752 break;
1753 case AtomicRMWInst::USubCond:
1754 IID = Intrinsic::amdgcn_raw_ptr_buffer_atomic_cond_sub_u32;
1755 break;
1756 case AtomicRMWInst::USubSat:
1757 IID = Intrinsic::amdgcn_raw_ptr_buffer_atomic_sub_clamp_u32;
1758 break;
1759 case AtomicRMWInst::FSub: {
1760 reportFatalUsageError(
1761 reason: "atomic floating point subtraction not supported for "
1762 "buffer resources and should've been expanded away");
1763 break;
1764 }
1765 case AtomicRMWInst::FMaximum: {
1766 reportFatalUsageError(
1767 reason: "atomic floating point fmaximum not supported for "
1768 "buffer resources and should've been expanded away");
1769 break;
1770 }
1771 case AtomicRMWInst::FMinimum: {
1772 reportFatalUsageError(
1773 reason: "atomic floating point fminimum not supported for "
1774 "buffer resources and should've been expanded away");
1775 break;
1776 }
1777 case AtomicRMWInst::FMaximumNum: {
1778 reportFatalUsageError(
1779 reason: "atomic floating point fmaximumnum not supported for "
1780 "buffer resources and should've been expanded away");
1781 break;
1782 }
1783 case AtomicRMWInst::FMinimumNum: {
1784 reportFatalUsageError(
1785 reason: "atomic floating point fminimumnum not supported for "
1786 "buffer resources and should've been expanded away");
1787 break;
1788 }
1789 case AtomicRMWInst::Nand:
1790 reportFatalUsageError(
1791 reason: "atomic nand not supported for buffer resources and "
1792 "should've been expanded away");
1793 break;
1794 case AtomicRMWInst::UIncWrap:
1795 case AtomicRMWInst::UDecWrap:
1796 reportFatalUsageError(
1797 reason: "wrapping increment/decrement not supported for "
1798 "buffer resources and should've been expanded away");
1799 break;
1800 case AtomicRMWInst::BAD_BINOP:
1801 llvm_unreachable("Not sure how we got a bad binop");
1802 }
1803 }
1804
1805 auto *Call = IRB.CreateIntrinsic(ID: IID, Types: Ty, Args);
1806 copyMetadata(Dest: Call, Src: I);
1807 setAlign(Intr: Call, A: Alignment, RsrcArgIdx: Arg ? 1 : 0);
1808 Call->takeName(V: I);
1809
1810 insertPostMemOpFence(Order, SSID);
1811 // The "no moving p7 directly" rewrites ensure that this load or store won't
1812 // itself need to be split into parts.
1813 SplitUsers.insert(V: I);
1814 I->replaceAllUsesWith(V: Call);
1815 return Call;
1816}
1817
1818PtrParts SplitPtrStructs::visitInstruction(Instruction &I) {
1819 return {nullptr, nullptr};
1820}
1821
1822PtrParts SplitPtrStructs::visitLoadInst(LoadInst &LI) {
1823 if (!isSplitFatPtr(Ty: LI.getPointerOperandType()))
1824 return {nullptr, nullptr};
1825 handleMemoryInst(I: &LI, Arg: nullptr, Ptr: LI.getPointerOperand(), Ty: LI.getType(),
1826 Alignment: LI.getAlign(), Order: LI.getOrdering(), IsVolatile: LI.isVolatile(),
1827 SSID: LI.getSyncScopeID());
1828 return {nullptr, nullptr};
1829}
1830
1831PtrParts SplitPtrStructs::visitStoreInst(StoreInst &SI) {
1832 if (!isSplitFatPtr(Ty: SI.getPointerOperandType()))
1833 return {nullptr, nullptr};
1834 Value *Arg = SI.getValueOperand();
1835 handleMemoryInst(I: &SI, Arg, Ptr: SI.getPointerOperand(), Ty: Arg->getType(),
1836 Alignment: SI.getAlign(), Order: SI.getOrdering(), IsVolatile: SI.isVolatile(),
1837 SSID: SI.getSyncScopeID());
1838 return {nullptr, nullptr};
1839}
1840
1841PtrParts SplitPtrStructs::visitAtomicRMWInst(AtomicRMWInst &AI) {
1842 if (!isSplitFatPtr(Ty: AI.getPointerOperand()->getType()))
1843 return {nullptr, nullptr};
1844 Value *Arg = AI.getValOperand();
1845 handleMemoryInst(I: &AI, Arg, Ptr: AI.getPointerOperand(), Ty: Arg->getType(),
1846 Alignment: AI.getAlign(), Order: AI.getOrdering(), IsVolatile: AI.isVolatile(),
1847 SSID: AI.getSyncScopeID());
1848 return {nullptr, nullptr};
1849}
1850
1851// Unlike load, store, and RMW, cmpxchg needs special handling to account
1852// for the boolean argument.
1853PtrParts SplitPtrStructs::visitAtomicCmpXchgInst(AtomicCmpXchgInst &AI) {
1854 Value *Ptr = AI.getPointerOperand();
1855 if (!isSplitFatPtr(Ty: Ptr->getType()))
1856 return {nullptr, nullptr};
1857 IRB.SetInsertPoint(&AI);
1858
1859 Type *Ty = AI.getNewValOperand()->getType();
1860 AtomicOrdering Order = AI.getMergedOrdering();
1861 SyncScope::ID SSID = AI.getSyncScopeID();
1862 bool IsNonTemporal = AI.getMetadata(KindID: LLVMContext::MD_nontemporal);
1863
1864 auto [Rsrc, Off] = getPtrParts(V: Ptr);
1865 insertPreMemOpFence(Order, SSID);
1866
1867 uint32_t Aux = 0;
1868 if (IsNonTemporal)
1869 Aux |= AMDGPU::CPol::SLC;
1870 if (AI.isVolatile())
1871 Aux |= AMDGPU::CPol::VOLATILE;
1872 auto *Call =
1873 IRB.CreateIntrinsic(ID: Intrinsic::amdgcn_raw_ptr_buffer_atomic_cmpswap, Types: Ty,
1874 Args: {AI.getNewValOperand(), AI.getCompareOperand(), Rsrc,
1875 Off, IRB.getInt32(C: 0), IRB.getInt32(C: Aux)});
1876 copyMetadata(Dest: Call, Src: &AI);
1877 setAlign(Intr: Call, A: AI.getAlign(), RsrcArgIdx: 2);
1878 Call->takeName(V: &AI);
1879 insertPostMemOpFence(Order, SSID);
1880
1881 Value *Res = PoisonValue::get(T: AI.getType());
1882 Res = IRB.CreateInsertValue(Agg: Res, Val: Call, Idxs: 0);
1883 if (!AI.isWeak()) {
1884 Value *Succeeded = IRB.CreateICmpEQ(LHS: Call, RHS: AI.getCompareOperand());
1885 Res = IRB.CreateInsertValue(Agg: Res, Val: Succeeded, Idxs: 1);
1886 }
1887 SplitUsers.insert(V: &AI);
1888 AI.replaceAllUsesWith(V: Res);
1889 return {nullptr, nullptr};
1890}
1891
1892PtrParts SplitPtrStructs::visitGetElementPtrInst(GetElementPtrInst &GEP) {
1893 using namespace llvm::PatternMatch;
1894 Value *Ptr = GEP.getPointerOperand();
1895 if (!isSplitFatPtr(Ty: Ptr->getType()))
1896 return {nullptr, nullptr};
1897 IRB.SetInsertPoint(&GEP);
1898
1899 auto [Rsrc, Off] = getPtrParts(V: Ptr);
1900 const DataLayout &DL = GEP.getDataLayout();
1901 bool IsNUW = GEP.hasNoUnsignedWrap();
1902 bool IsNUSW = GEP.hasNoUnsignedSignedWrap();
1903
1904 StructType *ResTy = cast<StructType>(Val: GEP.getType());
1905 Type *ResRsrcTy = ResTy->getElementType(N: 0);
1906 VectorType *ResRsrcVecTy = dyn_cast<VectorType>(Val: ResRsrcTy);
1907 bool BroadcastsPtr = ResRsrcVecTy && !isa<VectorType>(Val: Off->getType());
1908
1909 // In order to call emitGEPOffset() and thus not have to reimplement it,
1910 // we need the GEP result to have ptr addrspace(7) type.
1911 Type *FatPtrTy =
1912 ResRsrcTy->getWithNewType(EltTy: IRB.getPtrTy(AddrSpace: AMDGPUAS::BUFFER_FAT_POINTER));
1913 GEP.mutateType(Ty: FatPtrTy);
1914 Value *OffAccum = emitGEPOffset(Builder: &IRB, DL, GEP: &GEP);
1915 GEP.mutateType(Ty: ResTy);
1916
1917 if (BroadcastsPtr) {
1918 Rsrc = IRB.CreateVectorSplat(EC: ResRsrcVecTy->getElementCount(), V: Rsrc,
1919 Name: Rsrc->getName());
1920 Off = IRB.CreateVectorSplat(EC: ResRsrcVecTy->getElementCount(), V: Off,
1921 Name: Off->getName());
1922 }
1923 if (match(V: OffAccum, P: m_Zero())) { // Constant-zero offset
1924 SplitUsers.insert(V: &GEP);
1925 return {Rsrc, Off};
1926 }
1927
1928 bool HasNonNegativeOff = false;
1929 if (auto *CI = dyn_cast<ConstantInt>(Val: OffAccum)) {
1930 HasNonNegativeOff = !CI->isNegative();
1931 }
1932 Value *NewOff;
1933 if (match(V: Off, P: m_Zero())) {
1934 NewOff = OffAccum;
1935 } else {
1936 NewOff = IRB.CreateAdd(LHS: Off, RHS: OffAccum, Name: "",
1937 /*hasNUW=*/HasNUW: IsNUW || (IsNUSW && HasNonNegativeOff),
1938 /*hasNSW=*/HasNSW: false);
1939 }
1940 copyMetadata(Dest: NewOff, Src: &GEP);
1941 NewOff->takeName(V: &GEP);
1942 SplitUsers.insert(V: &GEP);
1943 return {Rsrc, NewOff};
1944}
1945
1946PtrParts SplitPtrStructs::visitPtrToIntInst(PtrToIntInst &PI) {
1947 Value *Ptr = PI.getPointerOperand();
1948 if (!isSplitFatPtr(Ty: Ptr->getType()))
1949 return {nullptr, nullptr};
1950 IRB.SetInsertPoint(&PI);
1951
1952 Type *ResTy = PI.getType();
1953 unsigned Width = ResTy->getScalarSizeInBits();
1954
1955 auto [Rsrc, Off] = getPtrParts(V: Ptr);
1956 const DataLayout &DL = PI.getDataLayout();
1957 unsigned FatPtrWidth = DL.getPointerSizeInBits(AS: AMDGPUAS::BUFFER_FAT_POINTER);
1958
1959 Value *Res;
1960 if (Width <= BufferOffsetWidth) {
1961 Res = IRB.CreateIntCast(V: Off, DestTy: ResTy, /*isSigned=*/false,
1962 Name: PI.getName() + ".off");
1963 } else {
1964 Value *RsrcInt = IRB.CreatePtrToInt(V: Rsrc, DestTy: ResTy, Name: PI.getName() + ".rsrc");
1965 Value *Shl = IRB.CreateShl(
1966 LHS: RsrcInt,
1967 RHS: ConstantExpr::getIntegerValue(Ty: ResTy, V: APInt(Width, BufferOffsetWidth)),
1968 Name: "", HasNUW: Width >= FatPtrWidth, HasNSW: Width > FatPtrWidth);
1969 Value *OffCast = IRB.CreateIntCast(V: Off, DestTy: ResTy, /*isSigned=*/false,
1970 Name: PI.getName() + ".off");
1971 Res = IRB.CreateOr(LHS: Shl, RHS: OffCast);
1972 }
1973
1974 copyMetadata(Dest: Res, Src: &PI);
1975 Res->takeName(V: &PI);
1976 SplitUsers.insert(V: &PI);
1977 PI.replaceAllUsesWith(V: Res);
1978 return {nullptr, nullptr};
1979}
1980
1981PtrParts SplitPtrStructs::visitPtrToAddrInst(PtrToAddrInst &PA) {
1982 Value *Ptr = PA.getPointerOperand();
1983 if (!isSplitFatPtr(Ty: Ptr->getType()))
1984 return {nullptr, nullptr};
1985 IRB.SetInsertPoint(&PA);
1986
1987 auto [Rsrc, Off] = getPtrParts(V: Ptr);
1988 Value *Res = IRB.CreateIntCast(V: Off, DestTy: PA.getType(), /*isSigned=*/false);
1989 copyMetadata(Dest: Res, Src: &PA);
1990 Res->takeName(V: &PA);
1991 SplitUsers.insert(V: &PA);
1992 PA.replaceAllUsesWith(V: Res);
1993 return {nullptr, nullptr};
1994}
1995
1996PtrParts SplitPtrStructs::visitIntToPtrInst(IntToPtrInst &IP) {
1997 if (!isSplitFatPtr(Ty: IP.getType()))
1998 return {nullptr, nullptr};
1999 IRB.SetInsertPoint(&IP);
2000 const DataLayout &DL = IP.getDataLayout();
2001 unsigned RsrcPtrWidth = DL.getPointerSizeInBits(AS: AMDGPUAS::BUFFER_RESOURCE);
2002 Value *Int = IP.getOperand(i_nocapture: 0);
2003 Type *IntTy = Int->getType();
2004 Type *RsrcIntTy = IntTy->getWithNewBitWidth(NewBitWidth: RsrcPtrWidth);
2005 unsigned Width = IntTy->getScalarSizeInBits();
2006
2007 auto *RetTy = cast<StructType>(Val: IP.getType());
2008 Type *RsrcTy = RetTy->getElementType(N: 0);
2009 Type *OffTy = RetTy->getElementType(N: 1);
2010 Value *RsrcPart = IRB.CreateLShr(
2011 LHS: Int,
2012 RHS: ConstantExpr::getIntegerValue(Ty: IntTy, V: APInt(Width, BufferOffsetWidth)));
2013 Value *RsrcInt = IRB.CreateIntCast(V: RsrcPart, DestTy: RsrcIntTy, /*isSigned=*/false);
2014 Value *Rsrc = IRB.CreateIntToPtr(V: RsrcInt, DestTy: RsrcTy, Name: IP.getName() + ".rsrc");
2015 Value *Off =
2016 IRB.CreateIntCast(V: Int, DestTy: OffTy, /*IsSigned=*/isSigned: false, Name: IP.getName() + ".off");
2017
2018 copyMetadata(Dest: Rsrc, Src: &IP);
2019 SplitUsers.insert(V: &IP);
2020 return {Rsrc, Off};
2021}
2022
2023PtrParts SplitPtrStructs::visitAddrSpaceCastInst(AddrSpaceCastInst &I) {
2024 // TODO(krzysz00): handle casts from ptr addrspace(7) to global pointers
2025 // by computing the effective address.
2026 if (!isSplitFatPtr(Ty: I.getType()))
2027 return {nullptr, nullptr};
2028 IRB.SetInsertPoint(&I);
2029 Value *In = I.getPointerOperand();
2030 // No-op casts preserve parts
2031 if (In->getType() == I.getType()) {
2032 auto [Rsrc, Off] = getPtrParts(V: In);
2033 SplitUsers.insert(V: &I);
2034 return {Rsrc, Off};
2035 }
2036
2037 auto *ResTy = cast<StructType>(Val: I.getType());
2038 Type *RsrcTy = ResTy->getElementType(N: 0);
2039 Type *OffTy = ResTy->getElementType(N: 1);
2040 Value *ZeroOff = Constant::getNullValue(Ty: OffTy);
2041
2042 // Special case for null pointers, undef, and poison, which can be created by
2043 // address space propagation.
2044 auto *InConst = dyn_cast<Constant>(Val: In);
2045 if (InConst && InConst->isNullValue()) {
2046 Value *NullRsrc = Constant::getNullValue(Ty: RsrcTy);
2047 SplitUsers.insert(V: &I);
2048 return {NullRsrc, ZeroOff};
2049 }
2050 if (isa<PoisonValue>(Val: In)) {
2051 Value *PoisonRsrc = PoisonValue::get(T: RsrcTy);
2052 Value *PoisonOff = PoisonValue::get(T: OffTy);
2053 SplitUsers.insert(V: &I);
2054 return {PoisonRsrc, PoisonOff};
2055 }
2056 if (isa<UndefValue>(Val: In)) {
2057 Value *UndefRsrc = UndefValue::get(T: RsrcTy);
2058 Value *UndefOff = UndefValue::get(T: OffTy);
2059 SplitUsers.insert(V: &I);
2060 return {UndefRsrc, UndefOff};
2061 }
2062
2063 if (I.getSrcAddressSpace() != AMDGPUAS::BUFFER_RESOURCE)
2064 reportFatalUsageError(
2065 reason: "only buffer resources (addrspace 8) and null/poison pointers can be "
2066 "cast to buffer fat pointers (addrspace 7)");
2067 SplitUsers.insert(V: &I);
2068 return {In, ZeroOff};
2069}
2070
2071PtrParts SplitPtrStructs::visitICmpInst(ICmpInst &Cmp) {
2072 Value *Lhs = Cmp.getOperand(i_nocapture: 0);
2073 if (!isSplitFatPtr(Ty: Lhs->getType()))
2074 return {nullptr, nullptr};
2075 Value *Rhs = Cmp.getOperand(i_nocapture: 1);
2076 IRB.SetInsertPoint(&Cmp);
2077 ICmpInst::Predicate Pred = Cmp.getPredicate();
2078
2079 assert((Pred == ICmpInst::ICMP_EQ || Pred == ICmpInst::ICMP_NE) &&
2080 "Pointer comparison is only equal or unequal");
2081 auto [LhsRsrc, LhsOff] = getPtrParts(V: Lhs);
2082 auto [RhsRsrc, RhsOff] = getPtrParts(V: Rhs);
2083 Value *Res = IRB.CreateICmp(P: Pred, LHS: LhsOff, RHS: RhsOff);
2084 copyMetadata(Dest: Res, Src: &Cmp);
2085 Res->takeName(V: &Cmp);
2086 SplitUsers.insert(V: &Cmp);
2087 Cmp.replaceAllUsesWith(V: Res);
2088 return {nullptr, nullptr};
2089}
2090
2091PtrParts SplitPtrStructs::visitFreezeInst(FreezeInst &I) {
2092 if (!isSplitFatPtr(Ty: I.getType()))
2093 return {nullptr, nullptr};
2094 IRB.SetInsertPoint(&I);
2095 auto [Rsrc, Off] = getPtrParts(V: I.getOperand(i_nocapture: 0));
2096
2097 Value *RsrcRes = IRB.CreateFreeze(V: Rsrc, Name: I.getName() + ".rsrc");
2098 copyMetadata(Dest: RsrcRes, Src: &I);
2099 Value *OffRes = IRB.CreateFreeze(V: Off, Name: I.getName() + ".off");
2100 copyMetadata(Dest: OffRes, Src: &I);
2101 SplitUsers.insert(V: &I);
2102 return {RsrcRes, OffRes};
2103}
2104
2105PtrParts SplitPtrStructs::visitExtractElementInst(ExtractElementInst &I) {
2106 if (!isSplitFatPtr(Ty: I.getType()))
2107 return {nullptr, nullptr};
2108 IRB.SetInsertPoint(&I);
2109 Value *Vec = I.getVectorOperand();
2110 Value *Idx = I.getIndexOperand();
2111 auto [Rsrc, Off] = getPtrParts(V: Vec);
2112
2113 Value *RsrcRes = IRB.CreateExtractElement(Vec: Rsrc, Idx, Name: I.getName() + ".rsrc");
2114 copyMetadata(Dest: RsrcRes, Src: &I);
2115 Value *OffRes = IRB.CreateExtractElement(Vec: Off, Idx, Name: I.getName() + ".off");
2116 copyMetadata(Dest: OffRes, Src: &I);
2117 SplitUsers.insert(V: &I);
2118 return {RsrcRes, OffRes};
2119}
2120
2121PtrParts SplitPtrStructs::visitInsertElementInst(InsertElementInst &I) {
2122 // The mutated instructions temporarily don't return vectors, and so
2123 // we need the generic getType() here to avoid crashes.
2124 if (!isSplitFatPtr(Ty: cast<Instruction>(Val&: I).getType()))
2125 return {nullptr, nullptr};
2126 IRB.SetInsertPoint(&I);
2127 Value *Vec = I.getOperand(i_nocapture: 0);
2128 Value *Elem = I.getOperand(i_nocapture: 1);
2129 Value *Idx = I.getOperand(i_nocapture: 2);
2130 auto [VecRsrc, VecOff] = getPtrParts(V: Vec);
2131 auto [ElemRsrc, ElemOff] = getPtrParts(V: Elem);
2132
2133 Value *RsrcRes =
2134 IRB.CreateInsertElement(Vec: VecRsrc, NewElt: ElemRsrc, Idx, Name: I.getName() + ".rsrc");
2135 copyMetadata(Dest: RsrcRes, Src: &I);
2136 Value *OffRes =
2137 IRB.CreateInsertElement(Vec: VecOff, NewElt: ElemOff, Idx, Name: I.getName() + ".off");
2138 copyMetadata(Dest: OffRes, Src: &I);
2139 SplitUsers.insert(V: &I);
2140 return {RsrcRes, OffRes};
2141}
2142
2143PtrParts SplitPtrStructs::visitShuffleVectorInst(ShuffleVectorInst &I) {
2144 // Cast is needed for the same reason as insertelement's.
2145 if (!isSplitFatPtr(Ty: cast<Instruction>(Val&: I).getType()))
2146 return {nullptr, nullptr};
2147 IRB.SetInsertPoint(&I);
2148
2149 Value *V1 = I.getOperand(i_nocapture: 0);
2150 Value *V2 = I.getOperand(i_nocapture: 1);
2151 ArrayRef<int> Mask = I.getShuffleMask();
2152 auto [V1Rsrc, V1Off] = getPtrParts(V: V1);
2153 auto [V2Rsrc, V2Off] = getPtrParts(V: V2);
2154
2155 Value *RsrcRes =
2156 IRB.CreateShuffleVector(V1: V1Rsrc, V2: V2Rsrc, Mask, Name: I.getName() + ".rsrc");
2157 copyMetadata(Dest: RsrcRes, Src: &I);
2158 Value *OffRes =
2159 IRB.CreateShuffleVector(V1: V1Off, V2: V2Off, Mask, Name: I.getName() + ".off");
2160 copyMetadata(Dest: OffRes, Src: &I);
2161 SplitUsers.insert(V: &I);
2162 return {RsrcRes, OffRes};
2163}
2164
2165PtrParts SplitPtrStructs::visitPHINode(PHINode &PHI) {
2166 if (!isSplitFatPtr(Ty: PHI.getType()))
2167 return {nullptr, nullptr};
2168 IRB.SetInsertPoint(*PHI.getInsertionPointAfterDef());
2169 // Phi nodes will be handled in post-processing after we've visited every
2170 // instruction. However, instead of just returning {nullptr, nullptr},
2171 // we explicitly create the temporary extractvalue operations that are our
2172 // temporary results so that they end up at the beginning of the block with
2173 // the PHIs.
2174 Value *TmpRsrc = IRB.CreateExtractValue(Agg: &PHI, Idxs: 0, Name: PHI.getName() + ".rsrc");
2175 Value *TmpOff = IRB.CreateExtractValue(Agg: &PHI, Idxs: 1, Name: PHI.getName() + ".off");
2176 Conditionals.push_back(Elt: &PHI);
2177 SplitUsers.insert(V: &PHI);
2178 return {TmpRsrc, TmpOff};
2179}
2180
2181PtrParts SplitPtrStructs::visitSelectInst(SelectInst &SI) {
2182 if (!isSplitFatPtr(Ty: SI.getType()))
2183 return {nullptr, nullptr};
2184 IRB.SetInsertPoint(&SI);
2185
2186 Value *Cond = SI.getCondition();
2187 Value *True = SI.getTrueValue();
2188 Value *False = SI.getFalseValue();
2189 auto [TrueRsrc, TrueOff] = getPtrParts(V: True);
2190 auto [FalseRsrc, FalseOff] = getPtrParts(V: False);
2191
2192 Value *RsrcRes =
2193 IRB.CreateSelect(C: Cond, True: TrueRsrc, False: FalseRsrc, Name: SI.getName() + ".rsrc", MDFrom: &SI);
2194 copyMetadata(Dest: RsrcRes, Src: &SI);
2195 Conditionals.push_back(Elt: &SI);
2196 Value *OffRes =
2197 IRB.CreateSelect(C: Cond, True: TrueOff, False: FalseOff, Name: SI.getName() + ".off", MDFrom: &SI);
2198 copyMetadata(Dest: OffRes, Src: &SI);
2199 SplitUsers.insert(V: &SI);
2200 return {RsrcRes, OffRes};
2201}
2202
2203/// Returns true if this intrinsic needs to be removed when it is
2204/// applied to `ptr addrspace(7)` values. Calls to these intrinsics are
2205/// rewritten into calls to versions of that intrinsic on the resource
2206/// descriptor.
2207static bool isRemovablePointerIntrinsic(Intrinsic::ID IID) {
2208 switch (IID) {
2209 default:
2210 return false;
2211 case Intrinsic::amdgcn_make_buffer_rsrc:
2212 case Intrinsic::ptrmask:
2213 case Intrinsic::invariant_start:
2214 case Intrinsic::invariant_end:
2215 case Intrinsic::launder_invariant_group:
2216 case Intrinsic::strip_invariant_group:
2217 case Intrinsic::memcpy:
2218 case Intrinsic::memcpy_inline:
2219 case Intrinsic::memmove:
2220 case Intrinsic::memset:
2221 case Intrinsic::memset_inline:
2222 case Intrinsic::experimental_memset_pattern:
2223 case Intrinsic::amdgcn_load_to_lds:
2224 case Intrinsic::amdgcn_load_async_to_lds:
2225 return true;
2226 }
2227}
2228
2229PtrParts SplitPtrStructs::visitIntrinsicInst(IntrinsicInst &I) {
2230 Intrinsic::ID IID = I.getIntrinsicID();
2231 switch (IID) {
2232 default:
2233 break;
2234 case Intrinsic::amdgcn_make_buffer_rsrc: {
2235 if (!isSplitFatPtr(Ty: I.getType()))
2236 return {nullptr, nullptr};
2237 Value *Base = I.getArgOperand(i: 0);
2238 Value *Stride = I.getArgOperand(i: 1);
2239 Value *NumRecords = I.getArgOperand(i: 2);
2240 Value *Flags = I.getArgOperand(i: 3);
2241 auto *SplitType = cast<StructType>(Val: I.getType());
2242 Type *RsrcType = SplitType->getElementType(N: 0);
2243 Type *OffType = SplitType->getElementType(N: 1);
2244 IRB.SetInsertPoint(&I);
2245 Value *Rsrc = IRB.CreateIntrinsic(ID: IID, Types: {RsrcType, Base->getType()},
2246 Args: {Base, Stride, NumRecords, Flags});
2247 copyMetadata(Dest: Rsrc, Src: &I);
2248 Rsrc->takeName(V: &I);
2249 Value *Zero = Constant::getNullValue(Ty: OffType);
2250 SplitUsers.insert(V: &I);
2251 return {Rsrc, Zero};
2252 }
2253 case Intrinsic::ptrmask: {
2254 Value *Ptr = I.getArgOperand(i: 0);
2255 if (!isSplitFatPtr(Ty: Ptr->getType()))
2256 return {nullptr, nullptr};
2257 Value *Mask = I.getArgOperand(i: 1);
2258 IRB.SetInsertPoint(&I);
2259 auto [Rsrc, Off] = getPtrParts(V: Ptr);
2260 if (Mask->getType() != Off->getType())
2261 reportFatalUsageError(reason: "offset width is not equal to index width of fat "
2262 "pointer (data layout not set up correctly?)");
2263 Value *OffRes = IRB.CreateAnd(LHS: Off, RHS: Mask, Name: I.getName() + ".off");
2264 copyMetadata(Dest: OffRes, Src: &I);
2265 SplitUsers.insert(V: &I);
2266 return {Rsrc, OffRes};
2267 }
2268 // Pointer annotation intrinsics that, given their object-wide nature
2269 // operate on the resource part.
2270 case Intrinsic::invariant_start: {
2271 Value *Ptr = I.getArgOperand(i: 1);
2272 if (!isSplitFatPtr(Ty: Ptr->getType()))
2273 return {nullptr, nullptr};
2274 IRB.SetInsertPoint(&I);
2275 auto [Rsrc, Off] = getPtrParts(V: Ptr);
2276 Type *NewTy = PointerType::get(C&: I.getContext(), AddressSpace: AMDGPUAS::BUFFER_RESOURCE);
2277 auto *NewRsrc = IRB.CreateIntrinsic(ID: IID, Types: {NewTy}, Args: {I.getOperand(i_nocapture: 0), Rsrc});
2278 copyMetadata(Dest: NewRsrc, Src: &I);
2279 NewRsrc->takeName(V: &I);
2280 SplitUsers.insert(V: &I);
2281 I.replaceAllUsesWith(V: NewRsrc);
2282 return {nullptr, nullptr};
2283 }
2284 case Intrinsic::invariant_end: {
2285 Value *RealPtr = I.getArgOperand(i: 2);
2286 if (!isSplitFatPtr(Ty: RealPtr->getType()))
2287 return {nullptr, nullptr};
2288 IRB.SetInsertPoint(&I);
2289 Value *RealRsrc = getPtrParts(V: RealPtr).first;
2290 Value *InvPtr = I.getArgOperand(i: 0);
2291 Value *Size = I.getArgOperand(i: 1);
2292 Value *NewRsrc = IRB.CreateIntrinsic(ID: IID, Types: {RealRsrc->getType()},
2293 Args: {InvPtr, Size, RealRsrc});
2294 copyMetadata(Dest: NewRsrc, Src: &I);
2295 NewRsrc->takeName(V: &I);
2296 SplitUsers.insert(V: &I);
2297 I.replaceAllUsesWith(V: NewRsrc);
2298 return {nullptr, nullptr};
2299 }
2300 case Intrinsic::launder_invariant_group:
2301 case Intrinsic::strip_invariant_group: {
2302 Value *Ptr = I.getArgOperand(i: 0);
2303 if (!isSplitFatPtr(Ty: Ptr->getType()))
2304 return {nullptr, nullptr};
2305 IRB.SetInsertPoint(&I);
2306 auto [Rsrc, Off] = getPtrParts(V: Ptr);
2307 Value *NewRsrc = IRB.CreateIntrinsic(ID: IID, Types: {Rsrc->getType()}, Args: {Rsrc});
2308 copyMetadata(Dest: NewRsrc, Src: &I);
2309 NewRsrc->takeName(V: &I);
2310 SplitUsers.insert(V: &I);
2311 return {NewRsrc, Off};
2312 }
2313 case Intrinsic::amdgcn_load_to_lds:
2314 case Intrinsic::amdgcn_load_async_to_lds: {
2315 Value *Ptr = I.getArgOperand(i: 0);
2316 if (!isSplitFatPtr(Ty: Ptr->getType()))
2317 return {nullptr, nullptr};
2318 IRB.SetInsertPoint(&I);
2319 auto [Rsrc, Off] = getPtrParts(V: Ptr);
2320 Value *LDSPtr = I.getArgOperand(i: 1);
2321 Value *LoadSize = I.getArgOperand(i: 2);
2322 Value *ImmOff = I.getArgOperand(i: 3);
2323 Value *Aux = I.getArgOperand(i: 4);
2324 Value *SOffset = IRB.getInt32(C: 0);
2325 Intrinsic::ID NewIntr =
2326 IID == Intrinsic::amdgcn_load_to_lds
2327 ? Intrinsic::amdgcn_raw_ptr_buffer_load_lds
2328 : Intrinsic::amdgcn_raw_ptr_buffer_load_async_lds;
2329 Instruction *NewLoad = IRB.CreateIntrinsic(
2330 ID: NewIntr, Types: {}, Args: {Rsrc, LDSPtr, LoadSize, Off, SOffset, ImmOff, Aux});
2331 copyMetadata(Dest: NewLoad, Src: &I);
2332 SplitUsers.insert(V: &I);
2333 I.replaceAllUsesWith(V: NewLoad);
2334 return {nullptr, nullptr};
2335 }
2336 }
2337 return {nullptr, nullptr};
2338}
2339
2340void SplitPtrStructs::processFunction(Function &F) {
2341 ST = &TM->getSubtarget<GCNSubtarget>(F);
2342 SmallVector<Instruction *, 0> Originals(
2343 llvm::make_pointer_range(Range: instructions(F)));
2344 LLVM_DEBUG(dbgs() << "Splitting pointer structs in function: " << F.getName()
2345 << "\n");
2346 for (Instruction *I : Originals) {
2347 // In some cases, instruction order doesn't reflect program order,
2348 // so the visit() call will have already visited coertain instructions
2349 // by the time this loop gets to them. Avoid re-visiting these so as to,
2350 // for example, avoid processing the same conditional twice.
2351 if (SplitUsers.contains(V: I))
2352 continue;
2353 auto [Rsrc, Off] = visit(I);
2354 assert(((Rsrc && Off) || (!Rsrc && !Off)) &&
2355 "Can't have a resource but no offset");
2356 if (Rsrc)
2357 RsrcParts[I] = Rsrc;
2358 if (Off)
2359 OffParts[I] = Off;
2360 }
2361 processConditionals();
2362 killAndReplaceSplitInstructions(Origs&: Originals);
2363
2364 // Clean up after ourselves to save on memory.
2365 RsrcParts.clear();
2366 OffParts.clear();
2367 SplitUsers.clear();
2368 Conditionals.clear();
2369 ConditionalTemps.clear();
2370}
2371
2372namespace {
2373class AMDGPULowerBufferFatPointers : public ModulePass {
2374public:
2375 static char ID;
2376
2377 AMDGPULowerBufferFatPointers() : ModulePass(ID) {}
2378
2379 bool run(Module &M, const TargetMachine &TM);
2380 bool runOnModule(Module &M) override;
2381
2382 void getAnalysisUsage(AnalysisUsage &AU) const override;
2383};
2384} // namespace
2385
2386/// Returns true if there are values that have a buffer fat pointer in them,
2387/// which means we'll need to perform rewrites on this function. As a side
2388/// effect, this will populate the type remapping cache.
2389static bool containsBufferFatPointers(const Function &F,
2390 BufferFatPtrToStructTypeMap *TypeMap) {
2391 bool HasFatPointers = false;
2392 for (const BasicBlock &BB : F)
2393 for (const Instruction &I : BB) {
2394 HasFatPointers |= (I.getType() != TypeMap->remapType(SrcTy: I.getType()));
2395 // Catch null pointer constants in loads, stores, etc.
2396 for (const Value *V : I.operand_values())
2397 HasFatPointers |= (V->getType() != TypeMap->remapType(SrcTy: V->getType()));
2398 }
2399 return HasFatPointers;
2400}
2401
2402static bool hasFatPointerInterface(const Function &F,
2403 BufferFatPtrToStructTypeMap *TypeMap) {
2404 Type *Ty = F.getFunctionType();
2405 return Ty != TypeMap->remapType(SrcTy: Ty);
2406}
2407
2408/// Move the body of `OldF` into a new function, returning it.
2409static Function *moveFunctionAdaptingType(Function *OldF, FunctionType *NewTy,
2410 ValueToValueMapTy &CloneMap) {
2411 bool IsIntrinsic = OldF->isIntrinsic();
2412 Function *NewF =
2413 Function::Create(Ty: NewTy, Linkage: OldF->getLinkage(), AddrSpace: OldF->getAddressSpace());
2414 NewF->copyAttributesFrom(Src: OldF);
2415 NewF->copyMetadata(Src: OldF, Offset: 0);
2416 NewF->takeName(V: OldF);
2417 NewF->updateAfterNameChange();
2418 NewF->setDLLStorageClass(OldF->getDLLStorageClass());
2419 OldF->getParent()->getFunctionList().insertAfter(where: OldF->getIterator(), New: NewF);
2420
2421 while (!OldF->empty()) {
2422 BasicBlock *BB = &OldF->front();
2423 BB->removeFromParent();
2424 BB->insertInto(Parent: NewF);
2425 CloneMap[BB] = BB;
2426 for (Instruction &I : *BB) {
2427 CloneMap[&I] = &I;
2428 }
2429 }
2430
2431 SmallVector<AttributeSet> ArgAttrs;
2432 AttributeList OldAttrs = OldF->getAttributes();
2433
2434 for (auto [I, OldArg, NewArg] : enumerate(First: OldF->args(), Rest: NewF->args())) {
2435 CloneMap[&NewArg] = &OldArg;
2436 NewArg.takeName(V: &OldArg);
2437 Type *OldArgTy = OldArg.getType(), *NewArgTy = NewArg.getType();
2438 // Temporarily mutate type of `NewArg` to allow RAUW to work.
2439 NewArg.mutateType(Ty: OldArgTy);
2440 OldArg.replaceAllUsesWith(V: &NewArg);
2441 NewArg.mutateType(Ty: NewArgTy);
2442
2443 AttributeSet ArgAttr = OldAttrs.getParamAttrs(ArgNo: I);
2444 // Intrinsics get their attributes fixed later.
2445 if (OldArgTy != NewArgTy && !IsIntrinsic)
2446 ArgAttr = ArgAttr.removeAttributes(
2447 C&: NewF->getContext(),
2448 AttrsToRemove: AttributeFuncs::typeIncompatible(Ty: NewArgTy, AS: ArgAttr));
2449 ArgAttrs.push_back(Elt: ArgAttr);
2450 }
2451 AttributeSet RetAttrs = OldAttrs.getRetAttrs();
2452 if (OldF->getReturnType() != NewF->getReturnType() && !IsIntrinsic)
2453 RetAttrs = RetAttrs.removeAttributes(
2454 C&: NewF->getContext(),
2455 AttrsToRemove: AttributeFuncs::typeIncompatible(Ty: NewF->getReturnType(), AS: RetAttrs));
2456 NewF->setAttributes(AttributeList::get(
2457 C&: NewF->getContext(), FnAttrs: OldAttrs.getFnAttrs(), RetAttrs, ArgAttrs));
2458 return NewF;
2459}
2460
2461static void makeCloneInPraceMap(Function *F, ValueToValueMapTy &CloneMap) {
2462 for (Argument &A : F->args())
2463 CloneMap[&A] = &A;
2464 for (BasicBlock &BB : *F) {
2465 CloneMap[&BB] = &BB;
2466 for (Instruction &I : BB)
2467 CloneMap[&I] = &I;
2468 }
2469}
2470
2471bool AMDGPULowerBufferFatPointers::run(Module &M, const TargetMachine &TM) {
2472 bool Changed = false;
2473 const DataLayout &DL = M.getDataLayout();
2474 // Record the functions which need to be remapped.
2475 // The second element of the pair indicates whether the function has to have
2476 // its arguments or return types adjusted.
2477 SmallVector<std::pair<Function *, bool>> NeedsRemap;
2478
2479 LLVMContext &Ctx = M.getContext();
2480
2481 BufferFatPtrToStructTypeMap StructTM(DL);
2482 BufferFatPtrToIntTypeMap IntTM(DL);
2483 for (GlobalVariable &GV : make_early_inc_range(Range: M.globals())) {
2484 if (GV.getAddressSpace() == AMDGPUAS::BUFFER_FAT_POINTER) {
2485 // FIXME: Use DiagnosticInfo unsupported but it requires a Function
2486 Ctx.emitError(ErrorStr: "global variables with a buffer fat pointer address "
2487 "space (7) are not supported");
2488 GV.replaceAllUsesWith(V: PoisonValue::get(T: GV.getType()));
2489 GV.eraseFromParent();
2490 Changed = true;
2491 continue;
2492 }
2493
2494 Type *VT = GV.getValueType();
2495 if (VT != StructTM.remapType(SrcTy: VT)) {
2496 // FIXME: Use DiagnosticInfo unsupported but it requires a Function
2497 Ctx.emitError(ErrorStr: "global variables that contain buffer fat pointers "
2498 "(address space 7 pointers) are unsupported. Use "
2499 "buffer resource pointers (address space 8) instead");
2500 GV.replaceAllUsesWith(V: PoisonValue::get(T: GV.getType()));
2501 GV.eraseFromParent();
2502 Changed = true;
2503 continue;
2504 }
2505 }
2506
2507 {
2508 // Collect all constant exprs and aggregates referenced by any function.
2509 SmallVector<Constant *, 8> Worklist;
2510 for (Function &F : M.functions())
2511 for (Instruction &I : instructions(F))
2512 for (Value *Op : I.operands())
2513 if (isa<ConstantExpr, ConstantAggregate>(Val: Op))
2514 Worklist.push_back(Elt: cast<Constant>(Val: Op));
2515
2516 // Recursively look for any referenced buffer pointer constants.
2517 SmallPtrSet<Constant *, 8> Visited;
2518 SetVector<Constant *> BufferFatPtrConsts;
2519 while (!Worklist.empty()) {
2520 Constant *C = Worklist.pop_back_val();
2521 if (!Visited.insert(Ptr: C).second)
2522 continue;
2523 if (isBufferFatPtrOrVector(Ty: C->getType()))
2524 BufferFatPtrConsts.insert(X: C);
2525 for (Value *Op : C->operands())
2526 if (isa<ConstantExpr, ConstantAggregate>(Val: Op))
2527 Worklist.push_back(Elt: cast<Constant>(Val: Op));
2528 }
2529
2530 // Expand all constant expressions using fat buffer pointers to
2531 // instructions.
2532 Changed |= convertUsersOfConstantsToInstructions(
2533 Consts: BufferFatPtrConsts.getArrayRef(), /*RestrictToFunc=*/nullptr,
2534 /*RemoveDeadConstants=*/false, /*IncludeSelf=*/true);
2535 }
2536
2537 StoreFatPtrsAsIntsAndExpandMemcpyVisitor MemOpsRewrite(&IntTM, DL,
2538 M.getContext(), &TM);
2539 LegalizeBufferContentTypesVisitor BufferContentsTypeRewrite(DL,
2540 M.getContext());
2541 for (Function &F : M.functions()) {
2542 bool InterfaceChange = hasFatPointerInterface(F, TypeMap: &StructTM);
2543 bool BodyChanges = containsBufferFatPointers(F, TypeMap: &StructTM);
2544 Changed |= MemOpsRewrite.processFunction(F);
2545 if (InterfaceChange || BodyChanges) {
2546 NeedsRemap.push_back(Elt: std::make_pair(x: &F, y&: InterfaceChange));
2547 Changed |= BufferContentsTypeRewrite.processFunction(F);
2548 }
2549 }
2550 if (NeedsRemap.empty())
2551 return Changed;
2552
2553 SmallVector<Function *> NeedsPostProcess;
2554 SmallVector<Function *> Intrinsics;
2555 // Keep one big map so as to memoize constants across functions.
2556 ValueToValueMapTy CloneMap;
2557 FatPtrConstMaterializer Materializer(&StructTM, CloneMap);
2558
2559 ValueMapper LowerInFuncs(CloneMap, RF_None, &StructTM, &Materializer);
2560 for (auto [F, InterfaceChange] : NeedsRemap) {
2561 Function *NewF = F;
2562 if (InterfaceChange)
2563 NewF = moveFunctionAdaptingType(
2564 OldF: F, NewTy: cast<FunctionType>(Val: StructTM.remapType(SrcTy: F->getFunctionType())),
2565 CloneMap);
2566 else
2567 makeCloneInPraceMap(F, CloneMap);
2568 LowerInFuncs.remapFunction(F&: *NewF);
2569 if (NewF->isIntrinsic())
2570 Intrinsics.push_back(Elt: NewF);
2571 else
2572 NeedsPostProcess.push_back(Elt: NewF);
2573 if (InterfaceChange) {
2574 F->replaceAllUsesWith(V: NewF);
2575 F->eraseFromParent();
2576 }
2577 Changed = true;
2578 }
2579 StructTM.clear();
2580 IntTM.clear();
2581 CloneMap.clear();
2582
2583 SplitPtrStructs Splitter(DL, M.getContext(), &TM);
2584 for (Function *F : NeedsPostProcess)
2585 Splitter.processFunction(F&: *F);
2586 for (Function *F : Intrinsics) {
2587 // use_empty() can also occur with cases like masked load, which will
2588 // have been rewritten out of the module by now but not erased.
2589 if (F->use_empty() || isRemovablePointerIntrinsic(IID: F->getIntrinsicID())) {
2590 F->eraseFromParent();
2591 } else {
2592 std::optional<Function *> NewF = Intrinsic::remangleIntrinsicFunction(F);
2593 if (NewF)
2594 F->replaceAllUsesWith(V: *NewF);
2595 }
2596 }
2597 return Changed;
2598}
2599
2600bool AMDGPULowerBufferFatPointers::runOnModule(Module &M) {
2601 TargetPassConfig &TPC = getAnalysis<TargetPassConfig>();
2602 const TargetMachine &TM = TPC.getTM<TargetMachine>();
2603 return run(M, TM);
2604}
2605
2606char AMDGPULowerBufferFatPointers::ID = 0;
2607
2608char &llvm::AMDGPULowerBufferFatPointersID = AMDGPULowerBufferFatPointers::ID;
2609
2610void AMDGPULowerBufferFatPointers::getAnalysisUsage(AnalysisUsage &AU) const {
2611 AU.addRequired<TargetPassConfig>();
2612}
2613
2614#define PASS_DESC "Lower buffer fat pointer operations to buffer resources"
2615INITIALIZE_PASS_BEGIN(AMDGPULowerBufferFatPointers, DEBUG_TYPE, PASS_DESC,
2616 false, false)
2617INITIALIZE_PASS_DEPENDENCY(TargetPassConfig)
2618INITIALIZE_PASS_END(AMDGPULowerBufferFatPointers, DEBUG_TYPE, PASS_DESC, false,
2619 false)
2620#undef PASS_DESC
2621
2622ModulePass *llvm::createAMDGPULowerBufferFatPointersPass() {
2623 return new AMDGPULowerBufferFatPointers();
2624}
2625
2626PreservedAnalyses
2627AMDGPULowerBufferFatPointersPass::run(Module &M, ModuleAnalysisManager &MA) {
2628 return AMDGPULowerBufferFatPointers().run(M, TM) ? PreservedAnalyses::none()
2629 : PreservedAnalyses::all();
2630}
2631