1//===-- SPIRVLegalizePointerCast.cpp ----------------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// The LLVM IR has multiple legal patterns we cannot lower to Logical SPIR-V.
10// This pass modifies such loads to have an IR we can directly lower to valid
11// logical SPIR-V.
12// OpenCL can avoid this because they rely on ptrcast, which is not supported
13// by logical SPIR-V.
14//
15// This pass relies on the assign_ptr_type intrinsic to deduce the type of the
16// pointed values, must replace all occurences of `ptrcast`. This is why
17// unhandled cases are reported as unreachable: we MUST cover all cases.
18//
19// 1. Loading the first element of an array
20//
21// %array = [10 x i32]
22// %value = load i32, ptr %array
23//
24// LLVM can skip the GEP instruction, and only request loading the first 4
25// bytes. In logical SPIR-V, we need an OpAccessChain to access the first
26// element. This pass will add a getelementptr instruction before the load.
27//
28//
29// 2. Implicit downcast from load
30//
31// %1 = getelementptr <4 x i32>, ptr %vec4, i64 0
32// %2 = load <3 x i32>, ptr %1
33//
34// The pointer in the GEP instruction is only used for offset computations,
35// but it doesn't NEED to match the pointed type. OpAccessChain however
36// requires this. Also, LLVM loads define the bitwidth of the load, not the
37// pointer. In this example, we can guess %vec4 is a vec4 thanks to the GEP
38// instruction basetype, but we only want to load the first 3 elements, hence
39// do a partial load. In logical SPIR-V, this is not legal. What we must do
40// is load the full vector (basetype), extract 3 elements, and recombine them
41// to form a 3-element vector.
42//
43//===----------------------------------------------------------------------===//
44
45#include "SPIRV.h"
46#include "SPIRVSubtarget.h"
47#include "SPIRVTargetMachine.h"
48#include "SPIRVUtils.h"
49#include "llvm/IR/IRBuilder.h"
50#include "llvm/IR/IntrinsicInst.h"
51#include "llvm/IR/Intrinsics.h"
52#include "llvm/IR/IntrinsicsSPIRV.h"
53#include "llvm/Transforms/Utils/Cloning.h"
54#include "llvm/Transforms/Utils/LowerMemIntrinsics.h"
55
56using namespace llvm;
57
58namespace {
59class SPIRVLegalizePointerCast : public FunctionPass {
60
61 // Builds the `spv_assign_type` assigning |Ty| to |Value| at the current
62 // builder position.
63 void buildAssignType(IRBuilder<> &B, Type *Ty, Value *Arg) {
64 Value *OfType = PoisonValue::get(T: Ty);
65 CallInst *AssignCI = buildIntrWithMD(IntrID: Intrinsic::spv_assign_type,
66 Types: {Arg->getType()}, Arg: OfType, Arg2: Arg, Imms: {}, B);
67 GR->addAssignPtrTypeInstr(Val: Arg, AssignPtrTyCI: AssignCI);
68 }
69
70 // Loads parts of the vector of type |SourceType| from the pointer |Source|
71 // and create a new vector of type |TargetType|. |TargetType| must be a vector
72 // type, and element types of |TargetType| and |SourceType| must match.
73 // Returns the loaded value.
74 Value *loadVectorFromVector(IRBuilder<> &B, FixedVectorType *SourceType,
75 FixedVectorType *TargetType, Value *Source) {
76 LoadInst *NewLoad = B.CreateLoad(Ty: SourceType, Ptr: Source);
77 buildAssignType(B, Ty: SourceType, Arg: NewLoad);
78 Value *AssignValue = NewLoad;
79 if (TargetType->getElementType() != SourceType->getElementType()) {
80 const DataLayout &DL = B.GetInsertBlock()->getModule()->getDataLayout();
81 [[maybe_unused]] TypeSize TargetTypeSize =
82 DL.getTypeSizeInBits(Ty: TargetType);
83 [[maybe_unused]] TypeSize SourceTypeSize =
84 DL.getTypeSizeInBits(Ty: SourceType);
85 assert(TargetTypeSize == SourceTypeSize);
86 AssignValue = B.CreateIntrinsic(ID: Intrinsic::spv_bitcast,
87 Types: {TargetType, SourceType}, Args: {NewLoad});
88 buildAssignType(B, Ty: TargetType, Arg: AssignValue);
89 return AssignValue;
90 }
91
92 assert(TargetType->getNumElements() < SourceType->getNumElements());
93 SmallVector<int> Mask(/* Size= */ TargetType->getNumElements());
94 for (unsigned I = 0; I < TargetType->getNumElements(); ++I)
95 Mask[I] = I;
96 Value *Output = B.CreateShuffleVector(V1: AssignValue, V2: AssignValue, Mask);
97 buildAssignType(B, Ty: TargetType, Arg: Output);
98 return Output;
99 }
100
101 // Loads the first value in an aggregate pointed by |Source| of containing
102 // elements of type |ElementType|. Load flags will be copied from |BadLoad|,
103 // which should be the load being legalized. Returns the loaded value.
104 Value *loadFirstValueFromAggregate(IRBuilder<> &B, Type *ElementType,
105 Value *Source, LoadInst *BadLoad) {
106 SmallVector<Type *, 2> Types = {BadLoad->getPointerOperandType(),
107 Source->getType()};
108 SmallVector<Value *, 8> Args{/* isInBounds= */ B.getInt1(V: false), Source};
109
110 Type *AggregateType = GR->findDeducedElementType(Val: Source);
111 assert(AggregateType && "Could not deduce aggregate type");
112 buildGEPIndexChain(B, Search: ElementType, Aggregate: AggregateType, Indices&: Args);
113
114 auto *GEP = B.CreateIntrinsic(ID: Intrinsic::spv_gep, Types: {Types}, Args: {Args});
115 GR->buildAssignPtr(B, ElemTy: ElementType, Arg: GEP);
116
117 LoadInst *LI = B.CreateLoad(Ty: ElementType, Ptr: GEP);
118 LI->setAlignment(BadLoad->getAlign());
119 buildAssignType(B, Ty: ElementType, Arg: LI);
120 return LI;
121 }
122
123 // Loads elements from an array and constructs a vector.
124 Value *loadVectorFromArray(IRBuilder<> &B, FixedVectorType *TargetType,
125 Value *Source) {
126 // Load each element of the array.
127 SmallVector<Value *, 4> LoadedElements;
128 for (unsigned i = 0; i < TargetType->getNumElements(); ++i) {
129 // Create a GEP to access the i-th element of the array.
130 SmallVector<Type *, 2> Types = {Source->getType(), Source->getType()};
131 SmallVector<Value *, 4> Args;
132 Args.push_back(Elt: B.getInt1(V: false));
133 Args.push_back(Elt: Source);
134 Args.push_back(Elt: B.getInt32(C: 0));
135 Args.push_back(Elt: ConstantInt::get(Ty: B.getInt32Ty(), V: i));
136 auto *ElementPtr = B.CreateIntrinsic(ID: Intrinsic::spv_gep, Types: {Types}, Args: {Args});
137 GR->buildAssignPtr(B, ElemTy: TargetType->getElementType(), Arg: ElementPtr);
138
139 // Load the value from the element pointer.
140 Value *Load = B.CreateLoad(Ty: TargetType->getElementType(), Ptr: ElementPtr);
141 buildAssignType(B, Ty: TargetType->getElementType(), Arg: Load);
142 LoadedElements.push_back(Elt: Load);
143 }
144
145 // Build the vector from the loaded elements.
146 Value *NewVector = PoisonValue::get(T: TargetType);
147 buildAssignType(B, Ty: TargetType, Arg: NewVector);
148
149 for (unsigned i = 0; i < TargetType->getNumElements(); ++i) {
150 Value *Index = B.getInt32(C: i);
151 SmallVector<Type *, 4> Types = {TargetType, TargetType,
152 TargetType->getElementType(),
153 Index->getType()};
154 SmallVector<Value *> Args = {NewVector, LoadedElements[i], Index};
155 NewVector = B.CreateIntrinsic(ID: Intrinsic::spv_insertelt, Types: {Types}, Args: {Args});
156 buildAssignType(B, Ty: TargetType, Arg: NewVector);
157 }
158 return NewVector;
159 }
160
161 // Stores elements from a vector into an array.
162 void storeArrayFromVector(IRBuilder<> &B, Value *SrcVector,
163 Value *DstArrayPtr, ArrayType *ArrTy,
164 Align Alignment) {
165 auto *VecTy = cast<FixedVectorType>(Val: SrcVector->getType());
166
167 // Ensure the element types of the array and vector are the same.
168 assert(VecTy->getElementType() == ArrTy->getElementType() &&
169 "Element types of array and vector must be the same.");
170
171 const DataLayout &DL = B.GetInsertBlock()->getModule()->getDataLayout();
172 uint64_t ElemSize = DL.getTypeAllocSize(Ty: ArrTy->getElementType());
173
174 for (unsigned i = 0; i < VecTy->getNumElements(); ++i) {
175 // Create a GEP to access the i-th element of the array.
176 SmallVector<Type *, 2> Types = {DstArrayPtr->getType(),
177 DstArrayPtr->getType()};
178 SmallVector<Value *, 4> Args;
179 Args.push_back(Elt: B.getInt1(V: false));
180 Args.push_back(Elt: DstArrayPtr);
181 Args.push_back(Elt: B.getInt32(C: 0));
182 Args.push_back(Elt: ConstantInt::get(Ty: B.getInt32Ty(), V: i));
183 auto *ElementPtr = B.CreateIntrinsic(ID: Intrinsic::spv_gep, Types: {Types}, Args: {Args});
184 GR->buildAssignPtr(B, ElemTy: ArrTy->getElementType(), Arg: ElementPtr);
185
186 // Extract the element from the vector and store it.
187 Value *Index = B.getInt32(C: i);
188 SmallVector<Type *, 3> EltTypes = {VecTy->getElementType(), VecTy,
189 Index->getType()};
190 SmallVector<Value *, 2> EltArgs = {SrcVector, Index};
191 Value *Element =
192 B.CreateIntrinsic(ID: Intrinsic::spv_extractelt, Types: {EltTypes}, Args: {EltArgs});
193 buildAssignType(B, Ty: VecTy->getElementType(), Arg: Element);
194
195 Types = {Element->getType(), ElementPtr->getType()};
196 Align NewAlign = commonAlignment(A: Alignment, Offset: i * ElemSize);
197 Args = {Element, ElementPtr, B.getInt16(C: 2), B.getInt8(C: NewAlign.value())};
198 B.CreateIntrinsic(ID: Intrinsic::spv_store, Types: {Types}, Args: {Args});
199 }
200 }
201
202 // Replaces the load instruction to get rid of the ptrcast used as source
203 // operand.
204 void transformLoad(IRBuilder<> &B, LoadInst *LI, Value *CastedOperand,
205 Value *OriginalOperand) {
206 Type *FromTy = GR->findDeducedElementType(Val: OriginalOperand);
207 Type *ToTy = GR->findDeducedElementType(Val: CastedOperand);
208 Value *Output = nullptr;
209
210 auto *SAT = dyn_cast<ArrayType>(Val: FromTy);
211 auto *SVT = dyn_cast<FixedVectorType>(Val: FromTy);
212 auto *DVT = dyn_cast<FixedVectorType>(Val: ToTy);
213
214 B.SetInsertPoint(LI);
215
216 // Destination is the element type of some member of FromTy. For example,
217 // loading the 1st element of an array:
218 // - float a = array[0];
219 if (isTypeFirstElementAggregate(Search: ToTy, Aggregate: FromTy))
220 Output = loadFirstValueFromAggregate(B, ElementType: ToTy, Source: OriginalOperand, BadLoad: LI);
221 // Destination is a smaller vector than source or different vector type.
222 // - float3 v3 = vector4;
223 // - float4 v2 = int4;
224 else if (SVT && DVT)
225 Output = loadVectorFromVector(B, SourceType: SVT, TargetType: DVT, Source: OriginalOperand);
226 else if (SAT && DVT && SAT->getElementType() == DVT->getElementType())
227 Output = loadVectorFromArray(B, TargetType: DVT, Source: OriginalOperand);
228 else
229 llvm_unreachable("Unimplemented implicit down-cast from load.");
230
231 GR->replaceAllUsesWith(Old: LI, New: Output, /* DeleteOld= */ true);
232 DeadInstructions.push_back(x: LI);
233 }
234
235 // Creates an spv_insertelt instruction (equivalent to llvm's insertelement).
236 Value *makeInsertElement(IRBuilder<> &B, Value *Vector, Value *Element,
237 unsigned Index) {
238 Type *Int32Ty = Type::getInt32Ty(C&: B.getContext());
239 SmallVector<Type *, 4> Types = {Vector->getType(), Vector->getType(),
240 Element->getType(), Int32Ty};
241 SmallVector<Value *> Args = {Vector, Element, B.getInt32(C: Index)};
242 Instruction *NewI =
243 B.CreateIntrinsic(ID: Intrinsic::spv_insertelt, Types: {Types}, Args: {Args});
244 buildAssignType(B, Ty: Vector->getType(), Arg: NewI);
245 return NewI;
246 }
247
248 // Creates an spv_extractelt instruction (equivalent to llvm's
249 // extractelement).
250 Value *makeExtractElement(IRBuilder<> &B, Type *ElementType, Value *Vector,
251 unsigned Index) {
252 Type *Int32Ty = Type::getInt32Ty(C&: B.getContext());
253 SmallVector<Type *, 3> Types = {ElementType, Vector->getType(), Int32Ty};
254 SmallVector<Value *> Args = {Vector, B.getInt32(C: Index)};
255 Instruction *NewI =
256 B.CreateIntrinsic(ID: Intrinsic::spv_extractelt, Types: {Types}, Args: {Args});
257 buildAssignType(B, Ty: ElementType, Arg: NewI);
258 return NewI;
259 }
260
261 // Stores the given Src vector operand into the Dst vector, adjusting the size
262 // if required.
263 Value *storeVectorFromVector(IRBuilder<> &B, Value *Src, Value *Dst,
264 Align Alignment) {
265 FixedVectorType *SrcType = cast<FixedVectorType>(Val: Src->getType());
266 FixedVectorType *DstType =
267 cast<FixedVectorType>(Val: GR->findDeducedElementType(Val: Dst));
268 auto dstNumElements = DstType->getNumElements();
269 auto srcNumElements = SrcType->getNumElements();
270
271 // if the element type differs, it is a bitcast.
272 if (DstType->getElementType() != SrcType->getElementType()) {
273 // Support bitcast between vectors of different sizes only if
274 // the total bitwidth is the same.
275 [[maybe_unused]] auto dstBitWidth =
276 DstType->getElementType()->getScalarSizeInBits() * dstNumElements;
277 [[maybe_unused]] auto srcBitWidth =
278 SrcType->getElementType()->getScalarSizeInBits() * srcNumElements;
279 assert(dstBitWidth == srcBitWidth &&
280 "Unsupported bitcast between vectors of different sizes.");
281
282 Src =
283 B.CreateIntrinsic(ID: Intrinsic::spv_bitcast, Types: {DstType, SrcType}, Args: {Src});
284 buildAssignType(B, Ty: DstType, Arg: Src);
285 SrcType = DstType;
286
287 StoreInst *SI = B.CreateStore(Val: Src, Ptr: Dst);
288 SI->setAlignment(Alignment);
289 return SI;
290 }
291
292 assert(DstType->getNumElements() >= SrcType->getNumElements());
293 LoadInst *LI = B.CreateLoad(Ty: DstType, Ptr: Dst);
294 LI->setAlignment(Alignment);
295 Value *OldValues = LI;
296 buildAssignType(B, Ty: OldValues->getType(), Arg: OldValues);
297 Value *NewValues = Src;
298
299 for (unsigned I = 0; I < SrcType->getNumElements(); ++I) {
300 Value *Element =
301 makeExtractElement(B, ElementType: SrcType->getElementType(), Vector: NewValues, Index: I);
302 OldValues = makeInsertElement(B, Vector: OldValues, Element, Index: I);
303 }
304
305 StoreInst *SI = B.CreateStore(Val: OldValues, Ptr: Dst);
306 SI->setAlignment(Alignment);
307 return SI;
308 }
309
310 void buildGEPIndexChain(IRBuilder<> &B, Type *Search, Type *Aggregate,
311 SmallVectorImpl<Value *> &Indices) {
312 Indices.push_back(Elt: B.getInt32(C: 0));
313
314 if (Search == Aggregate)
315 return;
316
317 if (auto *ST = dyn_cast<StructType>(Val: Aggregate))
318 buildGEPIndexChain(B, Search, Aggregate: ST->getTypeAtIndex(N: 0u), Indices);
319 else if (auto *AT = dyn_cast<ArrayType>(Val: Aggregate))
320 buildGEPIndexChain(B, Search, Aggregate: AT->getElementType(), Indices);
321 else if (auto *VT = dyn_cast<FixedVectorType>(Val: Aggregate))
322 buildGEPIndexChain(B, Search, Aggregate: VT->getElementType(), Indices);
323 else
324 llvm_unreachable("Bad access chain?");
325 }
326
327 // Stores the given Src value into the first entry of the Dst aggregate.
328 Value *storeToFirstValueAggregate(IRBuilder<> &B, Value *Src, Value *Dst,
329 Type *DstPointeeType, Align Alignment) {
330 SmallVector<Type *, 2> Types = {Dst->getType(), Dst->getType()};
331 SmallVector<Value *, 8> Args{/* isInBounds= */ B.getInt1(V: true), Dst};
332 buildGEPIndexChain(B, Search: Src->getType(), Aggregate: DstPointeeType, Indices&: Args);
333 auto *GEP = B.CreateIntrinsic(ID: Intrinsic::spv_gep, Types: {Types}, Args: {Args});
334 GR->buildAssignPtr(B, ElemTy: Src->getType(), Arg: GEP);
335 StoreInst *SI = B.CreateStore(Val: Src, Ptr: GEP);
336 SI->setAlignment(Alignment);
337 return SI;
338 }
339
340 bool isTypeFirstElementAggregate(Type *Search, Type *Aggregate) {
341 if (Search == Aggregate)
342 return true;
343 if (auto *ST = dyn_cast<StructType>(Val: Aggregate))
344 return isTypeFirstElementAggregate(Search, Aggregate: ST->getTypeAtIndex(N: 0u));
345 if (auto *VT = dyn_cast<FixedVectorType>(Val: Aggregate))
346 return isTypeFirstElementAggregate(Search, Aggregate: VT->getElementType());
347 if (auto *AT = dyn_cast<ArrayType>(Val: Aggregate))
348 return isTypeFirstElementAggregate(Search, Aggregate: AT->getElementType());
349 return false;
350 }
351
352 // Transforms a store instruction (or SPV intrinsic) using a ptrcast as
353 // operand into a valid logical SPIR-V store with no ptrcast.
354 void transformStore(IRBuilder<> &B, Instruction *BadStore, Value *Src,
355 Value *Dst, Align Alignment) {
356 Type *ToTy = GR->findDeducedElementType(Val: Dst);
357 Type *FromTy = Src->getType();
358
359 auto *S_VT = dyn_cast<FixedVectorType>(Val: FromTy);
360 auto *D_VT = dyn_cast<FixedVectorType>(Val: ToTy);
361 auto *D_AT = dyn_cast<ArrayType>(Val: ToTy);
362
363 B.SetInsertPoint(BadStore);
364 if (isTypeFirstElementAggregate(Search: FromTy, Aggregate: ToTy))
365 storeToFirstValueAggregate(B, Src, Dst, DstPointeeType: ToTy, Alignment);
366 else if (D_VT && S_VT)
367 storeVectorFromVector(B, Src, Dst, Alignment);
368 else if (D_VT && !S_VT && FromTy == D_VT->getElementType())
369 storeToFirstValueAggregate(B, Src, Dst, DstPointeeType: D_VT, Alignment);
370 else if (D_AT && S_VT && S_VT->getElementType() == D_AT->getElementType())
371 storeArrayFromVector(B, SrcVector: Src, DstArrayPtr: Dst, ArrTy: D_AT, Alignment);
372 else
373 llvm_unreachable("Unsupported ptrcast use in store. Please fix.");
374
375 DeadInstructions.push_back(x: BadStore);
376 }
377
378 void legalizePointerCast(IntrinsicInst *II) {
379 Value *CastedOperand = II;
380 Value *OriginalOperand = II->getOperand(i_nocapture: 0);
381
382 IRBuilder<> B(II->getContext());
383 std::vector<Value *> Users;
384 for (Use &U : II->uses())
385 Users.push_back(x: U.getUser());
386
387 for (Value *User : Users) {
388 if (LoadInst *LI = dyn_cast<LoadInst>(Val: User)) {
389 transformLoad(B, LI, CastedOperand, OriginalOperand);
390 continue;
391 }
392
393 if (StoreInst *SI = dyn_cast<StoreInst>(Val: User)) {
394 transformStore(B, BadStore: SI, Src: SI->getValueOperand(), Dst: OriginalOperand,
395 Alignment: SI->getAlign());
396 continue;
397 }
398
399 if (IntrinsicInst *Intrin = dyn_cast<IntrinsicInst>(Val: User)) {
400 if (Intrin->getIntrinsicID() == Intrinsic::spv_assign_ptr_type) {
401 DeadInstructions.push_back(x: Intrin);
402 continue;
403 }
404
405 if (Intrin->getIntrinsicID() == Intrinsic::spv_gep) {
406 GR->replaceAllUsesWith(Old: CastedOperand, New: OriginalOperand,
407 /* DeleteOld= */ false);
408 continue;
409 }
410
411 if (Intrin->getIntrinsicID() == Intrinsic::spv_store) {
412 Align Alignment;
413 if (ConstantInt *C = dyn_cast<ConstantInt>(Val: Intrin->getOperand(i_nocapture: 3)))
414 Alignment = Align(C->getZExtValue());
415 transformStore(B, BadStore: Intrin, Src: Intrin->getArgOperand(i: 0), Dst: OriginalOperand,
416 Alignment);
417 continue;
418 }
419 }
420
421 llvm_unreachable("Unsupported ptrcast user. Please fix.");
422 }
423
424 DeadInstructions.push_back(x: II);
425 }
426
427public:
428 SPIRVLegalizePointerCast(SPIRVTargetMachine *TM) : FunctionPass(ID), TM(TM) {}
429
430 bool runOnFunction(Function &F) override {
431 const SPIRVSubtarget &ST = TM->getSubtarget<SPIRVSubtarget>(F);
432 GR = ST.getSPIRVGlobalRegistry();
433 DeadInstructions.clear();
434
435 std::vector<IntrinsicInst *> WorkList;
436 for (auto &BB : F) {
437 for (auto &I : BB) {
438 auto *II = dyn_cast<IntrinsicInst>(Val: &I);
439 if (II && II->getIntrinsicID() == Intrinsic::spv_ptrcast)
440 WorkList.push_back(x: II);
441 }
442 }
443
444 for (IntrinsicInst *II : WorkList)
445 legalizePointerCast(II);
446
447 for (Instruction *I : DeadInstructions)
448 I->eraseFromParent();
449
450 return DeadInstructions.size() != 0;
451 }
452
453private:
454 SPIRVTargetMachine *TM = nullptr;
455 SPIRVGlobalRegistry *GR = nullptr;
456 std::vector<Instruction *> DeadInstructions;
457
458public:
459 static char ID;
460};
461} // namespace
462
463char SPIRVLegalizePointerCast::ID = 0;
464INITIALIZE_PASS(SPIRVLegalizePointerCast, "spirv-legalize-bitcast",
465 "SPIRV legalize bitcast pass", false, false)
466
467FunctionPass *llvm::createSPIRVLegalizePointerCastPass(SPIRVTargetMachine *TM) {
468 return new SPIRVLegalizePointerCast(TM);
469}
470