1//===- SPIRVLegalizeZeroSizeArrays.cpp - Legalize zero-size arrays -------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// SPIR-V does not support zero-size arrays unless it is within a shader. This
10// pass legalizes zero-size arrays ([0 x T]) in unsupported cases.
11//
12//===----------------------------------------------------------------------===//
13
14#include "SPIRVLegalizeZeroSizeArrays.h"
15#include "SPIRV.h"
16#include "SPIRVTargetMachine.h"
17#include "SPIRVUtils.h"
18#include "llvm/ADT/DenseMap.h"
19#include "llvm/ADT/SmallVector.h"
20#include "llvm/IR/IRBuilder.h"
21#include "llvm/IR/InstIterator.h"
22#include "llvm/IR/InstVisitor.h"
23#include "llvm/Pass.h"
24#include "llvm/Support/Debug.h"
25
26#define DEBUG_TYPE "spirv-legalize-zero-size-arrays"
27
28using namespace llvm;
29
30namespace {
31
32bool hasZeroSizeArray(const Type *Ty) {
33 if (const ArrayType *ArrTy = dyn_cast<ArrayType>(Val: Ty)) {
34 if (ArrTy->getNumElements() == 0)
35 return true;
36 return hasZeroSizeArray(Ty: ArrTy->getElementType());
37 }
38
39 if (const StructType *StructTy = dyn_cast<StructType>(Val: Ty)) {
40 for (Type *ElemTy : StructTy->elements()) {
41 if (hasZeroSizeArray(Ty: ElemTy))
42 return true;
43 }
44 }
45
46 return false;
47}
48
49bool shouldLegalizeInstType(const Type *Ty) {
50 // This recursive function will always terminate because we only look inside
51 // array types, and those can't be recursive.
52 if (const ArrayType *ArrTy = dyn_cast_if_present<ArrayType>(Val: Ty)) {
53 return ArrTy->getNumElements() == 0 ||
54 shouldLegalizeInstType(Ty: ArrTy->getElementType());
55 }
56 return false;
57}
58
59class SPIRVLegalizeZeroSizeArraysImpl
60 : public InstVisitor<SPIRVLegalizeZeroSizeArraysImpl> {
61 friend class InstVisitor<SPIRVLegalizeZeroSizeArraysImpl>;
62
63public:
64 SPIRVLegalizeZeroSizeArraysImpl(const SPIRVTargetMachine &TM)
65 : InstVisitor(), TM(TM) {}
66 bool runOnModule(Module &M);
67
68 // TODO: Handle GEP, PHI.
69 void visitAllocaInst(AllocaInst &AI);
70 void visitLoadInst(LoadInst &LI);
71 void visitStoreInst(StoreInst &SI);
72 void visitSelectInst(SelectInst &Sel);
73 void visitExtractValueInst(ExtractValueInst &EVI);
74 void visitInsertValueInst(InsertValueInst &IVI);
75
76private:
77 Type *legalizeType(Type *Ty);
78 Constant *legalizeConstant(Constant *C);
79
80 const SPIRVTargetMachine &TM;
81 DenseMap<Type *, Type *> TypeMap;
82 DenseMap<GlobalVariable *, GlobalVariable *> GlobalMap;
83 SmallVector<Instruction *, 16> ToErase;
84 bool Modified = false;
85};
86
87class SPIRVLegalizeZeroSizeArraysLegacy : public ModulePass {
88public:
89 static char ID;
90 SPIRVLegalizeZeroSizeArraysLegacy(const SPIRVTargetMachine &TM)
91 : ModulePass(ID), TM(TM) {}
92 StringRef getPassName() const override {
93 return "SPIRV Legalize Zero-Size Arrays";
94 }
95 bool runOnModule(Module &M) override {
96 SPIRVLegalizeZeroSizeArraysImpl Impl(TM);
97 return Impl.runOnModule(M);
98 }
99
100private:
101 const SPIRVTargetMachine &TM;
102};
103
104// Legalize a type. There are only two cases we need to care about:
105// arrays and structs.
106//
107// For arrays, we just replace the entire array type with a ptr.
108//
109// For structs, we create a new type with any members containing
110// nested arrays legalized.
111
112Type *SPIRVLegalizeZeroSizeArraysImpl::legalizeType(Type *Ty) {
113 auto It = TypeMap.find(Val: Ty);
114 if (It != TypeMap.end())
115 return It->second;
116
117 Type *LegalizedTy = Ty;
118
119 if (isa<ArrayType>(Val: Ty)) {
120 LegalizedTy = PointerType::get(
121 C&: Ty->getContext(),
122 AddressSpace: storageClassToAddressSpace(SC: SPIRV::StorageClass::Generic));
123
124 } else if (StructType *StructTy = dyn_cast<StructType>(Val: Ty)) {
125 SmallVector<Type *, 8> ElemTypes;
126 bool Changed = false;
127 for (Type *ElemTy : StructTy->elements()) {
128 Type *LegalizedElemTy = legalizeType(Ty: ElemTy);
129 ElemTypes.push_back(Elt: LegalizedElemTy);
130 Changed |= LegalizedElemTy != ElemTy;
131 }
132 if (Changed) {
133 LegalizedTy =
134 StructTy->hasName()
135 ? StructType::create(Context&: StructTy->getContext(), Elements: ElemTypes,
136 Name: (StructTy->getName() + ".legalized").str(),
137 isPacked: StructTy->isPacked())
138 : StructType::get(Context&: StructTy->getContext(), Elements: ElemTypes,
139 isPacked: StructTy->isPacked());
140 }
141 }
142
143 TypeMap[Ty] = LegalizedTy;
144 return LegalizedTy;
145}
146
147Constant *SPIRVLegalizeZeroSizeArraysImpl::legalizeConstant(Constant *C) {
148 if (!C || !hasZeroSizeArray(Ty: C->getType()))
149 return C;
150
151 if (GlobalVariable *GV = dyn_cast<GlobalVariable>(Val: C)) {
152 if (GlobalVariable *NewGV = GlobalMap.lookup(Val: GV))
153 return NewGV;
154 return C;
155 }
156
157 Type *NewTy = legalizeType(Ty: C->getType());
158 if (isa<UndefValue>(Val: C))
159 return PoisonValue::get(T: NewTy);
160 if (isa<ConstantAggregateZero>(Val: C))
161 return Constant::getNullValue(Ty: NewTy);
162 if (ConstantArray *CA = dyn_cast<ConstantArray>(Val: C)) {
163 SmallVector<Constant *, 8> Elems;
164 for (Use &U : CA->operands())
165 Elems.push_back(Elt: legalizeConstant(C: cast<Constant>(Val&: U)));
166 return ConstantArray::get(T: cast<ArrayType>(Val: NewTy), V: Elems);
167 }
168
169 if (ConstantStruct *CS = dyn_cast<ConstantStruct>(Val: C)) {
170 SmallVector<Constant *, 8> Fields;
171 for (Use &U : CS->operands())
172 Fields.push_back(Elt: legalizeConstant(C: cast<Constant>(Val&: U)));
173 return ConstantStruct::get(T: cast<StructType>(Val: NewTy), V: Fields);
174 }
175
176 if (ConstantExpr *CE = dyn_cast<ConstantExpr>(Val: C)) {
177 // Don't legalize GEP constant expressions, the backend deals with them
178 // fine.
179 if (CE->getOpcode() == Instruction::GetElementPtr)
180 return CE;
181 SmallVector<Constant *, 4> Ops;
182 bool Changed = false;
183 for (Use &U : CE->operands()) {
184 Constant *LegalizedOp = legalizeConstant(C: cast<Constant>(Val&: U));
185 Ops.push_back(Elt: LegalizedOp);
186 Changed |= LegalizedOp != cast<Constant>(Val: U.get());
187 }
188 if (Changed)
189 return CE->getWithOperands(Ops);
190 }
191
192 return C;
193}
194
195void SPIRVLegalizeZeroSizeArraysImpl::visitAllocaInst(AllocaInst &AI) {
196 // Check if allocation size is known-zero
197 const DataLayout &DL = AI.getModule()->getDataLayout();
198 std::optional<TypeSize> Size = AI.getAllocationSize(DL);
199 if (!Size || !Size->isZero())
200 return;
201
202 // Allocate a byte instead of an empty alloca.
203 IRBuilder<> Builder(&AI);
204 AllocaInst *NewAI = Builder.CreateAlloca(Ty: Builder.getInt8Ty());
205 NewAI->takeName(V: &AI);
206 NewAI->setAlignment(AI.getAlign());
207 NewAI->setDebugLoc(AI.getDebugLoc());
208 AI.replaceAllUsesWith(V: NewAI);
209 ToErase.push_back(Elt: &AI);
210 Modified = true;
211}
212
213void SPIRVLegalizeZeroSizeArraysImpl::visitLoadInst(LoadInst &LI) {
214 if (!hasZeroSizeArray(Ty: LI.getType()))
215 return;
216
217 // TODO: Handle structs containing zero-size arrays.
218 ArrayType *ArrTy = dyn_cast<ArrayType>(Val: LI.getType());
219 if (shouldLegalizeInstType(Ty: ArrTy)) {
220 LI.replaceAllUsesWith(V: PoisonValue::get(T: LI.getType()));
221 ToErase.push_back(Elt: &LI);
222 Modified = true;
223 }
224}
225
226void SPIRVLegalizeZeroSizeArraysImpl::visitStoreInst(StoreInst &SI) {
227 Type *StoreTy = SI.getValueOperand()->getType();
228
229 // TODO: Handle structs containing zero-size arrays.
230 ArrayType *ArrTy = dyn_cast<ArrayType>(Val: StoreTy);
231 if (shouldLegalizeInstType(Ty: ArrTy)) {
232 ToErase.push_back(Elt: &SI);
233 Modified = true;
234 }
235}
236
237void SPIRVLegalizeZeroSizeArraysImpl::visitSelectInst(SelectInst &Sel) {
238 if (!hasZeroSizeArray(Ty: Sel.getType()))
239 return;
240
241 // TODO: Handle structs containing zero-size arrays.
242 ArrayType *ArrTy = dyn_cast<ArrayType>(Val: Sel.getType());
243 if (shouldLegalizeInstType(Ty: ArrTy)) {
244 Sel.replaceAllUsesWith(V: PoisonValue::get(T: Sel.getType()));
245 ToErase.push_back(Elt: &Sel);
246 Modified = true;
247 }
248}
249
250void SPIRVLegalizeZeroSizeArraysImpl::visitExtractValueInst(
251 ExtractValueInst &EVI) {
252 if (!hasZeroSizeArray(Ty: EVI.getAggregateOperand()->getType()))
253 return;
254
255 // TODO: Handle structs containing zero-size arrays.
256 ArrayType *ArrTy = dyn_cast<ArrayType>(Val: EVI.getType());
257 if (shouldLegalizeInstType(Ty: ArrTy)) {
258 EVI.replaceAllUsesWith(V: PoisonValue::get(T: EVI.getType()));
259 ToErase.push_back(Elt: &EVI);
260 Modified = true;
261 }
262}
263
264void SPIRVLegalizeZeroSizeArraysImpl::visitInsertValueInst(
265 InsertValueInst &IVI) {
266 if (!hasZeroSizeArray(Ty: IVI.getAggregateOperand()->getType()))
267 return;
268
269 // TODO: Handle structs containing zero-size arrays.
270 ArrayType *ArrTy =
271 dyn_cast<ArrayType>(Val: IVI.getInsertedValueOperand()->getType());
272 if (shouldLegalizeInstType(Ty: ArrTy)) {
273 IVI.replaceAllUsesWith(V: IVI.getAggregateOperand());
274 ToErase.push_back(Elt: &IVI);
275 Modified = true;
276 }
277}
278
279bool SPIRVLegalizeZeroSizeArraysImpl::runOnModule(Module &M) {
280 TypeMap.clear();
281 GlobalMap.clear();
282 ToErase.clear();
283 Modified = false;
284
285 // Runtime arrays are allowed for shaders, so we don't need to do anything.
286 if (TM.getSubtargetImpl()->isShader())
287 return false;
288 // 0-sized arrays are handled differently for AMDGCN flavoured SPIRV.
289 if (M.getTargetTriple().getVendor() == Triple::VendorType::AMD)
290 return false;
291
292 // First pass: create new globals (legalizing the initializer as needed) and
293 // track mapping (don't erase old ones yet).
294 SmallVector<GlobalVariable *, 8> OldGlobals;
295 for (GlobalVariable &GV : M.globals()) {
296 if (!hasZeroSizeArray(Ty: GV.getValueType()))
297 continue;
298
299 Type *NewTy = legalizeType(Ty: GV.getValueType());
300 Constant *LegalizedInitializer = legalizeConstant(C: GV.getInitializer());
301
302 // Use an empty name for now, we will update it in the
303 // following step.
304 GlobalVariable *NewGV = new GlobalVariable(
305 M, NewTy, GV.isConstant(), GV.getLinkage(), LegalizedInitializer,
306 /*Name=*/"", &GV, GV.getThreadLocalMode(), GV.getAddressSpace(),
307 GV.isExternallyInitialized());
308 NewGV->copyAttributesFrom(Src: &GV);
309 NewGV->copyMetadata(Src: &GV, Offset: 0);
310 NewGV->setComdat(GV.getComdat());
311 NewGV->setAlignment(GV.getAlign());
312 GlobalMap[&GV] = NewGV;
313 OldGlobals.push_back(Elt: &GV);
314 Modified = true;
315 }
316
317 // Second pass: replace uses, transfer names, and erase old globals.
318 for (GlobalVariable *GV : OldGlobals) {
319 GlobalVariable *NewGV = GlobalMap[GV];
320 GV->replaceAllUsesWith(V: ConstantExpr::getBitCast(C: NewGV, Ty: GV->getType()));
321 NewGV->takeName(V: GV);
322 GV->eraseFromParent();
323 }
324
325 for (Function &F : M)
326 for (Instruction &I : instructions(F))
327 visit(I);
328
329 for (Instruction *I : ToErase)
330 I->eraseFromParent();
331
332 return Modified;
333}
334
335} // namespace
336
337PreservedAnalyses SPIRVLegalizeZeroSizeArrays::run(Module &M,
338 ModuleAnalysisManager &AM) {
339 SPIRVLegalizeZeroSizeArraysImpl Impl(TM);
340 if (Impl.runOnModule(M))
341 return PreservedAnalyses::none();
342 return PreservedAnalyses::all();
343}
344
345char SPIRVLegalizeZeroSizeArraysLegacy::ID = 0;
346
347INITIALIZE_PASS(SPIRVLegalizeZeroSizeArraysLegacy,
348 "spirv-legalize-zero-size-arrays",
349 "Legalize SPIR-V zero-size arrays", false, false)
350
351ModulePass *
352llvm::createSPIRVLegalizeZeroSizeArraysPass(const SPIRVTargetMachine &TM) {
353 return new SPIRVLegalizeZeroSizeArraysLegacy(TM);
354}
355