1//===-- SPIRVLegalizePointerCast.cpp ----------------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// The LLVM IR has multiple legal patterns we cannot lower to Logical SPIR-V.
10// This pass modifies such loads to have an IR we can directly lower to valid
11// logical SPIR-V.
12// OpenCL can avoid this because they rely on ptrcast, which is not supported
13// by logical SPIR-V.
14//
15// This pass relies on the assign_ptr_type intrinsic to deduce the type of the
16// pointed values, must replace all occurences of `ptrcast`. This is why
17// unhandled cases are reported as unreachable: we MUST cover all cases.
18//
19// 1. Loading the first element of an array
20//
21// %array = [10 x i32]
22// %value = load i32, ptr %array
23//
24// LLVM can skip the GEP instruction, and only request loading the first 4
25// bytes. In logical SPIR-V, we need an OpAccessChain to access the first
26// element. This pass will add a getelementptr instruction before the load.
27//
28//
29// 2. Implicit downcast from load
30//
31// %1 = getelementptr <4 x i32>, ptr %vec4, i64 0
32// %2 = load <3 x i32>, ptr %1
33//
34// The pointer in the GEP instruction is only used for offset computations,
35// but it doesn't NEED to match the pointed type. OpAccessChain however
36// requires this. Also, LLVM loads define the bitwidth of the load, not the
37// pointer. In this example, we can guess %vec4 is a vec4 thanks to the GEP
38// instruction basetype, but we only want to load the first 3 elements, hence
39// do a partial load. In logical SPIR-V, this is not legal. What we must do
40// is load the full vector (basetype), extract 3 elements, and recombine them
41// to form a 3-element vector.
42//
43//===----------------------------------------------------------------------===//
44
45#include "SPIRV.h"
46#include "SPIRVSubtarget.h"
47#include "SPIRVTargetMachine.h"
48#include "SPIRVUtils.h"
49#include "llvm/CodeGen/IntrinsicLowering.h"
50#include "llvm/IR/IRBuilder.h"
51#include "llvm/IR/IntrinsicInst.h"
52#include "llvm/IR/Intrinsics.h"
53#include "llvm/IR/IntrinsicsSPIRV.h"
54#include "llvm/Transforms/Utils/Cloning.h"
55#include "llvm/Transforms/Utils/LowerMemIntrinsics.h"
56
57using namespace llvm;
58
59namespace {
60class SPIRVLegalizePointerCast : public FunctionPass {
61
62 // Builds the `spv_assign_type` assigning |Ty| to |Value| at the current
63 // builder position.
64 void buildAssignType(IRBuilder<> &B, Type *Ty, Value *Arg) {
65 Value *OfType = PoisonValue::get(T: Ty);
66 CallInst *AssignCI = buildIntrWithMD(IntrID: Intrinsic::spv_assign_type,
67 Types: {Arg->getType()}, Arg: OfType, Arg2: Arg, Imms: {}, B);
68 GR->addAssignPtrTypeInstr(Val: Arg, AssignPtrTyCI: AssignCI);
69 }
70
71 // Loads parts of the vector of type |SourceType| from the pointer |Source|
72 // and create a new vector of type |TargetType|. |TargetType| must be a vector
73 // type, and element types of |TargetType| and |SourceType| must match.
74 // Returns the loaded value.
75 Value *loadVectorFromVector(IRBuilder<> &B, FixedVectorType *SourceType,
76 FixedVectorType *TargetType, Value *Source) {
77 // We expect the codegen to avoid doing implicit bitcast from a load.
78 assert(TargetType->getElementType() == SourceType->getElementType());
79 assert(TargetType->getNumElements() < SourceType->getNumElements());
80
81 LoadInst *NewLoad = B.CreateLoad(Ty: SourceType, Ptr: Source);
82 buildAssignType(B, Ty: SourceType, Arg: NewLoad);
83
84 SmallVector<int> Mask(/* Size= */ TargetType->getNumElements());
85 for (unsigned I = 0; I < TargetType->getNumElements(); ++I)
86 Mask[I] = I;
87 Value *Output = B.CreateShuffleVector(V1: NewLoad, V2: NewLoad, Mask);
88 buildAssignType(B, Ty: TargetType, Arg: Output);
89 return Output;
90 }
91
92 // Loads the first value in an aggregate pointed by |Source| of containing
93 // elements of type |ElementType|. Load flags will be copied from |BadLoad|,
94 // which should be the load being legalized. Returns the loaded value.
95 Value *loadFirstValueFromAggregate(IRBuilder<> &B, Type *ElementType,
96 Value *Source, LoadInst *BadLoad) {
97 SmallVector<Type *, 2> Types = {BadLoad->getPointerOperandType(),
98 BadLoad->getPointerOperandType()};
99 SmallVector<Value *, 3> Args{/* isInBounds= */ B.getInt1(V: false), Source,
100 B.getInt32(C: 0), B.getInt32(C: 0)};
101 auto *GEP = B.CreateIntrinsic(ID: Intrinsic::spv_gep, Types: {Types}, Args: {Args});
102 GR->buildAssignPtr(B, ElemTy: ElementType, Arg: GEP);
103
104 LoadInst *LI = B.CreateLoad(Ty: ElementType, Ptr: GEP);
105 LI->setAlignment(BadLoad->getAlign());
106 buildAssignType(B, Ty: ElementType, Arg: LI);
107 return LI;
108 }
109
110 // Replaces the load instruction to get rid of the ptrcast used as source
111 // operand.
112 void transformLoad(IRBuilder<> &B, LoadInst *LI, Value *CastedOperand,
113 Value *OriginalOperand) {
114 Type *FromTy = GR->findDeducedElementType(Val: OriginalOperand);
115 Type *ToTy = GR->findDeducedElementType(Val: CastedOperand);
116 Value *Output = nullptr;
117
118 auto *SAT = dyn_cast<ArrayType>(Val: FromTy);
119 auto *SVT = dyn_cast<FixedVectorType>(Val: FromTy);
120 auto *SST = dyn_cast<StructType>(Val: FromTy);
121 auto *DVT = dyn_cast<FixedVectorType>(Val: ToTy);
122
123 B.SetInsertPoint(LI);
124
125 // Destination is the element type of Source, and source is an array ->
126 // Loading 1st element.
127 // - float a = array[0];
128 if (SAT && SAT->getElementType() == ToTy)
129 Output = loadFirstValueFromAggregate(B, ElementType: SAT->getElementType(),
130 Source: OriginalOperand, BadLoad: LI);
131 // Destination is the element type of Source, and source is a vector ->
132 // Vector to scalar.
133 // - float a = vector.x;
134 else if (!DVT && SVT && SVT->getElementType() == ToTy) {
135 Output = loadFirstValueFromAggregate(B, ElementType: SVT->getElementType(),
136 Source: OriginalOperand, BadLoad: LI);
137 }
138 // Destination is a smaller vector than source.
139 // - float3 v3 = vector4;
140 else if (SVT && DVT)
141 Output = loadVectorFromVector(B, SourceType: SVT, TargetType: DVT, Source: OriginalOperand);
142 // Destination is the scalar type stored at the start of an aggregate.
143 // - struct S { float m };
144 // - float v = s.m;
145 else if (SST && SST->getTypeAtIndex(N: 0u) == ToTy)
146 Output = loadFirstValueFromAggregate(B, ElementType: ToTy, Source: OriginalOperand, BadLoad: LI);
147 else
148 llvm_unreachable("Unimplemented implicit down-cast from load.");
149
150 GR->replaceAllUsesWith(Old: LI, New: Output, /* DeleteOld= */ true);
151 DeadInstructions.push_back(x: LI);
152 }
153
154 // Creates an spv_insertelt instruction (equivalent to llvm's insertelement).
155 Value *makeInsertElement(IRBuilder<> &B, Value *Vector, Value *Element,
156 unsigned Index) {
157 Type *Int32Ty = Type::getInt32Ty(C&: B.getContext());
158 SmallVector<Type *, 4> Types = {Vector->getType(), Vector->getType(),
159 Element->getType(), Int32Ty};
160 SmallVector<Value *> Args = {Vector, Element, B.getInt32(C: Index)};
161 Instruction *NewI =
162 B.CreateIntrinsic(ID: Intrinsic::spv_insertelt, Types: {Types}, Args: {Args});
163 buildAssignType(B, Ty: Vector->getType(), Arg: NewI);
164 return NewI;
165 }
166
167 // Creates an spv_extractelt instruction (equivalent to llvm's
168 // extractelement).
169 Value *makeExtractElement(IRBuilder<> &B, Type *ElementType, Value *Vector,
170 unsigned Index) {
171 Type *Int32Ty = Type::getInt32Ty(C&: B.getContext());
172 SmallVector<Type *, 3> Types = {ElementType, Vector->getType(), Int32Ty};
173 SmallVector<Value *> Args = {Vector, B.getInt32(C: Index)};
174 Instruction *NewI =
175 B.CreateIntrinsic(ID: Intrinsic::spv_extractelt, Types: {Types}, Args: {Args});
176 buildAssignType(B, Ty: ElementType, Arg: NewI);
177 return NewI;
178 }
179
180 // Stores the given Src vector operand into the Dst vector, adjusting the size
181 // if required.
182 Value *storeVectorFromVector(IRBuilder<> &B, Value *Src, Value *Dst,
183 Align Alignment) {
184 FixedVectorType *SrcType = cast<FixedVectorType>(Val: Src->getType());
185 FixedVectorType *DstType =
186 cast<FixedVectorType>(Val: GR->findDeducedElementType(Val: Dst));
187 assert(DstType->getNumElements() >= SrcType->getNumElements());
188
189 LoadInst *LI = B.CreateLoad(Ty: DstType, Ptr: Dst);
190 LI->setAlignment(Alignment);
191 Value *OldValues = LI;
192 buildAssignType(B, Ty: OldValues->getType(), Arg: OldValues);
193 Value *NewValues = Src;
194
195 for (unsigned I = 0; I < SrcType->getNumElements(); ++I) {
196 Value *Element =
197 makeExtractElement(B, ElementType: SrcType->getElementType(), Vector: NewValues, Index: I);
198 OldValues = makeInsertElement(B, Vector: OldValues, Element, Index: I);
199 }
200
201 StoreInst *SI = B.CreateStore(Val: OldValues, Ptr: Dst);
202 SI->setAlignment(Alignment);
203 return SI;
204 }
205
206 void buildGEPIndexChain(IRBuilder<> &B, Type *Search, Type *Aggregate,
207 SmallVectorImpl<Value *> &Indices) {
208 Indices.push_back(Elt: B.getInt32(C: 0));
209
210 if (Search == Aggregate)
211 return;
212
213 if (auto *ST = dyn_cast<StructType>(Val: Aggregate))
214 buildGEPIndexChain(B, Search, Aggregate: ST->getTypeAtIndex(N: 0u), Indices);
215 else if (auto *AT = dyn_cast<ArrayType>(Val: Aggregate))
216 buildGEPIndexChain(B, Search, Aggregate: AT->getElementType(), Indices);
217 else if (auto *VT = dyn_cast<FixedVectorType>(Val: Aggregate))
218 buildGEPIndexChain(B, Search, Aggregate: VT->getElementType(), Indices);
219 else
220 llvm_unreachable("Bad access chain?");
221 }
222
223 // Stores the given Src value into the first entry of the Dst aggregate.
224 Value *storeToFirstValueAggregate(IRBuilder<> &B, Value *Src, Value *Dst,
225 Type *DstPointeeType, Align Alignment) {
226 SmallVector<Type *, 2> Types = {Dst->getType(), Dst->getType()};
227 SmallVector<Value *, 3> Args{/* isInBounds= */ B.getInt1(V: true), Dst};
228 buildGEPIndexChain(B, Search: Src->getType(), Aggregate: DstPointeeType, Indices&: Args);
229 auto *GEP = B.CreateIntrinsic(ID: Intrinsic::spv_gep, Types: {Types}, Args: {Args});
230 GR->buildAssignPtr(B, ElemTy: Src->getType(), Arg: GEP);
231 StoreInst *SI = B.CreateStore(Val: Src, Ptr: GEP);
232 SI->setAlignment(Alignment);
233 return SI;
234 }
235
236 bool isTypeFirstElementAggregate(Type *Search, Type *Aggregate) {
237 if (Search == Aggregate)
238 return true;
239 if (auto *ST = dyn_cast<StructType>(Val: Aggregate))
240 return isTypeFirstElementAggregate(Search, Aggregate: ST->getTypeAtIndex(N: 0u));
241 if (auto *VT = dyn_cast<FixedVectorType>(Val: Aggregate))
242 return isTypeFirstElementAggregate(Search, Aggregate: VT->getElementType());
243 if (auto *AT = dyn_cast<ArrayType>(Val: Aggregate))
244 return isTypeFirstElementAggregate(Search, Aggregate: AT->getElementType());
245 return false;
246 }
247
248 // Transforms a store instruction (or SPV intrinsic) using a ptrcast as
249 // operand into a valid logical SPIR-V store with no ptrcast.
250 void transformStore(IRBuilder<> &B, Instruction *BadStore, Value *Src,
251 Value *Dst, Align Alignment) {
252 Type *ToTy = GR->findDeducedElementType(Val: Dst);
253 Type *FromTy = Src->getType();
254
255 auto *S_VT = dyn_cast<FixedVectorType>(Val: FromTy);
256 auto *D_ST = dyn_cast<StructType>(Val: ToTy);
257 auto *D_VT = dyn_cast<FixedVectorType>(Val: ToTy);
258
259 B.SetInsertPoint(BadStore);
260 if (D_ST && isTypeFirstElementAggregate(Search: FromTy, Aggregate: D_ST))
261 storeToFirstValueAggregate(B, Src, Dst, DstPointeeType: D_ST, Alignment);
262 else if (D_VT && S_VT)
263 storeVectorFromVector(B, Src, Dst, Alignment);
264 else if (D_VT && !S_VT && FromTy == D_VT->getElementType())
265 storeToFirstValueAggregate(B, Src, Dst, DstPointeeType: D_VT, Alignment);
266 else
267 llvm_unreachable("Unsupported ptrcast use in store. Please fix.");
268
269 DeadInstructions.push_back(x: BadStore);
270 }
271
272 void legalizePointerCast(IntrinsicInst *II) {
273 Value *CastedOperand = II;
274 Value *OriginalOperand = II->getOperand(i_nocapture: 0);
275
276 IRBuilder<> B(II->getContext());
277 std::vector<Value *> Users;
278 for (Use &U : II->uses())
279 Users.push_back(x: U.getUser());
280
281 for (Value *User : Users) {
282 if (LoadInst *LI = dyn_cast<LoadInst>(Val: User)) {
283 transformLoad(B, LI, CastedOperand, OriginalOperand);
284 continue;
285 }
286
287 if (StoreInst *SI = dyn_cast<StoreInst>(Val: User)) {
288 transformStore(B, BadStore: SI, Src: SI->getValueOperand(), Dst: OriginalOperand,
289 Alignment: SI->getAlign());
290 continue;
291 }
292
293 if (IntrinsicInst *Intrin = dyn_cast<IntrinsicInst>(Val: User)) {
294 if (Intrin->getIntrinsicID() == Intrinsic::spv_assign_ptr_type) {
295 DeadInstructions.push_back(x: Intrin);
296 continue;
297 }
298
299 if (Intrin->getIntrinsicID() == Intrinsic::spv_gep) {
300 GR->replaceAllUsesWith(Old: CastedOperand, New: OriginalOperand,
301 /* DeleteOld= */ false);
302 continue;
303 }
304
305 if (Intrin->getIntrinsicID() == Intrinsic::spv_store) {
306 Align Alignment;
307 if (ConstantInt *C = dyn_cast<ConstantInt>(Val: Intrin->getOperand(i_nocapture: 3)))
308 Alignment = Align(C->getZExtValue());
309 transformStore(B, BadStore: Intrin, Src: Intrin->getArgOperand(i: 0), Dst: OriginalOperand,
310 Alignment);
311 continue;
312 }
313 }
314
315 llvm_unreachable("Unsupported ptrcast user. Please fix.");
316 }
317
318 DeadInstructions.push_back(x: II);
319 }
320
321public:
322 SPIRVLegalizePointerCast(SPIRVTargetMachine *TM) : FunctionPass(ID), TM(TM) {}
323
324 virtual bool runOnFunction(Function &F) override {
325 const SPIRVSubtarget &ST = TM->getSubtarget<SPIRVSubtarget>(F);
326 GR = ST.getSPIRVGlobalRegistry();
327 DeadInstructions.clear();
328
329 std::vector<IntrinsicInst *> WorkList;
330 for (auto &BB : F) {
331 for (auto &I : BB) {
332 auto *II = dyn_cast<IntrinsicInst>(Val: &I);
333 if (II && II->getIntrinsicID() == Intrinsic::spv_ptrcast)
334 WorkList.push_back(x: II);
335 }
336 }
337
338 for (IntrinsicInst *II : WorkList)
339 legalizePointerCast(II);
340
341 for (Instruction *I : DeadInstructions)
342 I->eraseFromParent();
343
344 return DeadInstructions.size() != 0;
345 }
346
347private:
348 SPIRVTargetMachine *TM = nullptr;
349 SPIRVGlobalRegistry *GR = nullptr;
350 std::vector<Instruction *> DeadInstructions;
351
352public:
353 static char ID;
354};
355} // namespace
356
357char SPIRVLegalizePointerCast::ID = 0;
358INITIALIZE_PASS(SPIRVLegalizePointerCast, "spirv-legalize-bitcast",
359 "SPIRV legalize bitcast pass", false, false)
360
361FunctionPass *llvm::createSPIRVLegalizePointerCastPass(SPIRVTargetMachine *TM) {
362 return new SPIRVLegalizePointerCast(TM);
363}
364