1 | //===- SPIRVISelLowering.cpp - SPIR-V DAG Lowering Impl ---------*- C++ -*-===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file implements the SPIRVTargetLowering class. |
10 | // |
11 | //===----------------------------------------------------------------------===// |
12 | |
13 | #include "SPIRVISelLowering.h" |
14 | #include "SPIRV.h" |
15 | #include "SPIRVInstrInfo.h" |
16 | #include "SPIRVRegisterBankInfo.h" |
17 | #include "SPIRVRegisterInfo.h" |
18 | #include "SPIRVSubtarget.h" |
19 | #include "llvm/CodeGen/MachineInstrBuilder.h" |
20 | #include "llvm/CodeGen/MachineRegisterInfo.h" |
21 | #include "llvm/IR/IntrinsicsSPIRV.h" |
22 | |
23 | #define DEBUG_TYPE "spirv-lower" |
24 | |
25 | using namespace llvm; |
26 | |
27 | // Returns true of the types logically match, as defined in |
28 | // https://registry.khronos.org/SPIR-V/specs/unified1/SPIRV.html#OpCopyLogical. |
29 | static bool typesLogicallyMatch(const SPIRVType *Ty1, const SPIRVType *Ty2, |
30 | SPIRVGlobalRegistry &GR) { |
31 | if (Ty1->getOpcode() != Ty2->getOpcode()) |
32 | return false; |
33 | |
34 | if (Ty1->getNumOperands() != Ty2->getNumOperands()) |
35 | return false; |
36 | |
37 | if (Ty1->getOpcode() == SPIRV::OpTypeArray) { |
38 | // Array must have the same size. |
39 | if (Ty1->getOperand(i: 2).getReg() != Ty2->getOperand(i: 2).getReg()) |
40 | return false; |
41 | |
42 | SPIRVType *ElemType1 = GR.getSPIRVTypeForVReg(VReg: Ty1->getOperand(i: 1).getReg()); |
43 | SPIRVType *ElemType2 = GR.getSPIRVTypeForVReg(VReg: Ty2->getOperand(i: 1).getReg()); |
44 | return ElemType1 == ElemType2 || |
45 | typesLogicallyMatch(Ty1: ElemType1, Ty2: ElemType2, GR); |
46 | } |
47 | |
48 | if (Ty1->getOpcode() == SPIRV::OpTypeStruct) { |
49 | for (unsigned I = 1; I < Ty1->getNumOperands(); I++) { |
50 | SPIRVType *ElemType1 = |
51 | GR.getSPIRVTypeForVReg(VReg: Ty1->getOperand(i: I).getReg()); |
52 | SPIRVType *ElemType2 = |
53 | GR.getSPIRVTypeForVReg(VReg: Ty2->getOperand(i: I).getReg()); |
54 | if (ElemType1 != ElemType2 && |
55 | !typesLogicallyMatch(Ty1: ElemType1, Ty2: ElemType2, GR)) |
56 | return false; |
57 | } |
58 | return true; |
59 | } |
60 | return false; |
61 | } |
62 | |
63 | unsigned SPIRVTargetLowering::getNumRegistersForCallingConv( |
64 | LLVMContext &Context, CallingConv::ID CC, EVT VT) const { |
65 | // This code avoids CallLowering fail inside getVectorTypeBreakdown |
66 | // on v3i1 arguments. Maybe we need to return 1 for all types. |
67 | // TODO: remove it once this case is supported by the default implementation. |
68 | if (VT.isVector() && VT.getVectorNumElements() == 3 && |
69 | (VT.getVectorElementType() == MVT::i1 || |
70 | VT.getVectorElementType() == MVT::i8)) |
71 | return 1; |
72 | if (!VT.isVector() && VT.isInteger() && VT.getSizeInBits() <= 64) |
73 | return 1; |
74 | return getNumRegisters(Context, VT); |
75 | } |
76 | |
77 | MVT SPIRVTargetLowering::getRegisterTypeForCallingConv(LLVMContext &Context, |
78 | CallingConv::ID CC, |
79 | EVT VT) const { |
80 | // This code avoids CallLowering fail inside getVectorTypeBreakdown |
81 | // on v3i1 arguments. Maybe we need to return i32 for all types. |
82 | // TODO: remove it once this case is supported by the default implementation. |
83 | if (VT.isVector() && VT.getVectorNumElements() == 3) { |
84 | if (VT.getVectorElementType() == MVT::i1) |
85 | return MVT::v4i1; |
86 | else if (VT.getVectorElementType() == MVT::i8) |
87 | return MVT::v4i8; |
88 | } |
89 | return getRegisterType(Context, VT); |
90 | } |
91 | |
92 | bool SPIRVTargetLowering::getTgtMemIntrinsic(IntrinsicInfo &Info, |
93 | const CallInst &I, |
94 | MachineFunction &MF, |
95 | unsigned Intrinsic) const { |
96 | unsigned AlignIdx = 3; |
97 | switch (Intrinsic) { |
98 | case Intrinsic::spv_load: |
99 | AlignIdx = 2; |
100 | [[fallthrough]]; |
101 | case Intrinsic::spv_store: { |
102 | if (I.getNumOperands() >= AlignIdx + 1) { |
103 | auto *AlignOp = cast<ConstantInt>(Val: I.getOperand(i_nocapture: AlignIdx)); |
104 | Info.align = Align(AlignOp->getZExtValue()); |
105 | } |
106 | Info.flags = static_cast<MachineMemOperand::Flags>( |
107 | cast<ConstantInt>(Val: I.getOperand(i_nocapture: AlignIdx - 1))->getZExtValue()); |
108 | Info.memVT = MVT::i64; |
109 | // TODO: take into account opaque pointers (don't use getElementType). |
110 | // MVT::getVT(PtrTy->getElementType()); |
111 | return true; |
112 | break; |
113 | } |
114 | default: |
115 | break; |
116 | } |
117 | return false; |
118 | } |
119 | |
120 | std::pair<unsigned, const TargetRegisterClass *> |
121 | SPIRVTargetLowering::getRegForInlineAsmConstraint(const TargetRegisterInfo *TRI, |
122 | StringRef Constraint, |
123 | MVT VT) const { |
124 | const TargetRegisterClass *RC = nullptr; |
125 | if (Constraint.starts_with(Prefix: "{" )) |
126 | return std::make_pair(x: 0u, y&: RC); |
127 | |
128 | if (VT.isFloatingPoint()) |
129 | RC = VT.isVector() ? &SPIRV::vfIDRegClass : &SPIRV::fIDRegClass; |
130 | else if (VT.isInteger()) |
131 | RC = VT.isVector() ? &SPIRV::vIDRegClass : &SPIRV::iIDRegClass; |
132 | else |
133 | RC = &SPIRV::iIDRegClass; |
134 | |
135 | return std::make_pair(x: 0u, y&: RC); |
136 | } |
137 | |
138 | inline Register getTypeReg(MachineRegisterInfo *MRI, Register OpReg) { |
139 | SPIRVType *TypeInst = MRI->getVRegDef(Reg: OpReg); |
140 | return TypeInst && TypeInst->getOpcode() == SPIRV::OpFunctionParameter |
141 | ? TypeInst->getOperand(i: 1).getReg() |
142 | : OpReg; |
143 | } |
144 | |
145 | static void doInsertBitcast(const SPIRVSubtarget &STI, MachineRegisterInfo *MRI, |
146 | SPIRVGlobalRegistry &GR, MachineInstr &I, |
147 | Register OpReg, unsigned OpIdx, |
148 | SPIRVType *NewPtrType) { |
149 | MachineIRBuilder MIB(I); |
150 | Register NewReg = createVirtualRegister(SpvType: NewPtrType, GR: &GR, MRI, MF: MIB.getMF()); |
151 | bool Res = MIB.buildInstr(Opcode: SPIRV::OpBitcast) |
152 | .addDef(RegNo: NewReg) |
153 | .addUse(RegNo: GR.getSPIRVTypeID(SpirvType: NewPtrType)) |
154 | .addUse(RegNo: OpReg) |
155 | .constrainAllUses(TII: *STI.getInstrInfo(), TRI: *STI.getRegisterInfo(), |
156 | RBI: *STI.getRegBankInfo()); |
157 | if (!Res) |
158 | report_fatal_error(reason: "insert validation bitcast: cannot constrain all uses" ); |
159 | I.getOperand(i: OpIdx).setReg(NewReg); |
160 | } |
161 | |
162 | static SPIRVType *createNewPtrType(SPIRVGlobalRegistry &GR, MachineInstr &I, |
163 | SPIRVType *OpType, bool ReuseType, |
164 | SPIRVType *ResType, const Type *ResTy) { |
165 | SPIRV::StorageClass::StorageClass SC = |
166 | static_cast<SPIRV::StorageClass::StorageClass>( |
167 | OpType->getOperand(i: 1).getImm()); |
168 | MachineIRBuilder MIB(I); |
169 | SPIRVType *NewBaseType = |
170 | ReuseType ? ResType |
171 | : GR.getOrCreateSPIRVType( |
172 | Type: ResTy, MIRBuilder&: MIB, AQ: SPIRV::AccessQualifier::ReadWrite, EmitIR: false); |
173 | return GR.getOrCreateSPIRVPointerType(BaseType: NewBaseType, MIRBuilder&: MIB, SC); |
174 | } |
175 | |
176 | // Insert a bitcast before the instruction to keep SPIR-V code valid |
177 | // when there is a type mismatch between results and operand types. |
178 | static void validatePtrTypes(const SPIRVSubtarget &STI, |
179 | MachineRegisterInfo *MRI, SPIRVGlobalRegistry &GR, |
180 | MachineInstr &I, unsigned OpIdx, |
181 | SPIRVType *ResType, const Type *ResTy = nullptr) { |
182 | // Get operand type |
183 | MachineFunction *MF = I.getParent()->getParent(); |
184 | Register OpReg = I.getOperand(i: OpIdx).getReg(); |
185 | Register OpTypeReg = getTypeReg(MRI, OpReg); |
186 | SPIRVType *OpType = GR.getSPIRVTypeForVReg(VReg: OpTypeReg, MF); |
187 | if (!ResType || !OpType || OpType->getOpcode() != SPIRV::OpTypePointer) |
188 | return; |
189 | // Get operand's pointee type |
190 | Register ElemTypeReg = OpType->getOperand(i: 2).getReg(); |
191 | SPIRVType *ElemType = GR.getSPIRVTypeForVReg(VReg: ElemTypeReg, MF); |
192 | if (!ElemType) |
193 | return; |
194 | // Check if we need a bitcast to make a statement valid |
195 | bool IsSameMF = MF == ResType->getParent()->getParent(); |
196 | bool IsEqualTypes = IsSameMF ? ElemType == ResType |
197 | : GR.getTypeForSPIRVType(Ty: ElemType) == ResTy; |
198 | if (IsEqualTypes) |
199 | return; |
200 | // There is a type mismatch between results and operand types |
201 | // and we insert a bitcast before the instruction to keep SPIR-V code valid |
202 | SPIRVType *NewPtrType = |
203 | createNewPtrType(GR, I, OpType, ReuseType: IsSameMF, ResType, ResTy); |
204 | if (!GR.isBitcastCompatible(Type1: NewPtrType, Type2: OpType)) |
205 | report_fatal_error( |
206 | reason: "insert validation bitcast: incompatible result and operand types" ); |
207 | doInsertBitcast(STI, MRI, GR, I, OpReg, OpIdx, NewPtrType); |
208 | } |
209 | |
210 | // Insert a bitcast before OpGroupWaitEvents if the last argument is a pointer |
211 | // that doesn't point to OpTypeEvent. |
212 | static void validateGroupWaitEventsPtr(const SPIRVSubtarget &STI, |
213 | MachineRegisterInfo *MRI, |
214 | SPIRVGlobalRegistry &GR, |
215 | MachineInstr &I) { |
216 | constexpr unsigned OpIdx = 2; |
217 | MachineFunction *MF = I.getParent()->getParent(); |
218 | Register OpReg = I.getOperand(i: OpIdx).getReg(); |
219 | Register OpTypeReg = getTypeReg(MRI, OpReg); |
220 | SPIRVType *OpType = GR.getSPIRVTypeForVReg(VReg: OpTypeReg, MF); |
221 | if (!OpType || OpType->getOpcode() != SPIRV::OpTypePointer) |
222 | return; |
223 | SPIRVType *ElemType = GR.getSPIRVTypeForVReg(VReg: OpType->getOperand(i: 2).getReg()); |
224 | if (!ElemType || ElemType->getOpcode() == SPIRV::OpTypeEvent) |
225 | return; |
226 | // Insert a bitcast before the instruction to keep SPIR-V code valid. |
227 | LLVMContext &Context = MF->getFunction().getContext(); |
228 | SPIRVType *NewPtrType = |
229 | createNewPtrType(GR, I, OpType, ReuseType: false, ResType: nullptr, |
230 | ResTy: TargetExtType::get(Context, Name: "spirv.Event" )); |
231 | doInsertBitcast(STI, MRI, GR, I, OpReg, OpIdx, NewPtrType); |
232 | } |
233 | |
234 | static void validateLifetimeStart(const SPIRVSubtarget &STI, |
235 | MachineRegisterInfo *MRI, |
236 | SPIRVGlobalRegistry &GR, MachineInstr &I) { |
237 | Register PtrReg = I.getOperand(i: 0).getReg(); |
238 | MachineFunction *MF = I.getParent()->getParent(); |
239 | Register PtrTypeReg = getTypeReg(MRI, OpReg: PtrReg); |
240 | SPIRVType *PtrType = GR.getSPIRVTypeForVReg(VReg: PtrTypeReg, MF); |
241 | SPIRVType *PonteeElemType = PtrType ? GR.getPointeeType(PtrType) : nullptr; |
242 | if (!PonteeElemType || PonteeElemType->getOpcode() == SPIRV::OpTypeVoid || |
243 | (PonteeElemType->getOpcode() == SPIRV::OpTypeInt && |
244 | PonteeElemType->getOperand(i: 1).getImm() == 8)) |
245 | return; |
246 | // To keep the code valid a bitcast must be inserted |
247 | SPIRV::StorageClass::StorageClass SC = |
248 | static_cast<SPIRV::StorageClass::StorageClass>( |
249 | PtrType->getOperand(i: 1).getImm()); |
250 | MachineIRBuilder MIB(I); |
251 | LLVMContext &Context = MF->getFunction().getContext(); |
252 | SPIRVType *NewPtrType = |
253 | GR.getOrCreateSPIRVPointerType(BaseType: IntegerType::getInt8Ty(C&: Context), MIRBuilder&: MIB, SC); |
254 | doInsertBitcast(STI, MRI, GR, I, OpReg: PtrReg, OpIdx: 0, NewPtrType); |
255 | } |
256 | |
257 | static void validatePtrUnwrapStructField(const SPIRVSubtarget &STI, |
258 | MachineRegisterInfo *MRI, |
259 | SPIRVGlobalRegistry &GR, |
260 | MachineInstr &I, unsigned OpIdx) { |
261 | MachineFunction *MF = I.getParent()->getParent(); |
262 | Register OpReg = I.getOperand(i: OpIdx).getReg(); |
263 | Register OpTypeReg = getTypeReg(MRI, OpReg); |
264 | SPIRVType *OpType = GR.getSPIRVTypeForVReg(VReg: OpTypeReg, MF); |
265 | if (!OpType || OpType->getOpcode() != SPIRV::OpTypePointer) |
266 | return; |
267 | SPIRVType *ElemType = GR.getSPIRVTypeForVReg(VReg: OpType->getOperand(i: 2).getReg()); |
268 | if (!ElemType || ElemType->getOpcode() != SPIRV::OpTypeStruct || |
269 | ElemType->getNumOperands() != 2) |
270 | return; |
271 | // It's a structure-wrapper around another type with a single member field. |
272 | SPIRVType *MemberType = |
273 | GR.getSPIRVTypeForVReg(VReg: ElemType->getOperand(i: 1).getReg()); |
274 | if (!MemberType) |
275 | return; |
276 | unsigned MemberTypeOp = MemberType->getOpcode(); |
277 | if (MemberTypeOp != SPIRV::OpTypeVector && MemberTypeOp != SPIRV::OpTypeInt && |
278 | MemberTypeOp != SPIRV::OpTypeFloat && MemberTypeOp != SPIRV::OpTypeBool) |
279 | return; |
280 | // It's a structure-wrapper around a valid type. Insert a bitcast before the |
281 | // instruction to keep SPIR-V code valid. |
282 | SPIRV::StorageClass::StorageClass SC = |
283 | static_cast<SPIRV::StorageClass::StorageClass>( |
284 | OpType->getOperand(i: 1).getImm()); |
285 | MachineIRBuilder MIB(I); |
286 | SPIRVType *NewPtrType = GR.getOrCreateSPIRVPointerType(BaseType: MemberType, MIRBuilder&: MIB, SC); |
287 | doInsertBitcast(STI, MRI, GR, I, OpReg, OpIdx, NewPtrType); |
288 | } |
289 | |
290 | // Insert a bitcast before the function call instruction to keep SPIR-V code |
291 | // valid when there is a type mismatch between actual and expected types of an |
292 | // argument: |
293 | // %formal = OpFunctionParameter %formal_type |
294 | // ... |
295 | // %res = OpFunctionCall %ty %fun %actual ... |
296 | // implies that %actual is of %formal_type, and in case of opaque pointers. |
297 | // We may need to insert a bitcast to ensure this. |
298 | void validateFunCallMachineDef(const SPIRVSubtarget &STI, |
299 | MachineRegisterInfo *DefMRI, |
300 | MachineRegisterInfo *CallMRI, |
301 | SPIRVGlobalRegistry &GR, MachineInstr &FunCall, |
302 | MachineInstr *FunDef) { |
303 | if (FunDef->getOpcode() != SPIRV::OpFunction) |
304 | return; |
305 | unsigned OpIdx = 3; |
306 | for (FunDef = FunDef->getNextNode(); |
307 | FunDef && FunDef->getOpcode() == SPIRV::OpFunctionParameter && |
308 | OpIdx < FunCall.getNumOperands(); |
309 | FunDef = FunDef->getNextNode(), OpIdx++) { |
310 | SPIRVType *DefPtrType = DefMRI->getVRegDef(Reg: FunDef->getOperand(i: 1).getReg()); |
311 | SPIRVType *DefElemType = |
312 | DefPtrType && DefPtrType->getOpcode() == SPIRV::OpTypePointer |
313 | ? GR.getSPIRVTypeForVReg(VReg: DefPtrType->getOperand(i: 2).getReg(), |
314 | MF: DefPtrType->getParent()->getParent()) |
315 | : nullptr; |
316 | if (DefElemType) { |
317 | const Type *DefElemTy = GR.getTypeForSPIRVType(Ty: DefElemType); |
318 | // validatePtrTypes() works in the context if the call site |
319 | // When we process historical records about forward calls |
320 | // we need to switch context to the (forward) call site and |
321 | // then restore it back to the current machine function. |
322 | MachineFunction *CurMF = |
323 | GR.setCurrentFunc(*FunCall.getParent()->getParent()); |
324 | validatePtrTypes(STI, MRI: CallMRI, GR, I&: FunCall, OpIdx, ResType: DefElemType, |
325 | ResTy: DefElemTy); |
326 | GR.setCurrentFunc(*CurMF); |
327 | } |
328 | } |
329 | } |
330 | |
331 | // Ensure there is no mismatch between actual and expected arg types: calls |
332 | // with a processed definition. Return Function pointer if it's a forward |
333 | // call (ahead of definition), and nullptr otherwise. |
334 | const Function *validateFunCall(const SPIRVSubtarget &STI, |
335 | MachineRegisterInfo *CallMRI, |
336 | SPIRVGlobalRegistry &GR, |
337 | MachineInstr &FunCall) { |
338 | const GlobalValue *GV = FunCall.getOperand(i: 2).getGlobal(); |
339 | const Function *F = dyn_cast<Function>(Val: GV); |
340 | MachineInstr *FunDef = |
341 | const_cast<MachineInstr *>(GR.getFunctionDefinition(F)); |
342 | if (!FunDef) |
343 | return F; |
344 | MachineRegisterInfo *DefMRI = &FunDef->getParent()->getParent()->getRegInfo(); |
345 | validateFunCallMachineDef(STI, DefMRI, CallMRI, GR, FunCall, FunDef); |
346 | return nullptr; |
347 | } |
348 | |
349 | // Ensure there is no mismatch between actual and expected arg types: calls |
350 | // ahead of a processed definition. |
351 | void validateForwardCalls(const SPIRVSubtarget &STI, |
352 | MachineRegisterInfo *DefMRI, SPIRVGlobalRegistry &GR, |
353 | MachineInstr &FunDef) { |
354 | const Function *F = GR.getFunctionByDefinition(MI: &FunDef); |
355 | if (SmallPtrSet<MachineInstr *, 8> *FwdCalls = GR.getForwardCalls(F)) |
356 | for (MachineInstr *FunCall : *FwdCalls) { |
357 | MachineRegisterInfo *CallMRI = |
358 | &FunCall->getParent()->getParent()->getRegInfo(); |
359 | validateFunCallMachineDef(STI, DefMRI, CallMRI, GR, FunCall&: *FunCall, FunDef: &FunDef); |
360 | } |
361 | } |
362 | |
363 | // Validation of an access chain. |
364 | void validateAccessChain(const SPIRVSubtarget &STI, MachineRegisterInfo *MRI, |
365 | SPIRVGlobalRegistry &GR, MachineInstr &I) { |
366 | SPIRVType *BaseTypeInst = GR.getSPIRVTypeForVReg(VReg: I.getOperand(i: 0).getReg()); |
367 | if (BaseTypeInst && BaseTypeInst->getOpcode() == SPIRV::OpTypePointer) { |
368 | SPIRVType *BaseElemType = |
369 | GR.getSPIRVTypeForVReg(VReg: BaseTypeInst->getOperand(i: 2).getReg()); |
370 | validatePtrTypes(STI, MRI, GR, I, OpIdx: 2, ResType: BaseElemType); |
371 | } |
372 | } |
373 | |
374 | // TODO: the logic of inserting additional bitcast's is to be moved |
375 | // to pre-IRTranslation passes eventually |
376 | void SPIRVTargetLowering::finalizeLowering(MachineFunction &MF) const { |
377 | // finalizeLowering() is called twice (see GlobalISel/InstructionSelect.cpp) |
378 | // We'd like to avoid the needless second processing pass. |
379 | if (ProcessedMF.find(x: &MF) != ProcessedMF.end()) |
380 | return; |
381 | |
382 | MachineRegisterInfo *MRI = &MF.getRegInfo(); |
383 | SPIRVGlobalRegistry &GR = *STI.getSPIRVGlobalRegistry(); |
384 | GR.setCurrentFunc(MF); |
385 | for (MachineFunction::iterator I = MF.begin(), E = MF.end(); I != E; ++I) { |
386 | MachineBasicBlock *MBB = &*I; |
387 | SmallPtrSet<MachineInstr *, 8> ToMove; |
388 | for (MachineBasicBlock::iterator MBBI = MBB->begin(), MBBE = MBB->end(); |
389 | MBBI != MBBE;) { |
390 | MachineInstr &MI = *MBBI++; |
391 | switch (MI.getOpcode()) { |
392 | case SPIRV::OpAtomicLoad: |
393 | case SPIRV::OpAtomicExchange: |
394 | case SPIRV::OpAtomicCompareExchange: |
395 | case SPIRV::OpAtomicCompareExchangeWeak: |
396 | case SPIRV::OpAtomicIIncrement: |
397 | case SPIRV::OpAtomicIDecrement: |
398 | case SPIRV::OpAtomicIAdd: |
399 | case SPIRV::OpAtomicISub: |
400 | case SPIRV::OpAtomicSMin: |
401 | case SPIRV::OpAtomicUMin: |
402 | case SPIRV::OpAtomicSMax: |
403 | case SPIRV::OpAtomicUMax: |
404 | case SPIRV::OpAtomicAnd: |
405 | case SPIRV::OpAtomicOr: |
406 | case SPIRV::OpAtomicXor: |
407 | // for the above listed instructions |
408 | // OpAtomicXXX <ResType>, ptr %Op, ... |
409 | // implies that %Op is a pointer to <ResType> |
410 | case SPIRV::OpLoad: |
411 | // OpLoad <ResType>, ptr %Op implies that %Op is a pointer to <ResType> |
412 | if (enforcePtrTypeCompatibility(I&: MI, PtrOpIdx: 2, OpIdx: 0)) |
413 | break; |
414 | |
415 | validatePtrTypes(STI, MRI, GR, I&: MI, OpIdx: 2, |
416 | ResType: GR.getSPIRVTypeForVReg(VReg: MI.getOperand(i: 0).getReg())); |
417 | break; |
418 | case SPIRV::OpAtomicStore: |
419 | // OpAtomicStore ptr %Op, <Scope>, <Mem>, <Obj> |
420 | // implies that %Op points to the <Obj>'s type |
421 | validatePtrTypes(STI, MRI, GR, I&: MI, OpIdx: 0, |
422 | ResType: GR.getSPIRVTypeForVReg(VReg: MI.getOperand(i: 3).getReg())); |
423 | break; |
424 | case SPIRV::OpStore: |
425 | // OpStore ptr %Op, <Obj> implies that %Op points to the <Obj>'s type |
426 | validatePtrTypes(STI, MRI, GR, I&: MI, OpIdx: 0, |
427 | ResType: GR.getSPIRVTypeForVReg(VReg: MI.getOperand(i: 1).getReg())); |
428 | break; |
429 | case SPIRV::OpPtrCastToGeneric: |
430 | case SPIRV::OpGenericCastToPtr: |
431 | case SPIRV::OpGenericCastToPtrExplicit: |
432 | validateAccessChain(STI, MRI, GR, I&: MI); |
433 | break; |
434 | case SPIRV::OpPtrAccessChain: |
435 | case SPIRV::OpInBoundsPtrAccessChain: |
436 | if (MI.getNumOperands() == 4) |
437 | validateAccessChain(STI, MRI, GR, I&: MI); |
438 | break; |
439 | |
440 | case SPIRV::OpFunctionCall: |
441 | // ensure there is no mismatch between actual and expected arg types: |
442 | // calls with a processed definition |
443 | if (MI.getNumOperands() > 3) |
444 | if (const Function *F = validateFunCall(STI, CallMRI: MRI, GR, FunCall&: MI)) |
445 | GR.addForwardCall(F, MI: &MI); |
446 | break; |
447 | case SPIRV::OpFunction: |
448 | // ensure there is no mismatch between actual and expected arg types: |
449 | // calls ahead of a processed definition |
450 | validateForwardCalls(STI, DefMRI: MRI, GR, FunDef&: MI); |
451 | break; |
452 | |
453 | // ensure that LLVM IR add/sub instructions result in logical SPIR-V |
454 | // instructions when applied to bool type |
455 | case SPIRV::OpIAddS: |
456 | case SPIRV::OpIAddV: |
457 | case SPIRV::OpISubS: |
458 | case SPIRV::OpISubV: |
459 | if (GR.isScalarOrVectorOfType(VReg: MI.getOperand(i: 1).getReg(), |
460 | TypeOpcode: SPIRV::OpTypeBool)) |
461 | MI.setDesc(STI.getInstrInfo()->get(Opcode: SPIRV::OpLogicalNotEqual)); |
462 | break; |
463 | |
464 | // ensure that LLVM IR bitwise instructions result in logical SPIR-V |
465 | // instructions when applied to bool type |
466 | case SPIRV::OpBitwiseOrS: |
467 | case SPIRV::OpBitwiseOrV: |
468 | if (GR.isScalarOrVectorOfType(VReg: MI.getOperand(i: 1).getReg(), |
469 | TypeOpcode: SPIRV::OpTypeBool)) |
470 | MI.setDesc(STI.getInstrInfo()->get(Opcode: SPIRV::OpLogicalOr)); |
471 | break; |
472 | case SPIRV::OpBitwiseAndS: |
473 | case SPIRV::OpBitwiseAndV: |
474 | if (GR.isScalarOrVectorOfType(VReg: MI.getOperand(i: 1).getReg(), |
475 | TypeOpcode: SPIRV::OpTypeBool)) |
476 | MI.setDesc(STI.getInstrInfo()->get(Opcode: SPIRV::OpLogicalAnd)); |
477 | break; |
478 | case SPIRV::OpBitwiseXorS: |
479 | case SPIRV::OpBitwiseXorV: |
480 | if (GR.isScalarOrVectorOfType(VReg: MI.getOperand(i: 1).getReg(), |
481 | TypeOpcode: SPIRV::OpTypeBool)) |
482 | MI.setDesc(STI.getInstrInfo()->get(Opcode: SPIRV::OpLogicalNotEqual)); |
483 | break; |
484 | case SPIRV::OpLifetimeStart: |
485 | case SPIRV::OpLifetimeStop: |
486 | if (MI.getOperand(i: 1).getImm() > 0) |
487 | validateLifetimeStart(STI, MRI, GR, I&: MI); |
488 | break; |
489 | case SPIRV::OpGroupAsyncCopy: |
490 | validatePtrUnwrapStructField(STI, MRI, GR, I&: MI, OpIdx: 3); |
491 | validatePtrUnwrapStructField(STI, MRI, GR, I&: MI, OpIdx: 4); |
492 | break; |
493 | case SPIRV::OpGroupWaitEvents: |
494 | // OpGroupWaitEvents ..., ..., <pointer to OpTypeEvent> |
495 | validateGroupWaitEventsPtr(STI, MRI, GR, I&: MI); |
496 | break; |
497 | case SPIRV::OpConstantI: { |
498 | SPIRVType *Type = GR.getSPIRVTypeForVReg(VReg: MI.getOperand(i: 1).getReg()); |
499 | if (Type->getOpcode() != SPIRV::OpTypeInt && MI.getOperand(i: 2).isImm() && |
500 | MI.getOperand(i: 2).getImm() == 0) { |
501 | // Validate the null constant of a target extension type |
502 | MI.setDesc(STI.getInstrInfo()->get(Opcode: SPIRV::OpConstantNull)); |
503 | for (unsigned i = MI.getNumOperands() - 1; i > 1; --i) |
504 | MI.removeOperand(OpNo: i); |
505 | } |
506 | } break; |
507 | case SPIRV::OpPhi: { |
508 | // Phi refers to a type definition that goes after the Phi |
509 | // instruction, so that the virtual register definition of the type |
510 | // doesn't dominate all uses. Let's place the type definition |
511 | // instruction at the end of the predecessor. |
512 | MachineBasicBlock *Curr = MI.getParent(); |
513 | SPIRVType *Type = GR.getSPIRVTypeForVReg(VReg: MI.getOperand(i: 1).getReg()); |
514 | if (Type->getParent() == Curr && !Curr->pred_empty()) |
515 | ToMove.insert(Ptr: const_cast<MachineInstr *>(Type)); |
516 | } break; |
517 | case SPIRV::OpExtInst: { |
518 | // prefetch |
519 | if (!MI.getOperand(i: 2).isImm() || !MI.getOperand(i: 3).isImm() || |
520 | MI.getOperand(i: 2).getImm() != SPIRV::InstructionSet::OpenCL_std) |
521 | continue; |
522 | switch (MI.getOperand(i: 3).getImm()) { |
523 | case SPIRV::OpenCLExtInst::frexp: |
524 | case SPIRV::OpenCLExtInst::lgamma_r: |
525 | case SPIRV::OpenCLExtInst::remquo: { |
526 | // The last operand must be of a pointer to i32 or vector of i32 |
527 | // values. |
528 | MachineIRBuilder MIB(MI); |
529 | SPIRVType *Int32Type = GR.getOrCreateSPIRVIntegerType(BitWidth: 32, MIRBuilder&: MIB); |
530 | SPIRVType *RetType = MRI->getVRegDef(Reg: MI.getOperand(i: 1).getReg()); |
531 | assert(RetType && "Expected return type" ); |
532 | validatePtrTypes(STI, MRI, GR, I&: MI, OpIdx: MI.getNumOperands() - 1, |
533 | ResType: RetType->getOpcode() != SPIRV::OpTypeVector |
534 | ? Int32Type |
535 | : GR.getOrCreateSPIRVVectorType( |
536 | BaseType: Int32Type, NumElements: RetType->getOperand(i: 2).getImm(), |
537 | MIRBuilder&: MIB, EmitIR: false)); |
538 | } break; |
539 | case SPIRV::OpenCLExtInst::fract: |
540 | case SPIRV::OpenCLExtInst::modf: |
541 | case SPIRV::OpenCLExtInst::sincos: |
542 | // The last operand must be of a pointer to the base type represented |
543 | // by the previous operand. |
544 | assert(MI.getOperand(MI.getNumOperands() - 2).isReg() && |
545 | "Expected v-reg" ); |
546 | validatePtrTypes( |
547 | STI, MRI, GR, I&: MI, OpIdx: MI.getNumOperands() - 1, |
548 | ResType: GR.getSPIRVTypeForVReg( |
549 | VReg: MI.getOperand(i: MI.getNumOperands() - 2).getReg())); |
550 | break; |
551 | case SPIRV::OpenCLExtInst::prefetch: |
552 | // Expected `ptr` type is a pointer to float, integer or vector, but |
553 | // the pontee value can be wrapped into a struct. |
554 | assert(MI.getOperand(MI.getNumOperands() - 2).isReg() && |
555 | "Expected v-reg" ); |
556 | validatePtrUnwrapStructField(STI, MRI, GR, I&: MI, |
557 | OpIdx: MI.getNumOperands() - 2); |
558 | break; |
559 | } |
560 | } break; |
561 | } |
562 | } |
563 | for (MachineInstr *MI : ToMove) { |
564 | MachineBasicBlock *Curr = MI->getParent(); |
565 | MachineBasicBlock *Pred = *Curr->pred_begin(); |
566 | Pred->insert(I: Pred->getFirstTerminator(), MI: Curr->remove_instr(I: MI)); |
567 | } |
568 | } |
569 | ProcessedMF.insert(x: &MF); |
570 | TargetLowering::finalizeLowering(MF); |
571 | } |
572 | |
573 | // Modifies either operand PtrOpIdx or OpIdx so that the pointee type of |
574 | // PtrOpIdx matches the type for operand OpIdx. Returns true if they already |
575 | // match or if the instruction was modified to make them match. |
576 | bool SPIRVTargetLowering::enforcePtrTypeCompatibility( |
577 | MachineInstr &I, unsigned int PtrOpIdx, unsigned int OpIdx) const { |
578 | SPIRVGlobalRegistry &GR = *STI.getSPIRVGlobalRegistry(); |
579 | SPIRVType *PtrType = GR.getResultType(VReg: I.getOperand(i: PtrOpIdx).getReg()); |
580 | SPIRVType *PointeeType = GR.getPointeeType(PtrType); |
581 | SPIRVType *OpType = GR.getResultType(VReg: I.getOperand(i: OpIdx).getReg()); |
582 | |
583 | if (PointeeType == OpType) |
584 | return true; |
585 | |
586 | if (typesLogicallyMatch(Ty1: PointeeType, Ty2: OpType, GR)) { |
587 | // Apply OpCopyLogical to OpIdx. |
588 | if (I.getOperand(i: OpIdx).isDef() && |
589 | insertLogicalCopyOnResult(I, NewResultType: PointeeType)) { |
590 | return true; |
591 | } |
592 | |
593 | llvm_unreachable("Unable to add OpCopyLogical yet." ); |
594 | return false; |
595 | } |
596 | |
597 | return false; |
598 | } |
599 | |
600 | bool SPIRVTargetLowering::insertLogicalCopyOnResult( |
601 | MachineInstr &I, SPIRVType *NewResultType) const { |
602 | MachineRegisterInfo *MRI = &I.getMF()->getRegInfo(); |
603 | SPIRVGlobalRegistry &GR = *STI.getSPIRVGlobalRegistry(); |
604 | |
605 | Register NewResultReg = |
606 | createVirtualRegister(SpvType: NewResultType, GR: &GR, MRI, MF: *I.getMF()); |
607 | Register NewTypeReg = GR.getSPIRVTypeID(SpirvType: NewResultType); |
608 | |
609 | assert(std::distance(I.defs().begin(), I.defs().end()) == 1 && |
610 | "Expected only one def" ); |
611 | MachineOperand &OldResult = *I.defs().begin(); |
612 | Register OldResultReg = OldResult.getReg(); |
613 | MachineOperand &OldType = *I.uses().begin(); |
614 | Register OldTypeReg = OldType.getReg(); |
615 | |
616 | OldResult.setReg(NewResultReg); |
617 | OldType.setReg(NewTypeReg); |
618 | |
619 | MachineIRBuilder MIB(*I.getNextNode()); |
620 | return MIB.buildInstr(Opcode: SPIRV::OpCopyLogical) |
621 | .addDef(RegNo: OldResultReg) |
622 | .addUse(RegNo: OldTypeReg) |
623 | .addUse(RegNo: NewResultReg) |
624 | .constrainAllUses(TII: *STI.getInstrInfo(), TRI: *STI.getRegisterInfo(), |
625 | RBI: *STI.getRegBankInfo()); |
626 | } |
627 | |