1//===-- AMDGPULowerBufferFatPointers.cpp ---------------------------=//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This pass lowers operations on buffer fat pointers (addrspace 7) to
10// operations on buffer resources (addrspace 8) and is needed for correct
11// codegen.
12//
13// # Background
14//
15// Address space 7 (the buffer fat pointer) is a 160-bit pointer that consists
16// of a 128-bit buffer descriptor and a 32-bit offset into that descriptor.
17// The buffer resource part needs to be it needs to be a "raw" buffer resource
18// (it must have a stride of 0 and bounds checks must be in raw buffer mode
19// or disabled).
20//
21// When these requirements are met, a buffer resource can be treated as a
22// typical (though quite wide) pointer that follows typical LLVM pointer
23// semantics. This allows the frontend to reason about such buffers (which are
24// often encountered in the context of SPIR-V kernels).
25//
26// However, because of their non-power-of-2 size, these fat pointers cannot be
27// present during translation to MIR (though this restriction may be lifted
28// during the transition to GlobalISel). Therefore, this pass is needed in order
29// to correctly implement these fat pointers.
30//
31// The resource intrinsics take the resource part (the address space 8 pointer)
32// and the offset part (the 32-bit integer) as separate arguments. In addition,
33// many users of these buffers manipulate the offset while leaving the resource
34// part alone. For these reasons, we want to typically separate the resource
35// and offset parts into separate variables, but combine them together when
36// encountering cases where this is required, such as by inserting these values
37// into aggretates or moving them to memory.
38//
39// Therefore, at a high level, `ptr addrspace(7) %x` becomes `ptr addrspace(8)
40// %x.rsrc` and `i32 %x.off`, which will be combined into `{ptr addrspace(8),
41// i32} %x = {%x.rsrc, %x.off}` if needed. Similarly, `vector<Nxp7>` becomes
42// `{vector<Nxp8>, vector<Nxi32 >}` and its component parts.
43//
44// # Implementation
45//
46// This pass proceeds in three main phases:
47//
48// ## Rewriting loads and stores of p7 and memcpy()-like handling
49//
50// The first phase is to rewrite away all loads and stors of `ptr addrspace(7)`,
51// including aggregates containing such pointers, to ones that use `i160`. This
52// is handled by `StoreFatPtrsAsIntsAndExpandMemcpyVisitor` , which visits
53// loads, stores, and allocas and, if the loaded or stored type contains `ptr
54// addrspace(7)`, rewrites that type to one where the p7s are replaced by i160s,
55// copying other parts of aggregates as needed. In the case of a store, each
56// pointer is `ptrtoint`d to i160 before storing, and load integers are
57// `inttoptr`d back. This same transformation is applied to vectors of pointers.
58//
59// Such a transformation allows the later phases of the pass to not need
60// to handle buffer fat pointers moving to and from memory, where we load
61// have to handle the incompatibility between a `{Nxp8, Nxi32}` representation
62// and `Nxi60` directly. Instead, that transposing action (where the vectors
63// of resources and vectors of offsets are concatentated before being stored to
64// memory) are handled through implementing `inttoptr` and `ptrtoint` only.
65//
66// Atomics operations on `ptr addrspace(7)` values are not suppported, as the
67// hardware does not include a 160-bit atomic.
68//
69// In order to save on O(N) work and to ensure that the contents type
70// legalizer correctly splits up wide loads, also unconditionally lower
71// memcpy-like intrinsics into loops here.
72//
73// ## Buffer contents type legalization
74//
75// The underlying buffer intrinsics only support types up to 128 bits long,
76// and don't support complex types. If buffer operations were
77// standard pointer operations that could be represented as MIR-level loads,
78// this would be handled by the various legalization schemes in instruction
79// selection. However, because we have to do the conversion from `load` and
80// `store` to intrinsics at LLVM IR level, we must perform that legalization
81// ourselves.
82//
83// This involves a combination of
84// - Converting arrays to vectors where possible
85// - Otherwise, splitting loads and stores of aggregates into loads/stores of
86// each component.
87// - Zero-extending things to fill a whole number of bytes
88// - Casting values of types that don't neatly correspond to supported machine
89// value
90// (for example, an i96 or i256) into ones that would work (
91// like <3 x i32> and <8 x i32>, respectively)
92// - Splitting values that are too long (such as aforementioned <8 x i32>) into
93// multiple operations.
94//
95// ## Type remapping
96//
97// We use a `ValueMapper` to mangle uses of [vectors of] buffer fat pointers
98// to the corresponding struct type, which has a resource part and an offset
99// part.
100//
101// This uses a `BufferFatPtrToStructTypeMap` and a `FatPtrConstMaterializer`
102// to, usually by way of `setType`ing values. Constants are handled here
103// because there isn't a good way to fix them up later.
104//
105// This has the downside of leaving the IR in an invalid state (for example,
106// the instruction `getelementptr {ptr addrspace(8), i32} %p, ...` will exist),
107// but all such invalid states will be resolved by the third phase.
108//
109// Functions that don't take buffer fat pointers are modified in place. Those
110// that do take such pointers have their basic blocks moved to a new function
111// with arguments that are {ptr addrspace(8), i32} arguments and return values.
112// This phase also records intrinsics so that they can be remangled or deleted
113// later.
114//
115// ## Splitting pointer structs
116//
117// The meat of this pass consists of defining semantics for operations that
118// produce or consume [vectors of] buffer fat pointers in terms of their
119// resource and offset parts. This is accomplished throgh the `SplitPtrStructs`
120// visitor.
121//
122// In the first pass through each function that is being lowered, the splitter
123// inserts new instructions to implement the split-structures behavior, which is
124// needed for correctness and performance. It records a list of "split users",
125// instructions that are being replaced by operations on the resource and offset
126// parts.
127//
128// Split users do not necessarily need to produce parts themselves (
129// a `load float, ptr addrspace(7)` does not, for example), but, if they do not
130// generate fat buffer pointers, they must RAUW in their replacement
131// instructions during the initial visit.
132//
133// When these new instructions are created, they use the split parts recorded
134// for their initial arguments in order to generate their replacements, creating
135// a parallel set of instructions that does not refer to the original fat
136// pointer values but instead to their resource and offset components.
137//
138// Instructions, such as `extractvalue`, that produce buffer fat pointers from
139// sources that do not have split parts, have such parts generated using
140// `extractvalue`. This is also the initial handling of PHI nodes, which
141// are then cleaned up.
142//
143// ### Conditionals
144//
145// PHI nodes are initially given resource parts via `extractvalue`. However,
146// this is not an efficient rewrite of such nodes, as, in most cases, the
147// resource part in a conditional or loop remains constant throughout the loop
148// and only the offset varies. Failing to optimize away these constant resources
149// would cause additional registers to be sent around loops and might lead to
150// waterfall loops being generated for buffer operations due to the
151// "non-uniform" resource argument.
152//
153// Therefore, after all instructions have been visited, the pointer splitter
154// post-processes all encountered conditionals. Given a PHI node or select,
155// getPossibleRsrcRoots() collects all values that the resource parts of that
156// conditional's input could come from as well as collecting all conditional
157// instructions encountered during the search. If, after filtering out the
158// initial node itself, the set of encountered conditionals is a subset of the
159// potential roots and there is a single potential resource that isn't in the
160// conditional set, that value is the only possible value the resource argument
161// could have throughout the control flow.
162//
163// If that condition is met, then a PHI node can have its resource part changed
164// to the singleton value and then be replaced by a PHI on the offsets.
165// Otherwise, each PHI node is split into two, one for the resource part and one
166// for the offset part, which replace the temporary `extractvalue` instructions
167// that were added during the first pass.
168//
169// Similar logic applies to `select`, where
170// `%z = select i1 %cond, %cond, ptr addrspace(7) %x, ptr addrspace(7) %y`
171// can be split into `%z.rsrc = %x.rsrc` and
172// `%z.off = select i1 %cond, ptr i32 %x.off, i32 %y.off`
173// if both `%x` and `%y` have the same resource part, but two `select`
174// operations will be needed if they do not.
175//
176// ### Final processing
177//
178// After conditionals have been cleaned up, the IR for each function is
179// rewritten to remove all the old instructions that have been split up.
180//
181// Any instruction that used to produce a buffer fat pointer (and therefore now
182// produces a resource-and-offset struct after type remapping) is
183// replaced as follows:
184// 1. All debug value annotations are cloned to reflect that the resource part
185// and offset parts are computed separately and constitute different
186// fragments of the underlying source language variable.
187// 2. All uses that were themselves split are replaced by a `poison` of the
188// struct type, as they will themselves be erased soon. This rule, combined
189// with debug handling, should leave the use lists of split instructions
190// empty in almost all cases.
191// 3. If a user of the original struct-valued result remains, the structure
192// needed for the new types to work is constructed out of the newly-defined
193// parts, and the original instruction is replaced by this structure
194// before being erased. Instructions requiring this construction include
195// `ret` and `insertvalue`.
196//
197// # Consequences
198//
199// This pass does not alter the CFG.
200//
201// Alias analysis information will become coarser, as the LLVM alias analyzer
202// cannot handle the buffer intrinsics. Specifically, while we can determine
203// that the following two loads do not alias:
204// ```
205// %y = getelementptr i32, ptr addrspace(7) %x, i32 1
206// %a = load i32, ptr addrspace(7) %x
207// %b = load i32, ptr addrspace(7) %y
208// ```
209// we cannot (except through some code that runs during scheduling) determine
210// that the rewritten loads below do not alias.
211// ```
212// %y.off = add i32 %x.off, 1
213// %a = call @llvm.amdgcn.raw.ptr.buffer.load(ptr addrspace(8) %x.rsrc, i32
214// %x.off, ...)
215// %b = call @llvm.amdgcn.raw.ptr.buffer.load(ptr addrspace(8)
216// %x.rsrc, i32 %y.off, ...)
217// ```
218// However, existing alias information is preserved.
219//===----------------------------------------------------------------------===//
220
221#include "AMDGPU.h"
222#include "AMDGPUTargetMachine.h"
223#include "GCNSubtarget.h"
224#include "SIDefines.h"
225#include "llvm/ADT/SetOperations.h"
226#include "llvm/ADT/SmallVector.h"
227#include "llvm/Analysis/InstSimplifyFolder.h"
228#include "llvm/Analysis/Utils/Local.h"
229#include "llvm/CodeGen/TargetPassConfig.h"
230#include "llvm/IR/AttributeMask.h"
231#include "llvm/IR/Constants.h"
232#include "llvm/IR/DebugInfo.h"
233#include "llvm/IR/DerivedTypes.h"
234#include "llvm/IR/IRBuilder.h"
235#include "llvm/IR/InstIterator.h"
236#include "llvm/IR/InstVisitor.h"
237#include "llvm/IR/Instructions.h"
238#include "llvm/IR/IntrinsicInst.h"
239#include "llvm/IR/Intrinsics.h"
240#include "llvm/IR/IntrinsicsAMDGPU.h"
241#include "llvm/IR/Metadata.h"
242#include "llvm/IR/Operator.h"
243#include "llvm/IR/PatternMatch.h"
244#include "llvm/IR/ReplaceConstant.h"
245#include "llvm/IR/ValueHandle.h"
246#include "llvm/Pass.h"
247#include "llvm/Support/AMDGPUAddrSpace.h"
248#include "llvm/Support/Alignment.h"
249#include "llvm/Support/AtomicOrdering.h"
250#include "llvm/Support/Debug.h"
251#include "llvm/Support/ErrorHandling.h"
252#include "llvm/Transforms/Utils/Cloning.h"
253#include "llvm/Transforms/Utils/Local.h"
254#include "llvm/Transforms/Utils/LowerMemIntrinsics.h"
255#include "llvm/Transforms/Utils/ValueMapper.h"
256
257#define DEBUG_TYPE "amdgpu-lower-buffer-fat-pointers"
258
259using namespace llvm;
260
261static constexpr unsigned BufferOffsetWidth = 32;
262
263namespace {
264/// Recursively replace instances of ptr addrspace(7) and vector<Nxptr
265/// addrspace(7)> with some other type as defined by the relevant subclass.
266class BufferFatPtrTypeLoweringBase : public ValueMapTypeRemapper {
267 DenseMap<Type *, Type *> Map;
268
269 Type *remapTypeImpl(Type *Ty);
270
271protected:
272 virtual Type *remapScalar(PointerType *PT) = 0;
273 virtual Type *remapVector(VectorType *VT) = 0;
274
275 const DataLayout &DL;
276
277public:
278 BufferFatPtrTypeLoweringBase(const DataLayout &DL) : DL(DL) {}
279 Type *remapType(Type *SrcTy) override;
280 void clear() { Map.clear(); }
281};
282
283/// Remap ptr addrspace(7) to i160 and vector<Nxptr addrspace(7)> to
284/// vector<Nxi60> in order to correctly handling loading/storing these values
285/// from memory.
286class BufferFatPtrToIntTypeMap : public BufferFatPtrTypeLoweringBase {
287 using BufferFatPtrTypeLoweringBase::BufferFatPtrTypeLoweringBase;
288
289protected:
290 Type *remapScalar(PointerType *PT) override { return DL.getIntPtrType(PT); }
291 Type *remapVector(VectorType *VT) override { return DL.getIntPtrType(VT); }
292};
293
294/// Remap ptr addrspace(7) to {ptr addrspace(8), i32} (the resource and offset
295/// parts of the pointer) so that we can easily rewrite operations on these
296/// values that aren't loading them from or storing them to memory.
297class BufferFatPtrToStructTypeMap : public BufferFatPtrTypeLoweringBase {
298 using BufferFatPtrTypeLoweringBase::BufferFatPtrTypeLoweringBase;
299
300protected:
301 Type *remapScalar(PointerType *PT) override;
302 Type *remapVector(VectorType *VT) override;
303};
304} // namespace
305
306// This code is adapted from the type remapper in lib/Linker/IRMover.cpp
307Type *BufferFatPtrTypeLoweringBase::remapTypeImpl(Type *Ty) {
308 Type **Entry = &Map[Ty];
309 if (*Entry)
310 return *Entry;
311 if (auto *PT = dyn_cast<PointerType>(Val: Ty)) {
312 if (PT->getAddressSpace() == AMDGPUAS::BUFFER_FAT_POINTER) {
313 return *Entry = remapScalar(PT);
314 }
315 }
316 if (auto *VT = dyn_cast<VectorType>(Val: Ty)) {
317 auto *PT = dyn_cast<PointerType>(Val: VT->getElementType());
318 if (PT && PT->getAddressSpace() == AMDGPUAS::BUFFER_FAT_POINTER) {
319 return *Entry = remapVector(VT);
320 }
321 return *Entry = Ty;
322 }
323 // Whether the type is one that is structurally uniqued - that is, if it is
324 // not a named struct (the only kind of type where multiple structurally
325 // identical types that have a distinct `Type*`)
326 StructType *TyAsStruct = dyn_cast<StructType>(Val: Ty);
327 bool IsUniqued = !TyAsStruct || TyAsStruct->isLiteral();
328 // Base case for ints, floats, opaque pointers, and so on, which don't
329 // require recursion.
330 if (Ty->getNumContainedTypes() == 0 && IsUniqued)
331 return *Entry = Ty;
332 bool Changed = false;
333 SmallVector<Type *> ElementTypes(Ty->getNumContainedTypes(), nullptr);
334 for (unsigned int I = 0, E = Ty->getNumContainedTypes(); I < E; ++I) {
335 Type *OldElem = Ty->getContainedType(i: I);
336 Type *NewElem = remapTypeImpl(Ty: OldElem);
337 ElementTypes[I] = NewElem;
338 Changed |= (OldElem != NewElem);
339 }
340 // Recursive calls to remapTypeImpl() may have invalidated pointer.
341 Entry = &Map[Ty];
342 if (!Changed) {
343 return *Entry = Ty;
344 }
345 if (auto *ArrTy = dyn_cast<ArrayType>(Val: Ty))
346 return *Entry = ArrayType::get(ElementType: ElementTypes[0], NumElements: ArrTy->getNumElements());
347 if (auto *FnTy = dyn_cast<FunctionType>(Val: Ty))
348 return *Entry = FunctionType::get(Result: ElementTypes[0],
349 Params: ArrayRef(ElementTypes).slice(N: 1),
350 isVarArg: FnTy->isVarArg());
351 if (auto *STy = dyn_cast<StructType>(Val: Ty)) {
352 // Genuine opaque types don't have a remapping.
353 if (STy->isOpaque())
354 return *Entry = Ty;
355 bool IsPacked = STy->isPacked();
356 if (IsUniqued)
357 return *Entry = StructType::get(Context&: Ty->getContext(), Elements: ElementTypes, isPacked: IsPacked);
358 SmallString<16> Name(STy->getName());
359 STy->setName("");
360 return *Entry = StructType::create(Context&: Ty->getContext(), Elements: ElementTypes, Name,
361 isPacked: IsPacked);
362 }
363 llvm_unreachable("Unknown type of type that contains elements");
364}
365
366Type *BufferFatPtrTypeLoweringBase::remapType(Type *SrcTy) {
367 return remapTypeImpl(Ty: SrcTy);
368}
369
370Type *BufferFatPtrToStructTypeMap::remapScalar(PointerType *PT) {
371 LLVMContext &Ctx = PT->getContext();
372 return StructType::get(elt1: PointerType::get(C&: Ctx, AddressSpace: AMDGPUAS::BUFFER_RESOURCE),
373 elts: IntegerType::get(C&: Ctx, NumBits: BufferOffsetWidth));
374}
375
376Type *BufferFatPtrToStructTypeMap::remapVector(VectorType *VT) {
377 ElementCount EC = VT->getElementCount();
378 LLVMContext &Ctx = VT->getContext();
379 Type *RsrcVec =
380 VectorType::get(ElementType: PointerType::get(C&: Ctx, AddressSpace: AMDGPUAS::BUFFER_RESOURCE), EC);
381 Type *OffVec = VectorType::get(ElementType: IntegerType::get(C&: Ctx, NumBits: BufferOffsetWidth), EC);
382 return StructType::get(elt1: RsrcVec, elts: OffVec);
383}
384
385static bool isBufferFatPtrOrVector(Type *Ty) {
386 if (auto *PT = dyn_cast<PointerType>(Val: Ty->getScalarType()))
387 return PT->getAddressSpace() == AMDGPUAS::BUFFER_FAT_POINTER;
388 return false;
389}
390
391// True if the type is {ptr addrspace(8), i32} or a struct containing vectors of
392// those types. Used to quickly skip instructions we don't need to process.
393static bool isSplitFatPtr(Type *Ty) {
394 auto *ST = dyn_cast<StructType>(Val: Ty);
395 if (!ST)
396 return false;
397 if (!ST->isLiteral() || ST->getNumElements() != 2)
398 return false;
399 auto *MaybeRsrc =
400 dyn_cast<PointerType>(Val: ST->getElementType(N: 0)->getScalarType());
401 auto *MaybeOff =
402 dyn_cast<IntegerType>(Val: ST->getElementType(N: 1)->getScalarType());
403 return MaybeRsrc && MaybeOff &&
404 MaybeRsrc->getAddressSpace() == AMDGPUAS::BUFFER_RESOURCE &&
405 MaybeOff->getBitWidth() == BufferOffsetWidth;
406}
407
408// True if the result type or any argument types are buffer fat pointers.
409static bool isBufferFatPtrConst(Constant *C) {
410 Type *T = C->getType();
411 return isBufferFatPtrOrVector(Ty: T) || any_of(Range: C->operands(), P: [](const Use &U) {
412 return isBufferFatPtrOrVector(Ty: U.get()->getType());
413 });
414}
415
416namespace {
417/// Convert [vectors of] buffer fat pointers to integers when they are read from
418/// or stored to memory. This ensures that these pointers will have the same
419/// memory layout as before they are lowered, even though they will no longer
420/// have their previous layout in registers/in the program (they'll be broken
421/// down into resource and offset parts). This has the downside of imposing
422/// marshalling costs when reading or storing these values, but since placing
423/// such pointers into memory is an uncommon operation at best, we feel that
424/// this cost is acceptable for better performance in the common case.
425class StoreFatPtrsAsIntsAndExpandMemcpyVisitor
426 : public InstVisitor<StoreFatPtrsAsIntsAndExpandMemcpyVisitor, bool> {
427 BufferFatPtrToIntTypeMap *TypeMap;
428
429 ValueToValueMapTy ConvertedForStore;
430
431 IRBuilder<InstSimplifyFolder> IRB;
432
433 const TargetMachine *TM;
434
435 // Convert all the buffer fat pointers within the input value to inttegers
436 // so that it can be stored in memory.
437 Value *fatPtrsToInts(Value *V, Type *From, Type *To, const Twine &Name);
438 // Convert all the i160s that need to be buffer fat pointers (as specified)
439 // by the To type) into those pointers to preserve the semantics of the rest
440 // of the program.
441 Value *intsToFatPtrs(Value *V, Type *From, Type *To, const Twine &Name);
442
443public:
444 StoreFatPtrsAsIntsAndExpandMemcpyVisitor(BufferFatPtrToIntTypeMap *TypeMap,
445 const DataLayout &DL,
446 LLVMContext &Ctx,
447 const TargetMachine *TM)
448 : TypeMap(TypeMap), IRB(Ctx, InstSimplifyFolder(DL)), TM(TM) {}
449 bool processFunction(Function &F);
450
451 bool visitInstruction(Instruction &I) { return false; }
452 bool visitAllocaInst(AllocaInst &I);
453 bool visitLoadInst(LoadInst &LI);
454 bool visitStoreInst(StoreInst &SI);
455 bool visitGetElementPtrInst(GetElementPtrInst &I);
456
457 bool visitMemCpyInst(MemCpyInst &MCI);
458 bool visitMemMoveInst(MemMoveInst &MMI);
459 bool visitMemSetInst(MemSetInst &MSI);
460 bool visitMemSetPatternInst(MemSetPatternInst &MSPI);
461};
462} // namespace
463
464Value *StoreFatPtrsAsIntsAndExpandMemcpyVisitor::fatPtrsToInts(
465 Value *V, Type *From, Type *To, const Twine &Name) {
466 if (From == To)
467 return V;
468 ValueToValueMapTy::iterator Find = ConvertedForStore.find(Val: V);
469 if (Find != ConvertedForStore.end())
470 return Find->second;
471 if (isBufferFatPtrOrVector(Ty: From)) {
472 Value *Cast = IRB.CreatePtrToInt(V, DestTy: To, Name: Name + ".int");
473 ConvertedForStore[V] = Cast;
474 return Cast;
475 }
476 if (From->getNumContainedTypes() == 0)
477 return V;
478 // Structs, arrays, and other compound types.
479 Value *Ret = PoisonValue::get(T: To);
480 if (auto *AT = dyn_cast<ArrayType>(Val: From)) {
481 Type *FromPart = AT->getArrayElementType();
482 Type *ToPart = cast<ArrayType>(Val: To)->getElementType();
483 for (uint64_t I = 0, E = AT->getArrayNumElements(); I < E; ++I) {
484 Value *Field = IRB.CreateExtractValue(Agg: V, Idxs: I);
485 Value *NewField =
486 fatPtrsToInts(V: Field, From: FromPart, To: ToPart, Name: Name + "." + Twine(I));
487 Ret = IRB.CreateInsertValue(Agg: Ret, Val: NewField, Idxs: I);
488 }
489 } else {
490 for (auto [Idx, FromPart, ToPart] :
491 enumerate(First: From->subtypes(), Rest: To->subtypes())) {
492 Value *Field = IRB.CreateExtractValue(Agg: V, Idxs: Idx);
493 Value *NewField =
494 fatPtrsToInts(V: Field, From: FromPart, To: ToPart, Name: Name + "." + Twine(Idx));
495 Ret = IRB.CreateInsertValue(Agg: Ret, Val: NewField, Idxs: Idx);
496 }
497 }
498 ConvertedForStore[V] = Ret;
499 return Ret;
500}
501
502Value *StoreFatPtrsAsIntsAndExpandMemcpyVisitor::intsToFatPtrs(
503 Value *V, Type *From, Type *To, const Twine &Name) {
504 if (From == To)
505 return V;
506 if (isBufferFatPtrOrVector(Ty: To)) {
507 Value *Cast = IRB.CreateIntToPtr(V, DestTy: To, Name: Name + ".ptr");
508 return Cast;
509 }
510 if (From->getNumContainedTypes() == 0)
511 return V;
512 // Structs, arrays, and other compound types.
513 Value *Ret = PoisonValue::get(T: To);
514 if (auto *AT = dyn_cast<ArrayType>(Val: From)) {
515 Type *FromPart = AT->getArrayElementType();
516 Type *ToPart = cast<ArrayType>(Val: To)->getElementType();
517 for (uint64_t I = 0, E = AT->getArrayNumElements(); I < E; ++I) {
518 Value *Field = IRB.CreateExtractValue(Agg: V, Idxs: I);
519 Value *NewField =
520 intsToFatPtrs(V: Field, From: FromPart, To: ToPart, Name: Name + "." + Twine(I));
521 Ret = IRB.CreateInsertValue(Agg: Ret, Val: NewField, Idxs: I);
522 }
523 } else {
524 for (auto [Idx, FromPart, ToPart] :
525 enumerate(First: From->subtypes(), Rest: To->subtypes())) {
526 Value *Field = IRB.CreateExtractValue(Agg: V, Idxs: Idx);
527 Value *NewField =
528 intsToFatPtrs(V: Field, From: FromPart, To: ToPart, Name: Name + "." + Twine(Idx));
529 Ret = IRB.CreateInsertValue(Agg: Ret, Val: NewField, Idxs: Idx);
530 }
531 }
532 return Ret;
533}
534
535bool StoreFatPtrsAsIntsAndExpandMemcpyVisitor::processFunction(Function &F) {
536 bool Changed = false;
537 // Process memcpy-like instructions after the main iteration because they can
538 // invalidate iterators.
539 SmallVector<WeakTrackingVH> CanBecomeLoops;
540 for (Instruction &I : make_early_inc_range(Range: instructions(F))) {
541 if (isa<MemTransferInst, MemSetInst, MemSetPatternInst>(Val: I))
542 CanBecomeLoops.push_back(Elt: &I);
543 else
544 Changed |= visit(I);
545 }
546 for (WeakTrackingVH VH : make_early_inc_range(Range&: CanBecomeLoops)) {
547 Changed |= visit(I: cast<Instruction>(Val&: VH));
548 }
549 ConvertedForStore.clear();
550 return Changed;
551}
552
553bool StoreFatPtrsAsIntsAndExpandMemcpyVisitor::visitAllocaInst(AllocaInst &I) {
554 Type *Ty = I.getAllocatedType();
555 Type *NewTy = TypeMap->remapType(SrcTy: Ty);
556 if (Ty == NewTy)
557 return false;
558 I.setAllocatedType(NewTy);
559 return true;
560}
561
562bool StoreFatPtrsAsIntsAndExpandMemcpyVisitor::visitGetElementPtrInst(
563 GetElementPtrInst &I) {
564 Type *Ty = I.getSourceElementType();
565 Type *NewTy = TypeMap->remapType(SrcTy: Ty);
566 if (Ty == NewTy)
567 return false;
568 // We'll be rewriting the type `ptr addrspace(7)` out of existence soon, so
569 // make sure GEPs don't have different semantics with the new type.
570 I.setSourceElementType(NewTy);
571 I.setResultElementType(TypeMap->remapType(SrcTy: I.getResultElementType()));
572 return true;
573}
574
575bool StoreFatPtrsAsIntsAndExpandMemcpyVisitor::visitLoadInst(LoadInst &LI) {
576 Type *Ty = LI.getType();
577 Type *IntTy = TypeMap->remapType(SrcTy: Ty);
578 if (Ty == IntTy)
579 return false;
580
581 IRB.SetInsertPoint(&LI);
582 auto *NLI = cast<LoadInst>(Val: LI.clone());
583 NLI->mutateType(Ty: IntTy);
584 NLI = IRB.Insert(I: NLI);
585 NLI->takeName(V: &LI);
586
587 Value *CastBack = intsToFatPtrs(V: NLI, From: IntTy, To: Ty, Name: NLI->getName());
588 LI.replaceAllUsesWith(V: CastBack);
589 LI.eraseFromParent();
590 return true;
591}
592
593bool StoreFatPtrsAsIntsAndExpandMemcpyVisitor::visitStoreInst(StoreInst &SI) {
594 Value *V = SI.getValueOperand();
595 Type *Ty = V->getType();
596 Type *IntTy = TypeMap->remapType(SrcTy: Ty);
597 if (Ty == IntTy)
598 return false;
599
600 IRB.SetInsertPoint(&SI);
601 Value *IntV = fatPtrsToInts(V, From: Ty, To: IntTy, Name: V->getName());
602 for (auto *Dbg : at::getAssignmentMarkers(Inst: &SI))
603 Dbg->setValue(IntV);
604
605 SI.setOperand(i_nocapture: 0, Val_nocapture: IntV);
606 return true;
607}
608
609bool StoreFatPtrsAsIntsAndExpandMemcpyVisitor::visitMemCpyInst(
610 MemCpyInst &MCI) {
611 // TODO: Allow memcpy.p7.p3 as a synonym for the direct-to-LDS copy, which'll
612 // need loop expansion here.
613 if (MCI.getSourceAddressSpace() != AMDGPUAS::BUFFER_FAT_POINTER &&
614 MCI.getDestAddressSpace() != AMDGPUAS::BUFFER_FAT_POINTER)
615 return false;
616 llvm::expandMemCpyAsLoop(MemCpy: &MCI,
617 TTI: TM->getTargetTransformInfo(F: *MCI.getFunction()));
618 MCI.eraseFromParent();
619 return true;
620}
621
622bool StoreFatPtrsAsIntsAndExpandMemcpyVisitor::visitMemMoveInst(
623 MemMoveInst &MMI) {
624 if (MMI.getSourceAddressSpace() != AMDGPUAS::BUFFER_FAT_POINTER &&
625 MMI.getDestAddressSpace() != AMDGPUAS::BUFFER_FAT_POINTER)
626 return false;
627 reportFatalUsageError(
628 reason: "memmove() on buffer descriptors is not implemented because pointer "
629 "comparison on buffer descriptors isn't implemented\n");
630}
631
632bool StoreFatPtrsAsIntsAndExpandMemcpyVisitor::visitMemSetInst(
633 MemSetInst &MSI) {
634 if (MSI.getDestAddressSpace() != AMDGPUAS::BUFFER_FAT_POINTER)
635 return false;
636 llvm::expandMemSetAsLoop(MemSet: &MSI);
637 MSI.eraseFromParent();
638 return true;
639}
640
641bool StoreFatPtrsAsIntsAndExpandMemcpyVisitor::visitMemSetPatternInst(
642 MemSetPatternInst &MSPI) {
643 if (MSPI.getDestAddressSpace() != AMDGPUAS::BUFFER_FAT_POINTER)
644 return false;
645 llvm::expandMemSetPatternAsLoop(MemSet: &MSPI);
646 MSPI.eraseFromParent();
647 return true;
648}
649
650namespace {
651/// Convert loads/stores of types that the buffer intrinsics can't handle into
652/// one ore more such loads/stores that consist of legal types.
653///
654/// Do this by
655/// 1. Recursing into structs (and arrays that don't share a memory layout with
656/// vectors) since the intrinsics can't handle complex types.
657/// 2. Converting arrays of non-aggregate, byte-sized types into their
658/// corresponding vectors
659/// 3. Bitcasting unsupported types, namely overly-long scalars and byte
660/// vectors, into vectors of supported types.
661/// 4. Splitting up excessively long reads/writes into multiple operations.
662///
663/// Note that this doesn't handle complex data strucures, but, in the future,
664/// the aggregate load splitter from SROA could be refactored to allow for that
665/// case.
666class LegalizeBufferContentTypesVisitor
667 : public InstVisitor<LegalizeBufferContentTypesVisitor, bool> {
668 friend class InstVisitor<LegalizeBufferContentTypesVisitor, bool>;
669
670 IRBuilder<InstSimplifyFolder> IRB;
671
672 const DataLayout &DL;
673
674 /// If T is [N x U], where U is a scalar type, return the vector type
675 /// <N x U>, otherwise, return T.
676 Type *scalarArrayTypeAsVector(Type *MaybeArrayType);
677 Value *arrayToVector(Value *V, Type *TargetType, const Twine &Name);
678 Value *vectorToArray(Value *V, Type *OrigType, const Twine &Name);
679
680 /// Break up the loads of a struct into the loads of its components
681
682 /// Convert a vector or scalar type that can't be operated on by buffer
683 /// intrinsics to one that would be legal through bitcasts and/or truncation.
684 /// Uses the wider of i32, i16, or i8 where possible.
685 Type *legalNonAggregateFor(Type *T);
686 Value *makeLegalNonAggregate(Value *V, Type *TargetType, const Twine &Name);
687 Value *makeIllegalNonAggregate(Value *V, Type *OrigType, const Twine &Name);
688
689 struct VecSlice {
690 uint64_t Index = 0;
691 uint64_t Length = 0;
692 VecSlice() = delete;
693 // Needed for some Clangs
694 VecSlice(uint64_t Index, uint64_t Length) : Index(Index), Length(Length) {}
695 };
696 /// Return the [index, length] pairs into which `T` needs to be cut to form
697 /// legal buffer load or store operations. Clears `Slices`. Creates an empty
698 /// `Slices` for non-vector inputs and creates one slice if no slicing will be
699 /// needed.
700 void getVecSlices(Type *T, SmallVectorImpl<VecSlice> &Slices);
701
702 Value *extractSlice(Value *Vec, VecSlice S, const Twine &Name);
703 Value *insertSlice(Value *Whole, Value *Part, VecSlice S, const Twine &Name);
704
705 /// In most cases, return `LegalType`. However, when given an input that would
706 /// normally be a legal type for the buffer intrinsics to return but that
707 /// isn't hooked up through SelectionDAG, return a type of the same width that
708 /// can be used with the relevant intrinsics. Specifically, handle the cases:
709 /// - <1 x T> => T for all T
710 /// - <N x i8> <=> i16, i32, 2xi32, 4xi32 (as needed)
711 /// - <N x T> where T is under 32 bits and the total size is 96 bits <=> <3 x
712 /// i32>
713 Type *intrinsicTypeFor(Type *LegalType);
714
715 bool visitLoadImpl(LoadInst &OrigLI, Type *PartType,
716 SmallVectorImpl<uint32_t> &AggIdxs, uint64_t AggByteOffset,
717 Value *&Result, const Twine &Name);
718 /// Return value is (Changed, ModifiedInPlace)
719 std::pair<bool, bool> visitStoreImpl(StoreInst &OrigSI, Type *PartType,
720 SmallVectorImpl<uint32_t> &AggIdxs,
721 uint64_t AggByteOffset,
722 const Twine &Name);
723
724 bool visitInstruction(Instruction &I) { return false; }
725 bool visitLoadInst(LoadInst &LI);
726 bool visitStoreInst(StoreInst &SI);
727
728public:
729 LegalizeBufferContentTypesVisitor(const DataLayout &DL, LLVMContext &Ctx)
730 : IRB(Ctx, InstSimplifyFolder(DL)), DL(DL) {}
731 bool processFunction(Function &F);
732};
733} // namespace
734
735Type *LegalizeBufferContentTypesVisitor::scalarArrayTypeAsVector(Type *T) {
736 ArrayType *AT = dyn_cast<ArrayType>(Val: T);
737 if (!AT)
738 return T;
739 Type *ET = AT->getElementType();
740 if (!ET->isSingleValueType() || isa<VectorType>(Val: ET))
741 reportFatalUsageError(reason: "loading non-scalar arrays from buffer fat pointers "
742 "should have recursed");
743 if (!DL.typeSizeEqualsStoreSize(Ty: AT))
744 reportFatalUsageError(
745 reason: "loading padded arrays from buffer fat pinters should have recursed");
746 return FixedVectorType::get(ElementType: ET, NumElts: AT->getNumElements());
747}
748
749Value *LegalizeBufferContentTypesVisitor::arrayToVector(Value *V,
750 Type *TargetType,
751 const Twine &Name) {
752 Value *VectorRes = PoisonValue::get(T: TargetType);
753 auto *VT = cast<FixedVectorType>(Val: TargetType);
754 unsigned EC = VT->getNumElements();
755 for (auto I : iota_range<unsigned>(0, EC, /*Inclusive=*/false)) {
756 Value *Elem = IRB.CreateExtractValue(Agg: V, Idxs: I, Name: Name + ".elem." + Twine(I));
757 VectorRes = IRB.CreateInsertElement(Vec: VectorRes, NewElt: Elem, Idx: I,
758 Name: Name + ".as.vec." + Twine(I));
759 }
760 return VectorRes;
761}
762
763Value *LegalizeBufferContentTypesVisitor::vectorToArray(Value *V,
764 Type *OrigType,
765 const Twine &Name) {
766 Value *ArrayRes = PoisonValue::get(T: OrigType);
767 ArrayType *AT = cast<ArrayType>(Val: OrigType);
768 unsigned EC = AT->getNumElements();
769 for (auto I : iota_range<unsigned>(0, EC, /*Inclusive=*/false)) {
770 Value *Elem = IRB.CreateExtractElement(Vec: V, Idx: I, Name: Name + ".elem." + Twine(I));
771 ArrayRes = IRB.CreateInsertValue(Agg: ArrayRes, Val: Elem, Idxs: I,
772 Name: Name + ".as.array." + Twine(I));
773 }
774 return ArrayRes;
775}
776
777Type *LegalizeBufferContentTypesVisitor::legalNonAggregateFor(Type *T) {
778 TypeSize Size = DL.getTypeStoreSizeInBits(Ty: T);
779 // Implicitly zero-extend to the next byte if needed
780 if (!DL.typeSizeEqualsStoreSize(Ty: T))
781 T = IRB.getIntNTy(N: Size.getFixedValue());
782 Type *ElemTy = T->getScalarType();
783 if (isa<PointerType, ScalableVectorType>(Val: ElemTy)) {
784 // Pointers are always big enough, and we'll let scalable vectors through to
785 // fail in codegen.
786 return T;
787 }
788 unsigned ElemSize = DL.getTypeSizeInBits(Ty: ElemTy).getFixedValue();
789 if (isPowerOf2_32(Value: ElemSize) && ElemSize >= 16 && ElemSize <= 128) {
790 // [vectors of] anything that's 16/32/64/128 bits can be cast and split into
791 // legal buffer operations.
792 return T;
793 }
794 Type *BestVectorElemType = nullptr;
795 if (Size.isKnownMultipleOf(RHS: 32))
796 BestVectorElemType = IRB.getInt32Ty();
797 else if (Size.isKnownMultipleOf(RHS: 16))
798 BestVectorElemType = IRB.getInt16Ty();
799 else
800 BestVectorElemType = IRB.getInt8Ty();
801 unsigned NumCastElems =
802 Size.getFixedValue() / BestVectorElemType->getIntegerBitWidth();
803 if (NumCastElems == 1)
804 return BestVectorElemType;
805 return FixedVectorType::get(ElementType: BestVectorElemType, NumElts: NumCastElems);
806}
807
808Value *LegalizeBufferContentTypesVisitor::makeLegalNonAggregate(
809 Value *V, Type *TargetType, const Twine &Name) {
810 Type *SourceType = V->getType();
811 TypeSize SourceSize = DL.getTypeSizeInBits(Ty: SourceType);
812 TypeSize TargetSize = DL.getTypeSizeInBits(Ty: TargetType);
813 if (SourceSize != TargetSize) {
814 Type *ShortScalarTy = IRB.getIntNTy(N: SourceSize.getFixedValue());
815 Type *ByteScalarTy = IRB.getIntNTy(N: TargetSize.getFixedValue());
816 Value *AsScalar = IRB.CreateBitCast(V, DestTy: ShortScalarTy, Name: Name + ".as.scalar");
817 Value *Zext = IRB.CreateZExt(V: AsScalar, DestTy: ByteScalarTy, Name: Name + ".zext");
818 V = Zext;
819 SourceType = ByteScalarTy;
820 }
821 return IRB.CreateBitCast(V, DestTy: TargetType, Name: Name + ".legal");
822}
823
824Value *LegalizeBufferContentTypesVisitor::makeIllegalNonAggregate(
825 Value *V, Type *OrigType, const Twine &Name) {
826 Type *LegalType = V->getType();
827 TypeSize LegalSize = DL.getTypeSizeInBits(Ty: LegalType);
828 TypeSize OrigSize = DL.getTypeSizeInBits(Ty: OrigType);
829 if (LegalSize != OrigSize) {
830 Type *ShortScalarTy = IRB.getIntNTy(N: OrigSize.getFixedValue());
831 Type *ByteScalarTy = IRB.getIntNTy(N: LegalSize.getFixedValue());
832 Value *AsScalar = IRB.CreateBitCast(V, DestTy: ByteScalarTy, Name: Name + ".bytes.cast");
833 Value *Trunc = IRB.CreateTrunc(V: AsScalar, DestTy: ShortScalarTy, Name: Name + ".trunc");
834 return IRB.CreateBitCast(V: Trunc, DestTy: OrigType, Name: Name + ".orig");
835 }
836 return IRB.CreateBitCast(V, DestTy: OrigType, Name: Name + ".real.ty");
837}
838
839Type *LegalizeBufferContentTypesVisitor::intrinsicTypeFor(Type *LegalType) {
840 auto *VT = dyn_cast<FixedVectorType>(Val: LegalType);
841 if (!VT)
842 return LegalType;
843 Type *ET = VT->getElementType();
844 // Explicitly return the element type of 1-element vectors because the
845 // underlying intrinsics don't like <1 x T> even though it's a synonym for T.
846 if (VT->getNumElements() == 1)
847 return ET;
848 if (DL.getTypeSizeInBits(Ty: LegalType) == 96 && DL.getTypeSizeInBits(Ty: ET) < 32)
849 return FixedVectorType::get(ElementType: IRB.getInt32Ty(), NumElts: 3);
850 if (ET->isIntegerTy(Bitwidth: 8)) {
851 switch (VT->getNumElements()) {
852 default:
853 return LegalType; // Let it crash later
854 case 1:
855 return IRB.getInt8Ty();
856 case 2:
857 return IRB.getInt16Ty();
858 case 4:
859 return IRB.getInt32Ty();
860 case 8:
861 return FixedVectorType::get(ElementType: IRB.getInt32Ty(), NumElts: 2);
862 case 16:
863 return FixedVectorType::get(ElementType: IRB.getInt32Ty(), NumElts: 4);
864 }
865 }
866 return LegalType;
867}
868
869void LegalizeBufferContentTypesVisitor::getVecSlices(
870 Type *T, SmallVectorImpl<VecSlice> &Slices) {
871 Slices.clear();
872 auto *VT = dyn_cast<FixedVectorType>(Val: T);
873 if (!VT)
874 return;
875
876 uint64_t ElemBitWidth =
877 DL.getTypeSizeInBits(Ty: VT->getElementType()).getFixedValue();
878
879 uint64_t ElemsPer4Words = 128 / ElemBitWidth;
880 uint64_t ElemsPer2Words = ElemsPer4Words / 2;
881 uint64_t ElemsPerWord = ElemsPer2Words / 2;
882 uint64_t ElemsPerShort = ElemsPerWord / 2;
883 uint64_t ElemsPerByte = ElemsPerShort / 2;
884 // If the elements evenly pack into 32-bit words, we can use 3-word stores,
885 // such as for <6 x bfloat> or <3 x i32>, but we can't dot his for, for
886 // example, <3 x i64>, since that's not slicing.
887 uint64_t ElemsPer3Words = ElemsPerWord * 3;
888
889 uint64_t TotalElems = VT->getNumElements();
890 uint64_t Index = 0;
891 auto TrySlice = [&](unsigned MaybeLen) {
892 if (MaybeLen > 0 && Index + MaybeLen <= TotalElems) {
893 VecSlice Slice{/*Index=*/Index, /*Length=*/MaybeLen};
894 Slices.push_back(Elt: Slice);
895 Index += MaybeLen;
896 return true;
897 }
898 return false;
899 };
900 while (Index < TotalElems) {
901 TrySlice(ElemsPer4Words) || TrySlice(ElemsPer3Words) ||
902 TrySlice(ElemsPer2Words) || TrySlice(ElemsPerWord) ||
903 TrySlice(ElemsPerShort) || TrySlice(ElemsPerByte);
904 }
905}
906
907Value *LegalizeBufferContentTypesVisitor::extractSlice(Value *Vec, VecSlice S,
908 const Twine &Name) {
909 auto *VecVT = dyn_cast<FixedVectorType>(Val: Vec->getType());
910 if (!VecVT)
911 return Vec;
912 if (S.Length == VecVT->getNumElements() && S.Index == 0)
913 return Vec;
914 if (S.Length == 1)
915 return IRB.CreateExtractElement(Vec, Idx: S.Index,
916 Name: Name + ".slice." + Twine(S.Index));
917 SmallVector<int> Mask = llvm::to_vector(
918 Range: llvm::iota_range<int>(S.Index, S.Index + S.Length, /*Inclusive=*/false));
919 return IRB.CreateShuffleVector(V: Vec, Mask, Name: Name + ".slice." + Twine(S.Index));
920}
921
922Value *LegalizeBufferContentTypesVisitor::insertSlice(Value *Whole, Value *Part,
923 VecSlice S,
924 const Twine &Name) {
925 auto *WholeVT = dyn_cast<FixedVectorType>(Val: Whole->getType());
926 if (!WholeVT)
927 return Part;
928 if (S.Length == WholeVT->getNumElements() && S.Index == 0)
929 return Part;
930 if (S.Length == 1) {
931 return IRB.CreateInsertElement(Vec: Whole, NewElt: Part, Idx: S.Index,
932 Name: Name + ".slice." + Twine(S.Index));
933 }
934 int NumElems = cast<FixedVectorType>(Val: Whole->getType())->getNumElements();
935
936 // Extend the slice with poisons to make the main shufflevector happy.
937 SmallVector<int> ExtPartMask(NumElems, -1);
938 for (auto [I, E] : llvm::enumerate(
939 First: MutableArrayRef<int>(ExtPartMask).take_front(N: S.Length))) {
940 E = I;
941 }
942 Value *ExtPart = IRB.CreateShuffleVector(V: Part, Mask: ExtPartMask,
943 Name: Name + ".ext." + Twine(S.Index));
944
945 SmallVector<int> Mask =
946 llvm::to_vector(Range: llvm::iota_range<int>(0, NumElems, /*Inclusive=*/false));
947 for (auto [I, E] :
948 llvm::enumerate(First: MutableArrayRef<int>(Mask).slice(N: S.Index, M: S.Length)))
949 E = I + NumElems;
950 return IRB.CreateShuffleVector(V1: Whole, V2: ExtPart, Mask,
951 Name: Name + ".parts." + Twine(S.Index));
952}
953
954bool LegalizeBufferContentTypesVisitor::visitLoadImpl(
955 LoadInst &OrigLI, Type *PartType, SmallVectorImpl<uint32_t> &AggIdxs,
956 uint64_t AggByteOff, Value *&Result, const Twine &Name) {
957 if (auto *ST = dyn_cast<StructType>(Val: PartType)) {
958 const StructLayout *Layout = DL.getStructLayout(Ty: ST);
959 bool Changed = false;
960 for (auto [I, ElemTy, Offset] :
961 llvm::enumerate(First: ST->elements(), Rest: Layout->getMemberOffsets())) {
962 AggIdxs.push_back(Elt: I);
963 Changed |= visitLoadImpl(OrigLI, PartType: ElemTy, AggIdxs,
964 AggByteOff: AggByteOff + Offset.getFixedValue(), Result,
965 Name: Name + "." + Twine(I));
966 AggIdxs.pop_back();
967 }
968 return Changed;
969 }
970 if (auto *AT = dyn_cast<ArrayType>(Val: PartType)) {
971 Type *ElemTy = AT->getElementType();
972 if (!ElemTy->isSingleValueType() || !DL.typeSizeEqualsStoreSize(Ty: ElemTy) ||
973 ElemTy->isVectorTy()) {
974 TypeSize ElemStoreSize = DL.getTypeStoreSize(Ty: ElemTy);
975 bool Changed = false;
976 for (auto I : llvm::iota_range<uint32_t>(0, AT->getNumElements(),
977 /*Inclusive=*/false)) {
978 AggIdxs.push_back(Elt: I);
979 Changed |= visitLoadImpl(OrigLI, PartType: ElemTy, AggIdxs,
980 AggByteOff: AggByteOff + I * ElemStoreSize.getFixedValue(),
981 Result, Name: Name + Twine(I));
982 AggIdxs.pop_back();
983 }
984 return Changed;
985 }
986 }
987
988 // Typical case
989
990 Type *ArrayAsVecType = scalarArrayTypeAsVector(T: PartType);
991 Type *LegalType = legalNonAggregateFor(T: ArrayAsVecType);
992
993 SmallVector<VecSlice> Slices;
994 getVecSlices(T: LegalType, Slices);
995 bool HasSlices = Slices.size() > 1;
996 bool IsAggPart = !AggIdxs.empty();
997 Value *LoadsRes;
998 if (!HasSlices && !IsAggPart) {
999 Type *LoadableType = intrinsicTypeFor(LegalType);
1000 if (LoadableType == PartType)
1001 return false;
1002
1003 IRB.SetInsertPoint(&OrigLI);
1004 auto *NLI = cast<LoadInst>(Val: OrigLI.clone());
1005 NLI->mutateType(Ty: LoadableType);
1006 NLI = IRB.Insert(I: NLI);
1007 NLI->setName(Name + ".loadable");
1008
1009 LoadsRes = IRB.CreateBitCast(V: NLI, DestTy: LegalType, Name: Name + ".from.loadable");
1010 } else {
1011 IRB.SetInsertPoint(&OrigLI);
1012 LoadsRes = PoisonValue::get(T: LegalType);
1013 Value *OrigPtr = OrigLI.getPointerOperand();
1014 // If we're needing to spill something into more than one load, its legal
1015 // type will be a vector (ex. an i256 load will have LegalType = <8 x i32>).
1016 // But if we're already a scalar (which can happen if we're splitting up a
1017 // struct), the element type will be the legal type itself.
1018 Type *ElemType = LegalType->getScalarType();
1019 unsigned ElemBytes = DL.getTypeStoreSize(Ty: ElemType);
1020 AAMDNodes AANodes = OrigLI.getAAMetadata();
1021 if (IsAggPart && Slices.empty())
1022 Slices.push_back(Elt: VecSlice{/*Index=*/0, /*Length=*/1});
1023 for (VecSlice S : Slices) {
1024 Type *SliceType =
1025 S.Length != 1 ? FixedVectorType::get(ElementType: ElemType, NumElts: S.Length) : ElemType;
1026 int64_t ByteOffset = AggByteOff + S.Index * ElemBytes;
1027 // You can't reasonably expect loads to wrap around the edge of memory.
1028 Value *NewPtr = IRB.CreateGEP(
1029 Ty: IRB.getInt8Ty(), Ptr: OrigLI.getPointerOperand(), IdxList: IRB.getInt32(C: ByteOffset),
1030 Name: OrigPtr->getName() + ".off.ptr." + Twine(ByteOffset),
1031 NW: GEPNoWrapFlags::noUnsignedWrap());
1032 Type *LoadableType = intrinsicTypeFor(LegalType: SliceType);
1033 LoadInst *NewLI = IRB.CreateAlignedLoad(
1034 Ty: LoadableType, Ptr: NewPtr, Align: commonAlignment(A: OrigLI.getAlign(), Offset: ByteOffset),
1035 Name: Name + ".off." + Twine(ByteOffset));
1036 copyMetadataForLoad(Dest&: *NewLI, Source: OrigLI);
1037 NewLI->setAAMetadata(
1038 AANodes.adjustForAccess(Offset: ByteOffset, AccessTy: LoadableType, DL));
1039 NewLI->setAtomic(Ordering: OrigLI.getOrdering(), SSID: OrigLI.getSyncScopeID());
1040 NewLI->setVolatile(OrigLI.isVolatile());
1041 Value *Loaded = IRB.CreateBitCast(V: NewLI, DestTy: SliceType,
1042 Name: NewLI->getName() + ".from.loadable");
1043 LoadsRes = insertSlice(Whole: LoadsRes, Part: Loaded, S, Name);
1044 }
1045 }
1046 if (LegalType != ArrayAsVecType)
1047 LoadsRes = makeIllegalNonAggregate(V: LoadsRes, OrigType: ArrayAsVecType, Name);
1048 if (ArrayAsVecType != PartType)
1049 LoadsRes = vectorToArray(V: LoadsRes, OrigType: PartType, Name);
1050
1051 if (IsAggPart)
1052 Result = IRB.CreateInsertValue(Agg: Result, Val: LoadsRes, Idxs: AggIdxs, Name);
1053 else
1054 Result = LoadsRes;
1055 return true;
1056}
1057
1058bool LegalizeBufferContentTypesVisitor::visitLoadInst(LoadInst &LI) {
1059 if (LI.getPointerAddressSpace() != AMDGPUAS::BUFFER_FAT_POINTER)
1060 return false;
1061
1062 SmallVector<uint32_t> AggIdxs;
1063 Type *OrigType = LI.getType();
1064 Value *Result = PoisonValue::get(T: OrigType);
1065 bool Changed = visitLoadImpl(OrigLI&: LI, PartType: OrigType, AggIdxs, AggByteOff: 0, Result, Name: LI.getName());
1066 if (!Changed)
1067 return false;
1068 Result->takeName(V: &LI);
1069 LI.replaceAllUsesWith(V: Result);
1070 LI.eraseFromParent();
1071 return Changed;
1072}
1073
1074std::pair<bool, bool> LegalizeBufferContentTypesVisitor::visitStoreImpl(
1075 StoreInst &OrigSI, Type *PartType, SmallVectorImpl<uint32_t> &AggIdxs,
1076 uint64_t AggByteOff, const Twine &Name) {
1077 if (auto *ST = dyn_cast<StructType>(Val: PartType)) {
1078 const StructLayout *Layout = DL.getStructLayout(Ty: ST);
1079 bool Changed = false;
1080 for (auto [I, ElemTy, Offset] :
1081 llvm::enumerate(First: ST->elements(), Rest: Layout->getMemberOffsets())) {
1082 AggIdxs.push_back(Elt: I);
1083 Changed |= std::get<0>(in: visitStoreImpl(OrigSI, PartType: ElemTy, AggIdxs,
1084 AggByteOff: AggByteOff + Offset.getFixedValue(),
1085 Name: Name + "." + Twine(I)));
1086 AggIdxs.pop_back();
1087 }
1088 return std::make_pair(x&: Changed, /*ModifiedInPlace=*/y: false);
1089 }
1090 if (auto *AT = dyn_cast<ArrayType>(Val: PartType)) {
1091 Type *ElemTy = AT->getElementType();
1092 if (!ElemTy->isSingleValueType() || !DL.typeSizeEqualsStoreSize(Ty: ElemTy) ||
1093 ElemTy->isVectorTy()) {
1094 TypeSize ElemStoreSize = DL.getTypeStoreSize(Ty: ElemTy);
1095 bool Changed = false;
1096 for (auto I : llvm::iota_range<uint32_t>(0, AT->getNumElements(),
1097 /*Inclusive=*/false)) {
1098 AggIdxs.push_back(Elt: I);
1099 Changed |= std::get<0>(in: visitStoreImpl(
1100 OrigSI, PartType: ElemTy, AggIdxs,
1101 AggByteOff: AggByteOff + I * ElemStoreSize.getFixedValue(), Name: Name + Twine(I)));
1102 AggIdxs.pop_back();
1103 }
1104 return std::make_pair(x&: Changed, /*ModifiedInPlace=*/y: false);
1105 }
1106 }
1107
1108 Value *OrigData = OrigSI.getValueOperand();
1109 Value *NewData = OrigData;
1110
1111 bool IsAggPart = !AggIdxs.empty();
1112 if (IsAggPart)
1113 NewData = IRB.CreateExtractValue(Agg: NewData, Idxs: AggIdxs, Name);
1114
1115 Type *ArrayAsVecType = scalarArrayTypeAsVector(T: PartType);
1116 if (ArrayAsVecType != PartType) {
1117 NewData = arrayToVector(V: NewData, TargetType: ArrayAsVecType, Name);
1118 }
1119
1120 Type *LegalType = legalNonAggregateFor(T: ArrayAsVecType);
1121 if (LegalType != ArrayAsVecType) {
1122 NewData = makeLegalNonAggregate(V: NewData, TargetType: LegalType, Name);
1123 }
1124
1125 SmallVector<VecSlice> Slices;
1126 getVecSlices(T: LegalType, Slices);
1127 bool NeedToSplit = Slices.size() > 1 || IsAggPart;
1128 if (!NeedToSplit) {
1129 Type *StorableType = intrinsicTypeFor(LegalType);
1130 if (StorableType == PartType)
1131 return std::make_pair(/*Changed=*/x: false, /*ModifiedInPlace=*/y: false);
1132 NewData = IRB.CreateBitCast(V: NewData, DestTy: StorableType, Name: Name + ".storable");
1133 OrigSI.setOperand(i_nocapture: 0, Val_nocapture: NewData);
1134 return std::make_pair(/*Changed=*/x: true, /*ModifiedInPlace=*/y: true);
1135 }
1136
1137 Value *OrigPtr = OrigSI.getPointerOperand();
1138 Type *ElemType = LegalType->getScalarType();
1139 if (IsAggPart && Slices.empty())
1140 Slices.push_back(Elt: VecSlice{/*Index=*/0, /*Length=*/1});
1141 unsigned ElemBytes = DL.getTypeStoreSize(Ty: ElemType);
1142 AAMDNodes AANodes = OrigSI.getAAMetadata();
1143 for (VecSlice S : Slices) {
1144 Type *SliceType =
1145 S.Length != 1 ? FixedVectorType::get(ElementType: ElemType, NumElts: S.Length) : ElemType;
1146 int64_t ByteOffset = AggByteOff + S.Index * ElemBytes;
1147 Value *NewPtr =
1148 IRB.CreateGEP(Ty: IRB.getInt8Ty(), Ptr: OrigPtr, IdxList: IRB.getInt32(C: ByteOffset),
1149 Name: OrigPtr->getName() + ".part." + Twine(S.Index),
1150 NW: GEPNoWrapFlags::noUnsignedWrap());
1151 Value *DataSlice = extractSlice(Vec: NewData, S, Name);
1152 Type *StorableType = intrinsicTypeFor(LegalType: SliceType);
1153 DataSlice = IRB.CreateBitCast(V: DataSlice, DestTy: StorableType,
1154 Name: DataSlice->getName() + ".storable");
1155 auto *NewSI = cast<StoreInst>(Val: OrigSI.clone());
1156 NewSI->setAlignment(commonAlignment(A: OrigSI.getAlign(), Offset: ByteOffset));
1157 IRB.Insert(I: NewSI);
1158 NewSI->setOperand(i_nocapture: 0, Val_nocapture: DataSlice);
1159 NewSI->setOperand(i_nocapture: 1, Val_nocapture: NewPtr);
1160 NewSI->setAAMetadata(AANodes.adjustForAccess(Offset: ByteOffset, AccessTy: StorableType, DL));
1161 }
1162 return std::make_pair(/*Changed=*/x: true, /*ModifiedInPlace=*/y: false);
1163}
1164
1165bool LegalizeBufferContentTypesVisitor::visitStoreInst(StoreInst &SI) {
1166 if (SI.getPointerAddressSpace() != AMDGPUAS::BUFFER_FAT_POINTER)
1167 return false;
1168 IRB.SetInsertPoint(&SI);
1169 SmallVector<uint32_t> AggIdxs;
1170 Value *OrigData = SI.getValueOperand();
1171 auto [Changed, ModifiedInPlace] =
1172 visitStoreImpl(OrigSI&: SI, PartType: OrigData->getType(), AggIdxs, AggByteOff: 0, Name: OrigData->getName());
1173 if (Changed && !ModifiedInPlace)
1174 SI.eraseFromParent();
1175 return Changed;
1176}
1177
1178bool LegalizeBufferContentTypesVisitor::processFunction(Function &F) {
1179 bool Changed = false;
1180 // Note, memory transfer intrinsics won't
1181 for (Instruction &I : make_early_inc_range(Range: instructions(F))) {
1182 Changed |= visit(I);
1183 }
1184 return Changed;
1185}
1186
1187/// Return the ptr addrspace(8) and i32 (resource and offset parts) in a lowered
1188/// buffer fat pointer constant.
1189static std::pair<Constant *, Constant *>
1190splitLoweredFatBufferConst(Constant *C) {
1191 assert(isSplitFatPtr(C->getType()) && "Not a split fat buffer pointer");
1192 return std::make_pair(x: C->getAggregateElement(Elt: 0u), y: C->getAggregateElement(Elt: 1u));
1193}
1194
1195namespace {
1196/// Handle the remapping of ptr addrspace(7) constants.
1197class FatPtrConstMaterializer final : public ValueMaterializer {
1198 BufferFatPtrToStructTypeMap *TypeMap;
1199 // An internal mapper that is used to recurse into the arguments of constants.
1200 // While the documentation for `ValueMapper` specifies not to use it
1201 // recursively, examination of the logic in mapValue() shows that it can
1202 // safely be used recursively when handling constants, like it does in its own
1203 // logic.
1204 ValueMapper InternalMapper;
1205
1206 Constant *materializeBufferFatPtrConst(Constant *C);
1207
1208public:
1209 // UnderlyingMap is the value map this materializer will be filling.
1210 FatPtrConstMaterializer(BufferFatPtrToStructTypeMap *TypeMap,
1211 ValueToValueMapTy &UnderlyingMap)
1212 : TypeMap(TypeMap),
1213 InternalMapper(UnderlyingMap, RF_None, TypeMap, this) {}
1214 ~FatPtrConstMaterializer() = default;
1215
1216 Value *materialize(Value *V) override;
1217};
1218} // namespace
1219
1220Constant *FatPtrConstMaterializer::materializeBufferFatPtrConst(Constant *C) {
1221 Type *SrcTy = C->getType();
1222 auto *NewTy = dyn_cast<StructType>(Val: TypeMap->remapType(SrcTy));
1223 if (C->isNullValue())
1224 return ConstantAggregateZero::getNullValue(Ty: NewTy);
1225 if (isa<PoisonValue>(Val: C)) {
1226 return ConstantStruct::get(T: NewTy,
1227 V: {PoisonValue::get(T: NewTy->getElementType(N: 0)),
1228 PoisonValue::get(T: NewTy->getElementType(N: 1))});
1229 }
1230 if (isa<UndefValue>(Val: C)) {
1231 return ConstantStruct::get(T: NewTy,
1232 V: {UndefValue::get(T: NewTy->getElementType(N: 0)),
1233 UndefValue::get(T: NewTy->getElementType(N: 1))});
1234 }
1235
1236 if (auto *VC = dyn_cast<ConstantVector>(Val: C)) {
1237 if (Constant *S = VC->getSplatValue()) {
1238 Constant *NewS = InternalMapper.mapConstant(C: *S);
1239 if (!NewS)
1240 return nullptr;
1241 auto [Rsrc, Off] = splitLoweredFatBufferConst(C: NewS);
1242 auto EC = VC->getType()->getElementCount();
1243 return ConstantStruct::get(T: NewTy, V: {ConstantVector::getSplat(EC, Elt: Rsrc),
1244 ConstantVector::getSplat(EC, Elt: Off)});
1245 }
1246 SmallVector<Constant *> Rsrcs;
1247 SmallVector<Constant *> Offs;
1248 for (Value *Op : VC->operand_values()) {
1249 auto *NewOp = dyn_cast_or_null<Constant>(Val: InternalMapper.mapValue(V: *Op));
1250 if (!NewOp)
1251 return nullptr;
1252 auto [Rsrc, Off] = splitLoweredFatBufferConst(C: NewOp);
1253 Rsrcs.push_back(Elt: Rsrc);
1254 Offs.push_back(Elt: Off);
1255 }
1256 Constant *RsrcVec = ConstantVector::get(V: Rsrcs);
1257 Constant *OffVec = ConstantVector::get(V: Offs);
1258 return ConstantStruct::get(T: NewTy, V: {RsrcVec, OffVec});
1259 }
1260
1261 if (isa<GlobalValue>(Val: C))
1262 reportFatalUsageError(reason: "global values containing ptr addrspace(7) (buffer "
1263 "fat pointer) values are not supported");
1264
1265 if (isa<ConstantExpr>(Val: C))
1266 reportFatalUsageError(
1267 reason: "constant exprs containing ptr addrspace(7) (buffer "
1268 "fat pointer) values should have been expanded earlier");
1269
1270 return nullptr;
1271}
1272
1273Value *FatPtrConstMaterializer::materialize(Value *V) {
1274 Constant *C = dyn_cast<Constant>(Val: V);
1275 if (!C)
1276 return nullptr;
1277 // Structs and other types that happen to contain fat pointers get remapped
1278 // by the mapValue() logic.
1279 if (!isBufferFatPtrConst(C))
1280 return nullptr;
1281 return materializeBufferFatPtrConst(C);
1282}
1283
1284using PtrParts = std::pair<Value *, Value *>;
1285namespace {
1286// The visitor returns the resource and offset parts for an instruction if they
1287// can be computed, or (nullptr, nullptr) for cases that don't have a meaningful
1288// value mapping.
1289class SplitPtrStructs : public InstVisitor<SplitPtrStructs, PtrParts> {
1290 ValueToValueMapTy RsrcParts;
1291 ValueToValueMapTy OffParts;
1292
1293 // Track instructions that have been rewritten into a user of the component
1294 // parts of their ptr addrspace(7) input. Instructions that produced
1295 // ptr addrspace(7) parts should **not** be RAUW'd before being added to this
1296 // set, as that replacement will be handled in a post-visit step. However,
1297 // instructions that yield values that aren't fat pointers (ex. ptrtoint)
1298 // should RAUW themselves with new instructions that use the split parts
1299 // of their arguments during processing.
1300 DenseSet<Instruction *> SplitUsers;
1301
1302 // Nodes that need a second look once we've computed the parts for all other
1303 // instructions to see if, for example, we really need to phi on the resource
1304 // part.
1305 SmallVector<Instruction *> Conditionals;
1306 // Temporary instructions produced while lowering conditionals that should be
1307 // killed.
1308 SmallVector<Instruction *> ConditionalTemps;
1309
1310 // Subtarget info, needed for determining what cache control bits to set.
1311 const TargetMachine *TM;
1312 const GCNSubtarget *ST = nullptr;
1313
1314 IRBuilder<InstSimplifyFolder> IRB;
1315
1316 // Copy metadata between instructions if applicable.
1317 void copyMetadata(Value *Dest, Value *Src);
1318
1319 // Get the resource and offset parts of the value V, inserting appropriate
1320 // extractvalue calls if needed.
1321 PtrParts getPtrParts(Value *V);
1322
1323 // Given an instruction that could produce multiple resource parts (a PHI or
1324 // select), collect the set of possible instructions that could have provided
1325 // its resource parts that it could have (the `Roots`) and the set of
1326 // conditional instructions visited during the search (`Seen`). If, after
1327 // removing the root of the search from `Seen` and `Roots`, `Seen` is a subset
1328 // of `Roots` and `Roots - Seen` contains one element, the resource part of
1329 // that element can replace the resource part of all other elements in `Seen`.
1330 void getPossibleRsrcRoots(Instruction *I, SmallPtrSetImpl<Value *> &Roots,
1331 SmallPtrSetImpl<Value *> &Seen);
1332 void processConditionals();
1333
1334 // If an instruction hav been split into resource and offset parts,
1335 // delete that instruction. If any of its uses have not themselves been split
1336 // into parts (for example, an insertvalue), construct the structure
1337 // that the type rewrites declared should be produced by the dying instruction
1338 // and use that.
1339 // Also, kill the temporary extractvalue operations produced by the two-stage
1340 // lowering of PHIs and conditionals.
1341 void killAndReplaceSplitInstructions(SmallVectorImpl<Instruction *> &Origs);
1342
1343 void setAlign(CallInst *Intr, Align A, unsigned RsrcArgIdx);
1344 void insertPreMemOpFence(AtomicOrdering Order, SyncScope::ID SSID);
1345 void insertPostMemOpFence(AtomicOrdering Order, SyncScope::ID SSID);
1346 Value *handleMemoryInst(Instruction *I, Value *Arg, Value *Ptr, Type *Ty,
1347 Align Alignment, AtomicOrdering Order,
1348 bool IsVolatile, SyncScope::ID SSID);
1349
1350public:
1351 SplitPtrStructs(const DataLayout &DL, LLVMContext &Ctx,
1352 const TargetMachine *TM)
1353 : TM(TM), IRB(Ctx, InstSimplifyFolder(DL)) {}
1354
1355 void processFunction(Function &F);
1356
1357 PtrParts visitInstruction(Instruction &I);
1358 PtrParts visitLoadInst(LoadInst &LI);
1359 PtrParts visitStoreInst(StoreInst &SI);
1360 PtrParts visitAtomicRMWInst(AtomicRMWInst &AI);
1361 PtrParts visitAtomicCmpXchgInst(AtomicCmpXchgInst &AI);
1362 PtrParts visitGetElementPtrInst(GetElementPtrInst &GEP);
1363
1364 PtrParts visitPtrToIntInst(PtrToIntInst &PI);
1365 PtrParts visitIntToPtrInst(IntToPtrInst &IP);
1366 PtrParts visitAddrSpaceCastInst(AddrSpaceCastInst &I);
1367 PtrParts visitICmpInst(ICmpInst &Cmp);
1368 PtrParts visitFreezeInst(FreezeInst &I);
1369
1370 PtrParts visitExtractElementInst(ExtractElementInst &I);
1371 PtrParts visitInsertElementInst(InsertElementInst &I);
1372 PtrParts visitShuffleVectorInst(ShuffleVectorInst &I);
1373
1374 PtrParts visitPHINode(PHINode &PHI);
1375 PtrParts visitSelectInst(SelectInst &SI);
1376
1377 PtrParts visitIntrinsicInst(IntrinsicInst &II);
1378};
1379} // namespace
1380
1381void SplitPtrStructs::copyMetadata(Value *Dest, Value *Src) {
1382 auto *DestI = dyn_cast<Instruction>(Val: Dest);
1383 auto *SrcI = dyn_cast<Instruction>(Val: Src);
1384
1385 if (!DestI || !SrcI)
1386 return;
1387
1388 DestI->copyMetadata(SrcInst: *SrcI);
1389}
1390
1391PtrParts SplitPtrStructs::getPtrParts(Value *V) {
1392 assert(isSplitFatPtr(V->getType()) && "it's not meaningful to get the parts "
1393 "of something that wasn't rewritten");
1394 auto *RsrcEntry = &RsrcParts[V];
1395 auto *OffEntry = &OffParts[V];
1396 if (*RsrcEntry && *OffEntry)
1397 return {*RsrcEntry, *OffEntry};
1398
1399 if (auto *C = dyn_cast<Constant>(Val: V)) {
1400 auto [Rsrc, Off] = splitLoweredFatBufferConst(C);
1401 return {*RsrcEntry = Rsrc, *OffEntry = Off};
1402 }
1403
1404 IRBuilder<InstSimplifyFolder>::InsertPointGuard Guard(IRB);
1405 if (auto *I = dyn_cast<Instruction>(Val: V)) {
1406 LLVM_DEBUG(dbgs() << "Recursing to split parts of " << *I << "\n");
1407 auto [Rsrc, Off] = visit(I&: *I);
1408 if (Rsrc && Off)
1409 return {*RsrcEntry = Rsrc, *OffEntry = Off};
1410 // We'll be creating the new values after the relevant instruction.
1411 // This instruction generates a value and so isn't a terminator.
1412 IRB.SetInsertPoint(*I->getInsertionPointAfterDef());
1413 IRB.SetCurrentDebugLocation(I->getDebugLoc());
1414 } else if (auto *A = dyn_cast<Argument>(Val: V)) {
1415 IRB.SetInsertPointPastAllocas(A->getParent());
1416 IRB.SetCurrentDebugLocation(DebugLoc());
1417 }
1418 Value *Rsrc = IRB.CreateExtractValue(Agg: V, Idxs: 0, Name: V->getName() + ".rsrc");
1419 Value *Off = IRB.CreateExtractValue(Agg: V, Idxs: 1, Name: V->getName() + ".off");
1420 return {*RsrcEntry = Rsrc, *OffEntry = Off};
1421}
1422
1423/// Returns the instruction that defines the resource part of the value V.
1424/// Note that this is not getUnderlyingObject(), since that looks through
1425/// operations like ptrmask which might modify the resource part.
1426///
1427/// We can limit ourselves to just looking through GEPs followed by looking
1428/// through addrspacecasts because only those two operations preserve the
1429/// resource part, and because operations on an `addrspace(8)` (which is the
1430/// legal input to this addrspacecast) would produce a different resource part.
1431static Value *rsrcPartRoot(Value *V) {
1432 while (auto *GEP = dyn_cast<GEPOperator>(Val: V))
1433 V = GEP->getPointerOperand();
1434 while (auto *ASC = dyn_cast<AddrSpaceCastOperator>(Val: V))
1435 V = ASC->getPointerOperand();
1436 return V;
1437}
1438
1439void SplitPtrStructs::getPossibleRsrcRoots(Instruction *I,
1440 SmallPtrSetImpl<Value *> &Roots,
1441 SmallPtrSetImpl<Value *> &Seen) {
1442 if (auto *PHI = dyn_cast<PHINode>(Val: I)) {
1443 if (!Seen.insert(Ptr: I).second)
1444 return;
1445 for (Value *In : PHI->incoming_values()) {
1446 In = rsrcPartRoot(V: In);
1447 Roots.insert(Ptr: In);
1448 if (isa<PHINode, SelectInst>(Val: In))
1449 getPossibleRsrcRoots(I: cast<Instruction>(Val: In), Roots, Seen);
1450 }
1451 } else if (auto *SI = dyn_cast<SelectInst>(Val: I)) {
1452 if (!Seen.insert(Ptr: SI).second)
1453 return;
1454 Value *TrueVal = rsrcPartRoot(V: SI->getTrueValue());
1455 Value *FalseVal = rsrcPartRoot(V: SI->getFalseValue());
1456 Roots.insert(Ptr: TrueVal);
1457 Roots.insert(Ptr: FalseVal);
1458 if (isa<PHINode, SelectInst>(Val: TrueVal))
1459 getPossibleRsrcRoots(I: cast<Instruction>(Val: TrueVal), Roots, Seen);
1460 if (isa<PHINode, SelectInst>(Val: FalseVal))
1461 getPossibleRsrcRoots(I: cast<Instruction>(Val: FalseVal), Roots, Seen);
1462 } else {
1463 llvm_unreachable("getPossibleRsrcParts() only works on phi and select");
1464 }
1465}
1466
1467void SplitPtrStructs::processConditionals() {
1468 SmallDenseMap<Value *, Value *> FoundRsrcs;
1469 SmallPtrSet<Value *, 4> Roots;
1470 SmallPtrSet<Value *, 4> Seen;
1471 for (Instruction *I : Conditionals) {
1472 // These have to exist by now because we've visited these nodes.
1473 Value *Rsrc = RsrcParts[I];
1474 Value *Off = OffParts[I];
1475 assert(Rsrc && Off && "must have visited conditionals by now");
1476
1477 std::optional<Value *> MaybeRsrc;
1478 auto MaybeFoundRsrc = FoundRsrcs.find(Val: I);
1479 if (MaybeFoundRsrc != FoundRsrcs.end()) {
1480 MaybeRsrc = MaybeFoundRsrc->second;
1481 } else {
1482 IRBuilder<InstSimplifyFolder>::InsertPointGuard Guard(IRB);
1483 Roots.clear();
1484 Seen.clear();
1485 getPossibleRsrcRoots(I, Roots, Seen);
1486 LLVM_DEBUG(dbgs() << "Processing conditional: " << *I << "\n");
1487#ifndef NDEBUG
1488 for (Value *V : Roots)
1489 LLVM_DEBUG(dbgs() << "Root: " << *V << "\n");
1490 for (Value *V : Seen)
1491 LLVM_DEBUG(dbgs() << "Seen: " << *V << "\n");
1492#endif
1493 // If we are our own possible root, then we shouldn't block our
1494 // replacement with a valid incoming value.
1495 Roots.erase(Ptr: I);
1496 // We don't want to block the optimization for conditionals that don't
1497 // refer to themselves but did see themselves during the traversal.
1498 Seen.erase(Ptr: I);
1499
1500 if (set_is_subset(S1: Seen, S2: Roots)) {
1501 auto Diff = set_difference(S1: Roots, S2: Seen);
1502 if (Diff.size() == 1) {
1503 Value *RootVal = *Diff.begin();
1504 // Handle the case where previous loops already looked through
1505 // an addrspacecast.
1506 if (isSplitFatPtr(Ty: RootVal->getType()))
1507 MaybeRsrc = std::get<0>(in: getPtrParts(V: RootVal));
1508 else
1509 MaybeRsrc = RootVal;
1510 }
1511 }
1512 }
1513
1514 if (auto *PHI = dyn_cast<PHINode>(Val: I)) {
1515 Value *NewRsrc;
1516 StructType *PHITy = cast<StructType>(Val: PHI->getType());
1517 IRB.SetInsertPoint(*PHI->getInsertionPointAfterDef());
1518 IRB.SetCurrentDebugLocation(PHI->getDebugLoc());
1519 if (MaybeRsrc) {
1520 NewRsrc = *MaybeRsrc;
1521 } else {
1522 Type *RsrcTy = PHITy->getElementType(N: 0);
1523 auto *RsrcPHI = IRB.CreatePHI(Ty: RsrcTy, NumReservedValues: PHI->getNumIncomingValues());
1524 RsrcPHI->takeName(V: Rsrc);
1525 for (auto [V, BB] : llvm::zip(t: PHI->incoming_values(), u: PHI->blocks())) {
1526 Value *VRsrc = std::get<0>(in: getPtrParts(V));
1527 RsrcPHI->addIncoming(V: VRsrc, BB);
1528 }
1529 copyMetadata(Dest: RsrcPHI, Src: PHI);
1530 NewRsrc = RsrcPHI;
1531 }
1532
1533 Type *OffTy = PHITy->getElementType(N: 1);
1534 auto *NewOff = IRB.CreatePHI(Ty: OffTy, NumReservedValues: PHI->getNumIncomingValues());
1535 NewOff->takeName(V: Off);
1536 for (auto [V, BB] : llvm::zip(t: PHI->incoming_values(), u: PHI->blocks())) {
1537 assert(OffParts.count(V) && "An offset part had to be created by now");
1538 Value *VOff = std::get<1>(in: getPtrParts(V));
1539 NewOff->addIncoming(V: VOff, BB);
1540 }
1541 copyMetadata(Dest: NewOff, Src: PHI);
1542
1543 // Note: We don't eraseFromParent() the temporaries because we don't want
1544 // to put the corrections maps in an inconstent state. That'll be handed
1545 // during the rest of the killing. Also, `ValueToValueMapTy` guarantees
1546 // that references in that map will be updated as well.
1547 // Note that if the temporary instruction got `InstSimplify`'d away, it
1548 // might be something like a block argument.
1549 if (auto *RsrcInst = dyn_cast<Instruction>(Val: Rsrc)) {
1550 ConditionalTemps.push_back(Elt: RsrcInst);
1551 RsrcInst->replaceAllUsesWith(V: NewRsrc);
1552 }
1553 if (auto *OffInst = dyn_cast<Instruction>(Val: Off)) {
1554 ConditionalTemps.push_back(Elt: OffInst);
1555 OffInst->replaceAllUsesWith(V: NewOff);
1556 }
1557
1558 // Save on recomputing the cycle traversals in known-root cases.
1559 if (MaybeRsrc)
1560 for (Value *V : Seen)
1561 FoundRsrcs[V] = NewRsrc;
1562 } else if (isa<SelectInst>(Val: I)) {
1563 if (MaybeRsrc) {
1564 if (auto *RsrcInst = dyn_cast<Instruction>(Val: Rsrc)) {
1565 ConditionalTemps.push_back(Elt: RsrcInst);
1566 RsrcInst->replaceAllUsesWith(V: *MaybeRsrc);
1567 }
1568 for (Value *V : Seen)
1569 FoundRsrcs[V] = *MaybeRsrc;
1570 }
1571 } else {
1572 llvm_unreachable("Only PHIs and selects go in the conditionals list");
1573 }
1574 }
1575}
1576
1577void SplitPtrStructs::killAndReplaceSplitInstructions(
1578 SmallVectorImpl<Instruction *> &Origs) {
1579 for (Instruction *I : ConditionalTemps)
1580 I->eraseFromParent();
1581
1582 for (Instruction *I : Origs) {
1583 if (!SplitUsers.contains(V: I))
1584 continue;
1585
1586 SmallVector<DbgValueInst *> Dbgs;
1587 findDbgValues(DbgValues&: Dbgs, V: I);
1588 for (auto *Dbg : Dbgs) {
1589 IRB.SetInsertPoint(Dbg);
1590 auto &DL = I->getDataLayout();
1591 assert(isSplitFatPtr(I->getType()) &&
1592 "We should've RAUW'd away loads, stores, etc. at this point");
1593 auto *OffDbg = cast<DbgValueInst>(Val: Dbg->clone());
1594 copyMetadata(Dest: OffDbg, Src: Dbg);
1595 auto [Rsrc, Off] = getPtrParts(V: I);
1596
1597 int64_t RsrcSz = DL.getTypeSizeInBits(Ty: Rsrc->getType());
1598 int64_t OffSz = DL.getTypeSizeInBits(Ty: Off->getType());
1599
1600 std::optional<DIExpression *> RsrcExpr =
1601 DIExpression::createFragmentExpression(Expr: Dbg->getExpression(), OffsetInBits: 0,
1602 SizeInBits: RsrcSz);
1603 std::optional<DIExpression *> OffExpr =
1604 DIExpression::createFragmentExpression(Expr: Dbg->getExpression(), OffsetInBits: RsrcSz,
1605 SizeInBits: OffSz);
1606 if (OffExpr) {
1607 OffDbg->setExpression(*OffExpr);
1608 OffDbg->replaceVariableLocationOp(OldValue: I, NewValue: Off);
1609 IRB.Insert(I: OffDbg);
1610 } else {
1611 OffDbg->deleteValue();
1612 }
1613 if (RsrcExpr) {
1614 Dbg->setExpression(*RsrcExpr);
1615 Dbg->replaceVariableLocationOp(OldValue: I, NewValue: Rsrc);
1616 } else {
1617 Dbg->replaceVariableLocationOp(OldValue: I, NewValue: PoisonValue::get(T: I->getType()));
1618 }
1619 }
1620
1621 Value *Poison = PoisonValue::get(T: I->getType());
1622 I->replaceUsesWithIf(New: Poison, ShouldReplace: [&](const Use &U) -> bool {
1623 if (const auto *UI = dyn_cast<Instruction>(Val: U.getUser()))
1624 return SplitUsers.contains(V: UI);
1625 return false;
1626 });
1627
1628 if (I->use_empty()) {
1629 I->eraseFromParent();
1630 continue;
1631 }
1632 IRB.SetInsertPoint(*I->getInsertionPointAfterDef());
1633 IRB.SetCurrentDebugLocation(I->getDebugLoc());
1634 auto [Rsrc, Off] = getPtrParts(V: I);
1635 Value *Struct = PoisonValue::get(T: I->getType());
1636 Struct = IRB.CreateInsertValue(Agg: Struct, Val: Rsrc, Idxs: 0);
1637 Struct = IRB.CreateInsertValue(Agg: Struct, Val: Off, Idxs: 1);
1638 copyMetadata(Dest: Struct, Src: I);
1639 Struct->takeName(V: I);
1640 I->replaceAllUsesWith(V: Struct);
1641 I->eraseFromParent();
1642 }
1643}
1644
1645void SplitPtrStructs::setAlign(CallInst *Intr, Align A, unsigned RsrcArgIdx) {
1646 LLVMContext &Ctx = Intr->getContext();
1647 Intr->addParamAttr(ArgNo: RsrcArgIdx, Attr: Attribute::getWithAlignment(Context&: Ctx, Alignment: A));
1648}
1649
1650void SplitPtrStructs::insertPreMemOpFence(AtomicOrdering Order,
1651 SyncScope::ID SSID) {
1652 switch (Order) {
1653 case AtomicOrdering::Release:
1654 case AtomicOrdering::AcquireRelease:
1655 case AtomicOrdering::SequentiallyConsistent:
1656 IRB.CreateFence(Ordering: AtomicOrdering::Release, SSID);
1657 break;
1658 default:
1659 break;
1660 }
1661}
1662
1663void SplitPtrStructs::insertPostMemOpFence(AtomicOrdering Order,
1664 SyncScope::ID SSID) {
1665 switch (Order) {
1666 case AtomicOrdering::Acquire:
1667 case AtomicOrdering::AcquireRelease:
1668 case AtomicOrdering::SequentiallyConsistent:
1669 IRB.CreateFence(Ordering: AtomicOrdering::Acquire, SSID);
1670 break;
1671 default:
1672 break;
1673 }
1674}
1675
1676Value *SplitPtrStructs::handleMemoryInst(Instruction *I, Value *Arg, Value *Ptr,
1677 Type *Ty, Align Alignment,
1678 AtomicOrdering Order, bool IsVolatile,
1679 SyncScope::ID SSID) {
1680 IRB.SetInsertPoint(I);
1681
1682 auto [Rsrc, Off] = getPtrParts(V: Ptr);
1683 SmallVector<Value *, 5> Args;
1684 if (Arg)
1685 Args.push_back(Elt: Arg);
1686 Args.push_back(Elt: Rsrc);
1687 Args.push_back(Elt: Off);
1688 insertPreMemOpFence(Order, SSID);
1689 // soffset is always 0 for these cases, where we always want any offset to be
1690 // part of bounds checking and we don't know which parts of the GEPs is
1691 // uniform.
1692 Args.push_back(Elt: IRB.getInt32(C: 0));
1693
1694 uint32_t Aux = 0;
1695 if (IsVolatile)
1696 Aux |= AMDGPU::CPol::VOLATILE;
1697 Args.push_back(Elt: IRB.getInt32(C: Aux));
1698
1699 Intrinsic::ID IID = Intrinsic::not_intrinsic;
1700 if (isa<LoadInst>(Val: I))
1701 IID = Order == AtomicOrdering::NotAtomic
1702 ? Intrinsic::amdgcn_raw_ptr_buffer_load
1703 : Intrinsic::amdgcn_raw_ptr_atomic_buffer_load;
1704 else if (isa<StoreInst>(Val: I))
1705 IID = Intrinsic::amdgcn_raw_ptr_buffer_store;
1706 else if (auto *RMW = dyn_cast<AtomicRMWInst>(Val: I)) {
1707 switch (RMW->getOperation()) {
1708 case AtomicRMWInst::Xchg:
1709 IID = Intrinsic::amdgcn_raw_ptr_buffer_atomic_swap;
1710 break;
1711 case AtomicRMWInst::Add:
1712 IID = Intrinsic::amdgcn_raw_ptr_buffer_atomic_add;
1713 break;
1714 case AtomicRMWInst::Sub:
1715 IID = Intrinsic::amdgcn_raw_ptr_buffer_atomic_sub;
1716 break;
1717 case AtomicRMWInst::And:
1718 IID = Intrinsic::amdgcn_raw_ptr_buffer_atomic_and;
1719 break;
1720 case AtomicRMWInst::Or:
1721 IID = Intrinsic::amdgcn_raw_ptr_buffer_atomic_or;
1722 break;
1723 case AtomicRMWInst::Xor:
1724 IID = Intrinsic::amdgcn_raw_ptr_buffer_atomic_xor;
1725 break;
1726 case AtomicRMWInst::Max:
1727 IID = Intrinsic::amdgcn_raw_ptr_buffer_atomic_smax;
1728 break;
1729 case AtomicRMWInst::Min:
1730 IID = Intrinsic::amdgcn_raw_ptr_buffer_atomic_smin;
1731 break;
1732 case AtomicRMWInst::UMax:
1733 IID = Intrinsic::amdgcn_raw_ptr_buffer_atomic_umax;
1734 break;
1735 case AtomicRMWInst::UMin:
1736 IID = Intrinsic::amdgcn_raw_ptr_buffer_atomic_umin;
1737 break;
1738 case AtomicRMWInst::FAdd:
1739 IID = Intrinsic::amdgcn_raw_ptr_buffer_atomic_fadd;
1740 break;
1741 case AtomicRMWInst::FMax:
1742 IID = Intrinsic::amdgcn_raw_ptr_buffer_atomic_fmax;
1743 break;
1744 case AtomicRMWInst::FMin:
1745 IID = Intrinsic::amdgcn_raw_ptr_buffer_atomic_fmin;
1746 break;
1747 case AtomicRMWInst::FSub: {
1748 reportFatalUsageError(
1749 reason: "atomic floating point subtraction not supported for "
1750 "buffer resources and should've been expanded away");
1751 break;
1752 }
1753 case AtomicRMWInst::FMaximum: {
1754 reportFatalUsageError(
1755 reason: "atomic floating point fmaximum not supported for "
1756 "buffer resources and should've been expanded away");
1757 break;
1758 }
1759 case AtomicRMWInst::FMinimum: {
1760 reportFatalUsageError(
1761 reason: "atomic floating point fminimum not supported for "
1762 "buffer resources and should've been expanded away");
1763 break;
1764 }
1765 case AtomicRMWInst::Nand:
1766 reportFatalUsageError(
1767 reason: "atomic nand not supported for buffer resources and "
1768 "should've been expanded away");
1769 break;
1770 case AtomicRMWInst::UIncWrap:
1771 case AtomicRMWInst::UDecWrap:
1772 reportFatalUsageError(reason: "wrapping increment/decrement not supported for "
1773 "buffer resources and should've ben expanded away");
1774 break;
1775 case AtomicRMWInst::BAD_BINOP:
1776 llvm_unreachable("Not sure how we got a bad binop");
1777 case AtomicRMWInst::USubCond:
1778 case AtomicRMWInst::USubSat:
1779 break;
1780 }
1781 }
1782
1783 auto *Call = IRB.CreateIntrinsic(ID: IID, Types: Ty, Args);
1784 copyMetadata(Dest: Call, Src: I);
1785 setAlign(Intr: Call, A: Alignment, RsrcArgIdx: Arg ? 1 : 0);
1786 Call->takeName(V: I);
1787
1788 insertPostMemOpFence(Order, SSID);
1789 // The "no moving p7 directly" rewrites ensure that this load or store won't
1790 // itself need to be split into parts.
1791 SplitUsers.insert(V: I);
1792 I->replaceAllUsesWith(V: Call);
1793 return Call;
1794}
1795
1796PtrParts SplitPtrStructs::visitInstruction(Instruction &I) {
1797 return {nullptr, nullptr};
1798}
1799
1800PtrParts SplitPtrStructs::visitLoadInst(LoadInst &LI) {
1801 if (!isSplitFatPtr(Ty: LI.getPointerOperandType()))
1802 return {nullptr, nullptr};
1803 handleMemoryInst(I: &LI, Arg: nullptr, Ptr: LI.getPointerOperand(), Ty: LI.getType(),
1804 Alignment: LI.getAlign(), Order: LI.getOrdering(), IsVolatile: LI.isVolatile(),
1805 SSID: LI.getSyncScopeID());
1806 return {nullptr, nullptr};
1807}
1808
1809PtrParts SplitPtrStructs::visitStoreInst(StoreInst &SI) {
1810 if (!isSplitFatPtr(Ty: SI.getPointerOperandType()))
1811 return {nullptr, nullptr};
1812 Value *Arg = SI.getValueOperand();
1813 handleMemoryInst(I: &SI, Arg, Ptr: SI.getPointerOperand(), Ty: Arg->getType(),
1814 Alignment: SI.getAlign(), Order: SI.getOrdering(), IsVolatile: SI.isVolatile(),
1815 SSID: SI.getSyncScopeID());
1816 return {nullptr, nullptr};
1817}
1818
1819PtrParts SplitPtrStructs::visitAtomicRMWInst(AtomicRMWInst &AI) {
1820 if (!isSplitFatPtr(Ty: AI.getPointerOperand()->getType()))
1821 return {nullptr, nullptr};
1822 Value *Arg = AI.getValOperand();
1823 handleMemoryInst(I: &AI, Arg, Ptr: AI.getPointerOperand(), Ty: Arg->getType(),
1824 Alignment: AI.getAlign(), Order: AI.getOrdering(), IsVolatile: AI.isVolatile(),
1825 SSID: AI.getSyncScopeID());
1826 return {nullptr, nullptr};
1827}
1828
1829// Unlike load, store, and RMW, cmpxchg needs special handling to account
1830// for the boolean argument.
1831PtrParts SplitPtrStructs::visitAtomicCmpXchgInst(AtomicCmpXchgInst &AI) {
1832 Value *Ptr = AI.getPointerOperand();
1833 if (!isSplitFatPtr(Ty: Ptr->getType()))
1834 return {nullptr, nullptr};
1835 IRB.SetInsertPoint(&AI);
1836
1837 Type *Ty = AI.getNewValOperand()->getType();
1838 AtomicOrdering Order = AI.getMergedOrdering();
1839 SyncScope::ID SSID = AI.getSyncScopeID();
1840 bool IsNonTemporal = AI.getMetadata(KindID: LLVMContext::MD_nontemporal);
1841
1842 auto [Rsrc, Off] = getPtrParts(V: Ptr);
1843 insertPreMemOpFence(Order, SSID);
1844
1845 uint32_t Aux = 0;
1846 if (IsNonTemporal)
1847 Aux |= AMDGPU::CPol::SLC;
1848 if (AI.isVolatile())
1849 Aux |= AMDGPU::CPol::VOLATILE;
1850 auto *Call =
1851 IRB.CreateIntrinsic(ID: Intrinsic::amdgcn_raw_ptr_buffer_atomic_cmpswap, Types: Ty,
1852 Args: {AI.getNewValOperand(), AI.getCompareOperand(), Rsrc,
1853 Off, IRB.getInt32(C: 0), IRB.getInt32(C: Aux)});
1854 copyMetadata(Dest: Call, Src: &AI);
1855 setAlign(Intr: Call, A: AI.getAlign(), RsrcArgIdx: 2);
1856 Call->takeName(V: &AI);
1857 insertPostMemOpFence(Order, SSID);
1858
1859 Value *Res = PoisonValue::get(T: AI.getType());
1860 Res = IRB.CreateInsertValue(Agg: Res, Val: Call, Idxs: 0);
1861 if (!AI.isWeak()) {
1862 Value *Succeeded = IRB.CreateICmpEQ(LHS: Call, RHS: AI.getCompareOperand());
1863 Res = IRB.CreateInsertValue(Agg: Res, Val: Succeeded, Idxs: 1);
1864 }
1865 SplitUsers.insert(V: &AI);
1866 AI.replaceAllUsesWith(V: Res);
1867 return {nullptr, nullptr};
1868}
1869
1870PtrParts SplitPtrStructs::visitGetElementPtrInst(GetElementPtrInst &GEP) {
1871 using namespace llvm::PatternMatch;
1872 Value *Ptr = GEP.getPointerOperand();
1873 if (!isSplitFatPtr(Ty: Ptr->getType()))
1874 return {nullptr, nullptr};
1875 IRB.SetInsertPoint(&GEP);
1876
1877 auto [Rsrc, Off] = getPtrParts(V: Ptr);
1878 const DataLayout &DL = GEP.getDataLayout();
1879 bool IsNUW = GEP.hasNoUnsignedWrap();
1880 bool IsNUSW = GEP.hasNoUnsignedSignedWrap();
1881
1882 StructType *ResTy = cast<StructType>(Val: GEP.getType());
1883 Type *ResRsrcTy = ResTy->getElementType(N: 0);
1884 VectorType *ResRsrcVecTy = dyn_cast<VectorType>(Val: ResRsrcTy);
1885 bool BroadcastsPtr = ResRsrcVecTy && !isa<VectorType>(Val: Off->getType());
1886
1887 // In order to call emitGEPOffset() and thus not have to reimplement it,
1888 // we need the GEP result to have ptr addrspace(7) type.
1889 Type *FatPtrTy =
1890 ResRsrcTy->getWithNewType(EltTy: IRB.getPtrTy(AddrSpace: AMDGPUAS::BUFFER_FAT_POINTER));
1891 GEP.mutateType(Ty: FatPtrTy);
1892 Value *OffAccum = emitGEPOffset(Builder: &IRB, DL, GEP: &GEP);
1893 GEP.mutateType(Ty: ResTy);
1894
1895 if (BroadcastsPtr) {
1896 Rsrc = IRB.CreateVectorSplat(EC: ResRsrcVecTy->getElementCount(), V: Rsrc,
1897 Name: Rsrc->getName());
1898 Off = IRB.CreateVectorSplat(EC: ResRsrcVecTy->getElementCount(), V: Off,
1899 Name: Off->getName());
1900 }
1901 if (match(V: OffAccum, P: m_Zero())) { // Constant-zero offset
1902 SplitUsers.insert(V: &GEP);
1903 return {Rsrc, Off};
1904 }
1905
1906 bool HasNonNegativeOff = false;
1907 if (auto *CI = dyn_cast<ConstantInt>(Val: OffAccum)) {
1908 HasNonNegativeOff = !CI->isNegative();
1909 }
1910 Value *NewOff;
1911 if (match(V: Off, P: m_Zero())) {
1912 NewOff = OffAccum;
1913 } else {
1914 NewOff = IRB.CreateAdd(LHS: Off, RHS: OffAccum, Name: "",
1915 /*hasNUW=*/HasNUW: IsNUW || (IsNUSW && HasNonNegativeOff),
1916 /*hasNSW=*/HasNSW: false);
1917 }
1918 copyMetadata(Dest: NewOff, Src: &GEP);
1919 NewOff->takeName(V: &GEP);
1920 SplitUsers.insert(V: &GEP);
1921 return {Rsrc, NewOff};
1922}
1923
1924PtrParts SplitPtrStructs::visitPtrToIntInst(PtrToIntInst &PI) {
1925 Value *Ptr = PI.getPointerOperand();
1926 if (!isSplitFatPtr(Ty: Ptr->getType()))
1927 return {nullptr, nullptr};
1928 IRB.SetInsertPoint(&PI);
1929
1930 Type *ResTy = PI.getType();
1931 unsigned Width = ResTy->getScalarSizeInBits();
1932
1933 auto [Rsrc, Off] = getPtrParts(V: Ptr);
1934 const DataLayout &DL = PI.getDataLayout();
1935 unsigned FatPtrWidth = DL.getPointerSizeInBits(AS: AMDGPUAS::BUFFER_FAT_POINTER);
1936
1937 Value *Res;
1938 if (Width <= BufferOffsetWidth) {
1939 Res = IRB.CreateIntCast(V: Off, DestTy: ResTy, /*isSigned=*/false,
1940 Name: PI.getName() + ".off");
1941 } else {
1942 Value *RsrcInt = IRB.CreatePtrToInt(V: Rsrc, DestTy: ResTy, Name: PI.getName() + ".rsrc");
1943 Value *Shl = IRB.CreateShl(
1944 LHS: RsrcInt,
1945 RHS: ConstantExpr::getIntegerValue(Ty: ResTy, V: APInt(Width, BufferOffsetWidth)),
1946 Name: "", HasNUW: Width >= FatPtrWidth, HasNSW: Width > FatPtrWidth);
1947 Value *OffCast = IRB.CreateIntCast(V: Off, DestTy: ResTy, /*isSigned=*/false,
1948 Name: PI.getName() + ".off");
1949 Res = IRB.CreateOr(LHS: Shl, RHS: OffCast);
1950 }
1951
1952 copyMetadata(Dest: Res, Src: &PI);
1953 Res->takeName(V: &PI);
1954 SplitUsers.insert(V: &PI);
1955 PI.replaceAllUsesWith(V: Res);
1956 return {nullptr, nullptr};
1957}
1958
1959PtrParts SplitPtrStructs::visitIntToPtrInst(IntToPtrInst &IP) {
1960 if (!isSplitFatPtr(Ty: IP.getType()))
1961 return {nullptr, nullptr};
1962 IRB.SetInsertPoint(&IP);
1963 const DataLayout &DL = IP.getDataLayout();
1964 unsigned RsrcPtrWidth = DL.getPointerSizeInBits(AS: AMDGPUAS::BUFFER_RESOURCE);
1965 Value *Int = IP.getOperand(i_nocapture: 0);
1966 Type *IntTy = Int->getType();
1967 Type *RsrcIntTy = IntTy->getWithNewBitWidth(NewBitWidth: RsrcPtrWidth);
1968 unsigned Width = IntTy->getScalarSizeInBits();
1969
1970 auto *RetTy = cast<StructType>(Val: IP.getType());
1971 Type *RsrcTy = RetTy->getElementType(N: 0);
1972 Type *OffTy = RetTy->getElementType(N: 1);
1973 Value *RsrcPart = IRB.CreateLShr(
1974 LHS: Int,
1975 RHS: ConstantExpr::getIntegerValue(Ty: IntTy, V: APInt(Width, BufferOffsetWidth)));
1976 Value *RsrcInt = IRB.CreateIntCast(V: RsrcPart, DestTy: RsrcIntTy, /*isSigned=*/false);
1977 Value *Rsrc = IRB.CreateIntToPtr(V: RsrcInt, DestTy: RsrcTy, Name: IP.getName() + ".rsrc");
1978 Value *Off =
1979 IRB.CreateIntCast(V: Int, DestTy: OffTy, /*IsSigned=*/isSigned: false, Name: IP.getName() + ".off");
1980
1981 copyMetadata(Dest: Rsrc, Src: &IP);
1982 SplitUsers.insert(V: &IP);
1983 return {Rsrc, Off};
1984}
1985
1986PtrParts SplitPtrStructs::visitAddrSpaceCastInst(AddrSpaceCastInst &I) {
1987 // TODO(krzysz00): handle casts from ptr addrspace(7) to global pointers
1988 // by computing the effective address.
1989 if (!isSplitFatPtr(Ty: I.getType()))
1990 return {nullptr, nullptr};
1991 IRB.SetInsertPoint(&I);
1992 Value *In = I.getPointerOperand();
1993 // No-op casts preserve parts
1994 if (In->getType() == I.getType()) {
1995 auto [Rsrc, Off] = getPtrParts(V: In);
1996 SplitUsers.insert(V: &I);
1997 return {Rsrc, Off};
1998 }
1999
2000 auto *ResTy = cast<StructType>(Val: I.getType());
2001 Type *RsrcTy = ResTy->getElementType(N: 0);
2002 Type *OffTy = ResTy->getElementType(N: 1);
2003 Value *ZeroOff = Constant::getNullValue(Ty: OffTy);
2004
2005 // Special case for null pointers, undef, and poison, which can be created by
2006 // address space propagation.
2007 auto *InConst = dyn_cast<Constant>(Val: In);
2008 if (InConst && InConst->isNullValue()) {
2009 Value *NullRsrc = Constant::getNullValue(Ty: RsrcTy);
2010 SplitUsers.insert(V: &I);
2011 return {NullRsrc, ZeroOff};
2012 }
2013 if (isa<PoisonValue>(Val: In)) {
2014 Value *PoisonRsrc = PoisonValue::get(T: RsrcTy);
2015 Value *PoisonOff = PoisonValue::get(T: OffTy);
2016 SplitUsers.insert(V: &I);
2017 return {PoisonRsrc, PoisonOff};
2018 }
2019 if (isa<UndefValue>(Val: In)) {
2020 Value *UndefRsrc = UndefValue::get(T: RsrcTy);
2021 Value *UndefOff = UndefValue::get(T: OffTy);
2022 SplitUsers.insert(V: &I);
2023 return {UndefRsrc, UndefOff};
2024 }
2025
2026 if (I.getSrcAddressSpace() != AMDGPUAS::BUFFER_RESOURCE)
2027 reportFatalUsageError(
2028 reason: "only buffer resources (addrspace 8) and null/poison pointers can be "
2029 "cast to buffer fat pointers (addrspace 7)");
2030 SplitUsers.insert(V: &I);
2031 return {In, ZeroOff};
2032}
2033
2034PtrParts SplitPtrStructs::visitICmpInst(ICmpInst &Cmp) {
2035 Value *Lhs = Cmp.getOperand(i_nocapture: 0);
2036 if (!isSplitFatPtr(Ty: Lhs->getType()))
2037 return {nullptr, nullptr};
2038 Value *Rhs = Cmp.getOperand(i_nocapture: 1);
2039 IRB.SetInsertPoint(&Cmp);
2040 ICmpInst::Predicate Pred = Cmp.getPredicate();
2041
2042 assert((Pred == ICmpInst::ICMP_EQ || Pred == ICmpInst::ICMP_NE) &&
2043 "Pointer comparison is only equal or unequal");
2044 auto [LhsRsrc, LhsOff] = getPtrParts(V: Lhs);
2045 auto [RhsRsrc, RhsOff] = getPtrParts(V: Rhs);
2046 Value *RsrcCmp =
2047 IRB.CreateICmp(P: Pred, LHS: LhsRsrc, RHS: RhsRsrc, Name: Cmp.getName() + ".rsrc");
2048 copyMetadata(Dest: RsrcCmp, Src: &Cmp);
2049 Value *OffCmp = IRB.CreateICmp(P: Pred, LHS: LhsOff, RHS: RhsOff, Name: Cmp.getName() + ".off");
2050 copyMetadata(Dest: OffCmp, Src: &Cmp);
2051
2052 Value *Res = nullptr;
2053 if (Pred == ICmpInst::ICMP_EQ)
2054 Res = IRB.CreateAnd(LHS: RsrcCmp, RHS: OffCmp);
2055 else if (Pred == ICmpInst::ICMP_NE)
2056 Res = IRB.CreateOr(LHS: RsrcCmp, RHS: OffCmp);
2057 copyMetadata(Dest: Res, Src: &Cmp);
2058 Res->takeName(V: &Cmp);
2059 SplitUsers.insert(V: &Cmp);
2060 Cmp.replaceAllUsesWith(V: Res);
2061 return {nullptr, nullptr};
2062}
2063
2064PtrParts SplitPtrStructs::visitFreezeInst(FreezeInst &I) {
2065 if (!isSplitFatPtr(Ty: I.getType()))
2066 return {nullptr, nullptr};
2067 IRB.SetInsertPoint(&I);
2068 auto [Rsrc, Off] = getPtrParts(V: I.getOperand(i_nocapture: 0));
2069
2070 Value *RsrcRes = IRB.CreateFreeze(V: Rsrc, Name: I.getName() + ".rsrc");
2071 copyMetadata(Dest: RsrcRes, Src: &I);
2072 Value *OffRes = IRB.CreateFreeze(V: Off, Name: I.getName() + ".off");
2073 copyMetadata(Dest: OffRes, Src: &I);
2074 SplitUsers.insert(V: &I);
2075 return {RsrcRes, OffRes};
2076}
2077
2078PtrParts SplitPtrStructs::visitExtractElementInst(ExtractElementInst &I) {
2079 if (!isSplitFatPtr(Ty: I.getType()))
2080 return {nullptr, nullptr};
2081 IRB.SetInsertPoint(&I);
2082 Value *Vec = I.getVectorOperand();
2083 Value *Idx = I.getIndexOperand();
2084 auto [Rsrc, Off] = getPtrParts(V: Vec);
2085
2086 Value *RsrcRes = IRB.CreateExtractElement(Vec: Rsrc, Idx, Name: I.getName() + ".rsrc");
2087 copyMetadata(Dest: RsrcRes, Src: &I);
2088 Value *OffRes = IRB.CreateExtractElement(Vec: Off, Idx, Name: I.getName() + ".off");
2089 copyMetadata(Dest: OffRes, Src: &I);
2090 SplitUsers.insert(V: &I);
2091 return {RsrcRes, OffRes};
2092}
2093
2094PtrParts SplitPtrStructs::visitInsertElementInst(InsertElementInst &I) {
2095 // The mutated instructions temporarily don't return vectors, and so
2096 // we need the generic getType() here to avoid crashes.
2097 if (!isSplitFatPtr(Ty: cast<Instruction>(Val&: I).getType()))
2098 return {nullptr, nullptr};
2099 IRB.SetInsertPoint(&I);
2100 Value *Vec = I.getOperand(i_nocapture: 0);
2101 Value *Elem = I.getOperand(i_nocapture: 1);
2102 Value *Idx = I.getOperand(i_nocapture: 2);
2103 auto [VecRsrc, VecOff] = getPtrParts(V: Vec);
2104 auto [ElemRsrc, ElemOff] = getPtrParts(V: Elem);
2105
2106 Value *RsrcRes =
2107 IRB.CreateInsertElement(Vec: VecRsrc, NewElt: ElemRsrc, Idx, Name: I.getName() + ".rsrc");
2108 copyMetadata(Dest: RsrcRes, Src: &I);
2109 Value *OffRes =
2110 IRB.CreateInsertElement(Vec: VecOff, NewElt: ElemOff, Idx, Name: I.getName() + ".off");
2111 copyMetadata(Dest: OffRes, Src: &I);
2112 SplitUsers.insert(V: &I);
2113 return {RsrcRes, OffRes};
2114}
2115
2116PtrParts SplitPtrStructs::visitShuffleVectorInst(ShuffleVectorInst &I) {
2117 // Cast is needed for the same reason as insertelement's.
2118 if (!isSplitFatPtr(Ty: cast<Instruction>(Val&: I).getType()))
2119 return {nullptr, nullptr};
2120 IRB.SetInsertPoint(&I);
2121
2122 Value *V1 = I.getOperand(i_nocapture: 0);
2123 Value *V2 = I.getOperand(i_nocapture: 1);
2124 ArrayRef<int> Mask = I.getShuffleMask();
2125 auto [V1Rsrc, V1Off] = getPtrParts(V: V1);
2126 auto [V2Rsrc, V2Off] = getPtrParts(V: V2);
2127
2128 Value *RsrcRes =
2129 IRB.CreateShuffleVector(V1: V1Rsrc, V2: V2Rsrc, Mask, Name: I.getName() + ".rsrc");
2130 copyMetadata(Dest: RsrcRes, Src: &I);
2131 Value *OffRes =
2132 IRB.CreateShuffleVector(V1: V1Off, V2: V2Off, Mask, Name: I.getName() + ".off");
2133 copyMetadata(Dest: OffRes, Src: &I);
2134 SplitUsers.insert(V: &I);
2135 return {RsrcRes, OffRes};
2136}
2137
2138PtrParts SplitPtrStructs::visitPHINode(PHINode &PHI) {
2139 if (!isSplitFatPtr(Ty: PHI.getType()))
2140 return {nullptr, nullptr};
2141 IRB.SetInsertPoint(*PHI.getInsertionPointAfterDef());
2142 // Phi nodes will be handled in post-processing after we've visited every
2143 // instruction. However, instead of just returning {nullptr, nullptr},
2144 // we explicitly create the temporary extractvalue operations that are our
2145 // temporary results so that they end up at the beginning of the block with
2146 // the PHIs.
2147 Value *TmpRsrc = IRB.CreateExtractValue(Agg: &PHI, Idxs: 0, Name: PHI.getName() + ".rsrc");
2148 Value *TmpOff = IRB.CreateExtractValue(Agg: &PHI, Idxs: 1, Name: PHI.getName() + ".off");
2149 Conditionals.push_back(Elt: &PHI);
2150 SplitUsers.insert(V: &PHI);
2151 return {TmpRsrc, TmpOff};
2152}
2153
2154PtrParts SplitPtrStructs::visitSelectInst(SelectInst &SI) {
2155 if (!isSplitFatPtr(Ty: SI.getType()))
2156 return {nullptr, nullptr};
2157 IRB.SetInsertPoint(&SI);
2158
2159 Value *Cond = SI.getCondition();
2160 Value *True = SI.getTrueValue();
2161 Value *False = SI.getFalseValue();
2162 auto [TrueRsrc, TrueOff] = getPtrParts(V: True);
2163 auto [FalseRsrc, FalseOff] = getPtrParts(V: False);
2164
2165 Value *RsrcRes =
2166 IRB.CreateSelect(C: Cond, True: TrueRsrc, False: FalseRsrc, Name: SI.getName() + ".rsrc", MDFrom: &SI);
2167 copyMetadata(Dest: RsrcRes, Src: &SI);
2168 Conditionals.push_back(Elt: &SI);
2169 Value *OffRes =
2170 IRB.CreateSelect(C: Cond, True: TrueOff, False: FalseOff, Name: SI.getName() + ".off", MDFrom: &SI);
2171 copyMetadata(Dest: OffRes, Src: &SI);
2172 SplitUsers.insert(V: &SI);
2173 return {RsrcRes, OffRes};
2174}
2175
2176/// Returns true if this intrinsic needs to be removed when it is
2177/// applied to `ptr addrspace(7)` values. Calls to these intrinsics are
2178/// rewritten into calls to versions of that intrinsic on the resource
2179/// descriptor.
2180static bool isRemovablePointerIntrinsic(Intrinsic::ID IID) {
2181 switch (IID) {
2182 default:
2183 return false;
2184 case Intrinsic::amdgcn_make_buffer_rsrc:
2185 case Intrinsic::ptrmask:
2186 case Intrinsic::invariant_start:
2187 case Intrinsic::invariant_end:
2188 case Intrinsic::launder_invariant_group:
2189 case Intrinsic::strip_invariant_group:
2190 case Intrinsic::memcpy:
2191 case Intrinsic::memcpy_inline:
2192 case Intrinsic::memmove:
2193 case Intrinsic::memset:
2194 case Intrinsic::memset_inline:
2195 case Intrinsic::experimental_memset_pattern:
2196 case Intrinsic::amdgcn_load_to_lds:
2197 return true;
2198 }
2199}
2200
2201PtrParts SplitPtrStructs::visitIntrinsicInst(IntrinsicInst &I) {
2202 Intrinsic::ID IID = I.getIntrinsicID();
2203 switch (IID) {
2204 default:
2205 break;
2206 case Intrinsic::amdgcn_make_buffer_rsrc: {
2207 if (!isSplitFatPtr(Ty: I.getType()))
2208 return {nullptr, nullptr};
2209 Value *Base = I.getArgOperand(i: 0);
2210 Value *Stride = I.getArgOperand(i: 1);
2211 Value *NumRecords = I.getArgOperand(i: 2);
2212 Value *Flags = I.getArgOperand(i: 3);
2213 auto *SplitType = cast<StructType>(Val: I.getType());
2214 Type *RsrcType = SplitType->getElementType(N: 0);
2215 Type *OffType = SplitType->getElementType(N: 1);
2216 IRB.SetInsertPoint(&I);
2217 Value *Rsrc = IRB.CreateIntrinsic(ID: IID, Types: {RsrcType, Base->getType()},
2218 Args: {Base, Stride, NumRecords, Flags});
2219 copyMetadata(Dest: Rsrc, Src: &I);
2220 Rsrc->takeName(V: &I);
2221 Value *Zero = Constant::getNullValue(Ty: OffType);
2222 SplitUsers.insert(V: &I);
2223 return {Rsrc, Zero};
2224 }
2225 case Intrinsic::ptrmask: {
2226 Value *Ptr = I.getArgOperand(i: 0);
2227 if (!isSplitFatPtr(Ty: Ptr->getType()))
2228 return {nullptr, nullptr};
2229 Value *Mask = I.getArgOperand(i: 1);
2230 IRB.SetInsertPoint(&I);
2231 auto [Rsrc, Off] = getPtrParts(V: Ptr);
2232 if (Mask->getType() != Off->getType())
2233 reportFatalUsageError(reason: "offset width is not equal to index width of fat "
2234 "pointer (data layout not set up correctly?)");
2235 Value *OffRes = IRB.CreateAnd(LHS: Off, RHS: Mask, Name: I.getName() + ".off");
2236 copyMetadata(Dest: OffRes, Src: &I);
2237 SplitUsers.insert(V: &I);
2238 return {Rsrc, OffRes};
2239 }
2240 // Pointer annotation intrinsics that, given their object-wide nature
2241 // operate on the resource part.
2242 case Intrinsic::invariant_start: {
2243 Value *Ptr = I.getArgOperand(i: 1);
2244 if (!isSplitFatPtr(Ty: Ptr->getType()))
2245 return {nullptr, nullptr};
2246 IRB.SetInsertPoint(&I);
2247 auto [Rsrc, Off] = getPtrParts(V: Ptr);
2248 Type *NewTy = PointerType::get(C&: I.getContext(), AddressSpace: AMDGPUAS::BUFFER_RESOURCE);
2249 auto *NewRsrc = IRB.CreateIntrinsic(ID: IID, Types: {NewTy}, Args: {I.getOperand(i_nocapture: 0), Rsrc});
2250 copyMetadata(Dest: NewRsrc, Src: &I);
2251 NewRsrc->takeName(V: &I);
2252 SplitUsers.insert(V: &I);
2253 I.replaceAllUsesWith(V: NewRsrc);
2254 return {nullptr, nullptr};
2255 }
2256 case Intrinsic::invariant_end: {
2257 Value *RealPtr = I.getArgOperand(i: 2);
2258 if (!isSplitFatPtr(Ty: RealPtr->getType()))
2259 return {nullptr, nullptr};
2260 IRB.SetInsertPoint(&I);
2261 Value *RealRsrc = getPtrParts(V: RealPtr).first;
2262 Value *InvPtr = I.getArgOperand(i: 0);
2263 Value *Size = I.getArgOperand(i: 1);
2264 Value *NewRsrc = IRB.CreateIntrinsic(ID: IID, Types: {RealRsrc->getType()},
2265 Args: {InvPtr, Size, RealRsrc});
2266 copyMetadata(Dest: NewRsrc, Src: &I);
2267 NewRsrc->takeName(V: &I);
2268 SplitUsers.insert(V: &I);
2269 I.replaceAllUsesWith(V: NewRsrc);
2270 return {nullptr, nullptr};
2271 }
2272 case Intrinsic::launder_invariant_group:
2273 case Intrinsic::strip_invariant_group: {
2274 Value *Ptr = I.getArgOperand(i: 0);
2275 if (!isSplitFatPtr(Ty: Ptr->getType()))
2276 return {nullptr, nullptr};
2277 IRB.SetInsertPoint(&I);
2278 auto [Rsrc, Off] = getPtrParts(V: Ptr);
2279 Value *NewRsrc = IRB.CreateIntrinsic(ID: IID, Types: {Rsrc->getType()}, Args: {Rsrc});
2280 copyMetadata(Dest: NewRsrc, Src: &I);
2281 NewRsrc->takeName(V: &I);
2282 SplitUsers.insert(V: &I);
2283 return {NewRsrc, Off};
2284 }
2285 case Intrinsic::amdgcn_load_to_lds: {
2286 Value *Ptr = I.getArgOperand(i: 0);
2287 if (!isSplitFatPtr(Ty: Ptr->getType()))
2288 return {nullptr, nullptr};
2289 IRB.SetInsertPoint(&I);
2290 auto [Rsrc, Off] = getPtrParts(V: Ptr);
2291 Value *LDSPtr = I.getArgOperand(i: 1);
2292 Value *LoadSize = I.getArgOperand(i: 2);
2293 Value *ImmOff = I.getArgOperand(i: 3);
2294 Value *Aux = I.getArgOperand(i: 4);
2295 Value *SOffset = IRB.getInt32(C: 0);
2296 Instruction *NewLoad = IRB.CreateIntrinsic(
2297 ID: Intrinsic::amdgcn_raw_ptr_buffer_load_lds, Types: {},
2298 Args: {Rsrc, LDSPtr, LoadSize, Off, SOffset, ImmOff, Aux});
2299 copyMetadata(Dest: NewLoad, Src: &I);
2300 SplitUsers.insert(V: &I);
2301 I.replaceAllUsesWith(V: NewLoad);
2302 return {nullptr, nullptr};
2303 }
2304 }
2305 return {nullptr, nullptr};
2306}
2307
2308void SplitPtrStructs::processFunction(Function &F) {
2309 ST = &TM->getSubtarget<GCNSubtarget>(F);
2310 SmallVector<Instruction *, 0> Originals(
2311 llvm::make_pointer_range(Range: instructions(F)));
2312 LLVM_DEBUG(dbgs() << "Splitting pointer structs in function: " << F.getName()
2313 << "\n");
2314 for (Instruction *I : Originals) {
2315 auto [Rsrc, Off] = visit(I);
2316 assert(((Rsrc && Off) || (!Rsrc && !Off)) &&
2317 "Can't have a resource but no offset");
2318 if (Rsrc)
2319 RsrcParts[I] = Rsrc;
2320 if (Off)
2321 OffParts[I] = Off;
2322 }
2323 processConditionals();
2324 killAndReplaceSplitInstructions(Origs&: Originals);
2325
2326 // Clean up after ourselves to save on memory.
2327 RsrcParts.clear();
2328 OffParts.clear();
2329 SplitUsers.clear();
2330 Conditionals.clear();
2331 ConditionalTemps.clear();
2332}
2333
2334namespace {
2335class AMDGPULowerBufferFatPointers : public ModulePass {
2336public:
2337 static char ID;
2338
2339 AMDGPULowerBufferFatPointers() : ModulePass(ID) {}
2340
2341 bool run(Module &M, const TargetMachine &TM);
2342 bool runOnModule(Module &M) override;
2343
2344 void getAnalysisUsage(AnalysisUsage &AU) const override;
2345};
2346} // namespace
2347
2348/// Returns true if there are values that have a buffer fat pointer in them,
2349/// which means we'll need to perform rewrites on this function. As a side
2350/// effect, this will populate the type remapping cache.
2351static bool containsBufferFatPointers(const Function &F,
2352 BufferFatPtrToStructTypeMap *TypeMap) {
2353 bool HasFatPointers = false;
2354 for (const BasicBlock &BB : F)
2355 for (const Instruction &I : BB)
2356 HasFatPointers |= (I.getType() != TypeMap->remapType(SrcTy: I.getType()));
2357 return HasFatPointers;
2358}
2359
2360static bool hasFatPointerInterface(const Function &F,
2361 BufferFatPtrToStructTypeMap *TypeMap) {
2362 Type *Ty = F.getFunctionType();
2363 return Ty != TypeMap->remapType(SrcTy: Ty);
2364}
2365
2366/// Move the body of `OldF` into a new function, returning it.
2367static Function *moveFunctionAdaptingType(Function *OldF, FunctionType *NewTy,
2368 ValueToValueMapTy &CloneMap) {
2369 bool IsIntrinsic = OldF->isIntrinsic();
2370 Function *NewF =
2371 Function::Create(Ty: NewTy, Linkage: OldF->getLinkage(), AddrSpace: OldF->getAddressSpace());
2372 NewF->copyAttributesFrom(Src: OldF);
2373 NewF->copyMetadata(Src: OldF, Offset: 0);
2374 NewF->takeName(V: OldF);
2375 NewF->updateAfterNameChange();
2376 NewF->setDLLStorageClass(OldF->getDLLStorageClass());
2377 OldF->getParent()->getFunctionList().insertAfter(where: OldF->getIterator(), New: NewF);
2378
2379 while (!OldF->empty()) {
2380 BasicBlock *BB = &OldF->front();
2381 BB->removeFromParent();
2382 BB->insertInto(Parent: NewF);
2383 CloneMap[BB] = BB;
2384 for (Instruction &I : *BB) {
2385 CloneMap[&I] = &I;
2386 }
2387 }
2388
2389 SmallVector<AttributeSet> ArgAttrs;
2390 AttributeList OldAttrs = OldF->getAttributes();
2391
2392 for (auto [I, OldArg, NewArg] : enumerate(First: OldF->args(), Rest: NewF->args())) {
2393 CloneMap[&NewArg] = &OldArg;
2394 NewArg.takeName(V: &OldArg);
2395 Type *OldArgTy = OldArg.getType(), *NewArgTy = NewArg.getType();
2396 // Temporarily mutate type of `NewArg` to allow RAUW to work.
2397 NewArg.mutateType(Ty: OldArgTy);
2398 OldArg.replaceAllUsesWith(V: &NewArg);
2399 NewArg.mutateType(Ty: NewArgTy);
2400
2401 AttributeSet ArgAttr = OldAttrs.getParamAttrs(ArgNo: I);
2402 // Intrinsics get their attributes fixed later.
2403 if (OldArgTy != NewArgTy && !IsIntrinsic)
2404 ArgAttr = ArgAttr.removeAttributes(
2405 C&: NewF->getContext(),
2406 AttrsToRemove: AttributeFuncs::typeIncompatible(Ty: NewArgTy, AS: ArgAttr));
2407 ArgAttrs.push_back(Elt: ArgAttr);
2408 }
2409 AttributeSet RetAttrs = OldAttrs.getRetAttrs();
2410 if (OldF->getReturnType() != NewF->getReturnType() && !IsIntrinsic)
2411 RetAttrs = RetAttrs.removeAttributes(
2412 C&: NewF->getContext(),
2413 AttrsToRemove: AttributeFuncs::typeIncompatible(Ty: NewF->getReturnType(), AS: RetAttrs));
2414 NewF->setAttributes(AttributeList::get(
2415 C&: NewF->getContext(), FnAttrs: OldAttrs.getFnAttrs(), RetAttrs, ArgAttrs));
2416 return NewF;
2417}
2418
2419static void makeCloneInPraceMap(Function *F, ValueToValueMapTy &CloneMap) {
2420 for (Argument &A : F->args())
2421 CloneMap[&A] = &A;
2422 for (BasicBlock &BB : *F) {
2423 CloneMap[&BB] = &BB;
2424 for (Instruction &I : BB)
2425 CloneMap[&I] = &I;
2426 }
2427}
2428
2429bool AMDGPULowerBufferFatPointers::run(Module &M, const TargetMachine &TM) {
2430 bool Changed = false;
2431 const DataLayout &DL = M.getDataLayout();
2432 // Record the functions which need to be remapped.
2433 // The second element of the pair indicates whether the function has to have
2434 // its arguments or return types adjusted.
2435 SmallVector<std::pair<Function *, bool>> NeedsRemap;
2436
2437 LLVMContext &Ctx = M.getContext();
2438
2439 BufferFatPtrToStructTypeMap StructTM(DL);
2440 BufferFatPtrToIntTypeMap IntTM(DL);
2441 for (const GlobalVariable &GV : M.globals()) {
2442 if (GV.getAddressSpace() == AMDGPUAS::BUFFER_FAT_POINTER) {
2443 // FIXME: Use DiagnosticInfo unsupported but it requires a Function
2444 Ctx.emitError(ErrorStr: "global variables with a buffer fat pointer address "
2445 "space (7) are not supported");
2446 continue;
2447 }
2448
2449 Type *VT = GV.getValueType();
2450 if (VT != StructTM.remapType(SrcTy: VT)) {
2451 // FIXME: Use DiagnosticInfo unsupported but it requires a Function
2452 Ctx.emitError(ErrorStr: "global variables that contain buffer fat pointers "
2453 "(address space 7 pointers) are unsupported. Use "
2454 "buffer resource pointers (address space 8) instead");
2455 continue;
2456 }
2457 }
2458
2459 {
2460 // Collect all constant exprs and aggregates referenced by any function.
2461 SmallVector<Constant *, 8> Worklist;
2462 for (Function &F : M.functions())
2463 for (Instruction &I : instructions(F))
2464 for (Value *Op : I.operands())
2465 if (isa<ConstantExpr, ConstantAggregate>(Val: Op))
2466 Worklist.push_back(Elt: cast<Constant>(Val: Op));
2467
2468 // Recursively look for any referenced buffer pointer constants.
2469 SmallPtrSet<Constant *, 8> Visited;
2470 SetVector<Constant *> BufferFatPtrConsts;
2471 while (!Worklist.empty()) {
2472 Constant *C = Worklist.pop_back_val();
2473 if (!Visited.insert(Ptr: C).second)
2474 continue;
2475 if (isBufferFatPtrOrVector(Ty: C->getType()))
2476 BufferFatPtrConsts.insert(X: C);
2477 for (Value *Op : C->operands())
2478 if (isa<ConstantExpr, ConstantAggregate>(Val: Op))
2479 Worklist.push_back(Elt: cast<Constant>(Val: Op));
2480 }
2481
2482 // Expand all constant expressions using fat buffer pointers to
2483 // instructions.
2484 Changed |= convertUsersOfConstantsToInstructions(
2485 Consts: BufferFatPtrConsts.getArrayRef(), /*RestrictToFunc=*/nullptr,
2486 /*RemoveDeadConstants=*/false, /*IncludeSelf=*/true);
2487 }
2488
2489 StoreFatPtrsAsIntsAndExpandMemcpyVisitor MemOpsRewrite(&IntTM, DL,
2490 M.getContext(), &TM);
2491 LegalizeBufferContentTypesVisitor BufferContentsTypeRewrite(DL,
2492 M.getContext());
2493 for (Function &F : M.functions()) {
2494 bool InterfaceChange = hasFatPointerInterface(F, TypeMap: &StructTM);
2495 bool BodyChanges = containsBufferFatPointers(F, TypeMap: &StructTM);
2496 Changed |= MemOpsRewrite.processFunction(F);
2497 if (InterfaceChange || BodyChanges) {
2498 NeedsRemap.push_back(Elt: std::make_pair(x: &F, y&: InterfaceChange));
2499 Changed |= BufferContentsTypeRewrite.processFunction(F);
2500 }
2501 }
2502 if (NeedsRemap.empty())
2503 return Changed;
2504
2505 SmallVector<Function *> NeedsPostProcess;
2506 SmallVector<Function *> Intrinsics;
2507 // Keep one big map so as to memoize constants across functions.
2508 ValueToValueMapTy CloneMap;
2509 FatPtrConstMaterializer Materializer(&StructTM, CloneMap);
2510
2511 ValueMapper LowerInFuncs(CloneMap, RF_None, &StructTM, &Materializer);
2512 for (auto [F, InterfaceChange] : NeedsRemap) {
2513 Function *NewF = F;
2514 if (InterfaceChange)
2515 NewF = moveFunctionAdaptingType(
2516 OldF: F, NewTy: cast<FunctionType>(Val: StructTM.remapType(SrcTy: F->getFunctionType())),
2517 CloneMap);
2518 else
2519 makeCloneInPraceMap(F, CloneMap);
2520 LowerInFuncs.remapFunction(F&: *NewF);
2521 if (NewF->isIntrinsic())
2522 Intrinsics.push_back(Elt: NewF);
2523 else
2524 NeedsPostProcess.push_back(Elt: NewF);
2525 if (InterfaceChange) {
2526 F->replaceAllUsesWith(V: NewF);
2527 F->eraseFromParent();
2528 }
2529 Changed = true;
2530 }
2531 StructTM.clear();
2532 IntTM.clear();
2533 CloneMap.clear();
2534
2535 SplitPtrStructs Splitter(DL, M.getContext(), &TM);
2536 for (Function *F : NeedsPostProcess)
2537 Splitter.processFunction(F&: *F);
2538 for (Function *F : Intrinsics) {
2539 if (isRemovablePointerIntrinsic(IID: F->getIntrinsicID())) {
2540 F->eraseFromParent();
2541 } else {
2542 std::optional<Function *> NewF = Intrinsic::remangleIntrinsicFunction(F);
2543 if (NewF)
2544 F->replaceAllUsesWith(V: *NewF);
2545 }
2546 }
2547 return Changed;
2548}
2549
2550bool AMDGPULowerBufferFatPointers::runOnModule(Module &M) {
2551 TargetPassConfig &TPC = getAnalysis<TargetPassConfig>();
2552 const TargetMachine &TM = TPC.getTM<TargetMachine>();
2553 return run(M, TM);
2554}
2555
2556char AMDGPULowerBufferFatPointers::ID = 0;
2557
2558char &llvm::AMDGPULowerBufferFatPointersID = AMDGPULowerBufferFatPointers::ID;
2559
2560void AMDGPULowerBufferFatPointers::getAnalysisUsage(AnalysisUsage &AU) const {
2561 AU.addRequired<TargetPassConfig>();
2562}
2563
2564#define PASS_DESC "Lower buffer fat pointer operations to buffer resources"
2565INITIALIZE_PASS_BEGIN(AMDGPULowerBufferFatPointers, DEBUG_TYPE, PASS_DESC,
2566 false, false)
2567INITIALIZE_PASS_DEPENDENCY(TargetPassConfig)
2568INITIALIZE_PASS_END(AMDGPULowerBufferFatPointers, DEBUG_TYPE, PASS_DESC, false,
2569 false)
2570#undef PASS_DESC
2571
2572ModulePass *llvm::createAMDGPULowerBufferFatPointersPass() {
2573 return new AMDGPULowerBufferFatPointers();
2574}
2575
2576PreservedAnalyses
2577AMDGPULowerBufferFatPointersPass::run(Module &M, ModuleAnalysisManager &MA) {
2578 return AMDGPULowerBufferFatPointers().run(M, TM) ? PreservedAnalyses::none()
2579 : PreservedAnalyses::all();
2580}
2581