1//===- SPIRVLegalizeZeroSizeArrays.cpp - Legalize zero-size arrays -------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// SPIR-V does not support zero-size arrays unless it is within a shader. This
10// pass legalizes zero-size arrays ([0 x T]) in unsupported cases.
11//
12//===----------------------------------------------------------------------===//
13
14#include "SPIRVLegalizeZeroSizeArrays.h"
15#include "SPIRV.h"
16#include "SPIRVTargetMachine.h"
17#include "SPIRVUtils.h"
18#include "llvm/ADT/DenseMap.h"
19#include "llvm/ADT/SmallVector.h"
20#include "llvm/IR/IRBuilder.h"
21#include "llvm/IR/InstIterator.h"
22#include "llvm/IR/InstVisitor.h"
23#include "llvm/Pass.h"
24#include "llvm/Support/Debug.h"
25
26#define DEBUG_TYPE "spirv-legalize-zero-size-arrays"
27
28using namespace llvm;
29
30namespace {
31
32bool hasZeroSizeArray(const Type *Ty) {
33 if (const ArrayType *ArrTy = dyn_cast<ArrayType>(Val: Ty)) {
34 if (ArrTy->getNumElements() == 0)
35 return true;
36 return hasZeroSizeArray(Ty: ArrTy->getElementType());
37 }
38
39 if (const StructType *StructTy = dyn_cast<StructType>(Val: Ty)) {
40 for (Type *ElemTy : StructTy->elements()) {
41 if (hasZeroSizeArray(Ty: ElemTy))
42 return true;
43 }
44 }
45
46 return false;
47}
48
49bool shouldLegalizeInstType(const Type *Ty) {
50 // This recursive function will always terminate because we only look inside
51 // array types, and those can't be recursive.
52 if (const ArrayType *ArrTy = dyn_cast_if_present<ArrayType>(Val: Ty)) {
53 return ArrTy->getNumElements() == 0 ||
54 shouldLegalizeInstType(Ty: ArrTy->getElementType());
55 }
56 return false;
57}
58
59class SPIRVLegalizeZeroSizeArraysImpl
60 : public InstVisitor<SPIRVLegalizeZeroSizeArraysImpl> {
61 friend class InstVisitor<SPIRVLegalizeZeroSizeArraysImpl>;
62
63public:
64 SPIRVLegalizeZeroSizeArraysImpl(const SPIRVTargetMachine &TM)
65 : InstVisitor(), TM(TM) {}
66 bool runOnModule(Module &M);
67
68 // TODO: Handle GEP, PHI.
69 void visitAllocaInst(AllocaInst &AI);
70 void visitLoadInst(LoadInst &LI);
71 void visitStoreInst(StoreInst &SI);
72 void visitSelectInst(SelectInst &Sel);
73 void visitExtractValueInst(ExtractValueInst &EVI);
74 void visitInsertValueInst(InsertValueInst &IVI);
75
76private:
77 Type *legalizeType(Type *Ty);
78 Constant *legalizeConstant(Constant *C);
79
80 const SPIRVTargetMachine &TM;
81 DenseMap<Type *, Type *> TypeMap;
82 DenseMap<GlobalVariable *, GlobalVariable *> GlobalMap;
83 SmallVector<Instruction *, 16> ToErase;
84 bool Modified = false;
85};
86
87class SPIRVLegalizeZeroSizeArraysLegacy : public ModulePass {
88public:
89 static char ID;
90 SPIRVLegalizeZeroSizeArraysLegacy(const SPIRVTargetMachine &TM)
91 : ModulePass(ID), TM(TM) {}
92 StringRef getPassName() const override {
93 return "SPIRV Legalize Zero-Size Arrays";
94 }
95 bool runOnModule(Module &M) override {
96 SPIRVLegalizeZeroSizeArraysImpl Impl(TM);
97 return Impl.runOnModule(M);
98 }
99
100private:
101 const SPIRVTargetMachine &TM;
102};
103
104// Legalize a type. There are only two cases we need to care about:
105// arrays and structs.
106//
107// For arrays, we just replace the entire array type with a ptr.
108//
109// For structs, we create a new type with any members containing
110// nested arrays legalized.
111
112Type *SPIRVLegalizeZeroSizeArraysImpl::legalizeType(Type *Ty) {
113 auto It = TypeMap.find(Val: Ty);
114 if (It != TypeMap.end())
115 return It->second;
116
117 Type *LegalizedTy = Ty;
118
119 if (isa<ArrayType>(Val: Ty)) {
120 LegalizedTy = PointerType::get(
121 C&: Ty->getContext(),
122 AddressSpace: storageClassToAddressSpace(SC: SPIRV::StorageClass::Generic));
123
124 } else if (StructType *StructTy = dyn_cast<StructType>(Val: Ty)) {
125 SmallVector<Type *, 8> ElemTypes;
126 bool Changed = false;
127 for (Type *ElemTy : StructTy->elements()) {
128 Type *LegalizedElemTy = legalizeType(Ty: ElemTy);
129 ElemTypes.push_back(Elt: LegalizedElemTy);
130 Changed |= LegalizedElemTy != ElemTy;
131 }
132 if (Changed) {
133 LegalizedTy =
134 StructTy->hasName()
135 ? StructType::create(Context&: StructTy->getContext(), Elements: ElemTypes,
136 Name: (StructTy->getName() + ".legalized").str(),
137 isPacked: StructTy->isPacked())
138 : StructType::get(Context&: StructTy->getContext(), Elements: ElemTypes,
139 isPacked: StructTy->isPacked());
140 }
141 }
142
143 TypeMap[Ty] = LegalizedTy;
144 return LegalizedTy;
145}
146
147Constant *SPIRVLegalizeZeroSizeArraysImpl::legalizeConstant(Constant *C) {
148 if (!C || !hasZeroSizeArray(Ty: C->getType()))
149 return C;
150
151 if (GlobalVariable *GV = dyn_cast<GlobalVariable>(Val: C)) {
152 if (GlobalVariable *NewGV = GlobalMap.lookup(Val: GV))
153 return NewGV;
154 return C;
155 }
156
157 Type *NewTy = legalizeType(Ty: C->getType());
158 if (isa<UndefValue>(Val: C))
159 return PoisonValue::get(T: NewTy);
160 if (isa<ConstantAggregateZero>(Val: C))
161 return Constant::getNullValue(Ty: NewTy);
162 if (ConstantArray *CA = dyn_cast<ConstantArray>(Val: C)) {
163 SmallVector<Constant *, 8> Elems;
164 for (Use &U : CA->operands())
165 Elems.push_back(Elt: legalizeConstant(C: cast<Constant>(Val&: U)));
166 return ConstantArray::get(T: cast<ArrayType>(Val: NewTy), V: Elems);
167 }
168
169 if (ConstantStruct *CS = dyn_cast<ConstantStruct>(Val: C)) {
170 SmallVector<Constant *, 8> Fields;
171 for (Use &U : CS->operands())
172 Fields.push_back(Elt: legalizeConstant(C: cast<Constant>(Val&: U)));
173 return ConstantStruct::get(T: cast<StructType>(Val: NewTy), V: Fields);
174 }
175
176 if (ConstantExpr *CE = dyn_cast<ConstantExpr>(Val: C)) {
177 // Don't legalize GEP constant expressions, the backend deals with them
178 // fine.
179 if (CE->getOpcode() == Instruction::GetElementPtr)
180 return CE;
181 SmallVector<Constant *, 4> Ops;
182 bool Changed = false;
183 for (Use &U : CE->operands()) {
184 Constant *LegalizedOp = legalizeConstant(C: cast<Constant>(Val&: U));
185 Ops.push_back(Elt: LegalizedOp);
186 Changed |= LegalizedOp != cast<Constant>(Val: U.get());
187 }
188 if (Changed)
189 return CE->getWithOperands(Ops);
190 }
191
192 return C;
193}
194
195void SPIRVLegalizeZeroSizeArraysImpl::visitAllocaInst(AllocaInst &AI) {
196 if (!hasZeroSizeArray(Ty: AI.getAllocatedType()))
197 return;
198
199 // TODO: Handle structs containing zero-size arrays.
200 ArrayType *ArrTy = dyn_cast<ArrayType>(Val: AI.getAllocatedType());
201 if (shouldLegalizeInstType(Ty: ArrTy)) {
202 // Allocate a generic pointer instead of an empty array.
203 IRBuilder<> Builder(&AI);
204 AllocaInst *NewAI = Builder.CreateAlloca(
205 Ty: PointerType::get(
206 C&: ArrTy->getContext(),
207 AddressSpace: storageClassToAddressSpace(SC: SPIRV::StorageClass::Generic)),
208 /*ArraySize=*/nullptr, Name: AI.getName());
209 NewAI->setAlignment(AI.getAlign());
210 NewAI->setDebugLoc(AI.getDebugLoc());
211 AI.replaceAllUsesWith(V: NewAI);
212 ToErase.push_back(Elt: &AI);
213 Modified = true;
214 }
215}
216
217void SPIRVLegalizeZeroSizeArraysImpl::visitLoadInst(LoadInst &LI) {
218 if (!hasZeroSizeArray(Ty: LI.getType()))
219 return;
220
221 // TODO: Handle structs containing zero-size arrays.
222 ArrayType *ArrTy = dyn_cast<ArrayType>(Val: LI.getType());
223 if (shouldLegalizeInstType(Ty: ArrTy)) {
224 LI.replaceAllUsesWith(V: PoisonValue::get(T: LI.getType()));
225 ToErase.push_back(Elt: &LI);
226 Modified = true;
227 }
228}
229
230void SPIRVLegalizeZeroSizeArraysImpl::visitStoreInst(StoreInst &SI) {
231 Type *StoreTy = SI.getValueOperand()->getType();
232
233 // TODO: Handle structs containing zero-size arrays.
234 ArrayType *ArrTy = dyn_cast<ArrayType>(Val: StoreTy);
235 if (shouldLegalizeInstType(Ty: ArrTy)) {
236 ToErase.push_back(Elt: &SI);
237 Modified = true;
238 }
239}
240
241void SPIRVLegalizeZeroSizeArraysImpl::visitSelectInst(SelectInst &Sel) {
242 if (!hasZeroSizeArray(Ty: Sel.getType()))
243 return;
244
245 // TODO: Handle structs containing zero-size arrays.
246 ArrayType *ArrTy = dyn_cast<ArrayType>(Val: Sel.getType());
247 if (shouldLegalizeInstType(Ty: ArrTy)) {
248 Sel.replaceAllUsesWith(V: PoisonValue::get(T: Sel.getType()));
249 ToErase.push_back(Elt: &Sel);
250 Modified = true;
251 }
252}
253
254void SPIRVLegalizeZeroSizeArraysImpl::visitExtractValueInst(
255 ExtractValueInst &EVI) {
256 if (!hasZeroSizeArray(Ty: EVI.getAggregateOperand()->getType()))
257 return;
258
259 // TODO: Handle structs containing zero-size arrays.
260 ArrayType *ArrTy = dyn_cast<ArrayType>(Val: EVI.getType());
261 if (shouldLegalizeInstType(Ty: ArrTy)) {
262 EVI.replaceAllUsesWith(V: PoisonValue::get(T: EVI.getType()));
263 ToErase.push_back(Elt: &EVI);
264 Modified = true;
265 }
266}
267
268void SPIRVLegalizeZeroSizeArraysImpl::visitInsertValueInst(
269 InsertValueInst &IVI) {
270 if (!hasZeroSizeArray(Ty: IVI.getAggregateOperand()->getType()))
271 return;
272
273 // TODO: Handle structs containing zero-size arrays.
274 ArrayType *ArrTy =
275 dyn_cast<ArrayType>(Val: IVI.getInsertedValueOperand()->getType());
276 if (shouldLegalizeInstType(Ty: ArrTy)) {
277 IVI.replaceAllUsesWith(V: IVI.getAggregateOperand());
278 ToErase.push_back(Elt: &IVI);
279 Modified = true;
280 }
281}
282
283bool SPIRVLegalizeZeroSizeArraysImpl::runOnModule(Module &M) {
284 TypeMap.clear();
285 GlobalMap.clear();
286 ToErase.clear();
287 Modified = false;
288
289 // Runtime arrays are allowed for shaders, so we don't need to do anything.
290 if (TM.getSubtargetImpl()->isShader())
291 return false;
292 // 0-sized arrays are handled differently for AMDGCN flavoured SPIRV.
293 if (M.getTargetTriple().getVendor() == Triple::VendorType::AMD)
294 return false;
295
296 // First pass: create new globals (legalizing the initializer as needed) and
297 // track mapping (don't erase old ones yet).
298 SmallVector<GlobalVariable *, 8> OldGlobals;
299 for (GlobalVariable &GV : M.globals()) {
300 if (!hasZeroSizeArray(Ty: GV.getValueType()))
301 continue;
302
303 Type *NewTy = legalizeType(Ty: GV.getValueType());
304 Constant *LegalizedInitializer = legalizeConstant(C: GV.getInitializer());
305
306 // Use an empty name for now, we will update it in the
307 // following step.
308 GlobalVariable *NewGV = new GlobalVariable(
309 M, NewTy, GV.isConstant(), GV.getLinkage(), LegalizedInitializer,
310 /*Name=*/"", &GV, GV.getThreadLocalMode(), GV.getAddressSpace(),
311 GV.isExternallyInitialized());
312 NewGV->copyAttributesFrom(Src: &GV);
313 NewGV->copyMetadata(Src: &GV, Offset: 0);
314 NewGV->setComdat(GV.getComdat());
315 NewGV->setAlignment(GV.getAlign());
316 GlobalMap[&GV] = NewGV;
317 OldGlobals.push_back(Elt: &GV);
318 Modified = true;
319 }
320
321 // Second pass: replace uses, transfer names, and erase old globals.
322 for (GlobalVariable *GV : OldGlobals) {
323 GlobalVariable *NewGV = GlobalMap[GV];
324 GV->replaceAllUsesWith(V: ConstantExpr::getBitCast(C: NewGV, Ty: GV->getType()));
325 NewGV->takeName(V: GV);
326 GV->eraseFromParent();
327 }
328
329 for (Function &F : M)
330 for (Instruction &I : instructions(F))
331 visit(I);
332
333 for (Instruction *I : ToErase)
334 I->eraseFromParent();
335
336 return Modified;
337}
338
339} // namespace
340
341PreservedAnalyses SPIRVLegalizeZeroSizeArrays::run(Module &M,
342 ModuleAnalysisManager &AM) {
343 SPIRVLegalizeZeroSizeArraysImpl Impl(TM);
344 if (Impl.runOnModule(M))
345 return PreservedAnalyses::none();
346 return PreservedAnalyses::all();
347}
348
349char SPIRVLegalizeZeroSizeArraysLegacy::ID = 0;
350
351INITIALIZE_PASS(SPIRVLegalizeZeroSizeArraysLegacy,
352 "spirv-legalize-zero-size-arrays",
353 "Legalize SPIR-V zero-size arrays", false, false)
354
355ModulePass *
356llvm::createSPIRVLegalizeZeroSizeArraysPass(const SPIRVTargetMachine &TM) {
357 return new SPIRVLegalizeZeroSizeArraysLegacy(TM);
358}
359