1//===- SPIRVLegalizeImplicitBinding.cpp - Legalize implicit bindings ----*- C++
2//-*-===//
3//
4// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
5// See https://llvm.org/LICENSE.txt for license information.
6// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7//
8//===----------------------------------------------------------------------===//
9//
10// This pass legalizes the @llvm.spv.resource.handlefromimplicitbinding
11// intrinsic by replacing it with a call to
12// @llvm.spv.resource.handlefrombinding.
13//
14//===----------------------------------------------------------------------===//
15
16#include "SPIRV.h"
17#include "llvm/ADT/BitVector.h"
18#include "llvm/ADT/SmallVector.h"
19#include "llvm/IR/IRBuilder.h"
20#include "llvm/IR/InstVisitor.h"
21#include "llvm/IR/Intrinsics.h"
22#include "llvm/IR/IntrinsicsSPIRV.h"
23#include "llvm/IR/Module.h"
24#include "llvm/Pass.h"
25#include <algorithm>
26#include <vector>
27
28using namespace llvm;
29
30namespace {
31class SPIRVLegalizeImplicitBinding : public ModulePass {
32public:
33 static char ID;
34 SPIRVLegalizeImplicitBinding() : ModulePass(ID) {}
35 StringRef getPassName() const override {
36 return "SPIRV Legalize Implicit Binding";
37 }
38 bool runOnModule(Module &M) override;
39
40private:
41 void collectBindingInfo(Module &M);
42 uint32_t getAndReserveFirstUnusedBinding(uint32_t DescSet);
43 void replaceImplicitBindingCalls(Module &M);
44 void replaceResourceHandleCall(Module &M, CallInst *OldCI,
45 uint32_t NewBinding);
46 void replaceCounterHandleCall(Module &M, CallInst *OldCI,
47 uint32_t NewBinding);
48 void verifyUniqueOrderIdPerResource(SmallVectorImpl<CallInst *> &Calls);
49
50 // A map from descriptor set to a bit vector of used binding numbers.
51 std::vector<BitVector> UsedBindings;
52 // A list of all implicit binding calls, to be sorted by order ID.
53 SmallVector<CallInst *, 16> ImplicitBindingCalls;
54};
55
56struct BindingInfoCollector : public InstVisitor<BindingInfoCollector> {
57 std::vector<BitVector> &UsedBindings;
58 SmallVector<CallInst *, 16> &ImplicitBindingCalls;
59
60 BindingInfoCollector(std::vector<BitVector> &UsedBindings,
61 SmallVector<CallInst *, 16> &ImplicitBindingCalls)
62 : UsedBindings(UsedBindings), ImplicitBindingCalls(ImplicitBindingCalls) {
63 }
64
65 void addBinding(uint32_t DescSet, uint32_t Binding) {
66 if (UsedBindings.size() <= DescSet) {
67 UsedBindings.resize(new_size: DescSet + 1);
68 UsedBindings[DescSet].resize(N: 64);
69 }
70 if (UsedBindings[DescSet].size() <= Binding) {
71 UsedBindings[DescSet].resize(N: 2 * Binding + 1);
72 }
73 UsedBindings[DescSet].set(Binding);
74 }
75
76 void visitCallInst(CallInst &CI) {
77 if (CI.getIntrinsicID() == Intrinsic::spv_resource_handlefrombinding) {
78 const uint32_t DescSet =
79 cast<ConstantInt>(Val: CI.getArgOperand(i: 0))->getZExtValue();
80 const uint32_t Binding =
81 cast<ConstantInt>(Val: CI.getArgOperand(i: 1))->getZExtValue();
82 addBinding(DescSet, Binding);
83 } else if (CI.getIntrinsicID() ==
84 Intrinsic::spv_resource_handlefromimplicitbinding) {
85 ImplicitBindingCalls.push_back(Elt: &CI);
86 } else if (CI.getIntrinsicID() ==
87 Intrinsic::spv_resource_counterhandlefrombinding) {
88 const uint32_t DescSet =
89 cast<ConstantInt>(Val: CI.getArgOperand(i: 2))->getZExtValue();
90 const uint32_t Binding =
91 cast<ConstantInt>(Val: CI.getArgOperand(i: 1))->getZExtValue();
92 addBinding(DescSet, Binding);
93 } else if (CI.getIntrinsicID() ==
94 Intrinsic::spv_resource_counterhandlefromimplicitbinding) {
95 ImplicitBindingCalls.push_back(Elt: &CI);
96 }
97 }
98};
99
100static uint32_t getOrderId(const CallInst *CI) {
101 uint32_t OrderIdArgIdx = 0;
102 switch (CI->getIntrinsicID()) {
103 case Intrinsic::spv_resource_handlefromimplicitbinding:
104 OrderIdArgIdx = 0;
105 break;
106 case Intrinsic::spv_resource_counterhandlefromimplicitbinding:
107 OrderIdArgIdx = 1;
108 break;
109 default:
110 llvm_unreachable("CallInst is not an implicit binding intrinsic");
111 }
112 return cast<ConstantInt>(Val: CI->getArgOperand(i: OrderIdArgIdx))->getZExtValue();
113}
114
115static uint32_t getDescSet(const CallInst *CI) {
116 uint32_t DescSetArgIdx;
117 switch (CI->getIntrinsicID()) {
118 case Intrinsic::spv_resource_handlefromimplicitbinding:
119 case Intrinsic::spv_resource_handlefrombinding:
120 DescSetArgIdx = 1;
121 break;
122 case Intrinsic::spv_resource_counterhandlefromimplicitbinding:
123 case Intrinsic::spv_resource_counterhandlefrombinding:
124 DescSetArgIdx = 2;
125 break;
126 default:
127 llvm_unreachable("CallInst is not an implicit binding intrinsic");
128 }
129 return cast<ConstantInt>(Val: CI->getArgOperand(i: DescSetArgIdx))->getZExtValue();
130}
131
132void SPIRVLegalizeImplicitBinding::collectBindingInfo(Module &M) {
133 BindingInfoCollector InfoCollector(UsedBindings, ImplicitBindingCalls);
134 InfoCollector.visit(M);
135
136 // Sort the collected calls by their order ID.
137 std::sort(first: ImplicitBindingCalls.begin(), last: ImplicitBindingCalls.end(),
138 comp: [](const CallInst *A, const CallInst *B) {
139 return getOrderId(CI: A) < getOrderId(CI: B);
140 });
141}
142
143void SPIRVLegalizeImplicitBinding::verifyUniqueOrderIdPerResource(
144 SmallVectorImpl<CallInst *> &Calls) {
145 // Check that the order Id is unique per resource.
146 for (uint32_t i = 1; i < Calls.size(); ++i) {
147 const uint32_t OrderA = getOrderId(CI: Calls[i - 1]);
148 const uint32_t OrderB = getOrderId(CI: Calls[i]);
149 if (OrderA == OrderB) {
150 const uint32_t DescSetA = getDescSet(CI: Calls[i - 1]);
151 const uint32_t DescSetB = getDescSet(CI: Calls[i]);
152 if (DescSetA != DescSetB) {
153 report_fatal_error(reason: "Implicit binding calls with the same order ID must "
154 "have the same descriptor set");
155 }
156 }
157 }
158}
159
160uint32_t SPIRVLegalizeImplicitBinding::getAndReserveFirstUnusedBinding(
161 uint32_t DescSet) {
162 if (UsedBindings.size() <= DescSet) {
163 UsedBindings.resize(new_size: DescSet + 1);
164 UsedBindings[DescSet].resize(N: 64);
165 }
166
167 int NewBinding = UsedBindings[DescSet].find_first_unset();
168 if (NewBinding == -1) {
169 NewBinding = UsedBindings[DescSet].size();
170 UsedBindings[DescSet].resize(N: 2 * NewBinding + 1);
171 }
172
173 UsedBindings[DescSet].set(NewBinding);
174 return NewBinding;
175}
176
177void SPIRVLegalizeImplicitBinding::replaceImplicitBindingCalls(Module &M) {
178 uint32_t lastOrderId = -1;
179 uint32_t lastBindingNumber = -1;
180
181 for (CallInst *OldCI : ImplicitBindingCalls) {
182 const uint32_t OrderId = getOrderId(CI: OldCI);
183 uint32_t BindingNumber;
184 if (OrderId == lastOrderId) {
185 BindingNumber = lastBindingNumber;
186 } else {
187 const uint32_t DescSet = getDescSet(CI: OldCI);
188 BindingNumber = getAndReserveFirstUnusedBinding(DescSet);
189 }
190
191 if (OldCI->getIntrinsicID() ==
192 Intrinsic::spv_resource_handlefromimplicitbinding) {
193 replaceResourceHandleCall(M, OldCI, NewBinding: BindingNumber);
194 } else {
195 assert(OldCI->getIntrinsicID() ==
196 Intrinsic::spv_resource_counterhandlefromimplicitbinding &&
197 "Unexpected implicit binding intrinsic");
198 replaceCounterHandleCall(M, OldCI, NewBinding: BindingNumber);
199 }
200 lastOrderId = OrderId;
201 lastBindingNumber = BindingNumber;
202 }
203}
204
205bool SPIRVLegalizeImplicitBinding::runOnModule(Module &M) {
206 collectBindingInfo(M);
207 if (ImplicitBindingCalls.empty()) {
208 return false;
209 }
210 verifyUniqueOrderIdPerResource(Calls&: ImplicitBindingCalls);
211
212 replaceImplicitBindingCalls(M);
213 return true;
214}
215} // namespace
216
217char SPIRVLegalizeImplicitBinding::ID = 0;
218
219INITIALIZE_PASS(SPIRVLegalizeImplicitBinding, "legalize-spirv-implicit-binding",
220 "Legalize SPIR-V implicit bindings", false, false)
221
222ModulePass *llvm::createSPIRVLegalizeImplicitBindingPass() {
223 return new SPIRVLegalizeImplicitBinding();
224}
225
226void SPIRVLegalizeImplicitBinding::replaceResourceHandleCall(
227 Module &M, CallInst *OldCI, uint32_t NewBinding) {
228 IRBuilder<> Builder(OldCI);
229 const uint32_t DescSet =
230 cast<ConstantInt>(Val: OldCI->getArgOperand(i: 1))->getZExtValue();
231
232 SmallVector<Value *, 8> Args;
233 Args.push_back(Elt: Builder.getInt32(C: DescSet));
234 Args.push_back(Elt: Builder.getInt32(C: NewBinding));
235
236 // Copy the remaining arguments from the old call.
237 for (uint32_t i = 2; i < OldCI->arg_size(); ++i) {
238 Args.push_back(Elt: OldCI->getArgOperand(i));
239 }
240
241 Function *NewFunc = Intrinsic::getOrInsertDeclaration(
242 M: &M, id: Intrinsic::spv_resource_handlefrombinding, Tys: OldCI->getType());
243 CallInst *NewCI = Builder.CreateCall(Callee: NewFunc, Args);
244 NewCI->setCallingConv(OldCI->getCallingConv());
245
246 OldCI->replaceAllUsesWith(V: NewCI);
247 OldCI->eraseFromParent();
248}
249
250void SPIRVLegalizeImplicitBinding::replaceCounterHandleCall(
251 Module &M, CallInst *OldCI, uint32_t NewBinding) {
252 IRBuilder<> Builder(OldCI);
253 const uint32_t DescSet =
254 cast<ConstantInt>(Val: OldCI->getArgOperand(i: 2))->getZExtValue();
255
256 SmallVector<Value *, 8> Args;
257 Args.push_back(Elt: OldCI->getArgOperand(i: 0));
258 Args.push_back(Elt: Builder.getInt32(C: NewBinding));
259 Args.push_back(Elt: Builder.getInt32(C: DescSet));
260
261 Type *Tys[] = {OldCI->getType(), OldCI->getArgOperand(i: 0)->getType()};
262 Function *NewFunc = Intrinsic::getOrInsertDeclaration(
263 M: &M, id: Intrinsic::spv_resource_counterhandlefrombinding, Tys);
264 CallInst *NewCI = Builder.CreateCall(Callee: NewFunc, Args);
265 NewCI->setCallingConv(OldCI->getCallingConv());
266
267 OldCI->replaceAllUsesWith(V: NewCI);
268 OldCI->eraseFromParent();
269}
270