1//===- SPIRVISelLowering.cpp - SPIR-V DAG Lowering Impl ---------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements the SPIRVTargetLowering class.
10//
11//===----------------------------------------------------------------------===//
12
13#include "SPIRVISelLowering.h"
14#include "SPIRV.h"
15#include "SPIRVInstrInfo.h"
16#include "SPIRVRegisterBankInfo.h"
17#include "SPIRVRegisterInfo.h"
18#include "SPIRVSubtarget.h"
19#include "llvm/CodeGen/MachineInstrBuilder.h"
20#include "llvm/CodeGen/MachineRegisterInfo.h"
21#include "llvm/IR/IntrinsicsSPIRV.h"
22
23#define DEBUG_TYPE "spirv-lower"
24
25using namespace llvm;
26
27SPIRVTargetLowering::SPIRVTargetLowering(const TargetMachine &TM,
28 const SPIRVSubtarget &ST)
29 : TargetLowering(TM, ST), STI(ST) {}
30
31// Returns true of the types logically match, as defined in
32// https://registry.khronos.org/SPIR-V/specs/unified1/SPIRV.html#OpCopyLogical.
33static bool typesLogicallyMatch(const SPIRVType *Ty1, const SPIRVType *Ty2,
34 SPIRVGlobalRegistry &GR) {
35 if (Ty1->getOpcode() != Ty2->getOpcode())
36 return false;
37
38 if (Ty1->getNumOperands() != Ty2->getNumOperands())
39 return false;
40
41 if (Ty1->getOpcode() == SPIRV::OpTypeArray) {
42 // Array must have the same size.
43 if (Ty1->getOperand(i: 2).getReg() != Ty2->getOperand(i: 2).getReg())
44 return false;
45
46 SPIRVType *ElemType1 = GR.getSPIRVTypeForVReg(VReg: Ty1->getOperand(i: 1).getReg());
47 SPIRVType *ElemType2 = GR.getSPIRVTypeForVReg(VReg: Ty2->getOperand(i: 1).getReg());
48 return ElemType1 == ElemType2 ||
49 typesLogicallyMatch(Ty1: ElemType1, Ty2: ElemType2, GR);
50 }
51
52 if (Ty1->getOpcode() == SPIRV::OpTypeStruct) {
53 for (unsigned I = 1; I < Ty1->getNumOperands(); I++) {
54 SPIRVType *ElemType1 =
55 GR.getSPIRVTypeForVReg(VReg: Ty1->getOperand(i: I).getReg());
56 SPIRVType *ElemType2 =
57 GR.getSPIRVTypeForVReg(VReg: Ty2->getOperand(i: I).getReg());
58 if (ElemType1 != ElemType2 &&
59 !typesLogicallyMatch(Ty1: ElemType1, Ty2: ElemType2, GR))
60 return false;
61 }
62 return true;
63 }
64 return false;
65}
66
67unsigned SPIRVTargetLowering::getNumRegistersForCallingConv(
68 LLVMContext &Context, CallingConv::ID CC, EVT VT) const {
69 // This code avoids CallLowering fail inside getVectorTypeBreakdown
70 // on v3i1 arguments. Maybe we need to return 1 for all types.
71 // TODO: remove it once this case is supported by the default implementation.
72 if (VT.isVector() && VT.getVectorNumElements() == 3 &&
73 (VT.getVectorElementType() == MVT::i1 ||
74 VT.getVectorElementType() == MVT::i8))
75 return 1;
76 if (!VT.isVector() && VT.isInteger() && VT.getSizeInBits() <= 64)
77 return 1;
78 return getNumRegisters(Context, VT);
79}
80
81MVT SPIRVTargetLowering::getRegisterTypeForCallingConv(LLVMContext &Context,
82 CallingConv::ID CC,
83 EVT VT) const {
84 // This code avoids CallLowering fail inside getVectorTypeBreakdown
85 // on v3i1 arguments. Maybe we need to return i32 for all types.
86 // TODO: remove it once this case is supported by the default implementation.
87 if (VT.isVector() && VT.getVectorNumElements() == 3) {
88 if (VT.getVectorElementType() == MVT::i1)
89 return MVT::v4i1;
90 else if (VT.getVectorElementType() == MVT::i8)
91 return MVT::v4i8;
92 }
93 return getRegisterType(Context, VT);
94}
95
96void SPIRVTargetLowering::getTgtMemIntrinsic(
97 SmallVectorImpl<IntrinsicInfo> &Infos, const CallBase &I,
98 MachineFunction &MF, unsigned Intrinsic) const {
99 IntrinsicInfo Info;
100 unsigned AlignIdx = 3;
101 switch (Intrinsic) {
102 case Intrinsic::spv_load:
103 AlignIdx = 2;
104 [[fallthrough]];
105 case Intrinsic::spv_store: {
106 if (I.getNumOperands() >= AlignIdx + 1) {
107 auto *AlignOp = cast<ConstantInt>(Val: I.getOperand(i_nocapture: AlignIdx));
108 Info.align = Align(AlignOp->getZExtValue());
109 }
110 Info.flags = static_cast<MachineMemOperand::Flags>(
111 cast<ConstantInt>(Val: I.getOperand(i_nocapture: AlignIdx - 1))->getZExtValue());
112 Info.memVT = MVT::i64;
113 // TODO: take into account opaque pointers (don't use getElementType).
114 // MVT::getVT(PtrTy->getElementType());
115 Infos.push_back(Elt: Info);
116 return;
117 }
118 default:
119 break;
120 }
121}
122
123std::pair<unsigned, const TargetRegisterClass *>
124SPIRVTargetLowering::getRegForInlineAsmConstraint(const TargetRegisterInfo *TRI,
125 StringRef Constraint,
126 MVT VT) const {
127 const TargetRegisterClass *RC = nullptr;
128 if (Constraint.starts_with(Prefix: "{"))
129 return std::make_pair(x: 0u, y&: RC);
130
131 if (VT.isFloatingPoint())
132 RC = VT.isVector() ? &SPIRV::vfIDRegClass : &SPIRV::fIDRegClass;
133 else if (VT.isInteger())
134 RC = VT.isVector() ? &SPIRV::vIDRegClass : &SPIRV::iIDRegClass;
135 else
136 RC = &SPIRV::iIDRegClass;
137
138 return std::make_pair(x: 0u, y&: RC);
139}
140
141inline Register getTypeReg(MachineRegisterInfo *MRI, Register OpReg) {
142 SPIRVType *TypeInst = MRI->getVRegDef(Reg: OpReg);
143 return TypeInst && TypeInst->getOpcode() == SPIRV::OpFunctionParameter
144 ? TypeInst->getOperand(i: 1).getReg()
145 : OpReg;
146}
147
148static void doInsertBitcast(const SPIRVSubtarget &STI, MachineRegisterInfo *MRI,
149 SPIRVGlobalRegistry &GR, MachineInstr &I,
150 Register OpReg, unsigned OpIdx,
151 SPIRVType *NewPtrType) {
152 MachineIRBuilder MIB(I);
153 Register NewReg = createVirtualRegister(SpvType: NewPtrType, GR: &GR, MRI, MF: MIB.getMF());
154 bool Res = MIB.buildInstr(Opcode: SPIRV::OpBitcast)
155 .addDef(RegNo: NewReg)
156 .addUse(RegNo: GR.getSPIRVTypeID(SpirvType: NewPtrType))
157 .addUse(RegNo: OpReg)
158 .constrainAllUses(TII: *STI.getInstrInfo(), TRI: *STI.getRegisterInfo(),
159 RBI: *STI.getRegBankInfo());
160 if (!Res)
161 report_fatal_error(reason: "insert validation bitcast: cannot constrain all uses");
162 I.getOperand(i: OpIdx).setReg(NewReg);
163}
164
165static SPIRVType *createNewPtrType(SPIRVGlobalRegistry &GR, MachineInstr &I,
166 SPIRVType *OpType, bool ReuseType,
167 SPIRVType *ResType, const Type *ResTy) {
168 SPIRV::StorageClass::StorageClass SC =
169 static_cast<SPIRV::StorageClass::StorageClass>(
170 OpType->getOperand(i: 1).getImm());
171 MachineIRBuilder MIB(I);
172 SPIRVType *NewBaseType =
173 ReuseType ? ResType
174 : GR.getOrCreateSPIRVType(
175 Type: ResTy, MIRBuilder&: MIB, AQ: SPIRV::AccessQualifier::ReadWrite, EmitIR: false);
176 return GR.getOrCreateSPIRVPointerType(BaseType: NewBaseType, MIRBuilder&: MIB, SC);
177}
178
179// Insert a bitcast before the instruction to keep SPIR-V code valid
180// when there is a type mismatch between results and operand types.
181static void validatePtrTypes(const SPIRVSubtarget &STI,
182 MachineRegisterInfo *MRI, SPIRVGlobalRegistry &GR,
183 MachineInstr &I, unsigned OpIdx,
184 SPIRVType *ResType, const Type *ResTy = nullptr) {
185 // Get operand type
186 MachineFunction *MF = I.getParent()->getParent();
187 Register OpReg = I.getOperand(i: OpIdx).getReg();
188 Register OpTypeReg = getTypeReg(MRI, OpReg);
189 SPIRVType *OpType = GR.getSPIRVTypeForVReg(VReg: OpTypeReg, MF);
190 if (!ResType || !OpType || OpType->getOpcode() != SPIRV::OpTypePointer)
191 return;
192 // Get operand's pointee type
193 Register ElemTypeReg = OpType->getOperand(i: 2).getReg();
194 SPIRVType *ElemType = GR.getSPIRVTypeForVReg(VReg: ElemTypeReg, MF);
195 if (!ElemType)
196 return;
197 // Check if we need a bitcast to make a statement valid
198 bool IsSameMF = MF == ResType->getParent()->getParent();
199 bool IsEqualTypes = IsSameMF ? ElemType == ResType
200 : GR.getTypeForSPIRVType(Ty: ElemType) == ResTy;
201 if (IsEqualTypes)
202 return;
203 // There is a type mismatch between results and operand types
204 // and we insert a bitcast before the instruction to keep SPIR-V code valid
205 SPIRVType *NewPtrType =
206 createNewPtrType(GR, I, OpType, ReuseType: IsSameMF, ResType, ResTy);
207 if (!GR.isBitcastCompatible(Type1: NewPtrType, Type2: OpType))
208 report_fatal_error(
209 reason: "insert validation bitcast: incompatible result and operand types");
210 doInsertBitcast(STI, MRI, GR, I, OpReg, OpIdx, NewPtrType);
211}
212
213// Insert a bitcast before OpGroupWaitEvents if the last argument is a pointer
214// that doesn't point to OpTypeEvent.
215static void validateGroupWaitEventsPtr(const SPIRVSubtarget &STI,
216 MachineRegisterInfo *MRI,
217 SPIRVGlobalRegistry &GR,
218 MachineInstr &I) {
219 constexpr unsigned OpIdx = 2;
220 MachineFunction *MF = I.getParent()->getParent();
221 Register OpReg = I.getOperand(i: OpIdx).getReg();
222 Register OpTypeReg = getTypeReg(MRI, OpReg);
223 SPIRVType *OpType = GR.getSPIRVTypeForVReg(VReg: OpTypeReg, MF);
224 if (!OpType || OpType->getOpcode() != SPIRV::OpTypePointer)
225 return;
226 SPIRVType *ElemType = GR.getSPIRVTypeForVReg(VReg: OpType->getOperand(i: 2).getReg());
227 if (!ElemType || ElemType->getOpcode() == SPIRV::OpTypeEvent)
228 return;
229 // Insert a bitcast before the instruction to keep SPIR-V code valid.
230 LLVMContext &Context = MF->getFunction().getContext();
231 SPIRVType *NewPtrType =
232 createNewPtrType(GR, I, OpType, ReuseType: false, ResType: nullptr,
233 ResTy: TargetExtType::get(Context, Name: "spirv.Event"));
234 doInsertBitcast(STI, MRI, GR, I, OpReg, OpIdx, NewPtrType);
235}
236
237static void validateLifetimeStart(const SPIRVSubtarget &STI,
238 MachineRegisterInfo *MRI,
239 SPIRVGlobalRegistry &GR, MachineInstr &I) {
240 Register PtrReg = I.getOperand(i: 0).getReg();
241 MachineFunction *MF = I.getParent()->getParent();
242 Register PtrTypeReg = getTypeReg(MRI, OpReg: PtrReg);
243 SPIRVType *PtrType = GR.getSPIRVTypeForVReg(VReg: PtrTypeReg, MF);
244 SPIRVType *PonteeElemType = PtrType ? GR.getPointeeType(PtrType) : nullptr;
245 if (!PonteeElemType || PonteeElemType->getOpcode() == SPIRV::OpTypeVoid ||
246 (PonteeElemType->getOpcode() == SPIRV::OpTypeInt &&
247 PonteeElemType->getOperand(i: 1).getImm() == 8))
248 return;
249 // To keep the code valid a bitcast must be inserted
250 SPIRV::StorageClass::StorageClass SC =
251 static_cast<SPIRV::StorageClass::StorageClass>(
252 PtrType->getOperand(i: 1).getImm());
253 MachineIRBuilder MIB(I);
254 LLVMContext &Context = MF->getFunction().getContext();
255 SPIRVType *NewPtrType =
256 GR.getOrCreateSPIRVPointerType(BaseType: IntegerType::getInt8Ty(C&: Context), MIRBuilder&: MIB, SC);
257 doInsertBitcast(STI, MRI, GR, I, OpReg: PtrReg, OpIdx: 0, NewPtrType);
258}
259
260static void validatePtrUnwrapStructField(const SPIRVSubtarget &STI,
261 MachineRegisterInfo *MRI,
262 SPIRVGlobalRegistry &GR,
263 MachineInstr &I, unsigned OpIdx) {
264 MachineFunction *MF = I.getParent()->getParent();
265 Register OpReg = I.getOperand(i: OpIdx).getReg();
266 Register OpTypeReg = getTypeReg(MRI, OpReg);
267 SPIRVType *OpType = GR.getSPIRVTypeForVReg(VReg: OpTypeReg, MF);
268 if (!OpType || OpType->getOpcode() != SPIRV::OpTypePointer)
269 return;
270 SPIRVType *ElemType = GR.getSPIRVTypeForVReg(VReg: OpType->getOperand(i: 2).getReg());
271 if (!ElemType || ElemType->getOpcode() != SPIRV::OpTypeStruct ||
272 ElemType->getNumOperands() != 2)
273 return;
274 // It's a structure-wrapper around another type with a single member field.
275 SPIRVType *MemberType =
276 GR.getSPIRVTypeForVReg(VReg: ElemType->getOperand(i: 1).getReg());
277 if (!MemberType)
278 return;
279 unsigned MemberTypeOp = MemberType->getOpcode();
280 if (MemberTypeOp != SPIRV::OpTypeVector && MemberTypeOp != SPIRV::OpTypeInt &&
281 MemberTypeOp != SPIRV::OpTypeFloat && MemberTypeOp != SPIRV::OpTypeBool)
282 return;
283 // It's a structure-wrapper around a valid type. Insert a bitcast before the
284 // instruction to keep SPIR-V code valid.
285 SPIRV::StorageClass::StorageClass SC =
286 static_cast<SPIRV::StorageClass::StorageClass>(
287 OpType->getOperand(i: 1).getImm());
288 MachineIRBuilder MIB(I);
289 SPIRVType *NewPtrType = GR.getOrCreateSPIRVPointerType(BaseType: MemberType, MIRBuilder&: MIB, SC);
290 doInsertBitcast(STI, MRI, GR, I, OpReg, OpIdx, NewPtrType);
291}
292
293// Insert a bitcast before the function call instruction to keep SPIR-V code
294// valid when there is a type mismatch between actual and expected types of an
295// argument:
296// %formal = OpFunctionParameter %formal_type
297// ...
298// %res = OpFunctionCall %ty %fun %actual ...
299// implies that %actual is of %formal_type, and in case of opaque pointers.
300// We may need to insert a bitcast to ensure this.
301void validateFunCallMachineDef(const SPIRVSubtarget &STI,
302 MachineRegisterInfo *DefMRI,
303 MachineRegisterInfo *CallMRI,
304 SPIRVGlobalRegistry &GR, MachineInstr &FunCall,
305 MachineInstr *FunDef) {
306 if (FunDef->getOpcode() != SPIRV::OpFunction)
307 return;
308 unsigned OpIdx = 3;
309 for (FunDef = FunDef->getNextNode();
310 FunDef && FunDef->getOpcode() == SPIRV::OpFunctionParameter &&
311 OpIdx < FunCall.getNumOperands();
312 FunDef = FunDef->getNextNode(), OpIdx++) {
313 SPIRVType *DefPtrType = DefMRI->getVRegDef(Reg: FunDef->getOperand(i: 1).getReg());
314 SPIRVType *DefElemType =
315 DefPtrType && DefPtrType->getOpcode() == SPIRV::OpTypePointer
316 ? GR.getSPIRVTypeForVReg(VReg: DefPtrType->getOperand(i: 2).getReg(),
317 MF: DefPtrType->getParent()->getParent())
318 : nullptr;
319 if (DefElemType) {
320 const Type *DefElemTy = GR.getTypeForSPIRVType(Ty: DefElemType);
321 // validatePtrTypes() works in the context if the call site
322 // When we process historical records about forward calls
323 // we need to switch context to the (forward) call site and
324 // then restore it back to the current machine function.
325 MachineFunction *CurMF =
326 GR.setCurrentFunc(*FunCall.getParent()->getParent());
327 validatePtrTypes(STI, MRI: CallMRI, GR, I&: FunCall, OpIdx, ResType: DefElemType,
328 ResTy: DefElemTy);
329 GR.setCurrentFunc(*CurMF);
330 }
331 }
332}
333
334// Ensure there is no mismatch between actual and expected arg types: calls
335// with a processed definition. Return Function pointer if it's a forward
336// call (ahead of definition), and nullptr otherwise.
337const Function *validateFunCall(const SPIRVSubtarget &STI,
338 MachineRegisterInfo *CallMRI,
339 SPIRVGlobalRegistry &GR,
340 MachineInstr &FunCall) {
341 const GlobalValue *GV = FunCall.getOperand(i: 2).getGlobal();
342 const Function *F = dyn_cast<Function>(Val: GV);
343 MachineInstr *FunDef =
344 const_cast<MachineInstr *>(GR.getFunctionDefinition(F));
345 if (!FunDef)
346 return F;
347 MachineRegisterInfo *DefMRI = &FunDef->getParent()->getParent()->getRegInfo();
348 validateFunCallMachineDef(STI, DefMRI, CallMRI, GR, FunCall, FunDef);
349 return nullptr;
350}
351
352// Ensure there is no mismatch between actual and expected arg types: calls
353// ahead of a processed definition.
354void validateForwardCalls(const SPIRVSubtarget &STI,
355 MachineRegisterInfo *DefMRI, SPIRVGlobalRegistry &GR,
356 MachineInstr &FunDef) {
357 const Function *F = GR.getFunctionByDefinition(MI: &FunDef);
358 if (SmallPtrSet<MachineInstr *, 8> *FwdCalls = GR.getForwardCalls(F))
359 for (MachineInstr *FunCall : *FwdCalls) {
360 MachineRegisterInfo *CallMRI =
361 &FunCall->getParent()->getParent()->getRegInfo();
362 validateFunCallMachineDef(STI, DefMRI, CallMRI, GR, FunCall&: *FunCall, FunDef: &FunDef);
363 }
364}
365
366// Validation of an access chain.
367void validateAccessChain(const SPIRVSubtarget &STI, MachineRegisterInfo *MRI,
368 SPIRVGlobalRegistry &GR, MachineInstr &I) {
369 SPIRVType *BaseTypeInst = GR.getSPIRVTypeForVReg(VReg: I.getOperand(i: 0).getReg());
370 if (BaseTypeInst && BaseTypeInst->getOpcode() == SPIRV::OpTypePointer) {
371 SPIRVType *BaseElemType =
372 GR.getSPIRVTypeForVReg(VReg: BaseTypeInst->getOperand(i: 2).getReg());
373 validatePtrTypes(STI, MRI, GR, I, OpIdx: 2, ResType: BaseElemType);
374 }
375}
376
377// TODO: the logic of inserting additional bitcast's is to be moved
378// to pre-IRTranslation passes eventually
379void SPIRVTargetLowering::finalizeLowering(MachineFunction &MF) const {
380 // finalizeLowering() is called twice (see GlobalISel/InstructionSelect.cpp)
381 // We'd like to avoid the needless second processing pass.
382 if (ProcessedMF.find(x: &MF) != ProcessedMF.end())
383 return;
384
385 MachineRegisterInfo *MRI = &MF.getRegInfo();
386 SPIRVGlobalRegistry &GR = *STI.getSPIRVGlobalRegistry();
387 GR.setCurrentFunc(MF);
388 for (MachineFunction::iterator I = MF.begin(), E = MF.end(); I != E; ++I) {
389 MachineBasicBlock *MBB = &*I;
390 SmallPtrSet<MachineInstr *, 8> ToMove;
391 for (MachineBasicBlock::iterator MBBI = MBB->begin(), MBBE = MBB->end();
392 MBBI != MBBE;) {
393 MachineInstr &MI = *MBBI++;
394 switch (MI.getOpcode()) {
395 case SPIRV::OpAtomicLoad:
396 case SPIRV::OpAtomicExchange:
397 case SPIRV::OpAtomicCompareExchange:
398 case SPIRV::OpAtomicCompareExchangeWeak:
399 case SPIRV::OpAtomicIIncrement:
400 case SPIRV::OpAtomicIDecrement:
401 case SPIRV::OpAtomicIAdd:
402 case SPIRV::OpAtomicISub:
403 case SPIRV::OpAtomicSMin:
404 case SPIRV::OpAtomicUMin:
405 case SPIRV::OpAtomicSMax:
406 case SPIRV::OpAtomicUMax:
407 case SPIRV::OpAtomicAnd:
408 case SPIRV::OpAtomicOr:
409 case SPIRV::OpAtomicXor:
410 // for the above listed instructions
411 // OpAtomicXXX <ResType>, ptr %Op, ...
412 // implies that %Op is a pointer to <ResType>
413 case SPIRV::OpLoad:
414 // OpLoad <ResType>, ptr %Op implies that %Op is a pointer to <ResType>
415 if (enforcePtrTypeCompatibility(I&: MI, PtrOpIdx: 2, OpIdx: 0))
416 break;
417
418 validatePtrTypes(STI, MRI, GR, I&: MI, OpIdx: 2,
419 ResType: GR.getSPIRVTypeForVReg(VReg: MI.getOperand(i: 0).getReg()));
420 break;
421 case SPIRV::OpAtomicStore:
422 // OpAtomicStore ptr %Op, <Scope>, <Mem>, <Obj>
423 // implies that %Op points to the <Obj>'s type
424 validatePtrTypes(STI, MRI, GR, I&: MI, OpIdx: 0,
425 ResType: GR.getSPIRVTypeForVReg(VReg: MI.getOperand(i: 3).getReg()));
426 break;
427 case SPIRV::OpStore:
428 // OpStore ptr %Op, <Obj> implies that %Op points to the <Obj>'s type
429 validatePtrTypes(STI, MRI, GR, I&: MI, OpIdx: 0,
430 ResType: GR.getSPIRVTypeForVReg(VReg: MI.getOperand(i: 1).getReg()));
431 break;
432 case SPIRV::OpPtrCastToGeneric:
433 case SPIRV::OpGenericCastToPtr:
434 case SPIRV::OpGenericCastToPtrExplicit:
435 validateAccessChain(STI, MRI, GR, I&: MI);
436 break;
437 case SPIRV::OpPtrAccessChain:
438 case SPIRV::OpInBoundsPtrAccessChain:
439 if (MI.getNumOperands() == 4)
440 validateAccessChain(STI, MRI, GR, I&: MI);
441 break;
442
443 case SPIRV::OpFunctionCall:
444 // ensure there is no mismatch between actual and expected arg types:
445 // calls with a processed definition
446 if (MI.getNumOperands() > 3)
447 if (const Function *F = validateFunCall(STI, CallMRI: MRI, GR, FunCall&: MI))
448 GR.addForwardCall(F, MI: &MI);
449 break;
450 case SPIRV::OpFunction:
451 // ensure there is no mismatch between actual and expected arg types:
452 // calls ahead of a processed definition
453 validateForwardCalls(STI, DefMRI: MRI, GR, FunDef&: MI);
454 break;
455
456 // ensure that LLVM IR add/sub instructions result in logical SPIR-V
457 // instructions when applied to bool type
458 case SPIRV::OpIAddS:
459 case SPIRV::OpIAddV:
460 case SPIRV::OpISubS:
461 case SPIRV::OpISubV:
462 if (GR.isScalarOrVectorOfType(VReg: MI.getOperand(i: 1).getReg(),
463 TypeOpcode: SPIRV::OpTypeBool))
464 MI.setDesc(STI.getInstrInfo()->get(Opcode: SPIRV::OpLogicalNotEqual));
465 break;
466
467 // ensure that LLVM IR bitwise instructions result in logical SPIR-V
468 // instructions when applied to bool type
469 case SPIRV::OpBitwiseOrS:
470 case SPIRV::OpBitwiseOrV:
471 if (GR.isScalarOrVectorOfType(VReg: MI.getOperand(i: 1).getReg(),
472 TypeOpcode: SPIRV::OpTypeBool))
473 MI.setDesc(STI.getInstrInfo()->get(Opcode: SPIRV::OpLogicalOr));
474 break;
475 case SPIRV::OpBitwiseAndS:
476 case SPIRV::OpBitwiseAndV:
477 if (GR.isScalarOrVectorOfType(VReg: MI.getOperand(i: 1).getReg(),
478 TypeOpcode: SPIRV::OpTypeBool))
479 MI.setDesc(STI.getInstrInfo()->get(Opcode: SPIRV::OpLogicalAnd));
480 break;
481 case SPIRV::OpBitwiseXorS:
482 case SPIRV::OpBitwiseXorV:
483 if (GR.isScalarOrVectorOfType(VReg: MI.getOperand(i: 1).getReg(),
484 TypeOpcode: SPIRV::OpTypeBool))
485 MI.setDesc(STI.getInstrInfo()->get(Opcode: SPIRV::OpLogicalNotEqual));
486 break;
487 case SPIRV::OpLifetimeStart:
488 case SPIRV::OpLifetimeStop:
489 if (MI.getOperand(i: 1).getImm() > 0)
490 validateLifetimeStart(STI, MRI, GR, I&: MI);
491 break;
492 case SPIRV::OpGroupAsyncCopy:
493 validatePtrUnwrapStructField(STI, MRI, GR, I&: MI, OpIdx: 3);
494 validatePtrUnwrapStructField(STI, MRI, GR, I&: MI, OpIdx: 4);
495 break;
496 case SPIRV::OpGroupWaitEvents:
497 // OpGroupWaitEvents ..., ..., <pointer to OpTypeEvent>
498 validateGroupWaitEventsPtr(STI, MRI, GR, I&: MI);
499 break;
500 case SPIRV::OpConstantI: {
501 SPIRVType *Type = GR.getSPIRVTypeForVReg(VReg: MI.getOperand(i: 1).getReg());
502 if (Type->getOpcode() != SPIRV::OpTypeInt && MI.getOperand(i: 2).isImm() &&
503 MI.getOperand(i: 2).getImm() == 0) {
504 // Validate the null constant of a target extension type
505 MI.setDesc(STI.getInstrInfo()->get(Opcode: SPIRV::OpConstantNull));
506 for (unsigned i = MI.getNumOperands() - 1; i > 1; --i)
507 MI.removeOperand(OpNo: i);
508 }
509 } break;
510 case SPIRV::OpPhi: {
511 // Phi refers to a type definition that goes after the Phi
512 // instruction, so that the virtual register definition of the type
513 // doesn't dominate all uses. Let's place the type definition
514 // instruction at the end of the predecessor.
515 MachineBasicBlock *Curr = MI.getParent();
516 SPIRVType *Type = GR.getSPIRVTypeForVReg(VReg: MI.getOperand(i: 1).getReg());
517 if (Type->getParent() == Curr && !Curr->pred_empty())
518 ToMove.insert(Ptr: const_cast<MachineInstr *>(Type));
519 } break;
520 case SPIRV::OpExtInst: {
521 // prefetch
522 if (!MI.getOperand(i: 2).isImm() || !MI.getOperand(i: 3).isImm() ||
523 MI.getOperand(i: 2).getImm() != SPIRV::InstructionSet::OpenCL_std)
524 continue;
525 switch (MI.getOperand(i: 3).getImm()) {
526 case SPIRV::OpenCLExtInst::frexp:
527 case SPIRV::OpenCLExtInst::lgamma_r:
528 case SPIRV::OpenCLExtInst::remquo: {
529 // The last operand must be of a pointer to i32 or vector of i32
530 // values.
531 MachineIRBuilder MIB(MI);
532 SPIRVType *Int32Type = GR.getOrCreateSPIRVIntegerType(BitWidth: 32, MIRBuilder&: MIB);
533 SPIRVType *RetType = MRI->getVRegDef(Reg: MI.getOperand(i: 1).getReg());
534 assert(RetType && "Expected return type");
535 validatePtrTypes(STI, MRI, GR, I&: MI, OpIdx: MI.getNumOperands() - 1,
536 ResType: RetType->getOpcode() != SPIRV::OpTypeVector
537 ? Int32Type
538 : GR.getOrCreateSPIRVVectorType(
539 BaseType: Int32Type, NumElements: RetType->getOperand(i: 2).getImm(),
540 MIRBuilder&: MIB, EmitIR: false));
541 } break;
542 case SPIRV::OpenCLExtInst::fract:
543 case SPIRV::OpenCLExtInst::modf:
544 case SPIRV::OpenCLExtInst::sincos:
545 // The last operand must be of a pointer to the base type represented
546 // by the previous operand.
547 assert(MI.getOperand(MI.getNumOperands() - 2).isReg() &&
548 "Expected v-reg");
549 validatePtrTypes(
550 STI, MRI, GR, I&: MI, OpIdx: MI.getNumOperands() - 1,
551 ResType: GR.getSPIRVTypeForVReg(
552 VReg: MI.getOperand(i: MI.getNumOperands() - 2).getReg()));
553 break;
554 case SPIRV::OpenCLExtInst::prefetch:
555 // Expected `ptr` type is a pointer to float, integer or vector, but
556 // the pontee value can be wrapped into a struct.
557 assert(MI.getOperand(MI.getNumOperands() - 2).isReg() &&
558 "Expected v-reg");
559 validatePtrUnwrapStructField(STI, MRI, GR, I&: MI,
560 OpIdx: MI.getNumOperands() - 2);
561 break;
562 }
563 } break;
564 }
565 }
566 for (MachineInstr *MI : ToMove) {
567 MachineBasicBlock *Curr = MI->getParent();
568 MachineBasicBlock *Pred = *Curr->pred_begin();
569 Pred->insert(I: Pred->getFirstTerminator(), MI: Curr->remove_instr(I: MI));
570 }
571 }
572 ProcessedMF.insert(x: &MF);
573 TargetLowering::finalizeLowering(MF);
574}
575
576// Modifies either operand PtrOpIdx or OpIdx so that the pointee type of
577// PtrOpIdx matches the type for operand OpIdx. Returns true if they already
578// match or if the instruction was modified to make them match.
579bool SPIRVTargetLowering::enforcePtrTypeCompatibility(
580 MachineInstr &I, unsigned int PtrOpIdx, unsigned int OpIdx) const {
581 SPIRVGlobalRegistry &GR = *STI.getSPIRVGlobalRegistry();
582 SPIRVType *PtrType = GR.getResultType(VReg: I.getOperand(i: PtrOpIdx).getReg());
583 SPIRVType *PointeeType = GR.getPointeeType(PtrType);
584 SPIRVType *OpType = GR.getResultType(VReg: I.getOperand(i: OpIdx).getReg());
585
586 if (PointeeType == OpType)
587 return true;
588
589 if (typesLogicallyMatch(Ty1: PointeeType, Ty2: OpType, GR)) {
590 // Apply OpCopyLogical to OpIdx.
591 if (I.getOperand(i: OpIdx).isDef() &&
592 insertLogicalCopyOnResult(I, NewResultType: PointeeType)) {
593 return true;
594 }
595
596 llvm_unreachable("Unable to add OpCopyLogical yet.");
597 return false;
598 }
599
600 return false;
601}
602
603bool SPIRVTargetLowering::insertLogicalCopyOnResult(
604 MachineInstr &I, SPIRVType *NewResultType) const {
605 MachineRegisterInfo *MRI = &I.getMF()->getRegInfo();
606 SPIRVGlobalRegistry &GR = *STI.getSPIRVGlobalRegistry();
607
608 Register NewResultReg =
609 createVirtualRegister(SpvType: NewResultType, GR: &GR, MRI, MF: *I.getMF());
610 Register NewTypeReg = GR.getSPIRVTypeID(SpirvType: NewResultType);
611
612 assert(llvm::size(I.defs()) == 1 && "Expected only one def");
613 MachineOperand &OldResult = *I.defs().begin();
614 Register OldResultReg = OldResult.getReg();
615 MachineOperand &OldType = *I.uses().begin();
616 Register OldTypeReg = OldType.getReg();
617
618 OldResult.setReg(NewResultReg);
619 OldType.setReg(NewTypeReg);
620
621 MachineIRBuilder MIB(*I.getNextNode());
622 return MIB.buildInstr(Opcode: SPIRV::OpCopyLogical)
623 .addDef(RegNo: OldResultReg)
624 .addUse(RegNo: OldTypeReg)
625 .addUse(RegNo: NewResultReg)
626 .constrainAllUses(TII: *STI.getInstrInfo(), TRI: *STI.getRegisterInfo(),
627 RBI: *STI.getRegBankInfo());
628}
629