1//===- SPIRVISelLowering.cpp - SPIR-V DAG Lowering Impl ---------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements the SPIRVTargetLowering class.
10//
11//===----------------------------------------------------------------------===//
12
13#include "SPIRVISelLowering.h"
14#include "SPIRV.h"
15#include "SPIRVInstrInfo.h"
16#include "SPIRVRegisterBankInfo.h"
17#include "SPIRVRegisterInfo.h"
18#include "SPIRVSubtarget.h"
19#include "llvm/CodeGen/MachineInstrBuilder.h"
20#include "llvm/CodeGen/MachineRegisterInfo.h"
21#include "llvm/IR/IntrinsicsSPIRV.h"
22
23#define DEBUG_TYPE "spirv-lower"
24
25using namespace llvm;
26
27// Returns true of the types logically match, as defined in
28// https://registry.khronos.org/SPIR-V/specs/unified1/SPIRV.html#OpCopyLogical.
29static bool typesLogicallyMatch(const SPIRVType *Ty1, const SPIRVType *Ty2,
30 SPIRVGlobalRegistry &GR) {
31 if (Ty1->getOpcode() != Ty2->getOpcode())
32 return false;
33
34 if (Ty1->getNumOperands() != Ty2->getNumOperands())
35 return false;
36
37 if (Ty1->getOpcode() == SPIRV::OpTypeArray) {
38 // Array must have the same size.
39 if (Ty1->getOperand(i: 2).getReg() != Ty2->getOperand(i: 2).getReg())
40 return false;
41
42 SPIRVType *ElemType1 = GR.getSPIRVTypeForVReg(VReg: Ty1->getOperand(i: 1).getReg());
43 SPIRVType *ElemType2 = GR.getSPIRVTypeForVReg(VReg: Ty2->getOperand(i: 1).getReg());
44 return ElemType1 == ElemType2 ||
45 typesLogicallyMatch(Ty1: ElemType1, Ty2: ElemType2, GR);
46 }
47
48 if (Ty1->getOpcode() == SPIRV::OpTypeStruct) {
49 for (unsigned I = 1; I < Ty1->getNumOperands(); I++) {
50 SPIRVType *ElemType1 =
51 GR.getSPIRVTypeForVReg(VReg: Ty1->getOperand(i: I).getReg());
52 SPIRVType *ElemType2 =
53 GR.getSPIRVTypeForVReg(VReg: Ty2->getOperand(i: I).getReg());
54 if (ElemType1 != ElemType2 &&
55 !typesLogicallyMatch(Ty1: ElemType1, Ty2: ElemType2, GR))
56 return false;
57 }
58 return true;
59 }
60 return false;
61}
62
63unsigned SPIRVTargetLowering::getNumRegistersForCallingConv(
64 LLVMContext &Context, CallingConv::ID CC, EVT VT) const {
65 // This code avoids CallLowering fail inside getVectorTypeBreakdown
66 // on v3i1 arguments. Maybe we need to return 1 for all types.
67 // TODO: remove it once this case is supported by the default implementation.
68 if (VT.isVector() && VT.getVectorNumElements() == 3 &&
69 (VT.getVectorElementType() == MVT::i1 ||
70 VT.getVectorElementType() == MVT::i8))
71 return 1;
72 if (!VT.isVector() && VT.isInteger() && VT.getSizeInBits() <= 64)
73 return 1;
74 return getNumRegisters(Context, VT);
75}
76
77MVT SPIRVTargetLowering::getRegisterTypeForCallingConv(LLVMContext &Context,
78 CallingConv::ID CC,
79 EVT VT) const {
80 // This code avoids CallLowering fail inside getVectorTypeBreakdown
81 // on v3i1 arguments. Maybe we need to return i32 for all types.
82 // TODO: remove it once this case is supported by the default implementation.
83 if (VT.isVector() && VT.getVectorNumElements() == 3) {
84 if (VT.getVectorElementType() == MVT::i1)
85 return MVT::v4i1;
86 else if (VT.getVectorElementType() == MVT::i8)
87 return MVT::v4i8;
88 }
89 return getRegisterType(Context, VT);
90}
91
92bool SPIRVTargetLowering::getTgtMemIntrinsic(IntrinsicInfo &Info,
93 const CallInst &I,
94 MachineFunction &MF,
95 unsigned Intrinsic) const {
96 unsigned AlignIdx = 3;
97 switch (Intrinsic) {
98 case Intrinsic::spv_load:
99 AlignIdx = 2;
100 [[fallthrough]];
101 case Intrinsic::spv_store: {
102 if (I.getNumOperands() >= AlignIdx + 1) {
103 auto *AlignOp = cast<ConstantInt>(Val: I.getOperand(i_nocapture: AlignIdx));
104 Info.align = Align(AlignOp->getZExtValue());
105 }
106 Info.flags = static_cast<MachineMemOperand::Flags>(
107 cast<ConstantInt>(Val: I.getOperand(i_nocapture: AlignIdx - 1))->getZExtValue());
108 Info.memVT = MVT::i64;
109 // TODO: take into account opaque pointers (don't use getElementType).
110 // MVT::getVT(PtrTy->getElementType());
111 return true;
112 break;
113 }
114 default:
115 break;
116 }
117 return false;
118}
119
120std::pair<unsigned, const TargetRegisterClass *>
121SPIRVTargetLowering::getRegForInlineAsmConstraint(const TargetRegisterInfo *TRI,
122 StringRef Constraint,
123 MVT VT) const {
124 const TargetRegisterClass *RC = nullptr;
125 if (Constraint.starts_with(Prefix: "{"))
126 return std::make_pair(x: 0u, y&: RC);
127
128 if (VT.isFloatingPoint())
129 RC = VT.isVector() ? &SPIRV::vfIDRegClass : &SPIRV::fIDRegClass;
130 else if (VT.isInteger())
131 RC = VT.isVector() ? &SPIRV::vIDRegClass : &SPIRV::iIDRegClass;
132 else
133 RC = &SPIRV::iIDRegClass;
134
135 return std::make_pair(x: 0u, y&: RC);
136}
137
138inline Register getTypeReg(MachineRegisterInfo *MRI, Register OpReg) {
139 SPIRVType *TypeInst = MRI->getVRegDef(Reg: OpReg);
140 return TypeInst && TypeInst->getOpcode() == SPIRV::OpFunctionParameter
141 ? TypeInst->getOperand(i: 1).getReg()
142 : OpReg;
143}
144
145static void doInsertBitcast(const SPIRVSubtarget &STI, MachineRegisterInfo *MRI,
146 SPIRVGlobalRegistry &GR, MachineInstr &I,
147 Register OpReg, unsigned OpIdx,
148 SPIRVType *NewPtrType) {
149 MachineIRBuilder MIB(I);
150 Register NewReg = createVirtualRegister(SpvType: NewPtrType, GR: &GR, MRI, MF: MIB.getMF());
151 bool Res = MIB.buildInstr(Opcode: SPIRV::OpBitcast)
152 .addDef(RegNo: NewReg)
153 .addUse(RegNo: GR.getSPIRVTypeID(SpirvType: NewPtrType))
154 .addUse(RegNo: OpReg)
155 .constrainAllUses(TII: *STI.getInstrInfo(), TRI: *STI.getRegisterInfo(),
156 RBI: *STI.getRegBankInfo());
157 if (!Res)
158 report_fatal_error(reason: "insert validation bitcast: cannot constrain all uses");
159 I.getOperand(i: OpIdx).setReg(NewReg);
160}
161
162static SPIRVType *createNewPtrType(SPIRVGlobalRegistry &GR, MachineInstr &I,
163 SPIRVType *OpType, bool ReuseType,
164 SPIRVType *ResType, const Type *ResTy) {
165 SPIRV::StorageClass::StorageClass SC =
166 static_cast<SPIRV::StorageClass::StorageClass>(
167 OpType->getOperand(i: 1).getImm());
168 MachineIRBuilder MIB(I);
169 SPIRVType *NewBaseType =
170 ReuseType ? ResType
171 : GR.getOrCreateSPIRVType(
172 Type: ResTy, MIRBuilder&: MIB, AQ: SPIRV::AccessQualifier::ReadWrite, EmitIR: false);
173 return GR.getOrCreateSPIRVPointerType(BaseType: NewBaseType, MIRBuilder&: MIB, SC);
174}
175
176// Insert a bitcast before the instruction to keep SPIR-V code valid
177// when there is a type mismatch between results and operand types.
178static void validatePtrTypes(const SPIRVSubtarget &STI,
179 MachineRegisterInfo *MRI, SPIRVGlobalRegistry &GR,
180 MachineInstr &I, unsigned OpIdx,
181 SPIRVType *ResType, const Type *ResTy = nullptr) {
182 // Get operand type
183 MachineFunction *MF = I.getParent()->getParent();
184 Register OpReg = I.getOperand(i: OpIdx).getReg();
185 Register OpTypeReg = getTypeReg(MRI, OpReg);
186 SPIRVType *OpType = GR.getSPIRVTypeForVReg(VReg: OpTypeReg, MF);
187 if (!ResType || !OpType || OpType->getOpcode() != SPIRV::OpTypePointer)
188 return;
189 // Get operand's pointee type
190 Register ElemTypeReg = OpType->getOperand(i: 2).getReg();
191 SPIRVType *ElemType = GR.getSPIRVTypeForVReg(VReg: ElemTypeReg, MF);
192 if (!ElemType)
193 return;
194 // Check if we need a bitcast to make a statement valid
195 bool IsSameMF = MF == ResType->getParent()->getParent();
196 bool IsEqualTypes = IsSameMF ? ElemType == ResType
197 : GR.getTypeForSPIRVType(Ty: ElemType) == ResTy;
198 if (IsEqualTypes)
199 return;
200 // There is a type mismatch between results and operand types
201 // and we insert a bitcast before the instruction to keep SPIR-V code valid
202 SPIRVType *NewPtrType =
203 createNewPtrType(GR, I, OpType, ReuseType: IsSameMF, ResType, ResTy);
204 if (!GR.isBitcastCompatible(Type1: NewPtrType, Type2: OpType))
205 report_fatal_error(
206 reason: "insert validation bitcast: incompatible result and operand types");
207 doInsertBitcast(STI, MRI, GR, I, OpReg, OpIdx, NewPtrType);
208}
209
210// Insert a bitcast before OpGroupWaitEvents if the last argument is a pointer
211// that doesn't point to OpTypeEvent.
212static void validateGroupWaitEventsPtr(const SPIRVSubtarget &STI,
213 MachineRegisterInfo *MRI,
214 SPIRVGlobalRegistry &GR,
215 MachineInstr &I) {
216 constexpr unsigned OpIdx = 2;
217 MachineFunction *MF = I.getParent()->getParent();
218 Register OpReg = I.getOperand(i: OpIdx).getReg();
219 Register OpTypeReg = getTypeReg(MRI, OpReg);
220 SPIRVType *OpType = GR.getSPIRVTypeForVReg(VReg: OpTypeReg, MF);
221 if (!OpType || OpType->getOpcode() != SPIRV::OpTypePointer)
222 return;
223 SPIRVType *ElemType = GR.getSPIRVTypeForVReg(VReg: OpType->getOperand(i: 2).getReg());
224 if (!ElemType || ElemType->getOpcode() == SPIRV::OpTypeEvent)
225 return;
226 // Insert a bitcast before the instruction to keep SPIR-V code valid.
227 LLVMContext &Context = MF->getFunction().getContext();
228 SPIRVType *NewPtrType =
229 createNewPtrType(GR, I, OpType, ReuseType: false, ResType: nullptr,
230 ResTy: TargetExtType::get(Context, Name: "spirv.Event"));
231 doInsertBitcast(STI, MRI, GR, I, OpReg, OpIdx, NewPtrType);
232}
233
234static void validateLifetimeStart(const SPIRVSubtarget &STI,
235 MachineRegisterInfo *MRI,
236 SPIRVGlobalRegistry &GR, MachineInstr &I) {
237 Register PtrReg = I.getOperand(i: 0).getReg();
238 MachineFunction *MF = I.getParent()->getParent();
239 Register PtrTypeReg = getTypeReg(MRI, OpReg: PtrReg);
240 SPIRVType *PtrType = GR.getSPIRVTypeForVReg(VReg: PtrTypeReg, MF);
241 SPIRVType *PonteeElemType = PtrType ? GR.getPointeeType(PtrType) : nullptr;
242 if (!PonteeElemType || PonteeElemType->getOpcode() == SPIRV::OpTypeVoid ||
243 (PonteeElemType->getOpcode() == SPIRV::OpTypeInt &&
244 PonteeElemType->getOperand(i: 1).getImm() == 8))
245 return;
246 // To keep the code valid a bitcast must be inserted
247 SPIRV::StorageClass::StorageClass SC =
248 static_cast<SPIRV::StorageClass::StorageClass>(
249 PtrType->getOperand(i: 1).getImm());
250 MachineIRBuilder MIB(I);
251 LLVMContext &Context = MF->getFunction().getContext();
252 SPIRVType *NewPtrType =
253 GR.getOrCreateSPIRVPointerType(BaseType: IntegerType::getInt8Ty(C&: Context), MIRBuilder&: MIB, SC);
254 doInsertBitcast(STI, MRI, GR, I, OpReg: PtrReg, OpIdx: 0, NewPtrType);
255}
256
257static void validatePtrUnwrapStructField(const SPIRVSubtarget &STI,
258 MachineRegisterInfo *MRI,
259 SPIRVGlobalRegistry &GR,
260 MachineInstr &I, unsigned OpIdx) {
261 MachineFunction *MF = I.getParent()->getParent();
262 Register OpReg = I.getOperand(i: OpIdx).getReg();
263 Register OpTypeReg = getTypeReg(MRI, OpReg);
264 SPIRVType *OpType = GR.getSPIRVTypeForVReg(VReg: OpTypeReg, MF);
265 if (!OpType || OpType->getOpcode() != SPIRV::OpTypePointer)
266 return;
267 SPIRVType *ElemType = GR.getSPIRVTypeForVReg(VReg: OpType->getOperand(i: 2).getReg());
268 if (!ElemType || ElemType->getOpcode() != SPIRV::OpTypeStruct ||
269 ElemType->getNumOperands() != 2)
270 return;
271 // It's a structure-wrapper around another type with a single member field.
272 SPIRVType *MemberType =
273 GR.getSPIRVTypeForVReg(VReg: ElemType->getOperand(i: 1).getReg());
274 if (!MemberType)
275 return;
276 unsigned MemberTypeOp = MemberType->getOpcode();
277 if (MemberTypeOp != SPIRV::OpTypeVector && MemberTypeOp != SPIRV::OpTypeInt &&
278 MemberTypeOp != SPIRV::OpTypeFloat && MemberTypeOp != SPIRV::OpTypeBool)
279 return;
280 // It's a structure-wrapper around a valid type. Insert a bitcast before the
281 // instruction to keep SPIR-V code valid.
282 SPIRV::StorageClass::StorageClass SC =
283 static_cast<SPIRV::StorageClass::StorageClass>(
284 OpType->getOperand(i: 1).getImm());
285 MachineIRBuilder MIB(I);
286 SPIRVType *NewPtrType = GR.getOrCreateSPIRVPointerType(BaseType: MemberType, MIRBuilder&: MIB, SC);
287 doInsertBitcast(STI, MRI, GR, I, OpReg, OpIdx, NewPtrType);
288}
289
290// Insert a bitcast before the function call instruction to keep SPIR-V code
291// valid when there is a type mismatch between actual and expected types of an
292// argument:
293// %formal = OpFunctionParameter %formal_type
294// ...
295// %res = OpFunctionCall %ty %fun %actual ...
296// implies that %actual is of %formal_type, and in case of opaque pointers.
297// We may need to insert a bitcast to ensure this.
298void validateFunCallMachineDef(const SPIRVSubtarget &STI,
299 MachineRegisterInfo *DefMRI,
300 MachineRegisterInfo *CallMRI,
301 SPIRVGlobalRegistry &GR, MachineInstr &FunCall,
302 MachineInstr *FunDef) {
303 if (FunDef->getOpcode() != SPIRV::OpFunction)
304 return;
305 unsigned OpIdx = 3;
306 for (FunDef = FunDef->getNextNode();
307 FunDef && FunDef->getOpcode() == SPIRV::OpFunctionParameter &&
308 OpIdx < FunCall.getNumOperands();
309 FunDef = FunDef->getNextNode(), OpIdx++) {
310 SPIRVType *DefPtrType = DefMRI->getVRegDef(Reg: FunDef->getOperand(i: 1).getReg());
311 SPIRVType *DefElemType =
312 DefPtrType && DefPtrType->getOpcode() == SPIRV::OpTypePointer
313 ? GR.getSPIRVTypeForVReg(VReg: DefPtrType->getOperand(i: 2).getReg(),
314 MF: DefPtrType->getParent()->getParent())
315 : nullptr;
316 if (DefElemType) {
317 const Type *DefElemTy = GR.getTypeForSPIRVType(Ty: DefElemType);
318 // validatePtrTypes() works in the context if the call site
319 // When we process historical records about forward calls
320 // we need to switch context to the (forward) call site and
321 // then restore it back to the current machine function.
322 MachineFunction *CurMF =
323 GR.setCurrentFunc(*FunCall.getParent()->getParent());
324 validatePtrTypes(STI, MRI: CallMRI, GR, I&: FunCall, OpIdx, ResType: DefElemType,
325 ResTy: DefElemTy);
326 GR.setCurrentFunc(*CurMF);
327 }
328 }
329}
330
331// Ensure there is no mismatch between actual and expected arg types: calls
332// with a processed definition. Return Function pointer if it's a forward
333// call (ahead of definition), and nullptr otherwise.
334const Function *validateFunCall(const SPIRVSubtarget &STI,
335 MachineRegisterInfo *CallMRI,
336 SPIRVGlobalRegistry &GR,
337 MachineInstr &FunCall) {
338 const GlobalValue *GV = FunCall.getOperand(i: 2).getGlobal();
339 const Function *F = dyn_cast<Function>(Val: GV);
340 MachineInstr *FunDef =
341 const_cast<MachineInstr *>(GR.getFunctionDefinition(F));
342 if (!FunDef)
343 return F;
344 MachineRegisterInfo *DefMRI = &FunDef->getParent()->getParent()->getRegInfo();
345 validateFunCallMachineDef(STI, DefMRI, CallMRI, GR, FunCall, FunDef);
346 return nullptr;
347}
348
349// Ensure there is no mismatch between actual and expected arg types: calls
350// ahead of a processed definition.
351void validateForwardCalls(const SPIRVSubtarget &STI,
352 MachineRegisterInfo *DefMRI, SPIRVGlobalRegistry &GR,
353 MachineInstr &FunDef) {
354 const Function *F = GR.getFunctionByDefinition(MI: &FunDef);
355 if (SmallPtrSet<MachineInstr *, 8> *FwdCalls = GR.getForwardCalls(F))
356 for (MachineInstr *FunCall : *FwdCalls) {
357 MachineRegisterInfo *CallMRI =
358 &FunCall->getParent()->getParent()->getRegInfo();
359 validateFunCallMachineDef(STI, DefMRI, CallMRI, GR, FunCall&: *FunCall, FunDef: &FunDef);
360 }
361}
362
363// Validation of an access chain.
364void validateAccessChain(const SPIRVSubtarget &STI, MachineRegisterInfo *MRI,
365 SPIRVGlobalRegistry &GR, MachineInstr &I) {
366 SPIRVType *BaseTypeInst = GR.getSPIRVTypeForVReg(VReg: I.getOperand(i: 0).getReg());
367 if (BaseTypeInst && BaseTypeInst->getOpcode() == SPIRV::OpTypePointer) {
368 SPIRVType *BaseElemType =
369 GR.getSPIRVTypeForVReg(VReg: BaseTypeInst->getOperand(i: 2).getReg());
370 validatePtrTypes(STI, MRI, GR, I, OpIdx: 2, ResType: BaseElemType);
371 }
372}
373
374// TODO: the logic of inserting additional bitcast's is to be moved
375// to pre-IRTranslation passes eventually
376void SPIRVTargetLowering::finalizeLowering(MachineFunction &MF) const {
377 // finalizeLowering() is called twice (see GlobalISel/InstructionSelect.cpp)
378 // We'd like to avoid the needless second processing pass.
379 if (ProcessedMF.find(x: &MF) != ProcessedMF.end())
380 return;
381
382 MachineRegisterInfo *MRI = &MF.getRegInfo();
383 SPIRVGlobalRegistry &GR = *STI.getSPIRVGlobalRegistry();
384 GR.setCurrentFunc(MF);
385 for (MachineFunction::iterator I = MF.begin(), E = MF.end(); I != E; ++I) {
386 MachineBasicBlock *MBB = &*I;
387 SmallPtrSet<MachineInstr *, 8> ToMove;
388 for (MachineBasicBlock::iterator MBBI = MBB->begin(), MBBE = MBB->end();
389 MBBI != MBBE;) {
390 MachineInstr &MI = *MBBI++;
391 switch (MI.getOpcode()) {
392 case SPIRV::OpAtomicLoad:
393 case SPIRV::OpAtomicExchange:
394 case SPIRV::OpAtomicCompareExchange:
395 case SPIRV::OpAtomicCompareExchangeWeak:
396 case SPIRV::OpAtomicIIncrement:
397 case SPIRV::OpAtomicIDecrement:
398 case SPIRV::OpAtomicIAdd:
399 case SPIRV::OpAtomicISub:
400 case SPIRV::OpAtomicSMin:
401 case SPIRV::OpAtomicUMin:
402 case SPIRV::OpAtomicSMax:
403 case SPIRV::OpAtomicUMax:
404 case SPIRV::OpAtomicAnd:
405 case SPIRV::OpAtomicOr:
406 case SPIRV::OpAtomicXor:
407 // for the above listed instructions
408 // OpAtomicXXX <ResType>, ptr %Op, ...
409 // implies that %Op is a pointer to <ResType>
410 case SPIRV::OpLoad:
411 // OpLoad <ResType>, ptr %Op implies that %Op is a pointer to <ResType>
412 if (enforcePtrTypeCompatibility(I&: MI, PtrOpIdx: 2, OpIdx: 0))
413 break;
414
415 validatePtrTypes(STI, MRI, GR, I&: MI, OpIdx: 2,
416 ResType: GR.getSPIRVTypeForVReg(VReg: MI.getOperand(i: 0).getReg()));
417 break;
418 case SPIRV::OpAtomicStore:
419 // OpAtomicStore ptr %Op, <Scope>, <Mem>, <Obj>
420 // implies that %Op points to the <Obj>'s type
421 validatePtrTypes(STI, MRI, GR, I&: MI, OpIdx: 0,
422 ResType: GR.getSPIRVTypeForVReg(VReg: MI.getOperand(i: 3).getReg()));
423 break;
424 case SPIRV::OpStore:
425 // OpStore ptr %Op, <Obj> implies that %Op points to the <Obj>'s type
426 validatePtrTypes(STI, MRI, GR, I&: MI, OpIdx: 0,
427 ResType: GR.getSPIRVTypeForVReg(VReg: MI.getOperand(i: 1).getReg()));
428 break;
429 case SPIRV::OpPtrCastToGeneric:
430 case SPIRV::OpGenericCastToPtr:
431 case SPIRV::OpGenericCastToPtrExplicit:
432 validateAccessChain(STI, MRI, GR, I&: MI);
433 break;
434 case SPIRV::OpPtrAccessChain:
435 case SPIRV::OpInBoundsPtrAccessChain:
436 if (MI.getNumOperands() == 4)
437 validateAccessChain(STI, MRI, GR, I&: MI);
438 break;
439
440 case SPIRV::OpFunctionCall:
441 // ensure there is no mismatch between actual and expected arg types:
442 // calls with a processed definition
443 if (MI.getNumOperands() > 3)
444 if (const Function *F = validateFunCall(STI, CallMRI: MRI, GR, FunCall&: MI))
445 GR.addForwardCall(F, MI: &MI);
446 break;
447 case SPIRV::OpFunction:
448 // ensure there is no mismatch between actual and expected arg types:
449 // calls ahead of a processed definition
450 validateForwardCalls(STI, DefMRI: MRI, GR, FunDef&: MI);
451 break;
452
453 // ensure that LLVM IR add/sub instructions result in logical SPIR-V
454 // instructions when applied to bool type
455 case SPIRV::OpIAddS:
456 case SPIRV::OpIAddV:
457 case SPIRV::OpISubS:
458 case SPIRV::OpISubV:
459 if (GR.isScalarOrVectorOfType(VReg: MI.getOperand(i: 1).getReg(),
460 TypeOpcode: SPIRV::OpTypeBool))
461 MI.setDesc(STI.getInstrInfo()->get(Opcode: SPIRV::OpLogicalNotEqual));
462 break;
463
464 // ensure that LLVM IR bitwise instructions result in logical SPIR-V
465 // instructions when applied to bool type
466 case SPIRV::OpBitwiseOrS:
467 case SPIRV::OpBitwiseOrV:
468 if (GR.isScalarOrVectorOfType(VReg: MI.getOperand(i: 1).getReg(),
469 TypeOpcode: SPIRV::OpTypeBool))
470 MI.setDesc(STI.getInstrInfo()->get(Opcode: SPIRV::OpLogicalOr));
471 break;
472 case SPIRV::OpBitwiseAndS:
473 case SPIRV::OpBitwiseAndV:
474 if (GR.isScalarOrVectorOfType(VReg: MI.getOperand(i: 1).getReg(),
475 TypeOpcode: SPIRV::OpTypeBool))
476 MI.setDesc(STI.getInstrInfo()->get(Opcode: SPIRV::OpLogicalAnd));
477 break;
478 case SPIRV::OpBitwiseXorS:
479 case SPIRV::OpBitwiseXorV:
480 if (GR.isScalarOrVectorOfType(VReg: MI.getOperand(i: 1).getReg(),
481 TypeOpcode: SPIRV::OpTypeBool))
482 MI.setDesc(STI.getInstrInfo()->get(Opcode: SPIRV::OpLogicalNotEqual));
483 break;
484 case SPIRV::OpLifetimeStart:
485 case SPIRV::OpLifetimeStop:
486 if (MI.getOperand(i: 1).getImm() > 0)
487 validateLifetimeStart(STI, MRI, GR, I&: MI);
488 break;
489 case SPIRV::OpGroupAsyncCopy:
490 validatePtrUnwrapStructField(STI, MRI, GR, I&: MI, OpIdx: 3);
491 validatePtrUnwrapStructField(STI, MRI, GR, I&: MI, OpIdx: 4);
492 break;
493 case SPIRV::OpGroupWaitEvents:
494 // OpGroupWaitEvents ..., ..., <pointer to OpTypeEvent>
495 validateGroupWaitEventsPtr(STI, MRI, GR, I&: MI);
496 break;
497 case SPIRV::OpConstantI: {
498 SPIRVType *Type = GR.getSPIRVTypeForVReg(VReg: MI.getOperand(i: 1).getReg());
499 if (Type->getOpcode() != SPIRV::OpTypeInt && MI.getOperand(i: 2).isImm() &&
500 MI.getOperand(i: 2).getImm() == 0) {
501 // Validate the null constant of a target extension type
502 MI.setDesc(STI.getInstrInfo()->get(Opcode: SPIRV::OpConstantNull));
503 for (unsigned i = MI.getNumOperands() - 1; i > 1; --i)
504 MI.removeOperand(OpNo: i);
505 }
506 } break;
507 case SPIRV::OpPhi: {
508 // Phi refers to a type definition that goes after the Phi
509 // instruction, so that the virtual register definition of the type
510 // doesn't dominate all uses. Let's place the type definition
511 // instruction at the end of the predecessor.
512 MachineBasicBlock *Curr = MI.getParent();
513 SPIRVType *Type = GR.getSPIRVTypeForVReg(VReg: MI.getOperand(i: 1).getReg());
514 if (Type->getParent() == Curr && !Curr->pred_empty())
515 ToMove.insert(Ptr: const_cast<MachineInstr *>(Type));
516 } break;
517 case SPIRV::OpExtInst: {
518 // prefetch
519 if (!MI.getOperand(i: 2).isImm() || !MI.getOperand(i: 3).isImm() ||
520 MI.getOperand(i: 2).getImm() != SPIRV::InstructionSet::OpenCL_std)
521 continue;
522 switch (MI.getOperand(i: 3).getImm()) {
523 case SPIRV::OpenCLExtInst::frexp:
524 case SPIRV::OpenCLExtInst::lgamma_r:
525 case SPIRV::OpenCLExtInst::remquo: {
526 // The last operand must be of a pointer to i32 or vector of i32
527 // values.
528 MachineIRBuilder MIB(MI);
529 SPIRVType *Int32Type = GR.getOrCreateSPIRVIntegerType(BitWidth: 32, MIRBuilder&: MIB);
530 SPIRVType *RetType = MRI->getVRegDef(Reg: MI.getOperand(i: 1).getReg());
531 assert(RetType && "Expected return type");
532 validatePtrTypes(STI, MRI, GR, I&: MI, OpIdx: MI.getNumOperands() - 1,
533 ResType: RetType->getOpcode() != SPIRV::OpTypeVector
534 ? Int32Type
535 : GR.getOrCreateSPIRVVectorType(
536 BaseType: Int32Type, NumElements: RetType->getOperand(i: 2).getImm(),
537 MIRBuilder&: MIB, EmitIR: false));
538 } break;
539 case SPIRV::OpenCLExtInst::fract:
540 case SPIRV::OpenCLExtInst::modf:
541 case SPIRV::OpenCLExtInst::sincos:
542 // The last operand must be of a pointer to the base type represented
543 // by the previous operand.
544 assert(MI.getOperand(MI.getNumOperands() - 2).isReg() &&
545 "Expected v-reg");
546 validatePtrTypes(
547 STI, MRI, GR, I&: MI, OpIdx: MI.getNumOperands() - 1,
548 ResType: GR.getSPIRVTypeForVReg(
549 VReg: MI.getOperand(i: MI.getNumOperands() - 2).getReg()));
550 break;
551 case SPIRV::OpenCLExtInst::prefetch:
552 // Expected `ptr` type is a pointer to float, integer or vector, but
553 // the pontee value can be wrapped into a struct.
554 assert(MI.getOperand(MI.getNumOperands() - 2).isReg() &&
555 "Expected v-reg");
556 validatePtrUnwrapStructField(STI, MRI, GR, I&: MI,
557 OpIdx: MI.getNumOperands() - 2);
558 break;
559 }
560 } break;
561 }
562 }
563 for (MachineInstr *MI : ToMove) {
564 MachineBasicBlock *Curr = MI->getParent();
565 MachineBasicBlock *Pred = *Curr->pred_begin();
566 Pred->insert(I: Pred->getFirstTerminator(), MI: Curr->remove_instr(I: MI));
567 }
568 }
569 ProcessedMF.insert(x: &MF);
570 TargetLowering::finalizeLowering(MF);
571}
572
573// Modifies either operand PtrOpIdx or OpIdx so that the pointee type of
574// PtrOpIdx matches the type for operand OpIdx. Returns true if they already
575// match or if the instruction was modified to make them match.
576bool SPIRVTargetLowering::enforcePtrTypeCompatibility(
577 MachineInstr &I, unsigned int PtrOpIdx, unsigned int OpIdx) const {
578 SPIRVGlobalRegistry &GR = *STI.getSPIRVGlobalRegistry();
579 SPIRVType *PtrType = GR.getResultType(VReg: I.getOperand(i: PtrOpIdx).getReg());
580 SPIRVType *PointeeType = GR.getPointeeType(PtrType);
581 SPIRVType *OpType = GR.getResultType(VReg: I.getOperand(i: OpIdx).getReg());
582
583 if (PointeeType == OpType)
584 return true;
585
586 if (typesLogicallyMatch(Ty1: PointeeType, Ty2: OpType, GR)) {
587 // Apply OpCopyLogical to OpIdx.
588 if (I.getOperand(i: OpIdx).isDef() &&
589 insertLogicalCopyOnResult(I, NewResultType: PointeeType)) {
590 return true;
591 }
592
593 llvm_unreachable("Unable to add OpCopyLogical yet.");
594 return false;
595 }
596
597 return false;
598}
599
600bool SPIRVTargetLowering::insertLogicalCopyOnResult(
601 MachineInstr &I, SPIRVType *NewResultType) const {
602 MachineRegisterInfo *MRI = &I.getMF()->getRegInfo();
603 SPIRVGlobalRegistry &GR = *STI.getSPIRVGlobalRegistry();
604
605 Register NewResultReg =
606 createVirtualRegister(SpvType: NewResultType, GR: &GR, MRI, MF: *I.getMF());
607 Register NewTypeReg = GR.getSPIRVTypeID(SpirvType: NewResultType);
608
609 assert(std::distance(I.defs().begin(), I.defs().end()) == 1 &&
610 "Expected only one def");
611 MachineOperand &OldResult = *I.defs().begin();
612 Register OldResultReg = OldResult.getReg();
613 MachineOperand &OldType = *I.uses().begin();
614 Register OldTypeReg = OldType.getReg();
615
616 OldResult.setReg(NewResultReg);
617 OldType.setReg(NewTypeReg);
618
619 MachineIRBuilder MIB(*I.getNextNode());
620 return MIB.buildInstr(Opcode: SPIRV::OpCopyLogical)
621 .addDef(RegNo: OldResultReg)
622 .addUse(RegNo: OldTypeReg)
623 .addUse(RegNo: NewResultReg)
624 .constrainAllUses(TII: *STI.getInstrInfo(), TRI: *STI.getRegisterInfo(),
625 RBI: *STI.getRegBankInfo());
626}
627