1//===- SPIRVLegalizeImplicitBinding.cpp - Legalize implicit bindings ----*- C++
2//-*-===//
3//
4// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
5// See https://llvm.org/LICENSE.txt for license information.
6// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7//
8//===----------------------------------------------------------------------===//
9//
10// This pass legalizes the @llvm.spv.resource.handlefromimplicitbinding
11// intrinsic by replacing it with a call to
12// @llvm.spv.resource.handlefrombinding.
13//
14//===----------------------------------------------------------------------===//
15
16#include "SPIRVLegalizeImplicitBinding.h"
17#include "SPIRV.h"
18#include "llvm/ADT/BitVector.h"
19#include "llvm/ADT/SmallVector.h"
20#include "llvm/IR/IRBuilder.h"
21#include "llvm/IR/InstVisitor.h"
22#include "llvm/IR/Intrinsics.h"
23#include "llvm/IR/IntrinsicsSPIRV.h"
24#include "llvm/IR/Module.h"
25#include "llvm/Pass.h"
26#include <algorithm>
27#include <vector>
28
29using namespace llvm;
30
31namespace {
32class SPIRVLegalizeImplicitBindingImpl {
33public:
34 bool runOnModule(Module &M);
35
36private:
37 void collectBindingInfo(Module &M);
38 uint32_t getAndReserveFirstUnusedBinding(uint32_t DescSet);
39 void replaceImplicitBindingCalls(Module &M);
40 void replaceResourceHandleCall(Module &M, CallInst *OldCI,
41 uint32_t NewBinding);
42 void replaceCounterHandleCall(Module &M, CallInst *OldCI,
43 uint32_t NewBinding);
44 void verifyUniqueOrderIdPerResource(SmallVectorImpl<CallInst *> &Calls);
45
46 // A map from descriptor set to a bit vector of used binding numbers.
47 std::vector<BitVector> UsedBindings;
48 // A list of all implicit binding calls, to be sorted by order ID.
49 SmallVector<CallInst *, 16> ImplicitBindingCalls;
50};
51
52class SPIRVLegalizeImplicitBindingLegacy : public ModulePass {
53public:
54 static char ID;
55 SPIRVLegalizeImplicitBindingLegacy() : ModulePass(ID) {}
56 StringRef getPassName() const override {
57 return "SPIRV Legalize Implicit Binding";
58 }
59 bool runOnModule(Module &M) override {
60 return SPIRVLegalizeImplicitBindingImpl().runOnModule(M);
61 }
62};
63
64struct BindingInfoCollector : public InstVisitor<BindingInfoCollector> {
65 std::vector<BitVector> &UsedBindings;
66 SmallVector<CallInst *, 16> &ImplicitBindingCalls;
67
68 BindingInfoCollector(std::vector<BitVector> &UsedBindings,
69 SmallVector<CallInst *, 16> &ImplicitBindingCalls)
70 : UsedBindings(UsedBindings), ImplicitBindingCalls(ImplicitBindingCalls) {
71 }
72
73 void addBinding(uint32_t DescSet, uint32_t Binding) {
74 if (UsedBindings.size() <= DescSet) {
75 UsedBindings.resize(new_size: DescSet + 1);
76 UsedBindings[DescSet].resize(N: 64);
77 }
78 if (UsedBindings[DescSet].size() <= Binding) {
79 UsedBindings[DescSet].resize(N: 2 * Binding + 1);
80 }
81 UsedBindings[DescSet].set(Binding);
82 }
83
84 void visitCallInst(CallInst &CI) {
85 if (CI.getIntrinsicID() == Intrinsic::spv_resource_handlefrombinding) {
86 const uint32_t DescSet =
87 cast<ConstantInt>(Val: CI.getArgOperand(i: 0))->getZExtValue();
88 const uint32_t Binding =
89 cast<ConstantInt>(Val: CI.getArgOperand(i: 1))->getZExtValue();
90 addBinding(DescSet, Binding);
91 } else if (CI.getIntrinsicID() ==
92 Intrinsic::spv_resource_handlefromimplicitbinding) {
93 ImplicitBindingCalls.push_back(Elt: &CI);
94 } else if (CI.getIntrinsicID() ==
95 Intrinsic::spv_resource_counterhandlefrombinding) {
96 const uint32_t DescSet =
97 cast<ConstantInt>(Val: CI.getArgOperand(i: 2))->getZExtValue();
98 const uint32_t Binding =
99 cast<ConstantInt>(Val: CI.getArgOperand(i: 1))->getZExtValue();
100 addBinding(DescSet, Binding);
101 } else if (CI.getIntrinsicID() ==
102 Intrinsic::spv_resource_counterhandlefromimplicitbinding) {
103 ImplicitBindingCalls.push_back(Elt: &CI);
104 }
105 }
106};
107
108static uint32_t getOrderId(const CallInst *CI) {
109 uint32_t OrderIdArgIdx = 0;
110 switch (CI->getIntrinsicID()) {
111 case Intrinsic::spv_resource_handlefromimplicitbinding:
112 OrderIdArgIdx = 0;
113 break;
114 case Intrinsic::spv_resource_counterhandlefromimplicitbinding:
115 OrderIdArgIdx = 1;
116 break;
117 default:
118 llvm_unreachable("CallInst is not an implicit binding intrinsic");
119 }
120 return cast<ConstantInt>(Val: CI->getArgOperand(i: OrderIdArgIdx))->getZExtValue();
121}
122
123static uint32_t getDescSet(const CallInst *CI) {
124 uint32_t DescSetArgIdx;
125 switch (CI->getIntrinsicID()) {
126 case Intrinsic::spv_resource_handlefromimplicitbinding:
127 case Intrinsic::spv_resource_handlefrombinding:
128 DescSetArgIdx = 1;
129 break;
130 case Intrinsic::spv_resource_counterhandlefromimplicitbinding:
131 case Intrinsic::spv_resource_counterhandlefrombinding:
132 DescSetArgIdx = 2;
133 break;
134 default:
135 llvm_unreachable("CallInst is not an implicit binding intrinsic");
136 }
137 return cast<ConstantInt>(Val: CI->getArgOperand(i: DescSetArgIdx))->getZExtValue();
138}
139
140void SPIRVLegalizeImplicitBindingImpl::collectBindingInfo(Module &M) {
141 BindingInfoCollector InfoCollector(UsedBindings, ImplicitBindingCalls);
142 InfoCollector.visit(M);
143
144 // Sort the collected calls by their order ID.
145 std::sort(first: ImplicitBindingCalls.begin(), last: ImplicitBindingCalls.end(),
146 comp: [](const CallInst *A, const CallInst *B) {
147 return getOrderId(CI: A) < getOrderId(CI: B);
148 });
149}
150
151void SPIRVLegalizeImplicitBindingImpl::verifyUniqueOrderIdPerResource(
152 SmallVectorImpl<CallInst *> &Calls) {
153 // Check that the order Id is unique per resource.
154 for (uint32_t i = 1; i < Calls.size(); ++i) {
155 const uint32_t OrderA = getOrderId(CI: Calls[i - 1]);
156 const uint32_t OrderB = getOrderId(CI: Calls[i]);
157 if (OrderA == OrderB) {
158 const uint32_t DescSetA = getDescSet(CI: Calls[i - 1]);
159 const uint32_t DescSetB = getDescSet(CI: Calls[i]);
160 if (DescSetA != DescSetB) {
161 report_fatal_error(reason: "Implicit binding calls with the same order ID must "
162 "have the same descriptor set");
163 }
164 }
165 }
166}
167
168uint32_t SPIRVLegalizeImplicitBindingImpl::getAndReserveFirstUnusedBinding(
169 uint32_t DescSet) {
170 if (UsedBindings.size() <= DescSet) {
171 UsedBindings.resize(new_size: DescSet + 1);
172 UsedBindings[DescSet].resize(N: 64);
173 }
174
175 int NewBinding = UsedBindings[DescSet].find_first_unset();
176 if (NewBinding == -1) {
177 NewBinding = UsedBindings[DescSet].size();
178 UsedBindings[DescSet].resize(N: 2 * NewBinding + 1);
179 }
180
181 UsedBindings[DescSet].set(NewBinding);
182 return NewBinding;
183}
184
185void SPIRVLegalizeImplicitBindingImpl::replaceImplicitBindingCalls(Module &M) {
186 uint32_t lastOrderId = -1;
187 uint32_t lastBindingNumber = -1;
188
189 for (CallInst *OldCI : ImplicitBindingCalls) {
190 const uint32_t OrderId = getOrderId(CI: OldCI);
191 uint32_t BindingNumber;
192 if (OrderId == lastOrderId) {
193 BindingNumber = lastBindingNumber;
194 } else {
195 const uint32_t DescSet = getDescSet(CI: OldCI);
196 BindingNumber = getAndReserveFirstUnusedBinding(DescSet);
197 }
198
199 if (OldCI->getIntrinsicID() ==
200 Intrinsic::spv_resource_handlefromimplicitbinding) {
201 replaceResourceHandleCall(M, OldCI, NewBinding: BindingNumber);
202 } else {
203 assert(OldCI->getIntrinsicID() ==
204 Intrinsic::spv_resource_counterhandlefromimplicitbinding &&
205 "Unexpected implicit binding intrinsic");
206 replaceCounterHandleCall(M, OldCI, NewBinding: BindingNumber);
207 }
208 lastOrderId = OrderId;
209 lastBindingNumber = BindingNumber;
210 }
211}
212
213bool SPIRVLegalizeImplicitBindingImpl::runOnModule(Module &M) {
214 collectBindingInfo(M);
215 if (ImplicitBindingCalls.empty()) {
216 return false;
217 }
218 verifyUniqueOrderIdPerResource(Calls&: ImplicitBindingCalls);
219
220 replaceImplicitBindingCalls(M);
221 return true;
222}
223} // namespace
224
225PreservedAnalyses SPIRVLegalizeImplicitBinding::run(Module &M,
226 ModuleAnalysisManager &AM) {
227 return SPIRVLegalizeImplicitBindingImpl().runOnModule(M)
228 ? PreservedAnalyses::none()
229 : PreservedAnalyses::all();
230}
231
232char SPIRVLegalizeImplicitBindingLegacy::ID = 0;
233
234INITIALIZE_PASS(SPIRVLegalizeImplicitBindingLegacy,
235 "legalize-spirv-implicit-binding",
236 "Legalize SPIR-V implicit bindings", false, false)
237
238ModulePass *llvm::createSPIRVLegalizeImplicitBindingPass() {
239 return new SPIRVLegalizeImplicitBindingLegacy();
240}
241
242void SPIRVLegalizeImplicitBindingImpl::replaceResourceHandleCall(
243 Module &M, CallInst *OldCI, uint32_t NewBinding) {
244 IRBuilder<> Builder(OldCI);
245 const uint32_t DescSet =
246 cast<ConstantInt>(Val: OldCI->getArgOperand(i: 1))->getZExtValue();
247
248 SmallVector<Value *, 8> Args;
249 Args.push_back(Elt: Builder.getInt32(C: DescSet));
250 Args.push_back(Elt: Builder.getInt32(C: NewBinding));
251
252 // Copy the remaining arguments from the old call.
253 for (uint32_t i = 2; i < OldCI->arg_size(); ++i) {
254 Args.push_back(Elt: OldCI->getArgOperand(i));
255 }
256
257 Function *NewFunc = Intrinsic::getOrInsertDeclaration(
258 M: &M, id: Intrinsic::spv_resource_handlefrombinding, OverloadTys: OldCI->getType());
259 CallInst *NewCI = Builder.CreateCall(Callee: NewFunc, Args);
260 NewCI->setCallingConv(OldCI->getCallingConv());
261
262 OldCI->replaceAllUsesWith(V: NewCI);
263 OldCI->eraseFromParent();
264}
265
266void SPIRVLegalizeImplicitBindingImpl::replaceCounterHandleCall(
267 Module &M, CallInst *OldCI, uint32_t NewBinding) {
268 IRBuilder<> Builder(OldCI);
269 const uint32_t DescSet =
270 cast<ConstantInt>(Val: OldCI->getArgOperand(i: 2))->getZExtValue();
271
272 SmallVector<Value *, 8> Args;
273 Args.push_back(Elt: OldCI->getArgOperand(i: 0));
274 Args.push_back(Elt: Builder.getInt32(C: NewBinding));
275 Args.push_back(Elt: Builder.getInt32(C: DescSet));
276
277 Type *Tys[] = {OldCI->getType(), OldCI->getArgOperand(i: 0)->getType()};
278 Function *NewFunc = Intrinsic::getOrInsertDeclaration(
279 M: &M, id: Intrinsic::spv_resource_counterhandlefrombinding, OverloadTys: Tys);
280 CallInst *NewCI = Builder.CreateCall(Callee: NewFunc, Args);
281 NewCI->setCallingConv(OldCI->getCallingConv());
282
283 OldCI->replaceAllUsesWith(V: NewCI);
284 OldCI->eraseFromParent();
285}
286