1//===-- AMDGPULowerBufferFatPointers.cpp ---------------------------=//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This pass lowers operations on buffer fat pointers (addrspace 7) to
10// operations on buffer resources (addrspace 8) and is needed for correct
11// codegen.
12//
13// # Background
14//
15// Address space 7 (the buffer fat pointer) is a 160-bit pointer that consists
16// of a 128-bit buffer descriptor and a 32-bit offset into that descriptor.
17// The buffer resource part needs to be it needs to be a "raw" buffer resource
18// (it must have a stride of 0 and bounds checks must be in raw buffer mode
19// or disabled).
20//
21// When these requirements are met, a buffer resource can be treated as a
22// typical (though quite wide) pointer that follows typical LLVM pointer
23// semantics. This allows the frontend to reason about such buffers (which are
24// often encountered in the context of SPIR-V kernels).
25//
26// However, because of their non-power-of-2 size, these fat pointers cannot be
27// present during translation to MIR (though this restriction may be lifted
28// during the transition to GlobalISel). Therefore, this pass is needed in order
29// to correctly implement these fat pointers.
30//
31// The resource intrinsics take the resource part (the address space 8 pointer)
32// and the offset part (the 32-bit integer) as separate arguments. In addition,
33// many users of these buffers manipulate the offset while leaving the resource
34// part alone. For these reasons, we want to typically separate the resource
35// and offset parts into separate variables, but combine them together when
36// encountering cases where this is required, such as by inserting these values
37// into aggretates or moving them to memory.
38//
39// Therefore, at a high level, `ptr addrspace(7) %x` becomes `ptr addrspace(8)
40// %x.rsrc` and `i32 %x.off`, which will be combined into `{ptr addrspace(8),
41// i32} %x = {%x.rsrc, %x.off}` if needed. Similarly, `vector<Nxp7>` becomes
42// `{vector<Nxp8>, vector<Nxi32 >}` and its component parts.
43//
44// # Implementation
45//
46// This pass proceeds in three main phases:
47//
48// ## Rewriting loads and stores of p7 and memcpy()-like handling
49//
50// The first phase is to rewrite away all loads and stors of `ptr addrspace(7)`,
51// including aggregates containing such pointers, to ones that use `i160`. This
52// is handled by `StoreFatPtrsAsIntsAndExpandMemcpyVisitor` , which visits
53// loads, stores, and allocas and, if the loaded or stored type contains `ptr
54// addrspace(7)`, rewrites that type to one where the p7s are replaced by i160s,
55// copying other parts of aggregates as needed. In the case of a store, each
56// pointer is `ptrtoint`d to i160 before storing, and load integers are
57// `inttoptr`d back. This same transformation is applied to vectors of pointers.
58//
59// Such a transformation allows the later phases of the pass to not need
60// to handle buffer fat pointers moving to and from memory, where we load
61// have to handle the incompatibility between a `{Nxp8, Nxi32}` representation
62// and `Nxi60` directly. Instead, that transposing action (where the vectors
63// of resources and vectors of offsets are concatentated before being stored to
64// memory) are handled through implementing `inttoptr` and `ptrtoint` only.
65//
66// Atomics operations on `ptr addrspace(7)` values are not suppported, as the
67// hardware does not include a 160-bit atomic.
68//
69// In order to save on O(N) work and to ensure that the contents type
70// legalizer correctly splits up wide loads, also unconditionally lower
71// memcpy-like intrinsics into loops here.
72//
73// ## Buffer contents type legalization
74//
75// The underlying buffer intrinsics only support types up to 128 bits long,
76// and don't support complex types. If buffer operations were
77// standard pointer operations that could be represented as MIR-level loads,
78// this would be handled by the various legalization schemes in instruction
79// selection. However, because we have to do the conversion from `load` and
80// `store` to intrinsics at LLVM IR level, we must perform that legalization
81// ourselves.
82//
83// This involves a combination of
84// - Converting arrays to vectors where possible
85// - Otherwise, splitting loads and stores of aggregates into loads/stores of
86// each component.
87// - Zero-extending things to fill a whole number of bytes
88// - Casting values of types that don't neatly correspond to supported machine
89// value
90// (for example, an i96 or i256) into ones that would work (
91// like <3 x i32> and <8 x i32>, respectively)
92// - Splitting values that are too long (such as aforementioned <8 x i32>) into
93// multiple operations.
94//
95// ## Type remapping
96//
97// We use a `ValueMapper` to mangle uses of [vectors of] buffer fat pointers
98// to the corresponding struct type, which has a resource part and an offset
99// part.
100//
101// This uses a `BufferFatPtrToStructTypeMap` and a `FatPtrConstMaterializer`
102// to, usually by way of `setType`ing values. Constants are handled here
103// because there isn't a good way to fix them up later.
104//
105// This has the downside of leaving the IR in an invalid state (for example,
106// the instruction `getelementptr {ptr addrspace(8), i32} %p, ...` will exist),
107// but all such invalid states will be resolved by the third phase.
108//
109// Functions that don't take buffer fat pointers are modified in place. Those
110// that do take such pointers have their basic blocks moved to a new function
111// with arguments that are {ptr addrspace(8), i32} arguments and return values.
112// This phase also records intrinsics so that they can be remangled or deleted
113// later.
114//
115// ## Splitting pointer structs
116//
117// The meat of this pass consists of defining semantics for operations that
118// produce or consume [vectors of] buffer fat pointers in terms of their
119// resource and offset parts. This is accomplished throgh the `SplitPtrStructs`
120// visitor.
121//
122// In the first pass through each function that is being lowered, the splitter
123// inserts new instructions to implement the split-structures behavior, which is
124// needed for correctness and performance. It records a list of "split users",
125// instructions that are being replaced by operations on the resource and offset
126// parts.
127//
128// Split users do not necessarily need to produce parts themselves (
129// a `load float, ptr addrspace(7)` does not, for example), but, if they do not
130// generate fat buffer pointers, they must RAUW in their replacement
131// instructions during the initial visit.
132//
133// When these new instructions are created, they use the split parts recorded
134// for their initial arguments in order to generate their replacements, creating
135// a parallel set of instructions that does not refer to the original fat
136// pointer values but instead to their resource and offset components.
137//
138// Instructions, such as `extractvalue`, that produce buffer fat pointers from
139// sources that do not have split parts, have such parts generated using
140// `extractvalue`. This is also the initial handling of PHI nodes, which
141// are then cleaned up.
142//
143// ### Conditionals
144//
145// PHI nodes are initially given resource parts via `extractvalue`. However,
146// this is not an efficient rewrite of such nodes, as, in most cases, the
147// resource part in a conditional or loop remains constant throughout the loop
148// and only the offset varies. Failing to optimize away these constant resources
149// would cause additional registers to be sent around loops and might lead to
150// waterfall loops being generated for buffer operations due to the
151// "non-uniform" resource argument.
152//
153// Therefore, after all instructions have been visited, the pointer splitter
154// post-processes all encountered conditionals. Given a PHI node or select,
155// getPossibleRsrcRoots() collects all values that the resource parts of that
156// conditional's input could come from as well as collecting all conditional
157// instructions encountered during the search. If, after filtering out the
158// initial node itself, the set of encountered conditionals is a subset of the
159// potential roots and there is a single potential resource that isn't in the
160// conditional set, that value is the only possible value the resource argument
161// could have throughout the control flow.
162//
163// If that condition is met, then a PHI node can have its resource part changed
164// to the singleton value and then be replaced by a PHI on the offsets.
165// Otherwise, each PHI node is split into two, one for the resource part and one
166// for the offset part, which replace the temporary `extractvalue` instructions
167// that were added during the first pass.
168//
169// Similar logic applies to `select`, where
170// `%z = select i1 %cond, %cond, ptr addrspace(7) %x, ptr addrspace(7) %y`
171// can be split into `%z.rsrc = %x.rsrc` and
172// `%z.off = select i1 %cond, ptr i32 %x.off, i32 %y.off`
173// if both `%x` and `%y` have the same resource part, but two `select`
174// operations will be needed if they do not.
175//
176// ### Final processing
177//
178// After conditionals have been cleaned up, the IR for each function is
179// rewritten to remove all the old instructions that have been split up.
180//
181// Any instruction that used to produce a buffer fat pointer (and therefore now
182// produces a resource-and-offset struct after type remapping) is
183// replaced as follows:
184// 1. All debug value annotations are cloned to reflect that the resource part
185// and offset parts are computed separately and constitute different
186// fragments of the underlying source language variable.
187// 2. All uses that were themselves split are replaced by a `poison` of the
188// struct type, as they will themselves be erased soon. This rule, combined
189// with debug handling, should leave the use lists of split instructions
190// empty in almost all cases.
191// 3. If a user of the original struct-valued result remains, the structure
192// needed for the new types to work is constructed out of the newly-defined
193// parts, and the original instruction is replaced by this structure
194// before being erased. Instructions requiring this construction include
195// `ret` and `insertvalue`.
196//
197// # Consequences
198//
199// This pass does not alter the CFG.
200//
201// Alias analysis information will become coarser, as the LLVM alias analyzer
202// cannot handle the buffer intrinsics. Specifically, while we can determine
203// that the following two loads do not alias:
204// ```
205// %y = getelementptr i32, ptr addrspace(7) %x, i32 1
206// %a = load i32, ptr addrspace(7) %x
207// %b = load i32, ptr addrspace(7) %y
208// ```
209// we cannot (except through some code that runs during scheduling) determine
210// that the rewritten loads below do not alias.
211// ```
212// %y.off = add i32 %x.off, 1
213// %a = call @llvm.amdgcn.raw.ptr.buffer.load(ptr addrspace(8) %x.rsrc, i32
214// %x.off, ...)
215// %b = call @llvm.amdgcn.raw.ptr.buffer.load(ptr addrspace(8)
216// %x.rsrc, i32 %y.off, ...)
217// ```
218// However, existing alias information is preserved.
219//===----------------------------------------------------------------------===//
220
221#include "AMDGPU.h"
222#include "AMDGPUTargetMachine.h"
223#include "GCNSubtarget.h"
224#include "SIDefines.h"
225#include "llvm/ADT/SetOperations.h"
226#include "llvm/ADT/SmallVector.h"
227#include "llvm/Analysis/InstSimplifyFolder.h"
228#include "llvm/Analysis/TargetTransformInfo.h"
229#include "llvm/Analysis/Utils/Local.h"
230#include "llvm/CodeGen/TargetPassConfig.h"
231#include "llvm/IR/AttributeMask.h"
232#include "llvm/IR/Constants.h"
233#include "llvm/IR/DebugInfo.h"
234#include "llvm/IR/DerivedTypes.h"
235#include "llvm/IR/IRBuilder.h"
236#include "llvm/IR/InstIterator.h"
237#include "llvm/IR/InstVisitor.h"
238#include "llvm/IR/Instructions.h"
239#include "llvm/IR/IntrinsicInst.h"
240#include "llvm/IR/Intrinsics.h"
241#include "llvm/IR/IntrinsicsAMDGPU.h"
242#include "llvm/IR/Metadata.h"
243#include "llvm/IR/Operator.h"
244#include "llvm/IR/PatternMatch.h"
245#include "llvm/IR/ReplaceConstant.h"
246#include "llvm/IR/ValueHandle.h"
247#include "llvm/InitializePasses.h"
248#include "llvm/Pass.h"
249#include "llvm/Support/AMDGPUAddrSpace.h"
250#include "llvm/Support/Alignment.h"
251#include "llvm/Support/AtomicOrdering.h"
252#include "llvm/Support/Debug.h"
253#include "llvm/Support/ErrorHandling.h"
254#include "llvm/Transforms/Utils/Cloning.h"
255#include "llvm/Transforms/Utils/Local.h"
256#include "llvm/Transforms/Utils/LowerMemIntrinsics.h"
257#include "llvm/Transforms/Utils/ValueMapper.h"
258
259#define DEBUG_TYPE "amdgpu-lower-buffer-fat-pointers"
260
261using namespace llvm;
262
263static constexpr unsigned BufferOffsetWidth = 32;
264
265namespace {
266/// Recursively replace instances of ptr addrspace(7) and vector<Nxptr
267/// addrspace(7)> with some other type as defined by the relevant subclass.
268class BufferFatPtrTypeLoweringBase : public ValueMapTypeRemapper {
269 DenseMap<Type *, Type *> Map;
270
271 Type *remapTypeImpl(Type *Ty);
272
273protected:
274 virtual Type *remapScalar(PointerType *PT) = 0;
275 virtual Type *remapVector(VectorType *VT) = 0;
276
277 const DataLayout &DL;
278
279public:
280 BufferFatPtrTypeLoweringBase(const DataLayout &DL) : DL(DL) {}
281 Type *remapType(Type *SrcTy) override;
282 void clear() { Map.clear(); }
283};
284
285/// Remap ptr addrspace(7) to i160 and vector<Nxptr addrspace(7)> to
286/// vector<Nxi60> in order to correctly handling loading/storing these values
287/// from memory.
288class BufferFatPtrToIntTypeMap : public BufferFatPtrTypeLoweringBase {
289 using BufferFatPtrTypeLoweringBase::BufferFatPtrTypeLoweringBase;
290
291protected:
292 Type *remapScalar(PointerType *PT) override { return DL.getIntPtrType(PT); }
293 Type *remapVector(VectorType *VT) override { return DL.getIntPtrType(VT); }
294};
295
296/// Remap ptr addrspace(7) to {ptr addrspace(8), i32} (the resource and offset
297/// parts of the pointer) so that we can easily rewrite operations on these
298/// values that aren't loading them from or storing them to memory.
299class BufferFatPtrToStructTypeMap : public BufferFatPtrTypeLoweringBase {
300 using BufferFatPtrTypeLoweringBase::BufferFatPtrTypeLoweringBase;
301
302protected:
303 Type *remapScalar(PointerType *PT) override;
304 Type *remapVector(VectorType *VT) override;
305};
306} // namespace
307
308// This code is adapted from the type remapper in lib/Linker/IRMover.cpp
309Type *BufferFatPtrTypeLoweringBase::remapTypeImpl(Type *Ty) {
310 Type **Entry = &Map[Ty];
311 if (*Entry)
312 return *Entry;
313 if (auto *PT = dyn_cast<PointerType>(Val: Ty)) {
314 if (PT->getAddressSpace() == AMDGPUAS::BUFFER_FAT_POINTER) {
315 return *Entry = remapScalar(PT);
316 }
317 }
318 if (auto *VT = dyn_cast<VectorType>(Val: Ty)) {
319 auto *PT = dyn_cast<PointerType>(Val: VT->getElementType());
320 if (PT && PT->getAddressSpace() == AMDGPUAS::BUFFER_FAT_POINTER) {
321 return *Entry = remapVector(VT);
322 }
323 return *Entry = Ty;
324 }
325 // Whether the type is one that is structurally uniqued - that is, if it is
326 // not a named struct (the only kind of type where multiple structurally
327 // identical types that have a distinct `Type*`)
328 StructType *TyAsStruct = dyn_cast<StructType>(Val: Ty);
329 bool IsUniqued = !TyAsStruct || TyAsStruct->isLiteral();
330 // Base case for ints, floats, opaque pointers, and so on, which don't
331 // require recursion.
332 if (Ty->getNumContainedTypes() == 0 && IsUniqued)
333 return *Entry = Ty;
334 bool Changed = false;
335 SmallVector<Type *> ElementTypes(Ty->getNumContainedTypes(), nullptr);
336 for (unsigned int I = 0, E = Ty->getNumContainedTypes(); I < E; ++I) {
337 Type *OldElem = Ty->getContainedType(i: I);
338 Type *NewElem = remapTypeImpl(Ty: OldElem);
339 ElementTypes[I] = NewElem;
340 Changed |= (OldElem != NewElem);
341 }
342 // Recursive calls to remapTypeImpl() may have invalidated pointer.
343 Entry = &Map[Ty];
344 if (!Changed) {
345 return *Entry = Ty;
346 }
347 if (auto *ArrTy = dyn_cast<ArrayType>(Val: Ty))
348 return *Entry = ArrayType::get(ElementType: ElementTypes[0], NumElements: ArrTy->getNumElements());
349 if (auto *FnTy = dyn_cast<FunctionType>(Val: Ty))
350 return *Entry = FunctionType::get(Result: ElementTypes[0],
351 Params: ArrayRef(ElementTypes).slice(N: 1),
352 isVarArg: FnTy->isVarArg());
353 if (auto *STy = dyn_cast<StructType>(Val: Ty)) {
354 // Genuine opaque types don't have a remapping.
355 if (STy->isOpaque())
356 return *Entry = Ty;
357 bool IsPacked = STy->isPacked();
358 if (IsUniqued)
359 return *Entry = StructType::get(Context&: Ty->getContext(), Elements: ElementTypes, isPacked: IsPacked);
360 SmallString<16> Name(STy->getName());
361 STy->setName("");
362 return *Entry = StructType::create(Context&: Ty->getContext(), Elements: ElementTypes, Name,
363 isPacked: IsPacked);
364 }
365 llvm_unreachable("Unknown type of type that contains elements");
366}
367
368Type *BufferFatPtrTypeLoweringBase::remapType(Type *SrcTy) {
369 return remapTypeImpl(Ty: SrcTy);
370}
371
372Type *BufferFatPtrToStructTypeMap::remapScalar(PointerType *PT) {
373 LLVMContext &Ctx = PT->getContext();
374 return StructType::get(elt1: PointerType::get(C&: Ctx, AddressSpace: AMDGPUAS::BUFFER_RESOURCE),
375 elts: IntegerType::get(C&: Ctx, NumBits: BufferOffsetWidth));
376}
377
378Type *BufferFatPtrToStructTypeMap::remapVector(VectorType *VT) {
379 ElementCount EC = VT->getElementCount();
380 LLVMContext &Ctx = VT->getContext();
381 Type *RsrcVec =
382 VectorType::get(ElementType: PointerType::get(C&: Ctx, AddressSpace: AMDGPUAS::BUFFER_RESOURCE), EC);
383 Type *OffVec = VectorType::get(ElementType: IntegerType::get(C&: Ctx, NumBits: BufferOffsetWidth), EC);
384 return StructType::get(elt1: RsrcVec, elts: OffVec);
385}
386
387static bool isBufferFatPtrOrVector(Type *Ty) {
388 if (auto *PT = dyn_cast<PointerType>(Val: Ty->getScalarType()))
389 return PT->getAddressSpace() == AMDGPUAS::BUFFER_FAT_POINTER;
390 return false;
391}
392
393// True if the type is {ptr addrspace(8), i32} or a struct containing vectors of
394// those types. Used to quickly skip instructions we don't need to process.
395static bool isSplitFatPtr(Type *Ty) {
396 auto *ST = dyn_cast<StructType>(Val: Ty);
397 if (!ST)
398 return false;
399 if (!ST->isLiteral() || ST->getNumElements() != 2)
400 return false;
401 auto *MaybeRsrc =
402 dyn_cast<PointerType>(Val: ST->getElementType(N: 0)->getScalarType());
403 auto *MaybeOff =
404 dyn_cast<IntegerType>(Val: ST->getElementType(N: 1)->getScalarType());
405 return MaybeRsrc && MaybeOff &&
406 MaybeRsrc->getAddressSpace() == AMDGPUAS::BUFFER_RESOURCE &&
407 MaybeOff->getBitWidth() == BufferOffsetWidth;
408}
409
410// True if the result type or any argument types are buffer fat pointers.
411static bool isBufferFatPtrConst(Constant *C) {
412 Type *T = C->getType();
413 return isBufferFatPtrOrVector(Ty: T) || any_of(Range: C->operands(), P: [](const Use &U) {
414 return isBufferFatPtrOrVector(Ty: U.get()->getType());
415 });
416}
417
418namespace {
419/// Convert [vectors of] buffer fat pointers to integers when they are read from
420/// or stored to memory. This ensures that these pointers will have the same
421/// memory layout as before they are lowered, even though they will no longer
422/// have their previous layout in registers/in the program (they'll be broken
423/// down into resource and offset parts). This has the downside of imposing
424/// marshalling costs when reading or storing these values, but since placing
425/// such pointers into memory is an uncommon operation at best, we feel that
426/// this cost is acceptable for better performance in the common case.
427class StoreFatPtrsAsIntsAndExpandMemcpyVisitor
428 : public InstVisitor<StoreFatPtrsAsIntsAndExpandMemcpyVisitor, bool> {
429 BufferFatPtrToIntTypeMap *TypeMap;
430
431 ValueToValueMapTy ConvertedForStore;
432
433 IRBuilder<InstSimplifyFolder> IRB;
434
435 const TargetMachine *TM;
436
437 // Convert all the buffer fat pointers within the input value to inttegers
438 // so that it can be stored in memory.
439 Value *fatPtrsToInts(Value *V, Type *From, Type *To, const Twine &Name);
440 // Convert all the i160s that need to be buffer fat pointers (as specified)
441 // by the To type) into those pointers to preserve the semantics of the rest
442 // of the program.
443 Value *intsToFatPtrs(Value *V, Type *From, Type *To, const Twine &Name);
444
445public:
446 StoreFatPtrsAsIntsAndExpandMemcpyVisitor(BufferFatPtrToIntTypeMap *TypeMap,
447 const DataLayout &DL,
448 LLVMContext &Ctx,
449 const TargetMachine *TM)
450 : TypeMap(TypeMap), IRB(Ctx, InstSimplifyFolder(DL)), TM(TM) {}
451 bool processFunction(Function &F);
452
453 bool visitInstruction(Instruction &I) { return false; }
454 bool visitAllocaInst(AllocaInst &I);
455 bool visitLoadInst(LoadInst &LI);
456 bool visitStoreInst(StoreInst &SI);
457 bool visitGetElementPtrInst(GetElementPtrInst &I);
458
459 bool visitMemCpyInst(MemCpyInst &MCI);
460 bool visitMemMoveInst(MemMoveInst &MMI);
461 bool visitMemSetInst(MemSetInst &MSI);
462 bool visitMemSetPatternInst(MemSetPatternInst &MSPI);
463};
464} // namespace
465
466Value *StoreFatPtrsAsIntsAndExpandMemcpyVisitor::fatPtrsToInts(
467 Value *V, Type *From, Type *To, const Twine &Name) {
468 if (From == To)
469 return V;
470 ValueToValueMapTy::iterator Find = ConvertedForStore.find(Val: V);
471 if (Find != ConvertedForStore.end())
472 return Find->second;
473 if (isBufferFatPtrOrVector(Ty: From)) {
474 Value *Cast = IRB.CreatePtrToInt(V, DestTy: To, Name: Name + ".int");
475 ConvertedForStore[V] = Cast;
476 return Cast;
477 }
478 if (From->getNumContainedTypes() == 0)
479 return V;
480 // Structs, arrays, and other compound types.
481 Value *Ret = PoisonValue::get(T: To);
482 if (auto *AT = dyn_cast<ArrayType>(Val: From)) {
483 Type *FromPart = AT->getArrayElementType();
484 Type *ToPart = cast<ArrayType>(Val: To)->getElementType();
485 for (uint64_t I = 0, E = AT->getArrayNumElements(); I < E; ++I) {
486 Value *Field = IRB.CreateExtractValue(Agg: V, Idxs: I);
487 Value *NewField =
488 fatPtrsToInts(V: Field, From: FromPart, To: ToPart, Name: Name + "." + Twine(I));
489 Ret = IRB.CreateInsertValue(Agg: Ret, Val: NewField, Idxs: I);
490 }
491 } else {
492 for (auto [Idx, FromPart, ToPart] :
493 enumerate(First: From->subtypes(), Rest: To->subtypes())) {
494 Value *Field = IRB.CreateExtractValue(Agg: V, Idxs: Idx);
495 Value *NewField =
496 fatPtrsToInts(V: Field, From: FromPart, To: ToPart, Name: Name + "." + Twine(Idx));
497 Ret = IRB.CreateInsertValue(Agg: Ret, Val: NewField, Idxs: Idx);
498 }
499 }
500 ConvertedForStore[V] = Ret;
501 return Ret;
502}
503
504Value *StoreFatPtrsAsIntsAndExpandMemcpyVisitor::intsToFatPtrs(
505 Value *V, Type *From, Type *To, const Twine &Name) {
506 if (From == To)
507 return V;
508 if (isBufferFatPtrOrVector(Ty: To)) {
509 Value *Cast = IRB.CreateIntToPtr(V, DestTy: To, Name: Name + ".ptr");
510 return Cast;
511 }
512 if (From->getNumContainedTypes() == 0)
513 return V;
514 // Structs, arrays, and other compound types.
515 Value *Ret = PoisonValue::get(T: To);
516 if (auto *AT = dyn_cast<ArrayType>(Val: From)) {
517 Type *FromPart = AT->getArrayElementType();
518 Type *ToPart = cast<ArrayType>(Val: To)->getElementType();
519 for (uint64_t I = 0, E = AT->getArrayNumElements(); I < E; ++I) {
520 Value *Field = IRB.CreateExtractValue(Agg: V, Idxs: I);
521 Value *NewField =
522 intsToFatPtrs(V: Field, From: FromPart, To: ToPart, Name: Name + "." + Twine(I));
523 Ret = IRB.CreateInsertValue(Agg: Ret, Val: NewField, Idxs: I);
524 }
525 } else {
526 for (auto [Idx, FromPart, ToPart] :
527 enumerate(First: From->subtypes(), Rest: To->subtypes())) {
528 Value *Field = IRB.CreateExtractValue(Agg: V, Idxs: Idx);
529 Value *NewField =
530 intsToFatPtrs(V: Field, From: FromPart, To: ToPart, Name: Name + "." + Twine(Idx));
531 Ret = IRB.CreateInsertValue(Agg: Ret, Val: NewField, Idxs: Idx);
532 }
533 }
534 return Ret;
535}
536
537bool StoreFatPtrsAsIntsAndExpandMemcpyVisitor::processFunction(Function &F) {
538 bool Changed = false;
539 // Process memcpy-like instructions after the main iteration because they can
540 // invalidate iterators.
541 SmallVector<WeakTrackingVH> CanBecomeLoops;
542 for (Instruction &I : make_early_inc_range(Range: instructions(F))) {
543 if (isa<MemTransferInst, MemSetInst, MemSetPatternInst>(Val: I))
544 CanBecomeLoops.push_back(Elt: &I);
545 else
546 Changed |= visit(I);
547 }
548 for (WeakTrackingVH VH : make_early_inc_range(Range&: CanBecomeLoops)) {
549 Changed |= visit(I: cast<Instruction>(Val&: VH));
550 }
551 ConvertedForStore.clear();
552 return Changed;
553}
554
555bool StoreFatPtrsAsIntsAndExpandMemcpyVisitor::visitAllocaInst(AllocaInst &I) {
556 Type *Ty = I.getAllocatedType();
557 Type *NewTy = TypeMap->remapType(SrcTy: Ty);
558 if (Ty == NewTy)
559 return false;
560 I.setAllocatedType(NewTy);
561 return true;
562}
563
564bool StoreFatPtrsAsIntsAndExpandMemcpyVisitor::visitGetElementPtrInst(
565 GetElementPtrInst &I) {
566 Type *Ty = I.getSourceElementType();
567 Type *NewTy = TypeMap->remapType(SrcTy: Ty);
568 if (Ty == NewTy)
569 return false;
570 // We'll be rewriting the type `ptr addrspace(7)` out of existence soon, so
571 // make sure GEPs don't have different semantics with the new type.
572 I.setSourceElementType(NewTy);
573 I.setResultElementType(TypeMap->remapType(SrcTy: I.getResultElementType()));
574 return true;
575}
576
577bool StoreFatPtrsAsIntsAndExpandMemcpyVisitor::visitLoadInst(LoadInst &LI) {
578 Type *Ty = LI.getType();
579 Type *IntTy = TypeMap->remapType(SrcTy: Ty);
580 if (Ty == IntTy)
581 return false;
582
583 IRB.SetInsertPoint(&LI);
584 auto *NLI = cast<LoadInst>(Val: LI.clone());
585 NLI->mutateType(Ty: IntTy);
586 NLI = IRB.Insert(I: NLI);
587 NLI->takeName(V: &LI);
588
589 Value *CastBack = intsToFatPtrs(V: NLI, From: IntTy, To: Ty, Name: NLI->getName());
590 LI.replaceAllUsesWith(V: CastBack);
591 LI.eraseFromParent();
592 return true;
593}
594
595bool StoreFatPtrsAsIntsAndExpandMemcpyVisitor::visitStoreInst(StoreInst &SI) {
596 Value *V = SI.getValueOperand();
597 Type *Ty = V->getType();
598 Type *IntTy = TypeMap->remapType(SrcTy: Ty);
599 if (Ty == IntTy)
600 return false;
601
602 IRB.SetInsertPoint(&SI);
603 Value *IntV = fatPtrsToInts(V, From: Ty, To: IntTy, Name: V->getName());
604 for (auto *Dbg : at::getDVRAssignmentMarkers(Inst: &SI))
605 Dbg->setRawLocation(ValueAsMetadata::get(V: IntV));
606
607 SI.setOperand(i_nocapture: 0, Val_nocapture: IntV);
608 return true;
609}
610
611bool StoreFatPtrsAsIntsAndExpandMemcpyVisitor::visitMemCpyInst(
612 MemCpyInst &MCI) {
613 // TODO: Allow memcpy.p7.p3 as a synonym for the direct-to-LDS copy, which'll
614 // need loop expansion here.
615 if (MCI.getSourceAddressSpace() != AMDGPUAS::BUFFER_FAT_POINTER &&
616 MCI.getDestAddressSpace() != AMDGPUAS::BUFFER_FAT_POINTER)
617 return false;
618 llvm::expandMemCpyAsLoop(MemCpy: &MCI,
619 TTI: TM->getTargetTransformInfo(F: *MCI.getFunction()));
620 MCI.eraseFromParent();
621 return true;
622}
623
624bool StoreFatPtrsAsIntsAndExpandMemcpyVisitor::visitMemMoveInst(
625 MemMoveInst &MMI) {
626 if (MMI.getSourceAddressSpace() != AMDGPUAS::BUFFER_FAT_POINTER &&
627 MMI.getDestAddressSpace() != AMDGPUAS::BUFFER_FAT_POINTER)
628 return false;
629 reportFatalUsageError(
630 reason: "memmove() on buffer descriptors is not implemented because pointer "
631 "comparison on buffer descriptors isn't implemented\n");
632}
633
634bool StoreFatPtrsAsIntsAndExpandMemcpyVisitor::visitMemSetInst(
635 MemSetInst &MSI) {
636 if (MSI.getDestAddressSpace() != AMDGPUAS::BUFFER_FAT_POINTER)
637 return false;
638 llvm::expandMemSetAsLoop(MemSet: &MSI);
639 MSI.eraseFromParent();
640 return true;
641}
642
643bool StoreFatPtrsAsIntsAndExpandMemcpyVisitor::visitMemSetPatternInst(
644 MemSetPatternInst &MSPI) {
645 if (MSPI.getDestAddressSpace() != AMDGPUAS::BUFFER_FAT_POINTER)
646 return false;
647 llvm::expandMemSetPatternAsLoop(MemSet: &MSPI);
648 MSPI.eraseFromParent();
649 return true;
650}
651
652namespace {
653/// Convert loads/stores of types that the buffer intrinsics can't handle into
654/// one ore more such loads/stores that consist of legal types.
655///
656/// Do this by
657/// 1. Recursing into structs (and arrays that don't share a memory layout with
658/// vectors) since the intrinsics can't handle complex types.
659/// 2. Converting arrays of non-aggregate, byte-sized types into their
660/// corresponding vectors
661/// 3. Bitcasting unsupported types, namely overly-long scalars and byte
662/// vectors, into vectors of supported types.
663/// 4. Splitting up excessively long reads/writes into multiple operations.
664///
665/// Note that this doesn't handle complex data strucures, but, in the future,
666/// the aggregate load splitter from SROA could be refactored to allow for that
667/// case.
668class LegalizeBufferContentTypesVisitor
669 : public InstVisitor<LegalizeBufferContentTypesVisitor, bool> {
670 friend class InstVisitor<LegalizeBufferContentTypesVisitor, bool>;
671
672 IRBuilder<InstSimplifyFolder> IRB;
673
674 const DataLayout &DL;
675
676 /// If T is [N x U], where U is a scalar type, return the vector type
677 /// <N x U>, otherwise, return T.
678 Type *scalarArrayTypeAsVector(Type *MaybeArrayType);
679 Value *arrayToVector(Value *V, Type *TargetType, const Twine &Name);
680 Value *vectorToArray(Value *V, Type *OrigType, const Twine &Name);
681
682 /// Break up the loads of a struct into the loads of its components
683
684 /// Convert a vector or scalar type that can't be operated on by buffer
685 /// intrinsics to one that would be legal through bitcasts and/or truncation.
686 /// Uses the wider of i32, i16, or i8 where possible.
687 Type *legalNonAggregateFor(Type *T);
688 Value *makeLegalNonAggregate(Value *V, Type *TargetType, const Twine &Name);
689 Value *makeIllegalNonAggregate(Value *V, Type *OrigType, const Twine &Name);
690
691 struct VecSlice {
692 uint64_t Index = 0;
693 uint64_t Length = 0;
694 VecSlice() = delete;
695 // Needed for some Clangs
696 VecSlice(uint64_t Index, uint64_t Length) : Index(Index), Length(Length) {}
697 };
698 /// Return the [index, length] pairs into which `T` needs to be cut to form
699 /// legal buffer load or store operations. Clears `Slices`. Creates an empty
700 /// `Slices` for non-vector inputs and creates one slice if no slicing will be
701 /// needed.
702 void getVecSlices(Type *T, SmallVectorImpl<VecSlice> &Slices);
703
704 Value *extractSlice(Value *Vec, VecSlice S, const Twine &Name);
705 Value *insertSlice(Value *Whole, Value *Part, VecSlice S, const Twine &Name);
706
707 /// In most cases, return `LegalType`. However, when given an input that would
708 /// normally be a legal type for the buffer intrinsics to return but that
709 /// isn't hooked up through SelectionDAG, return a type of the same width that
710 /// can be used with the relevant intrinsics. Specifically, handle the cases:
711 /// - <1 x T> => T for all T
712 /// - <N x i8> <=> i16, i32, 2xi32, 4xi32 (as needed)
713 /// - <N x T> where T is under 32 bits and the total size is 96 bits <=> <3 x
714 /// i32>
715 Type *intrinsicTypeFor(Type *LegalType);
716
717 bool visitLoadImpl(LoadInst &OrigLI, Type *PartType,
718 SmallVectorImpl<uint32_t> &AggIdxs, uint64_t AggByteOffset,
719 Value *&Result, const Twine &Name);
720 /// Return value is (Changed, ModifiedInPlace)
721 std::pair<bool, bool> visitStoreImpl(StoreInst &OrigSI, Type *PartType,
722 SmallVectorImpl<uint32_t> &AggIdxs,
723 uint64_t AggByteOffset,
724 const Twine &Name);
725
726 bool visitInstruction(Instruction &I) { return false; }
727 bool visitLoadInst(LoadInst &LI);
728 bool visitStoreInst(StoreInst &SI);
729
730public:
731 LegalizeBufferContentTypesVisitor(const DataLayout &DL, LLVMContext &Ctx)
732 : IRB(Ctx, InstSimplifyFolder(DL)), DL(DL) {}
733 bool processFunction(Function &F);
734};
735} // namespace
736
737Type *LegalizeBufferContentTypesVisitor::scalarArrayTypeAsVector(Type *T) {
738 ArrayType *AT = dyn_cast<ArrayType>(Val: T);
739 if (!AT)
740 return T;
741 Type *ET = AT->getElementType();
742 if (!ET->isSingleValueType() || isa<VectorType>(Val: ET))
743 reportFatalUsageError(reason: "loading non-scalar arrays from buffer fat pointers "
744 "should have recursed");
745 if (!DL.typeSizeEqualsStoreSize(Ty: AT))
746 reportFatalUsageError(
747 reason: "loading padded arrays from buffer fat pinters should have recursed");
748 return FixedVectorType::get(ElementType: ET, NumElts: AT->getNumElements());
749}
750
751Value *LegalizeBufferContentTypesVisitor::arrayToVector(Value *V,
752 Type *TargetType,
753 const Twine &Name) {
754 Value *VectorRes = PoisonValue::get(T: TargetType);
755 auto *VT = cast<FixedVectorType>(Val: TargetType);
756 unsigned EC = VT->getNumElements();
757 for (auto I : iota_range<unsigned>(0, EC, /*Inclusive=*/false)) {
758 Value *Elem = IRB.CreateExtractValue(Agg: V, Idxs: I, Name: Name + ".elem." + Twine(I));
759 VectorRes = IRB.CreateInsertElement(Vec: VectorRes, NewElt: Elem, Idx: I,
760 Name: Name + ".as.vec." + Twine(I));
761 }
762 return VectorRes;
763}
764
765Value *LegalizeBufferContentTypesVisitor::vectorToArray(Value *V,
766 Type *OrigType,
767 const Twine &Name) {
768 Value *ArrayRes = PoisonValue::get(T: OrigType);
769 ArrayType *AT = cast<ArrayType>(Val: OrigType);
770 unsigned EC = AT->getNumElements();
771 for (auto I : iota_range<unsigned>(0, EC, /*Inclusive=*/false)) {
772 Value *Elem = IRB.CreateExtractElement(Vec: V, Idx: I, Name: Name + ".elem." + Twine(I));
773 ArrayRes = IRB.CreateInsertValue(Agg: ArrayRes, Val: Elem, Idxs: I,
774 Name: Name + ".as.array." + Twine(I));
775 }
776 return ArrayRes;
777}
778
779Type *LegalizeBufferContentTypesVisitor::legalNonAggregateFor(Type *T) {
780 TypeSize Size = DL.getTypeStoreSizeInBits(Ty: T);
781 // Implicitly zero-extend to the next byte if needed
782 if (!DL.typeSizeEqualsStoreSize(Ty: T))
783 T = IRB.getIntNTy(N: Size.getFixedValue());
784 Type *ElemTy = T->getScalarType();
785 if (isa<PointerType, ScalableVectorType>(Val: ElemTy)) {
786 // Pointers are always big enough, and we'll let scalable vectors through to
787 // fail in codegen.
788 return T;
789 }
790 unsigned ElemSize = DL.getTypeSizeInBits(Ty: ElemTy).getFixedValue();
791 if (isPowerOf2_32(Value: ElemSize) && ElemSize >= 16 && ElemSize <= 128) {
792 // [vectors of] anything that's 16/32/64/128 bits can be cast and split into
793 // legal buffer operations.
794 return T;
795 }
796 Type *BestVectorElemType = nullptr;
797 if (Size.isKnownMultipleOf(RHS: 32))
798 BestVectorElemType = IRB.getInt32Ty();
799 else if (Size.isKnownMultipleOf(RHS: 16))
800 BestVectorElemType = IRB.getInt16Ty();
801 else
802 BestVectorElemType = IRB.getInt8Ty();
803 unsigned NumCastElems =
804 Size.getFixedValue() / BestVectorElemType->getIntegerBitWidth();
805 if (NumCastElems == 1)
806 return BestVectorElemType;
807 return FixedVectorType::get(ElementType: BestVectorElemType, NumElts: NumCastElems);
808}
809
810Value *LegalizeBufferContentTypesVisitor::makeLegalNonAggregate(
811 Value *V, Type *TargetType, const Twine &Name) {
812 Type *SourceType = V->getType();
813 TypeSize SourceSize = DL.getTypeSizeInBits(Ty: SourceType);
814 TypeSize TargetSize = DL.getTypeSizeInBits(Ty: TargetType);
815 if (SourceSize != TargetSize) {
816 Type *ShortScalarTy = IRB.getIntNTy(N: SourceSize.getFixedValue());
817 Type *ByteScalarTy = IRB.getIntNTy(N: TargetSize.getFixedValue());
818 Value *AsScalar = IRB.CreateBitCast(V, DestTy: ShortScalarTy, Name: Name + ".as.scalar");
819 Value *Zext = IRB.CreateZExt(V: AsScalar, DestTy: ByteScalarTy, Name: Name + ".zext");
820 V = Zext;
821 SourceType = ByteScalarTy;
822 }
823 return IRB.CreateBitCast(V, DestTy: TargetType, Name: Name + ".legal");
824}
825
826Value *LegalizeBufferContentTypesVisitor::makeIllegalNonAggregate(
827 Value *V, Type *OrigType, const Twine &Name) {
828 Type *LegalType = V->getType();
829 TypeSize LegalSize = DL.getTypeSizeInBits(Ty: LegalType);
830 TypeSize OrigSize = DL.getTypeSizeInBits(Ty: OrigType);
831 if (LegalSize != OrigSize) {
832 Type *ShortScalarTy = IRB.getIntNTy(N: OrigSize.getFixedValue());
833 Type *ByteScalarTy = IRB.getIntNTy(N: LegalSize.getFixedValue());
834 Value *AsScalar = IRB.CreateBitCast(V, DestTy: ByteScalarTy, Name: Name + ".bytes.cast");
835 Value *Trunc = IRB.CreateTrunc(V: AsScalar, DestTy: ShortScalarTy, Name: Name + ".trunc");
836 return IRB.CreateBitCast(V: Trunc, DestTy: OrigType, Name: Name + ".orig");
837 }
838 return IRB.CreateBitCast(V, DestTy: OrigType, Name: Name + ".real.ty");
839}
840
841Type *LegalizeBufferContentTypesVisitor::intrinsicTypeFor(Type *LegalType) {
842 auto *VT = dyn_cast<FixedVectorType>(Val: LegalType);
843 if (!VT)
844 return LegalType;
845 Type *ET = VT->getElementType();
846 // Explicitly return the element type of 1-element vectors because the
847 // underlying intrinsics don't like <1 x T> even though it's a synonym for T.
848 if (VT->getNumElements() == 1)
849 return ET;
850 if (DL.getTypeSizeInBits(Ty: LegalType) == 96 && DL.getTypeSizeInBits(Ty: ET) < 32)
851 return FixedVectorType::get(ElementType: IRB.getInt32Ty(), NumElts: 3);
852 if (ET->isIntegerTy(Bitwidth: 8)) {
853 switch (VT->getNumElements()) {
854 default:
855 return LegalType; // Let it crash later
856 case 1:
857 return IRB.getInt8Ty();
858 case 2:
859 return IRB.getInt16Ty();
860 case 4:
861 return IRB.getInt32Ty();
862 case 8:
863 return FixedVectorType::get(ElementType: IRB.getInt32Ty(), NumElts: 2);
864 case 16:
865 return FixedVectorType::get(ElementType: IRB.getInt32Ty(), NumElts: 4);
866 }
867 }
868 return LegalType;
869}
870
871void LegalizeBufferContentTypesVisitor::getVecSlices(
872 Type *T, SmallVectorImpl<VecSlice> &Slices) {
873 Slices.clear();
874 auto *VT = dyn_cast<FixedVectorType>(Val: T);
875 if (!VT)
876 return;
877
878 uint64_t ElemBitWidth =
879 DL.getTypeSizeInBits(Ty: VT->getElementType()).getFixedValue();
880
881 uint64_t ElemsPer4Words = 128 / ElemBitWidth;
882 uint64_t ElemsPer2Words = ElemsPer4Words / 2;
883 uint64_t ElemsPerWord = ElemsPer2Words / 2;
884 uint64_t ElemsPerShort = ElemsPerWord / 2;
885 uint64_t ElemsPerByte = ElemsPerShort / 2;
886 // If the elements evenly pack into 32-bit words, we can use 3-word stores,
887 // such as for <6 x bfloat> or <3 x i32>, but we can't dot his for, for
888 // example, <3 x i64>, since that's not slicing.
889 uint64_t ElemsPer3Words = ElemsPerWord * 3;
890
891 uint64_t TotalElems = VT->getNumElements();
892 uint64_t Index = 0;
893 auto TrySlice = [&](unsigned MaybeLen) {
894 if (MaybeLen > 0 && Index + MaybeLen <= TotalElems) {
895 VecSlice Slice{/*Index=*/Index, /*Length=*/MaybeLen};
896 Slices.push_back(Elt: Slice);
897 Index += MaybeLen;
898 return true;
899 }
900 return false;
901 };
902 while (Index < TotalElems) {
903 TrySlice(ElemsPer4Words) || TrySlice(ElemsPer3Words) ||
904 TrySlice(ElemsPer2Words) || TrySlice(ElemsPerWord) ||
905 TrySlice(ElemsPerShort) || TrySlice(ElemsPerByte);
906 }
907}
908
909Value *LegalizeBufferContentTypesVisitor::extractSlice(Value *Vec, VecSlice S,
910 const Twine &Name) {
911 auto *VecVT = dyn_cast<FixedVectorType>(Val: Vec->getType());
912 if (!VecVT)
913 return Vec;
914 if (S.Length == VecVT->getNumElements() && S.Index == 0)
915 return Vec;
916 if (S.Length == 1)
917 return IRB.CreateExtractElement(Vec, Idx: S.Index,
918 Name: Name + ".slice." + Twine(S.Index));
919 SmallVector<int> Mask = llvm::to_vector(
920 Range: llvm::iota_range<int>(S.Index, S.Index + S.Length, /*Inclusive=*/false));
921 return IRB.CreateShuffleVector(V: Vec, Mask, Name: Name + ".slice." + Twine(S.Index));
922}
923
924Value *LegalizeBufferContentTypesVisitor::insertSlice(Value *Whole, Value *Part,
925 VecSlice S,
926 const Twine &Name) {
927 auto *WholeVT = dyn_cast<FixedVectorType>(Val: Whole->getType());
928 if (!WholeVT)
929 return Part;
930 if (S.Length == WholeVT->getNumElements() && S.Index == 0)
931 return Part;
932 if (S.Length == 1) {
933 return IRB.CreateInsertElement(Vec: Whole, NewElt: Part, Idx: S.Index,
934 Name: Name + ".slice." + Twine(S.Index));
935 }
936 int NumElems = cast<FixedVectorType>(Val: Whole->getType())->getNumElements();
937
938 // Extend the slice with poisons to make the main shufflevector happy.
939 SmallVector<int> ExtPartMask(NumElems, -1);
940 for (auto [I, E] : llvm::enumerate(
941 First: MutableArrayRef<int>(ExtPartMask).take_front(N: S.Length))) {
942 E = I;
943 }
944 Value *ExtPart = IRB.CreateShuffleVector(V: Part, Mask: ExtPartMask,
945 Name: Name + ".ext." + Twine(S.Index));
946
947 SmallVector<int> Mask =
948 llvm::to_vector(Range: llvm::iota_range<int>(0, NumElems, /*Inclusive=*/false));
949 for (auto [I, E] :
950 llvm::enumerate(First: MutableArrayRef<int>(Mask).slice(N: S.Index, M: S.Length)))
951 E = I + NumElems;
952 return IRB.CreateShuffleVector(V1: Whole, V2: ExtPart, Mask,
953 Name: Name + ".parts." + Twine(S.Index));
954}
955
956bool LegalizeBufferContentTypesVisitor::visitLoadImpl(
957 LoadInst &OrigLI, Type *PartType, SmallVectorImpl<uint32_t> &AggIdxs,
958 uint64_t AggByteOff, Value *&Result, const Twine &Name) {
959 if (auto *ST = dyn_cast<StructType>(Val: PartType)) {
960 const StructLayout *Layout = DL.getStructLayout(Ty: ST);
961 bool Changed = false;
962 for (auto [I, ElemTy, Offset] :
963 llvm::enumerate(First: ST->elements(), Rest: Layout->getMemberOffsets())) {
964 AggIdxs.push_back(Elt: I);
965 Changed |= visitLoadImpl(OrigLI, PartType: ElemTy, AggIdxs,
966 AggByteOff: AggByteOff + Offset.getFixedValue(), Result,
967 Name: Name + "." + Twine(I));
968 AggIdxs.pop_back();
969 }
970 return Changed;
971 }
972 if (auto *AT = dyn_cast<ArrayType>(Val: PartType)) {
973 Type *ElemTy = AT->getElementType();
974 if (!ElemTy->isSingleValueType() || !DL.typeSizeEqualsStoreSize(Ty: ElemTy) ||
975 ElemTy->isVectorTy()) {
976 TypeSize ElemStoreSize = DL.getTypeStoreSize(Ty: ElemTy);
977 bool Changed = false;
978 for (auto I : llvm::iota_range<uint32_t>(0, AT->getNumElements(),
979 /*Inclusive=*/false)) {
980 AggIdxs.push_back(Elt: I);
981 Changed |= visitLoadImpl(OrigLI, PartType: ElemTy, AggIdxs,
982 AggByteOff: AggByteOff + I * ElemStoreSize.getFixedValue(),
983 Result, Name: Name + Twine(I));
984 AggIdxs.pop_back();
985 }
986 return Changed;
987 }
988 }
989
990 // Typical case
991
992 Type *ArrayAsVecType = scalarArrayTypeAsVector(T: PartType);
993 Type *LegalType = legalNonAggregateFor(T: ArrayAsVecType);
994
995 SmallVector<VecSlice> Slices;
996 getVecSlices(T: LegalType, Slices);
997 bool HasSlices = Slices.size() > 1;
998 bool IsAggPart = !AggIdxs.empty();
999 Value *LoadsRes;
1000 if (!HasSlices && !IsAggPart) {
1001 Type *LoadableType = intrinsicTypeFor(LegalType);
1002 if (LoadableType == PartType)
1003 return false;
1004
1005 IRB.SetInsertPoint(&OrigLI);
1006 auto *NLI = cast<LoadInst>(Val: OrigLI.clone());
1007 NLI->mutateType(Ty: LoadableType);
1008 NLI = IRB.Insert(I: NLI);
1009 NLI->setName(Name + ".loadable");
1010
1011 LoadsRes = IRB.CreateBitCast(V: NLI, DestTy: LegalType, Name: Name + ".from.loadable");
1012 } else {
1013 IRB.SetInsertPoint(&OrigLI);
1014 LoadsRes = PoisonValue::get(T: LegalType);
1015 Value *OrigPtr = OrigLI.getPointerOperand();
1016 // If we're needing to spill something into more than one load, its legal
1017 // type will be a vector (ex. an i256 load will have LegalType = <8 x i32>).
1018 // But if we're already a scalar (which can happen if we're splitting up a
1019 // struct), the element type will be the legal type itself.
1020 Type *ElemType = LegalType->getScalarType();
1021 unsigned ElemBytes = DL.getTypeStoreSize(Ty: ElemType);
1022 AAMDNodes AANodes = OrigLI.getAAMetadata();
1023 if (IsAggPart && Slices.empty())
1024 Slices.push_back(Elt: VecSlice{/*Index=*/0, /*Length=*/1});
1025 for (VecSlice S : Slices) {
1026 Type *SliceType =
1027 S.Length != 1 ? FixedVectorType::get(ElementType: ElemType, NumElts: S.Length) : ElemType;
1028 int64_t ByteOffset = AggByteOff + S.Index * ElemBytes;
1029 // You can't reasonably expect loads to wrap around the edge of memory.
1030 Value *NewPtr = IRB.CreateGEP(
1031 Ty: IRB.getInt8Ty(), Ptr: OrigLI.getPointerOperand(), IdxList: IRB.getInt32(C: ByteOffset),
1032 Name: OrigPtr->getName() + ".off.ptr." + Twine(ByteOffset),
1033 NW: GEPNoWrapFlags::noUnsignedWrap());
1034 Type *LoadableType = intrinsicTypeFor(LegalType: SliceType);
1035 LoadInst *NewLI = IRB.CreateAlignedLoad(
1036 Ty: LoadableType, Ptr: NewPtr, Align: commonAlignment(A: OrigLI.getAlign(), Offset: ByteOffset),
1037 Name: Name + ".off." + Twine(ByteOffset));
1038 copyMetadataForLoad(Dest&: *NewLI, Source: OrigLI);
1039 NewLI->setAAMetadata(
1040 AANodes.adjustForAccess(Offset: ByteOffset, AccessTy: LoadableType, DL));
1041 NewLI->setAtomic(Ordering: OrigLI.getOrdering(), SSID: OrigLI.getSyncScopeID());
1042 NewLI->setVolatile(OrigLI.isVolatile());
1043 Value *Loaded = IRB.CreateBitCast(V: NewLI, DestTy: SliceType,
1044 Name: NewLI->getName() + ".from.loadable");
1045 LoadsRes = insertSlice(Whole: LoadsRes, Part: Loaded, S, Name);
1046 }
1047 }
1048 if (LegalType != ArrayAsVecType)
1049 LoadsRes = makeIllegalNonAggregate(V: LoadsRes, OrigType: ArrayAsVecType, Name);
1050 if (ArrayAsVecType != PartType)
1051 LoadsRes = vectorToArray(V: LoadsRes, OrigType: PartType, Name);
1052
1053 if (IsAggPart)
1054 Result = IRB.CreateInsertValue(Agg: Result, Val: LoadsRes, Idxs: AggIdxs, Name);
1055 else
1056 Result = LoadsRes;
1057 return true;
1058}
1059
1060bool LegalizeBufferContentTypesVisitor::visitLoadInst(LoadInst &LI) {
1061 if (LI.getPointerAddressSpace() != AMDGPUAS::BUFFER_FAT_POINTER)
1062 return false;
1063
1064 SmallVector<uint32_t> AggIdxs;
1065 Type *OrigType = LI.getType();
1066 Value *Result = PoisonValue::get(T: OrigType);
1067 bool Changed = visitLoadImpl(OrigLI&: LI, PartType: OrigType, AggIdxs, AggByteOff: 0, Result, Name: LI.getName());
1068 if (!Changed)
1069 return false;
1070 Result->takeName(V: &LI);
1071 LI.replaceAllUsesWith(V: Result);
1072 LI.eraseFromParent();
1073 return Changed;
1074}
1075
1076std::pair<bool, bool> LegalizeBufferContentTypesVisitor::visitStoreImpl(
1077 StoreInst &OrigSI, Type *PartType, SmallVectorImpl<uint32_t> &AggIdxs,
1078 uint64_t AggByteOff, const Twine &Name) {
1079 if (auto *ST = dyn_cast<StructType>(Val: PartType)) {
1080 const StructLayout *Layout = DL.getStructLayout(Ty: ST);
1081 bool Changed = false;
1082 for (auto [I, ElemTy, Offset] :
1083 llvm::enumerate(First: ST->elements(), Rest: Layout->getMemberOffsets())) {
1084 AggIdxs.push_back(Elt: I);
1085 Changed |= std::get<0>(in: visitStoreImpl(OrigSI, PartType: ElemTy, AggIdxs,
1086 AggByteOff: AggByteOff + Offset.getFixedValue(),
1087 Name: Name + "." + Twine(I)));
1088 AggIdxs.pop_back();
1089 }
1090 return std::make_pair(x&: Changed, /*ModifiedInPlace=*/y: false);
1091 }
1092 if (auto *AT = dyn_cast<ArrayType>(Val: PartType)) {
1093 Type *ElemTy = AT->getElementType();
1094 if (!ElemTy->isSingleValueType() || !DL.typeSizeEqualsStoreSize(Ty: ElemTy) ||
1095 ElemTy->isVectorTy()) {
1096 TypeSize ElemStoreSize = DL.getTypeStoreSize(Ty: ElemTy);
1097 bool Changed = false;
1098 for (auto I : llvm::iota_range<uint32_t>(0, AT->getNumElements(),
1099 /*Inclusive=*/false)) {
1100 AggIdxs.push_back(Elt: I);
1101 Changed |= std::get<0>(in: visitStoreImpl(
1102 OrigSI, PartType: ElemTy, AggIdxs,
1103 AggByteOff: AggByteOff + I * ElemStoreSize.getFixedValue(), Name: Name + Twine(I)));
1104 AggIdxs.pop_back();
1105 }
1106 return std::make_pair(x&: Changed, /*ModifiedInPlace=*/y: false);
1107 }
1108 }
1109
1110 Value *OrigData = OrigSI.getValueOperand();
1111 Value *NewData = OrigData;
1112
1113 bool IsAggPart = !AggIdxs.empty();
1114 if (IsAggPart)
1115 NewData = IRB.CreateExtractValue(Agg: NewData, Idxs: AggIdxs, Name);
1116
1117 Type *ArrayAsVecType = scalarArrayTypeAsVector(T: PartType);
1118 if (ArrayAsVecType != PartType) {
1119 NewData = arrayToVector(V: NewData, TargetType: ArrayAsVecType, Name);
1120 }
1121
1122 Type *LegalType = legalNonAggregateFor(T: ArrayAsVecType);
1123 if (LegalType != ArrayAsVecType) {
1124 NewData = makeLegalNonAggregate(V: NewData, TargetType: LegalType, Name);
1125 }
1126
1127 SmallVector<VecSlice> Slices;
1128 getVecSlices(T: LegalType, Slices);
1129 bool NeedToSplit = Slices.size() > 1 || IsAggPart;
1130 if (!NeedToSplit) {
1131 Type *StorableType = intrinsicTypeFor(LegalType);
1132 if (StorableType == PartType)
1133 return std::make_pair(/*Changed=*/x: false, /*ModifiedInPlace=*/y: false);
1134 NewData = IRB.CreateBitCast(V: NewData, DestTy: StorableType, Name: Name + ".storable");
1135 OrigSI.setOperand(i_nocapture: 0, Val_nocapture: NewData);
1136 return std::make_pair(/*Changed=*/x: true, /*ModifiedInPlace=*/y: true);
1137 }
1138
1139 Value *OrigPtr = OrigSI.getPointerOperand();
1140 Type *ElemType = LegalType->getScalarType();
1141 if (IsAggPart && Slices.empty())
1142 Slices.push_back(Elt: VecSlice{/*Index=*/0, /*Length=*/1});
1143 unsigned ElemBytes = DL.getTypeStoreSize(Ty: ElemType);
1144 AAMDNodes AANodes = OrigSI.getAAMetadata();
1145 for (VecSlice S : Slices) {
1146 Type *SliceType =
1147 S.Length != 1 ? FixedVectorType::get(ElementType: ElemType, NumElts: S.Length) : ElemType;
1148 int64_t ByteOffset = AggByteOff + S.Index * ElemBytes;
1149 Value *NewPtr =
1150 IRB.CreateGEP(Ty: IRB.getInt8Ty(), Ptr: OrigPtr, IdxList: IRB.getInt32(C: ByteOffset),
1151 Name: OrigPtr->getName() + ".part." + Twine(S.Index),
1152 NW: GEPNoWrapFlags::noUnsignedWrap());
1153 Value *DataSlice = extractSlice(Vec: NewData, S, Name);
1154 Type *StorableType = intrinsicTypeFor(LegalType: SliceType);
1155 DataSlice = IRB.CreateBitCast(V: DataSlice, DestTy: StorableType,
1156 Name: DataSlice->getName() + ".storable");
1157 auto *NewSI = cast<StoreInst>(Val: OrigSI.clone());
1158 NewSI->setAlignment(commonAlignment(A: OrigSI.getAlign(), Offset: ByteOffset));
1159 IRB.Insert(I: NewSI);
1160 NewSI->setOperand(i_nocapture: 0, Val_nocapture: DataSlice);
1161 NewSI->setOperand(i_nocapture: 1, Val_nocapture: NewPtr);
1162 NewSI->setAAMetadata(AANodes.adjustForAccess(Offset: ByteOffset, AccessTy: StorableType, DL));
1163 }
1164 return std::make_pair(/*Changed=*/x: true, /*ModifiedInPlace=*/y: false);
1165}
1166
1167bool LegalizeBufferContentTypesVisitor::visitStoreInst(StoreInst &SI) {
1168 if (SI.getPointerAddressSpace() != AMDGPUAS::BUFFER_FAT_POINTER)
1169 return false;
1170 IRB.SetInsertPoint(&SI);
1171 SmallVector<uint32_t> AggIdxs;
1172 Value *OrigData = SI.getValueOperand();
1173 auto [Changed, ModifiedInPlace] =
1174 visitStoreImpl(OrigSI&: SI, PartType: OrigData->getType(), AggIdxs, AggByteOff: 0, Name: OrigData->getName());
1175 if (Changed && !ModifiedInPlace)
1176 SI.eraseFromParent();
1177 return Changed;
1178}
1179
1180bool LegalizeBufferContentTypesVisitor::processFunction(Function &F) {
1181 bool Changed = false;
1182 // Note, memory transfer intrinsics won't
1183 for (Instruction &I : make_early_inc_range(Range: instructions(F))) {
1184 Changed |= visit(I);
1185 }
1186 return Changed;
1187}
1188
1189/// Return the ptr addrspace(8) and i32 (resource and offset parts) in a lowered
1190/// buffer fat pointer constant.
1191static std::pair<Constant *, Constant *>
1192splitLoweredFatBufferConst(Constant *C) {
1193 assert(isSplitFatPtr(C->getType()) && "Not a split fat buffer pointer");
1194 return std::make_pair(x: C->getAggregateElement(Elt: 0u), y: C->getAggregateElement(Elt: 1u));
1195}
1196
1197namespace {
1198/// Handle the remapping of ptr addrspace(7) constants.
1199class FatPtrConstMaterializer final : public ValueMaterializer {
1200 BufferFatPtrToStructTypeMap *TypeMap;
1201 // An internal mapper that is used to recurse into the arguments of constants.
1202 // While the documentation for `ValueMapper` specifies not to use it
1203 // recursively, examination of the logic in mapValue() shows that it can
1204 // safely be used recursively when handling constants, like it does in its own
1205 // logic.
1206 ValueMapper InternalMapper;
1207
1208 Constant *materializeBufferFatPtrConst(Constant *C);
1209
1210public:
1211 // UnderlyingMap is the value map this materializer will be filling.
1212 FatPtrConstMaterializer(BufferFatPtrToStructTypeMap *TypeMap,
1213 ValueToValueMapTy &UnderlyingMap)
1214 : TypeMap(TypeMap),
1215 InternalMapper(UnderlyingMap, RF_None, TypeMap, this) {}
1216 ~FatPtrConstMaterializer() = default;
1217
1218 Value *materialize(Value *V) override;
1219};
1220} // namespace
1221
1222Constant *FatPtrConstMaterializer::materializeBufferFatPtrConst(Constant *C) {
1223 Type *SrcTy = C->getType();
1224 auto *NewTy = dyn_cast<StructType>(Val: TypeMap->remapType(SrcTy));
1225 if (C->isNullValue())
1226 return ConstantAggregateZero::getNullValue(Ty: NewTy);
1227 if (isa<PoisonValue>(Val: C)) {
1228 return ConstantStruct::get(T: NewTy,
1229 V: {PoisonValue::get(T: NewTy->getElementType(N: 0)),
1230 PoisonValue::get(T: NewTy->getElementType(N: 1))});
1231 }
1232 if (isa<UndefValue>(Val: C)) {
1233 return ConstantStruct::get(T: NewTy,
1234 V: {UndefValue::get(T: NewTy->getElementType(N: 0)),
1235 UndefValue::get(T: NewTy->getElementType(N: 1))});
1236 }
1237
1238 if (auto *VC = dyn_cast<ConstantVector>(Val: C)) {
1239 if (Constant *S = VC->getSplatValue()) {
1240 Constant *NewS = InternalMapper.mapConstant(C: *S);
1241 if (!NewS)
1242 return nullptr;
1243 auto [Rsrc, Off] = splitLoweredFatBufferConst(C: NewS);
1244 auto EC = VC->getType()->getElementCount();
1245 return ConstantStruct::get(T: NewTy, V: {ConstantVector::getSplat(EC, Elt: Rsrc),
1246 ConstantVector::getSplat(EC, Elt: Off)});
1247 }
1248 SmallVector<Constant *> Rsrcs;
1249 SmallVector<Constant *> Offs;
1250 for (Value *Op : VC->operand_values()) {
1251 auto *NewOp = dyn_cast_or_null<Constant>(Val: InternalMapper.mapValue(V: *Op));
1252 if (!NewOp)
1253 return nullptr;
1254 auto [Rsrc, Off] = splitLoweredFatBufferConst(C: NewOp);
1255 Rsrcs.push_back(Elt: Rsrc);
1256 Offs.push_back(Elt: Off);
1257 }
1258 Constant *RsrcVec = ConstantVector::get(V: Rsrcs);
1259 Constant *OffVec = ConstantVector::get(V: Offs);
1260 return ConstantStruct::get(T: NewTy, V: {RsrcVec, OffVec});
1261 }
1262
1263 if (isa<GlobalValue>(Val: C))
1264 reportFatalUsageError(reason: "global values containing ptr addrspace(7) (buffer "
1265 "fat pointer) values are not supported");
1266
1267 if (isa<ConstantExpr>(Val: C))
1268 reportFatalUsageError(
1269 reason: "constant exprs containing ptr addrspace(7) (buffer "
1270 "fat pointer) values should have been expanded earlier");
1271
1272 return nullptr;
1273}
1274
1275Value *FatPtrConstMaterializer::materialize(Value *V) {
1276 Constant *C = dyn_cast<Constant>(Val: V);
1277 if (!C)
1278 return nullptr;
1279 // Structs and other types that happen to contain fat pointers get remapped
1280 // by the mapValue() logic.
1281 if (!isBufferFatPtrConst(C))
1282 return nullptr;
1283 return materializeBufferFatPtrConst(C);
1284}
1285
1286using PtrParts = std::pair<Value *, Value *>;
1287namespace {
1288// The visitor returns the resource and offset parts for an instruction if they
1289// can be computed, or (nullptr, nullptr) for cases that don't have a meaningful
1290// value mapping.
1291class SplitPtrStructs : public InstVisitor<SplitPtrStructs, PtrParts> {
1292 ValueToValueMapTy RsrcParts;
1293 ValueToValueMapTy OffParts;
1294
1295 // Track instructions that have been rewritten into a user of the component
1296 // parts of their ptr addrspace(7) input. Instructions that produced
1297 // ptr addrspace(7) parts should **not** be RAUW'd before being added to this
1298 // set, as that replacement will be handled in a post-visit step. However,
1299 // instructions that yield values that aren't fat pointers (ex. ptrtoint)
1300 // should RAUW themselves with new instructions that use the split parts
1301 // of their arguments during processing.
1302 DenseSet<Instruction *> SplitUsers;
1303
1304 // Nodes that need a second look once we've computed the parts for all other
1305 // instructions to see if, for example, we really need to phi on the resource
1306 // part.
1307 SmallVector<Instruction *> Conditionals;
1308 // Temporary instructions produced while lowering conditionals that should be
1309 // killed.
1310 SmallVector<Instruction *> ConditionalTemps;
1311
1312 // Subtarget info, needed for determining what cache control bits to set.
1313 const TargetMachine *TM;
1314 const GCNSubtarget *ST = nullptr;
1315
1316 IRBuilder<InstSimplifyFolder> IRB;
1317
1318 // Copy metadata between instructions if applicable.
1319 void copyMetadata(Value *Dest, Value *Src);
1320
1321 // Get the resource and offset parts of the value V, inserting appropriate
1322 // extractvalue calls if needed.
1323 PtrParts getPtrParts(Value *V);
1324
1325 // Given an instruction that could produce multiple resource parts (a PHI or
1326 // select), collect the set of possible instructions that could have provided
1327 // its resource parts that it could have (the `Roots`) and the set of
1328 // conditional instructions visited during the search (`Seen`). If, after
1329 // removing the root of the search from `Seen` and `Roots`, `Seen` is a subset
1330 // of `Roots` and `Roots - Seen` contains one element, the resource part of
1331 // that element can replace the resource part of all other elements in `Seen`.
1332 void getPossibleRsrcRoots(Instruction *I, SmallPtrSetImpl<Value *> &Roots,
1333 SmallPtrSetImpl<Value *> &Seen);
1334 void processConditionals();
1335
1336 // If an instruction hav been split into resource and offset parts,
1337 // delete that instruction. If any of its uses have not themselves been split
1338 // into parts (for example, an insertvalue), construct the structure
1339 // that the type rewrites declared should be produced by the dying instruction
1340 // and use that.
1341 // Also, kill the temporary extractvalue operations produced by the two-stage
1342 // lowering of PHIs and conditionals.
1343 void killAndReplaceSplitInstructions(SmallVectorImpl<Instruction *> &Origs);
1344
1345 void setAlign(CallInst *Intr, Align A, unsigned RsrcArgIdx);
1346 void insertPreMemOpFence(AtomicOrdering Order, SyncScope::ID SSID);
1347 void insertPostMemOpFence(AtomicOrdering Order, SyncScope::ID SSID);
1348 Value *handleMemoryInst(Instruction *I, Value *Arg, Value *Ptr, Type *Ty,
1349 Align Alignment, AtomicOrdering Order,
1350 bool IsVolatile, SyncScope::ID SSID);
1351
1352public:
1353 SplitPtrStructs(const DataLayout &DL, LLVMContext &Ctx,
1354 const TargetMachine *TM)
1355 : TM(TM), IRB(Ctx, InstSimplifyFolder(DL)) {}
1356
1357 void processFunction(Function &F);
1358
1359 PtrParts visitInstruction(Instruction &I);
1360 PtrParts visitLoadInst(LoadInst &LI);
1361 PtrParts visitStoreInst(StoreInst &SI);
1362 PtrParts visitAtomicRMWInst(AtomicRMWInst &AI);
1363 PtrParts visitAtomicCmpXchgInst(AtomicCmpXchgInst &AI);
1364 PtrParts visitGetElementPtrInst(GetElementPtrInst &GEP);
1365
1366 PtrParts visitPtrToAddrInst(PtrToAddrInst &PA);
1367 PtrParts visitPtrToIntInst(PtrToIntInst &PI);
1368 PtrParts visitIntToPtrInst(IntToPtrInst &IP);
1369 PtrParts visitAddrSpaceCastInst(AddrSpaceCastInst &I);
1370 PtrParts visitICmpInst(ICmpInst &Cmp);
1371 PtrParts visitFreezeInst(FreezeInst &I);
1372
1373 PtrParts visitExtractElementInst(ExtractElementInst &I);
1374 PtrParts visitInsertElementInst(InsertElementInst &I);
1375 PtrParts visitShuffleVectorInst(ShuffleVectorInst &I);
1376
1377 PtrParts visitPHINode(PHINode &PHI);
1378 PtrParts visitSelectInst(SelectInst &SI);
1379
1380 PtrParts visitIntrinsicInst(IntrinsicInst &II);
1381};
1382} // namespace
1383
1384void SplitPtrStructs::copyMetadata(Value *Dest, Value *Src) {
1385 auto *DestI = dyn_cast<Instruction>(Val: Dest);
1386 auto *SrcI = dyn_cast<Instruction>(Val: Src);
1387
1388 if (!DestI || !SrcI)
1389 return;
1390
1391 DestI->copyMetadata(SrcInst: *SrcI);
1392}
1393
1394PtrParts SplitPtrStructs::getPtrParts(Value *V) {
1395 assert(isSplitFatPtr(V->getType()) && "it's not meaningful to get the parts "
1396 "of something that wasn't rewritten");
1397 auto *RsrcEntry = &RsrcParts[V];
1398 auto *OffEntry = &OffParts[V];
1399 if (*RsrcEntry && *OffEntry)
1400 return {*RsrcEntry, *OffEntry};
1401
1402 if (auto *C = dyn_cast<Constant>(Val: V)) {
1403 auto [Rsrc, Off] = splitLoweredFatBufferConst(C);
1404 return {*RsrcEntry = Rsrc, *OffEntry = Off};
1405 }
1406
1407 IRBuilder<InstSimplifyFolder>::InsertPointGuard Guard(IRB);
1408 if (auto *I = dyn_cast<Instruction>(Val: V)) {
1409 LLVM_DEBUG(dbgs() << "Recursing to split parts of " << *I << "\n");
1410 auto [Rsrc, Off] = visit(I&: *I);
1411 if (Rsrc && Off)
1412 return {*RsrcEntry = Rsrc, *OffEntry = Off};
1413 // We'll be creating the new values after the relevant instruction.
1414 // This instruction generates a value and so isn't a terminator.
1415 IRB.SetInsertPoint(*I->getInsertionPointAfterDef());
1416 IRB.SetCurrentDebugLocation(I->getDebugLoc());
1417 } else if (auto *A = dyn_cast<Argument>(Val: V)) {
1418 IRB.SetInsertPointPastAllocas(A->getParent());
1419 IRB.SetCurrentDebugLocation(DebugLoc());
1420 }
1421 Value *Rsrc = IRB.CreateExtractValue(Agg: V, Idxs: 0, Name: V->getName() + ".rsrc");
1422 Value *Off = IRB.CreateExtractValue(Agg: V, Idxs: 1, Name: V->getName() + ".off");
1423 return {*RsrcEntry = Rsrc, *OffEntry = Off};
1424}
1425
1426/// Returns the instruction that defines the resource part of the value V.
1427/// Note that this is not getUnderlyingObject(), since that looks through
1428/// operations like ptrmask which might modify the resource part.
1429///
1430/// We can limit ourselves to just looking through GEPs followed by looking
1431/// through addrspacecasts because only those two operations preserve the
1432/// resource part, and because operations on an `addrspace(8)` (which is the
1433/// legal input to this addrspacecast) would produce a different resource part.
1434static Value *rsrcPartRoot(Value *V) {
1435 while (auto *GEP = dyn_cast<GEPOperator>(Val: V))
1436 V = GEP->getPointerOperand();
1437 while (auto *ASC = dyn_cast<AddrSpaceCastOperator>(Val: V))
1438 V = ASC->getPointerOperand();
1439 return V;
1440}
1441
1442void SplitPtrStructs::getPossibleRsrcRoots(Instruction *I,
1443 SmallPtrSetImpl<Value *> &Roots,
1444 SmallPtrSetImpl<Value *> &Seen) {
1445 if (auto *PHI = dyn_cast<PHINode>(Val: I)) {
1446 if (!Seen.insert(Ptr: I).second)
1447 return;
1448 for (Value *In : PHI->incoming_values()) {
1449 In = rsrcPartRoot(V: In);
1450 Roots.insert(Ptr: In);
1451 if (isa<PHINode, SelectInst>(Val: In))
1452 getPossibleRsrcRoots(I: cast<Instruction>(Val: In), Roots, Seen);
1453 }
1454 } else if (auto *SI = dyn_cast<SelectInst>(Val: I)) {
1455 if (!Seen.insert(Ptr: SI).second)
1456 return;
1457 Value *TrueVal = rsrcPartRoot(V: SI->getTrueValue());
1458 Value *FalseVal = rsrcPartRoot(V: SI->getFalseValue());
1459 Roots.insert(Ptr: TrueVal);
1460 Roots.insert(Ptr: FalseVal);
1461 if (isa<PHINode, SelectInst>(Val: TrueVal))
1462 getPossibleRsrcRoots(I: cast<Instruction>(Val: TrueVal), Roots, Seen);
1463 if (isa<PHINode, SelectInst>(Val: FalseVal))
1464 getPossibleRsrcRoots(I: cast<Instruction>(Val: FalseVal), Roots, Seen);
1465 } else {
1466 llvm_unreachable("getPossibleRsrcParts() only works on phi and select");
1467 }
1468}
1469
1470void SplitPtrStructs::processConditionals() {
1471 SmallDenseMap<Value *, Value *> FoundRsrcs;
1472 SmallPtrSet<Value *, 4> Roots;
1473 SmallPtrSet<Value *, 4> Seen;
1474 for (Instruction *I : Conditionals) {
1475 // These have to exist by now because we've visited these nodes.
1476 Value *Rsrc = RsrcParts[I];
1477 Value *Off = OffParts[I];
1478 assert(Rsrc && Off && "must have visited conditionals by now");
1479
1480 std::optional<Value *> MaybeRsrc;
1481 auto MaybeFoundRsrc = FoundRsrcs.find(Val: I);
1482 if (MaybeFoundRsrc != FoundRsrcs.end()) {
1483 MaybeRsrc = MaybeFoundRsrc->second;
1484 } else {
1485 IRBuilder<InstSimplifyFolder>::InsertPointGuard Guard(IRB);
1486 Roots.clear();
1487 Seen.clear();
1488 getPossibleRsrcRoots(I, Roots, Seen);
1489 LLVM_DEBUG(dbgs() << "Processing conditional: " << *I << "\n");
1490#ifndef NDEBUG
1491 for (Value *V : Roots)
1492 LLVM_DEBUG(dbgs() << "Root: " << *V << "\n");
1493 for (Value *V : Seen)
1494 LLVM_DEBUG(dbgs() << "Seen: " << *V << "\n");
1495#endif
1496 // If we are our own possible root, then we shouldn't block our
1497 // replacement with a valid incoming value.
1498 Roots.erase(Ptr: I);
1499 // We don't want to block the optimization for conditionals that don't
1500 // refer to themselves but did see themselves during the traversal.
1501 Seen.erase(Ptr: I);
1502
1503 if (set_is_subset(S1: Seen, S2: Roots)) {
1504 auto Diff = set_difference(S1: Roots, S2: Seen);
1505 if (Diff.size() == 1) {
1506 Value *RootVal = *Diff.begin();
1507 // Handle the case where previous loops already looked through
1508 // an addrspacecast.
1509 if (isSplitFatPtr(Ty: RootVal->getType()))
1510 MaybeRsrc = std::get<0>(in: getPtrParts(V: RootVal));
1511 else
1512 MaybeRsrc = RootVal;
1513 }
1514 }
1515 }
1516
1517 if (auto *PHI = dyn_cast<PHINode>(Val: I)) {
1518 Value *NewRsrc;
1519 StructType *PHITy = cast<StructType>(Val: PHI->getType());
1520 IRB.SetInsertPoint(*PHI->getInsertionPointAfterDef());
1521 IRB.SetCurrentDebugLocation(PHI->getDebugLoc());
1522 if (MaybeRsrc) {
1523 NewRsrc = *MaybeRsrc;
1524 } else {
1525 Type *RsrcTy = PHITy->getElementType(N: 0);
1526 auto *RsrcPHI = IRB.CreatePHI(Ty: RsrcTy, NumReservedValues: PHI->getNumIncomingValues());
1527 RsrcPHI->takeName(V: Rsrc);
1528 for (auto [V, BB] : llvm::zip(t: PHI->incoming_values(), u: PHI->blocks())) {
1529 Value *VRsrc = std::get<0>(in: getPtrParts(V));
1530 RsrcPHI->addIncoming(V: VRsrc, BB);
1531 }
1532 copyMetadata(Dest: RsrcPHI, Src: PHI);
1533 NewRsrc = RsrcPHI;
1534 }
1535
1536 Type *OffTy = PHITy->getElementType(N: 1);
1537 auto *NewOff = IRB.CreatePHI(Ty: OffTy, NumReservedValues: PHI->getNumIncomingValues());
1538 NewOff->takeName(V: Off);
1539 for (auto [V, BB] : llvm::zip(t: PHI->incoming_values(), u: PHI->blocks())) {
1540 assert(OffParts.count(V) && "An offset part had to be created by now");
1541 Value *VOff = std::get<1>(in: getPtrParts(V));
1542 NewOff->addIncoming(V: VOff, BB);
1543 }
1544 copyMetadata(Dest: NewOff, Src: PHI);
1545
1546 // Note: We don't eraseFromParent() the temporaries because we don't want
1547 // to put the corrections maps in an inconstent state. That'll be handed
1548 // during the rest of the killing. Also, `ValueToValueMapTy` guarantees
1549 // that references in that map will be updated as well.
1550 // Note that if the temporary instruction got `InstSimplify`'d away, it
1551 // might be something like a block argument.
1552 if (auto *RsrcInst = dyn_cast<Instruction>(Val: Rsrc)) {
1553 ConditionalTemps.push_back(Elt: RsrcInst);
1554 RsrcInst->replaceAllUsesWith(V: NewRsrc);
1555 }
1556 if (auto *OffInst = dyn_cast<Instruction>(Val: Off)) {
1557 ConditionalTemps.push_back(Elt: OffInst);
1558 OffInst->replaceAllUsesWith(V: NewOff);
1559 }
1560
1561 // Save on recomputing the cycle traversals in known-root cases.
1562 if (MaybeRsrc)
1563 for (Value *V : Seen)
1564 FoundRsrcs[V] = NewRsrc;
1565 } else if (isa<SelectInst>(Val: I)) {
1566 if (MaybeRsrc) {
1567 if (auto *RsrcInst = dyn_cast<Instruction>(Val: Rsrc)) {
1568 // Guard against conditionals that were already folded away.
1569 if (RsrcInst != *MaybeRsrc) {
1570 ConditionalTemps.push_back(Elt: RsrcInst);
1571 RsrcInst->replaceAllUsesWith(V: *MaybeRsrc);
1572 }
1573 }
1574 for (Value *V : Seen)
1575 FoundRsrcs[V] = *MaybeRsrc;
1576 }
1577 } else {
1578 llvm_unreachable("Only PHIs and selects go in the conditionals list");
1579 }
1580 }
1581}
1582
1583void SplitPtrStructs::killAndReplaceSplitInstructions(
1584 SmallVectorImpl<Instruction *> &Origs) {
1585 for (Instruction *I : ConditionalTemps)
1586 I->eraseFromParent();
1587
1588 for (Instruction *I : Origs) {
1589 if (!SplitUsers.contains(V: I))
1590 continue;
1591
1592 SmallVector<DbgVariableRecord *> Dbgs;
1593 findDbgValues(V: I, DbgVariableRecords&: Dbgs);
1594 for (DbgVariableRecord *Dbg : Dbgs) {
1595 auto &DL = I->getDataLayout();
1596 assert(isSplitFatPtr(I->getType()) &&
1597 "We should've RAUW'd away loads, stores, etc. at this point");
1598 DbgVariableRecord *OffDbg = Dbg->clone();
1599 auto [Rsrc, Off] = getPtrParts(V: I);
1600
1601 int64_t RsrcSz = DL.getTypeSizeInBits(Ty: Rsrc->getType());
1602 int64_t OffSz = DL.getTypeSizeInBits(Ty: Off->getType());
1603
1604 std::optional<DIExpression *> RsrcExpr =
1605 DIExpression::createFragmentExpression(Expr: Dbg->getExpression(), OffsetInBits: 0,
1606 SizeInBits: RsrcSz);
1607 std::optional<DIExpression *> OffExpr =
1608 DIExpression::createFragmentExpression(Expr: Dbg->getExpression(), OffsetInBits: RsrcSz,
1609 SizeInBits: OffSz);
1610 if (OffExpr) {
1611 OffDbg->setExpression(*OffExpr);
1612 OffDbg->replaceVariableLocationOp(OldValue: I, NewValue: Off);
1613 OffDbg->insertBefore(InsertBefore: Dbg);
1614 } else {
1615 OffDbg->eraseFromParent();
1616 }
1617 if (RsrcExpr) {
1618 Dbg->setExpression(*RsrcExpr);
1619 Dbg->replaceVariableLocationOp(OldValue: I, NewValue: Rsrc);
1620 } else {
1621 Dbg->replaceVariableLocationOp(OldValue: I, NewValue: PoisonValue::get(T: I->getType()));
1622 }
1623 }
1624
1625 Value *Poison = PoisonValue::get(T: I->getType());
1626 I->replaceUsesWithIf(New: Poison, ShouldReplace: [&](const Use &U) -> bool {
1627 if (const auto *UI = dyn_cast<Instruction>(Val: U.getUser()))
1628 return SplitUsers.contains(V: UI);
1629 return false;
1630 });
1631
1632 if (I->use_empty()) {
1633 I->eraseFromParent();
1634 continue;
1635 }
1636 IRB.SetInsertPoint(*I->getInsertionPointAfterDef());
1637 IRB.SetCurrentDebugLocation(I->getDebugLoc());
1638 auto [Rsrc, Off] = getPtrParts(V: I);
1639 Value *Struct = PoisonValue::get(T: I->getType());
1640 Struct = IRB.CreateInsertValue(Agg: Struct, Val: Rsrc, Idxs: 0);
1641 Struct = IRB.CreateInsertValue(Agg: Struct, Val: Off, Idxs: 1);
1642 copyMetadata(Dest: Struct, Src: I);
1643 Struct->takeName(V: I);
1644 I->replaceAllUsesWith(V: Struct);
1645 I->eraseFromParent();
1646 }
1647}
1648
1649void SplitPtrStructs::setAlign(CallInst *Intr, Align A, unsigned RsrcArgIdx) {
1650 LLVMContext &Ctx = Intr->getContext();
1651 Intr->addParamAttr(ArgNo: RsrcArgIdx, Attr: Attribute::getWithAlignment(Context&: Ctx, Alignment: A));
1652}
1653
1654void SplitPtrStructs::insertPreMemOpFence(AtomicOrdering Order,
1655 SyncScope::ID SSID) {
1656 switch (Order) {
1657 case AtomicOrdering::Release:
1658 case AtomicOrdering::AcquireRelease:
1659 case AtomicOrdering::SequentiallyConsistent:
1660 IRB.CreateFence(Ordering: AtomicOrdering::Release, SSID);
1661 break;
1662 default:
1663 break;
1664 }
1665}
1666
1667void SplitPtrStructs::insertPostMemOpFence(AtomicOrdering Order,
1668 SyncScope::ID SSID) {
1669 switch (Order) {
1670 case AtomicOrdering::Acquire:
1671 case AtomicOrdering::AcquireRelease:
1672 case AtomicOrdering::SequentiallyConsistent:
1673 IRB.CreateFence(Ordering: AtomicOrdering::Acquire, SSID);
1674 break;
1675 default:
1676 break;
1677 }
1678}
1679
1680Value *SplitPtrStructs::handleMemoryInst(Instruction *I, Value *Arg, Value *Ptr,
1681 Type *Ty, Align Alignment,
1682 AtomicOrdering Order, bool IsVolatile,
1683 SyncScope::ID SSID) {
1684 IRB.SetInsertPoint(I);
1685
1686 auto [Rsrc, Off] = getPtrParts(V: Ptr);
1687 SmallVector<Value *, 5> Args;
1688 if (Arg)
1689 Args.push_back(Elt: Arg);
1690 Args.push_back(Elt: Rsrc);
1691 Args.push_back(Elt: Off);
1692 insertPreMemOpFence(Order, SSID);
1693 // soffset is always 0 for these cases, where we always want any offset to be
1694 // part of bounds checking and we don't know which parts of the GEPs is
1695 // uniform.
1696 Args.push_back(Elt: IRB.getInt32(C: 0));
1697
1698 uint32_t Aux = 0;
1699 if (IsVolatile)
1700 Aux |= AMDGPU::CPol::VOLATILE;
1701 Args.push_back(Elt: IRB.getInt32(C: Aux));
1702
1703 Intrinsic::ID IID = Intrinsic::not_intrinsic;
1704 if (isa<LoadInst>(Val: I))
1705 IID = Order == AtomicOrdering::NotAtomic
1706 ? Intrinsic::amdgcn_raw_ptr_buffer_load
1707 : Intrinsic::amdgcn_raw_ptr_atomic_buffer_load;
1708 else if (isa<StoreInst>(Val: I))
1709 IID = Intrinsic::amdgcn_raw_ptr_buffer_store;
1710 else if (auto *RMW = dyn_cast<AtomicRMWInst>(Val: I)) {
1711 switch (RMW->getOperation()) {
1712 case AtomicRMWInst::Xchg:
1713 IID = Intrinsic::amdgcn_raw_ptr_buffer_atomic_swap;
1714 break;
1715 case AtomicRMWInst::Add:
1716 IID = Intrinsic::amdgcn_raw_ptr_buffer_atomic_add;
1717 break;
1718 case AtomicRMWInst::Sub:
1719 IID = Intrinsic::amdgcn_raw_ptr_buffer_atomic_sub;
1720 break;
1721 case AtomicRMWInst::And:
1722 IID = Intrinsic::amdgcn_raw_ptr_buffer_atomic_and;
1723 break;
1724 case AtomicRMWInst::Or:
1725 IID = Intrinsic::amdgcn_raw_ptr_buffer_atomic_or;
1726 break;
1727 case AtomicRMWInst::Xor:
1728 IID = Intrinsic::amdgcn_raw_ptr_buffer_atomic_xor;
1729 break;
1730 case AtomicRMWInst::Max:
1731 IID = Intrinsic::amdgcn_raw_ptr_buffer_atomic_smax;
1732 break;
1733 case AtomicRMWInst::Min:
1734 IID = Intrinsic::amdgcn_raw_ptr_buffer_atomic_smin;
1735 break;
1736 case AtomicRMWInst::UMax:
1737 IID = Intrinsic::amdgcn_raw_ptr_buffer_atomic_umax;
1738 break;
1739 case AtomicRMWInst::UMin:
1740 IID = Intrinsic::amdgcn_raw_ptr_buffer_atomic_umin;
1741 break;
1742 case AtomicRMWInst::FAdd:
1743 IID = Intrinsic::amdgcn_raw_ptr_buffer_atomic_fadd;
1744 break;
1745 case AtomicRMWInst::FMax:
1746 IID = Intrinsic::amdgcn_raw_ptr_buffer_atomic_fmax;
1747 break;
1748 case AtomicRMWInst::FMin:
1749 IID = Intrinsic::amdgcn_raw_ptr_buffer_atomic_fmin;
1750 break;
1751 case AtomicRMWInst::USubCond:
1752 IID = Intrinsic::amdgcn_raw_ptr_buffer_atomic_cond_sub_u32;
1753 break;
1754 case AtomicRMWInst::USubSat:
1755 IID = Intrinsic::amdgcn_raw_ptr_buffer_atomic_sub_clamp_u32;
1756 break;
1757 case AtomicRMWInst::FSub: {
1758 reportFatalUsageError(
1759 reason: "atomic floating point subtraction not supported for "
1760 "buffer resources and should've been expanded away");
1761 break;
1762 }
1763 case AtomicRMWInst::FMaximum: {
1764 reportFatalUsageError(
1765 reason: "atomic floating point fmaximum not supported for "
1766 "buffer resources and should've been expanded away");
1767 break;
1768 }
1769 case AtomicRMWInst::FMinimum: {
1770 reportFatalUsageError(
1771 reason: "atomic floating point fminimum not supported for "
1772 "buffer resources and should've been expanded away");
1773 break;
1774 }
1775 case AtomicRMWInst::Nand:
1776 reportFatalUsageError(
1777 reason: "atomic nand not supported for buffer resources and "
1778 "should've been expanded away");
1779 break;
1780 case AtomicRMWInst::UIncWrap:
1781 case AtomicRMWInst::UDecWrap:
1782 reportFatalUsageError(
1783 reason: "wrapping increment/decrement not supported for "
1784 "buffer resources and should've been expanded away");
1785 break;
1786 case AtomicRMWInst::BAD_BINOP:
1787 llvm_unreachable("Not sure how we got a bad binop");
1788 }
1789 }
1790
1791 auto *Call = IRB.CreateIntrinsic(ID: IID, Types: Ty, Args);
1792 copyMetadata(Dest: Call, Src: I);
1793 setAlign(Intr: Call, A: Alignment, RsrcArgIdx: Arg ? 1 : 0);
1794 Call->takeName(V: I);
1795
1796 insertPostMemOpFence(Order, SSID);
1797 // The "no moving p7 directly" rewrites ensure that this load or store won't
1798 // itself need to be split into parts.
1799 SplitUsers.insert(V: I);
1800 I->replaceAllUsesWith(V: Call);
1801 return Call;
1802}
1803
1804PtrParts SplitPtrStructs::visitInstruction(Instruction &I) {
1805 return {nullptr, nullptr};
1806}
1807
1808PtrParts SplitPtrStructs::visitLoadInst(LoadInst &LI) {
1809 if (!isSplitFatPtr(Ty: LI.getPointerOperandType()))
1810 return {nullptr, nullptr};
1811 handleMemoryInst(I: &LI, Arg: nullptr, Ptr: LI.getPointerOperand(), Ty: LI.getType(),
1812 Alignment: LI.getAlign(), Order: LI.getOrdering(), IsVolatile: LI.isVolatile(),
1813 SSID: LI.getSyncScopeID());
1814 return {nullptr, nullptr};
1815}
1816
1817PtrParts SplitPtrStructs::visitStoreInst(StoreInst &SI) {
1818 if (!isSplitFatPtr(Ty: SI.getPointerOperandType()))
1819 return {nullptr, nullptr};
1820 Value *Arg = SI.getValueOperand();
1821 handleMemoryInst(I: &SI, Arg, Ptr: SI.getPointerOperand(), Ty: Arg->getType(),
1822 Alignment: SI.getAlign(), Order: SI.getOrdering(), IsVolatile: SI.isVolatile(),
1823 SSID: SI.getSyncScopeID());
1824 return {nullptr, nullptr};
1825}
1826
1827PtrParts SplitPtrStructs::visitAtomicRMWInst(AtomicRMWInst &AI) {
1828 if (!isSplitFatPtr(Ty: AI.getPointerOperand()->getType()))
1829 return {nullptr, nullptr};
1830 Value *Arg = AI.getValOperand();
1831 handleMemoryInst(I: &AI, Arg, Ptr: AI.getPointerOperand(), Ty: Arg->getType(),
1832 Alignment: AI.getAlign(), Order: AI.getOrdering(), IsVolatile: AI.isVolatile(),
1833 SSID: AI.getSyncScopeID());
1834 return {nullptr, nullptr};
1835}
1836
1837// Unlike load, store, and RMW, cmpxchg needs special handling to account
1838// for the boolean argument.
1839PtrParts SplitPtrStructs::visitAtomicCmpXchgInst(AtomicCmpXchgInst &AI) {
1840 Value *Ptr = AI.getPointerOperand();
1841 if (!isSplitFatPtr(Ty: Ptr->getType()))
1842 return {nullptr, nullptr};
1843 IRB.SetInsertPoint(&AI);
1844
1845 Type *Ty = AI.getNewValOperand()->getType();
1846 AtomicOrdering Order = AI.getMergedOrdering();
1847 SyncScope::ID SSID = AI.getSyncScopeID();
1848 bool IsNonTemporal = AI.getMetadata(KindID: LLVMContext::MD_nontemporal);
1849
1850 auto [Rsrc, Off] = getPtrParts(V: Ptr);
1851 insertPreMemOpFence(Order, SSID);
1852
1853 uint32_t Aux = 0;
1854 if (IsNonTemporal)
1855 Aux |= AMDGPU::CPol::SLC;
1856 if (AI.isVolatile())
1857 Aux |= AMDGPU::CPol::VOLATILE;
1858 auto *Call =
1859 IRB.CreateIntrinsic(ID: Intrinsic::amdgcn_raw_ptr_buffer_atomic_cmpswap, Types: Ty,
1860 Args: {AI.getNewValOperand(), AI.getCompareOperand(), Rsrc,
1861 Off, IRB.getInt32(C: 0), IRB.getInt32(C: Aux)});
1862 copyMetadata(Dest: Call, Src: &AI);
1863 setAlign(Intr: Call, A: AI.getAlign(), RsrcArgIdx: 2);
1864 Call->takeName(V: &AI);
1865 insertPostMemOpFence(Order, SSID);
1866
1867 Value *Res = PoisonValue::get(T: AI.getType());
1868 Res = IRB.CreateInsertValue(Agg: Res, Val: Call, Idxs: 0);
1869 if (!AI.isWeak()) {
1870 Value *Succeeded = IRB.CreateICmpEQ(LHS: Call, RHS: AI.getCompareOperand());
1871 Res = IRB.CreateInsertValue(Agg: Res, Val: Succeeded, Idxs: 1);
1872 }
1873 SplitUsers.insert(V: &AI);
1874 AI.replaceAllUsesWith(V: Res);
1875 return {nullptr, nullptr};
1876}
1877
1878PtrParts SplitPtrStructs::visitGetElementPtrInst(GetElementPtrInst &GEP) {
1879 using namespace llvm::PatternMatch;
1880 Value *Ptr = GEP.getPointerOperand();
1881 if (!isSplitFatPtr(Ty: Ptr->getType()))
1882 return {nullptr, nullptr};
1883 IRB.SetInsertPoint(&GEP);
1884
1885 auto [Rsrc, Off] = getPtrParts(V: Ptr);
1886 const DataLayout &DL = GEP.getDataLayout();
1887 bool IsNUW = GEP.hasNoUnsignedWrap();
1888 bool IsNUSW = GEP.hasNoUnsignedSignedWrap();
1889
1890 StructType *ResTy = cast<StructType>(Val: GEP.getType());
1891 Type *ResRsrcTy = ResTy->getElementType(N: 0);
1892 VectorType *ResRsrcVecTy = dyn_cast<VectorType>(Val: ResRsrcTy);
1893 bool BroadcastsPtr = ResRsrcVecTy && !isa<VectorType>(Val: Off->getType());
1894
1895 // In order to call emitGEPOffset() and thus not have to reimplement it,
1896 // we need the GEP result to have ptr addrspace(7) type.
1897 Type *FatPtrTy =
1898 ResRsrcTy->getWithNewType(EltTy: IRB.getPtrTy(AddrSpace: AMDGPUAS::BUFFER_FAT_POINTER));
1899 GEP.mutateType(Ty: FatPtrTy);
1900 Value *OffAccum = emitGEPOffset(Builder: &IRB, DL, GEP: &GEP);
1901 GEP.mutateType(Ty: ResTy);
1902
1903 if (BroadcastsPtr) {
1904 Rsrc = IRB.CreateVectorSplat(EC: ResRsrcVecTy->getElementCount(), V: Rsrc,
1905 Name: Rsrc->getName());
1906 Off = IRB.CreateVectorSplat(EC: ResRsrcVecTy->getElementCount(), V: Off,
1907 Name: Off->getName());
1908 }
1909 if (match(V: OffAccum, P: m_Zero())) { // Constant-zero offset
1910 SplitUsers.insert(V: &GEP);
1911 return {Rsrc, Off};
1912 }
1913
1914 bool HasNonNegativeOff = false;
1915 if (auto *CI = dyn_cast<ConstantInt>(Val: OffAccum)) {
1916 HasNonNegativeOff = !CI->isNegative();
1917 }
1918 Value *NewOff;
1919 if (match(V: Off, P: m_Zero())) {
1920 NewOff = OffAccum;
1921 } else {
1922 NewOff = IRB.CreateAdd(LHS: Off, RHS: OffAccum, Name: "",
1923 /*hasNUW=*/HasNUW: IsNUW || (IsNUSW && HasNonNegativeOff),
1924 /*hasNSW=*/HasNSW: false);
1925 }
1926 copyMetadata(Dest: NewOff, Src: &GEP);
1927 NewOff->takeName(V: &GEP);
1928 SplitUsers.insert(V: &GEP);
1929 return {Rsrc, NewOff};
1930}
1931
1932PtrParts SplitPtrStructs::visitPtrToIntInst(PtrToIntInst &PI) {
1933 Value *Ptr = PI.getPointerOperand();
1934 if (!isSplitFatPtr(Ty: Ptr->getType()))
1935 return {nullptr, nullptr};
1936 IRB.SetInsertPoint(&PI);
1937
1938 Type *ResTy = PI.getType();
1939 unsigned Width = ResTy->getScalarSizeInBits();
1940
1941 auto [Rsrc, Off] = getPtrParts(V: Ptr);
1942 const DataLayout &DL = PI.getDataLayout();
1943 unsigned FatPtrWidth = DL.getPointerSizeInBits(AS: AMDGPUAS::BUFFER_FAT_POINTER);
1944
1945 Value *Res;
1946 if (Width <= BufferOffsetWidth) {
1947 Res = IRB.CreateIntCast(V: Off, DestTy: ResTy, /*isSigned=*/false,
1948 Name: PI.getName() + ".off");
1949 } else {
1950 Value *RsrcInt = IRB.CreatePtrToInt(V: Rsrc, DestTy: ResTy, Name: PI.getName() + ".rsrc");
1951 Value *Shl = IRB.CreateShl(
1952 LHS: RsrcInt,
1953 RHS: ConstantExpr::getIntegerValue(Ty: ResTy, V: APInt(Width, BufferOffsetWidth)),
1954 Name: "", HasNUW: Width >= FatPtrWidth, HasNSW: Width > FatPtrWidth);
1955 Value *OffCast = IRB.CreateIntCast(V: Off, DestTy: ResTy, /*isSigned=*/false,
1956 Name: PI.getName() + ".off");
1957 Res = IRB.CreateOr(LHS: Shl, RHS: OffCast);
1958 }
1959
1960 copyMetadata(Dest: Res, Src: &PI);
1961 Res->takeName(V: &PI);
1962 SplitUsers.insert(V: &PI);
1963 PI.replaceAllUsesWith(V: Res);
1964 return {nullptr, nullptr};
1965}
1966
1967PtrParts SplitPtrStructs::visitPtrToAddrInst(PtrToAddrInst &PA) {
1968 Value *Ptr = PA.getPointerOperand();
1969 if (!isSplitFatPtr(Ty: Ptr->getType()))
1970 return {nullptr, nullptr};
1971 IRB.SetInsertPoint(&PA);
1972
1973 auto [Rsrc, Off] = getPtrParts(V: Ptr);
1974 Value *Res = IRB.CreateIntCast(V: Off, DestTy: PA.getType(), /*isSigned=*/false);
1975 copyMetadata(Dest: Res, Src: &PA);
1976 Res->takeName(V: &PA);
1977 SplitUsers.insert(V: &PA);
1978 PA.replaceAllUsesWith(V: Res);
1979 return {nullptr, nullptr};
1980}
1981
1982PtrParts SplitPtrStructs::visitIntToPtrInst(IntToPtrInst &IP) {
1983 if (!isSplitFatPtr(Ty: IP.getType()))
1984 return {nullptr, nullptr};
1985 IRB.SetInsertPoint(&IP);
1986 const DataLayout &DL = IP.getDataLayout();
1987 unsigned RsrcPtrWidth = DL.getPointerSizeInBits(AS: AMDGPUAS::BUFFER_RESOURCE);
1988 Value *Int = IP.getOperand(i_nocapture: 0);
1989 Type *IntTy = Int->getType();
1990 Type *RsrcIntTy = IntTy->getWithNewBitWidth(NewBitWidth: RsrcPtrWidth);
1991 unsigned Width = IntTy->getScalarSizeInBits();
1992
1993 auto *RetTy = cast<StructType>(Val: IP.getType());
1994 Type *RsrcTy = RetTy->getElementType(N: 0);
1995 Type *OffTy = RetTy->getElementType(N: 1);
1996 Value *RsrcPart = IRB.CreateLShr(
1997 LHS: Int,
1998 RHS: ConstantExpr::getIntegerValue(Ty: IntTy, V: APInt(Width, BufferOffsetWidth)));
1999 Value *RsrcInt = IRB.CreateIntCast(V: RsrcPart, DestTy: RsrcIntTy, /*isSigned=*/false);
2000 Value *Rsrc = IRB.CreateIntToPtr(V: RsrcInt, DestTy: RsrcTy, Name: IP.getName() + ".rsrc");
2001 Value *Off =
2002 IRB.CreateIntCast(V: Int, DestTy: OffTy, /*IsSigned=*/isSigned: false, Name: IP.getName() + ".off");
2003
2004 copyMetadata(Dest: Rsrc, Src: &IP);
2005 SplitUsers.insert(V: &IP);
2006 return {Rsrc, Off};
2007}
2008
2009PtrParts SplitPtrStructs::visitAddrSpaceCastInst(AddrSpaceCastInst &I) {
2010 // TODO(krzysz00): handle casts from ptr addrspace(7) to global pointers
2011 // by computing the effective address.
2012 if (!isSplitFatPtr(Ty: I.getType()))
2013 return {nullptr, nullptr};
2014 IRB.SetInsertPoint(&I);
2015 Value *In = I.getPointerOperand();
2016 // No-op casts preserve parts
2017 if (In->getType() == I.getType()) {
2018 auto [Rsrc, Off] = getPtrParts(V: In);
2019 SplitUsers.insert(V: &I);
2020 return {Rsrc, Off};
2021 }
2022
2023 auto *ResTy = cast<StructType>(Val: I.getType());
2024 Type *RsrcTy = ResTy->getElementType(N: 0);
2025 Type *OffTy = ResTy->getElementType(N: 1);
2026 Value *ZeroOff = Constant::getNullValue(Ty: OffTy);
2027
2028 // Special case for null pointers, undef, and poison, which can be created by
2029 // address space propagation.
2030 auto *InConst = dyn_cast<Constant>(Val: In);
2031 if (InConst && InConst->isNullValue()) {
2032 Value *NullRsrc = Constant::getNullValue(Ty: RsrcTy);
2033 SplitUsers.insert(V: &I);
2034 return {NullRsrc, ZeroOff};
2035 }
2036 if (isa<PoisonValue>(Val: In)) {
2037 Value *PoisonRsrc = PoisonValue::get(T: RsrcTy);
2038 Value *PoisonOff = PoisonValue::get(T: OffTy);
2039 SplitUsers.insert(V: &I);
2040 return {PoisonRsrc, PoisonOff};
2041 }
2042 if (isa<UndefValue>(Val: In)) {
2043 Value *UndefRsrc = UndefValue::get(T: RsrcTy);
2044 Value *UndefOff = UndefValue::get(T: OffTy);
2045 SplitUsers.insert(V: &I);
2046 return {UndefRsrc, UndefOff};
2047 }
2048
2049 if (I.getSrcAddressSpace() != AMDGPUAS::BUFFER_RESOURCE)
2050 reportFatalUsageError(
2051 reason: "only buffer resources (addrspace 8) and null/poison pointers can be "
2052 "cast to buffer fat pointers (addrspace 7)");
2053 SplitUsers.insert(V: &I);
2054 return {In, ZeroOff};
2055}
2056
2057PtrParts SplitPtrStructs::visitICmpInst(ICmpInst &Cmp) {
2058 Value *Lhs = Cmp.getOperand(i_nocapture: 0);
2059 if (!isSplitFatPtr(Ty: Lhs->getType()))
2060 return {nullptr, nullptr};
2061 Value *Rhs = Cmp.getOperand(i_nocapture: 1);
2062 IRB.SetInsertPoint(&Cmp);
2063 ICmpInst::Predicate Pred = Cmp.getPredicate();
2064
2065 assert((Pred == ICmpInst::ICMP_EQ || Pred == ICmpInst::ICMP_NE) &&
2066 "Pointer comparison is only equal or unequal");
2067 auto [LhsRsrc, LhsOff] = getPtrParts(V: Lhs);
2068 auto [RhsRsrc, RhsOff] = getPtrParts(V: Rhs);
2069 Value *Res = IRB.CreateICmp(P: Pred, LHS: LhsOff, RHS: RhsOff);
2070 copyMetadata(Dest: Res, Src: &Cmp);
2071 Res->takeName(V: &Cmp);
2072 SplitUsers.insert(V: &Cmp);
2073 Cmp.replaceAllUsesWith(V: Res);
2074 return {nullptr, nullptr};
2075}
2076
2077PtrParts SplitPtrStructs::visitFreezeInst(FreezeInst &I) {
2078 if (!isSplitFatPtr(Ty: I.getType()))
2079 return {nullptr, nullptr};
2080 IRB.SetInsertPoint(&I);
2081 auto [Rsrc, Off] = getPtrParts(V: I.getOperand(i_nocapture: 0));
2082
2083 Value *RsrcRes = IRB.CreateFreeze(V: Rsrc, Name: I.getName() + ".rsrc");
2084 copyMetadata(Dest: RsrcRes, Src: &I);
2085 Value *OffRes = IRB.CreateFreeze(V: Off, Name: I.getName() + ".off");
2086 copyMetadata(Dest: OffRes, Src: &I);
2087 SplitUsers.insert(V: &I);
2088 return {RsrcRes, OffRes};
2089}
2090
2091PtrParts SplitPtrStructs::visitExtractElementInst(ExtractElementInst &I) {
2092 if (!isSplitFatPtr(Ty: I.getType()))
2093 return {nullptr, nullptr};
2094 IRB.SetInsertPoint(&I);
2095 Value *Vec = I.getVectorOperand();
2096 Value *Idx = I.getIndexOperand();
2097 auto [Rsrc, Off] = getPtrParts(V: Vec);
2098
2099 Value *RsrcRes = IRB.CreateExtractElement(Vec: Rsrc, Idx, Name: I.getName() + ".rsrc");
2100 copyMetadata(Dest: RsrcRes, Src: &I);
2101 Value *OffRes = IRB.CreateExtractElement(Vec: Off, Idx, Name: I.getName() + ".off");
2102 copyMetadata(Dest: OffRes, Src: &I);
2103 SplitUsers.insert(V: &I);
2104 return {RsrcRes, OffRes};
2105}
2106
2107PtrParts SplitPtrStructs::visitInsertElementInst(InsertElementInst &I) {
2108 // The mutated instructions temporarily don't return vectors, and so
2109 // we need the generic getType() here to avoid crashes.
2110 if (!isSplitFatPtr(Ty: cast<Instruction>(Val&: I).getType()))
2111 return {nullptr, nullptr};
2112 IRB.SetInsertPoint(&I);
2113 Value *Vec = I.getOperand(i_nocapture: 0);
2114 Value *Elem = I.getOperand(i_nocapture: 1);
2115 Value *Idx = I.getOperand(i_nocapture: 2);
2116 auto [VecRsrc, VecOff] = getPtrParts(V: Vec);
2117 auto [ElemRsrc, ElemOff] = getPtrParts(V: Elem);
2118
2119 Value *RsrcRes =
2120 IRB.CreateInsertElement(Vec: VecRsrc, NewElt: ElemRsrc, Idx, Name: I.getName() + ".rsrc");
2121 copyMetadata(Dest: RsrcRes, Src: &I);
2122 Value *OffRes =
2123 IRB.CreateInsertElement(Vec: VecOff, NewElt: ElemOff, Idx, Name: I.getName() + ".off");
2124 copyMetadata(Dest: OffRes, Src: &I);
2125 SplitUsers.insert(V: &I);
2126 return {RsrcRes, OffRes};
2127}
2128
2129PtrParts SplitPtrStructs::visitShuffleVectorInst(ShuffleVectorInst &I) {
2130 // Cast is needed for the same reason as insertelement's.
2131 if (!isSplitFatPtr(Ty: cast<Instruction>(Val&: I).getType()))
2132 return {nullptr, nullptr};
2133 IRB.SetInsertPoint(&I);
2134
2135 Value *V1 = I.getOperand(i_nocapture: 0);
2136 Value *V2 = I.getOperand(i_nocapture: 1);
2137 ArrayRef<int> Mask = I.getShuffleMask();
2138 auto [V1Rsrc, V1Off] = getPtrParts(V: V1);
2139 auto [V2Rsrc, V2Off] = getPtrParts(V: V2);
2140
2141 Value *RsrcRes =
2142 IRB.CreateShuffleVector(V1: V1Rsrc, V2: V2Rsrc, Mask, Name: I.getName() + ".rsrc");
2143 copyMetadata(Dest: RsrcRes, Src: &I);
2144 Value *OffRes =
2145 IRB.CreateShuffleVector(V1: V1Off, V2: V2Off, Mask, Name: I.getName() + ".off");
2146 copyMetadata(Dest: OffRes, Src: &I);
2147 SplitUsers.insert(V: &I);
2148 return {RsrcRes, OffRes};
2149}
2150
2151PtrParts SplitPtrStructs::visitPHINode(PHINode &PHI) {
2152 if (!isSplitFatPtr(Ty: PHI.getType()))
2153 return {nullptr, nullptr};
2154 IRB.SetInsertPoint(*PHI.getInsertionPointAfterDef());
2155 // Phi nodes will be handled in post-processing after we've visited every
2156 // instruction. However, instead of just returning {nullptr, nullptr},
2157 // we explicitly create the temporary extractvalue operations that are our
2158 // temporary results so that they end up at the beginning of the block with
2159 // the PHIs.
2160 Value *TmpRsrc = IRB.CreateExtractValue(Agg: &PHI, Idxs: 0, Name: PHI.getName() + ".rsrc");
2161 Value *TmpOff = IRB.CreateExtractValue(Agg: &PHI, Idxs: 1, Name: PHI.getName() + ".off");
2162 Conditionals.push_back(Elt: &PHI);
2163 SplitUsers.insert(V: &PHI);
2164 return {TmpRsrc, TmpOff};
2165}
2166
2167PtrParts SplitPtrStructs::visitSelectInst(SelectInst &SI) {
2168 if (!isSplitFatPtr(Ty: SI.getType()))
2169 return {nullptr, nullptr};
2170 IRB.SetInsertPoint(&SI);
2171
2172 Value *Cond = SI.getCondition();
2173 Value *True = SI.getTrueValue();
2174 Value *False = SI.getFalseValue();
2175 auto [TrueRsrc, TrueOff] = getPtrParts(V: True);
2176 auto [FalseRsrc, FalseOff] = getPtrParts(V: False);
2177
2178 Value *RsrcRes =
2179 IRB.CreateSelect(C: Cond, True: TrueRsrc, False: FalseRsrc, Name: SI.getName() + ".rsrc", MDFrom: &SI);
2180 copyMetadata(Dest: RsrcRes, Src: &SI);
2181 Conditionals.push_back(Elt: &SI);
2182 Value *OffRes =
2183 IRB.CreateSelect(C: Cond, True: TrueOff, False: FalseOff, Name: SI.getName() + ".off", MDFrom: &SI);
2184 copyMetadata(Dest: OffRes, Src: &SI);
2185 SplitUsers.insert(V: &SI);
2186 return {RsrcRes, OffRes};
2187}
2188
2189/// Returns true if this intrinsic needs to be removed when it is
2190/// applied to `ptr addrspace(7)` values. Calls to these intrinsics are
2191/// rewritten into calls to versions of that intrinsic on the resource
2192/// descriptor.
2193static bool isRemovablePointerIntrinsic(Intrinsic::ID IID) {
2194 switch (IID) {
2195 default:
2196 return false;
2197 case Intrinsic::amdgcn_make_buffer_rsrc:
2198 case Intrinsic::ptrmask:
2199 case Intrinsic::invariant_start:
2200 case Intrinsic::invariant_end:
2201 case Intrinsic::launder_invariant_group:
2202 case Intrinsic::strip_invariant_group:
2203 case Intrinsic::memcpy:
2204 case Intrinsic::memcpy_inline:
2205 case Intrinsic::memmove:
2206 case Intrinsic::memset:
2207 case Intrinsic::memset_inline:
2208 case Intrinsic::experimental_memset_pattern:
2209 case Intrinsic::amdgcn_load_to_lds:
2210 return true;
2211 }
2212}
2213
2214PtrParts SplitPtrStructs::visitIntrinsicInst(IntrinsicInst &I) {
2215 Intrinsic::ID IID = I.getIntrinsicID();
2216 switch (IID) {
2217 default:
2218 break;
2219 case Intrinsic::amdgcn_make_buffer_rsrc: {
2220 if (!isSplitFatPtr(Ty: I.getType()))
2221 return {nullptr, nullptr};
2222 Value *Base = I.getArgOperand(i: 0);
2223 Value *Stride = I.getArgOperand(i: 1);
2224 Value *NumRecords = I.getArgOperand(i: 2);
2225 Value *Flags = I.getArgOperand(i: 3);
2226 auto *SplitType = cast<StructType>(Val: I.getType());
2227 Type *RsrcType = SplitType->getElementType(N: 0);
2228 Type *OffType = SplitType->getElementType(N: 1);
2229 IRB.SetInsertPoint(&I);
2230 Value *Rsrc = IRB.CreateIntrinsic(ID: IID, Types: {RsrcType, Base->getType()},
2231 Args: {Base, Stride, NumRecords, Flags});
2232 copyMetadata(Dest: Rsrc, Src: &I);
2233 Rsrc->takeName(V: &I);
2234 Value *Zero = Constant::getNullValue(Ty: OffType);
2235 SplitUsers.insert(V: &I);
2236 return {Rsrc, Zero};
2237 }
2238 case Intrinsic::ptrmask: {
2239 Value *Ptr = I.getArgOperand(i: 0);
2240 if (!isSplitFatPtr(Ty: Ptr->getType()))
2241 return {nullptr, nullptr};
2242 Value *Mask = I.getArgOperand(i: 1);
2243 IRB.SetInsertPoint(&I);
2244 auto [Rsrc, Off] = getPtrParts(V: Ptr);
2245 if (Mask->getType() != Off->getType())
2246 reportFatalUsageError(reason: "offset width is not equal to index width of fat "
2247 "pointer (data layout not set up correctly?)");
2248 Value *OffRes = IRB.CreateAnd(LHS: Off, RHS: Mask, Name: I.getName() + ".off");
2249 copyMetadata(Dest: OffRes, Src: &I);
2250 SplitUsers.insert(V: &I);
2251 return {Rsrc, OffRes};
2252 }
2253 // Pointer annotation intrinsics that, given their object-wide nature
2254 // operate on the resource part.
2255 case Intrinsic::invariant_start: {
2256 Value *Ptr = I.getArgOperand(i: 1);
2257 if (!isSplitFatPtr(Ty: Ptr->getType()))
2258 return {nullptr, nullptr};
2259 IRB.SetInsertPoint(&I);
2260 auto [Rsrc, Off] = getPtrParts(V: Ptr);
2261 Type *NewTy = PointerType::get(C&: I.getContext(), AddressSpace: AMDGPUAS::BUFFER_RESOURCE);
2262 auto *NewRsrc = IRB.CreateIntrinsic(ID: IID, Types: {NewTy}, Args: {I.getOperand(i_nocapture: 0), Rsrc});
2263 copyMetadata(Dest: NewRsrc, Src: &I);
2264 NewRsrc->takeName(V: &I);
2265 SplitUsers.insert(V: &I);
2266 I.replaceAllUsesWith(V: NewRsrc);
2267 return {nullptr, nullptr};
2268 }
2269 case Intrinsic::invariant_end: {
2270 Value *RealPtr = I.getArgOperand(i: 2);
2271 if (!isSplitFatPtr(Ty: RealPtr->getType()))
2272 return {nullptr, nullptr};
2273 IRB.SetInsertPoint(&I);
2274 Value *RealRsrc = getPtrParts(V: RealPtr).first;
2275 Value *InvPtr = I.getArgOperand(i: 0);
2276 Value *Size = I.getArgOperand(i: 1);
2277 Value *NewRsrc = IRB.CreateIntrinsic(ID: IID, Types: {RealRsrc->getType()},
2278 Args: {InvPtr, Size, RealRsrc});
2279 copyMetadata(Dest: NewRsrc, Src: &I);
2280 NewRsrc->takeName(V: &I);
2281 SplitUsers.insert(V: &I);
2282 I.replaceAllUsesWith(V: NewRsrc);
2283 return {nullptr, nullptr};
2284 }
2285 case Intrinsic::launder_invariant_group:
2286 case Intrinsic::strip_invariant_group: {
2287 Value *Ptr = I.getArgOperand(i: 0);
2288 if (!isSplitFatPtr(Ty: Ptr->getType()))
2289 return {nullptr, nullptr};
2290 IRB.SetInsertPoint(&I);
2291 auto [Rsrc, Off] = getPtrParts(V: Ptr);
2292 Value *NewRsrc = IRB.CreateIntrinsic(ID: IID, Types: {Rsrc->getType()}, Args: {Rsrc});
2293 copyMetadata(Dest: NewRsrc, Src: &I);
2294 NewRsrc->takeName(V: &I);
2295 SplitUsers.insert(V: &I);
2296 return {NewRsrc, Off};
2297 }
2298 case Intrinsic::amdgcn_load_to_lds: {
2299 Value *Ptr = I.getArgOperand(i: 0);
2300 if (!isSplitFatPtr(Ty: Ptr->getType()))
2301 return {nullptr, nullptr};
2302 IRB.SetInsertPoint(&I);
2303 auto [Rsrc, Off] = getPtrParts(V: Ptr);
2304 Value *LDSPtr = I.getArgOperand(i: 1);
2305 Value *LoadSize = I.getArgOperand(i: 2);
2306 Value *ImmOff = I.getArgOperand(i: 3);
2307 Value *Aux = I.getArgOperand(i: 4);
2308 Value *SOffset = IRB.getInt32(C: 0);
2309 Instruction *NewLoad = IRB.CreateIntrinsic(
2310 ID: Intrinsic::amdgcn_raw_ptr_buffer_load_lds, Types: {},
2311 Args: {Rsrc, LDSPtr, LoadSize, Off, SOffset, ImmOff, Aux});
2312 copyMetadata(Dest: NewLoad, Src: &I);
2313 SplitUsers.insert(V: &I);
2314 I.replaceAllUsesWith(V: NewLoad);
2315 return {nullptr, nullptr};
2316 }
2317 }
2318 return {nullptr, nullptr};
2319}
2320
2321void SplitPtrStructs::processFunction(Function &F) {
2322 ST = &TM->getSubtarget<GCNSubtarget>(F);
2323 SmallVector<Instruction *, 0> Originals(
2324 llvm::make_pointer_range(Range: instructions(F)));
2325 LLVM_DEBUG(dbgs() << "Splitting pointer structs in function: " << F.getName()
2326 << "\n");
2327 for (Instruction *I : Originals) {
2328 // In some cases, instruction order doesn't reflect program order,
2329 // so the visit() call will have already visited coertain instructions
2330 // by the time this loop gets to them. Avoid re-visiting these so as to,
2331 // for example, avoid processing the same conditional twice.
2332 if (SplitUsers.contains(V: I))
2333 continue;
2334 auto [Rsrc, Off] = visit(I);
2335 assert(((Rsrc && Off) || (!Rsrc && !Off)) &&
2336 "Can't have a resource but no offset");
2337 if (Rsrc)
2338 RsrcParts[I] = Rsrc;
2339 if (Off)
2340 OffParts[I] = Off;
2341 }
2342 processConditionals();
2343 killAndReplaceSplitInstructions(Origs&: Originals);
2344
2345 // Clean up after ourselves to save on memory.
2346 RsrcParts.clear();
2347 OffParts.clear();
2348 SplitUsers.clear();
2349 Conditionals.clear();
2350 ConditionalTemps.clear();
2351}
2352
2353namespace {
2354class AMDGPULowerBufferFatPointers : public ModulePass {
2355public:
2356 static char ID;
2357
2358 AMDGPULowerBufferFatPointers() : ModulePass(ID) {}
2359
2360 bool run(Module &M, const TargetMachine &TM);
2361 bool runOnModule(Module &M) override;
2362
2363 void getAnalysisUsage(AnalysisUsage &AU) const override;
2364};
2365} // namespace
2366
2367/// Returns true if there are values that have a buffer fat pointer in them,
2368/// which means we'll need to perform rewrites on this function. As a side
2369/// effect, this will populate the type remapping cache.
2370static bool containsBufferFatPointers(const Function &F,
2371 BufferFatPtrToStructTypeMap *TypeMap) {
2372 bool HasFatPointers = false;
2373 for (const BasicBlock &BB : F)
2374 for (const Instruction &I : BB) {
2375 HasFatPointers |= (I.getType() != TypeMap->remapType(SrcTy: I.getType()));
2376 // Catch null pointer constants in loads, stores, etc.
2377 for (const Value *V : I.operand_values())
2378 HasFatPointers |= (V->getType() != TypeMap->remapType(SrcTy: V->getType()));
2379 }
2380 return HasFatPointers;
2381}
2382
2383static bool hasFatPointerInterface(const Function &F,
2384 BufferFatPtrToStructTypeMap *TypeMap) {
2385 Type *Ty = F.getFunctionType();
2386 return Ty != TypeMap->remapType(SrcTy: Ty);
2387}
2388
2389/// Move the body of `OldF` into a new function, returning it.
2390static Function *moveFunctionAdaptingType(Function *OldF, FunctionType *NewTy,
2391 ValueToValueMapTy &CloneMap) {
2392 bool IsIntrinsic = OldF->isIntrinsic();
2393 Function *NewF =
2394 Function::Create(Ty: NewTy, Linkage: OldF->getLinkage(), AddrSpace: OldF->getAddressSpace());
2395 NewF->copyAttributesFrom(Src: OldF);
2396 NewF->copyMetadata(Src: OldF, Offset: 0);
2397 NewF->takeName(V: OldF);
2398 NewF->updateAfterNameChange();
2399 NewF->setDLLStorageClass(OldF->getDLLStorageClass());
2400 OldF->getParent()->getFunctionList().insertAfter(where: OldF->getIterator(), New: NewF);
2401
2402 while (!OldF->empty()) {
2403 BasicBlock *BB = &OldF->front();
2404 BB->removeFromParent();
2405 BB->insertInto(Parent: NewF);
2406 CloneMap[BB] = BB;
2407 for (Instruction &I : *BB) {
2408 CloneMap[&I] = &I;
2409 }
2410 }
2411
2412 SmallVector<AttributeSet> ArgAttrs;
2413 AttributeList OldAttrs = OldF->getAttributes();
2414
2415 for (auto [I, OldArg, NewArg] : enumerate(First: OldF->args(), Rest: NewF->args())) {
2416 CloneMap[&NewArg] = &OldArg;
2417 NewArg.takeName(V: &OldArg);
2418 Type *OldArgTy = OldArg.getType(), *NewArgTy = NewArg.getType();
2419 // Temporarily mutate type of `NewArg` to allow RAUW to work.
2420 NewArg.mutateType(Ty: OldArgTy);
2421 OldArg.replaceAllUsesWith(V: &NewArg);
2422 NewArg.mutateType(Ty: NewArgTy);
2423
2424 AttributeSet ArgAttr = OldAttrs.getParamAttrs(ArgNo: I);
2425 // Intrinsics get their attributes fixed later.
2426 if (OldArgTy != NewArgTy && !IsIntrinsic)
2427 ArgAttr = ArgAttr.removeAttributes(
2428 C&: NewF->getContext(),
2429 AttrsToRemove: AttributeFuncs::typeIncompatible(Ty: NewArgTy, AS: ArgAttr));
2430 ArgAttrs.push_back(Elt: ArgAttr);
2431 }
2432 AttributeSet RetAttrs = OldAttrs.getRetAttrs();
2433 if (OldF->getReturnType() != NewF->getReturnType() && !IsIntrinsic)
2434 RetAttrs = RetAttrs.removeAttributes(
2435 C&: NewF->getContext(),
2436 AttrsToRemove: AttributeFuncs::typeIncompatible(Ty: NewF->getReturnType(), AS: RetAttrs));
2437 NewF->setAttributes(AttributeList::get(
2438 C&: NewF->getContext(), FnAttrs: OldAttrs.getFnAttrs(), RetAttrs, ArgAttrs));
2439 return NewF;
2440}
2441
2442static void makeCloneInPraceMap(Function *F, ValueToValueMapTy &CloneMap) {
2443 for (Argument &A : F->args())
2444 CloneMap[&A] = &A;
2445 for (BasicBlock &BB : *F) {
2446 CloneMap[&BB] = &BB;
2447 for (Instruction &I : BB)
2448 CloneMap[&I] = &I;
2449 }
2450}
2451
2452bool AMDGPULowerBufferFatPointers::run(Module &M, const TargetMachine &TM) {
2453 bool Changed = false;
2454 const DataLayout &DL = M.getDataLayout();
2455 // Record the functions which need to be remapped.
2456 // The second element of the pair indicates whether the function has to have
2457 // its arguments or return types adjusted.
2458 SmallVector<std::pair<Function *, bool>> NeedsRemap;
2459
2460 LLVMContext &Ctx = M.getContext();
2461
2462 BufferFatPtrToStructTypeMap StructTM(DL);
2463 BufferFatPtrToIntTypeMap IntTM(DL);
2464 for (const GlobalVariable &GV : M.globals()) {
2465 if (GV.getAddressSpace() == AMDGPUAS::BUFFER_FAT_POINTER) {
2466 // FIXME: Use DiagnosticInfo unsupported but it requires a Function
2467 Ctx.emitError(ErrorStr: "global variables with a buffer fat pointer address "
2468 "space (7) are not supported");
2469 continue;
2470 }
2471
2472 Type *VT = GV.getValueType();
2473 if (VT != StructTM.remapType(SrcTy: VT)) {
2474 // FIXME: Use DiagnosticInfo unsupported but it requires a Function
2475 Ctx.emitError(ErrorStr: "global variables that contain buffer fat pointers "
2476 "(address space 7 pointers) are unsupported. Use "
2477 "buffer resource pointers (address space 8) instead");
2478 continue;
2479 }
2480 }
2481
2482 {
2483 // Collect all constant exprs and aggregates referenced by any function.
2484 SmallVector<Constant *, 8> Worklist;
2485 for (Function &F : M.functions())
2486 for (Instruction &I : instructions(F))
2487 for (Value *Op : I.operands())
2488 if (isa<ConstantExpr, ConstantAggregate>(Val: Op))
2489 Worklist.push_back(Elt: cast<Constant>(Val: Op));
2490
2491 // Recursively look for any referenced buffer pointer constants.
2492 SmallPtrSet<Constant *, 8> Visited;
2493 SetVector<Constant *> BufferFatPtrConsts;
2494 while (!Worklist.empty()) {
2495 Constant *C = Worklist.pop_back_val();
2496 if (!Visited.insert(Ptr: C).second)
2497 continue;
2498 if (isBufferFatPtrOrVector(Ty: C->getType()))
2499 BufferFatPtrConsts.insert(X: C);
2500 for (Value *Op : C->operands())
2501 if (isa<ConstantExpr, ConstantAggregate>(Val: Op))
2502 Worklist.push_back(Elt: cast<Constant>(Val: Op));
2503 }
2504
2505 // Expand all constant expressions using fat buffer pointers to
2506 // instructions.
2507 Changed |= convertUsersOfConstantsToInstructions(
2508 Consts: BufferFatPtrConsts.getArrayRef(), /*RestrictToFunc=*/nullptr,
2509 /*RemoveDeadConstants=*/false, /*IncludeSelf=*/true);
2510 }
2511
2512 StoreFatPtrsAsIntsAndExpandMemcpyVisitor MemOpsRewrite(&IntTM, DL,
2513 M.getContext(), &TM);
2514 LegalizeBufferContentTypesVisitor BufferContentsTypeRewrite(DL,
2515 M.getContext());
2516 for (Function &F : M.functions()) {
2517 bool InterfaceChange = hasFatPointerInterface(F, TypeMap: &StructTM);
2518 bool BodyChanges = containsBufferFatPointers(F, TypeMap: &StructTM);
2519 Changed |= MemOpsRewrite.processFunction(F);
2520 if (InterfaceChange || BodyChanges) {
2521 NeedsRemap.push_back(Elt: std::make_pair(x: &F, y&: InterfaceChange));
2522 Changed |= BufferContentsTypeRewrite.processFunction(F);
2523 }
2524 }
2525 if (NeedsRemap.empty())
2526 return Changed;
2527
2528 SmallVector<Function *> NeedsPostProcess;
2529 SmallVector<Function *> Intrinsics;
2530 // Keep one big map so as to memoize constants across functions.
2531 ValueToValueMapTy CloneMap;
2532 FatPtrConstMaterializer Materializer(&StructTM, CloneMap);
2533
2534 ValueMapper LowerInFuncs(CloneMap, RF_None, &StructTM, &Materializer);
2535 for (auto [F, InterfaceChange] : NeedsRemap) {
2536 Function *NewF = F;
2537 if (InterfaceChange)
2538 NewF = moveFunctionAdaptingType(
2539 OldF: F, NewTy: cast<FunctionType>(Val: StructTM.remapType(SrcTy: F->getFunctionType())),
2540 CloneMap);
2541 else
2542 makeCloneInPraceMap(F, CloneMap);
2543 LowerInFuncs.remapFunction(F&: *NewF);
2544 if (NewF->isIntrinsic())
2545 Intrinsics.push_back(Elt: NewF);
2546 else
2547 NeedsPostProcess.push_back(Elt: NewF);
2548 if (InterfaceChange) {
2549 F->replaceAllUsesWith(V: NewF);
2550 F->eraseFromParent();
2551 }
2552 Changed = true;
2553 }
2554 StructTM.clear();
2555 IntTM.clear();
2556 CloneMap.clear();
2557
2558 SplitPtrStructs Splitter(DL, M.getContext(), &TM);
2559 for (Function *F : NeedsPostProcess)
2560 Splitter.processFunction(F&: *F);
2561 for (Function *F : Intrinsics) {
2562 // use_empty() can also occur with cases like masked load, which will
2563 // have been rewritten out of the module by now but not erased.
2564 if (F->use_empty() || isRemovablePointerIntrinsic(IID: F->getIntrinsicID())) {
2565 F->eraseFromParent();
2566 } else {
2567 std::optional<Function *> NewF = Intrinsic::remangleIntrinsicFunction(F);
2568 if (NewF)
2569 F->replaceAllUsesWith(V: *NewF);
2570 }
2571 }
2572 return Changed;
2573}
2574
2575bool AMDGPULowerBufferFatPointers::runOnModule(Module &M) {
2576 TargetPassConfig &TPC = getAnalysis<TargetPassConfig>();
2577 const TargetMachine &TM = TPC.getTM<TargetMachine>();
2578 return run(M, TM);
2579}
2580
2581char AMDGPULowerBufferFatPointers::ID = 0;
2582
2583char &llvm::AMDGPULowerBufferFatPointersID = AMDGPULowerBufferFatPointers::ID;
2584
2585void AMDGPULowerBufferFatPointers::getAnalysisUsage(AnalysisUsage &AU) const {
2586 AU.addRequired<TargetPassConfig>();
2587}
2588
2589#define PASS_DESC "Lower buffer fat pointer operations to buffer resources"
2590INITIALIZE_PASS_BEGIN(AMDGPULowerBufferFatPointers, DEBUG_TYPE, PASS_DESC,
2591 false, false)
2592INITIALIZE_PASS_DEPENDENCY(TargetPassConfig)
2593INITIALIZE_PASS_END(AMDGPULowerBufferFatPointers, DEBUG_TYPE, PASS_DESC, false,
2594 false)
2595#undef PASS_DESC
2596
2597ModulePass *llvm::createAMDGPULowerBufferFatPointersPass() {
2598 return new AMDGPULowerBufferFatPointers();
2599}
2600
2601PreservedAnalyses
2602AMDGPULowerBufferFatPointersPass::run(Module &M, ModuleAnalysisManager &MA) {
2603 return AMDGPULowerBufferFatPointers().run(M, TM) ? PreservedAnalyses::none()
2604 : PreservedAnalyses::all();
2605}
2606