1//===- SPIRVCBufferAccess.cpp - Translate CBuffer Loads ---------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This pass replaces all accesses to constant buffer global variables with
10// accesses to the proper SPIR-V resource.
11//
12// The pass operates as follows:
13// 1. It finds all constant buffers by looking for the `!hlsl.cbs` metadata.
14// 2. For each cbuffer, it finds the global variable holding the resource handle
15// and the global variables for each of the cbuffer's members.
16// 3. For each member variable, it creates a call to the
17// `llvm.spv.resource.getpointer` intrinsic. This intrinsic takes the
18// resource handle and the member's index within the cbuffer as arguments.
19// The result is a pointer to that member within the SPIR-V resource.
20// 4. It then replaces all uses of the original member global variable with the
21// pointer returned by the `getpointer` intrinsic. This effectively retargets
22// all loads and GEPs to the new resource pointer.
23// 5. Finally, it cleans up by deleting the original global variables and the
24// `!hlsl.cbs` metadata.
25//
26// This approach allows subsequent passes, like SPIRVEmitIntrinsics, to
27// correctly handle GEPs that operate on the result of the `getpointer` call,
28// folding them into a single OpAccessChain instruction.
29//
30//===----------------------------------------------------------------------===//
31
32#include "SPIRVCBufferAccess.h"
33#include "SPIRV.h"
34#include "llvm/Frontend/HLSL/CBuffer.h"
35#include "llvm/IR/IRBuilder.h"
36#include "llvm/IR/IntrinsicsSPIRV.h"
37#include "llvm/IR/Module.h"
38#include "llvm/IR/ReplaceConstant.h"
39#include "llvm/Transforms/Utils/ModuleUtils.h"
40
41#define DEBUG_TYPE "spirv-cbuffer-access"
42using namespace llvm;
43
44// Finds the single instruction that defines the resource handle. This is
45// typically a call to `llvm.spv.resource.handlefrombinding`.
46static Instruction *findHandleDef(GlobalVariable *HandleVar) {
47 for (User *U : HandleVar->users()) {
48 if (auto *SI = dyn_cast<StoreInst>(Val: U)) {
49 if (auto *I = dyn_cast<Instruction>(Val: SI->getValueOperand())) {
50 return I;
51 }
52 }
53 }
54 return nullptr;
55}
56
57static bool replaceCBufferAccesses(Module &M) {
58 std::optional<hlsl::CBufferMetadata> CBufMD =
59 hlsl::CBufferMetadata::get(M, IsPadding: [](Type *Ty) {
60 if (auto *TET = dyn_cast<TargetExtType>(Val: Ty))
61 return TET->getName() == "spirv.Padding";
62 return false;
63 });
64 if (!CBufMD)
65 return false;
66
67 SmallPtrSet<GlobalVariable *, 8> CBufferHandles;
68 SmallVector<Constant *> CBufferGlobals;
69 for (const hlsl::CBufferMapping &Mapping : *CBufMD) {
70 CBufferHandles.insert(Ptr: Mapping.Handle);
71 for (const hlsl::CBufferMember &Member : Mapping.Members)
72 CBufferGlobals.push_back(Elt: Member.GV);
73 }
74 convertUsersOfConstantsToInstructions(Consts: CBufferGlobals);
75
76 for (const hlsl::CBufferMapping &Mapping : *CBufMD) {
77 Instruction *HandleDef = findHandleDef(HandleVar: Mapping.Handle);
78 if (!HandleDef) {
79 report_fatal_error(reason: "Could not find handle definition for cbuffer: " +
80 Mapping.Handle->getName());
81 }
82
83 // The handle definition should dominate all uses of the cbuffer members.
84 // We'll insert our getpointer calls right after it.
85 IRBuilder<> Builder(HandleDef->getNextNode());
86 auto *HandleTy = cast<TargetExtType>(Val: Mapping.Handle->getValueType());
87 auto *LayoutTy = cast<StructType>(Val: HandleTy->getTypeParameter(i: 0));
88 const StructLayout *SL = M.getDataLayout().getStructLayout(Ty: LayoutTy);
89
90 for (const hlsl::CBufferMember &Member : Mapping.Members) {
91 GlobalVariable *MemberGV = Member.GV;
92 if (MemberGV->use_empty()) {
93 continue;
94 }
95
96 uint32_t IndexInStruct = SL->getElementContainingOffset(FixedOffset: Member.Offset);
97
98 // Create the getpointer intrinsic call.
99 Value *IndexVal = Builder.getInt32(C: IndexInStruct);
100 Type *PtrType = MemberGV->getType();
101 Value *GetPointerCall = Builder.CreateIntrinsic(
102 RetTy: PtrType, ID: Intrinsic::spv_resource_getpointer, Args: {HandleDef, IndexVal});
103
104 MemberGV->replaceAllUsesWith(V: GetPointerCall);
105 }
106 }
107
108 // Remove cbuffer handle globals from @llvm.compiler.used list.
109 llvm::removeFromUsedLists(M, ShouldRemove: [&](Constant *C) -> bool {
110 auto *GV = dyn_cast<GlobalVariable>(Val: C);
111 return GV && CBufferHandles.contains(Ptr: GV);
112 });
113 for (GlobalVariable *HandleGV : CBufferHandles)
114 HandleGV->removeDeadConstantUsers();
115
116 // Now that all uses are replaced, clean up the globals and metadata.
117 for (const hlsl::CBufferMapping &Mapping : *CBufMD) {
118 for (const auto &Member : Mapping.Members) {
119 Member.GV->eraseFromParent();
120 }
121 // Erase the stores to the handle variable before erasing the handle itself.
122 SmallVector<Instruction *, 4> HandleStores;
123 for (User *U : Mapping.Handle->users()) {
124 if (auto *SI = dyn_cast<StoreInst>(Val: U)) {
125 HandleStores.push_back(Elt: SI);
126 }
127 }
128 for (Instruction *I : HandleStores) {
129 I->eraseFromParent();
130 }
131 Mapping.Handle->eraseFromParent();
132 }
133
134 CBufMD->eraseFromModule();
135 return true;
136}
137
138PreservedAnalyses SPIRVCBufferAccess::run(Module &M,
139 ModuleAnalysisManager &AM) {
140 if (replaceCBufferAccesses(M)) {
141 return PreservedAnalyses::none();
142 }
143 return PreservedAnalyses::all();
144}
145
146namespace {
147class SPIRVCBufferAccessLegacy : public ModulePass {
148public:
149 bool runOnModule(Module &M) override { return replaceCBufferAccesses(M); }
150 StringRef getPassName() const override { return "SPIRV CBuffer Access"; }
151 SPIRVCBufferAccessLegacy() : ModulePass(ID) {}
152
153 static char ID; // Pass identification.
154};
155char SPIRVCBufferAccessLegacy::ID = 0;
156} // end anonymous namespace
157
158INITIALIZE_PASS(SPIRVCBufferAccessLegacy, DEBUG_TYPE, "SPIRV CBuffer Access",
159 false, false)
160
161ModulePass *llvm::createSPIRVCBufferAccessLegacyPass() {
162 return new SPIRVCBufferAccessLegacy();
163}
164