1//===- SPIRVISelLowering.cpp - SPIR-V DAG Lowering Impl ---------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements the SPIRVTargetLowering class.
10//
11//===----------------------------------------------------------------------===//
12
13#include "SPIRVISelLowering.h"
14#include "SPIRV.h"
15#include "SPIRVInstrInfo.h"
16#include "SPIRVRegisterBankInfo.h"
17#include "SPIRVRegisterInfo.h"
18#include "SPIRVSubtarget.h"
19#include "llvm/CodeGen/MachineInstrBuilder.h"
20#include "llvm/CodeGen/MachineRegisterInfo.h"
21#include "llvm/CodeGen/TargetLowering.h"
22#include "llvm/IR/Instructions.h"
23#include "llvm/IR/IntrinsicsSPIRV.h"
24
25#define DEBUG_TYPE "spirv-lower"
26
27using namespace llvm;
28
29SPIRVTargetLowering::SPIRVTargetLowering(const TargetMachine &TM,
30 const SPIRVSubtarget &ST)
31 : TargetLowering(TM, ST), STI(ST) {
32 // Even with SPV_ALTERA_arbitrary_precision_integers enabled, atomic sizes are
33 // limited by atomicrmw xchg operation, which only supports operand up to 64
34 // bits wide, as defined in SPIR-V legalizer. Currently, spirv-val doesn't
35 // consider 128-bit OpTypeInt as valid either.
36 setMaxAtomicSizeInBitsSupported(64);
37 setMinCmpXchgSizeInBits(8);
38}
39
40// Returns true of the types logically match, as defined in
41// https://registry.khronos.org/SPIR-V/specs/unified1/SPIRV.html#OpCopyLogical.
42static bool typesLogicallyMatch(const SPIRVTypeInst Ty1,
43 const SPIRVTypeInst Ty2,
44 SPIRVGlobalRegistry &GR) {
45 if (Ty1->getOpcode() != Ty2->getOpcode())
46 return false;
47
48 if (Ty1->getNumOperands() != Ty2->getNumOperands())
49 return false;
50
51 if (Ty1->getOpcode() == SPIRV::OpTypeArray) {
52 // Array must have the same size.
53 if (Ty1->getOperand(i: 2).getReg() != Ty2->getOperand(i: 2).getReg())
54 return false;
55
56 SPIRVTypeInst ElemType1 =
57 GR.getSPIRVTypeForVReg(VReg: Ty1->getOperand(i: 1).getReg());
58 SPIRVTypeInst ElemType2 =
59 GR.getSPIRVTypeForVReg(VReg: Ty2->getOperand(i: 1).getReg());
60 return ElemType1 == ElemType2 ||
61 typesLogicallyMatch(Ty1: ElemType1, Ty2: ElemType2, GR);
62 }
63
64 if (Ty1->getOpcode() == SPIRV::OpTypeStruct) {
65 for (unsigned I = 1; I < Ty1->getNumOperands(); I++) {
66 SPIRVTypeInst ElemType1 =
67 GR.getSPIRVTypeForVReg(VReg: Ty1->getOperand(i: I).getReg());
68 SPIRVTypeInst ElemType2 =
69 GR.getSPIRVTypeForVReg(VReg: Ty2->getOperand(i: I).getReg());
70 if (ElemType1 != ElemType2 &&
71 !typesLogicallyMatch(Ty1: ElemType1, Ty2: ElemType2, GR))
72 return false;
73 }
74 return true;
75 }
76 return false;
77}
78
79unsigned SPIRVTargetLowering::getNumRegistersForCallingConv(
80 LLVMContext &Context, CallingConv::ID CC, EVT VT) const {
81 // This code avoids CallLowering fail inside getVectorTypeBreakdown
82 // on v3i1 arguments. Maybe we need to return 1 for all types.
83 // TODO: remove it once this case is supported by the default implementation.
84 if (VT.isVector() && VT.getVectorNumElements() == 3 &&
85 (VT.getVectorElementType() == MVT::i1 ||
86 VT.getVectorElementType() == MVT::i8))
87 return 1;
88 if (!VT.isVector() && VT.isInteger() && VT.getSizeInBits() <= 64)
89 return 1;
90 return getNumRegisters(Context, VT);
91}
92
93MVT SPIRVTargetLowering::getRegisterTypeForCallingConv(LLVMContext &Context,
94 CallingConv::ID CC,
95 EVT VT) const {
96 // This code avoids CallLowering fail inside getVectorTypeBreakdown
97 // on v3i1 arguments. Maybe we need to return i32 for all types.
98 // TODO: remove it once this case is supported by the default implementation.
99 if (VT.isVector() && VT.getVectorNumElements() == 3) {
100 if (VT.getVectorElementType() == MVT::i1)
101 return MVT::v4i1;
102 else if (VT.getVectorElementType() == MVT::i8)
103 return MVT::v4i8;
104 }
105 return getRegisterType(Context, VT);
106}
107
108void SPIRVTargetLowering::getTgtMemIntrinsic(
109 SmallVectorImpl<IntrinsicInfo> &Infos, const CallBase &I,
110 MachineFunction &MF, unsigned Intrinsic) const {
111 IntrinsicInfo Info;
112 unsigned AlignIdx = 3;
113 switch (Intrinsic) {
114 case Intrinsic::spv_load:
115 AlignIdx = 2;
116 [[fallthrough]];
117 case Intrinsic::spv_store: {
118 if (I.getNumOperands() >= AlignIdx + 1) {
119 auto *AlignOp = cast<ConstantInt>(Val: I.getOperand(i_nocapture: AlignIdx));
120 Info.align = Align(AlignOp->getZExtValue());
121 }
122 Info.flags = static_cast<MachineMemOperand::Flags>(
123 cast<ConstantInt>(Val: I.getOperand(i_nocapture: AlignIdx - 1))->getZExtValue());
124 Info.memVT = MVT::i64;
125 // TODO: take into account opaque pointers (don't use getElementType).
126 // MVT::getVT(PtrTy->getElementType());
127 Infos.push_back(Elt: Info);
128 return;
129 }
130 default:
131 break;
132 }
133}
134
135TargetLowering::ConstraintType
136SPIRVTargetLowering::getConstraintType(StringRef Constraint) const {
137 // SPIR-V represents inline assembly via OpAsmINTEL where constraints are
138 // passed through as literals defined by client API. Return C_RegisterClass
139 // for any constraint since SPIR-V does not distinguish between register,
140 // immediate, or memory operands at this level.
141 return C_RegisterClass;
142}
143
144std::pair<unsigned, const TargetRegisterClass *>
145SPIRVTargetLowering::getRegForInlineAsmConstraint(const TargetRegisterInfo *TRI,
146 StringRef Constraint,
147 MVT VT) const {
148 const TargetRegisterClass *RC = nullptr;
149 if (Constraint.starts_with(Prefix: "{"))
150 return std::make_pair(x: 0u, y&: RC);
151
152 if (VT.isFloatingPoint())
153 RC = VT.isVector() ? &SPIRV::vfIDRegClass : &SPIRV::fIDRegClass;
154 else if (VT.isInteger())
155 RC = VT.isVector() ? &SPIRV::vIDRegClass : &SPIRV::iIDRegClass;
156 else
157 RC = &SPIRV::iIDRegClass;
158
159 return std::make_pair(x: 0u, y&: RC);
160}
161
162inline Register getTypeReg(MachineRegisterInfo *MRI, Register OpReg) {
163 const MachineInstr *Inst = MRI->getVRegDef(Reg: OpReg);
164 return Inst && Inst->getOpcode() == SPIRV::OpFunctionParameter
165 ? Inst->getOperand(i: 1).getReg()
166 : OpReg;
167}
168
169static void doInsertBitcast(const SPIRVSubtarget &STI, MachineRegisterInfo *MRI,
170 SPIRVGlobalRegistry &GR, MachineInstr &I,
171 Register OpReg, unsigned OpIdx,
172 SPIRVTypeInst NewPtrType) {
173 MachineIRBuilder MIB(I);
174 Register NewReg = createVirtualRegister(SpvType: NewPtrType, GR: &GR, MRI, MF: MIB.getMF());
175 MIB.buildInstr(Opcode: SPIRV::OpBitcast)
176 .addDef(RegNo: NewReg)
177 .addUse(RegNo: GR.getSPIRVTypeID(SpirvType: NewPtrType))
178 .addUse(RegNo: OpReg)
179 .constrainAllUses(TII: *STI.getInstrInfo(), TRI: *STI.getRegisterInfo(),
180 RBI: *STI.getRegBankInfo());
181 I.getOperand(i: OpIdx).setReg(NewReg);
182}
183
184static SPIRVTypeInst createNewPtrType(SPIRVGlobalRegistry &GR, MachineInstr &I,
185 SPIRVTypeInst OpType, bool ReuseType,
186 SPIRVTypeInst ResType,
187 const Type *ResTy) {
188 SPIRV::StorageClass::StorageClass SC =
189 static_cast<SPIRV::StorageClass::StorageClass>(
190 OpType->getOperand(i: 1).getImm());
191 MachineIRBuilder MIB(I);
192 SPIRVTypeInst NewBaseType =
193 ReuseType ? ResType
194 : GR.getOrCreateSPIRVType(
195 Type: ResTy, MIRBuilder&: MIB, AQ: SPIRV::AccessQualifier::ReadWrite, EmitIR: false);
196 return GR.getOrCreateSPIRVPointerType(BaseType: NewBaseType, MIRBuilder&: MIB, SC);
197}
198
199// Insert a bitcast before the instruction to keep SPIR-V code valid
200// when there is a type mismatch between results and operand types.
201static void validatePtrTypes(const SPIRVSubtarget &STI,
202 MachineRegisterInfo *MRI, SPIRVGlobalRegistry &GR,
203 MachineInstr &I, unsigned OpIdx,
204 SPIRVTypeInst ResType,
205 const Type *ResTy = nullptr) {
206 // Get operand type
207 MachineFunction *MF = I.getParent()->getParent();
208 Register OpReg = I.getOperand(i: OpIdx).getReg();
209 Register OpTypeReg = getTypeReg(MRI, OpReg);
210 const MachineInstr *OpType = GR.getSPIRVTypeForVReg(VReg: OpTypeReg, MF);
211 if (!ResType || !OpType || OpType->getOpcode() != SPIRV::OpTypePointer)
212 return;
213 // Get operand's pointee type
214 Register ElemTypeReg = OpType->getOperand(i: 2).getReg();
215 SPIRVTypeInst ElemType = GR.getSPIRVTypeForVReg(VReg: ElemTypeReg, MF);
216 if (!ElemType)
217 return;
218 // Check if we need a bitcast to make a statement valid
219 bool IsSameMF = MF == ResType->getParent()->getParent();
220 bool IsEqualTypes = IsSameMF ? ElemType == ResType
221 : GR.getTypeForSPIRVType(Ty: ElemType) == ResTy;
222 if (IsEqualTypes)
223 return;
224 // There is a type mismatch between results and operand types
225 // and we insert a bitcast before the instruction to keep SPIR-V code valid
226 SPIRVTypeInst NewPtrType =
227 createNewPtrType(GR, I, OpType, ReuseType: IsSameMF, ResType, ResTy);
228 if (!GR.isBitcastCompatible(Type1: NewPtrType, Type2: OpType))
229 report_fatal_error(
230 reason: "insert validation bitcast: incompatible result and operand types");
231 doInsertBitcast(STI, MRI, GR, I, OpReg, OpIdx, NewPtrType);
232}
233
234// Insert a bitcast before OpGroupWaitEvents if the last argument is a pointer
235// that doesn't point to OpTypeEvent.
236static void validateGroupWaitEventsPtr(const SPIRVSubtarget &STI,
237 MachineRegisterInfo *MRI,
238 SPIRVGlobalRegistry &GR,
239 MachineInstr &I) {
240 constexpr unsigned OpIdx = 2;
241 MachineFunction *MF = I.getParent()->getParent();
242 Register OpReg = I.getOperand(i: OpIdx).getReg();
243 Register OpTypeReg = getTypeReg(MRI, OpReg);
244 SPIRVTypeInst OpType = GR.getSPIRVTypeForVReg(VReg: OpTypeReg, MF);
245 if (!OpType || OpType->getOpcode() != SPIRV::OpTypePointer)
246 return;
247 SPIRVTypeInst ElemType =
248 GR.getSPIRVTypeForVReg(VReg: OpType->getOperand(i: 2).getReg());
249 if (!ElemType || ElemType->getOpcode() == SPIRV::OpTypeEvent)
250 return;
251 // Insert a bitcast before the instruction to keep SPIR-V code valid.
252 LLVMContext &Context = MF->getFunction().getContext();
253 SPIRVTypeInst NewPtrType =
254 createNewPtrType(GR, I, OpType, ReuseType: false, ResType: nullptr,
255 ResTy: TargetExtType::get(Context, Name: "spirv.Event"));
256 doInsertBitcast(STI, MRI, GR, I, OpReg, OpIdx, NewPtrType);
257}
258
259static void validateLifetimeStart(const SPIRVSubtarget &STI,
260 MachineRegisterInfo *MRI,
261 SPIRVGlobalRegistry &GR, MachineInstr &I) {
262 Register PtrReg = I.getOperand(i: 0).getReg();
263 MachineFunction *MF = I.getParent()->getParent();
264 Register PtrTypeReg = getTypeReg(MRI, OpReg: PtrReg);
265 SPIRVTypeInst PtrType = GR.getSPIRVTypeForVReg(VReg: PtrTypeReg, MF);
266 SPIRVTypeInst PonteeElemType = PtrType ? GR.getPointeeType(PtrType) : nullptr;
267 if (!PonteeElemType || PonteeElemType->getOpcode() == SPIRV::OpTypeVoid ||
268 (PonteeElemType->getOpcode() == SPIRV::OpTypeInt &&
269 PonteeElemType->getOperand(i: 1).getImm() == 8))
270 return;
271 // To keep the code valid a bitcast must be inserted
272 SPIRV::StorageClass::StorageClass SC =
273 static_cast<SPIRV::StorageClass::StorageClass>(
274 PtrType->getOperand(i: 1).getImm());
275 MachineIRBuilder MIB(I);
276 LLVMContext &Context = MF->getFunction().getContext();
277 SPIRVTypeInst NewPtrType =
278 GR.getOrCreateSPIRVPointerType(BaseType: IntegerType::getInt8Ty(C&: Context), MIRBuilder&: MIB, SC);
279 doInsertBitcast(STI, MRI, GR, I, OpReg: PtrReg, OpIdx: 0, NewPtrType);
280}
281
282static void validatePtrUnwrapStructField(const SPIRVSubtarget &STI,
283 MachineRegisterInfo *MRI,
284 SPIRVGlobalRegistry &GR,
285 MachineInstr &I, unsigned OpIdx) {
286 MachineFunction *MF = I.getParent()->getParent();
287 Register OpReg = I.getOperand(i: OpIdx).getReg();
288 Register OpTypeReg = getTypeReg(MRI, OpReg);
289 SPIRVTypeInst OpType = GR.getSPIRVTypeForVReg(VReg: OpTypeReg, MF);
290 if (!OpType || OpType->getOpcode() != SPIRV::OpTypePointer)
291 return;
292 SPIRVTypeInst ElemType =
293 GR.getSPIRVTypeForVReg(VReg: OpType->getOperand(i: 2).getReg());
294 if (!ElemType || ElemType->getOpcode() != SPIRV::OpTypeStruct ||
295 ElemType->getNumOperands() != 2)
296 return;
297 // It's a structure-wrapper around another type with a single member field.
298 SPIRVTypeInst MemberType =
299 GR.getSPIRVTypeForVReg(VReg: ElemType->getOperand(i: 1).getReg());
300 if (!MemberType)
301 return;
302 unsigned MemberTypeOp = MemberType->getOpcode();
303 if (MemberTypeOp != SPIRV::OpTypeVector && MemberTypeOp != SPIRV::OpTypeInt &&
304 MemberTypeOp != SPIRV::OpTypeFloat && MemberTypeOp != SPIRV::OpTypeBool)
305 return;
306 // It's a structure-wrapper around a valid type. Insert a bitcast before the
307 // instruction to keep SPIR-V code valid.
308 SPIRV::StorageClass::StorageClass SC =
309 static_cast<SPIRV::StorageClass::StorageClass>(
310 OpType->getOperand(i: 1).getImm());
311 MachineIRBuilder MIB(I);
312 SPIRVTypeInst NewPtrType =
313 GR.getOrCreateSPIRVPointerType(BaseType: MemberType, MIRBuilder&: MIB, SC);
314 doInsertBitcast(STI, MRI, GR, I, OpReg, OpIdx, NewPtrType);
315}
316
317// Insert a bitcast before the function call instruction to keep SPIR-V code
318// valid when there is a type mismatch between actual and expected types of an
319// argument:
320// %formal = OpFunctionParameter %formal_type
321// ...
322// %res = OpFunctionCall %ty %fun %actual ...
323// implies that %actual is of %formal_type, and in case of opaque pointers.
324// We may need to insert a bitcast to ensure this.
325void validateFunCallMachineDef(const SPIRVSubtarget &STI,
326 MachineRegisterInfo *DefMRI,
327 MachineRegisterInfo *CallMRI,
328 SPIRVGlobalRegistry &GR, MachineInstr &FunCall,
329 MachineInstr *FunDef) {
330 if (FunDef->getOpcode() != SPIRV::OpFunction)
331 return;
332 unsigned OpIdx = 3;
333 for (FunDef = FunDef->getNextNode();
334 FunDef && FunDef->getOpcode() == SPIRV::OpFunctionParameter &&
335 OpIdx < FunCall.getNumOperands();
336 FunDef = FunDef->getNextNode(), OpIdx++) {
337 SPIRVTypeInst DefPtrType =
338 DefMRI->getVRegDef(Reg: FunDef->getOperand(i: 1).getReg());
339 SPIRVTypeInst DefElemType =
340 DefPtrType && DefPtrType->getOpcode() == SPIRV::OpTypePointer
341 ? GR.getSPIRVTypeForVReg(VReg: DefPtrType->getOperand(i: 2).getReg(),
342 MF: DefPtrType->getParent()->getParent())
343 : nullptr;
344 if (DefElemType) {
345 const Type *DefElemTy = GR.getTypeForSPIRVType(Ty: DefElemType);
346 // validatePtrTypes() works in the context if the call site
347 // When we process historical records about forward calls
348 // we need to switch context to the (forward) call site and
349 // then restore it back to the current machine function.
350 MachineFunction *CurMF =
351 GR.setCurrentFunc(*FunCall.getParent()->getParent());
352 validatePtrTypes(STI, MRI: CallMRI, GR, I&: FunCall, OpIdx, ResType: DefElemType,
353 ResTy: DefElemTy);
354 GR.setCurrentFunc(*CurMF);
355 }
356 }
357}
358
359// Ensure there is no mismatch between actual and expected arg types: calls
360// with a processed definition. Return Function pointer if it's a forward
361// call (ahead of definition), and nullptr otherwise.
362const Function *validateFunCall(const SPIRVSubtarget &STI,
363 MachineRegisterInfo *CallMRI,
364 SPIRVGlobalRegistry &GR,
365 MachineInstr &FunCall) {
366 const GlobalValue *GV = FunCall.getOperand(i: 2).getGlobal();
367 const Function *F = dyn_cast<Function>(Val: GV);
368 MachineInstr *FunDef =
369 const_cast<MachineInstr *>(GR.getFunctionDefinition(F));
370 if (!FunDef)
371 return F;
372 MachineRegisterInfo *DefMRI = &FunDef->getParent()->getParent()->getRegInfo();
373 validateFunCallMachineDef(STI, DefMRI, CallMRI, GR, FunCall, FunDef);
374 return nullptr;
375}
376
377// Ensure there is no mismatch between actual and expected arg types: calls
378// ahead of a processed definition.
379void validateForwardCalls(const SPIRVSubtarget &STI,
380 MachineRegisterInfo *DefMRI, SPIRVGlobalRegistry &GR,
381 MachineInstr &FunDef) {
382 const Function *F = GR.getFunctionByDefinition(MI: &FunDef);
383 if (SmallPtrSet<MachineInstr *, 8> *FwdCalls = GR.getForwardCalls(F))
384 for (MachineInstr *FunCall : *FwdCalls) {
385 MachineRegisterInfo *CallMRI =
386 &FunCall->getParent()->getParent()->getRegInfo();
387 validateFunCallMachineDef(STI, DefMRI, CallMRI, GR, FunCall&: *FunCall, FunDef: &FunDef);
388 }
389}
390
391// Validation of an access chain.
392void validateAccessChain(const SPIRVSubtarget &STI, MachineRegisterInfo *MRI,
393 SPIRVGlobalRegistry &GR, MachineInstr &I) {
394 SPIRVTypeInst BaseTypeInst = GR.getSPIRVTypeForVReg(VReg: I.getOperand(i: 0).getReg());
395 if (BaseTypeInst && BaseTypeInst->getOpcode() == SPIRV::OpTypePointer) {
396 SPIRVTypeInst BaseElemType =
397 GR.getSPIRVTypeForVReg(VReg: BaseTypeInst->getOperand(i: 2).getReg());
398 validatePtrTypes(STI, MRI, GR, I, OpIdx: 2, ResType: BaseElemType);
399 }
400}
401
402// TODO: the logic of inserting additional bitcast's is to be moved
403// to pre-IRTranslation passes eventually
404void SPIRVTargetLowering::finalizeLowering(MachineFunction &MF) const {
405 // finalizeLowering() is called twice (see GlobalISel/InstructionSelect.cpp)
406 // We'd like to avoid the needless second processing pass.
407 if (ProcessedMF.find(x: &MF) != ProcessedMF.end())
408 return;
409
410 MachineRegisterInfo *MRI = &MF.getRegInfo();
411 SPIRVGlobalRegistry &GR = *STI.getSPIRVGlobalRegistry();
412 GR.setCurrentFunc(MF);
413 for (MachineFunction::iterator I = MF.begin(), E = MF.end(); I != E; ++I) {
414 MachineBasicBlock *MBB = &*I;
415 for (MachineBasicBlock::iterator MBBI = MBB->begin(), MBBE = MBB->end();
416 MBBI != MBBE;) {
417 MachineInstr &MI = *MBBI++;
418 switch (MI.getOpcode()) {
419 case SPIRV::OpAtomicLoad:
420 case SPIRV::OpAtomicExchange:
421 case SPIRV::OpAtomicCompareExchange:
422 case SPIRV::OpAtomicCompareExchangeWeak:
423 case SPIRV::OpAtomicIIncrement:
424 case SPIRV::OpAtomicIDecrement:
425 case SPIRV::OpAtomicIAdd:
426 case SPIRV::OpAtomicISub:
427 case SPIRV::OpAtomicSMin:
428 case SPIRV::OpAtomicUMin:
429 case SPIRV::OpAtomicSMax:
430 case SPIRV::OpAtomicUMax:
431 case SPIRV::OpAtomicAnd:
432 case SPIRV::OpAtomicOr:
433 case SPIRV::OpAtomicXor:
434 // for the above listed instructions
435 // OpAtomicXXX <ResType>, ptr %Op, ...
436 // implies that %Op is a pointer to <ResType>
437 case SPIRV::OpLoad:
438 // OpLoad <ResType>, ptr %Op implies that %Op is a pointer to <ResType>
439 if (enforcePtrTypeCompatibility(I&: MI, PtrOpIdx: 2, OpIdx: 0))
440 break;
441
442 validatePtrTypes(STI, MRI, GR, I&: MI, OpIdx: 2,
443 ResType: GR.getSPIRVTypeForVReg(VReg: MI.getOperand(i: 0).getReg()));
444 break;
445 case SPIRV::OpAtomicStore:
446 // OpAtomicStore ptr %Op, <Scope>, <Mem>, <Obj>
447 // implies that %Op points to the <Obj>'s type
448 validatePtrTypes(STI, MRI, GR, I&: MI, OpIdx: 0,
449 ResType: GR.getSPIRVTypeForVReg(VReg: MI.getOperand(i: 3).getReg()));
450 break;
451 case SPIRV::OpStore:
452 // OpStore ptr %Op, <Obj> implies that %Op points to the <Obj>'s type
453 validatePtrTypes(STI, MRI, GR, I&: MI, OpIdx: 0,
454 ResType: GR.getSPIRVTypeForVReg(VReg: MI.getOperand(i: 1).getReg()));
455 break;
456 case SPIRV::OpPtrCastToGeneric:
457 case SPIRV::OpGenericCastToPtr:
458 case SPIRV::OpGenericCastToPtrExplicit:
459 validateAccessChain(STI, MRI, GR, I&: MI);
460 break;
461 case SPIRV::OpPtrAccessChain:
462 case SPIRV::OpInBoundsPtrAccessChain:
463 if (MI.getNumOperands() == 4)
464 validateAccessChain(STI, MRI, GR, I&: MI);
465 break;
466
467 case SPIRV::OpFunctionCall:
468 // ensure there is no mismatch between actual and expected arg types:
469 // calls with a processed definition
470 if (MI.getNumOperands() > 3)
471 if (const Function *F = validateFunCall(STI, CallMRI: MRI, GR, FunCall&: MI))
472 GR.addForwardCall(F, MI: &MI);
473 break;
474 case SPIRV::OpFunction:
475 // ensure there is no mismatch between actual and expected arg types:
476 // calls ahead of a processed definition
477 validateForwardCalls(STI, DefMRI: MRI, GR, FunDef&: MI);
478 break;
479
480 // ensure that LLVM IR add/sub instructions result in logical SPIR-V
481 // instructions when applied to bool type
482 case SPIRV::OpIAddS:
483 case SPIRV::OpIAddV:
484 case SPIRV::OpISubS:
485 case SPIRV::OpISubV:
486 if (GR.isScalarOrVectorOfType(VReg: MI.getOperand(i: 1).getReg(),
487 TypeOpcode: SPIRV::OpTypeBool))
488 MI.setDesc(STI.getInstrInfo()->get(Opcode: SPIRV::OpLogicalNotEqual));
489 break;
490
491 // ensure that LLVM IR bitwise instructions result in logical SPIR-V
492 // instructions when applied to bool type
493 case SPIRV::OpBitwiseOrS:
494 case SPIRV::OpBitwiseOrV:
495 if (GR.isScalarOrVectorOfType(VReg: MI.getOperand(i: 1).getReg(),
496 TypeOpcode: SPIRV::OpTypeBool))
497 MI.setDesc(STI.getInstrInfo()->get(Opcode: SPIRV::OpLogicalOr));
498 break;
499 case SPIRV::OpBitwiseAndS:
500 case SPIRV::OpBitwiseAndV:
501 if (GR.isScalarOrVectorOfType(VReg: MI.getOperand(i: 1).getReg(),
502 TypeOpcode: SPIRV::OpTypeBool))
503 MI.setDesc(STI.getInstrInfo()->get(Opcode: SPIRV::OpLogicalAnd));
504 break;
505 case SPIRV::OpBitwiseXorS:
506 case SPIRV::OpBitwiseXorV:
507 if (GR.isScalarOrVectorOfType(VReg: MI.getOperand(i: 1).getReg(),
508 TypeOpcode: SPIRV::OpTypeBool))
509 MI.setDesc(STI.getInstrInfo()->get(Opcode: SPIRV::OpLogicalNotEqual));
510 break;
511 case SPIRV::OpLifetimeStart:
512 case SPIRV::OpLifetimeStop:
513 if (MI.getOperand(i: 1).getImm() > 0)
514 validateLifetimeStart(STI, MRI, GR, I&: MI);
515 break;
516 case SPIRV::OpGroupAsyncCopy:
517 validatePtrUnwrapStructField(STI, MRI, GR, I&: MI, OpIdx: 3);
518 validatePtrUnwrapStructField(STI, MRI, GR, I&: MI, OpIdx: 4);
519 break;
520 case SPIRV::OpGroupWaitEvents:
521 // OpGroupWaitEvents ..., ..., <pointer to OpTypeEvent>
522 validateGroupWaitEventsPtr(STI, MRI, GR, I&: MI);
523 break;
524 case SPIRV::OpConstantI: {
525 SPIRVTypeInst Type = GR.getSPIRVTypeForVReg(VReg: MI.getOperand(i: 1).getReg());
526 if (Type->getOpcode() != SPIRV::OpTypeInt && MI.getOperand(i: 2).isImm() &&
527 MI.getOperand(i: 2).getImm() == 0) {
528 // Validate the null constant of a target extension type
529 MI.setDesc(STI.getInstrInfo()->get(Opcode: SPIRV::OpConstantNull));
530 for (unsigned i = MI.getNumOperands() - 1; i > 1; --i)
531 MI.removeOperand(OpNo: i);
532 }
533 } break;
534 case SPIRV::OpExtInst: {
535 // prefetch
536 if (!MI.getOperand(i: 2).isImm() || !MI.getOperand(i: 3).isImm() ||
537 MI.getOperand(i: 2).getImm() != SPIRV::InstructionSet::OpenCL_std)
538 continue;
539 switch (MI.getOperand(i: 3).getImm()) {
540 case SPIRV::OpenCLExtInst::frexp:
541 case SPIRV::OpenCLExtInst::lgamma_r:
542 case SPIRV::OpenCLExtInst::remquo: {
543 // The last operand must be of a pointer to i32 or vector of i32
544 // values.
545 MachineIRBuilder MIB(MI);
546 SPIRVTypeInst Int32Type = GR.getOrCreateSPIRVIntegerType(BitWidth: 32, MIRBuilder&: MIB);
547 SPIRVTypeInst RetType = MRI->getVRegDef(Reg: MI.getOperand(i: 1).getReg());
548 assert(RetType && "Expected return type");
549 validatePtrTypes(STI, MRI, GR, I&: MI, OpIdx: MI.getNumOperands() - 1,
550 ResType: RetType->getOpcode() != SPIRV::OpTypeVector
551 ? Int32Type
552 : GR.getOrCreateSPIRVVectorType(
553 BaseType: Int32Type, NumElements: RetType->getOperand(i: 2).getImm(),
554 MIRBuilder&: MIB, EmitIR: false));
555 } break;
556 case SPIRV::OpenCLExtInst::fract:
557 case SPIRV::OpenCLExtInst::modf:
558 case SPIRV::OpenCLExtInst::sincos:
559 // The last operand must be of a pointer to the base type represented
560 // by the previous operand.
561 assert(MI.getOperand(MI.getNumOperands() - 2).isReg() &&
562 "Expected v-reg");
563 validatePtrTypes(
564 STI, MRI, GR, I&: MI, OpIdx: MI.getNumOperands() - 1,
565 ResType: GR.getSPIRVTypeForVReg(
566 VReg: MI.getOperand(i: MI.getNumOperands() - 2).getReg()));
567 break;
568 case SPIRV::OpenCLExtInst::prefetch:
569 // Expected `ptr` type is a pointer to float, integer or vector, but
570 // the pontee value can be wrapped into a struct.
571 assert(MI.getOperand(MI.getNumOperands() - 2).isReg() &&
572 "Expected v-reg");
573 validatePtrUnwrapStructField(STI, MRI, GR, I&: MI,
574 OpIdx: MI.getNumOperands() - 2);
575 break;
576 }
577 } break;
578 }
579 }
580 }
581 ProcessedMF.insert(x: &MF);
582 TargetLowering::finalizeLowering(MF);
583}
584
585// Modifies either operand PtrOpIdx or OpIdx so that the pointee type of
586// PtrOpIdx matches the type for operand OpIdx. Returns true if they already
587// match or if the instruction was modified to make them match.
588bool SPIRVTargetLowering::enforcePtrTypeCompatibility(
589 MachineInstr &I, unsigned int PtrOpIdx, unsigned int OpIdx) const {
590 SPIRVGlobalRegistry &GR = *STI.getSPIRVGlobalRegistry();
591 SPIRVTypeInst PtrType = GR.getResultType(VReg: I.getOperand(i: PtrOpIdx).getReg());
592 SPIRVTypeInst PointeeType = GR.getPointeeType(PtrType);
593 SPIRVTypeInst OpType = GR.getResultType(VReg: I.getOperand(i: OpIdx).getReg());
594
595 if (PointeeType == OpType)
596 return true;
597
598 if (typesLogicallyMatch(Ty1: PointeeType, Ty2: OpType, GR)) {
599 // Apply OpCopyLogical to OpIdx.
600 if (I.getOperand(i: OpIdx).isDef() &&
601 insertLogicalCopyOnResult(I, NewResultType: PointeeType)) {
602 return true;
603 }
604
605 llvm_unreachable("Unable to add OpCopyLogical yet.");
606 return false;
607 }
608
609 return false;
610}
611
612bool SPIRVTargetLowering::insertLogicalCopyOnResult(
613 MachineInstr &I, SPIRVTypeInst NewResultType) const {
614 MachineRegisterInfo *MRI = &I.getMF()->getRegInfo();
615 SPIRVGlobalRegistry &GR = *STI.getSPIRVGlobalRegistry();
616
617 Register NewResultReg =
618 createVirtualRegister(SpvType: NewResultType, GR: &GR, MRI, MF: *I.getMF());
619 Register NewTypeReg = GR.getSPIRVTypeID(SpirvType: NewResultType);
620
621 assert(llvm::size(I.defs()) == 1 && "Expected only one def");
622 MachineOperand &OldResult = *I.defs().begin();
623 Register OldResultReg = OldResult.getReg();
624 MachineOperand &OldType = *I.uses().begin();
625 Register OldTypeReg = OldType.getReg();
626
627 OldResult.setReg(NewResultReg);
628 OldType.setReg(NewTypeReg);
629
630 MachineIRBuilder MIB(*I.getNextNode());
631 MIB.buildInstr(Opcode: SPIRV::OpCopyLogical)
632 .addDef(RegNo: OldResultReg)
633 .addUse(RegNo: OldTypeReg)
634 .addUse(RegNo: NewResultReg)
635 .constrainAllUses(TII: *STI.getInstrInfo(), TRI: *STI.getRegisterInfo(),
636 RBI: *STI.getRegBankInfo());
637 return true;
638}
639
640TargetLowering::AtomicExpansionKind
641SPIRVTargetLowering::shouldExpandAtomicRMWInIR(const AtomicRMWInst *RMW) const {
642 switch (RMW->getOperation()) {
643 case AtomicRMWInst::FAdd:
644 case AtomicRMWInst::FSub:
645 case AtomicRMWInst::FMin:
646 case AtomicRMWInst::FMax:
647 return AtomicExpansionKind::None;
648 case AtomicRMWInst::UIncWrap:
649 case AtomicRMWInst::UDecWrap:
650 return AtomicExpansionKind::CmpXChg;
651 default:
652 return TargetLowering::shouldExpandAtomicRMWInIR(RMW);
653 }
654}
655
656TargetLowering::AtomicExpansionKind
657SPIRVTargetLowering::shouldCastAtomicRMWIInIR(AtomicRMWInst *RMWI) const {
658 // TODO: Pointer operand should be cast to integer in atomicrmw xchg, since
659 // SPIR-V only supports atomic exchange for integer and floating-point types.
660 return AtomicExpansionKind::None;
661}
662