1//===- SPIRVISelLowering.cpp - SPIR-V DAG Lowering Impl ---------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements the SPIRVTargetLowering class.
10//
11//===----------------------------------------------------------------------===//
12
13#include "SPIRVISelLowering.h"
14#include "SPIRV.h"
15#include "SPIRVInstrInfo.h"
16#include "SPIRVRegisterBankInfo.h"
17#include "SPIRVRegisterInfo.h"
18#include "SPIRVSubtarget.h"
19#include "llvm/CodeGen/MachineInstrBuilder.h"
20#include "llvm/CodeGen/MachineRegisterInfo.h"
21#include "llvm/CodeGen/TargetLowering.h"
22#include "llvm/IR/Instructions.h"
23#include "llvm/IR/IntrinsicsSPIRV.h"
24
25#define DEBUG_TYPE "spirv-lower"
26
27using namespace llvm;
28
29SPIRVTargetLowering::SPIRVTargetLowering(const TargetMachine &TM,
30 const SPIRVSubtarget &ST)
31 : TargetLowering(TM, ST), STI(ST) {
32 // Even with SPV_ALTERA_arbitrary_precision_integers enabled, atomic sizes are
33 // limited by atomicrmw xchg operation, which only supports operand up to 64
34 // bits wide, as defined in SPIR-V legalizer. Currently, spirv-val doesn't
35 // consider 128-bit OpTypeInt as valid either.
36 setMaxAtomicSizeInBitsSupported(64);
37 setMinCmpXchgSizeInBits(8);
38}
39
40// Returns true of the types logically match, as defined in
41// https://registry.khronos.org/SPIR-V/specs/unified1/SPIRV.html#OpCopyLogical.
42static bool typesLogicallyMatch(const SPIRVTypeInst Ty1,
43 const SPIRVTypeInst Ty2,
44 SPIRVGlobalRegistry &GR) {
45 if (Ty1->getOpcode() != Ty2->getOpcode())
46 return false;
47
48 if (Ty1->getNumOperands() != Ty2->getNumOperands())
49 return false;
50
51 if (Ty1->getOpcode() == SPIRV::OpTypeArray) {
52 // Array must have the same size.
53 if (Ty1->getOperand(i: 2).getReg() != Ty2->getOperand(i: 2).getReg())
54 return false;
55
56 SPIRVTypeInst ElemType1 =
57 GR.getSPIRVTypeForVReg(VReg: Ty1->getOperand(i: 1).getReg());
58 SPIRVTypeInst ElemType2 =
59 GR.getSPIRVTypeForVReg(VReg: Ty2->getOperand(i: 1).getReg());
60 return ElemType1 == ElemType2 ||
61 typesLogicallyMatch(Ty1: ElemType1, Ty2: ElemType2, GR);
62 }
63
64 if (Ty1->getOpcode() == SPIRV::OpTypeStruct) {
65 for (unsigned I = 1; I < Ty1->getNumOperands(); I++) {
66 SPIRVTypeInst ElemType1 =
67 GR.getSPIRVTypeForVReg(VReg: Ty1->getOperand(i: I).getReg());
68 SPIRVTypeInst ElemType2 =
69 GR.getSPIRVTypeForVReg(VReg: Ty2->getOperand(i: I).getReg());
70 if (ElemType1 != ElemType2 &&
71 !typesLogicallyMatch(Ty1: ElemType1, Ty2: ElemType2, GR))
72 return false;
73 }
74 return true;
75 }
76 return false;
77}
78
79unsigned SPIRVTargetLowering::getNumRegistersForCallingConv(
80 LLVMContext &Context, CallingConv::ID CC, EVT VT) const {
81 // This code avoids CallLowering fail inside getVectorTypeBreakdown
82 // on v3i1 arguments. Maybe we need to return 1 for all types.
83 // TODO: remove it once this case is supported by the default implementation.
84 if (VT.isVector() && VT.getVectorNumElements() == 3 &&
85 (VT.getVectorElementType() == MVT::i1 ||
86 VT.getVectorElementType() == MVT::i8))
87 return 1;
88 if (!VT.isVector() && VT.isInteger() && VT.getSizeInBits() <= 64)
89 return 1;
90 return getNumRegisters(Context, VT);
91}
92
93MVT SPIRVTargetLowering::getRegisterTypeForCallingConv(LLVMContext &Context,
94 CallingConv::ID CC,
95 EVT VT) const {
96 // This code avoids CallLowering fail inside getVectorTypeBreakdown
97 // on v3i1 arguments. Maybe we need to return i32 for all types.
98 // TODO: remove it once this case is supported by the default implementation.
99 if (VT.isVector() && VT.getVectorNumElements() == 3) {
100 if (VT.getVectorElementType() == MVT::i1)
101 return MVT::v4i1;
102 else if (VT.getVectorElementType() == MVT::i8)
103 return MVT::v4i8;
104 }
105 return getRegisterType(Context, VT);
106}
107
108void SPIRVTargetLowering::getTgtMemIntrinsic(
109 SmallVectorImpl<IntrinsicInfo> &Infos, const CallBase &I,
110 MachineFunction &MF, unsigned Intrinsic) const {
111 IntrinsicInfo Info;
112 unsigned AlignIdx = 3;
113 switch (Intrinsic) {
114 case Intrinsic::spv_load:
115 AlignIdx = 2;
116 [[fallthrough]];
117 case Intrinsic::spv_store: {
118 if (I.getNumOperands() >= AlignIdx + 1) {
119 auto *AlignOp = cast<ConstantInt>(Val: I.getOperand(i_nocapture: AlignIdx));
120 Info.align = Align(AlignOp->getZExtValue());
121 }
122 Info.flags = static_cast<MachineMemOperand::Flags>(
123 cast<ConstantInt>(Val: I.getOperand(i_nocapture: AlignIdx - 1))->getZExtValue());
124 Info.memVT = MVT::i64;
125 // TODO: take into account opaque pointers (don't use getElementType).
126 // MVT::getVT(PtrTy->getElementType());
127 Infos.push_back(Elt: Info);
128 return;
129 }
130 default:
131 break;
132 }
133}
134
135std::pair<unsigned, const TargetRegisterClass *>
136SPIRVTargetLowering::getRegForInlineAsmConstraint(const TargetRegisterInfo *TRI,
137 StringRef Constraint,
138 MVT VT) const {
139 const TargetRegisterClass *RC = nullptr;
140 if (Constraint.starts_with(Prefix: "{"))
141 return std::make_pair(x: 0u, y&: RC);
142
143 if (VT.isFloatingPoint())
144 RC = VT.isVector() ? &SPIRV::vfIDRegClass : &SPIRV::fIDRegClass;
145 else if (VT.isInteger())
146 RC = VT.isVector() ? &SPIRV::vIDRegClass : &SPIRV::iIDRegClass;
147 else
148 RC = &SPIRV::iIDRegClass;
149
150 return std::make_pair(x: 0u, y&: RC);
151}
152
153inline Register getTypeReg(MachineRegisterInfo *MRI, Register OpReg) {
154 const MachineInstr *Inst = MRI->getVRegDef(Reg: OpReg);
155 return Inst && Inst->getOpcode() == SPIRV::OpFunctionParameter
156 ? Inst->getOperand(i: 1).getReg()
157 : OpReg;
158}
159
160static void doInsertBitcast(const SPIRVSubtarget &STI, MachineRegisterInfo *MRI,
161 SPIRVGlobalRegistry &GR, MachineInstr &I,
162 Register OpReg, unsigned OpIdx,
163 SPIRVTypeInst NewPtrType) {
164 MachineIRBuilder MIB(I);
165 Register NewReg = createVirtualRegister(SpvType: NewPtrType, GR: &GR, MRI, MF: MIB.getMF());
166 MIB.buildInstr(Opcode: SPIRV::OpBitcast)
167 .addDef(RegNo: NewReg)
168 .addUse(RegNo: GR.getSPIRVTypeID(SpirvType: NewPtrType))
169 .addUse(RegNo: OpReg)
170 .constrainAllUses(TII: *STI.getInstrInfo(), TRI: *STI.getRegisterInfo(),
171 RBI: *STI.getRegBankInfo());
172 I.getOperand(i: OpIdx).setReg(NewReg);
173}
174
175static SPIRVTypeInst createNewPtrType(SPIRVGlobalRegistry &GR, MachineInstr &I,
176 SPIRVTypeInst OpType, bool ReuseType,
177 SPIRVTypeInst ResType,
178 const Type *ResTy) {
179 SPIRV::StorageClass::StorageClass SC =
180 static_cast<SPIRV::StorageClass::StorageClass>(
181 OpType->getOperand(i: 1).getImm());
182 MachineIRBuilder MIB(I);
183 SPIRVTypeInst NewBaseType =
184 ReuseType ? ResType
185 : GR.getOrCreateSPIRVType(
186 Type: ResTy, MIRBuilder&: MIB, AQ: SPIRV::AccessQualifier::ReadWrite, EmitIR: false);
187 return GR.getOrCreateSPIRVPointerType(BaseType: NewBaseType, MIRBuilder&: MIB, SC);
188}
189
190// Insert a bitcast before the instruction to keep SPIR-V code valid
191// when there is a type mismatch between results and operand types.
192static void validatePtrTypes(const SPIRVSubtarget &STI,
193 MachineRegisterInfo *MRI, SPIRVGlobalRegistry &GR,
194 MachineInstr &I, unsigned OpIdx,
195 SPIRVTypeInst ResType,
196 const Type *ResTy = nullptr) {
197 // Get operand type
198 MachineFunction *MF = I.getParent()->getParent();
199 Register OpReg = I.getOperand(i: OpIdx).getReg();
200 Register OpTypeReg = getTypeReg(MRI, OpReg);
201 const MachineInstr *OpType = GR.getSPIRVTypeForVReg(VReg: OpTypeReg, MF);
202 if (!ResType || !OpType || OpType->getOpcode() != SPIRV::OpTypePointer)
203 return;
204 // Get operand's pointee type
205 Register ElemTypeReg = OpType->getOperand(i: 2).getReg();
206 SPIRVTypeInst ElemType = GR.getSPIRVTypeForVReg(VReg: ElemTypeReg, MF);
207 if (!ElemType)
208 return;
209 // Check if we need a bitcast to make a statement valid
210 bool IsSameMF = MF == ResType->getParent()->getParent();
211 bool IsEqualTypes = IsSameMF ? ElemType == ResType
212 : GR.getTypeForSPIRVType(Ty: ElemType) == ResTy;
213 if (IsEqualTypes)
214 return;
215 // There is a type mismatch between results and operand types
216 // and we insert a bitcast before the instruction to keep SPIR-V code valid
217 SPIRVTypeInst NewPtrType =
218 createNewPtrType(GR, I, OpType, ReuseType: IsSameMF, ResType, ResTy);
219 if (!GR.isBitcastCompatible(Type1: NewPtrType, Type2: OpType))
220 report_fatal_error(
221 reason: "insert validation bitcast: incompatible result and operand types");
222 doInsertBitcast(STI, MRI, GR, I, OpReg, OpIdx, NewPtrType);
223}
224
225// Insert a bitcast before OpGroupWaitEvents if the last argument is a pointer
226// that doesn't point to OpTypeEvent.
227static void validateGroupWaitEventsPtr(const SPIRVSubtarget &STI,
228 MachineRegisterInfo *MRI,
229 SPIRVGlobalRegistry &GR,
230 MachineInstr &I) {
231 constexpr unsigned OpIdx = 2;
232 MachineFunction *MF = I.getParent()->getParent();
233 Register OpReg = I.getOperand(i: OpIdx).getReg();
234 Register OpTypeReg = getTypeReg(MRI, OpReg);
235 SPIRVTypeInst OpType = GR.getSPIRVTypeForVReg(VReg: OpTypeReg, MF);
236 if (!OpType || OpType->getOpcode() != SPIRV::OpTypePointer)
237 return;
238 SPIRVTypeInst ElemType =
239 GR.getSPIRVTypeForVReg(VReg: OpType->getOperand(i: 2).getReg());
240 if (!ElemType || ElemType->getOpcode() == SPIRV::OpTypeEvent)
241 return;
242 // Insert a bitcast before the instruction to keep SPIR-V code valid.
243 LLVMContext &Context = MF->getFunction().getContext();
244 SPIRVTypeInst NewPtrType =
245 createNewPtrType(GR, I, OpType, ReuseType: false, ResType: nullptr,
246 ResTy: TargetExtType::get(Context, Name: "spirv.Event"));
247 doInsertBitcast(STI, MRI, GR, I, OpReg, OpIdx, NewPtrType);
248}
249
250static void validateLifetimeStart(const SPIRVSubtarget &STI,
251 MachineRegisterInfo *MRI,
252 SPIRVGlobalRegistry &GR, MachineInstr &I) {
253 Register PtrReg = I.getOperand(i: 0).getReg();
254 MachineFunction *MF = I.getParent()->getParent();
255 Register PtrTypeReg = getTypeReg(MRI, OpReg: PtrReg);
256 SPIRVTypeInst PtrType = GR.getSPIRVTypeForVReg(VReg: PtrTypeReg, MF);
257 SPIRVTypeInst PonteeElemType = PtrType ? GR.getPointeeType(PtrType) : nullptr;
258 if (!PonteeElemType || PonteeElemType->getOpcode() == SPIRV::OpTypeVoid ||
259 (PonteeElemType->getOpcode() == SPIRV::OpTypeInt &&
260 PonteeElemType->getOperand(i: 1).getImm() == 8))
261 return;
262 // To keep the code valid a bitcast must be inserted
263 SPIRV::StorageClass::StorageClass SC =
264 static_cast<SPIRV::StorageClass::StorageClass>(
265 PtrType->getOperand(i: 1).getImm());
266 MachineIRBuilder MIB(I);
267 LLVMContext &Context = MF->getFunction().getContext();
268 SPIRVTypeInst NewPtrType =
269 GR.getOrCreateSPIRVPointerType(BaseType: IntegerType::getInt8Ty(C&: Context), MIRBuilder&: MIB, SC);
270 doInsertBitcast(STI, MRI, GR, I, OpReg: PtrReg, OpIdx: 0, NewPtrType);
271}
272
273static void validatePtrUnwrapStructField(const SPIRVSubtarget &STI,
274 MachineRegisterInfo *MRI,
275 SPIRVGlobalRegistry &GR,
276 MachineInstr &I, unsigned OpIdx) {
277 MachineFunction *MF = I.getParent()->getParent();
278 Register OpReg = I.getOperand(i: OpIdx).getReg();
279 Register OpTypeReg = getTypeReg(MRI, OpReg);
280 SPIRVTypeInst OpType = GR.getSPIRVTypeForVReg(VReg: OpTypeReg, MF);
281 if (!OpType || OpType->getOpcode() != SPIRV::OpTypePointer)
282 return;
283 SPIRVTypeInst ElemType =
284 GR.getSPIRVTypeForVReg(VReg: OpType->getOperand(i: 2).getReg());
285 if (!ElemType || ElemType->getOpcode() != SPIRV::OpTypeStruct ||
286 ElemType->getNumOperands() != 2)
287 return;
288 // It's a structure-wrapper around another type with a single member field.
289 SPIRVTypeInst MemberType =
290 GR.getSPIRVTypeForVReg(VReg: ElemType->getOperand(i: 1).getReg());
291 if (!MemberType)
292 return;
293 unsigned MemberTypeOp = MemberType->getOpcode();
294 if (MemberTypeOp != SPIRV::OpTypeVector && MemberTypeOp != SPIRV::OpTypeInt &&
295 MemberTypeOp != SPIRV::OpTypeFloat && MemberTypeOp != SPIRV::OpTypeBool)
296 return;
297 // It's a structure-wrapper around a valid type. Insert a bitcast before the
298 // instruction to keep SPIR-V code valid.
299 SPIRV::StorageClass::StorageClass SC =
300 static_cast<SPIRV::StorageClass::StorageClass>(
301 OpType->getOperand(i: 1).getImm());
302 MachineIRBuilder MIB(I);
303 SPIRVTypeInst NewPtrType =
304 GR.getOrCreateSPIRVPointerType(BaseType: MemberType, MIRBuilder&: MIB, SC);
305 doInsertBitcast(STI, MRI, GR, I, OpReg, OpIdx, NewPtrType);
306}
307
308// Insert a bitcast before the function call instruction to keep SPIR-V code
309// valid when there is a type mismatch between actual and expected types of an
310// argument:
311// %formal = OpFunctionParameter %formal_type
312// ...
313// %res = OpFunctionCall %ty %fun %actual ...
314// implies that %actual is of %formal_type, and in case of opaque pointers.
315// We may need to insert a bitcast to ensure this.
316void validateFunCallMachineDef(const SPIRVSubtarget &STI,
317 MachineRegisterInfo *DefMRI,
318 MachineRegisterInfo *CallMRI,
319 SPIRVGlobalRegistry &GR, MachineInstr &FunCall,
320 MachineInstr *FunDef) {
321 if (FunDef->getOpcode() != SPIRV::OpFunction)
322 return;
323 unsigned OpIdx = 3;
324 for (FunDef = FunDef->getNextNode();
325 FunDef && FunDef->getOpcode() == SPIRV::OpFunctionParameter &&
326 OpIdx < FunCall.getNumOperands();
327 FunDef = FunDef->getNextNode(), OpIdx++) {
328 SPIRVTypeInst DefPtrType =
329 DefMRI->getVRegDef(Reg: FunDef->getOperand(i: 1).getReg());
330 SPIRVTypeInst DefElemType =
331 DefPtrType && DefPtrType->getOpcode() == SPIRV::OpTypePointer
332 ? GR.getSPIRVTypeForVReg(VReg: DefPtrType->getOperand(i: 2).getReg(),
333 MF: DefPtrType->getParent()->getParent())
334 : nullptr;
335 if (DefElemType) {
336 const Type *DefElemTy = GR.getTypeForSPIRVType(Ty: DefElemType);
337 // validatePtrTypes() works in the context if the call site
338 // When we process historical records about forward calls
339 // we need to switch context to the (forward) call site and
340 // then restore it back to the current machine function.
341 MachineFunction *CurMF =
342 GR.setCurrentFunc(*FunCall.getParent()->getParent());
343 validatePtrTypes(STI, MRI: CallMRI, GR, I&: FunCall, OpIdx, ResType: DefElemType,
344 ResTy: DefElemTy);
345 GR.setCurrentFunc(*CurMF);
346 }
347 }
348}
349
350// Ensure there is no mismatch between actual and expected arg types: calls
351// with a processed definition. Return Function pointer if it's a forward
352// call (ahead of definition), and nullptr otherwise.
353const Function *validateFunCall(const SPIRVSubtarget &STI,
354 MachineRegisterInfo *CallMRI,
355 SPIRVGlobalRegistry &GR,
356 MachineInstr &FunCall) {
357 const GlobalValue *GV = FunCall.getOperand(i: 2).getGlobal();
358 const Function *F = dyn_cast<Function>(Val: GV);
359 MachineInstr *FunDef =
360 const_cast<MachineInstr *>(GR.getFunctionDefinition(F));
361 if (!FunDef)
362 return F;
363 MachineRegisterInfo *DefMRI = &FunDef->getParent()->getParent()->getRegInfo();
364 validateFunCallMachineDef(STI, DefMRI, CallMRI, GR, FunCall, FunDef);
365 return nullptr;
366}
367
368// Ensure there is no mismatch between actual and expected arg types: calls
369// ahead of a processed definition.
370void validateForwardCalls(const SPIRVSubtarget &STI,
371 MachineRegisterInfo *DefMRI, SPIRVGlobalRegistry &GR,
372 MachineInstr &FunDef) {
373 const Function *F = GR.getFunctionByDefinition(MI: &FunDef);
374 if (SmallPtrSet<MachineInstr *, 8> *FwdCalls = GR.getForwardCalls(F))
375 for (MachineInstr *FunCall : *FwdCalls) {
376 MachineRegisterInfo *CallMRI =
377 &FunCall->getParent()->getParent()->getRegInfo();
378 validateFunCallMachineDef(STI, DefMRI, CallMRI, GR, FunCall&: *FunCall, FunDef: &FunDef);
379 }
380}
381
382// Validation of an access chain.
383void validateAccessChain(const SPIRVSubtarget &STI, MachineRegisterInfo *MRI,
384 SPIRVGlobalRegistry &GR, MachineInstr &I) {
385 SPIRVTypeInst BaseTypeInst = GR.getSPIRVTypeForVReg(VReg: I.getOperand(i: 0).getReg());
386 if (BaseTypeInst && BaseTypeInst->getOpcode() == SPIRV::OpTypePointer) {
387 SPIRVTypeInst BaseElemType =
388 GR.getSPIRVTypeForVReg(VReg: BaseTypeInst->getOperand(i: 2).getReg());
389 validatePtrTypes(STI, MRI, GR, I, OpIdx: 2, ResType: BaseElemType);
390 }
391}
392
393// TODO: the logic of inserting additional bitcast's is to be moved
394// to pre-IRTranslation passes eventually
395void SPIRVTargetLowering::finalizeLowering(MachineFunction &MF) const {
396 // finalizeLowering() is called twice (see GlobalISel/InstructionSelect.cpp)
397 // We'd like to avoid the needless second processing pass.
398 if (ProcessedMF.find(x: &MF) != ProcessedMF.end())
399 return;
400
401 MachineRegisterInfo *MRI = &MF.getRegInfo();
402 SPIRVGlobalRegistry &GR = *STI.getSPIRVGlobalRegistry();
403 GR.setCurrentFunc(MF);
404 for (MachineFunction::iterator I = MF.begin(), E = MF.end(); I != E; ++I) {
405 MachineBasicBlock *MBB = &*I;
406 SmallPtrSet<MachineInstr *, 8> ToMove;
407 for (MachineBasicBlock::iterator MBBI = MBB->begin(), MBBE = MBB->end();
408 MBBI != MBBE;) {
409 MachineInstr &MI = *MBBI++;
410 switch (MI.getOpcode()) {
411 case SPIRV::OpAtomicLoad:
412 case SPIRV::OpAtomicExchange:
413 case SPIRV::OpAtomicCompareExchange:
414 case SPIRV::OpAtomicCompareExchangeWeak:
415 case SPIRV::OpAtomicIIncrement:
416 case SPIRV::OpAtomicIDecrement:
417 case SPIRV::OpAtomicIAdd:
418 case SPIRV::OpAtomicISub:
419 case SPIRV::OpAtomicSMin:
420 case SPIRV::OpAtomicUMin:
421 case SPIRV::OpAtomicSMax:
422 case SPIRV::OpAtomicUMax:
423 case SPIRV::OpAtomicAnd:
424 case SPIRV::OpAtomicOr:
425 case SPIRV::OpAtomicXor:
426 // for the above listed instructions
427 // OpAtomicXXX <ResType>, ptr %Op, ...
428 // implies that %Op is a pointer to <ResType>
429 case SPIRV::OpLoad:
430 // OpLoad <ResType>, ptr %Op implies that %Op is a pointer to <ResType>
431 if (enforcePtrTypeCompatibility(I&: MI, PtrOpIdx: 2, OpIdx: 0))
432 break;
433
434 validatePtrTypes(STI, MRI, GR, I&: MI, OpIdx: 2,
435 ResType: GR.getSPIRVTypeForVReg(VReg: MI.getOperand(i: 0).getReg()));
436 break;
437 case SPIRV::OpAtomicStore:
438 // OpAtomicStore ptr %Op, <Scope>, <Mem>, <Obj>
439 // implies that %Op points to the <Obj>'s type
440 validatePtrTypes(STI, MRI, GR, I&: MI, OpIdx: 0,
441 ResType: GR.getSPIRVTypeForVReg(VReg: MI.getOperand(i: 3).getReg()));
442 break;
443 case SPIRV::OpStore:
444 // OpStore ptr %Op, <Obj> implies that %Op points to the <Obj>'s type
445 validatePtrTypes(STI, MRI, GR, I&: MI, OpIdx: 0,
446 ResType: GR.getSPIRVTypeForVReg(VReg: MI.getOperand(i: 1).getReg()));
447 break;
448 case SPIRV::OpPtrCastToGeneric:
449 case SPIRV::OpGenericCastToPtr:
450 case SPIRV::OpGenericCastToPtrExplicit:
451 validateAccessChain(STI, MRI, GR, I&: MI);
452 break;
453 case SPIRV::OpPtrAccessChain:
454 case SPIRV::OpInBoundsPtrAccessChain:
455 if (MI.getNumOperands() == 4)
456 validateAccessChain(STI, MRI, GR, I&: MI);
457 break;
458
459 case SPIRV::OpFunctionCall:
460 // ensure there is no mismatch between actual and expected arg types:
461 // calls with a processed definition
462 if (MI.getNumOperands() > 3)
463 if (const Function *F = validateFunCall(STI, CallMRI: MRI, GR, FunCall&: MI))
464 GR.addForwardCall(F, MI: &MI);
465 break;
466 case SPIRV::OpFunction:
467 // ensure there is no mismatch between actual and expected arg types:
468 // calls ahead of a processed definition
469 validateForwardCalls(STI, DefMRI: MRI, GR, FunDef&: MI);
470 break;
471
472 // ensure that LLVM IR add/sub instructions result in logical SPIR-V
473 // instructions when applied to bool type
474 case SPIRV::OpIAddS:
475 case SPIRV::OpIAddV:
476 case SPIRV::OpISubS:
477 case SPIRV::OpISubV:
478 if (GR.isScalarOrVectorOfType(VReg: MI.getOperand(i: 1).getReg(),
479 TypeOpcode: SPIRV::OpTypeBool))
480 MI.setDesc(STI.getInstrInfo()->get(Opcode: SPIRV::OpLogicalNotEqual));
481 break;
482
483 // ensure that LLVM IR bitwise instructions result in logical SPIR-V
484 // instructions when applied to bool type
485 case SPIRV::OpBitwiseOrS:
486 case SPIRV::OpBitwiseOrV:
487 if (GR.isScalarOrVectorOfType(VReg: MI.getOperand(i: 1).getReg(),
488 TypeOpcode: SPIRV::OpTypeBool))
489 MI.setDesc(STI.getInstrInfo()->get(Opcode: SPIRV::OpLogicalOr));
490 break;
491 case SPIRV::OpBitwiseAndS:
492 case SPIRV::OpBitwiseAndV:
493 if (GR.isScalarOrVectorOfType(VReg: MI.getOperand(i: 1).getReg(),
494 TypeOpcode: SPIRV::OpTypeBool))
495 MI.setDesc(STI.getInstrInfo()->get(Opcode: SPIRV::OpLogicalAnd));
496 break;
497 case SPIRV::OpBitwiseXorS:
498 case SPIRV::OpBitwiseXorV:
499 if (GR.isScalarOrVectorOfType(VReg: MI.getOperand(i: 1).getReg(),
500 TypeOpcode: SPIRV::OpTypeBool))
501 MI.setDesc(STI.getInstrInfo()->get(Opcode: SPIRV::OpLogicalNotEqual));
502 break;
503 case SPIRV::OpLifetimeStart:
504 case SPIRV::OpLifetimeStop:
505 if (MI.getOperand(i: 1).getImm() > 0)
506 validateLifetimeStart(STI, MRI, GR, I&: MI);
507 break;
508 case SPIRV::OpGroupAsyncCopy:
509 validatePtrUnwrapStructField(STI, MRI, GR, I&: MI, OpIdx: 3);
510 validatePtrUnwrapStructField(STI, MRI, GR, I&: MI, OpIdx: 4);
511 break;
512 case SPIRV::OpGroupWaitEvents:
513 // OpGroupWaitEvents ..., ..., <pointer to OpTypeEvent>
514 validateGroupWaitEventsPtr(STI, MRI, GR, I&: MI);
515 break;
516 case SPIRV::OpConstantI: {
517 SPIRVTypeInst Type = GR.getSPIRVTypeForVReg(VReg: MI.getOperand(i: 1).getReg());
518 if (Type->getOpcode() != SPIRV::OpTypeInt && MI.getOperand(i: 2).isImm() &&
519 MI.getOperand(i: 2).getImm() == 0) {
520 // Validate the null constant of a target extension type
521 MI.setDesc(STI.getInstrInfo()->get(Opcode: SPIRV::OpConstantNull));
522 for (unsigned i = MI.getNumOperands() - 1; i > 1; --i)
523 MI.removeOperand(OpNo: i);
524 }
525 } break;
526 case SPIRV::OpPhi: {
527 // Phi refers to a type definition that goes after the Phi
528 // instruction, so that the virtual register definition of the type
529 // doesn't dominate all uses. Let's place the type definition
530 // instruction at the end of the predecessor.
531 MachineBasicBlock *Curr = MI.getParent();
532 SPIRVTypeInst Type = GR.getSPIRVTypeForVReg(VReg: MI.getOperand(i: 1).getReg());
533 if (Type->getParent() == Curr && !Curr->pred_empty())
534 ToMove.insert(Ptr: const_cast<MachineInstr *>(&*Type));
535 } break;
536 case SPIRV::OpExtInst: {
537 // prefetch
538 if (!MI.getOperand(i: 2).isImm() || !MI.getOperand(i: 3).isImm() ||
539 MI.getOperand(i: 2).getImm() != SPIRV::InstructionSet::OpenCL_std)
540 continue;
541 switch (MI.getOperand(i: 3).getImm()) {
542 case SPIRV::OpenCLExtInst::frexp:
543 case SPIRV::OpenCLExtInst::lgamma_r:
544 case SPIRV::OpenCLExtInst::remquo: {
545 // The last operand must be of a pointer to i32 or vector of i32
546 // values.
547 MachineIRBuilder MIB(MI);
548 SPIRVTypeInst Int32Type = GR.getOrCreateSPIRVIntegerType(BitWidth: 32, MIRBuilder&: MIB);
549 SPIRVTypeInst RetType = MRI->getVRegDef(Reg: MI.getOperand(i: 1).getReg());
550 assert(RetType && "Expected return type");
551 validatePtrTypes(STI, MRI, GR, I&: MI, OpIdx: MI.getNumOperands() - 1,
552 ResType: RetType->getOpcode() != SPIRV::OpTypeVector
553 ? Int32Type
554 : GR.getOrCreateSPIRVVectorType(
555 BaseType: Int32Type, NumElements: RetType->getOperand(i: 2).getImm(),
556 MIRBuilder&: MIB, EmitIR: false));
557 } break;
558 case SPIRV::OpenCLExtInst::fract:
559 case SPIRV::OpenCLExtInst::modf:
560 case SPIRV::OpenCLExtInst::sincos:
561 // The last operand must be of a pointer to the base type represented
562 // by the previous operand.
563 assert(MI.getOperand(MI.getNumOperands() - 2).isReg() &&
564 "Expected v-reg");
565 validatePtrTypes(
566 STI, MRI, GR, I&: MI, OpIdx: MI.getNumOperands() - 1,
567 ResType: GR.getSPIRVTypeForVReg(
568 VReg: MI.getOperand(i: MI.getNumOperands() - 2).getReg()));
569 break;
570 case SPIRV::OpenCLExtInst::prefetch:
571 // Expected `ptr` type is a pointer to float, integer or vector, but
572 // the pontee value can be wrapped into a struct.
573 assert(MI.getOperand(MI.getNumOperands() - 2).isReg() &&
574 "Expected v-reg");
575 validatePtrUnwrapStructField(STI, MRI, GR, I&: MI,
576 OpIdx: MI.getNumOperands() - 2);
577 break;
578 }
579 } break;
580 }
581 }
582 for (MachineInstr *MI : ToMove) {
583 MachineBasicBlock *Curr = MI->getParent();
584 MachineBasicBlock *Pred = *Curr->pred_begin();
585 Pred->insert(I: Pred->getFirstTerminator(), MI: Curr->remove_instr(I: MI));
586 }
587 }
588 ProcessedMF.insert(x: &MF);
589 TargetLowering::finalizeLowering(MF);
590}
591
592// Modifies either operand PtrOpIdx or OpIdx so that the pointee type of
593// PtrOpIdx matches the type for operand OpIdx. Returns true if they already
594// match or if the instruction was modified to make them match.
595bool SPIRVTargetLowering::enforcePtrTypeCompatibility(
596 MachineInstr &I, unsigned int PtrOpIdx, unsigned int OpIdx) const {
597 SPIRVGlobalRegistry &GR = *STI.getSPIRVGlobalRegistry();
598 SPIRVTypeInst PtrType = GR.getResultType(VReg: I.getOperand(i: PtrOpIdx).getReg());
599 SPIRVTypeInst PointeeType = GR.getPointeeType(PtrType);
600 SPIRVTypeInst OpType = GR.getResultType(VReg: I.getOperand(i: OpIdx).getReg());
601
602 if (PointeeType == OpType)
603 return true;
604
605 if (typesLogicallyMatch(Ty1: PointeeType, Ty2: OpType, GR)) {
606 // Apply OpCopyLogical to OpIdx.
607 if (I.getOperand(i: OpIdx).isDef() &&
608 insertLogicalCopyOnResult(I, NewResultType: PointeeType)) {
609 return true;
610 }
611
612 llvm_unreachable("Unable to add OpCopyLogical yet.");
613 return false;
614 }
615
616 return false;
617}
618
619bool SPIRVTargetLowering::insertLogicalCopyOnResult(
620 MachineInstr &I, SPIRVTypeInst NewResultType) const {
621 MachineRegisterInfo *MRI = &I.getMF()->getRegInfo();
622 SPIRVGlobalRegistry &GR = *STI.getSPIRVGlobalRegistry();
623
624 Register NewResultReg =
625 createVirtualRegister(SpvType: NewResultType, GR: &GR, MRI, MF: *I.getMF());
626 Register NewTypeReg = GR.getSPIRVTypeID(SpirvType: NewResultType);
627
628 assert(llvm::size(I.defs()) == 1 && "Expected only one def");
629 MachineOperand &OldResult = *I.defs().begin();
630 Register OldResultReg = OldResult.getReg();
631 MachineOperand &OldType = *I.uses().begin();
632 Register OldTypeReg = OldType.getReg();
633
634 OldResult.setReg(NewResultReg);
635 OldType.setReg(NewTypeReg);
636
637 MachineIRBuilder MIB(*I.getNextNode());
638 MIB.buildInstr(Opcode: SPIRV::OpCopyLogical)
639 .addDef(RegNo: OldResultReg)
640 .addUse(RegNo: OldTypeReg)
641 .addUse(RegNo: NewResultReg)
642 .constrainAllUses(TII: *STI.getInstrInfo(), TRI: *STI.getRegisterInfo(),
643 RBI: *STI.getRegBankInfo());
644 return true;
645}
646
647TargetLowering::AtomicExpansionKind
648SPIRVTargetLowering::shouldExpandAtomicRMWInIR(const AtomicRMWInst *RMW) const {
649 switch (RMW->getOperation()) {
650 case AtomicRMWInst::FAdd:
651 case AtomicRMWInst::FSub:
652 case AtomicRMWInst::FMin:
653 case AtomicRMWInst::FMax:
654 return AtomicExpansionKind::None;
655 case AtomicRMWInst::UIncWrap:
656 case AtomicRMWInst::UDecWrap:
657 return AtomicExpansionKind::CmpXChg;
658 default:
659 return TargetLowering::shouldExpandAtomicRMWInIR(RMW);
660 }
661}
662
663TargetLowering::AtomicExpansionKind
664SPIRVTargetLowering::shouldCastAtomicRMWIInIR(AtomicRMWInst *RMWI) const {
665 // TODO: Pointer operand should be cast to integer in atomicrmw xchg, since
666 // SPIR-V only supports atomic exchange for integer and floating-point types.
667 return AtomicExpansionKind::None;
668}
669