1//===-- NVPTXISelLowering.cpp - NVPTX DAG Lowering Implementation ---------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file defines the interfaces that NVPTX uses to lower LLVM code into a
10// selection DAG.
11//
12//===----------------------------------------------------------------------===//
13
14#include "NVPTXISelLowering.h"
15#include "MCTargetDesc/NVPTXBaseInfo.h"
16#include "NVPTX.h"
17#include "NVPTXISelDAGToDAG.h"
18#include "NVPTXSelectionDAGInfo.h"
19#include "NVPTXSubtarget.h"
20#include "NVPTXTargetMachine.h"
21#include "NVPTXTargetObjectFile.h"
22#include "NVPTXUtilities.h"
23#include "llvm/ADT/APFloat.h"
24#include "llvm/ADT/APInt.h"
25#include "llvm/ADT/STLExtras.h"
26#include "llvm/ADT/SmallVector.h"
27#include "llvm/ADT/StringRef.h"
28#include "llvm/CodeGen/Analysis.h"
29#include "llvm/CodeGen/ISDOpcodes.h"
30#include "llvm/CodeGen/MachineFunction.h"
31#include "llvm/CodeGen/MachineJumpTableInfo.h"
32#include "llvm/CodeGen/MachineMemOperand.h"
33#include "llvm/CodeGen/SDPatternMatch.h"
34#include "llvm/CodeGen/SelectionDAG.h"
35#include "llvm/CodeGen/SelectionDAGNodes.h"
36#include "llvm/CodeGen/TargetCallingConv.h"
37#include "llvm/CodeGen/TargetLowering.h"
38#include "llvm/CodeGen/ValueTypes.h"
39#include "llvm/CodeGenTypes/MachineValueType.h"
40#include "llvm/IR/Argument.h"
41#include "llvm/IR/Attributes.h"
42#include "llvm/IR/Constants.h"
43#include "llvm/IR/DataLayout.h"
44#include "llvm/IR/DerivedTypes.h"
45#include "llvm/IR/DiagnosticInfo.h"
46#include "llvm/IR/FPEnv.h"
47#include "llvm/IR/Function.h"
48#include "llvm/IR/GlobalValue.h"
49#include "llvm/IR/IRBuilder.h"
50#include "llvm/IR/Instruction.h"
51#include "llvm/IR/Instructions.h"
52#include "llvm/IR/IntrinsicsNVPTX.h"
53#include "llvm/IR/Module.h"
54#include "llvm/IR/Type.h"
55#include "llvm/IR/Value.h"
56#include "llvm/Support/Alignment.h"
57#include "llvm/Support/AtomicOrdering.h"
58#include "llvm/Support/Casting.h"
59#include "llvm/Support/CodeGen.h"
60#include "llvm/Support/CommandLine.h"
61#include "llvm/Support/ErrorHandling.h"
62#include "llvm/Support/KnownBits.h"
63#include "llvm/Support/NVPTXAddrSpace.h"
64#include "llvm/Support/raw_ostream.h"
65#include "llvm/Target/TargetMachine.h"
66#include "llvm/Target/TargetOptions.h"
67#include <algorithm>
68#include <cassert>
69#include <cmath>
70#include <cstdint>
71#include <iterator>
72#include <optional>
73#include <string>
74#include <tuple>
75#include <utility>
76#include <vector>
77
78#define DEBUG_TYPE "nvptx-lower"
79
80using namespace llvm;
81
82static cl::opt<bool> sched4reg(
83 "nvptx-sched4reg",
84 cl::desc("NVPTX Specific: schedule for register pressue"), cl::init(Val: false));
85
86static cl::opt<unsigned> FMAContractLevelOpt(
87 "nvptx-fma-level", cl::Hidden,
88 cl::desc("NVPTX Specific: FMA contraction (0: don't do it"
89 " 1: do it 2: do it aggressively"),
90 cl::init(Val: 2));
91
92static cl::opt<NVPTX::DivPrecisionLevel> UsePrecDivF32(
93 "nvptx-prec-divf32", cl::Hidden,
94 cl::desc(
95 "NVPTX Specific: Override the precision of the lowering for f32 fdiv"),
96 cl::values(
97 clEnumValN(NVPTX::DivPrecisionLevel::Approx, "0", "Use div.approx"),
98 clEnumValN(NVPTX::DivPrecisionLevel::Full, "1", "Use div.full"),
99 clEnumValN(NVPTX::DivPrecisionLevel::IEEE754, "2",
100 "Use IEEE Compliant F32 div.rnd if available (default)"),
101 clEnumValN(NVPTX::DivPrecisionLevel::IEEE754_NoFTZ, "3",
102 "Use IEEE Compliant F32 div.rnd if available, no FTZ")),
103 cl::init(Val: NVPTX::DivPrecisionLevel::IEEE754));
104
105static cl::opt<bool> UsePrecSqrtF32(
106 "nvptx-prec-sqrtf32", cl::Hidden,
107 cl::desc("NVPTX Specific: 0 use sqrt.approx, 1 use sqrt.rn."),
108 cl::init(Val: true));
109
110/// Whereas CUDA's implementation (see libdevice) uses ex2.approx for exp2(), it
111/// does NOT use lg2.approx for log2, so this is disabled by default.
112static cl::opt<bool> UseApproxLog2F32(
113 "nvptx-approx-log2f32",
114 cl::desc("NVPTX Specific: whether to use lg2.approx for log2"),
115 cl::init(Val: false));
116
117NVPTX::DivPrecisionLevel
118NVPTXTargetLowering::getDivF32Level(const MachineFunction &MF,
119 const SDNode &N) const {
120 // If nvptx-prec-div32=N is used on the command-line, always honor it
121 if (UsePrecDivF32.getNumOccurrences() > 0)
122 return UsePrecDivF32;
123
124 const SDNodeFlags Flags = N.getFlags();
125 if (Flags.hasApproximateFuncs())
126 return NVPTX::DivPrecisionLevel::Approx;
127
128 return NVPTX::DivPrecisionLevel::IEEE754;
129}
130
131bool NVPTXTargetLowering::usePrecSqrtF32(const SDNode *N) const {
132 // If nvptx-prec-sqrtf32 is used on the command-line, always honor it
133 if (UsePrecSqrtF32.getNumOccurrences() > 0)
134 return UsePrecSqrtF32;
135
136 if (N) {
137 const SDNodeFlags Flags = N->getFlags();
138 if (Flags.hasApproximateFuncs())
139 return false;
140 }
141
142 return true;
143}
144
145bool NVPTXTargetLowering::useF32FTZ(const MachineFunction &MF) const {
146 return MF.getDenormalMode(FPType: APFloat::IEEEsingle()).Output ==
147 DenormalMode::PreserveSign;
148}
149
150static bool IsPTXVectorType(MVT VT) {
151 switch (VT.SimpleTy) {
152 default:
153 return false;
154 case MVT::v2i1:
155 case MVT::v4i1:
156 case MVT::v2i8:
157 case MVT::v4i8:
158 case MVT::v8i8: // <2 x i8x4>
159 case MVT::v16i8: // <4 x i8x4>
160 case MVT::v2i16:
161 case MVT::v4i16:
162 case MVT::v8i16: // <4 x i16x2>
163 case MVT::v2i32:
164 case MVT::v4i32:
165 case MVT::v2i64:
166 case MVT::v2f16:
167 case MVT::v4f16:
168 case MVT::v8f16: // <4 x f16x2>
169 case MVT::v2bf16:
170 case MVT::v4bf16:
171 case MVT::v8bf16: // <4 x bf16x2>
172 case MVT::v2f32:
173 case MVT::v4f32:
174 case MVT::v2f64:
175 case MVT::v4i64:
176 case MVT::v4f64:
177 case MVT::v8i32:
178 case MVT::v8f32:
179 case MVT::v16f16: // <8 x f16x2>
180 case MVT::v16bf16: // <8 x bf16x2>
181 case MVT::v16i16: // <8 x i16x2>
182 case MVT::v32i8: // <8 x i8x4>
183 return true;
184 }
185}
186
187// When legalizing vector loads/stores, this function is called, which does two
188// things:
189// 1. Determines Whether the vector is something we want to custom lower,
190// std::nullopt is returned if we do not want to custom lower it.
191// 2. If we do want to handle it, returns two parameters:
192// - unsigned int NumElts - The number of elements in the final vector
193// - EVT EltVT - The type of the elements in the final vector
194static std::optional<std::pair<unsigned int, MVT>>
195getVectorLoweringShape(EVT VectorEVT, const NVPTXSubtarget &STI,
196 unsigned AddressSpace) {
197 const bool CanLowerTo256Bit = STI.has256BitVectorLoadStore(AS: AddressSpace);
198
199 if (CanLowerTo256Bit && VectorEVT.isScalarInteger() &&
200 VectorEVT.getSizeInBits() == 256)
201 return {{4, MVT::i64}};
202
203 if (!VectorEVT.isSimple())
204 return std::nullopt;
205 const MVT VectorVT = VectorEVT.getSimpleVT();
206
207 if (!VectorVT.isVector()) {
208 if (VectorVT == MVT::i128 || VectorVT == MVT::f128)
209 return {{2, MVT::i64}};
210 return std::nullopt;
211 }
212
213 const MVT EltVT = VectorVT.getVectorElementType();
214 const unsigned NumElts = VectorVT.getVectorNumElements();
215
216 // The size of the PTX virtual register that holds a packed type.
217 unsigned PackRegSize;
218
219 // We only handle "native" vector sizes for now, e.g. <4 x double> is not
220 // legal. We can (and should) split that into 2 stores of <2 x double> here
221 // but I'm leaving that as a TODO for now.
222 switch (VectorVT.SimpleTy) {
223 default:
224 return std::nullopt;
225
226 case MVT::v4i64:
227 case MVT::v4f64:
228 // This is a "native" vector type iff the address space is global and the
229 // target supports 256-bit loads/stores
230 if (!CanLowerTo256Bit)
231 return std::nullopt;
232 [[fallthrough]];
233 case MVT::v2i8:
234 case MVT::v2i64:
235 case MVT::v2f64:
236 // This is a "native" vector type
237 return std::pair(NumElts, EltVT);
238
239 case MVT::v16f16: // <8 x f16x2>
240 case MVT::v16bf16: // <8 x bf16x2>
241 case MVT::v16i16: // <8 x i16x2>
242 case MVT::v32i8: // <8 x i8x4>
243 // This can be upsized into a "native" vector type iff the address space is
244 // global and the target supports 256-bit loads/stores.
245 if (!CanLowerTo256Bit)
246 return std::nullopt;
247 [[fallthrough]];
248 case MVT::v2i16: // <1 x i16x2>
249 case MVT::v2f16: // <1 x f16x2>
250 case MVT::v2bf16: // <1 x bf16x2>
251 case MVT::v4i8: // <1 x i8x4>
252 case MVT::v4i16: // <2 x i16x2>
253 case MVT::v4f16: // <2 x f16x2>
254 case MVT::v4bf16: // <2 x bf16x2>
255 case MVT::v8i8: // <2 x i8x4>
256 case MVT::v8f16: // <4 x f16x2>
257 case MVT::v8bf16: // <4 x bf16x2>
258 case MVT::v8i16: // <4 x i16x2>
259 case MVT::v16i8: // <4 x i8x4>
260 PackRegSize = 32;
261 break;
262
263 case MVT::v8f32: // <4 x f32x2>
264 case MVT::v8i32: // <4 x i32x2>
265 // This is a "native" vector type iff the address space is global and the
266 // target supports 256-bit loads/stores
267 if (!CanLowerTo256Bit)
268 return std::nullopt;
269 [[fallthrough]];
270 case MVT::v2f32: // <1 x f32x2>
271 case MVT::v4f32: // <2 x f32x2>
272 case MVT::v2i32: // <1 x i32x2>
273 case MVT::v4i32: // <2 x i32x2>
274 if (!STI.hasF32x2Instructions())
275 return std::pair(NumElts, EltVT);
276 PackRegSize = 64;
277 break;
278 }
279
280 // If we reach here, then we can pack 2 or more elements into a single 32-bit
281 // or 64-bit PTX register and treat the vector as a new vector containing
282 // packed elements.
283
284 // Number of elements to pack in one word.
285 const unsigned NPerReg = PackRegSize / EltVT.getSizeInBits();
286
287 return std::pair(NumElts / NPerReg, MVT::getVectorVT(VT: EltVT, NumElements: NPerReg));
288}
289
290/// ComputePTXValueVTs - For the given Type \p Ty, returns the set of primitive
291/// legal-ish MVTs that compose it. Unlike ComputeValueVTs, this will legalize
292/// the types as required by the calling convention (with special handling for
293/// i8s).
294/// NOTE: This is a band-aid for code that expects ComputeValueVTs to return the
295/// same number of types as the Ins/Outs arrays in LowerFormalArguments,
296/// LowerCall, and LowerReturn.
297static void ComputePTXValueVTs(const TargetLowering &TLI, const DataLayout &DL,
298 LLVMContext &Ctx, CallingConv::ID CallConv,
299 Type *Ty, SmallVectorImpl<EVT> &ValueVTs,
300 SmallVectorImpl<uint64_t> &Offsets,
301 uint64_t StartingOffset = 0) {
302 SmallVector<EVT, 16> TempVTs;
303 SmallVector<uint64_t, 16> TempOffsets;
304 ComputeValueVTs(TLI, DL, Ty, ValueVTs&: TempVTs, /*MemVTs=*/nullptr, FixedOffsets: &TempOffsets,
305 StartingOffset);
306
307 for (const auto [VT, Off] : zip(t&: TempVTs, u&: TempOffsets)) {
308 MVT RegisterVT = TLI.getRegisterTypeForCallingConv(Context&: Ctx, CC: CallConv, VT);
309 unsigned NumRegs = TLI.getNumRegistersForCallingConv(Context&: Ctx, CC: CallConv, VT);
310
311 // Since we actually can load/store b8, we need to ensure that we'll use
312 // the original sized type for any i8s or i8 vectors.
313 if (VT.getScalarType() == MVT::i8) {
314 if (RegisterVT == MVT::i16)
315 RegisterVT = MVT::i8;
316 else if (RegisterVT == MVT::v2i16)
317 RegisterVT = MVT::v2i8;
318 else
319 assert(RegisterVT == MVT::v4i8 &&
320 "Expected v4i8, v2i16, or i16 for i8 RegisterVT");
321 }
322
323 // TODO: This is horribly incorrect for cases where the vector elements are
324 // not a multiple of bytes (ex i1) and legal or i8. However, this problem
325 // has existed for as long as NVPTX has and no one has complained, so we'll
326 // leave it for now.
327 for (unsigned I : seq(Size: NumRegs)) {
328 ValueVTs.push_back(Elt: RegisterVT);
329 Offsets.push_back(Elt: Off + I * RegisterVT.getStoreSize());
330 }
331 }
332}
333
334// We return an EVT that can hold N VTs
335// If the VT is a vector, the resulting EVT is a flat vector with the same
336// element type as VT's element type.
337static EVT getVectorizedVT(EVT VT, unsigned N, LLVMContext &C) {
338 if (N == 1)
339 return VT;
340
341 return VT.isVector() ? EVT::getVectorVT(Context&: C, VT: VT.getScalarType(),
342 NumElements: VT.getVectorNumElements() * N)
343 : EVT::getVectorVT(Context&: C, VT, NumElements: N);
344}
345
346static SDValue getExtractVectorizedValue(SDValue V, unsigned I, EVT VT,
347 const SDLoc &dl, SelectionDAG &DAG) {
348 if (V.getValueType() == VT) {
349 assert(I == 0 && "Index must be 0 for scalar value");
350 return V;
351 }
352
353 if (!VT.isVector())
354 return DAG.getNode(Opcode: ISD::EXTRACT_VECTOR_ELT, DL: dl, VT, N1: V,
355 N2: DAG.getVectorIdxConstant(Val: I, DL: dl));
356
357 return DAG.getNode(
358 Opcode: ISD::EXTRACT_SUBVECTOR, DL: dl, VT, N1: V,
359 N2: DAG.getVectorIdxConstant(Val: I * VT.getVectorNumElements(), DL: dl));
360}
361
362template <typename T>
363static inline SDValue getBuildVectorizedValue(unsigned N, const SDLoc &dl,
364 SelectionDAG &DAG, T GetElement) {
365 if (N == 1)
366 return GetElement(0);
367
368 SmallVector<SDValue, 8> Values;
369 for (const unsigned I : llvm::seq(Size: N)) {
370 SDValue Val = GetElement(I);
371 if (Val.getValueType().isVector())
372 DAG.ExtractVectorElements(Op: Val, Args&: Values);
373 else
374 Values.push_back(Elt: Val);
375 }
376
377 EVT VT = EVT::getVectorVT(Context&: *DAG.getContext(), VT: Values[0].getValueType(),
378 NumElements: Values.size());
379 return DAG.getBuildVector(VT, DL: dl, Ops: Values);
380}
381
382/// PromoteScalarIntegerPTX
383/// Used to make sure the arguments/returns are suitable for passing
384/// and promote them to a larger size if they're not.
385///
386/// The promoted type is placed in \p PromoteVT if the function returns true.
387static EVT promoteScalarIntegerPTX(const EVT VT) {
388 if (VT.isScalarInteger()) {
389 switch (PowerOf2Ceil(A: VT.getFixedSizeInBits())) {
390 default:
391 llvm_unreachable(
392 "Promotion is not suitable for scalars of size larger than 64-bits");
393 case 1:
394 return MVT::i1;
395 case 2:
396 case 4:
397 case 8:
398 return MVT::i8;
399 case 16:
400 return MVT::i16;
401 case 32:
402 return MVT::i32;
403 case 64:
404 return MVT::i64;
405 }
406 }
407 return VT;
408}
409
410// Check whether we can merge loads/stores of some of the pieces of a
411// flattened function parameter or return value into a single vector
412// load/store.
413//
414// The flattened parameter is represented as a list of EVTs and
415// offsets, and the whole structure is aligned to ParamAlignment. This
416// function determines whether we can load/store pieces of the
417// parameter starting at index Idx using a single vectorized op of
418// size AccessSize. If so, it returns the number of param pieces
419// covered by the vector op. Otherwise, it returns 1.
420template <typename T>
421static unsigned canMergeParamLoadStoresStartingAt(
422 unsigned Idx, uint32_t AccessSize, const SmallVectorImpl<EVT> &ValueVTs,
423 const SmallVectorImpl<T> &Offsets, Align ParamAlignment) {
424
425 // Can't vectorize if param alignment is not sufficient.
426 if (ParamAlignment < AccessSize)
427 return 1;
428 // Can't vectorize if offset is not aligned.
429 if (Offsets[Idx] & (AccessSize - 1))
430 return 1;
431
432 EVT EltVT = ValueVTs[Idx];
433 unsigned EltSize = EltVT.getStoreSize();
434
435 // Element is too large to vectorize.
436 if (EltSize >= AccessSize)
437 return 1;
438
439 unsigned NumElts = AccessSize / EltSize;
440 // Can't vectorize if AccessBytes if not a multiple of EltSize.
441 if (AccessSize != EltSize * NumElts)
442 return 1;
443
444 // We don't have enough elements to vectorize.
445 if (Idx + NumElts > ValueVTs.size())
446 return 1;
447
448 // PTX ISA can only deal with 2- and 4-element vector ops.
449 if (NumElts != 4 && NumElts != 2)
450 return 1;
451
452 for (unsigned j = Idx + 1; j < Idx + NumElts; ++j) {
453 // Types do not match.
454 if (ValueVTs[j] != EltVT)
455 return 1;
456
457 // Elements are not contiguous.
458 if (Offsets[j] - Offsets[j - 1] != EltSize)
459 return 1;
460 }
461 // OK. We can vectorize ValueVTs[i..i+NumElts)
462 return NumElts;
463}
464
465// Computes whether and how we can vectorize the loads/stores of a
466// flattened function parameter or return value.
467//
468// The flattened parameter is represented as the list of ValueVTs and
469// Offsets, and is aligned to ParamAlignment bytes. We return a vector
470// of the same size as ValueVTs indicating how each piece should be
471// loaded/stored (i.e. as a scalar, or as part of a vector
472// load/store).
473template <typename T>
474static SmallVector<unsigned, 16>
475VectorizePTXValueVTs(const SmallVectorImpl<EVT> &ValueVTs,
476 const SmallVectorImpl<T> &Offsets, Align ParamAlignment,
477 bool IsVAArg = false) {
478 // Set vector size to match ValueVTs and mark all elements as
479 // scalars by default.
480
481 if (IsVAArg)
482 return SmallVector<unsigned>(ValueVTs.size(), 1);
483
484 SmallVector<unsigned, 16> VectorInfo;
485
486 const auto GetNumElts = [&](unsigned I) -> unsigned {
487 for (const unsigned AccessSize : {16, 8, 4, 2}) {
488 const unsigned NumElts = canMergeParamLoadStoresStartingAt(
489 I, AccessSize, ValueVTs, Offsets, ParamAlignment);
490 assert((NumElts == 1 || NumElts == 2 || NumElts == 4) &&
491 "Unexpected vectorization size");
492 if (NumElts != 1)
493 return NumElts;
494 }
495 return 1;
496 };
497
498 // Check what we can vectorize using 128/64/32-bit accesses.
499 for (unsigned I = 0, E = ValueVTs.size(); I != E;) {
500 const unsigned NumElts = GetNumElts(I);
501 VectorInfo.push_back(Elt: NumElts);
502 I += NumElts;
503 }
504 assert(std::accumulate(VectorInfo.begin(), VectorInfo.end(), 0u) ==
505 ValueVTs.size());
506 return VectorInfo;
507}
508
509// NVPTXTargetLowering Constructor.
510NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
511 const NVPTXSubtarget &STI)
512 : TargetLowering(TM, STI), nvTM(&TM), STI(STI), GlobalUniqueCallSite(0) {
513 // always lower memset, memcpy, and memmove intrinsics to load/store
514 // instructions, rather
515 // then generating calls to memset, mempcy or memmove.
516 MaxStoresPerMemset = MaxStoresPerMemsetOptSize = (unsigned)0xFFFFFFFF;
517 MaxStoresPerMemcpy = MaxStoresPerMemcpyOptSize = (unsigned) 0xFFFFFFFF;
518 MaxStoresPerMemmove = MaxStoresPerMemmoveOptSize = (unsigned) 0xFFFFFFFF;
519
520 setBooleanContents(ZeroOrNegativeOneBooleanContent);
521 setBooleanVectorContents(ZeroOrNegativeOneBooleanContent);
522
523 // Jump is Expensive. Don't create extra control flow for 'and', 'or'
524 // condition branches.
525 setJumpIsExpensive(true);
526
527 // Wide divides are _very_ slow. Try to reduce the width of the divide if
528 // possible.
529 addBypassSlowDiv(SlowBitWidth: 64, FastBitWidth: 32);
530
531 // By default, use the Source scheduling
532 if (sched4reg)
533 setSchedulingPreference(Sched::RegPressure);
534 else
535 setSchedulingPreference(Sched::Source);
536
537 auto setFP16OperationAction = [&](unsigned Op, MVT VT, LegalizeAction Action,
538 LegalizeAction NoF16Action) {
539 bool IsOpSupported = STI.allowFP16Math();
540 switch (Op) {
541 // Several FP16 instructions are available on sm_80 only.
542 case ISD::FMINNUM:
543 case ISD::FMAXNUM:
544 case ISD::FMAXNUM_IEEE:
545 case ISD::FMINNUM_IEEE:
546 case ISD::FMAXIMUM:
547 case ISD::FMINIMUM:
548 case ISD::FMAXIMUMNUM:
549 case ISD::FMINIMUMNUM:
550 IsOpSupported &= STI.getSmVersion() >= 80 && STI.getPTXVersion() >= 70;
551 break;
552 case ISD::FEXP2:
553 IsOpSupported &= STI.getSmVersion() >= 75 && STI.getPTXVersion() >= 70;
554 break;
555 }
556 setOperationAction(Op, VT, Action: IsOpSupported ? Action : NoF16Action);
557 };
558
559 auto setBF16OperationAction = [&](unsigned Op, MVT VT, LegalizeAction Action,
560 LegalizeAction NoBF16Action) {
561 bool IsOpSupported = STI.hasNativeBF16Support(Opcode: Op);
562 setOperationAction(
563 Op, VT, Action: IsOpSupported ? Action : NoBF16Action);
564 };
565
566 auto setI16x2OperationAction = [&](unsigned Op, MVT VT, LegalizeAction Action,
567 LegalizeAction NoI16x2Action) {
568 bool IsOpSupported = false;
569 // instructions are available on sm_90 only
570 switch (Op) {
571 case ISD::ADD:
572 case ISD::SMAX:
573 case ISD::SMIN:
574 case ISD::UMIN:
575 case ISD::UMAX:
576 IsOpSupported = STI.getSmVersion() >= 90 && STI.getPTXVersion() >= 80;
577 break;
578 }
579 setOperationAction(Op, VT, Action: IsOpSupported ? Action : NoI16x2Action);
580 };
581
582 addRegisterClass(VT: MVT::i1, RC: &NVPTX::B1RegClass);
583 addRegisterClass(VT: MVT::i16, RC: &NVPTX::B16RegClass);
584 addRegisterClass(VT: MVT::v2i16, RC: &NVPTX::B32RegClass);
585 addRegisterClass(VT: MVT::v4i8, RC: &NVPTX::B32RegClass);
586 addRegisterClass(VT: MVT::i32, RC: &NVPTX::B32RegClass);
587 addRegisterClass(VT: MVT::i64, RC: &NVPTX::B64RegClass);
588 addRegisterClass(VT: MVT::f32, RC: &NVPTX::B32RegClass);
589 addRegisterClass(VT: MVT::f64, RC: &NVPTX::B64RegClass);
590 addRegisterClass(VT: MVT::f16, RC: &NVPTX::B16RegClass);
591 addRegisterClass(VT: MVT::v2f16, RC: &NVPTX::B32RegClass);
592 addRegisterClass(VT: MVT::bf16, RC: &NVPTX::B16RegClass);
593 addRegisterClass(VT: MVT::v2bf16, RC: &NVPTX::B32RegClass);
594
595 if (STI.hasF32x2Instructions()) {
596 addRegisterClass(VT: MVT::v2f32, RC: &NVPTX::B64RegClass);
597 addRegisterClass(VT: MVT::v2i32, RC: &NVPTX::B64RegClass);
598 }
599
600 // Conversion to/from FP16/FP16x2 is always legal.
601 setOperationAction(Op: ISD::BUILD_VECTOR, VT: MVT::v2f16, Action: Custom);
602 setOperationAction(Op: ISD::EXTRACT_VECTOR_ELT, VT: MVT::v2f16, Action: Custom);
603 setOperationAction(Op: ISD::INSERT_VECTOR_ELT, VT: MVT::v2f16, Action: Expand);
604 setOperationAction(Op: ISD::VECTOR_SHUFFLE, VT: MVT::v2f16, Action: Expand);
605
606 setOperationAction(Op: ISD::READCYCLECOUNTER, VT: MVT::i64, Action: Legal);
607 if (STI.getSmVersion() >= 30 && STI.getPTXVersion() > 31)
608 setOperationAction(Op: ISD::READSTEADYCOUNTER, VT: MVT::i64, Action: Legal);
609
610 setFP16OperationAction(ISD::SETCC, MVT::f16, Legal, Promote);
611 setFP16OperationAction(ISD::SETCC, MVT::v2f16, Legal, Expand);
612
613 // Conversion to/from BFP16/BFP16x2 is always legal.
614 setOperationAction(Op: ISD::BUILD_VECTOR, VT: MVT::v2bf16, Action: Custom);
615 setOperationAction(Op: ISD::EXTRACT_VECTOR_ELT, VT: MVT::v2bf16, Action: Custom);
616 setOperationAction(Op: ISD::INSERT_VECTOR_ELT, VT: MVT::v2bf16, Action: Expand);
617 setOperationAction(Op: ISD::VECTOR_SHUFFLE, VT: MVT::v2bf16, Action: Expand);
618
619 setBF16OperationAction(ISD::SETCC, MVT::v2bf16, Legal, Expand);
620 setBF16OperationAction(ISD::SETCC, MVT::bf16, Legal, Promote);
621 if (getOperationAction(Op: ISD::SETCC, VT: MVT::bf16) == Promote)
622 AddPromotedToType(Opc: ISD::SETCC, OrigVT: MVT::bf16, DestVT: MVT::f32);
623
624 // Conversion to/from i16/i16x2 is always legal.
625 setOperationAction(Op: ISD::BUILD_VECTOR, VT: MVT::v2i16, Action: Custom);
626 setOperationAction(Op: ISD::EXTRACT_VECTOR_ELT, VT: MVT::v2i16, Action: Custom);
627 setOperationAction(Op: ISD::INSERT_VECTOR_ELT, VT: MVT::v2i16, Action: Expand);
628 setOperationAction(Op: ISD::VECTOR_SHUFFLE, VT: MVT::v2i16, Action: Expand);
629
630 setOperationAction(Op: ISD::BUILD_VECTOR, VT: MVT::v4i8, Action: Custom);
631 setOperationAction(Op: ISD::EXTRACT_VECTOR_ELT, VT: MVT::v4i8, Action: Custom);
632 setOperationAction(Op: ISD::INSERT_VECTOR_ELT, VT: MVT::v4i8, Action: Custom);
633 setOperationAction(Op: ISD::VECTOR_SHUFFLE, VT: MVT::v4i8, Action: Custom);
634
635 // No support for these operations with v2f32/v2i32
636 setOperationAction(Ops: ISD::INSERT_VECTOR_ELT, VTs: {MVT::v2f32, MVT::v2i32}, Action: Expand);
637 setOperationAction(Ops: ISD::VECTOR_SHUFFLE, VTs: {MVT::v2f32, MVT::v2i32}, Action: Expand);
638
639 setOperationAction(Op: ISD::TRUNCATE, VT: MVT::v2i16, Action: Expand);
640 setOperationAction(Ops: {ISD::ANY_EXTEND, ISD::ZERO_EXTEND, ISD::SIGN_EXTEND},
641 VT: MVT::v2i32, Action: Expand);
642
643 // Need custom lowering in case the index is dynamic.
644 if (STI.hasF32x2Instructions())
645 setOperationAction(Ops: ISD::EXTRACT_VECTOR_ELT, VTs: {MVT::v2f32, MVT::v2i32},
646 Action: Custom);
647
648 // Custom conversions to/from v2i8.
649 setOperationAction(Op: ISD::BITCAST, VT: MVT::v2i8, Action: Custom);
650
651 // Only logical ops can be done on v4i8/v2i32 directly, others must be done
652 // elementwise.
653 setOperationAction(
654 Ops: {ISD::ABS, ISD::ADD, ISD::ADDC, ISD::ADDE,
655 ISD::BITREVERSE, ISD::CTLZ, ISD::CTPOP, ISD::CTTZ,
656 ISD::FP_TO_SINT, ISD::FP_TO_UINT, ISD::FSHL, ISD::FSHR,
657 ISD::MUL, ISD::MULHS, ISD::MULHU, ISD::PARITY,
658 ISD::ROTL, ISD::ROTR, ISD::SADDO, ISD::SADDO_CARRY,
659 ISD::SADDSAT, ISD::SDIV, ISD::SDIVREM, ISD::SELECT_CC,
660 ISD::SETCC, ISD::SHL, ISD::SINT_TO_FP, ISD::SMAX,
661 ISD::SMIN, ISD::SMULO, ISD::SMUL_LOHI, ISD::SRA,
662 ISD::SREM, ISD::SRL, ISD::SSHLSAT, ISD::SSUBO,
663 ISD::SSUBO_CARRY, ISD::SSUBSAT, ISD::SUB, ISD::SUBC,
664 ISD::SUBE, ISD::UADDO, ISD::UADDO_CARRY, ISD::UADDSAT,
665 ISD::UDIV, ISD::UDIVREM, ISD::UINT_TO_FP, ISD::UMAX,
666 ISD::UMIN, ISD::UMULO, ISD::UMUL_LOHI, ISD::UREM,
667 ISD::USHLSAT, ISD::USUBO, ISD::USUBO_CARRY, ISD::VSELECT,
668 ISD::USUBSAT},
669 VTs: {MVT::v4i8, MVT::v2i32}, Action: Expand);
670
671 // Operations not directly supported by NVPTX.
672 for (MVT VT : {MVT::bf16, MVT::f16, MVT::v2bf16, MVT::v2f16, MVT::f32,
673 MVT::v2f32, MVT::f64, MVT::i1, MVT::i8, MVT::i16, MVT::v2i16,
674 MVT::v4i8, MVT::i32, MVT::v2i32, MVT::i64}) {
675 setOperationAction(Op: ISD::SELECT_CC, VT, Action: Expand);
676 setOperationAction(Op: ISD::BR_CC, VT, Action: Expand);
677 }
678
679 // We don't want ops like FMINIMUM or UMAX to be lowered to SETCC+VSELECT.
680 setOperationAction(Ops: ISD::VSELECT, VTs: {MVT::v2f32, MVT::v2i32}, Action: Expand);
681
682 // Some SIGN_EXTEND_INREG can be done using cvt instruction.
683 // For others we will expand to a SHL/SRA pair.
684 setOperationAction(Op: ISD::SIGN_EXTEND_INREG, VT: MVT::i64, Action: Legal);
685 setOperationAction(Op: ISD::SIGN_EXTEND_INREG, VT: MVT::i32, Action: Legal);
686 setOperationAction(Op: ISD::SIGN_EXTEND_INREG, VT: MVT::i16, Action: Legal);
687 setOperationAction(Op: ISD::SIGN_EXTEND_INREG, VT: MVT::i8 , Action: Legal);
688 setOperationAction(Op: ISD::SIGN_EXTEND_INREG, VT: MVT::i1, Action: Expand);
689 setOperationAction(Ops: ISD::SIGN_EXTEND_INREG, VTs: {MVT::v2i16, MVT::v2i32}, Action: Expand);
690
691 setOperationAction(Op: ISD::SHL_PARTS, VT: MVT::i32 , Action: Custom);
692 setOperationAction(Op: ISD::SRA_PARTS, VT: MVT::i32 , Action: Custom);
693 setOperationAction(Op: ISD::SRL_PARTS, VT: MVT::i32 , Action: Custom);
694 setOperationAction(Op: ISD::SHL_PARTS, VT: MVT::i64 , Action: Custom);
695 setOperationAction(Op: ISD::SRA_PARTS, VT: MVT::i64 , Action: Custom);
696 setOperationAction(Op: ISD::SRL_PARTS, VT: MVT::i64 , Action: Custom);
697
698 setOperationAction(Op: ISD::BITREVERSE, VT: MVT::i32, Action: Legal);
699 setOperationAction(Op: ISD::BITREVERSE, VT: MVT::i64, Action: Legal);
700
701 setOperationAction(Ops: {ISD::ROTL, ISD::ROTR},
702 VTs: {MVT::i8, MVT::i16, MVT::v2i16, MVT::i32, MVT::i64},
703 Action: Expand);
704
705 if (STI.hasHWROT32()) {
706 setOperationAction(Ops: {ISD::FSHL, ISD::FSHR}, VT: MVT::i32, Action: Legal);
707 setOperationAction(Ops: {ISD::ROTL, ISD::ROTR, ISD::FSHL, ISD::FSHR}, VT: MVT::i64,
708 Action: Custom);
709 }
710
711 setOperationAction(Op: ISD::BR_JT, VT: MVT::Other, Action: STI.hasBrx() ? Legal : Expand);
712 setOperationAction(Op: ISD::BRIND, VT: MVT::Other, Action: Expand);
713
714 // We want to legalize constant related memmove and memcopy
715 // intrinsics.
716 setOperationAction(Op: ISD::INTRINSIC_W_CHAIN, VT: MVT::Other, Action: Custom);
717
718 // FP extload/truncstore is not legal in PTX. We need to expand all these.
719 for (auto FloatVTs :
720 {MVT::fp_valuetypes(), MVT::fp_fixedlen_vector_valuetypes()}) {
721 for (MVT ValVT : FloatVTs) {
722 for (MVT MemVT : FloatVTs) {
723 setLoadExtAction(ExtType: ISD::EXTLOAD, ValVT, MemVT, Action: Expand);
724 setTruncStoreAction(ValVT, MemVT, Action: Expand);
725 }
726 }
727 }
728
729 // To improve CodeGen we'll legalize any-extend loads to zext loads. This is
730 // how they'll be lowered in ISel anyway, and by doing this a little earlier
731 // we allow for more DAG combine opportunities.
732 for (auto IntVTs :
733 {MVT::integer_valuetypes(), MVT::integer_fixedlen_vector_valuetypes()})
734 for (MVT ValVT : IntVTs)
735 for (MVT MemVT : IntVTs)
736 if (isTypeLegal(VT: ValVT))
737 setLoadExtAction(ExtType: ISD::EXTLOAD, ValVT, MemVT, Action: Custom);
738
739 // PTX does not support load / store predicate registers
740 setOperationAction(Ops: {ISD::LOAD, ISD::STORE}, VT: MVT::i1, Action: Custom);
741 for (MVT VT : MVT::integer_valuetypes()) {
742 setLoadExtAction(ExtTypes: {ISD::SEXTLOAD, ISD::ZEXTLOAD, ISD::EXTLOAD}, ValVT: VT, MemVT: MVT::i1,
743 Action: Promote);
744 setTruncStoreAction(ValVT: VT, MemVT: MVT::i1, Action: Expand);
745 }
746
747 // Disable generations of extload/truncstore for v2i32/v2i16/v2i8. The generic
748 // expansion for these nodes when they are unaligned is incorrect if the
749 // type is a vector.
750 //
751 // TODO: Fix the generic expansion for these nodes found in
752 // TargetLowering::expandUnalignedLoad/Store.
753 setLoadExtAction(ExtTypes: {ISD::EXTLOAD, ISD::SEXTLOAD, ISD::ZEXTLOAD}, ValVT: MVT::v2i16,
754 MemVT: MVT::v2i8, Action: Expand);
755 setLoadExtAction(ExtTypes: {ISD::EXTLOAD, ISD::SEXTLOAD, ISD::ZEXTLOAD}, ValVT: MVT::v2i32,
756 MemVTs: {MVT::v2i8, MVT::v2i16}, Action: Expand);
757 setTruncStoreAction(ValVT: MVT::v2i16, MemVT: MVT::v2i8, Action: Expand);
758 setTruncStoreAction(ValVT: MVT::v2i32, MemVT: MVT::v2i16, Action: Expand);
759 setTruncStoreAction(ValVT: MVT::v2i32, MemVT: MVT::v2i8, Action: Expand);
760
761 // Register custom handling for illegal type loads/stores. We'll try to custom
762 // lower almost all illegal types and logic in the lowering will discard cases
763 // we can't handle.
764 setOperationAction(Ops: {ISD::LOAD, ISD::STORE}, VTs: {MVT::i128, MVT::i256, MVT::f128},
765 Action: Custom);
766 for (MVT VT : MVT::fixedlen_vector_valuetypes())
767 if (!isTypeLegal(VT) && VT.getStoreSizeInBits() <= 256)
768 setOperationAction(Ops: {ISD::STORE, ISD::LOAD, ISD::MSTORE, ISD::MLOAD}, VT,
769 Action: Custom);
770
771 // Custom legalization for LDU intrinsics.
772 // TODO: The logic to lower these is not very robust and we should rewrite it.
773 // Perhaps LDU should not be represented as an intrinsic at all.
774 setOperationAction(Op: ISD::INTRINSIC_W_CHAIN, VT: MVT::i8, Action: Custom);
775 for (MVT VT : MVT::fixedlen_vector_valuetypes())
776 if (IsPTXVectorType(VT))
777 setOperationAction(Op: ISD::INTRINSIC_W_CHAIN, VT, Action: Custom);
778
779 setCondCodeAction(CCs: {ISD::SETNE, ISD::SETEQ, ISD::SETUGE, ISD::SETULE,
780 ISD::SETUGT, ISD::SETULT, ISD::SETGT, ISD::SETLT,
781 ISD::SETGE, ISD::SETLE},
782 VT: MVT::i1, Action: Expand);
783
784 // This is legal in NVPTX
785 setOperationAction(Op: ISD::ConstantFP, VT: MVT::f64, Action: Legal);
786 setOperationAction(Op: ISD::ConstantFP, VT: MVT::f32, Action: Legal);
787 setOperationAction(Op: ISD::ConstantFP, VT: MVT::f16, Action: Legal);
788 setOperationAction(Op: ISD::ConstantFP, VT: MVT::bf16, Action: Legal);
789
790 setOperationAction(Ops: ISD::DYNAMIC_STACKALLOC, VTs: {MVT::i32, MVT::i64}, Action: Custom);
791 setOperationAction(Ops: {ISD::STACKRESTORE, ISD::STACKSAVE}, VT: MVT::Other, Action: Custom);
792
793 // TRAP can be lowered to PTX trap
794 setOperationAction(Op: ISD::TRAP, VT: MVT::Other, Action: Legal);
795 // DEBUGTRAP can be lowered to PTX brkpt
796 setOperationAction(Op: ISD::DEBUGTRAP, VT: MVT::Other, Action: Legal);
797
798 // Support varargs.
799 setOperationAction(Op: ISD::VASTART, VT: MVT::Other, Action: Custom);
800 setOperationAction(Op: ISD::VAARG, VT: MVT::Other, Action: Custom);
801 setOperationAction(Op: ISD::VACOPY, VT: MVT::Other, Action: Expand);
802 setOperationAction(Op: ISD::VAEND, VT: MVT::Other, Action: Expand);
803
804 setOperationAction(Ops: {ISD::ABS, ISD::SMIN, ISD::SMAX, ISD::UMIN, ISD::UMAX},
805 VTs: {MVT::i16, MVT::i32, MVT::i64}, Action: Legal);
806
807 setOperationAction(Ops: {ISD::CTPOP, ISD::CTLZ, ISD::CTLZ_ZERO_UNDEF}, VT: MVT::i16,
808 Action: Promote);
809 setOperationAction(Ops: {ISD::CTPOP, ISD::CTLZ}, VT: MVT::i32, Action: Legal);
810 setOperationAction(Ops: {ISD::CTPOP, ISD::CTLZ}, VT: MVT::i64, Action: Custom);
811
812 setI16x2OperationAction(ISD::ABS, MVT::v2i16, Legal, Custom);
813 setI16x2OperationAction(ISD::SMIN, MVT::v2i16, Legal, Custom);
814 setI16x2OperationAction(ISD::SMAX, MVT::v2i16, Legal, Custom);
815 setI16x2OperationAction(ISD::UMIN, MVT::v2i16, Legal, Custom);
816 setI16x2OperationAction(ISD::UMAX, MVT::v2i16, Legal, Custom);
817 setI16x2OperationAction(ISD::CTPOP, MVT::v2i16, Legal, Expand);
818 setI16x2OperationAction(ISD::CTLZ, MVT::v2i16, Legal, Expand);
819
820 setI16x2OperationAction(ISD::ADD, MVT::v2i16, Legal, Custom);
821 setI16x2OperationAction(ISD::SUB, MVT::v2i16, Legal, Custom);
822 setI16x2OperationAction(ISD::MUL, MVT::v2i16, Legal, Custom);
823 setI16x2OperationAction(ISD::SHL, MVT::v2i16, Legal, Custom);
824 setI16x2OperationAction(ISD::SREM, MVT::v2i16, Legal, Custom);
825 setI16x2OperationAction(ISD::UREM, MVT::v2i16, Legal, Custom);
826
827 // Other arithmetic and logic ops are unsupported.
828 setOperationAction(Ops: {ISD::SDIV, ISD::UDIV, ISD::SRA, ISD::SRL, ISD::MULHS,
829 ISD::MULHU, ISD::FP_TO_SINT, ISD::FP_TO_UINT,
830 ISD::SINT_TO_FP, ISD::UINT_TO_FP, ISD::SETCC},
831 VTs: {MVT::v2i16, MVT::v2i32}, Action: Expand);
832
833 // v2i32 is not supported for any arithmetic operations
834 setOperationAction(Ops: {ISD::ABS, ISD::SMIN, ISD::SMAX, ISD::UMIN, ISD::UMAX,
835 ISD::CTPOP, ISD::CTLZ, ISD::ADD, ISD::SUB, ISD::MUL,
836 ISD::SHL, ISD::SRA, ISD::SRL, ISD::OR, ISD::AND, ISD::XOR,
837 ISD::SREM, ISD::UREM},
838 VT: MVT::v2i32, Action: Expand);
839
840 setOperationAction(Op: ISD::ADDC, VT: MVT::i32, Action: Legal);
841 setOperationAction(Op: ISD::ADDE, VT: MVT::i32, Action: Legal);
842 setOperationAction(Op: ISD::SUBC, VT: MVT::i32, Action: Legal);
843 setOperationAction(Op: ISD::SUBE, VT: MVT::i32, Action: Legal);
844 if (STI.getPTXVersion() >= 43) {
845 setOperationAction(Op: ISD::ADDC, VT: MVT::i64, Action: Legal);
846 setOperationAction(Op: ISD::ADDE, VT: MVT::i64, Action: Legal);
847 setOperationAction(Op: ISD::SUBC, VT: MVT::i64, Action: Legal);
848 setOperationAction(Op: ISD::SUBE, VT: MVT::i64, Action: Legal);
849 }
850
851 setOperationAction(Op: ISD::CTTZ, VT: MVT::i16, Action: Expand);
852 setOperationAction(Ops: ISD::CTTZ, VTs: {MVT::v2i16, MVT::v2i32}, Action: Expand);
853 setOperationAction(Op: ISD::CTTZ, VT: MVT::i32, Action: Expand);
854 setOperationAction(Op: ISD::CTTZ, VT: MVT::i64, Action: Expand);
855
856 // PTX does not directly support SELP of i1, so promote to i32 first
857 setOperationAction(Op: ISD::SELECT, VT: MVT::i1, Action: Custom);
858
859 // PTX cannot multiply two i64s in a single instruction.
860 setOperationAction(Op: ISD::SMUL_LOHI, VT: MVT::i64, Action: Expand);
861 setOperationAction(Op: ISD::UMUL_LOHI, VT: MVT::i64, Action: Expand);
862
863 // We have some custom DAG combine patterns for these nodes
864 setTargetDAGCombine({ISD::ADD,
865 ISD::AND,
866 ISD::EXTRACT_VECTOR_ELT,
867 ISD::FADD,
868 ISD::FMAXNUM,
869 ISD::FMINNUM,
870 ISD::FMAXIMUM,
871 ISD::FMINIMUM,
872 ISD::FMAXIMUMNUM,
873 ISD::FMINIMUMNUM,
874 ISD::MUL,
875 ISD::SELECT,
876 ISD::SHL,
877 ISD::SREM,
878 ISD::UREM,
879 ISD::VSELECT,
880 ISD::BUILD_VECTOR,
881 ISD::ADDRSPACECAST,
882 ISD::LOAD,
883 ISD::STORE,
884 ISD::ZERO_EXTEND,
885 ISD::SIGN_EXTEND,
886 ISD::INTRINSIC_WO_CHAIN});
887
888 // If the vector operands require register coalescing, scalarize instead
889 if (STI.hasF32x2Instructions())
890 setTargetDAGCombine({ISD::FMA, ISD::FMUL, ISD::FSUB});
891
892 // setcc for f16x2 and bf16x2 needs special handling to prevent
893 // legalizer's attempt to scalarize it due to v2i1 not being legal.
894 if (STI.allowFP16Math() || STI.hasBF16Math())
895 setTargetDAGCombine(ISD::SETCC);
896
897 // Vector reduction operations. These may be turned into shuffle or tree
898 // reductions depending on what instructions are available for each type.
899 for (MVT VT : MVT::fixedlen_vector_valuetypes()) {
900 MVT EltVT = VT.getVectorElementType();
901 if (EltVT == MVT::f32 || EltVT == MVT::f64) {
902 setOperationAction(Ops: {ISD::VECREDUCE_FMAX, ISD::VECREDUCE_FMIN,
903 ISD::VECREDUCE_FMAXIMUM, ISD::VECREDUCE_FMINIMUM},
904 VT, Action: Custom);
905 }
906 }
907
908 // Promote fp16 arithmetic if fp16 hardware isn't available or the
909 // user passed --nvptx-no-fp16-math. The flag is useful because,
910 // although sm_53+ GPUs have some sort of FP16 support in
911 // hardware, only sm_53 and sm_60 have full implementation. Others
912 // only have token amount of hardware and are likely to run faster
913 // by using fp32 units instead.
914 for (const auto &Op : {ISD::FADD, ISD::FMUL, ISD::FSUB, ISD::FMA}) {
915 setFP16OperationAction(Op, MVT::f16, Legal, Promote);
916 setFP16OperationAction(Op, MVT::v2f16, Legal, Expand);
917 setBF16OperationAction(Op, MVT::v2bf16, Legal, Expand);
918 // bf16 must be promoted to f32.
919 setBF16OperationAction(Op, MVT::bf16, Legal, Promote);
920 if (getOperationAction(Op, VT: MVT::bf16) == Promote)
921 AddPromotedToType(Opc: Op, OrigVT: MVT::bf16, DestVT: MVT::f32);
922 setOperationAction(Op, VT: MVT::v2f32,
923 Action: STI.hasF32x2Instructions() ? Legal : Expand);
924 }
925
926 // On SM80, we select add/mul/sub as fma to avoid promotion to float
927 for (const auto &Op : {ISD::FADD, ISD::FMUL, ISD::FSUB}) {
928 for (const auto &VT : {MVT::bf16, MVT::v2bf16}) {
929 if (!STI.hasNativeBF16Support(Opcode: Op) && STI.hasNativeBF16Support(Opcode: ISD::FMA)) {
930 setOperationAction(Op, VT, Action: Custom);
931 }
932 }
933 }
934
935 // f16/f16x2 neg was introduced in PTX 60, SM_53.
936 const bool IsFP16FP16x2NegAvailable = STI.getSmVersion() >= 53 &&
937 STI.getPTXVersion() >= 60 &&
938 STI.allowFP16Math();
939 for (const auto &VT : {MVT::f16, MVT::v2f16})
940 setOperationAction(Op: ISD::FNEG, VT,
941 Action: IsFP16FP16x2NegAvailable ? Legal : Expand);
942
943 setBF16OperationAction(ISD::FNEG, MVT::bf16, Legal, Expand);
944 setBF16OperationAction(ISD::FNEG, MVT::v2bf16, Legal, Expand);
945 setOperationAction(Op: ISD::FNEG, VT: MVT::v2f32, Action: Expand);
946 // (would be) Library functions.
947
948 // These map to conversion instructions for scalar FP types.
949 for (const auto &Op : {ISD::FCEIL, ISD::FFLOOR, ISD::FNEARBYINT, ISD::FRINT,
950 ISD::FROUNDEVEN, ISD::FTRUNC}) {
951 setOperationAction(Op, VT: MVT::f16, Action: Legal);
952 setOperationAction(Op, VT: MVT::f32, Action: Legal);
953 setOperationAction(Op, VT: MVT::f64, Action: Legal);
954 setOperationAction(Op, VT: MVT::v2f16, Action: Expand);
955 setOperationAction(Op, VT: MVT::v2bf16, Action: Expand);
956 setOperationAction(Op, VT: MVT::v2f32, Action: Expand);
957 setBF16OperationAction(Op, MVT::bf16, Legal, Promote);
958 if (getOperationAction(Op, VT: MVT::bf16) == Promote)
959 AddPromotedToType(Opc: Op, OrigVT: MVT::bf16, DestVT: MVT::f32);
960 }
961
962 if (STI.getSmVersion() < 80 || STI.getPTXVersion() < 71) {
963 setOperationAction(Op: ISD::BF16_TO_FP, VT: MVT::f32, Action: Expand);
964 }
965 if (STI.getSmVersion() < 90 || STI.getPTXVersion() < 78) {
966 for (MVT VT : {MVT::bf16, MVT::f32, MVT::f64}) {
967 setOperationAction(Op: ISD::FP_EXTEND, VT, Action: Custom);
968 setOperationAction(Op: ISD::FP_ROUND, VT, Action: Custom);
969 }
970 }
971
972 // Expand v2f32 = fp_extend
973 setOperationAction(Op: ISD::FP_EXTEND, VT: MVT::v2f32, Action: Expand);
974 // Expand v2[b]f16 = fp_round v2f32
975 setOperationAction(Ops: ISD::FP_ROUND, VTs: {MVT::v2bf16, MVT::v2f16}, Action: Expand);
976
977 // sm_80 only has conversions between f32 and bf16. Custom lower all other
978 // bf16 conversions.
979 if (STI.getSmVersion() < 90 || STI.getPTXVersion() < 78) {
980 for (MVT VT : {MVT::i1, MVT::i16, MVT::i32, MVT::i64}) {
981 setOperationAction(
982 Ops: {ISD::SINT_TO_FP, ISD::UINT_TO_FP, ISD::FP_TO_SINT, ISD::FP_TO_UINT},
983 VT, Action: Custom);
984 }
985 setOperationAction(
986 Ops: {ISD::SINT_TO_FP, ISD::UINT_TO_FP, ISD::FP_TO_SINT, ISD::FP_TO_UINT},
987 VT: MVT::bf16, Action: Custom);
988 }
989
990 setOperationAction(Op: ISD::FROUND, VT: MVT::f16, Action: Promote);
991 setOperationAction(Op: ISD::FROUND, VT: MVT::v2f16, Action: Expand);
992 setOperationAction(Op: ISD::FROUND, VT: MVT::v2bf16, Action: Expand);
993 setOperationAction(Op: ISD::FROUND, VT: MVT::f32, Action: Custom);
994 setOperationAction(Op: ISD::FROUND, VT: MVT::f64, Action: Custom);
995 setOperationAction(Op: ISD::FROUND, VT: MVT::bf16, Action: Promote);
996 AddPromotedToType(Opc: ISD::FROUND, OrigVT: MVT::bf16, DestVT: MVT::f32);
997
998 // 'Expand' implements FCOPYSIGN without calling an external library.
999 setOperationAction(Op: ISD::FCOPYSIGN, VT: MVT::f16, Action: Expand);
1000 setOperationAction(Op: ISD::FCOPYSIGN, VT: MVT::v2f16, Action: Expand);
1001 setOperationAction(Op: ISD::FCOPYSIGN, VT: MVT::bf16, Action: Expand);
1002 setOperationAction(Op: ISD::FCOPYSIGN, VT: MVT::v2bf16, Action: Expand);
1003 setOperationAction(Op: ISD::FCOPYSIGN, VT: MVT::f32, Action: Custom);
1004 setOperationAction(Op: ISD::FCOPYSIGN, VT: MVT::f64, Action: Custom);
1005
1006 // These map to corresponding instructions for f32/f64. f16 must be
1007 // promoted to f32. v2f16 is expanded to f16, which is then promoted
1008 // to f32.
1009 for (const auto &Op :
1010 {ISD::FDIV, ISD::FREM, ISD::FSQRT, ISD::FSIN, ISD::FCOS, ISD::FTANH}) {
1011 setOperationAction(Op, VT: MVT::f16, Action: Promote);
1012 setOperationAction(Op, VT: MVT::f32, Action: Legal);
1013 // only div/rem/sqrt are legal for f64
1014 if (Op == ISD::FDIV || Op == ISD::FREM || Op == ISD::FSQRT) {
1015 setOperationAction(Op, VT: MVT::f64, Action: Legal);
1016 }
1017 setOperationAction(Ops: Op, VTs: {MVT::v2f16, MVT::v2bf16, MVT::v2f32}, Action: Expand);
1018 setOperationAction(Op, VT: MVT::bf16, Action: Promote);
1019 AddPromotedToType(Opc: Op, OrigVT: MVT::bf16, DestVT: MVT::f32);
1020 }
1021 setOperationAction(Ops: ISD::FREM, VTs: {MVT::f32, MVT::f64}, Action: Custom);
1022
1023 setOperationAction(Ops: ISD::FABS, VTs: {MVT::f32, MVT::f64}, Action: Legal);
1024 setOperationAction(Op: ISD::FABS, VT: MVT::v2f32, Action: Expand);
1025 if (STI.getPTXVersion() >= 65) {
1026 setFP16OperationAction(ISD::FABS, MVT::f16, Legal, Promote);
1027 setFP16OperationAction(ISD::FABS, MVT::v2f16, Legal, Expand);
1028 } else {
1029 setOperationAction(Op: ISD::FABS, VT: MVT::f16, Action: Promote);
1030 setOperationAction(Op: ISD::FABS, VT: MVT::v2f16, Action: Expand);
1031 }
1032 setBF16OperationAction(ISD::FABS, MVT::v2bf16, Legal, Expand);
1033 setBF16OperationAction(ISD::FABS, MVT::bf16, Legal, Promote);
1034 if (getOperationAction(Op: ISD::FABS, VT: MVT::bf16) == Promote)
1035 AddPromotedToType(Opc: ISD::FABS, OrigVT: MVT::bf16, DestVT: MVT::f32);
1036
1037 for (const auto &Op :
1038 {ISD::FMINNUM, ISD::FMAXNUM, ISD::FMINIMUMNUM, ISD::FMAXIMUMNUM}) {
1039 setOperationAction(Op, VT: MVT::f32, Action: Legal);
1040 setOperationAction(Op, VT: MVT::f64, Action: Legal);
1041 setFP16OperationAction(Op, MVT::f16, Legal, Promote);
1042 setFP16OperationAction(Op, MVT::v2f16, Legal, Expand);
1043 setBF16OperationAction(Op, MVT::v2bf16, Legal, Expand);
1044 setBF16OperationAction(Op, MVT::bf16, Legal, Promote);
1045 if (getOperationAction(Op, VT: MVT::bf16) == Promote)
1046 AddPromotedToType(Opc: Op, OrigVT: MVT::bf16, DestVT: MVT::f32);
1047 setOperationAction(Op, VT: MVT::v2f32, Action: Expand);
1048 }
1049 bool SupportsF32MinMaxNaN =
1050 STI.getSmVersion() >= 80 && STI.getPTXVersion() >= 70;
1051 for (const auto &Op : {ISD::FMINIMUM, ISD::FMAXIMUM}) {
1052 setOperationAction(Op, VT: MVT::f32, Action: SupportsF32MinMaxNaN ? Legal : Expand);
1053 setFP16OperationAction(Op, MVT::f16, Legal, Expand);
1054 setFP16OperationAction(Op, MVT::v2f16, Legal, Expand);
1055 setBF16OperationAction(Op, MVT::bf16, Legal, Expand);
1056 setBF16OperationAction(Op, MVT::v2bf16, Legal, Expand);
1057 setOperationAction(Op, VT: MVT::v2f32, Action: Expand);
1058 }
1059
1060 // Custom lowering for inline asm with 128-bit operands
1061 setOperationAction(Op: ISD::CopyToReg, VT: MVT::i128, Action: Custom);
1062 setOperationAction(Op: ISD::CopyFromReg, VT: MVT::i128, Action: Custom);
1063
1064 // FEXP2 support:
1065 // - f32
1066 // - f16/f16x2 (sm_70+, PTX 7.0+)
1067 // - bf16/bf16x2 (sm_90+, PTX 7.8+)
1068 // When f16/bf16 types aren't supported, they are promoted/expanded to f32.
1069 setOperationAction(Op: ISD::FEXP2, VT: MVT::f32, Action: Legal);
1070 setOperationAction(Op: ISD::FEXP2, VT: MVT::v2f32, Action: Expand);
1071 setFP16OperationAction(ISD::FEXP2, MVT::f16, Legal, Promote);
1072 setFP16OperationAction(ISD::FEXP2, MVT::v2f16, Legal, Expand);
1073 setBF16OperationAction(ISD::FEXP2, MVT::bf16, Legal, Promote);
1074 setBF16OperationAction(ISD::FEXP2, MVT::v2bf16, Legal, Expand);
1075
1076 // FLOG2 supports f32 only
1077 // f16/bf16 types aren't supported, but they are promoted/expanded to f32.
1078 if (UseApproxLog2F32) {
1079 setOperationAction(Op: ISD::FLOG2, VT: MVT::f32, Action: Legal);
1080 setOperationPromotedToType(Opc: ISD::FLOG2, OrigVT: MVT::f16, DestVT: MVT::f32);
1081 setOperationPromotedToType(Opc: ISD::FLOG2, OrigVT: MVT::bf16, DestVT: MVT::f32);
1082 setOperationAction(Ops: ISD::FLOG2, VTs: {MVT::v2f16, MVT::v2bf16, MVT::v2f32},
1083 Action: Expand);
1084 }
1085
1086 setOperationAction(Ops: ISD::ADDRSPACECAST, VTs: {MVT::i32, MVT::i64}, Action: Custom);
1087
1088 setOperationAction(Ops: ISD::ATOMIC_LOAD_SUB, VTs: {MVT::i32, MVT::i64}, Action: Expand);
1089
1090 // atom.b128 is legal in PTX but since we don't represent i128 as a legal
1091 // type, we need to custom lower it.
1092 setOperationAction(Ops: {ISD::ATOMIC_CMP_SWAP, ISD::ATOMIC_SWAP}, VT: MVT::i128,
1093 Action: Custom);
1094
1095 // Now deduce the information based on the above mentioned
1096 // actions
1097 computeRegisterProperties(TRI: STI.getRegisterInfo());
1098
1099 // PTX support for 16-bit CAS is emulated. Only use 32+
1100 setMinCmpXchgSizeInBits(STI.getMinCmpXchgSizeInBits());
1101 setMaxAtomicSizeInBitsSupported(STI.hasAtomSwap128() ? 128 : 64);
1102 setMaxDivRemBitWidthSupported(64);
1103
1104 // Custom lowering for tcgen05.ld vector operands
1105 setOperationAction(Ops: ISD::INTRINSIC_W_CHAIN,
1106 VTs: {MVT::v2i32, MVT::v4i32, MVT::v8i32, MVT::v16i32,
1107 MVT::v32i32, MVT::v64i32, MVT::v128i32, MVT::v2f32,
1108 MVT::v4f32, MVT::v8f32, MVT::v16f32, MVT::v32f32,
1109 MVT::v64f32, MVT::v128f32},
1110 Action: Custom);
1111
1112 // Custom lowering for tcgen05.st vector operands
1113 setOperationAction(Ops: ISD::INTRINSIC_VOID,
1114 VTs: {MVT::v2i32, MVT::v4i32, MVT::v8i32, MVT::v16i32,
1115 MVT::v32i32, MVT::v64i32, MVT::v128i32, MVT::Other},
1116 Action: Custom);
1117
1118 // Enable custom lowering for the following:
1119 // * MVT::i128 - clusterlaunchcontrol
1120 // * MVT::i32 - prmt
1121 // * MVT::v4f32 - cvt_rs fp{4/6/8}x4 intrinsics
1122 // * MVT::Other - internal.addrspace.wrap
1123 setOperationAction(Ops: ISD::INTRINSIC_WO_CHAIN,
1124 VTs: {MVT::i32, MVT::i128, MVT::v4f32, MVT::Other}, Action: Custom);
1125
1126 // Custom lowering for bswap
1127 setOperationAction(Ops: ISD::BSWAP, VTs: {MVT::i16, MVT::i32, MVT::i64, MVT::v2i16},
1128 Action: Custom);
1129}
1130
1131TargetLoweringBase::LegalizeTypeAction
1132NVPTXTargetLowering::getPreferredVectorAction(MVT VT) const {
1133 if (!VT.isScalableVector() && VT.getVectorNumElements() != 1 &&
1134 VT.getScalarType() == MVT::i1)
1135 return TypeSplitVector;
1136 return TargetLoweringBase::getPreferredVectorAction(VT);
1137}
1138
1139SDValue NVPTXTargetLowering::getSqrtEstimate(SDValue Operand, SelectionDAG &DAG,
1140 int Enabled, int &ExtraSteps,
1141 bool &UseOneConst,
1142 bool Reciprocal) const {
1143 if (!(Enabled == ReciprocalEstimate::Enabled ||
1144 (Enabled == ReciprocalEstimate::Unspecified && !usePrecSqrtF32())))
1145 return SDValue();
1146
1147 if (ExtraSteps == ReciprocalEstimate::Unspecified)
1148 ExtraSteps = 0;
1149
1150 SDLoc DL(Operand);
1151 EVT VT = Operand.getValueType();
1152 bool Ftz = useF32FTZ(MF: DAG.getMachineFunction());
1153
1154 auto MakeIntrinsicCall = [&](Intrinsic::ID IID) {
1155 return DAG.getNode(Opcode: ISD::INTRINSIC_WO_CHAIN, DL, VT,
1156 N1: DAG.getConstant(Val: IID, DL, VT: MVT::i32), N2: Operand);
1157 };
1158
1159 // The sqrt and rsqrt refinement processes assume we always start out with an
1160 // approximation of the rsqrt. Therefore, if we're going to do any refinement
1161 // (i.e. ExtraSteps > 0), we must return an rsqrt. But if we're *not* doing
1162 // any refinement, we must return a regular sqrt.
1163 if (Reciprocal || ExtraSteps > 0) {
1164 if (VT == MVT::f32)
1165 return MakeIntrinsicCall(Ftz ? Intrinsic::nvvm_rsqrt_approx_ftz_f
1166 : Intrinsic::nvvm_rsqrt_approx_f);
1167 else if (VT == MVT::f64)
1168 return MakeIntrinsicCall(Intrinsic::nvvm_rsqrt_approx_d);
1169 else
1170 return SDValue();
1171 } else {
1172 if (VT == MVT::f32)
1173 return MakeIntrinsicCall(Ftz ? Intrinsic::nvvm_sqrt_approx_ftz_f
1174 : Intrinsic::nvvm_sqrt_approx_f);
1175 else {
1176 // There's no sqrt.approx.f64 instruction, so we emit
1177 // reciprocal(rsqrt(x)). This is faster than
1178 // select(x == 0, 0, x * rsqrt(x)). (In fact, it's faster than plain
1179 // x * rsqrt(x).)
1180 return DAG.getNode(
1181 Opcode: ISD::INTRINSIC_WO_CHAIN, DL, VT,
1182 N1: DAG.getConstant(Val: Intrinsic::nvvm_rcp_approx_ftz_d, DL, VT: MVT::i32),
1183 N2: MakeIntrinsicCall(Intrinsic::nvvm_rsqrt_approx_d));
1184 }
1185 }
1186}
1187
1188static Align getArgumentAlignment(const CallBase *CB, Type *Ty, unsigned Idx,
1189 const DataLayout &DL);
1190
1191std::string NVPTXTargetLowering::getPrototype(
1192 const DataLayout &DL, Type *RetTy, const ArgListTy &Args,
1193 const SmallVectorImpl<ISD::OutputArg> &Outs,
1194 std::optional<unsigned> FirstVAArg, const CallBase &CB,
1195 unsigned UniqueCallSite) const {
1196 auto PtrVT = getPointerTy(DL);
1197
1198 std::string Prototype;
1199 raw_string_ostream O(Prototype);
1200 O << "prototype_" << UniqueCallSite << " : .callprototype ";
1201
1202 if (RetTy->isVoidTy()) {
1203 O << "()";
1204 } else {
1205 O << "(";
1206 if (shouldPassAsArray(Ty: RetTy)) {
1207 const Align RetAlign = getArgumentAlignment(CB: &CB, Ty: RetTy, Idx: 0, DL);
1208 O << ".param .align " << RetAlign.value() << " .b8 _["
1209 << DL.getTypeAllocSize(Ty: RetTy) << "]";
1210 } else if (RetTy->isFloatingPointTy() || RetTy->isIntegerTy()) {
1211 unsigned size = 0;
1212 if (auto *ITy = dyn_cast<IntegerType>(Val: RetTy)) {
1213 size = ITy->getBitWidth();
1214 } else {
1215 assert(RetTy->isFloatingPointTy() &&
1216 "Floating point type expected here");
1217 size = RetTy->getPrimitiveSizeInBits();
1218 }
1219 // PTX ABI requires all scalar return values to be at least 32
1220 // bits in size. fp16 normally uses .b16 as its storage type in
1221 // PTX, so its size must be adjusted here, too.
1222 size = promoteScalarArgumentSize(size);
1223
1224 O << ".param .b" << size << " _";
1225 } else if (isa<PointerType>(Val: RetTy)) {
1226 O << ".param .b" << PtrVT.getSizeInBits() << " _";
1227 } else {
1228 llvm_unreachable("Unknown return type");
1229 }
1230 O << ") ";
1231 }
1232 O << "_ (";
1233
1234 bool first = true;
1235
1236 const unsigned NumArgs = FirstVAArg.value_or(u: Args.size());
1237 auto AllOuts = ArrayRef(Outs);
1238 for (const unsigned I : llvm::seq(Size: NumArgs)) {
1239 const auto ArgOuts =
1240 AllOuts.take_while(Pred: [I](auto O) { return O.OrigArgIndex == I; });
1241 AllOuts = AllOuts.drop_front(N: ArgOuts.size());
1242
1243 Type *Ty = Args[I].Ty;
1244 if (!first) {
1245 O << ", ";
1246 }
1247 first = false;
1248
1249 if (ArgOuts[0].Flags.isByVal()) {
1250 // Indirect calls need strict ABI alignment so we disable optimizations by
1251 // not providing a function to optimize.
1252 Type *ETy = Args[I].IndirectType;
1253 Align InitialAlign = ArgOuts[0].Flags.getNonZeroByValAlign();
1254 Align ParamByValAlign =
1255 getFunctionByValParamAlign(/*F=*/nullptr, ArgTy: ETy, InitialAlign, DL);
1256
1257 O << ".param .align " << ParamByValAlign.value() << " .b8 _["
1258 << ArgOuts[0].Flags.getByValSize() << "]";
1259 } else {
1260 if (shouldPassAsArray(Ty)) {
1261 Align ParamAlign =
1262 getArgumentAlignment(CB: &CB, Ty, Idx: I + AttributeList::FirstArgIndex, DL);
1263 O << ".param .align " << ParamAlign.value() << " .b8 _["
1264 << DL.getTypeAllocSize(Ty) << "]";
1265 continue;
1266 }
1267 // i8 types in IR will be i16 types in SDAG
1268 assert((getValueType(DL, Ty) == ArgOuts[0].VT ||
1269 (getValueType(DL, Ty) == MVT::i8 && ArgOuts[0].VT == MVT::i16)) &&
1270 "type mismatch between callee prototype and arguments");
1271 // scalar type
1272 unsigned sz = 0;
1273 if (auto *ITy = dyn_cast<IntegerType>(Val: Ty)) {
1274 sz = promoteScalarArgumentSize(size: ITy->getBitWidth());
1275 } else if (isa<PointerType>(Val: Ty)) {
1276 sz = PtrVT.getSizeInBits();
1277 } else {
1278 sz = Ty->getPrimitiveSizeInBits();
1279 }
1280 O << ".param .b" << sz << " _";
1281 }
1282 }
1283
1284 if (FirstVAArg)
1285 O << (first ? "" : ",") << " .param .align "
1286 << STI.getMaxRequiredAlignment() << " .b8 _[]";
1287 O << ")";
1288 if (shouldEmitPTXNoReturn(V: &CB, TM: *nvTM))
1289 O << " .noreturn";
1290 O << ";";
1291
1292 return Prototype;
1293}
1294
1295static Align getArgumentAlignment(const CallBase *CB, Type *Ty, unsigned Idx,
1296 const DataLayout &DL) {
1297 if (!CB) {
1298 // CallSite is zero, fallback to ABI type alignment
1299 return DL.getABITypeAlign(Ty);
1300 }
1301
1302 const Function *DirectCallee = CB->getCalledFunction();
1303
1304 if (!DirectCallee) {
1305 // We don't have a direct function symbol, but that may be because of
1306 // constant cast instructions in the call.
1307
1308 // With bitcast'd call targets, the instruction will be the call
1309 if (const auto *CI = dyn_cast<CallInst>(Val: CB)) {
1310 // Check if we have call alignment metadata
1311 if (MaybeAlign StackAlign = getAlign(*CI, Idx))
1312 return StackAlign.value();
1313 }
1314 DirectCallee = getMaybeBitcastedCallee(CB);
1315 }
1316
1317 // Check for function alignment information if we found that the
1318 // ultimate target is a Function
1319 if (DirectCallee)
1320 return getFunctionArgumentAlignment(F: DirectCallee, Ty, Idx, DL);
1321
1322 // Call is indirect, fall back to the ABI type alignment
1323 return DL.getABITypeAlign(Ty);
1324}
1325
1326static bool shouldConvertToIndirectCall(const CallBase *CB,
1327 const GlobalAddressSDNode *Func) {
1328 if (!Func)
1329 return false;
1330 if (auto *CalleeFunc = dyn_cast<Function>(Val: Func->getGlobal()))
1331 return CB->getFunctionType() != CalleeFunc->getFunctionType();
1332 return false;
1333}
1334
1335static MachinePointerInfo refinePtrAS(SDValue &Ptr, SelectionDAG &DAG,
1336 const DataLayout &DL,
1337 const TargetLowering &TL) {
1338 if (Ptr->getOpcode() == ISD::FrameIndex) {
1339 auto Ty = TL.getPointerTy(DL, AS: ADDRESS_SPACE_LOCAL);
1340 Ptr = DAG.getAddrSpaceCast(dl: SDLoc(), VT: Ty, Ptr, SrcAS: ADDRESS_SPACE_GENERIC,
1341 DestAS: ADDRESS_SPACE_LOCAL);
1342
1343 return MachinePointerInfo(ADDRESS_SPACE_LOCAL);
1344 }
1345
1346 // Peel of an addrspacecast to generic and load directly from the specific
1347 // address space.
1348 if (Ptr->getOpcode() == ISD::ADDRSPACECAST) {
1349 const auto *ASC = cast<AddrSpaceCastSDNode>(Val&: Ptr);
1350 if (ASC->getDestAddressSpace() == ADDRESS_SPACE_GENERIC) {
1351 Ptr = ASC->getOperand(Num: 0);
1352 return MachinePointerInfo(ASC->getSrcAddressSpace());
1353 }
1354 }
1355
1356 return MachinePointerInfo();
1357}
1358
1359static ISD::NodeType getExtOpcode(const ISD::ArgFlagsTy &Flags) {
1360 if (Flags.isSExt())
1361 return ISD::SIGN_EXTEND;
1362 if (Flags.isZExt())
1363 return ISD::ZERO_EXTEND;
1364 return ISD::ANY_EXTEND;
1365}
1366
1367static SDValue correctParamType(SDValue V, EVT ExpectedVT,
1368 ISD::ArgFlagsTy Flags, SelectionDAG &DAG,
1369 SDLoc dl) {
1370 const EVT ActualVT = V.getValueType();
1371 assert((ActualVT == ExpectedVT ||
1372 (ExpectedVT.isInteger() && ActualVT.isInteger())) &&
1373 "Non-integer argument type size mismatch");
1374 if (ExpectedVT.bitsGT(VT: ActualVT))
1375 return DAG.getNode(Opcode: getExtOpcode(Flags), DL: dl, VT: ExpectedVT, Operand: V);
1376 if (ExpectedVT.bitsLT(VT: ActualVT))
1377 return DAG.getNode(Opcode: ISD::TRUNCATE, DL: dl, VT: ExpectedVT, Operand: V);
1378
1379 return V;
1380}
1381
1382SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
1383 SmallVectorImpl<SDValue> &InVals) const {
1384
1385 if (CLI.IsVarArg && (STI.getPTXVersion() < 60 || STI.getSmVersion() < 30))
1386 report_fatal_error(
1387 reason: "Support for variadic functions (unsized array parameter) introduced "
1388 "in PTX ISA version 6.0 and requires target sm_30.");
1389
1390 SelectionDAG &DAG = CLI.DAG;
1391 SDLoc dl = CLI.DL;
1392 const SmallVectorImpl<ISD::InputArg> &Ins = CLI.Ins;
1393 SDValue Callee = CLI.Callee;
1394 ArgListTy &Args = CLI.getArgs();
1395 Type *RetTy = CLI.RetTy;
1396 const CallBase *CB = CLI.CB;
1397 const DataLayout &DL = DAG.getDataLayout();
1398 LLVMContext &Ctx = *DAG.getContext();
1399
1400 const auto GetI32 = [&](const unsigned I) {
1401 return DAG.getConstant(Val: I, DL: dl, VT: MVT::i32);
1402 };
1403
1404 const unsigned UniqueCallSite = GlobalUniqueCallSite++;
1405 const SDValue CallChain = CLI.Chain;
1406 const SDValue StartChain =
1407 DAG.getCALLSEQ_START(Chain: CallChain, InSize: UniqueCallSite, OutSize: 0, DL: dl);
1408 SDValue DeclareGlue = StartChain.getValue(R: 1);
1409
1410 SmallVector<SDValue, 16> CallPrereqs{StartChain};
1411
1412 const auto MakeDeclareScalarParam = [&](SDValue Symbol, unsigned Size) {
1413 // PTX ABI requires integral types to be at least 32 bits in size. FP16 is
1414 // loaded/stored using i16, so it's handled here as well.
1415 const unsigned SizeBits = promoteScalarArgumentSize(size: Size * 8);
1416 SDValue Declare =
1417 DAG.getNode(Opcode: NVPTXISD::DeclareScalarParam, DL: dl, ResultTys: {MVT::Other, MVT::Glue},
1418 Ops: {StartChain, Symbol, GetI32(SizeBits), DeclareGlue});
1419 CallPrereqs.push_back(Elt: Declare);
1420 DeclareGlue = Declare.getValue(R: 1);
1421 return Declare;
1422 };
1423
1424 const auto MakeDeclareArrayParam = [&](SDValue Symbol, Align Align,
1425 unsigned Size) {
1426 SDValue Declare = DAG.getNode(
1427 Opcode: NVPTXISD::DeclareArrayParam, DL: dl, ResultTys: {MVT::Other, MVT::Glue},
1428 Ops: {StartChain, Symbol, GetI32(Align.value()), GetI32(Size), DeclareGlue});
1429 CallPrereqs.push_back(Elt: Declare);
1430 DeclareGlue = Declare.getValue(R: 1);
1431 return Declare;
1432 };
1433
1434 // Variadic arguments.
1435 //
1436 // Normally, for each argument, we declare a param scalar or a param
1437 // byte array in the .param space, and store the argument value to that
1438 // param scalar or array starting at offset 0.
1439 //
1440 // In the case of the first variadic argument, we declare a vararg byte array
1441 // with size 0. The exact size of this array isn't known at this point, so
1442 // it'll be patched later. All the variadic arguments will be stored to this
1443 // array at a certain offset (which gets tracked by 'VAOffset'). The offset is
1444 // initially set to 0, so it can be used for non-variadic arguments (which use
1445 // 0 offset) to simplify the code.
1446 //
1447 // After all vararg is processed, 'VAOffset' holds the size of the
1448 // vararg byte array.
1449 assert((CLI.IsVarArg || CLI.Args.size() == CLI.NumFixedArgs) &&
1450 "Non-VarArg function with extra arguments");
1451
1452 const unsigned FirstVAArg = CLI.NumFixedArgs; // position of first variadic
1453 unsigned VAOffset = 0; // current offset in the param array
1454
1455 const SDValue VADeclareParam =
1456 CLI.Args.size() > FirstVAArg
1457 ? MakeDeclareArrayParam(getCallParamSymbol(DAG, I: FirstVAArg, T: MVT::i32),
1458 Align(STI.getMaxRequiredAlignment()), 0)
1459 : SDValue();
1460
1461 // Args.size() and Outs.size() need not match.
1462 // Outs.size() will be larger
1463 // * if there is an aggregate argument with multiple fields (each field
1464 // showing up separately in Outs)
1465 // * if there is a vector argument with more than typical vector-length
1466 // elements (generally if more than 4) where each vector element is
1467 // individually present in Outs.
1468 // So a different index should be used for indexing into Outs/OutVals.
1469 // See similar issue in LowerFormalArguments.
1470 auto AllOuts = ArrayRef(CLI.Outs);
1471 auto AllOutVals = ArrayRef(CLI.OutVals);
1472 assert(AllOuts.size() == AllOutVals.size() &&
1473 "Outs and OutVals must be the same size");
1474 // Declare the .params or .reg need to pass values
1475 // to the function
1476 for (const auto E : llvm::enumerate(First&: Args)) {
1477 const auto ArgI = E.index();
1478 const auto Arg = E.value();
1479 const auto ArgOuts =
1480 AllOuts.take_while(Pred: [&](auto O) { return O.OrigArgIndex == ArgI; });
1481 const auto ArgOutVals = AllOutVals.take_front(N: ArgOuts.size());
1482 AllOuts = AllOuts.drop_front(N: ArgOuts.size());
1483 AllOutVals = AllOutVals.drop_front(N: ArgOuts.size());
1484
1485 const bool IsVAArg = (ArgI >= FirstVAArg);
1486 const bool IsByVal = Arg.IsByVal;
1487
1488 const SDValue ParamSymbol =
1489 getCallParamSymbol(DAG, I: IsVAArg ? FirstVAArg : ArgI, T: MVT::i32);
1490
1491 assert((!IsByVal || Arg.IndirectType) &&
1492 "byval arg must have indirect type");
1493 Type *ETy = (IsByVal ? Arg.IndirectType : Arg.Ty);
1494
1495 const Align ArgAlign = [&]() {
1496 if (IsByVal) {
1497 // The ByValAlign in the Outs[OIdx].Flags is always set at this point,
1498 // so we don't need to worry whether it's naturally aligned or not.
1499 // See TargetLowering::LowerCallTo().
1500 const Align InitialAlign = ArgOuts[0].Flags.getNonZeroByValAlign();
1501 return getFunctionByValParamAlign(F: CB->getCalledFunction(), ArgTy: ETy,
1502 InitialAlign, DL);
1503 }
1504 return getArgumentAlignment(CB, Ty: Arg.Ty, Idx: ArgI + 1, DL);
1505 }();
1506
1507 const unsigned TySize = DL.getTypeAllocSize(Ty: ETy);
1508 assert((!IsByVal || TySize == ArgOuts[0].Flags.getByValSize()) &&
1509 "type size mismatch");
1510
1511 const SDValue ArgDeclare = [&]() {
1512 if (IsVAArg)
1513 return VADeclareParam;
1514
1515 if (IsByVal || shouldPassAsArray(Ty: Arg.Ty))
1516 return MakeDeclareArrayParam(ParamSymbol, ArgAlign, TySize);
1517
1518 assert(ArgOuts.size() == 1 && "We must pass only one value as non-array");
1519 assert((ArgOuts[0].VT.isInteger() || ArgOuts[0].VT.isFloatingPoint()) &&
1520 "Only int and float types are supported as non-array arguments");
1521
1522 return MakeDeclareScalarParam(ParamSymbol, TySize);
1523 }();
1524
1525 if (IsByVal) {
1526 assert(ArgOutVals.size() == 1 && "We must pass only one value as byval");
1527 SDValue SrcPtr = ArgOutVals[0];
1528 const auto PointerInfo = refinePtrAS(Ptr&: SrcPtr, DAG, DL, TL: *this);
1529 const Align BaseSrcAlign = ArgOuts[0].Flags.getNonZeroByValAlign();
1530
1531 if (IsVAArg)
1532 VAOffset = alignTo(Size: VAOffset, A: ArgAlign);
1533
1534 SmallVector<EVT, 4> ValueVTs, MemVTs;
1535 SmallVector<TypeSize, 4> Offsets;
1536 ComputeValueVTs(TLI: *this, DL, Ty: ETy, ValueVTs, MemVTs: &MemVTs, Offsets: &Offsets);
1537
1538 unsigned J = 0;
1539 const auto VI = VectorizePTXValueVTs(ValueVTs: MemVTs, Offsets, ParamAlignment: ArgAlign, IsVAArg);
1540 for (const unsigned NumElts : VI) {
1541 EVT LoadVT = getVectorizedVT(VT: MemVTs[J], N: NumElts, C&: Ctx);
1542 Align SrcAlign = commonAlignment(A: BaseSrcAlign, Offset: Offsets[J]);
1543 SDValue SrcAddr = DAG.getObjectPtrOffset(SL: dl, Ptr: SrcPtr, Offset: Offsets[J]);
1544 SDValue SrcLoad =
1545 DAG.getLoad(VT: LoadVT, dl, Chain: CallChain, Ptr: SrcAddr, PtrInfo: PointerInfo, Alignment: SrcAlign);
1546
1547 TypeSize ParamOffset = Offsets[J].getWithIncrement(RHS: VAOffset);
1548 Align ParamAlign = commonAlignment(A: ArgAlign, Offset: ParamOffset);
1549 SDValue ParamAddr =
1550 DAG.getObjectPtrOffset(SL: dl, Ptr: ParamSymbol, Offset: ParamOffset);
1551 SDValue StoreParam =
1552 DAG.getStore(Chain: ArgDeclare, dl, Val: SrcLoad, Ptr: ParamAddr,
1553 PtrInfo: MachinePointerInfo(ADDRESS_SPACE_PARAM), Alignment: ParamAlign);
1554 CallPrereqs.push_back(Elt: StoreParam);
1555
1556 J += NumElts;
1557 }
1558 if (IsVAArg)
1559 VAOffset += TySize;
1560 } else {
1561 SmallVector<EVT, 16> VTs;
1562 SmallVector<uint64_t, 16> Offsets;
1563 ComputePTXValueVTs(TLI: *this, DL, Ctx, CallConv: CLI.CallConv, Ty: Arg.Ty, ValueVTs&: VTs, Offsets,
1564 StartingOffset: VAOffset);
1565 assert(VTs.size() == Offsets.size() && "Size mismatch");
1566 assert(VTs.size() == ArgOuts.size() && "Size mismatch");
1567
1568 // PTX Interoperability Guide 3.3(A): [Integer] Values shorter
1569 // than 32-bits are sign extended or zero extended, depending on
1570 // whether they are signed or unsigned types. This case applies
1571 // only to scalar parameters and not to aggregate values.
1572 const bool ExtendIntegerParam =
1573 Arg.Ty->isIntegerTy() && DL.getTypeAllocSizeInBits(Ty: Arg.Ty) < 32;
1574
1575 const auto GetStoredValue = [&](const unsigned I) {
1576 SDValue StVal = ArgOutVals[I];
1577 assert(promoteScalarIntegerPTX(StVal.getValueType()) ==
1578 StVal.getValueType() &&
1579 "OutVal type should always be legal");
1580
1581 const EVT VTI = promoteScalarIntegerPTX(VT: VTs[I]);
1582 const EVT StoreVT =
1583 ExtendIntegerParam ? MVT::i32 : (VTI == MVT::i1 ? MVT::i8 : VTI);
1584
1585 return correctParamType(V: StVal, ExpectedVT: StoreVT, Flags: ArgOuts[I].Flags, DAG, dl);
1586 };
1587
1588 unsigned J = 0;
1589 const auto VI = VectorizePTXValueVTs(ValueVTs: VTs, Offsets, ParamAlignment: ArgAlign, IsVAArg);
1590 for (const unsigned NumElts : VI) {
1591 const EVT EltVT = promoteScalarIntegerPTX(VT: VTs[J]);
1592
1593 unsigned Offset;
1594 if (IsVAArg) {
1595 // TODO: We may need to support vector types that can be passed
1596 // as scalars in variadic arguments.
1597 assert(NumElts == 1 &&
1598 "Vectorization should be disabled for vaargs.");
1599
1600 // Align each part of the variadic argument to their type.
1601 VAOffset = alignTo(Size: VAOffset, A: DAG.getEVTAlign(MemoryVT: EltVT));
1602 Offset = VAOffset;
1603
1604 const EVT TheStoreType = ExtendIntegerParam ? MVT::i32 : EltVT;
1605 VAOffset += DL.getTypeAllocSize(Ty: TheStoreType.getTypeForEVT(Context&: Ctx));
1606 } else {
1607 assert(VAOffset == 0 && "VAOffset must be 0 for non-VA args");
1608 Offset = Offsets[J];
1609 }
1610
1611 SDValue Ptr =
1612 DAG.getObjectPtrOffset(SL: dl, Ptr: ParamSymbol, Offset: TypeSize::getFixed(ExactSize: Offset));
1613
1614 const MaybeAlign CurrentAlign = ExtendIntegerParam
1615 ? MaybeAlign(std::nullopt)
1616 : commonAlignment(A: ArgAlign, Offset);
1617
1618 SDValue Val =
1619 getBuildVectorizedValue(N: NumElts, dl, DAG, GetElement: [&](unsigned K) {
1620 return GetStoredValue(J + K);
1621 });
1622
1623 SDValue StoreParam =
1624 DAG.getStore(Chain: ArgDeclare, dl, Val, Ptr,
1625 PtrInfo: MachinePointerInfo(ADDRESS_SPACE_PARAM), Alignment: CurrentAlign);
1626 CallPrereqs.push_back(Elt: StoreParam);
1627
1628 J += NumElts;
1629 }
1630 }
1631 }
1632
1633 // Handle Result
1634 if (!Ins.empty()) {
1635 const SDValue RetSymbol = DAG.getExternalSymbol(Sym: "retval0", VT: MVT::i32);
1636 const unsigned ResultSize = DL.getTypeAllocSize(Ty: RetTy);
1637 if (shouldPassAsArray(Ty: RetTy)) {
1638 const Align RetAlign = getArgumentAlignment(CB, Ty: RetTy, Idx: 0, DL);
1639 MakeDeclareArrayParam(RetSymbol, RetAlign, ResultSize);
1640 } else {
1641 MakeDeclareScalarParam(RetSymbol, ResultSize);
1642 }
1643 }
1644
1645 // Set the size of the vararg param byte array if the callee is a variadic
1646 // function and the variadic part is not empty.
1647 if (VADeclareParam) {
1648 SDValue DeclareParamOps[] = {VADeclareParam.getOperand(i: 0),
1649 VADeclareParam.getOperand(i: 1),
1650 VADeclareParam.getOperand(i: 2), GetI32(VAOffset),
1651 VADeclareParam.getOperand(i: 4)};
1652 DAG.MorphNodeTo(N: VADeclareParam.getNode(), Opc: VADeclareParam.getOpcode(),
1653 VTs: VADeclareParam->getVTList(), Ops: DeclareParamOps);
1654 }
1655
1656 const auto *Func = dyn_cast<GlobalAddressSDNode>(Val: Callee.getNode());
1657 // If the type of the callsite does not match that of the function, convert
1658 // the callsite to an indirect call.
1659 const bool ConvertToIndirectCall = shouldConvertToIndirectCall(CB, Func);
1660
1661 // Both indirect calls and libcalls have nullptr Func. In order to distinguish
1662 // between them we must rely on the call site value which is valid for
1663 // indirect calls but is always null for libcalls.
1664 const bool IsIndirectCall = (!Func && CB) || ConvertToIndirectCall;
1665
1666 if (isa<ExternalSymbolSDNode>(Val: Callee)) {
1667 Function* CalleeFunc = nullptr;
1668
1669 // Try to find the callee in the current module.
1670 Callee = DAG.getSymbolFunctionGlobalAddress(Op: Callee, TargetFunction: &CalleeFunc);
1671 assert(CalleeFunc != nullptr && "Libcall callee must be set.");
1672
1673 // Set the "libcall callee" attribute to indicate that the function
1674 // must always have a declaration.
1675 CalleeFunc->addFnAttr(Kind: "nvptx-libcall-callee", Val: "true");
1676 }
1677
1678 if (IsIndirectCall) {
1679 // This is indirect function call case : PTX requires a prototype of the
1680 // form
1681 // proto_0 : .callprototype(.param .b32 _) _ (.param .b32 _);
1682 // to be emitted, and the label has to used as the last arg of call
1683 // instruction.
1684 // The prototype is embedded in a string and put as the operand for a
1685 // CallPrototype SDNode which will print out to the value of the string.
1686 const bool HasVAArgs = CLI.IsVarArg && (CLI.Args.size() > CLI.NumFixedArgs);
1687 std::string Proto =
1688 getPrototype(DL, RetTy, Args, Outs: CLI.Outs,
1689 FirstVAArg: HasVAArgs ? std::optional(FirstVAArg) : std::nullopt, CB: *CB,
1690 UniqueCallSite);
1691 const char *ProtoStr = nvTM->getStrPool().save(S: Proto).data();
1692 const SDValue PrototypeDeclare = DAG.getNode(
1693 Opcode: NVPTXISD::CallPrototype, DL: dl, VT: MVT::Other,
1694 Ops: {StartChain, DAG.getTargetExternalSymbol(Sym: ProtoStr, VT: MVT::i32)});
1695 CallPrereqs.push_back(Elt: PrototypeDeclare);
1696 }
1697
1698 const unsigned Proto = IsIndirectCall ? UniqueCallSite : 0;
1699 const unsigned NumArgs =
1700 std::min<unsigned>(a: CLI.NumFixedArgs + 1, b: Args.size());
1701 /// CALL(Chain, IsConvergent, IsIndirectCall/IsUniform, NumReturns,
1702 /// NumParams, Callee, Proto)
1703 const SDValue CallToken = DAG.getTokenFactor(DL: dl, Vals&: CallPrereqs);
1704 const SDValue Call = DAG.getNode(
1705 Opcode: NVPTXISD::CALL, DL: dl, VT: MVT::Other,
1706 Ops: {CallToken, GetI32(CLI.IsConvergent), GetI32(IsIndirectCall),
1707 GetI32(Ins.empty() ? 0 : 1), GetI32(NumArgs), Callee, GetI32(Proto)});
1708
1709 SmallVector<SDValue, 16> LoadChains{Call};
1710 SmallVector<SDValue, 16> ProxyRegOps;
1711 if (!Ins.empty()) {
1712 SmallVector<EVT, 16> VTs;
1713 SmallVector<uint64_t, 16> Offsets;
1714 ComputePTXValueVTs(TLI: *this, DL, Ctx, CallConv: CLI.CallConv, Ty: RetTy, ValueVTs&: VTs, Offsets);
1715 assert(VTs.size() == Ins.size() && "Bad value decomposition");
1716
1717 const Align RetAlign = getArgumentAlignment(CB, Ty: RetTy, Idx: 0, DL);
1718 const SDValue RetSymbol = DAG.getExternalSymbol(Sym: "retval0", VT: MVT::i32);
1719
1720 // PTX Interoperability Guide 3.3(A): [Integer] Values shorter than
1721 // 32-bits are sign extended or zero extended, depending on whether
1722 // they are signed or unsigned types.
1723 const bool ExtendIntegerRetVal =
1724 RetTy->isIntegerTy() && DL.getTypeAllocSizeInBits(Ty: RetTy) < 32;
1725
1726 unsigned I = 0;
1727 const auto VI = VectorizePTXValueVTs(ValueVTs: VTs, Offsets, ParamAlignment: RetAlign);
1728 for (const unsigned NumElts : VI) {
1729 const MaybeAlign CurrentAlign =
1730 ExtendIntegerRetVal ? MaybeAlign(std::nullopt)
1731 : commonAlignment(A: RetAlign, Offset: Offsets[I]);
1732
1733 const EVT VTI = promoteScalarIntegerPTX(VT: VTs[I]);
1734 const EVT LoadVT =
1735 ExtendIntegerRetVal ? MVT::i32 : (VTI == MVT::i1 ? MVT::i8 : VTI);
1736 const EVT VecVT = getVectorizedVT(VT: LoadVT, N: NumElts, C&: Ctx);
1737 SDValue Ptr =
1738 DAG.getObjectPtrOffset(SL: dl, Ptr: RetSymbol, Offset: TypeSize::getFixed(ExactSize: Offsets[I]));
1739
1740 SDValue R =
1741 DAG.getLoad(VT: VecVT, dl, Chain: Call, Ptr,
1742 PtrInfo: MachinePointerInfo(ADDRESS_SPACE_PARAM), Alignment: CurrentAlign);
1743
1744 LoadChains.push_back(Elt: R.getValue(R: 1));
1745 for (const unsigned J : llvm::seq(Size: NumElts))
1746 ProxyRegOps.push_back(Elt: getExtractVectorizedValue(V: R, I: J, VT: LoadVT, dl, DAG));
1747 I += NumElts;
1748 }
1749 }
1750
1751 const SDValue EndToken = DAG.getTokenFactor(DL: dl, Vals&: LoadChains);
1752 const SDValue CallEnd = DAG.getCALLSEQ_END(Chain: EndToken, Size1: UniqueCallSite,
1753 Size2: UniqueCallSite + 1, Glue: SDValue(), DL: dl);
1754
1755 // Append ProxyReg instructions to the chain to make sure that `callseq_end`
1756 // will not get lost. Otherwise, during libcalls expansion, the nodes can become
1757 // dangling.
1758 for (const auto [I, Reg] : llvm::enumerate(First&: ProxyRegOps)) {
1759 SDValue Proxy =
1760 DAG.getNode(Opcode: NVPTXISD::ProxyReg, DL: dl, VT: Reg.getValueType(), Ops: {CallEnd, Reg});
1761 SDValue Ret = correctParamType(V: Proxy, ExpectedVT: Ins[I].VT, Flags: Ins[I].Flags, DAG, dl);
1762 InVals.push_back(Elt: Ret);
1763 }
1764
1765 // set IsTailCall to false for now, until we figure out how to express
1766 // tail call optimization in PTX
1767 CLI.IsTailCall = false;
1768 return CallEnd;
1769}
1770
1771SDValue NVPTXTargetLowering::LowerDYNAMIC_STACKALLOC(SDValue Op,
1772 SelectionDAG &DAG) const {
1773
1774 if (STI.getPTXVersion() < 73 || STI.getSmVersion() < 52) {
1775 const Function &Fn = DAG.getMachineFunction().getFunction();
1776
1777 DAG.getContext()->diagnose(DI: DiagnosticInfoUnsupported(
1778 Fn,
1779 "Support for dynamic alloca introduced in PTX ISA version 7.3 and "
1780 "requires target sm_52.",
1781 SDLoc(Op).getDebugLoc()));
1782 auto Ops = {DAG.getConstant(Val: 0, DL: SDLoc(), VT: Op.getValueType()),
1783 Op.getOperand(i: 0)};
1784 return DAG.getMergeValues(Ops, dl: SDLoc());
1785 }
1786
1787 SDLoc DL(Op.getNode());
1788 SDValue Chain = Op.getOperand(i: 0);
1789 SDValue Size = Op.getOperand(i: 1);
1790 uint64_t Align = Op.getConstantOperandVal(i: 2);
1791
1792 // The alignment on a ISD::DYNAMIC_STACKALLOC node may be 0 to indicate that
1793 // the default stack alignment should be used.
1794 if (Align == 0)
1795 Align = DAG.getSubtarget().getFrameLowering()->getStackAlign().value();
1796
1797 // The size for ptx alloca instruction is 64-bit for m64 and 32-bit for m32.
1798 const MVT LocalVT = getPointerTy(DL: DAG.getDataLayout(), AS: ADDRESS_SPACE_LOCAL);
1799
1800 SDValue Alloc =
1801 DAG.getNode(Opcode: NVPTXISD::DYNAMIC_STACKALLOC, DL, ResultTys: {LocalVT, MVT::Other},
1802 Ops: {Chain, DAG.getZExtOrTrunc(Op: Size, DL, VT: LocalVT),
1803 DAG.getTargetConstant(Val: Align, DL, VT: MVT::i32)});
1804
1805 SDValue ASC = DAG.getAddrSpaceCast(
1806 dl: DL, VT: Op.getValueType(), Ptr: Alloc, SrcAS: ADDRESS_SPACE_LOCAL, DestAS: ADDRESS_SPACE_GENERIC);
1807
1808 return DAG.getMergeValues(Ops: {ASC, SDValue(Alloc.getNode(), 1)}, dl: DL);
1809}
1810
1811SDValue NVPTXTargetLowering::LowerSTACKRESTORE(SDValue Op,
1812 SelectionDAG &DAG) const {
1813 SDLoc DL(Op.getNode());
1814 if (STI.getPTXVersion() < 73 || STI.getSmVersion() < 52) {
1815 const Function &Fn = DAG.getMachineFunction().getFunction();
1816
1817 DAG.getContext()->diagnose(DI: DiagnosticInfoUnsupported(
1818 Fn,
1819 "Support for stackrestore requires PTX ISA version >= 7.3 and target "
1820 ">= sm_52.",
1821 DL.getDebugLoc()));
1822 return Op.getOperand(i: 0);
1823 }
1824
1825 const MVT LocalVT = getPointerTy(DL: DAG.getDataLayout(), AS: ADDRESS_SPACE_LOCAL);
1826 SDValue Chain = Op.getOperand(i: 0);
1827 SDValue Ptr = Op.getOperand(i: 1);
1828 SDValue ASC = DAG.getAddrSpaceCast(dl: DL, VT: LocalVT, Ptr, SrcAS: ADDRESS_SPACE_GENERIC,
1829 DestAS: ADDRESS_SPACE_LOCAL);
1830 return DAG.getNode(Opcode: NVPTXISD::STACKRESTORE, DL, VT: MVT::Other, Ops: {Chain, ASC});
1831}
1832
1833SDValue NVPTXTargetLowering::LowerSTACKSAVE(SDValue Op,
1834 SelectionDAG &DAG) const {
1835 SDLoc DL(Op.getNode());
1836 if (STI.getPTXVersion() < 73 || STI.getSmVersion() < 52) {
1837 const Function &Fn = DAG.getMachineFunction().getFunction();
1838
1839 DAG.getContext()->diagnose(DI: DiagnosticInfoUnsupported(
1840 Fn,
1841 "Support for stacksave requires PTX ISA version >= 7.3 and target >= "
1842 "sm_52.",
1843 DL.getDebugLoc()));
1844 auto Ops = {DAG.getConstant(Val: 0, DL, VT: Op.getValueType()), Op.getOperand(i: 0)};
1845 return DAG.getMergeValues(Ops, dl: DL);
1846 }
1847
1848 const MVT LocalVT = getPointerTy(DL: DAG.getDataLayout(), AS: ADDRESS_SPACE_LOCAL);
1849 SDValue Chain = Op.getOperand(i: 0);
1850 SDValue SS =
1851 DAG.getNode(Opcode: NVPTXISD::STACKSAVE, DL, ResultTys: {LocalVT, MVT::Other}, Ops: Chain);
1852 SDValue ASC = DAG.getAddrSpaceCast(
1853 dl: DL, VT: Op.getValueType(), Ptr: SS, SrcAS: ADDRESS_SPACE_LOCAL, DestAS: ADDRESS_SPACE_GENERIC);
1854 return DAG.getMergeValues(Ops: {ASC, SDValue(SS.getNode(), 1)}, dl: DL);
1855}
1856
1857// By default CONCAT_VECTORS is lowered by ExpandVectorBuildThroughStack()
1858// (see LegalizeDAG.cpp). This is slow and uses local memory.
1859// We use extract/insert/build vector just as what LegalizeOp() does in llvm 2.5
1860SDValue
1861NVPTXTargetLowering::LowerCONCAT_VECTORS(SDValue Op, SelectionDAG &DAG) const {
1862 SDNode *Node = Op.getNode();
1863 SDLoc dl(Node);
1864 SmallVector<SDValue, 8> Ops;
1865 unsigned NumOperands = Node->getNumOperands();
1866 for (unsigned i = 0; i < NumOperands; ++i) {
1867 SDValue SubOp = Node->getOperand(Num: i);
1868 EVT VVT = SubOp.getNode()->getValueType(ResNo: 0);
1869 EVT EltVT = VVT.getVectorElementType();
1870 unsigned NumSubElem = VVT.getVectorNumElements();
1871 for (unsigned j = 0; j < NumSubElem; ++j) {
1872 Ops.push_back(Elt: DAG.getNode(Opcode: ISD::EXTRACT_VECTOR_ELT, DL: dl, VT: EltVT, N1: SubOp,
1873 N2: DAG.getIntPtrConstant(Val: j, DL: dl)));
1874 }
1875 }
1876 return DAG.getBuildVector(VT: Node->getValueType(ResNo: 0), DL: dl, Ops);
1877}
1878
1879static SDValue getPRMT(SDValue A, SDValue B, SDValue Selector, SDLoc DL,
1880 SelectionDAG &DAG,
1881 unsigned Mode = NVPTX::PTXPrmtMode::NONE) {
1882 assert(A.getValueType() == MVT::i32 && B.getValueType() == MVT::i32 &&
1883 Selector.getValueType() == MVT::i32 && "PRMT must have i32 operands");
1884 return DAG.getNode(Opcode: NVPTXISD::PRMT, DL, VT: MVT::i32,
1885 Ops: {A, B, Selector, DAG.getConstant(Val: Mode, DL, VT: MVT::i32)});
1886}
1887
1888static SDValue getPRMT(SDValue A, SDValue B, uint64_t Selector, SDLoc DL,
1889 SelectionDAG &DAG,
1890 unsigned Mode = NVPTX::PTXPrmtMode::NONE) {
1891 return getPRMT(A, B, Selector: DAG.getConstant(Val: Selector, DL, VT: MVT::i32), DL, DAG, Mode);
1892}
1893
1894/// Reduces the elements using the scalar operations provided. The operations
1895/// are sorted descending in number of inputs they take. The flags on the
1896/// original reduction operation will be propagated to each scalar operation.
1897/// Nearby elements are grouped in tree reduction, unlike the shuffle reduction
1898/// used in ExpandReductions and SelectionDAG.
1899static SDValue buildTreeReduction(
1900 const SmallVector<SDValue> &Elements, EVT EltTy,
1901 ArrayRef<std::pair<unsigned /*NodeType*/, unsigned /*NumInputs*/>> Ops,
1902 const SDLoc &DL, const SDNodeFlags Flags, SelectionDAG &DAG) {
1903 // Build the reduction tree at each level, starting with all the elements.
1904 SmallVector<SDValue> Level = Elements;
1905
1906 unsigned OpIdx = 0;
1907 while (Level.size() > 1) {
1908 // Try to reduce this level using the current operator.
1909 const auto [Op, NumInputs] = Ops[OpIdx];
1910
1911 // Build the next level by partially reducing all elements.
1912 SmallVector<SDValue> ReducedLevel;
1913 unsigned I = 0, E = Level.size();
1914 for (; I + NumInputs <= E; I += NumInputs) {
1915 // Reduce elements in groups of [NumInputs], as much as possible.
1916 ReducedLevel.push_back(Elt: DAG.getNode(
1917 Opcode: Op, DL, VT: EltTy, Ops: ArrayRef<SDValue>(Level).slice(N: I, M: NumInputs), Flags));
1918 }
1919
1920 if (I < E) {
1921 // Handle leftover elements.
1922
1923 if (ReducedLevel.empty()) {
1924 // We didn't reduce anything at this level. We need to pick a smaller
1925 // operator.
1926 ++OpIdx;
1927 assert(OpIdx < Ops.size() && "no smaller operators for reduction");
1928 continue;
1929 }
1930
1931 // We reduced some things but there's still more left, meaning the
1932 // operator's number of inputs doesn't evenly divide this level size. Move
1933 // these elements to the next level.
1934 for (; I < E; ++I)
1935 ReducedLevel.push_back(Elt: Level[I]);
1936 }
1937
1938 // Process the next level.
1939 Level = ReducedLevel;
1940 }
1941
1942 return *Level.begin();
1943}
1944
1945// Get scalar reduction opcode
1946static ISD::NodeType getScalarOpcodeForReduction(unsigned ReductionOpcode) {
1947 switch (ReductionOpcode) {
1948 case ISD::VECREDUCE_FMAX:
1949 return ISD::FMAXNUM;
1950 case ISD::VECREDUCE_FMIN:
1951 return ISD::FMINNUM;
1952 case ISD::VECREDUCE_FMAXIMUM:
1953 return ISD::FMAXIMUM;
1954 case ISD::VECREDUCE_FMINIMUM:
1955 return ISD::FMINIMUM;
1956 default:
1957 llvm_unreachable("unhandled reduction opcode");
1958 }
1959}
1960
1961/// Get 3-input scalar reduction opcode
1962static std::optional<unsigned>
1963getScalar3OpcodeForReduction(unsigned ReductionOpcode) {
1964 switch (ReductionOpcode) {
1965 case ISD::VECREDUCE_FMAX:
1966 return NVPTXISD::FMAXNUM3;
1967 case ISD::VECREDUCE_FMIN:
1968 return NVPTXISD::FMINNUM3;
1969 case ISD::VECREDUCE_FMAXIMUM:
1970 return NVPTXISD::FMAXIMUM3;
1971 case ISD::VECREDUCE_FMINIMUM:
1972 return NVPTXISD::FMINIMUM3;
1973 default:
1974 return std::nullopt;
1975 }
1976}
1977
1978/// Lower reductions to either a sequence of operations or a tree if
1979/// reassociations are allowed. This method will use larger operations like
1980/// max3/min3 when the target supports them.
1981SDValue NVPTXTargetLowering::LowerVECREDUCE(SDValue Op,
1982 SelectionDAG &DAG) const {
1983 SDLoc DL(Op);
1984 const SDNodeFlags Flags = Op->getFlags();
1985 SDValue Vector = Op.getOperand(i: 0);
1986
1987 const unsigned Opcode = Op->getOpcode();
1988 const EVT EltTy = Vector.getValueType().getVectorElementType();
1989
1990 // Whether we can use 3-input min/max when expanding the reduction.
1991 const bool CanUseMinMax3 =
1992 EltTy == MVT::f32 && STI.getSmVersion() >= 100 &&
1993 STI.getPTXVersion() >= 88 &&
1994 (Opcode == ISD::VECREDUCE_FMAX || Opcode == ISD::VECREDUCE_FMIN ||
1995 Opcode == ISD::VECREDUCE_FMAXIMUM || Opcode == ISD::VECREDUCE_FMINIMUM);
1996
1997 // A list of SDNode opcodes with equivalent semantics, sorted descending by
1998 // number of inputs they take.
1999 SmallVector<std::pair<unsigned /*Op*/, unsigned /*NumIn*/>, 2> ScalarOps;
2000
2001 if (auto Opcode3Elem = getScalar3OpcodeForReduction(ReductionOpcode: Opcode);
2002 CanUseMinMax3 && Opcode3Elem)
2003 ScalarOps.push_back(Elt: {*Opcode3Elem, 3});
2004 ScalarOps.push_back(Elt: {getScalarOpcodeForReduction(ReductionOpcode: Opcode), 2});
2005
2006 SmallVector<SDValue> Elements;
2007 DAG.ExtractVectorElements(Op: Vector, Args&: Elements);
2008
2009 return buildTreeReduction(Elements, EltTy, Ops: ScalarOps, DL, Flags, DAG);
2010}
2011
2012SDValue NVPTXTargetLowering::LowerBITCAST(SDValue Op, SelectionDAG &DAG) const {
2013 // Handle bitcasting from v2i8 without hitting the default promotion
2014 // strategy which goes through stack memory.
2015 EVT FromVT = Op->getOperand(Num: 0)->getValueType(ResNo: 0);
2016 if (FromVT != MVT::v2i8) {
2017 return Op;
2018 }
2019
2020 // Pack vector elements into i16 and bitcast to final type
2021 SDLoc DL(Op);
2022 SDValue Vec0 = DAG.getNode(Opcode: ISD::EXTRACT_VECTOR_ELT, DL, VT: MVT::i8,
2023 N1: Op->getOperand(Num: 0), N2: DAG.getIntPtrConstant(Val: 0, DL));
2024 SDValue Vec1 = DAG.getNode(Opcode: ISD::EXTRACT_VECTOR_ELT, DL, VT: MVT::i8,
2025 N1: Op->getOperand(Num: 0), N2: DAG.getIntPtrConstant(Val: 1, DL));
2026 SDValue Extend0 = DAG.getNode(Opcode: ISD::ZERO_EXTEND, DL, VT: MVT::i16, Operand: Vec0);
2027 SDValue Extend1 = DAG.getNode(Opcode: ISD::ZERO_EXTEND, DL, VT: MVT::i16, Operand: Vec1);
2028 SDValue Const8 = DAG.getConstant(Val: 8, DL, VT: MVT::i16);
2029 SDValue AsInt = DAG.getNode(
2030 Opcode: ISD::OR, DL, VT: MVT::i16,
2031 Ops: {Extend0, DAG.getNode(Opcode: ISD::SHL, DL, VT: MVT::i16, Ops: {Extend1, Const8})});
2032 EVT ToVT = Op->getValueType(ResNo: 0);
2033 return DAG.getBitcast(VT: ToVT, V: AsInt);
2034}
2035
2036// We can init constant f16x2/v2i16/v4i8 with a single .b32 move. Normally it
2037// would get lowered as two constant loads and vector-packing move.
2038// Instead we want just a constant move:
2039// mov.b32 %r2, 0x40003C00
2040SDValue NVPTXTargetLowering::LowerBUILD_VECTOR(SDValue Op,
2041 SelectionDAG &DAG) const {
2042 EVT VT = Op->getValueType(ResNo: 0);
2043 if (!(NVPTX::isPackedVectorTy(VT) && VT.is32BitVector()))
2044 return Op;
2045 SDLoc DL(Op);
2046
2047 if (!llvm::all_of(Range: Op->ops(), P: [](SDValue Operand) {
2048 return Operand->isUndef() || isa<ConstantSDNode>(Val: Operand) ||
2049 isa<ConstantFPSDNode>(Val: Operand);
2050 })) {
2051 if (VT != MVT::v4i8)
2052 return Op;
2053 // Lower non-const v4i8 vector as byte-wise constructed i32, which allows us
2054 // to optimize calculation of constant parts.
2055 auto GetPRMT = [&](const SDValue Left, const SDValue Right, bool Cast,
2056 uint64_t SelectionValue) -> SDValue {
2057 SDValue L = Left;
2058 SDValue R = Right;
2059 if (Cast) {
2060 L = DAG.getAnyExtOrTrunc(Op: L, DL, VT: MVT::i32);
2061 R = DAG.getAnyExtOrTrunc(Op: R, DL, VT: MVT::i32);
2062 }
2063 return getPRMT(A: L, B: R, Selector: SelectionValue, DL, DAG);
2064 };
2065 auto PRMT__10 = GetPRMT(Op->getOperand(Num: 0), Op->getOperand(Num: 1), true, 0x3340);
2066 auto PRMT__32 = GetPRMT(Op->getOperand(Num: 2), Op->getOperand(Num: 3), true, 0x3340);
2067 auto PRMT3210 = GetPRMT(PRMT__10, PRMT__32, false, 0x5410);
2068 return DAG.getBitcast(VT, V: PRMT3210);
2069 }
2070
2071 // Get value or the Nth operand as an APInt(32). Undef values treated as 0.
2072 auto GetOperand = [](SDValue Op, int N) -> APInt {
2073 const SDValue &Operand = Op->getOperand(Num: N);
2074 EVT VT = Op->getValueType(ResNo: 0);
2075 if (Operand->isUndef())
2076 return APInt(32, 0);
2077 APInt Value;
2078 if (VT == MVT::v2f16 || VT == MVT::v2bf16)
2079 Value = cast<ConstantFPSDNode>(Val: Operand)->getValueAPF().bitcastToAPInt();
2080 else if (VT == MVT::v2i16 || VT == MVT::v4i8)
2081 Value = Operand->getAsAPIntVal();
2082 else
2083 llvm_unreachable("Unsupported type");
2084 // i8 values are carried around as i16, so we need to zero out upper bits,
2085 // so they do not get in the way of combining individual byte values
2086 if (VT == MVT::v4i8)
2087 Value = Value.trunc(width: 8);
2088 return Value.zext(width: 32);
2089 };
2090
2091 // Construct a 32-bit constant by shifting into place smaller values
2092 // (elements of the vector type VT).
2093 // For example, if VT has 2 elements, then N == 2:
2094 // ShiftAmount = 32 / N = 16
2095 // Value |= Op0 (b16) << 0
2096 // Value |= Op1 (b16) << 16
2097 // If N == 4:
2098 // ShiftAmount = 32 / N = 8
2099 // Value |= Op0 (b8) << 0
2100 // Value |= Op1 (b8) << 8
2101 // Value |= Op2 (b8) << 16
2102 // Value |= Op3 (b8) << 24
2103 // ...etc
2104 APInt Value(32, 0);
2105 const unsigned NumElements = VT.getVectorNumElements();
2106 assert(32 % NumElements == 0 && "must evenly divide bit length");
2107 const unsigned ShiftAmount = 32 / NumElements;
2108 for (unsigned ElementNo : seq(Size: NumElements))
2109 Value |= GetOperand(Op, ElementNo).shl(shiftAmt: ElementNo * ShiftAmount);
2110 SDValue Const = DAG.getConstant(Val: Value, DL, VT: MVT::i32);
2111 return DAG.getNode(Opcode: ISD::BITCAST, DL, VT: Op->getValueType(ResNo: 0), Operand: Const);
2112}
2113
2114SDValue NVPTXTargetLowering::LowerEXTRACT_VECTOR_ELT(SDValue Op,
2115 SelectionDAG &DAG) const {
2116 SDValue Index = Op->getOperand(Num: 1);
2117 SDValue Vector = Op->getOperand(Num: 0);
2118 SDLoc DL(Op);
2119 EVT VectorVT = Vector.getValueType();
2120
2121 if (VectorVT == MVT::v4i8) {
2122 SDValue Selector = DAG.getNode(Opcode: ISD::OR, DL, VT: MVT::i32,
2123 N1: DAG.getZExtOrTrunc(Op: Index, DL, VT: MVT::i32),
2124 N2: DAG.getConstant(Val: 0x7770, DL, VT: MVT::i32));
2125 SDValue PRMT = getPRMT(A: DAG.getBitcast(VT: MVT::i32, V: Vector),
2126 B: DAG.getConstant(Val: 0, DL, VT: MVT::i32), Selector, DL, DAG);
2127 SDValue Ext = DAG.getAnyExtOrTrunc(Op: PRMT, DL, VT: Op->getValueType(ResNo: 0));
2128 SDNodeFlags Flags;
2129 Flags.setNoSignedWrap(Ext.getScalarValueSizeInBits() > 8);
2130 Flags.setNoUnsignedWrap(Ext.getScalarValueSizeInBits() >= 8);
2131 Ext->setFlags(Flags);
2132 return Ext;
2133 }
2134
2135 // Constant index will be matched by tablegen.
2136 if (isa<ConstantSDNode>(Val: Index.getNode()))
2137 return Op;
2138
2139 // Extract individual elements and select one of them.
2140 assert(NVPTX::isPackedVectorTy(VectorVT) &&
2141 VectorVT.getVectorNumElements() == 2 && "Unexpected vector type.");
2142 EVT EltVT = VectorVT.getVectorElementType();
2143
2144 SDLoc dl(Op.getNode());
2145 SDValue E0 = DAG.getNode(Opcode: ISD::EXTRACT_VECTOR_ELT, DL: dl, VT: EltVT, N1: Vector,
2146 N2: DAG.getIntPtrConstant(Val: 0, DL: dl));
2147 SDValue E1 = DAG.getNode(Opcode: ISD::EXTRACT_VECTOR_ELT, DL: dl, VT: EltVT, N1: Vector,
2148 N2: DAG.getIntPtrConstant(Val: 1, DL: dl));
2149 return DAG.getSelectCC(DL: dl, LHS: Index, RHS: DAG.getIntPtrConstant(Val: 0, DL: dl), True: E0, False: E1,
2150 Cond: ISD::CondCode::SETEQ);
2151}
2152
2153SDValue NVPTXTargetLowering::LowerINSERT_VECTOR_ELT(SDValue Op,
2154 SelectionDAG &DAG) const {
2155 SDValue Vector = Op->getOperand(Num: 0);
2156 EVT VectorVT = Vector.getValueType();
2157
2158 if (VectorVT != MVT::v4i8)
2159 return Op;
2160 SDLoc DL(Op);
2161 SDValue Value = Op->getOperand(Num: 1);
2162 if (Value->isUndef())
2163 return Vector;
2164
2165 SDValue Index = Op->getOperand(Num: 2);
2166
2167 SDValue BFI =
2168 DAG.getNode(Opcode: NVPTXISD::BFI, DL, VT: MVT::i32,
2169 Ops: {DAG.getZExtOrTrunc(Op: Value, DL, VT: MVT::i32), Vector,
2170 DAG.getNode(Opcode: ISD::MUL, DL, VT: MVT::i32,
2171 N1: DAG.getZExtOrTrunc(Op: Index, DL, VT: MVT::i32),
2172 N2: DAG.getConstant(Val: 8, DL, VT: MVT::i32)),
2173 DAG.getConstant(Val: 8, DL, VT: MVT::i32)});
2174 return DAG.getNode(Opcode: ISD::BITCAST, DL, VT: Op->getValueType(ResNo: 0), Operand: BFI);
2175}
2176
2177SDValue NVPTXTargetLowering::LowerVECTOR_SHUFFLE(SDValue Op,
2178 SelectionDAG &DAG) const {
2179 SDValue V1 = Op.getOperand(i: 0);
2180 EVT VectorVT = V1.getValueType();
2181 if (VectorVT != MVT::v4i8 || Op.getValueType() != MVT::v4i8)
2182 return Op;
2183
2184 // Lower shuffle to PRMT instruction.
2185 const ShuffleVectorSDNode *SVN = cast<ShuffleVectorSDNode>(Val: Op.getNode());
2186 SDValue V2 = Op.getOperand(i: 1);
2187 uint32_t Selector = 0;
2188 for (auto I : llvm::enumerate(First: SVN->getMask())) {
2189 if (I.value() != -1) // -1 is a placeholder for undef.
2190 Selector |= (I.value() << (I.index() * 4));
2191 }
2192
2193 SDLoc DL(Op);
2194 SDValue PRMT = getPRMT(A: DAG.getBitcast(VT: MVT::i32, V: V1),
2195 B: DAG.getBitcast(VT: MVT::i32, V: V2), Selector, DL, DAG);
2196 return DAG.getBitcast(VT: Op.getValueType(), V: PRMT);
2197}
2198/// LowerShiftRightParts - Lower SRL_PARTS, SRA_PARTS, which
2199/// 1) returns two i32 values and take a 2 x i32 value to shift plus a shift
2200/// amount, or
2201/// 2) returns two i64 values and take a 2 x i64 value to shift plus a shift
2202/// amount.
2203SDValue NVPTXTargetLowering::LowerShiftRightParts(SDValue Op,
2204 SelectionDAG &DAG) const {
2205 assert(Op.getNumOperands() == 3 && "Not a double-shift!");
2206 assert(Op.getOpcode() == ISD::SRA_PARTS || Op.getOpcode() == ISD::SRL_PARTS);
2207
2208 EVT VT = Op.getValueType();
2209 unsigned VTBits = VT.getSizeInBits();
2210 SDLoc dl(Op);
2211 SDValue ShOpLo = Op.getOperand(i: 0);
2212 SDValue ShOpHi = Op.getOperand(i: 1);
2213 SDValue ShAmt = Op.getOperand(i: 2);
2214 unsigned Opc = (Op.getOpcode() == ISD::SRA_PARTS) ? ISD::SRA : ISD::SRL;
2215
2216 if (VTBits == 32 && STI.getSmVersion() >= 35) {
2217 // For 32bit and sm35, we can use the funnel shift 'shf' instruction.
2218 // {dHi, dLo} = {aHi, aLo} >> Amt
2219 // dHi = aHi >> Amt
2220 // dLo = shf.r.clamp aLo, aHi, Amt
2221
2222 SDValue Hi = DAG.getNode(Opcode: Opc, DL: dl, VT, N1: ShOpHi, N2: ShAmt);
2223 SDValue Lo =
2224 DAG.getNode(Opcode: NVPTXISD::FSHR_CLAMP, DL: dl, VT, N1: ShOpHi, N2: ShOpLo, N3: ShAmt);
2225
2226 SDValue Ops[2] = { Lo, Hi };
2227 return DAG.getMergeValues(Ops, dl);
2228 }
2229 else {
2230 // {dHi, dLo} = {aHi, aLo} >> Amt
2231 // - if (Amt>=size) then
2232 // dLo = aHi >> (Amt-size)
2233 // dHi = aHi >> Amt (this is either all 0 or all 1)
2234 // else
2235 // dLo = (aLo >>logic Amt) | (aHi << (size-Amt))
2236 // dHi = aHi >> Amt
2237
2238 SDValue RevShAmt = DAG.getNode(Opcode: ISD::SUB, DL: dl, VT: MVT::i32,
2239 N1: DAG.getConstant(Val: VTBits, DL: dl, VT: MVT::i32),
2240 N2: ShAmt);
2241 SDValue Tmp1 = DAG.getNode(Opcode: ISD::SRL, DL: dl, VT, N1: ShOpLo, N2: ShAmt);
2242 SDValue ExtraShAmt = DAG.getNode(Opcode: ISD::SUB, DL: dl, VT: MVT::i32, N1: ShAmt,
2243 N2: DAG.getConstant(Val: VTBits, DL: dl, VT: MVT::i32));
2244 SDValue Tmp2 = DAG.getNode(Opcode: ISD::SHL, DL: dl, VT, N1: ShOpHi, N2: RevShAmt);
2245 SDValue FalseVal = DAG.getNode(Opcode: ISD::OR, DL: dl, VT, N1: Tmp1, N2: Tmp2);
2246 SDValue TrueVal = DAG.getNode(Opcode: Opc, DL: dl, VT, N1: ShOpHi, N2: ExtraShAmt);
2247
2248 SDValue Cmp = DAG.getSetCC(DL: dl, VT: MVT::i1, LHS: ShAmt,
2249 RHS: DAG.getConstant(Val: VTBits, DL: dl, VT: MVT::i32),
2250 Cond: ISD::SETGE);
2251 SDValue Hi = DAG.getNode(Opcode: Opc, DL: dl, VT, N1: ShOpHi, N2: ShAmt);
2252 SDValue Lo = DAG.getNode(Opcode: ISD::SELECT, DL: dl, VT, N1: Cmp, N2: TrueVal, N3: FalseVal);
2253
2254 SDValue Ops[2] = { Lo, Hi };
2255 return DAG.getMergeValues(Ops, dl);
2256 }
2257}
2258
2259/// LowerShiftLeftParts - Lower SHL_PARTS, which
2260/// 1) returns two i32 values and take a 2 x i32 value to shift plus a shift
2261/// amount, or
2262/// 2) returns two i64 values and take a 2 x i64 value to shift plus a shift
2263/// amount.
2264SDValue NVPTXTargetLowering::LowerShiftLeftParts(SDValue Op,
2265 SelectionDAG &DAG) const {
2266 assert(Op.getNumOperands() == 3 && "Not a double-shift!");
2267 assert(Op.getOpcode() == ISD::SHL_PARTS);
2268
2269 EVT VT = Op.getValueType();
2270 unsigned VTBits = VT.getSizeInBits();
2271 SDLoc dl(Op);
2272 SDValue ShOpLo = Op.getOperand(i: 0);
2273 SDValue ShOpHi = Op.getOperand(i: 1);
2274 SDValue ShAmt = Op.getOperand(i: 2);
2275
2276 if (VTBits == 32 && STI.getSmVersion() >= 35) {
2277 // For 32bit and sm35, we can use the funnel shift 'shf' instruction.
2278 // {dHi, dLo} = {aHi, aLo} << Amt
2279 // dHi = shf.l.clamp aLo, aHi, Amt
2280 // dLo = aLo << Amt
2281
2282 SDValue Hi =
2283 DAG.getNode(Opcode: NVPTXISD::FSHL_CLAMP, DL: dl, VT, N1: ShOpHi, N2: ShOpLo, N3: ShAmt);
2284 SDValue Lo = DAG.getNode(Opcode: ISD::SHL, DL: dl, VT, N1: ShOpLo, N2: ShAmt);
2285
2286 SDValue Ops[2] = { Lo, Hi };
2287 return DAG.getMergeValues(Ops, dl);
2288 }
2289 else {
2290 // {dHi, dLo} = {aHi, aLo} << Amt
2291 // - if (Amt>=size) then
2292 // dLo = aLo << Amt (all 0)
2293 // dLo = aLo << (Amt-size)
2294 // else
2295 // dLo = aLo << Amt
2296 // dHi = (aHi << Amt) | (aLo >> (size-Amt))
2297
2298 SDValue RevShAmt = DAG.getNode(Opcode: ISD::SUB, DL: dl, VT: MVT::i32,
2299 N1: DAG.getConstant(Val: VTBits, DL: dl, VT: MVT::i32),
2300 N2: ShAmt);
2301 SDValue Tmp1 = DAG.getNode(Opcode: ISD::SHL, DL: dl, VT, N1: ShOpHi, N2: ShAmt);
2302 SDValue ExtraShAmt = DAG.getNode(Opcode: ISD::SUB, DL: dl, VT: MVT::i32, N1: ShAmt,
2303 N2: DAG.getConstant(Val: VTBits, DL: dl, VT: MVT::i32));
2304 SDValue Tmp2 = DAG.getNode(Opcode: ISD::SRL, DL: dl, VT, N1: ShOpLo, N2: RevShAmt);
2305 SDValue FalseVal = DAG.getNode(Opcode: ISD::OR, DL: dl, VT, N1: Tmp1, N2: Tmp2);
2306 SDValue TrueVal = DAG.getNode(Opcode: ISD::SHL, DL: dl, VT, N1: ShOpLo, N2: ExtraShAmt);
2307
2308 SDValue Cmp = DAG.getSetCC(DL: dl, VT: MVT::i1, LHS: ShAmt,
2309 RHS: DAG.getConstant(Val: VTBits, DL: dl, VT: MVT::i32),
2310 Cond: ISD::SETGE);
2311 SDValue Lo = DAG.getNode(Opcode: ISD::SHL, DL: dl, VT, N1: ShOpLo, N2: ShAmt);
2312 SDValue Hi = DAG.getNode(Opcode: ISD::SELECT, DL: dl, VT, N1: Cmp, N2: TrueVal, N3: FalseVal);
2313
2314 SDValue Ops[2] = { Lo, Hi };
2315 return DAG.getMergeValues(Ops, dl);
2316 }
2317}
2318
2319/// If the types match, convert the generic copysign to the NVPTXISD version,
2320/// otherwise bail ensuring that mismatched cases are properly expaned.
2321SDValue NVPTXTargetLowering::LowerFCOPYSIGN(SDValue Op,
2322 SelectionDAG &DAG) const {
2323 EVT VT = Op.getValueType();
2324 SDLoc DL(Op);
2325
2326 SDValue In1 = Op.getOperand(i: 0);
2327 SDValue In2 = Op.getOperand(i: 1);
2328 EVT SrcVT = In2.getValueType();
2329
2330 if (!SrcVT.bitsEq(VT))
2331 return SDValue();
2332
2333 return DAG.getNode(Opcode: NVPTXISD::FCOPYSIGN, DL, VT, N1: In1, N2: In2);
2334}
2335
2336SDValue NVPTXTargetLowering::LowerFROUND(SDValue Op, SelectionDAG &DAG) const {
2337 EVT VT = Op.getValueType();
2338
2339 if (VT == MVT::f32)
2340 return LowerFROUND32(Op, DAG);
2341
2342 if (VT == MVT::f64)
2343 return LowerFROUND64(Op, DAG);
2344
2345 llvm_unreachable("unhandled type");
2346}
2347
2348// This is the the rounding method used in CUDA libdevice in C like code:
2349// float roundf(float A)
2350// {
2351// float RoundedA = (float) (int) ( A > 0 ? (A + 0.5f) : (A - 0.5f));
2352// RoundedA = abs(A) > 0x1.0p23 ? A : RoundedA;
2353// return abs(A) < 0.5 ? (float)(int)A : RoundedA;
2354// }
2355SDValue NVPTXTargetLowering::LowerFROUND32(SDValue Op,
2356 SelectionDAG &DAG) const {
2357 SDLoc SL(Op);
2358 SDValue A = Op.getOperand(i: 0);
2359 EVT VT = Op.getValueType();
2360
2361 SDValue AbsA = DAG.getNode(Opcode: ISD::FABS, DL: SL, VT, Operand: A);
2362
2363 // RoundedA = (float) (int) ( A > 0 ? (A + 0.5f) : (A - 0.5f))
2364 SDValue Bitcast = DAG.getNode(Opcode: ISD::BITCAST, DL: SL, VT: MVT::i32, Operand: A);
2365 const unsigned SignBitMask = 0x80000000;
2366 SDValue Sign = DAG.getNode(Opcode: ISD::AND, DL: SL, VT: MVT::i32, N1: Bitcast,
2367 N2: DAG.getConstant(Val: SignBitMask, DL: SL, VT: MVT::i32));
2368 const unsigned PointFiveInBits = 0x3F000000;
2369 SDValue PointFiveWithSignRaw =
2370 DAG.getNode(Opcode: ISD::OR, DL: SL, VT: MVT::i32, N1: Sign,
2371 N2: DAG.getConstant(Val: PointFiveInBits, DL: SL, VT: MVT::i32));
2372 SDValue PointFiveWithSign =
2373 DAG.getNode(Opcode: ISD::BITCAST, DL: SL, VT, Operand: PointFiveWithSignRaw);
2374 SDValue AdjustedA = DAG.getNode(Opcode: ISD::FADD, DL: SL, VT, N1: A, N2: PointFiveWithSign);
2375 SDValue RoundedA = DAG.getNode(Opcode: ISD::FTRUNC, DL: SL, VT, Operand: AdjustedA);
2376
2377 // RoundedA = abs(A) > 0x1.0p23 ? A : RoundedA;
2378 EVT SetCCVT = getSetCCResultType(DL: DAG.getDataLayout(), Ctx&: *DAG.getContext(), VT);
2379 SDValue IsLarge =
2380 DAG.getSetCC(DL: SL, VT: SetCCVT, LHS: AbsA, RHS: DAG.getConstantFP(Val: pow(x: 2.0, y: 23.0), DL: SL, VT),
2381 Cond: ISD::SETOGT);
2382 RoundedA = DAG.getNode(Opcode: ISD::SELECT, DL: SL, VT, N1: IsLarge, N2: A, N3: RoundedA);
2383
2384 // return abs(A) < 0.5 ? (float)(int)A : RoundedA;
2385 SDValue IsSmall =DAG.getSetCC(DL: SL, VT: SetCCVT, LHS: AbsA,
2386 RHS: DAG.getConstantFP(Val: 0.5, DL: SL, VT), Cond: ISD::SETOLT);
2387 SDValue RoundedAForSmallA = DAG.getNode(Opcode: ISD::FTRUNC, DL: SL, VT, Operand: A);
2388 return DAG.getNode(Opcode: ISD::SELECT, DL: SL, VT, N1: IsSmall, N2: RoundedAForSmallA, N3: RoundedA);
2389}
2390
2391// The implementation of round(double) is similar to that of round(float) in
2392// that they both separate the value range into three regions and use a method
2393// specific to the region to round the values. However, round(double) first
2394// calculates the round of the absolute value and then adds the sign back while
2395// round(float) directly rounds the value with sign.
2396SDValue NVPTXTargetLowering::LowerFROUND64(SDValue Op,
2397 SelectionDAG &DAG) const {
2398 SDLoc SL(Op);
2399 SDValue A = Op.getOperand(i: 0);
2400 EVT VT = Op.getValueType();
2401
2402 SDValue AbsA = DAG.getNode(Opcode: ISD::FABS, DL: SL, VT, Operand: A);
2403
2404 // double RoundedA = (double) (int) (abs(A) + 0.5f);
2405 SDValue AdjustedA = DAG.getNode(Opcode: ISD::FADD, DL: SL, VT, N1: AbsA,
2406 N2: DAG.getConstantFP(Val: 0.5, DL: SL, VT));
2407 SDValue RoundedA = DAG.getNode(Opcode: ISD::FTRUNC, DL: SL, VT, Operand: AdjustedA);
2408
2409 // RoundedA = abs(A) < 0.5 ? (double)0 : RoundedA;
2410 EVT SetCCVT = getSetCCResultType(DL: DAG.getDataLayout(), Ctx&: *DAG.getContext(), VT);
2411 SDValue IsSmall =DAG.getSetCC(DL: SL, VT: SetCCVT, LHS: AbsA,
2412 RHS: DAG.getConstantFP(Val: 0.5, DL: SL, VT), Cond: ISD::SETOLT);
2413 RoundedA = DAG.getNode(Opcode: ISD::SELECT, DL: SL, VT, N1: IsSmall,
2414 N2: DAG.getConstantFP(Val: 0, DL: SL, VT),
2415 N3: RoundedA);
2416
2417 // Add sign to rounded_A
2418 RoundedA = DAG.getNode(Opcode: ISD::FCOPYSIGN, DL: SL, VT, N1: RoundedA, N2: A);
2419 DAG.getNode(Opcode: ISD::FTRUNC, DL: SL, VT, Operand: A);
2420
2421 // RoundedA = abs(A) > 0x1.0p52 ? A : RoundedA;
2422 SDValue IsLarge =
2423 DAG.getSetCC(DL: SL, VT: SetCCVT, LHS: AbsA, RHS: DAG.getConstantFP(Val: pow(x: 2.0, y: 52.0), DL: SL, VT),
2424 Cond: ISD::SETOGT);
2425 return DAG.getNode(Opcode: ISD::SELECT, DL: SL, VT, N1: IsLarge, N2: A, N3: RoundedA);
2426}
2427
2428static SDValue PromoteBinOpToF32(SDNode *N, SelectionDAG &DAG) {
2429 EVT VT = N->getValueType(ResNo: 0);
2430 EVT NVT = MVT::f32;
2431 if (VT.isVector()) {
2432 NVT = EVT::getVectorVT(Context&: *DAG.getContext(), VT: NVT, EC: VT.getVectorElementCount());
2433 }
2434 SDLoc DL(N);
2435 SDValue Tmp0 = DAG.getFPExtendOrRound(Op: N->getOperand(Num: 0), DL, VT: NVT);
2436 SDValue Tmp1 = DAG.getFPExtendOrRound(Op: N->getOperand(Num: 1), DL, VT: NVT);
2437 SDValue Res = DAG.getNode(Opcode: N->getOpcode(), DL, VT: NVT, N1: Tmp0, N2: Tmp1, Flags: N->getFlags());
2438 return DAG.getFPExtendOrRound(Op: Res, DL, VT);
2439}
2440
2441SDValue NVPTXTargetLowering::PromoteBinOpIfF32FTZ(SDValue Op,
2442 SelectionDAG &DAG) const {
2443 if (useF32FTZ(MF: DAG.getMachineFunction())) {
2444 return PromoteBinOpToF32(N: Op.getNode(), DAG);
2445 }
2446 return Op;
2447}
2448
2449SDValue NVPTXTargetLowering::LowerINT_TO_FP(SDValue Op,
2450 SelectionDAG &DAG) const {
2451 assert(STI.getSmVersion() < 90 || STI.getPTXVersion() < 78);
2452
2453 if (Op.getValueType() == MVT::bf16) {
2454 SDLoc Loc(Op);
2455 return DAG.getNode(
2456 Opcode: ISD::FP_ROUND, DL: Loc, VT: MVT::bf16,
2457 N1: DAG.getNode(Opcode: Op.getOpcode(), DL: Loc, VT: MVT::f32, Operand: Op.getOperand(i: 0)),
2458 N2: DAG.getIntPtrConstant(Val: 0, DL: Loc, /*isTarget=*/true));
2459 }
2460
2461 // Everything else is considered legal.
2462 return Op;
2463}
2464
2465SDValue NVPTXTargetLowering::LowerFP_TO_INT(SDValue Op,
2466 SelectionDAG &DAG) const {
2467 assert(STI.getSmVersion() < 90 || STI.getPTXVersion() < 78);
2468
2469 if (Op.getOperand(i: 0).getValueType() == MVT::bf16) {
2470 SDLoc Loc(Op);
2471 return DAG.getNode(
2472 Opcode: Op.getOpcode(), DL: Loc, VT: Op.getValueType(),
2473 Operand: DAG.getNode(Opcode: ISD::FP_EXTEND, DL: Loc, VT: MVT::f32, Operand: Op.getOperand(i: 0)));
2474 }
2475
2476 // Everything else is considered legal.
2477 return Op;
2478}
2479
2480SDValue NVPTXTargetLowering::LowerFP_ROUND(SDValue Op,
2481 SelectionDAG &DAG) const {
2482 EVT NarrowVT = Op.getValueType();
2483 SDValue Wide = Op.getOperand(i: 0);
2484 EVT WideVT = Wide.getValueType();
2485 if (NarrowVT.getScalarType() == MVT::bf16) {
2486 const TargetLowering *TLI = STI.getTargetLowering();
2487 if (STI.getSmVersion() < 80 || STI.getPTXVersion() < 70) {
2488 return TLI->expandFP_ROUND(Node: Op.getNode(), DAG);
2489 }
2490 if (STI.getSmVersion() < 90 || STI.getPTXVersion() < 78) {
2491 // This combination was the first to support f32 -> bf16.
2492 if (STI.getSmVersion() >= 80 && STI.getPTXVersion() >= 70) {
2493 if (WideVT.getScalarType() == MVT::f32) {
2494 return Op;
2495 }
2496 if (WideVT.getScalarType() == MVT::f64) {
2497 SDLoc Loc(Op);
2498 // Round-inexact-to-odd f64 to f32, then do the final rounding using
2499 // the hardware f32 -> bf16 instruction.
2500 SDValue rod = TLI->expandRoundInexactToOdd(
2501 ResultVT: WideVT.changeElementType(Context&: *DAG.getContext(), EltVT: MVT::f32), Op: Wide, DL: Loc,
2502 DAG);
2503 return DAG.getFPExtendOrRound(Op: rod, DL: Loc, VT: NarrowVT);
2504 }
2505 }
2506 return TLI->expandFP_ROUND(Node: Op.getNode(), DAG);
2507 }
2508 }
2509
2510 // Everything else is considered legal.
2511 return Op;
2512}
2513
2514SDValue NVPTXTargetLowering::LowerFP_EXTEND(SDValue Op,
2515 SelectionDAG &DAG) const {
2516 SDValue Narrow = Op.getOperand(i: 0);
2517 EVT NarrowVT = Narrow.getValueType();
2518 EVT WideVT = Op.getValueType();
2519 if (NarrowVT.getScalarType() == MVT::bf16) {
2520 if (WideVT.getScalarType() == MVT::f32 &&
2521 (STI.getSmVersion() < 80 || STI.getPTXVersion() < 71)) {
2522 SDLoc Loc(Op);
2523 return DAG.getNode(Opcode: ISD::BF16_TO_FP, DL: Loc, VT: WideVT, Operand: Narrow);
2524 }
2525 if (WideVT.getScalarType() == MVT::f64 &&
2526 (STI.getSmVersion() < 90 || STI.getPTXVersion() < 78)) {
2527 EVT F32 = NarrowVT.changeElementType(Context&: *DAG.getContext(), EltVT: MVT::f32);
2528 SDLoc Loc(Op);
2529 if (STI.getSmVersion() >= 80 && STI.getPTXVersion() >= 71) {
2530 Op = DAG.getNode(Opcode: ISD::FP_EXTEND, DL: Loc, VT: F32, Operand: Narrow);
2531 } else {
2532 Op = DAG.getNode(Opcode: ISD::BF16_TO_FP, DL: Loc, VT: F32, Operand: Narrow);
2533 }
2534 return DAG.getNode(Opcode: ISD::FP_EXTEND, DL: Loc, VT: WideVT, Operand: Op);
2535 }
2536 }
2537
2538 // Everything else is considered legal.
2539 return Op;
2540}
2541
2542static SDValue LowerVectorArith(SDValue Op, SelectionDAG &DAG) {
2543 SDLoc DL(Op);
2544 if (Op.getValueType() != MVT::v2i16)
2545 return Op;
2546 EVT EltVT = Op.getValueType().getVectorElementType();
2547 SmallVector<SDValue> VecElements;
2548 for (int I = 0, E = Op.getValueType().getVectorNumElements(); I < E; I++) {
2549 SmallVector<SDValue> ScalarArgs;
2550 llvm::transform(Range: Op->ops(), d_first: std::back_inserter(x&: ScalarArgs),
2551 F: [&](const SDUse &O) {
2552 return DAG.getNode(Opcode: ISD::EXTRACT_VECTOR_ELT, DL, VT: EltVT,
2553 N1: O.get(), N2: DAG.getIntPtrConstant(Val: I, DL));
2554 });
2555 VecElements.push_back(Elt: DAG.getNode(Opcode: Op.getOpcode(), DL, VT: EltVT, Ops: ScalarArgs));
2556 }
2557 SDValue V =
2558 DAG.getNode(Opcode: ISD::BUILD_VECTOR, DL, VT: Op.getValueType(), Ops: VecElements);
2559 return V;
2560}
2561
2562static SDValue lowerTcgen05St(SDValue Op, SelectionDAG &DAG,
2563 bool hasOffset = false) {
2564 // skip lowering if the vector operand is already legalized
2565 if (!Op->getOperand(Num: hasOffset ? 4 : 3).getValueType().isVector())
2566 return Op;
2567
2568 SDNode *N = Op.getNode();
2569 SDLoc DL(N);
2570 SmallVector<SDValue, 32> Ops;
2571
2572 // split the vector argument
2573 for (size_t I = 0; I < N->getNumOperands(); I++) {
2574 SDValue Val = N->getOperand(Num: I);
2575 EVT ValVT = Val.getValueType();
2576 if (ValVT.isVector()) {
2577 EVT EltVT = ValVT.getVectorElementType();
2578 for (unsigned J = 0, NElts = ValVT.getVectorNumElements(); J < NElts; J++)
2579 Ops.push_back(Elt: DAG.getNode(Opcode: ISD::EXTRACT_VECTOR_ELT, DL, VT: EltVT, N1: Val,
2580 N2: DAG.getIntPtrConstant(Val: J, DL)));
2581 } else
2582 Ops.push_back(Elt: Val);
2583 }
2584
2585 MemIntrinsicSDNode *MemSD = cast<MemIntrinsicSDNode>(Val: N);
2586 SDValue Tcgen05StNode =
2587 DAG.getMemIntrinsicNode(Opcode: ISD::INTRINSIC_VOID, dl: DL, VTList: N->getVTList(), Ops,
2588 MemVT: MemSD->getMemoryVT(), MMO: MemSD->getMemOperand());
2589
2590 return Tcgen05StNode;
2591}
2592
2593static SDValue lowerBSWAP(SDValue Op, SelectionDAG &DAG) {
2594 SDLoc DL(Op);
2595 SDValue Src = Op.getOperand(i: 0);
2596 EVT VT = Op.getValueType();
2597
2598 switch (VT.getSimpleVT().SimpleTy) {
2599 case MVT::i16: {
2600 SDValue Extended = DAG.getNode(Opcode: ISD::ANY_EXTEND, DL, VT: MVT::i32, Operand: Src);
2601 SDValue Swapped =
2602 getPRMT(A: Extended, B: DAG.getConstant(Val: 0, DL, VT: MVT::i32), Selector: 0x7701, DL, DAG);
2603 return DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT: MVT::i16, Operand: Swapped);
2604 }
2605 case MVT::i32: {
2606 return getPRMT(A: Src, B: DAG.getConstant(Val: 0, DL, VT: MVT::i32), Selector: 0x0123, DL, DAG);
2607 }
2608 case MVT::v2i16: {
2609 SDValue Converted = DAG.getBitcast(VT: MVT::i32, V: Src);
2610 SDValue Swapped =
2611 getPRMT(A: Converted, B: DAG.getConstant(Val: 0, DL, VT: MVT::i32), Selector: 0x2301, DL, DAG);
2612 return DAG.getNode(Opcode: ISD::BITCAST, DL, VT: MVT::v2i16, Operand: Swapped);
2613 }
2614 case MVT::i64: {
2615 SDValue UnpackSrc =
2616 DAG.getNode(Opcode: NVPTXISD::UNPACK_VECTOR, DL, ResultTys: {MVT::i32, MVT::i32}, Ops: Src);
2617 SDValue SwappedLow =
2618 getPRMT(A: UnpackSrc.getValue(R: 0), B: DAG.getConstant(Val: 0, DL, VT: MVT::i32), Selector: 0x0123,
2619 DL, DAG);
2620 SDValue SwappedHigh =
2621 getPRMT(A: UnpackSrc.getValue(R: 1), B: DAG.getConstant(Val: 0, DL, VT: MVT::i32), Selector: 0x0123,
2622 DL, DAG);
2623 return DAG.getNode(Opcode: NVPTXISD::BUILD_VECTOR, DL, VT: MVT::i64,
2624 Ops: {SwappedHigh, SwappedLow});
2625 }
2626 default:
2627 llvm_unreachable("unsupported type for bswap");
2628 }
2629}
2630
2631static unsigned getTcgen05MMADisableOutputLane(unsigned IID) {
2632 switch (IID) {
2633 case Intrinsic::nvvm_tcgen05_mma_shared_disable_output_lane_cg1:
2634 return NVPTXISD::TCGEN05_MMA_SHARED_DISABLE_OUTPUT_LANE_CG1;
2635 case Intrinsic::nvvm_tcgen05_mma_shared_disable_output_lane_cg2:
2636 return NVPTXISD::TCGEN05_MMA_SHARED_DISABLE_OUTPUT_LANE_CG2;
2637 case Intrinsic::nvvm_tcgen05_mma_shared_scale_d_disable_output_lane_cg1:
2638 return NVPTXISD::TCGEN05_MMA_SHARED_SCALE_D_DISABLE_OUTPUT_LANE_CG1;
2639 case Intrinsic::nvvm_tcgen05_mma_shared_scale_d_disable_output_lane_cg2:
2640 return NVPTXISD::TCGEN05_MMA_SHARED_SCALE_D_DISABLE_OUTPUT_LANE_CG2;
2641 case Intrinsic::nvvm_tcgen05_mma_tensor_disable_output_lane_cg1:
2642 return NVPTXISD::TCGEN05_MMA_TENSOR_DISABLE_OUTPUT_LANE_CG1;
2643 case Intrinsic::nvvm_tcgen05_mma_tensor_disable_output_lane_cg2:
2644 return NVPTXISD::TCGEN05_MMA_TENSOR_DISABLE_OUTPUT_LANE_CG2;
2645 case Intrinsic::nvvm_tcgen05_mma_tensor_scale_d_disable_output_lane_cg1:
2646 return NVPTXISD::TCGEN05_MMA_TENSOR_SCALE_D_DISABLE_OUTPUT_LANE_CG1;
2647 case Intrinsic::nvvm_tcgen05_mma_tensor_scale_d_disable_output_lane_cg2:
2648 return NVPTXISD::TCGEN05_MMA_TENSOR_SCALE_D_DISABLE_OUTPUT_LANE_CG2;
2649 case Intrinsic::nvvm_tcgen05_mma_tensor_disable_output_lane_cg1_ashift:
2650 return NVPTXISD::TCGEN05_MMA_TENSOR_DISABLE_OUTPUT_LANE_CG1_ASHIFT;
2651 case Intrinsic::nvvm_tcgen05_mma_tensor_disable_output_lane_cg2_ashift:
2652 return NVPTXISD::TCGEN05_MMA_TENSOR_DISABLE_OUTPUT_LANE_CG2_ASHIFT;
2653 case Intrinsic::
2654 nvvm_tcgen05_mma_tensor_scale_d_disable_output_lane_cg1_ashift:
2655 return NVPTXISD::TCGEN05_MMA_TENSOR_SCALE_D_DISABLE_OUTPUT_LANE_CG1_ASHIFT;
2656 case Intrinsic::
2657 nvvm_tcgen05_mma_tensor_scale_d_disable_output_lane_cg2_ashift:
2658 return NVPTXISD::TCGEN05_MMA_TENSOR_SCALE_D_DISABLE_OUTPUT_LANE_CG2_ASHIFT;
2659 case Intrinsic::nvvm_tcgen05_mma_sp_shared_disable_output_lane_cg1:
2660 return NVPTXISD::TCGEN05_MMA_SP_SHARED_DISABLE_OUTPUT_LANE_CG1;
2661 case Intrinsic::nvvm_tcgen05_mma_sp_shared_disable_output_lane_cg2:
2662 return NVPTXISD::TCGEN05_MMA_SP_SHARED_DISABLE_OUTPUT_LANE_CG2;
2663 case Intrinsic::nvvm_tcgen05_mma_sp_shared_scale_d_disable_output_lane_cg1:
2664 return NVPTXISD::TCGEN05_MMA_SP_SHARED_SCALE_D_DISABLE_OUTPUT_LANE_CG1;
2665 case Intrinsic::nvvm_tcgen05_mma_sp_shared_scale_d_disable_output_lane_cg2:
2666 return NVPTXISD::TCGEN05_MMA_SP_SHARED_SCALE_D_DISABLE_OUTPUT_LANE_CG2;
2667 case Intrinsic::nvvm_tcgen05_mma_sp_tensor_disable_output_lane_cg1:
2668 return NVPTXISD::TCGEN05_MMA_SP_TENSOR_DISABLE_OUTPUT_LANE_CG1;
2669 case Intrinsic::nvvm_tcgen05_mma_sp_tensor_disable_output_lane_cg2:
2670 return NVPTXISD::TCGEN05_MMA_SP_TENSOR_DISABLE_OUTPUT_LANE_CG2;
2671 case Intrinsic::nvvm_tcgen05_mma_sp_tensor_disable_output_lane_cg1_ashift:
2672 return NVPTXISD::TCGEN05_MMA_SP_TENSOR_DISABLE_OUTPUT_LANE_CG1_ASHIFT;
2673 case Intrinsic::nvvm_tcgen05_mma_sp_tensor_disable_output_lane_cg2_ashift:
2674 return NVPTXISD::TCGEN05_MMA_SP_TENSOR_DISABLE_OUTPUT_LANE_CG2_ASHIFT;
2675 case Intrinsic::nvvm_tcgen05_mma_sp_tensor_scale_d_disable_output_lane_cg1:
2676 return NVPTXISD::TCGEN05_MMA_SP_TENSOR_SCALE_D_DISABLE_OUTPUT_LANE_CG1;
2677 case Intrinsic::nvvm_tcgen05_mma_sp_tensor_scale_d_disable_output_lane_cg2:
2678 return NVPTXISD::TCGEN05_MMA_SP_TENSOR_SCALE_D_DISABLE_OUTPUT_LANE_CG2;
2679 case Intrinsic::
2680 nvvm_tcgen05_mma_sp_tensor_scale_d_disable_output_lane_cg1_ashift:
2681 return NVPTXISD::
2682 TCGEN05_MMA_SP_TENSOR_SCALE_D_DISABLE_OUTPUT_LANE_CG1_ASHIFT;
2683 case Intrinsic::
2684 nvvm_tcgen05_mma_sp_tensor_scale_d_disable_output_lane_cg2_ashift:
2685 return NVPTXISD::
2686 TCGEN05_MMA_SP_TENSOR_SCALE_D_DISABLE_OUTPUT_LANE_CG2_ASHIFT;
2687 };
2688 llvm_unreachable("unhandled tcgen05.mma.disable_output_lane intrinsic");
2689}
2690
2691static SDValue LowerTcgen05MMADisableOutputLane(SDValue Op, SelectionDAG &DAG) {
2692 SDNode *N = Op.getNode();
2693 SDLoc DL(N);
2694 unsigned IID = cast<ConstantSDNode>(Val: N->getOperand(Num: 1))->getZExtValue();
2695
2696 SmallVector<SDValue, 16> Ops;
2697 // split the vector argument
2698 for (size_t I = 0; I < N->getNumOperands(); I++) {
2699 if (I == 1)
2700 continue; // skip IID
2701 SDValue Val = N->getOperand(Num: I);
2702 EVT ValVT = Val.getValueType();
2703 if (ValVT.isVector()) {
2704 EVT EltVT = ValVT.getVectorElementType();
2705 for (unsigned J = 0, NElts = ValVT.getVectorNumElements(); J < NElts; J++)
2706 Ops.push_back(Elt: DAG.getNode(Opcode: ISD::EXTRACT_VECTOR_ELT, DL, VT: EltVT, N1: Val,
2707 N2: DAG.getIntPtrConstant(Val: J, DL)));
2708 } else
2709 Ops.push_back(Elt: Val);
2710 }
2711
2712 MemIntrinsicSDNode *MemSD = cast<MemIntrinsicSDNode>(Val: N);
2713 SDValue Tcgen05MMANode = DAG.getMemIntrinsicNode(
2714 Opcode: getTcgen05MMADisableOutputLane(IID), dl: DL, VTList: N->getVTList(), Ops,
2715 MemVT: MemSD->getMemoryVT(), MMO: MemSD->getMemOperand());
2716
2717 return Tcgen05MMANode;
2718}
2719
2720// Lower vector return type of tcgen05.ld intrinsics
2721static std::optional<std::pair<SDValue, SDValue>>
2722lowerTcgen05Ld(SDNode *N, SelectionDAG &DAG, bool HasOffset = false) {
2723 SDLoc DL(N);
2724 EVT ResVT = N->getValueType(ResNo: 0);
2725 if (!ResVT.isVector())
2726 return {}; // already legalized.
2727
2728 const unsigned NumElts = ResVT.getVectorNumElements();
2729
2730 // Create the return type of the instructions
2731 SmallVector<EVT, 5> ListVTs;
2732 for (unsigned i = 0; i < NumElts; ++i)
2733 ListVTs.push_back(Elt: MVT::i32);
2734
2735 ListVTs.push_back(Elt: N->getValueType(ResNo: 1)); // Chain
2736
2737 SDVTList ResVTs = DAG.getVTList(VTs: ListVTs);
2738
2739 SmallVector<SDValue, 8> Ops{N->getOperand(Num: 0), N->getOperand(Num: 1),
2740 N->getOperand(Num: 2)};
2741
2742 if (HasOffset) {
2743 Ops.push_back(Elt: N->getOperand(Num: 3)); // offset
2744 Ops.push_back(Elt: N->getOperand(Num: 4)); // Pack flag
2745 } else
2746 Ops.push_back(Elt: N->getOperand(Num: 3)); // Pack flag
2747
2748 MemIntrinsicSDNode *MemSD = cast<MemIntrinsicSDNode>(Val: N);
2749 SDValue NewNode =
2750 DAG.getMemIntrinsicNode(Opcode: ISD::INTRINSIC_W_CHAIN, dl: DL, VTList: ResVTs, Ops,
2751 MemVT: MemSD->getMemoryVT(), MMO: MemSD->getMemOperand());
2752
2753 // split the vector result
2754 SmallVector<SDValue, 4> ScalarRes;
2755 for (unsigned i = 0; i < NumElts; ++i) {
2756 SDValue Res = NewNode.getValue(R: i);
2757 ScalarRes.push_back(Elt: Res);
2758 }
2759
2760 SDValue Chain = NewNode.getValue(R: NumElts);
2761 SDValue BuildVector = DAG.getNode(Opcode: ISD::BUILD_VECTOR, DL, VT: ResVT, Ops: ScalarRes);
2762 return {{BuildVector, Chain}};
2763}
2764
2765static SDValue reportInvalidTensormapReplaceUsage(SDValue Op, SelectionDAG &DAG,
2766 unsigned Val) {
2767 SDNode *N = Op.getNode();
2768 SDLoc DL(N);
2769
2770 const Function &Fn = DAG.getMachineFunction().getFunction();
2771
2772 unsigned AS = 0;
2773 if (auto *MemN = dyn_cast<MemIntrinsicSDNode>(Val: N))
2774 AS = MemN->getAddressSpace();
2775 Type *PtrTy = PointerType::get(C&: *DAG.getContext(), AddressSpace: AS);
2776 Module *M = DAG.getMachineFunction().getFunction().getParent();
2777
2778 DAG.getContext()->diagnose(DI: DiagnosticInfoUnsupported(
2779 Fn,
2780 "Intrinsic " +
2781 Intrinsic::getName(Id: N->getConstantOperandVal(Num: 1), Tys: {PtrTy}, M) +
2782 " with value " + Twine(Val) +
2783 " is not supported on the given target.",
2784 DL.getDebugLoc()));
2785 return Op.getOperand(i: 0);
2786}
2787
2788static SDValue lowerTensormapReplaceElemtype(SDValue Op, SelectionDAG &DAG) {
2789 SDNode *N = Op.getNode();
2790 SDLoc DL(N);
2791
2792 // immediate argument representing elemtype
2793 unsigned Val = N->getConstantOperandVal(Num: 3);
2794
2795 if (!DAG.getSubtarget<NVPTXSubtarget>().hasTensormapReplaceElemtypeSupport(
2796 value: Val))
2797 return reportInvalidTensormapReplaceUsage(Op, DAG, Val);
2798
2799 return Op;
2800}
2801
2802static SDValue lowerTensormapReplaceSwizzleMode(SDValue Op, SelectionDAG &DAG) {
2803 SDNode *N = Op.getNode();
2804 SDLoc DL(N);
2805
2806 // immediate argument representing swizzle mode
2807 unsigned Val = N->getConstantOperandVal(Num: 3);
2808
2809 if (!DAG.getSubtarget<NVPTXSubtarget>().hasTensormapReplaceSwizzleModeSupport(
2810 value: Val))
2811 return reportInvalidTensormapReplaceUsage(Op, DAG, Val);
2812
2813 return Op;
2814}
2815
2816static SDValue lowerIntrinsicVoid(SDValue Op, SelectionDAG &DAG) {
2817 SDNode *N = Op.getNode();
2818 SDValue Intrin = N->getOperand(Num: 1);
2819
2820 // Get the intrinsic ID
2821 unsigned IntrinNo = cast<ConstantSDNode>(Val: Intrin.getNode())->getZExtValue();
2822 switch (IntrinNo) {
2823 default:
2824 break;
2825 case Intrinsic::nvvm_tcgen05_st_16x64b_x2:
2826 case Intrinsic::nvvm_tcgen05_st_16x64b_x4:
2827 case Intrinsic::nvvm_tcgen05_st_16x64b_x8:
2828 case Intrinsic::nvvm_tcgen05_st_16x64b_x16:
2829 case Intrinsic::nvvm_tcgen05_st_16x64b_x32:
2830 case Intrinsic::nvvm_tcgen05_st_16x64b_x128:
2831 case Intrinsic::nvvm_tcgen05_st_16x128b_x1:
2832 case Intrinsic::nvvm_tcgen05_st_16x128b_x2:
2833 case Intrinsic::nvvm_tcgen05_st_16x128b_x4:
2834 case Intrinsic::nvvm_tcgen05_st_16x128b_x8:
2835 case Intrinsic::nvvm_tcgen05_st_16x128b_x16:
2836 case Intrinsic::nvvm_tcgen05_st_16x128b_x32:
2837 case Intrinsic::nvvm_tcgen05_st_16x128b_x64:
2838 case Intrinsic::nvvm_tcgen05_st_16x256b_x1:
2839 case Intrinsic::nvvm_tcgen05_st_16x256b_x2:
2840 case Intrinsic::nvvm_tcgen05_st_16x256b_x4:
2841 case Intrinsic::nvvm_tcgen05_st_16x256b_x8:
2842 case Intrinsic::nvvm_tcgen05_st_16x256b_x16:
2843 case Intrinsic::nvvm_tcgen05_st_16x256b_x32:
2844 case Intrinsic::nvvm_tcgen05_st_32x32b_x2:
2845 case Intrinsic::nvvm_tcgen05_st_32x32b_x4:
2846 case Intrinsic::nvvm_tcgen05_st_32x32b_x8:
2847 case Intrinsic::nvvm_tcgen05_st_32x32b_x16:
2848 case Intrinsic::nvvm_tcgen05_st_32x32b_x32:
2849 case Intrinsic::nvvm_tcgen05_st_16x64b_x64:
2850 case Intrinsic::nvvm_tcgen05_st_32x32b_x64:
2851 case Intrinsic::nvvm_tcgen05_st_32x32b_x128:
2852 return lowerTcgen05St(Op, DAG);
2853 case Intrinsic::nvvm_tcgen05_st_16x32bx2_x2:
2854 case Intrinsic::nvvm_tcgen05_st_16x32bx2_x4:
2855 case Intrinsic::nvvm_tcgen05_st_16x32bx2_x8:
2856 case Intrinsic::nvvm_tcgen05_st_16x32bx2_x16:
2857 case Intrinsic::nvvm_tcgen05_st_16x32bx2_x32:
2858 case Intrinsic::nvvm_tcgen05_st_16x32bx2_x64:
2859 case Intrinsic::nvvm_tcgen05_st_16x32bx2_x128:
2860 return lowerTcgen05St(Op, DAG, /* hasOffset */ true);
2861 case Intrinsic::nvvm_tcgen05_mma_shared_disable_output_lane_cg1:
2862 case Intrinsic::nvvm_tcgen05_mma_shared_disable_output_lane_cg2:
2863 case Intrinsic::nvvm_tcgen05_mma_shared_scale_d_disable_output_lane_cg1:
2864 case Intrinsic::nvvm_tcgen05_mma_shared_scale_d_disable_output_lane_cg2:
2865 case Intrinsic::nvvm_tcgen05_mma_sp_shared_disable_output_lane_cg1:
2866 case Intrinsic::nvvm_tcgen05_mma_sp_shared_disable_output_lane_cg2:
2867 case Intrinsic::nvvm_tcgen05_mma_sp_shared_scale_d_disable_output_lane_cg1:
2868 case Intrinsic::nvvm_tcgen05_mma_sp_shared_scale_d_disable_output_lane_cg2:
2869 case Intrinsic::nvvm_tcgen05_mma_tensor_disable_output_lane_cg1:
2870 case Intrinsic::nvvm_tcgen05_mma_tensor_disable_output_lane_cg2:
2871 case Intrinsic::nvvm_tcgen05_mma_tensor_scale_d_disable_output_lane_cg1:
2872 case Intrinsic::nvvm_tcgen05_mma_tensor_scale_d_disable_output_lane_cg2:
2873 case Intrinsic::nvvm_tcgen05_mma_sp_tensor_disable_output_lane_cg1:
2874 case Intrinsic::nvvm_tcgen05_mma_sp_tensor_disable_output_lane_cg2:
2875 case Intrinsic::nvvm_tcgen05_mma_sp_tensor_scale_d_disable_output_lane_cg1:
2876 case Intrinsic::nvvm_tcgen05_mma_sp_tensor_scale_d_disable_output_lane_cg2:
2877 case Intrinsic::nvvm_tcgen05_mma_tensor_disable_output_lane_cg1_ashift:
2878 case Intrinsic::nvvm_tcgen05_mma_tensor_disable_output_lane_cg2_ashift:
2879 case Intrinsic::
2880 nvvm_tcgen05_mma_tensor_scale_d_disable_output_lane_cg1_ashift:
2881 case Intrinsic::
2882 nvvm_tcgen05_mma_tensor_scale_d_disable_output_lane_cg2_ashift:
2883 case Intrinsic::nvvm_tcgen05_mma_sp_tensor_disable_output_lane_cg1_ashift:
2884 case Intrinsic::nvvm_tcgen05_mma_sp_tensor_disable_output_lane_cg2_ashift:
2885 case Intrinsic::
2886 nvvm_tcgen05_mma_sp_tensor_scale_d_disable_output_lane_cg1_ashift:
2887 case Intrinsic::
2888 nvvm_tcgen05_mma_sp_tensor_scale_d_disable_output_lane_cg2_ashift:
2889 return LowerTcgen05MMADisableOutputLane(Op, DAG);
2890 case Intrinsic::nvvm_tensormap_replace_elemtype:
2891 return lowerTensormapReplaceElemtype(Op, DAG);
2892 case Intrinsic::nvvm_tensormap_replace_swizzle_mode:
2893 return lowerTensormapReplaceSwizzleMode(Op, DAG);
2894 }
2895 return Op;
2896}
2897
2898static SDValue LowerClusterLaunchControlQueryCancel(SDValue Op,
2899 SelectionDAG &DAG) {
2900
2901 SDNode *N = Op.getNode();
2902 if (N->getOperand(Num: 1).getValueType() != MVT::i128) {
2903 // return, if the operand is already lowered
2904 return SDValue();
2905 }
2906
2907 unsigned IID =
2908 cast<ConstantSDNode>(Val: N->getOperand(Num: 0).getNode())->getZExtValue();
2909 auto Opcode = [&]() {
2910 switch (IID) {
2911 case Intrinsic::nvvm_clusterlaunchcontrol_query_cancel_is_canceled:
2912 return NVPTXISD::CLUSTERLAUNCHCONTROL_QUERY_CANCEL_IS_CANCELED;
2913 case Intrinsic::nvvm_clusterlaunchcontrol_query_cancel_get_first_ctaid_x:
2914 return NVPTXISD::CLUSTERLAUNCHCONTROL_QUERY_CANCEL_GET_FIRST_CTAID_X;
2915 case Intrinsic::nvvm_clusterlaunchcontrol_query_cancel_get_first_ctaid_y:
2916 return NVPTXISD::CLUSTERLAUNCHCONTROL_QUERY_CANCEL_GET_FIRST_CTAID_Y;
2917 case Intrinsic::nvvm_clusterlaunchcontrol_query_cancel_get_first_ctaid_z:
2918 return NVPTXISD::CLUSTERLAUNCHCONTROL_QUERY_CANCEL_GET_FIRST_CTAID_Z;
2919 default:
2920 llvm_unreachable("unsupported/unhandled intrinsic");
2921 }
2922 }();
2923
2924 SDLoc DL(N);
2925 SDValue TryCancelResponse = N->getOperand(Num: 1);
2926 SDValue Cast = DAG.getNode(Opcode: ISD::BITCAST, DL, VT: MVT::v2i64, Operand: TryCancelResponse);
2927 SDValue TryCancelResponse0 =
2928 DAG.getNode(Opcode: ISD::EXTRACT_VECTOR_ELT, DL, VT: MVT::i64, N1: Cast,
2929 N2: DAG.getIntPtrConstant(Val: 0, DL));
2930 SDValue TryCancelResponse1 =
2931 DAG.getNode(Opcode: ISD::EXTRACT_VECTOR_ELT, DL, VT: MVT::i64, N1: Cast,
2932 N2: DAG.getIntPtrConstant(Val: 1, DL));
2933
2934 return DAG.getNode(Opcode, DL, VTList: N->getVTList(),
2935 Ops: {TryCancelResponse0, TryCancelResponse1});
2936}
2937
2938static SDValue lowerCvtRSIntrinsics(SDValue Op, SelectionDAG &DAG) {
2939 SDNode *N = Op.getNode();
2940 SDLoc DL(N);
2941 SDValue F32Vec = N->getOperand(Num: 1);
2942 SDValue RBits = N->getOperand(Num: 2);
2943
2944 unsigned IntrinsicID = N->getConstantOperandVal(Num: 0);
2945
2946 // Extract the 4 float elements from the vector
2947 SmallVector<SDValue, 6> Ops;
2948 for (unsigned i = 0; i < 4; ++i)
2949 Ops.push_back(Elt: DAG.getNode(Opcode: ISD::EXTRACT_VECTOR_ELT, DL, VT: MVT::f32, N1: F32Vec,
2950 N2: DAG.getIntPtrConstant(Val: i, DL)));
2951
2952 using NVPTX::PTXCvtMode::CvtMode;
2953
2954 auto [OpCode, RetTy, CvtModeFlag] =
2955 [&]() -> std::tuple<unsigned, MVT::SimpleValueType, uint32_t> {
2956 switch (IntrinsicID) {
2957 case Intrinsic::nvvm_f32x4_to_e4m3x4_rs_relu_satfinite:
2958 return {NVPTXISD::CVT_E4M3X4_F32X4_RS_SF, MVT::v4i8,
2959 CvtMode::RS | CvtMode::RELU_FLAG};
2960 case Intrinsic::nvvm_f32x4_to_e4m3x4_rs_satfinite:
2961 return {NVPTXISD::CVT_E4M3X4_F32X4_RS_SF, MVT::v4i8, CvtMode::RS};
2962 case Intrinsic::nvvm_f32x4_to_e5m2x4_rs_relu_satfinite:
2963 return {NVPTXISD::CVT_E5M2X4_F32X4_RS_SF, MVT::v4i8,
2964 CvtMode::RS | CvtMode::RELU_FLAG};
2965 case Intrinsic::nvvm_f32x4_to_e5m2x4_rs_satfinite:
2966 return {NVPTXISD::CVT_E5M2X4_F32X4_RS_SF, MVT::v4i8, CvtMode::RS};
2967 case Intrinsic::nvvm_f32x4_to_e2m3x4_rs_relu_satfinite:
2968 return {NVPTXISD::CVT_E2M3X4_F32X4_RS_SF, MVT::v4i8,
2969 CvtMode::RS | CvtMode::RELU_FLAG};
2970 case Intrinsic::nvvm_f32x4_to_e2m3x4_rs_satfinite:
2971 return {NVPTXISD::CVT_E2M3X4_F32X4_RS_SF, MVT::v4i8, CvtMode::RS};
2972 case Intrinsic::nvvm_f32x4_to_e3m2x4_rs_relu_satfinite:
2973 return {NVPTXISD::CVT_E3M2X4_F32X4_RS_SF, MVT::v4i8,
2974 CvtMode::RS | CvtMode::RELU_FLAG};
2975 case Intrinsic::nvvm_f32x4_to_e3m2x4_rs_satfinite:
2976 return {NVPTXISD::CVT_E3M2X4_F32X4_RS_SF, MVT::v4i8, CvtMode::RS};
2977 case Intrinsic::nvvm_f32x4_to_e2m1x4_rs_relu_satfinite:
2978 return {NVPTXISD::CVT_E2M1X4_F32X4_RS_SF, MVT::i16,
2979 CvtMode::RS | CvtMode::RELU_FLAG};
2980 case Intrinsic::nvvm_f32x4_to_e2m1x4_rs_satfinite:
2981 return {NVPTXISD::CVT_E2M1X4_F32X4_RS_SF, MVT::i16, CvtMode::RS};
2982 default:
2983 llvm_unreachable("unsupported/unhandled intrinsic");
2984 }
2985 }();
2986
2987 Ops.push_back(Elt: RBits);
2988 Ops.push_back(Elt: DAG.getConstant(Val: CvtModeFlag, DL, VT: MVT::i32));
2989
2990 return DAG.getNode(Opcode: OpCode, DL, VT: RetTy, Ops);
2991}
2992
2993static SDValue lowerPrmtIntrinsic(SDValue Op, SelectionDAG &DAG) {
2994 const unsigned Mode = [&]() {
2995 switch (Op->getConstantOperandVal(Num: 0)) {
2996 case Intrinsic::nvvm_prmt:
2997 return NVPTX::PTXPrmtMode::NONE;
2998 case Intrinsic::nvvm_prmt_b4e:
2999 return NVPTX::PTXPrmtMode::B4E;
3000 case Intrinsic::nvvm_prmt_ecl:
3001 return NVPTX::PTXPrmtMode::ECL;
3002 case Intrinsic::nvvm_prmt_ecr:
3003 return NVPTX::PTXPrmtMode::ECR;
3004 case Intrinsic::nvvm_prmt_f4e:
3005 return NVPTX::PTXPrmtMode::F4E;
3006 case Intrinsic::nvvm_prmt_rc16:
3007 return NVPTX::PTXPrmtMode::RC16;
3008 case Intrinsic::nvvm_prmt_rc8:
3009 return NVPTX::PTXPrmtMode::RC8;
3010 default:
3011 llvm_unreachable("unsupported/unhandled intrinsic");
3012 }
3013 }();
3014 SDLoc DL(Op);
3015 SDValue A = Op->getOperand(Num: 1);
3016 SDValue B = Op.getNumOperands() == 4 ? Op.getOperand(i: 2)
3017 : DAG.getConstant(Val: 0, DL, VT: MVT::i32);
3018 SDValue Selector = (Op->op_end() - 1)->get();
3019 return getPRMT(A, B, Selector, DL, DAG, Mode);
3020}
3021
3022#define TCGEN05_LD_RED_INTR(SHAPE, NUM, TYPE) \
3023 Intrinsic::nvvm_tcgen05_ld_red_##SHAPE##_x##NUM##_##TYPE
3024
3025#define TCGEN05_LD_RED_INST(SHAPE, NUM, TYPE) \
3026 NVPTXISD::TCGEN05_LD_RED_##SHAPE##_X##NUM##_##TYPE
3027
3028static unsigned getTcgen05LdRedID(Intrinsic::ID IID) {
3029 switch (IID) {
3030 case TCGEN05_LD_RED_INTR(32x32b, 2, f32):
3031 return TCGEN05_LD_RED_INST(32x32b, 2, F32);
3032 case TCGEN05_LD_RED_INTR(32x32b, 4, f32):
3033 return TCGEN05_LD_RED_INST(32x32b, 4, F32);
3034 case TCGEN05_LD_RED_INTR(32x32b, 8, f32):
3035 return TCGEN05_LD_RED_INST(32x32b, 8, F32);
3036 case TCGEN05_LD_RED_INTR(32x32b, 16, f32):
3037 return TCGEN05_LD_RED_INST(32x32b, 16, F32);
3038 case TCGEN05_LD_RED_INTR(32x32b, 32, f32):
3039 return TCGEN05_LD_RED_INST(32x32b, 32, F32);
3040 case TCGEN05_LD_RED_INTR(32x32b, 64, f32):
3041 return TCGEN05_LD_RED_INST(32x32b, 64, F32);
3042 case TCGEN05_LD_RED_INTR(32x32b, 128, f32):
3043 return TCGEN05_LD_RED_INST(32x32b, 128, F32);
3044 case TCGEN05_LD_RED_INTR(16x32bx2, 2, f32):
3045 return TCGEN05_LD_RED_INST(16x32bx2, 2, F32);
3046 case TCGEN05_LD_RED_INTR(16x32bx2, 4, f32):
3047 return TCGEN05_LD_RED_INST(16x32bx2, 4, F32);
3048 case TCGEN05_LD_RED_INTR(16x32bx2, 8, f32):
3049 return TCGEN05_LD_RED_INST(16x32bx2, 8, F32);
3050 case TCGEN05_LD_RED_INTR(16x32bx2, 16, f32):
3051 return TCGEN05_LD_RED_INST(16x32bx2, 16, F32);
3052 case TCGEN05_LD_RED_INTR(16x32bx2, 32, f32):
3053 return TCGEN05_LD_RED_INST(16x32bx2, 32, F32);
3054 case TCGEN05_LD_RED_INTR(16x32bx2, 64, f32):
3055 return TCGEN05_LD_RED_INST(16x32bx2, 64, F32);
3056 case TCGEN05_LD_RED_INTR(16x32bx2, 128, f32):
3057 return TCGEN05_LD_RED_INST(16x32bx2, 128, F32);
3058 case TCGEN05_LD_RED_INTR(32x32b, 2, i32):
3059 return TCGEN05_LD_RED_INST(32x32b, 2, I32);
3060 case TCGEN05_LD_RED_INTR(32x32b, 4, i32):
3061 return TCGEN05_LD_RED_INST(32x32b, 4, I32);
3062 case TCGEN05_LD_RED_INTR(32x32b, 8, i32):
3063 return TCGEN05_LD_RED_INST(32x32b, 8, I32);
3064 case TCGEN05_LD_RED_INTR(32x32b, 16, i32):
3065 return TCGEN05_LD_RED_INST(32x32b, 16, I32);
3066 case TCGEN05_LD_RED_INTR(32x32b, 32, i32):
3067 return TCGEN05_LD_RED_INST(32x32b, 32, I32);
3068 case TCGEN05_LD_RED_INTR(32x32b, 64, i32):
3069 return TCGEN05_LD_RED_INST(32x32b, 64, I32);
3070 case TCGEN05_LD_RED_INTR(32x32b, 128, i32):
3071 return TCGEN05_LD_RED_INST(32x32b, 128, I32);
3072 case TCGEN05_LD_RED_INTR(16x32bx2, 2, i32):
3073 return TCGEN05_LD_RED_INST(16x32bx2, 2, I32);
3074 case TCGEN05_LD_RED_INTR(16x32bx2, 4, i32):
3075 return TCGEN05_LD_RED_INST(16x32bx2, 4, I32);
3076 case TCGEN05_LD_RED_INTR(16x32bx2, 8, i32):
3077 return TCGEN05_LD_RED_INST(16x32bx2, 8, I32);
3078 case TCGEN05_LD_RED_INTR(16x32bx2, 16, i32):
3079 return TCGEN05_LD_RED_INST(16x32bx2, 16, I32);
3080 case TCGEN05_LD_RED_INTR(16x32bx2, 32, i32):
3081 return TCGEN05_LD_RED_INST(16x32bx2, 32, I32);
3082 case TCGEN05_LD_RED_INTR(16x32bx2, 64, i32):
3083 return TCGEN05_LD_RED_INST(16x32bx2, 64, I32);
3084 case TCGEN05_LD_RED_INTR(16x32bx2, 128, i32):
3085 return TCGEN05_LD_RED_INST(16x32bx2, 128, I32);
3086 default:
3087 llvm_unreachable("Invalid tcgen05.ld.red intrinsic ID");
3088 }
3089}
3090
3091// Lower vector return type of tcgen05.ld intrinsics
3092static std::optional<std::tuple<SDValue, SDValue, SDValue>>
3093lowerTcgen05LdRed(SDNode *N, SelectionDAG &DAG) {
3094 SDLoc DL(N);
3095 EVT ResVT = N->getValueType(ResNo: 0);
3096 if (!ResVT.isVector())
3097 return {}; // already legalized.
3098
3099 const unsigned NumElts = ResVT.getVectorNumElements();
3100
3101 // Create the return type of the instructions
3102 // +1 represents the reduction value
3103 SmallVector<EVT, 132> ListVTs{
3104 NumElts + 1,
3105 ResVT.getVectorElementType().isFloatingPoint() ? MVT::f32 : MVT::i32};
3106
3107 ListVTs.push_back(Elt: MVT::Other); // Chain
3108
3109 SDVTList ResVTs = DAG.getVTList(VTs: ListVTs);
3110
3111 // Prepare the Operands
3112 SmallVector<SDValue, 8> Ops{N->getOperand(Num: 0)}; // Chain
3113
3114 // skip IID at index 1
3115 for (unsigned i = 2; i < N->getNumOperands(); i++)
3116 Ops.push_back(Elt: N->getOperand(Num: i));
3117
3118 unsigned IID = cast<ConstantSDNode>(Val: N->getOperand(Num: 1))->getZExtValue();
3119 MemIntrinsicSDNode *MemSD = cast<MemIntrinsicSDNode>(Val: N);
3120 SDValue NewNode =
3121 DAG.getMemIntrinsicNode(Opcode: getTcgen05LdRedID(IID), dl: DL, VTList: ResVTs, Ops,
3122 MemVT: MemSD->getMemoryVT(), MMO: MemSD->getMemOperand());
3123
3124 // Split vector result
3125 SmallVector<SDValue, 132> ScalarRes;
3126 for (unsigned i = 0; i < NumElts; ++i) {
3127 SDValue Res = NewNode.getValue(R: i);
3128 ScalarRes.push_back(Elt: Res);
3129 }
3130
3131 SDValue BuildVector = DAG.getNode(Opcode: ISD::BUILD_VECTOR, DL, VT: ResVT, Ops: ScalarRes);
3132 SDValue RedResult = NewNode.getValue(R: NumElts);
3133 SDValue Chain = NewNode.getValue(R: NumElts + 1);
3134 return {{BuildVector, RedResult, Chain}};
3135}
3136
3137static SDValue lowerIntrinsicWChain(SDValue Op, SelectionDAG &DAG) {
3138 switch (Op->getConstantOperandVal(Num: 1)) {
3139 default:
3140 return Op;
3141
3142 // These tcgen05 intrinsics return a v2i32, which is legal, so we have to
3143 // lower them through LowerOperation() instead of ReplaceNodeResults().
3144 case Intrinsic::nvvm_tcgen05_ld_16x64b_x2:
3145 case Intrinsic::nvvm_tcgen05_ld_16x128b_x1:
3146 case Intrinsic::nvvm_tcgen05_ld_32x32b_x2:
3147 if (auto Res = lowerTcgen05Ld(N: Op.getNode(), DAG))
3148 return DAG.getMergeValues(Ops: {Res->first, Res->second}, dl: SDLoc(Op));
3149 return SDValue();
3150
3151 case Intrinsic::nvvm_tcgen05_ld_16x32bx2_x2:
3152 if (auto Res = lowerTcgen05Ld(N: Op.getNode(), DAG, /*HasOffset=*/true))
3153 return DAG.getMergeValues(Ops: {Res->first, Res->second}, dl: SDLoc(Op));
3154 return SDValue();
3155
3156 case Intrinsic::nvvm_tcgen05_ld_red_32x32b_x2_f32:
3157 case Intrinsic::nvvm_tcgen05_ld_red_32x32b_x2_i32:
3158 case Intrinsic::nvvm_tcgen05_ld_red_16x32bx2_x2_f32:
3159 case Intrinsic::nvvm_tcgen05_ld_red_16x32bx2_x2_i32:
3160 if (auto Res = lowerTcgen05LdRed(N: Op.getNode(), DAG))
3161 return DAG.getMergeValues(
3162 Ops: {std::get<0>(t&: *Res), std::get<1>(t&: *Res), std::get<2>(t&: *Res)}, dl: SDLoc(Op));
3163 return SDValue();
3164 }
3165}
3166
3167static SDValue lowerIntrinsicWOChain(SDValue Op, SelectionDAG &DAG) {
3168 switch (Op->getConstantOperandVal(Num: 0)) {
3169 default:
3170 return Op;
3171 case Intrinsic::nvvm_prmt:
3172 case Intrinsic::nvvm_prmt_b4e:
3173 case Intrinsic::nvvm_prmt_ecl:
3174 case Intrinsic::nvvm_prmt_ecr:
3175 case Intrinsic::nvvm_prmt_f4e:
3176 case Intrinsic::nvvm_prmt_rc16:
3177 case Intrinsic::nvvm_prmt_rc8:
3178 return lowerPrmtIntrinsic(Op, DAG);
3179 case Intrinsic::nvvm_internal_addrspace_wrap:
3180 return Op.getOperand(i: 1);
3181 case Intrinsic::nvvm_clusterlaunchcontrol_query_cancel_is_canceled:
3182 case Intrinsic::nvvm_clusterlaunchcontrol_query_cancel_get_first_ctaid_x:
3183 case Intrinsic::nvvm_clusterlaunchcontrol_query_cancel_get_first_ctaid_y:
3184 case Intrinsic::nvvm_clusterlaunchcontrol_query_cancel_get_first_ctaid_z:
3185 return LowerClusterLaunchControlQueryCancel(Op, DAG);
3186 case Intrinsic::nvvm_f32x4_to_e4m3x4_rs_satfinite:
3187 case Intrinsic::nvvm_f32x4_to_e4m3x4_rs_relu_satfinite:
3188 case Intrinsic::nvvm_f32x4_to_e5m2x4_rs_satfinite:
3189 case Intrinsic::nvvm_f32x4_to_e5m2x4_rs_relu_satfinite:
3190 case Intrinsic::nvvm_f32x4_to_e2m3x4_rs_satfinite:
3191 case Intrinsic::nvvm_f32x4_to_e2m3x4_rs_relu_satfinite:
3192 case Intrinsic::nvvm_f32x4_to_e3m2x4_rs_satfinite:
3193 case Intrinsic::nvvm_f32x4_to_e3m2x4_rs_relu_satfinite:
3194 case Intrinsic::nvvm_f32x4_to_e2m1x4_rs_satfinite:
3195 case Intrinsic::nvvm_f32x4_to_e2m1x4_rs_relu_satfinite:
3196 return lowerCvtRSIntrinsics(Op, DAG);
3197 }
3198}
3199
3200// In PTX 64-bit CTLZ and CTPOP are supported, but they return a 32-bit value.
3201// Lower these into a node returning the correct type which is zero-extended
3202// back to the correct size.
3203static SDValue lowerCTLZCTPOP(SDValue Op, SelectionDAG &DAG) {
3204 SDValue V = Op->getOperand(Num: 0);
3205 assert(V.getValueType() == MVT::i64 &&
3206 "Unexpected CTLZ/CTPOP type to legalize");
3207
3208 SDLoc DL(Op);
3209 SDValue CT = DAG.getNode(Opcode: Op->getOpcode(), DL, VT: MVT::i32, Operand: V);
3210 return DAG.getNode(Opcode: ISD::ZERO_EXTEND, DL, VT: MVT::i64, Operand: CT, Flags: SDNodeFlags::NonNeg);
3211}
3212
3213static SDValue expandFSH64(SDValue A, SDValue B, SDValue ShiftAmount, SDLoc DL,
3214 unsigned Opcode, SelectionDAG &DAG) {
3215 assert(A.getValueType() == MVT::i64 && B.getValueType() == MVT::i64);
3216
3217 const auto *AmtConst = dyn_cast<ConstantSDNode>(Val&: ShiftAmount);
3218 if (!AmtConst)
3219 return SDValue();
3220 const auto Amt = AmtConst->getZExtValue() & 63;
3221
3222 SDValue UnpackA =
3223 DAG.getNode(Opcode: NVPTXISD::UNPACK_VECTOR, DL, ResultTys: {MVT::i32, MVT::i32}, Ops: A);
3224 SDValue UnpackB =
3225 DAG.getNode(Opcode: NVPTXISD::UNPACK_VECTOR, DL, ResultTys: {MVT::i32, MVT::i32}, Ops: B);
3226
3227 // Arch is Little endiain: 0 = low bits, 1 = high bits
3228 SDValue ALo = UnpackA.getValue(R: 0);
3229 SDValue AHi = UnpackA.getValue(R: 1);
3230 SDValue BLo = UnpackB.getValue(R: 0);
3231 SDValue BHi = UnpackB.getValue(R: 1);
3232
3233 // The bitfeild consists of { AHi : ALo : BHi : BLo }
3234 //
3235 // * FSHL, Amt < 32 - The window will contain { AHi : ALo : BHi }
3236 // * FSHL, Amt >= 32 - The window will contain { ALo : BHi : BLo }
3237 // * FSHR, Amt < 32 - The window will contain { ALo : BHi : BLo }
3238 // * FSHR, Amt >= 32 - The window will contain { AHi : ALo : BHi }
3239 //
3240 // Note that Amt = 0 and Amt = 32 are special cases where 32-bit funnel shifts
3241 // are not needed at all. Amt = 0 is a no-op producing either A or B depending
3242 // on the direction. Amt = 32 can be implemented by a packing and unpacking
3243 // move to select and arrange the 32bit values. For simplicity, these cases
3244 // are not handled here explicitly and instead we rely on DAGCombiner to
3245 // remove the no-op funnel shifts we insert.
3246 auto [High, Mid, Low] = ((Opcode == ISD::FSHL) == (Amt < 32))
3247 ? std::make_tuple(args&: AHi, args&: ALo, args&: BHi)
3248 : std::make_tuple(args&: ALo, args&: BHi, args&: BLo);
3249
3250 SDValue NewAmt = DAG.getConstant(Val: Amt & 31, DL, VT: MVT::i32);
3251 SDValue RHi = DAG.getNode(Opcode, DL, VT: MVT::i32, Ops: {High, Mid, NewAmt});
3252 SDValue RLo = DAG.getNode(Opcode, DL, VT: MVT::i32, Ops: {Mid, Low, NewAmt});
3253
3254 return DAG.getNode(Opcode: NVPTXISD::BUILD_VECTOR, DL, VT: MVT::i64, Ops: {RLo, RHi});
3255}
3256
3257static SDValue lowerFSH(SDValue Op, SelectionDAG &DAG) {
3258 return expandFSH64(A: Op->getOperand(Num: 0), B: Op->getOperand(Num: 1), ShiftAmount: Op->getOperand(Num: 2),
3259 DL: SDLoc(Op), Opcode: Op->getOpcode(), DAG);
3260}
3261
3262static SDValue lowerROT(SDValue Op, SelectionDAG &DAG) {
3263 unsigned Opcode = Op->getOpcode() == ISD::ROTL ? ISD::FSHL : ISD::FSHR;
3264 return expandFSH64(A: Op->getOperand(Num: 0), B: Op->getOperand(Num: 0), ShiftAmount: Op->getOperand(Num: 1),
3265 DL: SDLoc(Op), Opcode, DAG);
3266}
3267
3268static SDValue lowerFREM(SDValue Op, SelectionDAG &DAG) {
3269 // Lower (frem x, y) into (sub x, (mul (ftrunc (div x, y)) y)),
3270 // i.e. "poor man's fmod()". When y is infinite, x is returned. This matches
3271 // the semantics of LLVM's frem.
3272 SDLoc DL(Op);
3273 SDValue X = Op->getOperand(Num: 0);
3274 SDValue Y = Op->getOperand(Num: 1);
3275 EVT Ty = Op.getValueType();
3276 SDNodeFlags Flags = Op->getFlags();
3277
3278 SDValue Div = DAG.getNode(Opcode: ISD::FDIV, DL, VT: Ty, N1: X, N2: Y, Flags);
3279 SDValue Trunc = DAG.getNode(Opcode: ISD::FTRUNC, DL, VT: Ty, Operand: Div, Flags);
3280 SDValue Mul = DAG.getNode(Opcode: ISD::FMUL, DL, VT: Ty, N1: Trunc, N2: Y,
3281 Flags: Flags | SDNodeFlags::AllowContract);
3282 SDValue Sub = DAG.getNode(Opcode: ISD::FSUB, DL, VT: Ty, N1: X, N2: Mul,
3283 Flags: Flags | SDNodeFlags::AllowContract);
3284
3285 if (Flags.hasNoInfs())
3286 return Sub;
3287
3288 // If Y is infinite, return X
3289 SDValue AbsY = DAG.getNode(Opcode: ISD::FABS, DL, VT: Ty, Operand: Y);
3290 SDValue Inf =
3291 DAG.getConstantFP(Val: APFloat::getInf(Sem: Ty.getFltSemantics()), DL, VT: Ty);
3292 SDValue IsInf = DAG.getSetCC(DL, VT: MVT::i1, LHS: AbsY, RHS: Inf, Cond: ISD::SETEQ);
3293 return DAG.getSelect(DL, VT: Ty, Cond: IsInf, LHS: X, RHS: Sub);
3294}
3295
3296static SDValue lowerSELECT(SDValue Op, SelectionDAG &DAG) {
3297 assert(Op.getValueType() == MVT::i1 && "Custom lowering enabled only for i1");
3298
3299 SDValue Cond = Op->getOperand(Num: 0);
3300 SDValue TrueVal = Op->getOperand(Num: 1);
3301 SDValue FalseVal = Op->getOperand(Num: 2);
3302 SDLoc DL(Op);
3303
3304 // If both operands are truncated, we push the select through the truncates.
3305 if (TrueVal.getOpcode() == ISD::TRUNCATE &&
3306 FalseVal.getOpcode() == ISD::TRUNCATE) {
3307 TrueVal = TrueVal.getOperand(i: 0);
3308 FalseVal = FalseVal.getOperand(i: 0);
3309
3310 EVT VT = TrueVal.getSimpleValueType().bitsLE(VT: FalseVal.getSimpleValueType())
3311 ? TrueVal.getValueType()
3312 : FalseVal.getValueType();
3313 TrueVal = DAG.getAnyExtOrTrunc(Op: TrueVal, DL, VT);
3314 FalseVal = DAG.getAnyExtOrTrunc(Op: FalseVal, DL, VT);
3315 SDValue Select = DAG.getSelect(DL, VT, Cond, LHS: TrueVal, RHS: FalseVal);
3316 return DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT: MVT::i1, Operand: Select);
3317 }
3318
3319 // Otherwise, expand the select into a series of logical operations. These
3320 // often can be folded into other operations either by us or ptxas.
3321 TrueVal = DAG.getFreeze(V: TrueVal);
3322 FalseVal = DAG.getFreeze(V: FalseVal);
3323 SDValue And1 = DAG.getNode(Opcode: ISD::AND, DL, VT: MVT::i1, N1: Cond, N2: TrueVal);
3324 SDValue NotCond = DAG.getNOT(DL, Val: Cond, VT: MVT::i1);
3325 SDValue And2 = DAG.getNode(Opcode: ISD::AND, DL, VT: MVT::i1, N1: NotCond, N2: FalseVal);
3326 SDValue Or = DAG.getNode(Opcode: ISD::OR, DL, VT: MVT::i1, N1: And1, N2: And2);
3327 return Or;
3328}
3329
3330static SDValue lowerMSTORE(SDValue Op, SelectionDAG &DAG) {
3331 SDNode *N = Op.getNode();
3332
3333 SDValue Chain = N->getOperand(Num: 0);
3334 SDValue Val = N->getOperand(Num: 1);
3335 SDValue BasePtr = N->getOperand(Num: 2);
3336 SDValue Offset = N->getOperand(Num: 3);
3337 SDValue Mask = N->getOperand(Num: 4);
3338
3339 SDLoc DL(N);
3340 EVT ValVT = Val.getValueType();
3341 MemSDNode *MemSD = cast<MemSDNode>(Val: N);
3342 assert(ValVT.isVector() && "Masked vector store must have vector type");
3343 assert(MemSD->getAlign() >= DAG.getEVTAlign(ValVT) &&
3344 "Unexpected alignment for masked store");
3345
3346 unsigned Opcode = 0;
3347 switch (ValVT.getSimpleVT().SimpleTy) {
3348 default:
3349 llvm_unreachable("Unexpected masked vector store type");
3350 case MVT::v4i64:
3351 case MVT::v4f64: {
3352 Opcode = NVPTXISD::StoreV4;
3353 break;
3354 }
3355 case MVT::v8i32:
3356 case MVT::v8f32: {
3357 Opcode = NVPTXISD::StoreV8;
3358 break;
3359 }
3360 }
3361
3362 SmallVector<SDValue, 8> Ops;
3363
3364 // Construct the new SDNode. First operand is the chain.
3365 Ops.push_back(Elt: Chain);
3366
3367 // The next N operands are the values to store. Encode the mask into the
3368 // values using the sentinel register 0 to represent a masked-off element.
3369 assert(Mask.getValueType().isVector() &&
3370 Mask.getValueType().getVectorElementType() == MVT::i1 &&
3371 "Mask must be a vector of i1");
3372 assert(Mask.getOpcode() == ISD::BUILD_VECTOR &&
3373 "Mask expected to be a BUILD_VECTOR");
3374 assert(Mask.getValueType().getVectorNumElements() ==
3375 ValVT.getVectorNumElements() &&
3376 "Mask size must be the same as the vector size");
3377 for (auto [I, Op] : enumerate(First: Mask->ops())) {
3378 // Mask elements must be constants.
3379 if (Op.getNode()->getAsZExtVal() == 0) {
3380 // Append a sentinel register 0 to the Ops vector to represent a masked
3381 // off element, this will be handled in tablegen
3382 Ops.push_back(Elt: DAG.getRegister(Reg: MCRegister::NoRegister,
3383 VT: ValVT.getVectorElementType()));
3384 } else {
3385 // Extract the element from the vector to store
3386 SDValue ExtVal =
3387 DAG.getNode(Opcode: ISD::EXTRACT_VECTOR_ELT, DL, VT: ValVT.getVectorElementType(),
3388 N1: Val, N2: DAG.getIntPtrConstant(Val: I, DL));
3389 Ops.push_back(Elt: ExtVal);
3390 }
3391 }
3392
3393 // Next, the pointer operand.
3394 Ops.push_back(Elt: BasePtr);
3395
3396 // Finally, the offset operand. We expect this to always be undef, and it will
3397 // be ignored in lowering, but to mirror the handling of the other vector
3398 // store instructions we include it in the new SDNode.
3399 assert(Offset.getOpcode() == ISD::UNDEF &&
3400 "Offset operand expected to be undef");
3401 Ops.push_back(Elt: Offset);
3402
3403 SDValue NewSt =
3404 DAG.getMemIntrinsicNode(Opcode, dl: DL, VTList: DAG.getVTList(VT: MVT::Other), Ops,
3405 MemVT: MemSD->getMemoryVT(), MMO: MemSD->getMemOperand());
3406
3407 return NewSt;
3408}
3409
3410SDValue
3411NVPTXTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
3412 switch (Op.getOpcode()) {
3413 case ISD::RETURNADDR:
3414 return SDValue();
3415 case ISD::FRAMEADDR:
3416 return SDValue();
3417 case ISD::ADDRSPACECAST:
3418 return LowerADDRSPACECAST(Op, DAG);
3419 case ISD::INTRINSIC_W_CHAIN:
3420 return lowerIntrinsicWChain(Op, DAG);
3421 case ISD::INTRINSIC_WO_CHAIN:
3422 return lowerIntrinsicWOChain(Op, DAG);
3423 case ISD::INTRINSIC_VOID:
3424 return lowerIntrinsicVoid(Op, DAG);
3425 case ISD::BUILD_VECTOR:
3426 return LowerBUILD_VECTOR(Op, DAG);
3427 case ISD::BITCAST:
3428 return LowerBITCAST(Op, DAG);
3429 case ISD::EXTRACT_SUBVECTOR:
3430 return Op;
3431 case ISD::EXTRACT_VECTOR_ELT:
3432 return LowerEXTRACT_VECTOR_ELT(Op, DAG);
3433 case ISD::INSERT_VECTOR_ELT:
3434 return LowerINSERT_VECTOR_ELT(Op, DAG);
3435 case ISD::VECTOR_SHUFFLE:
3436 return LowerVECTOR_SHUFFLE(Op, DAG);
3437 case ISD::CONCAT_VECTORS:
3438 return LowerCONCAT_VECTORS(Op, DAG);
3439 case ISD::VECREDUCE_FMAX:
3440 case ISD::VECREDUCE_FMIN:
3441 case ISD::VECREDUCE_FMAXIMUM:
3442 case ISD::VECREDUCE_FMINIMUM:
3443 return LowerVECREDUCE(Op, DAG);
3444 case ISD::STORE:
3445 return LowerSTORE(Op, DAG);
3446 case ISD::MSTORE: {
3447 assert(STI.has256BitVectorLoadStore(
3448 cast<MemSDNode>(Op.getNode())->getAddressSpace()) &&
3449 "Masked store vector not supported on subtarget.");
3450 return lowerMSTORE(Op, DAG);
3451 }
3452 case ISD::LOAD:
3453 return LowerLOAD(Op, DAG);
3454 case ISD::MLOAD:
3455 return LowerMLOAD(Op, DAG);
3456 case ISD::SHL_PARTS:
3457 return LowerShiftLeftParts(Op, DAG);
3458 case ISD::SRA_PARTS:
3459 case ISD::SRL_PARTS:
3460 return LowerShiftRightParts(Op, DAG);
3461 case ISD::SELECT:
3462 return lowerSELECT(Op, DAG);
3463 case ISD::FROUND:
3464 return LowerFROUND(Op, DAG);
3465 case ISD::FCOPYSIGN:
3466 return LowerFCOPYSIGN(Op, DAG);
3467 case ISD::SINT_TO_FP:
3468 case ISD::UINT_TO_FP:
3469 return LowerINT_TO_FP(Op, DAG);
3470 case ISD::FP_TO_SINT:
3471 case ISD::FP_TO_UINT:
3472 return LowerFP_TO_INT(Op, DAG);
3473 case ISD::FP_ROUND:
3474 return LowerFP_ROUND(Op, DAG);
3475 case ISD::FP_EXTEND:
3476 return LowerFP_EXTEND(Op, DAG);
3477 case ISD::VAARG:
3478 return LowerVAARG(Op, DAG);
3479 case ISD::VASTART:
3480 return LowerVASTART(Op, DAG);
3481 case ISD::FSHL:
3482 case ISD::FSHR:
3483 return lowerFSH(Op, DAG);
3484 case ISD::ROTL:
3485 case ISD::ROTR:
3486 return lowerROT(Op, DAG);
3487 case ISD::ABS:
3488 case ISD::SMIN:
3489 case ISD::SMAX:
3490 case ISD::UMIN:
3491 case ISD::UMAX:
3492 case ISD::ADD:
3493 case ISD::SUB:
3494 case ISD::MUL:
3495 case ISD::SHL:
3496 case ISD::SREM:
3497 case ISD::UREM:
3498 return LowerVectorArith(Op, DAG);
3499 case ISD::DYNAMIC_STACKALLOC:
3500 return LowerDYNAMIC_STACKALLOC(Op, DAG);
3501 case ISD::STACKRESTORE:
3502 return LowerSTACKRESTORE(Op, DAG);
3503 case ISD::STACKSAVE:
3504 return LowerSTACKSAVE(Op, DAG);
3505 case ISD::CopyToReg:
3506 return LowerCopyToReg_128(Op, DAG);
3507 case ISD::FADD:
3508 case ISD::FSUB:
3509 case ISD::FMUL:
3510 // Used only for bf16 on SM80, where we select fma for non-ftz operation
3511 return PromoteBinOpIfF32FTZ(Op, DAG);
3512 case ISD::CTPOP:
3513 case ISD::CTLZ:
3514 return lowerCTLZCTPOP(Op, DAG);
3515 case ISD::FREM:
3516 return lowerFREM(Op, DAG);
3517 case ISD::BSWAP:
3518 return lowerBSWAP(Op, DAG);
3519 default:
3520 llvm_unreachable("Custom lowering not defined for operation");
3521 }
3522}
3523
3524// This will prevent AsmPrinter from trying to print the jump tables itself.
3525unsigned NVPTXTargetLowering::getJumpTableEncoding() const {
3526 return MachineJumpTableInfo::EK_Inline;
3527}
3528
3529SDValue NVPTXTargetLowering::LowerADDRSPACECAST(SDValue Op,
3530 SelectionDAG &DAG) const {
3531 AddrSpaceCastSDNode *N = cast<AddrSpaceCastSDNode>(Val: Op.getNode());
3532 unsigned SrcAS = N->getSrcAddressSpace();
3533 unsigned DestAS = N->getDestAddressSpace();
3534 if (SrcAS != llvm::ADDRESS_SPACE_GENERIC &&
3535 DestAS != llvm::ADDRESS_SPACE_GENERIC) {
3536 // Shared and SharedCluster can be converted to each other through generic
3537 // space
3538 if ((SrcAS == llvm::ADDRESS_SPACE_SHARED &&
3539 DestAS == llvm::ADDRESS_SPACE_SHARED_CLUSTER) ||
3540 (SrcAS == llvm::ADDRESS_SPACE_SHARED_CLUSTER &&
3541 DestAS == llvm::ADDRESS_SPACE_SHARED)) {
3542 SDLoc DL(Op.getNode());
3543 const MVT GenerictVT =
3544 getPointerTy(DL: DAG.getDataLayout(), AS: ADDRESS_SPACE_GENERIC);
3545 SDValue GenericConversion = DAG.getAddrSpaceCast(
3546 dl: DL, VT: GenerictVT, Ptr: Op.getOperand(i: 0), SrcAS, DestAS: ADDRESS_SPACE_GENERIC);
3547 SDValue SharedClusterConversion =
3548 DAG.getAddrSpaceCast(dl: DL, VT: Op.getValueType(), Ptr: GenericConversion,
3549 SrcAS: ADDRESS_SPACE_GENERIC, DestAS);
3550 return SharedClusterConversion;
3551 }
3552
3553 return DAG.getUNDEF(VT: Op.getValueType());
3554 }
3555
3556 return Op;
3557}
3558
3559// This function is almost a copy of SelectionDAG::expandVAArg().
3560// The only diff is that this one produces loads from local address space.
3561SDValue NVPTXTargetLowering::LowerVAARG(SDValue Op, SelectionDAG &DAG) const {
3562 const TargetLowering *TLI = STI.getTargetLowering();
3563 SDLoc DL(Op);
3564
3565 SDNode *Node = Op.getNode();
3566 const Value *V = cast<SrcValueSDNode>(Val: Node->getOperand(Num: 2))->getValue();
3567 EVT VT = Node->getValueType(ResNo: 0);
3568 auto *Ty = VT.getTypeForEVT(Context&: *DAG.getContext());
3569 SDValue Tmp1 = Node->getOperand(Num: 0);
3570 SDValue Tmp2 = Node->getOperand(Num: 1);
3571 const MaybeAlign MA(Node->getConstantOperandVal(Num: 3));
3572
3573 SDValue VAListLoad = DAG.getLoad(VT: TLI->getPointerTy(DL: DAG.getDataLayout()), dl: DL,
3574 Chain: Tmp1, Ptr: Tmp2, PtrInfo: MachinePointerInfo(V));
3575 SDValue VAList = VAListLoad;
3576
3577 if (MA && *MA > TLI->getMinStackArgumentAlignment()) {
3578 VAList = DAG.getNode(
3579 Opcode: ISD::ADD, DL, VT: VAList.getValueType(), N1: VAList,
3580 N2: DAG.getConstant(Val: MA->value() - 1, DL, VT: VAList.getValueType()));
3581
3582 VAList = DAG.getNode(Opcode: ISD::AND, DL, VT: VAList.getValueType(), N1: VAList,
3583 N2: DAG.getSignedConstant(Val: -(int64_t)MA->value(), DL,
3584 VT: VAList.getValueType()));
3585 }
3586
3587 // Increment the pointer, VAList, to the next vaarg
3588 Tmp1 = DAG.getNode(Opcode: ISD::ADD, DL, VT: VAList.getValueType(), N1: VAList,
3589 N2: DAG.getConstant(Val: DAG.getDataLayout().getTypeAllocSize(Ty),
3590 DL, VT: VAList.getValueType()));
3591
3592 // Store the incremented VAList to the legalized pointer
3593 Tmp1 = DAG.getStore(Chain: VAListLoad.getValue(R: 1), dl: DL, Val: Tmp1, Ptr: Tmp2,
3594 PtrInfo: MachinePointerInfo(V));
3595
3596 const Value *SrcV = Constant::getNullValue(
3597 Ty: PointerType::get(C&: *DAG.getContext(), AddressSpace: ADDRESS_SPACE_LOCAL));
3598
3599 // Load the actual argument out of the pointer VAList
3600 return DAG.getLoad(VT, dl: DL, Chain: Tmp1, Ptr: VAList, PtrInfo: MachinePointerInfo(SrcV));
3601}
3602
3603SDValue NVPTXTargetLowering::LowerVASTART(SDValue Op, SelectionDAG &DAG) const {
3604 const TargetLowering *TLI = STI.getTargetLowering();
3605 SDLoc DL(Op);
3606 EVT PtrVT = TLI->getPointerTy(DL: DAG.getDataLayout());
3607
3608 // Store the address of unsized array <function>_vararg[] in the ap object.
3609 SDValue VAReg = getParamSymbol(DAG, /* vararg */ I: -1, T: PtrVT);
3610
3611 const Value *SV = cast<SrcValueSDNode>(Val: Op.getOperand(i: 2))->getValue();
3612 return DAG.getStore(Chain: Op.getOperand(i: 0), dl: DL, Val: VAReg, Ptr: Op.getOperand(i: 1),
3613 PtrInfo: MachinePointerInfo(SV));
3614}
3615
3616static std::pair<MemSDNode *, uint32_t>
3617convertMLOADToLoadWithUsedBytesMask(MemSDNode *N, SelectionDAG &DAG,
3618 const NVPTXSubtarget &STI) {
3619 SDValue Chain = N->getOperand(Num: 0);
3620 SDValue BasePtr = N->getOperand(Num: 1);
3621 SDValue Mask = N->getOperand(Num: 3);
3622 [[maybe_unused]] SDValue Passthru = N->getOperand(Num: 4);
3623
3624 SDLoc DL(N);
3625 EVT ResVT = N->getValueType(ResNo: 0);
3626 assert(ResVT.isVector() && "Masked vector load must have vector type");
3627 // While we only expect poison passthru vectors as an input to the backend,
3628 // when the legalization framework splits a poison vector in half, it creates
3629 // two undef vectors, so we can technically expect those too.
3630 assert((Passthru.getOpcode() == ISD::POISON ||
3631 Passthru.getOpcode() == ISD::UNDEF) &&
3632 "Passthru operand expected to be poison or undef");
3633
3634 // Extract the mask and convert it to a uint32_t representing the used bytes
3635 // of the entire vector load
3636 uint32_t UsedBytesMask = 0;
3637 uint32_t ElementSizeInBits = ResVT.getVectorElementType().getSizeInBits();
3638 assert(ElementSizeInBits % 8 == 0 && "Unexpected element size");
3639 uint32_t ElementSizeInBytes = ElementSizeInBits / 8;
3640 uint32_t ElementMask = (1u << ElementSizeInBytes) - 1u;
3641
3642 for (SDValue Op : reverse(C: Mask->ops())) {
3643 // We technically only want to do this shift for every
3644 // iteration *but* the first, but in the first iteration UsedBytesMask is 0,
3645 // so this shift is a no-op.
3646 UsedBytesMask <<= ElementSizeInBytes;
3647
3648 // Mask elements must be constants.
3649 if (Op->getAsZExtVal() != 0)
3650 UsedBytesMask |= ElementMask;
3651 }
3652
3653 assert(UsedBytesMask != 0 && UsedBytesMask != UINT32_MAX &&
3654 "Unexpected masked load with elements masked all on or all off");
3655
3656 // Create a new load sd node to be handled normally by ReplaceLoadVector.
3657 MemSDNode *NewLD = cast<MemSDNode>(
3658 Val: DAG.getLoad(VT: ResVT, dl: DL, Chain, Ptr: BasePtr, MMO: N->getMemOperand()).getNode());
3659
3660 // If our subtarget does not support the used bytes mask pragma, "drop" the
3661 // mask by setting it to UINT32_MAX
3662 if (!STI.hasUsedBytesMaskPragma())
3663 UsedBytesMask = UINT32_MAX;
3664
3665 return {NewLD, UsedBytesMask};
3666}
3667
3668/// replaceLoadVector - Convert vector loads into multi-output scalar loads.
3669static std::optional<std::pair<SDValue, SDValue>>
3670replaceLoadVector(SDNode *N, SelectionDAG &DAG, const NVPTXSubtarget &STI) {
3671 MemSDNode *LD = cast<MemSDNode>(Val: N);
3672 const EVT ResVT = LD->getValueType(ResNo: 0);
3673 const EVT MemVT = LD->getMemoryVT();
3674
3675 // If we're doing sign/zero extension as part of the load, avoid lowering to
3676 // a LoadV node. TODO: consider relaxing this restriction.
3677 if (ResVT != MemVT)
3678 return std::nullopt;
3679
3680 const auto NumEltsAndEltVT =
3681 getVectorLoweringShape(VectorEVT: ResVT, STI, AddressSpace: LD->getAddressSpace());
3682 if (!NumEltsAndEltVT)
3683 return std::nullopt;
3684 const auto [NumElts, EltVT] = NumEltsAndEltVT.value();
3685
3686 Align Alignment = LD->getAlign();
3687 const auto &TD = DAG.getDataLayout();
3688 Align PrefAlign = TD.getPrefTypeAlign(Ty: MemVT.getTypeForEVT(Context&: *DAG.getContext()));
3689 if (Alignment < PrefAlign) {
3690 // This load is not sufficiently aligned, so bail out and let this vector
3691 // load be scalarized. Note that we may still be able to emit smaller
3692 // vector loads. For example, if we are loading a <4 x float> with an
3693 // alignment of 8, this check will fail but the legalizer will try again
3694 // with 2 x <2 x float>, which will succeed with an alignment of 8.
3695 return std::nullopt;
3696 }
3697
3698 // If we have a masked load, convert it to a normal load now
3699 std::optional<uint32_t> UsedBytesMask = std::nullopt;
3700 if (LD->getOpcode() == ISD::MLOAD)
3701 std::tie(args&: LD, args&: UsedBytesMask) =
3702 convertMLOADToLoadWithUsedBytesMask(N: LD, DAG, STI);
3703
3704 // Since LoadV2 is a target node, we cannot rely on DAG type legalization.
3705 // Therefore, we must ensure the type is legal. For i1 and i8, we set the
3706 // loaded type to i16 and propagate the "real" type as the memory type.
3707 const MVT LoadEltVT = (EltVT.getSizeInBits() < 16) ? MVT::i16 : EltVT;
3708
3709 unsigned Opcode;
3710 switch (NumElts) {
3711 default:
3712 return std::nullopt;
3713 case 2:
3714 Opcode = NVPTXISD::LoadV2;
3715 break;
3716 case 4:
3717 Opcode = NVPTXISD::LoadV4;
3718 break;
3719 case 8:
3720 Opcode = NVPTXISD::LoadV8;
3721 break;
3722 }
3723 auto ListVTs = SmallVector<EVT, 9>(NumElts, LoadEltVT);
3724 ListVTs.push_back(Elt: MVT::Other);
3725 SDVTList LdResVTs = DAG.getVTList(VTs: ListVTs);
3726
3727 SDLoc DL(LD);
3728
3729 // Copy regular operands
3730 SmallVector<SDValue, 8> OtherOps(LD->ops());
3731
3732 OtherOps.push_back(
3733 Elt: DAG.getConstant(Val: UsedBytesMask.value_or(UINT32_MAX), DL, VT: MVT::i32));
3734
3735 // The select routine does not have access to the LoadSDNode instance, so
3736 // pass along the extension information
3737 OtherOps.push_back(
3738 Elt: DAG.getIntPtrConstant(Val: cast<LoadSDNode>(Val: LD)->getExtensionType(), DL));
3739
3740 SDValue NewLD = DAG.getMemIntrinsicNode(Opcode, dl: DL, VTList: LdResVTs, Ops: OtherOps, MemVT,
3741 MMO: LD->getMemOperand());
3742
3743 SmallVector<SDValue> ScalarRes;
3744 if (EltVT.isVector()) {
3745 assert(EVT(EltVT.getVectorElementType()) == ResVT.getVectorElementType());
3746 assert(NumElts * EltVT.getVectorNumElements() ==
3747 ResVT.getVectorNumElements());
3748 // Generate EXTRACT_VECTOR_ELTs to split v2[i,f,bf]16/v4i8 subvectors back
3749 // into individual elements.
3750 for (const unsigned I : llvm::seq(Size: NumElts)) {
3751 SDValue SubVector = NewLD.getValue(R: I);
3752 DAG.ExtractVectorElements(Op: SubVector, Args&: ScalarRes);
3753 }
3754 } else {
3755 for (const unsigned I : llvm::seq(Size: NumElts)) {
3756 SDValue Res = NewLD.getValue(R: I);
3757 if (LoadEltVT != EltVT)
3758 Res = DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT: EltVT, Operand: Res);
3759 ScalarRes.push_back(Elt: Res);
3760 }
3761 }
3762
3763 SDValue LoadChain = NewLD.getValue(R: NumElts);
3764
3765 const MVT BuildVecVT =
3766 MVT::getVectorVT(VT: EltVT.getScalarType(), NumElements: ScalarRes.size());
3767 SDValue BuildVec = DAG.getBuildVector(VT: BuildVecVT, DL, Ops: ScalarRes);
3768 SDValue LoadValue = DAG.getBitcast(VT: ResVT, V: BuildVec);
3769
3770 return {{LoadValue, LoadChain}};
3771}
3772
3773static void replaceLoadVector(SDNode *N, SelectionDAG &DAG,
3774 SmallVectorImpl<SDValue> &Results,
3775 const NVPTXSubtarget &STI) {
3776 if (auto Res = replaceLoadVector(N, DAG, STI))
3777 Results.append(IL: {Res->first, Res->second});
3778}
3779
3780static SDValue lowerLoadVector(SDNode *N, SelectionDAG &DAG,
3781 const NVPTXSubtarget &STI) {
3782 if (auto Res = replaceLoadVector(N, DAG, STI))
3783 return DAG.getMergeValues(Ops: {Res->first, Res->second}, dl: SDLoc(N));
3784 return SDValue();
3785}
3786
3787// v = ld i1* addr
3788// =>
3789// v1 = ld i8* addr (-> i16)
3790// v = trunc i16 to i1
3791static SDValue lowerLOADi1(LoadSDNode *LD, SelectionDAG &DAG) {
3792 SDLoc dl(LD);
3793 assert(LD->getExtensionType() == ISD::NON_EXTLOAD);
3794 assert(LD->getValueType(0) == MVT::i1 && "Custom lowering for i1 load only");
3795 SDValue newLD = DAG.getExtLoad(ExtType: ISD::ZEXTLOAD, dl, VT: MVT::i16, Chain: LD->getChain(),
3796 Ptr: LD->getBasePtr(), PtrInfo: LD->getPointerInfo(),
3797 MemVT: MVT::i8, Alignment: LD->getAlign(),
3798 MMOFlags: LD->getMemOperand()->getFlags());
3799 SDValue result = DAG.getNode(Opcode: ISD::TRUNCATE, DL: dl, VT: MVT::i1, Operand: newLD);
3800 // The legalizer (the caller) is expecting two values from the legalized
3801 // load, so we build a MergeValues node for it. See ExpandUnalignedLoad()
3802 // in LegalizeDAG.cpp which also uses MergeValues.
3803 return DAG.getMergeValues(Ops: {result, LD->getChain()}, dl);
3804}
3805
3806SDValue NVPTXTargetLowering::LowerLOAD(SDValue Op, SelectionDAG &DAG) const {
3807 LoadSDNode *LD = cast<LoadSDNode>(Val&: Op);
3808
3809 if (Op.getValueType() == MVT::i1)
3810 return lowerLOADi1(LD, DAG);
3811
3812 // To improve CodeGen we'll legalize any-extend loads to zext loads. This is
3813 // how they'll be lowered in ISel anyway, and by doing this a little earlier
3814 // we allow for more DAG combine opportunities.
3815 if (LD->getExtensionType() == ISD::EXTLOAD) {
3816 assert(LD->getValueType(0).isInteger() && LD->getMemoryVT().isInteger() &&
3817 "Unexpected fpext-load");
3818 return DAG.getExtLoad(ExtType: ISD::ZEXTLOAD, dl: SDLoc(Op), VT: Op.getValueType(),
3819 Chain: LD->getChain(), Ptr: LD->getBasePtr(), MemVT: LD->getMemoryVT(),
3820 MMO: LD->getMemOperand());
3821 }
3822
3823 llvm_unreachable("Unexpected custom lowering for load");
3824}
3825
3826SDValue NVPTXTargetLowering::LowerMLOAD(SDValue Op, SelectionDAG &DAG) const {
3827 // v2f16/v2bf16/v2i16/v4i8 are legal, so we can't rely on legalizer to handle
3828 // masked loads of these types and have to handle them here.
3829 // v2f32 also needs to be handled here if the subtarget has f32x2
3830 // instructions, making it legal.
3831 //
3832 // Note: misaligned masked loads should never reach this point
3833 // because the override of isLegalMaskedLoad in NVPTXTargetTransformInfo.cpp
3834 // will validate alignment. Therefore, we do not need to special case handle
3835 // them here.
3836 EVT VT = Op.getValueType();
3837 if (NVPTX::isPackedVectorTy(VT)) {
3838 auto Result = convertMLOADToLoadWithUsedBytesMask(
3839 N: cast<MemSDNode>(Val: Op.getNode()), DAG, STI);
3840 MemSDNode *LD = std::get<0>(in&: Result);
3841 uint32_t UsedBytesMask = std::get<1>(in&: Result);
3842
3843 SDLoc DL(LD);
3844
3845 // Copy regular operands
3846 SmallVector<SDValue, 8> OtherOps(LD->ops());
3847
3848 OtherOps.push_back(Elt: DAG.getConstant(Val: UsedBytesMask, DL, VT: MVT::i32));
3849
3850 // We currently are not lowering extending loads, but pass the extension
3851 // type anyway as later handling expects it.
3852 OtherOps.push_back(
3853 Elt: DAG.getIntPtrConstant(Val: cast<LoadSDNode>(Val: LD)->getExtensionType(), DL));
3854 SDValue NewLD =
3855 DAG.getMemIntrinsicNode(Opcode: NVPTXISD::MLoad, dl: DL, VTList: LD->getVTList(), Ops: OtherOps,
3856 MemVT: LD->getMemoryVT(), MMO: LD->getMemOperand());
3857 return NewLD;
3858 }
3859 return SDValue();
3860}
3861
3862static SDValue lowerSTOREVector(SDValue Op, SelectionDAG &DAG,
3863 const NVPTXSubtarget &STI) {
3864 MemSDNode *N = cast<MemSDNode>(Val: Op.getNode());
3865 SDValue Val = N->getOperand(Num: 1);
3866 SDLoc DL(N);
3867 const EVT ValVT = Val.getValueType();
3868 const EVT MemVT = N->getMemoryVT();
3869
3870 // If we're truncating as part of the store, avoid lowering to a StoreV node.
3871 // TODO: consider relaxing this restriction.
3872 if (ValVT != MemVT)
3873 return SDValue();
3874
3875 const auto NumEltsAndEltVT =
3876 getVectorLoweringShape(VectorEVT: ValVT, STI, AddressSpace: N->getAddressSpace());
3877 if (!NumEltsAndEltVT)
3878 return SDValue();
3879 const auto [NumElts, EltVT] = NumEltsAndEltVT.value();
3880
3881 const DataLayout &TD = DAG.getDataLayout();
3882
3883 Align Alignment = N->getAlign();
3884 Align PrefAlign = TD.getPrefTypeAlign(Ty: ValVT.getTypeForEVT(Context&: *DAG.getContext()));
3885 if (Alignment < PrefAlign) {
3886 // This store is not sufficiently aligned, so bail out and let this vector
3887 // store be scalarized. Note that we may still be able to emit smaller
3888 // vector stores. For example, if we are storing a <4 x float> with an
3889 // alignment of 8, this check will fail but the legalizer will try again
3890 // with 2 x <2 x float>, which will succeed with an alignment of 8.
3891 return SDValue();
3892 }
3893
3894 unsigned Opcode;
3895 switch (NumElts) {
3896 default:
3897 return SDValue();
3898 case 2:
3899 Opcode = NVPTXISD::StoreV2;
3900 break;
3901 case 4:
3902 Opcode = NVPTXISD::StoreV4;
3903 break;
3904 case 8:
3905 Opcode = NVPTXISD::StoreV8;
3906 break;
3907 }
3908
3909 SmallVector<SDValue, 8> Ops;
3910
3911 // First is the chain
3912 Ops.push_back(Elt: N->getOperand(Num: 0));
3913
3914 // Then the split values
3915 if (EltVT.isVector()) {
3916 assert(EVT(EltVT.getVectorElementType()) == ValVT.getVectorElementType());
3917 assert(NumElts * EltVT.getVectorNumElements() ==
3918 ValVT.getVectorNumElements());
3919 // Combine individual elements into v2[i,f,bf]16/v4i8 subvectors to be
3920 // stored as b32s
3921 const unsigned NumEltsPerSubVector = EltVT.getVectorNumElements();
3922 for (const unsigned I : llvm::seq(Size: NumElts)) {
3923 SmallVector<SDValue, 4> SubVectorElts;
3924 DAG.ExtractVectorElements(Op: Val, Args&: SubVectorElts, Start: I * NumEltsPerSubVector,
3925 Count: NumEltsPerSubVector);
3926 Ops.push_back(Elt: DAG.getBuildVector(VT: EltVT, DL, Ops: SubVectorElts));
3927 }
3928 } else {
3929 SDValue V = DAG.getBitcast(VT: MVT::getVectorVT(VT: EltVT, NumElements: NumElts), V: Val);
3930 for (const unsigned I : llvm::seq(Size: NumElts)) {
3931 SDValue ExtVal = DAG.getNode(Opcode: ISD::EXTRACT_VECTOR_ELT, DL, VT: EltVT, N1: V,
3932 N2: DAG.getIntPtrConstant(Val: I, DL));
3933
3934 // Since StoreV2 is a target node, we cannot rely on DAG type
3935 // legalization. Therefore, we must ensure the type is legal. For i1 and
3936 // i8, we set the stored type to i16 and propagate the "real" type as the
3937 // memory type.
3938 if (EltVT.getSizeInBits() < 16)
3939 ExtVal = DAG.getNode(Opcode: ISD::ANY_EXTEND, DL, VT: MVT::i16, Operand: ExtVal);
3940 Ops.push_back(Elt: ExtVal);
3941 }
3942 }
3943
3944 // Then any remaining arguments
3945 Ops.append(in_start: N->op_begin() + 2, in_end: N->op_end());
3946
3947 SDValue NewSt =
3948 DAG.getMemIntrinsicNode(Opcode, dl: DL, VTList: DAG.getVTList(VT: MVT::Other), Ops,
3949 MemVT: N->getMemoryVT(), MMO: N->getMemOperand());
3950
3951 // return DCI.CombineTo(N, NewSt, true);
3952 return NewSt;
3953}
3954
3955SDValue NVPTXTargetLowering::LowerSTORE(SDValue Op, SelectionDAG &DAG) const {
3956 StoreSDNode *Store = cast<StoreSDNode>(Val&: Op);
3957 EVT VT = Store->getMemoryVT();
3958
3959 if (VT == MVT::i1)
3960 return LowerSTOREi1(Op, DAG);
3961
3962 // Lower store of any other vector type, including v2f32 as we want to break
3963 // it apart since this is not a widely-supported type.
3964 return lowerSTOREVector(Op, DAG, STI);
3965}
3966
3967// st i1 v, addr
3968// =>
3969// v1 = zxt v to i16
3970// st.u8 i16, addr
3971SDValue NVPTXTargetLowering::LowerSTOREi1(SDValue Op, SelectionDAG &DAG) const {
3972 SDNode *Node = Op.getNode();
3973 SDLoc dl(Node);
3974 StoreSDNode *ST = cast<StoreSDNode>(Val: Node);
3975 SDValue Tmp1 = ST->getChain();
3976 SDValue Tmp2 = ST->getBasePtr();
3977 SDValue Tmp3 = ST->getValue();
3978 assert(Tmp3.getValueType() == MVT::i1 && "Custom lowering for i1 store only");
3979 Tmp3 = DAG.getNode(Opcode: ISD::ZERO_EXTEND, DL: dl, VT: MVT::i16, Operand: Tmp3);
3980 SDValue Result =
3981 DAG.getTruncStore(Chain: Tmp1, dl, Val: Tmp3, Ptr: Tmp2, PtrInfo: ST->getPointerInfo(), SVT: MVT::i8,
3982 Alignment: ST->getAlign(), MMOFlags: ST->getMemOperand()->getFlags());
3983 return Result;
3984}
3985
3986SDValue NVPTXTargetLowering::LowerCopyToReg_128(SDValue Op,
3987 SelectionDAG &DAG) const {
3988 // Change the CopyToReg to take in two 64-bit operands instead of a 128-bit
3989 // operand so that it can pass the legalization.
3990
3991 assert(Op.getOperand(1).getValueType() == MVT::i128 &&
3992 "Custom lowering for 128-bit CopyToReg only");
3993
3994 SDNode *Node = Op.getNode();
3995 SDLoc DL(Node);
3996
3997 SDValue Cast = DAG.getBitcast(VT: MVT::v2i64, V: Op->getOperand(Num: 2));
3998 SDValue Lo = DAG.getNode(Opcode: ISD::EXTRACT_VECTOR_ELT, DL, VT: MVT::i64, N1: Cast,
3999 N2: DAG.getIntPtrConstant(Val: 0, DL));
4000 SDValue Hi = DAG.getNode(Opcode: ISD::EXTRACT_VECTOR_ELT, DL, VT: MVT::i64, N1: Cast,
4001 N2: DAG.getIntPtrConstant(Val: 1, DL));
4002
4003 SmallVector<SDValue, 5> NewOps(Op->getNumOperands() + 1);
4004 SmallVector<EVT, 3> ResultsType(Node->values());
4005
4006 NewOps[0] = Op->getOperand(Num: 0); // Chain
4007 NewOps[1] = Op->getOperand(Num: 1); // Dst Reg
4008 NewOps[2] = Lo; // Lower 64-bit
4009 NewOps[3] = Hi; // Higher 64-bit
4010 if (Op.getNumOperands() == 4)
4011 NewOps[4] = Op->getOperand(Num: 3); // Glue if exists
4012
4013 return DAG.getNode(Opcode: ISD::CopyToReg, DL, ResultTys: ResultsType, Ops: NewOps);
4014}
4015
4016unsigned NVPTXTargetLowering::getNumRegisters(
4017 LLVMContext &Context, EVT VT,
4018 std::optional<MVT> RegisterVT = std::nullopt) const {
4019 if (VT == MVT::i128 && RegisterVT == MVT::i128)
4020 return 1;
4021 return TargetLoweringBase::getNumRegisters(Context, VT, RegisterVT);
4022}
4023
4024bool NVPTXTargetLowering::splitValueIntoRegisterParts(
4025 SelectionDAG &DAG, const SDLoc &DL, SDValue Val, SDValue *Parts,
4026 unsigned NumParts, MVT PartVT, std::optional<CallingConv::ID> CC) const {
4027 if (Val.getValueType() == MVT::i128 && NumParts == 1) {
4028 Parts[0] = Val;
4029 return true;
4030 }
4031 return false;
4032}
4033
4034// This creates target external symbol for a function parameter.
4035// Name of the symbol is composed from its index and the function name.
4036// Negative index corresponds to special parameter (unsized array) used for
4037// passing variable arguments.
4038SDValue NVPTXTargetLowering::getParamSymbol(SelectionDAG &DAG, int I,
4039 EVT T) const {
4040 StringRef SavedStr = nvTM->getStrPool().save(
4041 S: getParamName(F: &DAG.getMachineFunction().getFunction(), Idx: I));
4042 return DAG.getExternalSymbol(Sym: SavedStr.data(), VT: T);
4043}
4044
4045SDValue NVPTXTargetLowering::getCallParamSymbol(SelectionDAG &DAG, int I,
4046 EVT T) const {
4047 const StringRef SavedStr = nvTM->getStrPool().save(S: "param" + Twine(I));
4048 return DAG.getExternalSymbol(Sym: SavedStr.data(), VT: T);
4049}
4050
4051SDValue NVPTXTargetLowering::LowerFormalArguments(
4052 SDValue Chain, CallingConv::ID CallConv, bool isVarArg,
4053 const SmallVectorImpl<ISD::InputArg> &Ins, const SDLoc &dl,
4054 SelectionDAG &DAG, SmallVectorImpl<SDValue> &InVals) const {
4055 const DataLayout &DL = DAG.getDataLayout();
4056 LLVMContext &Ctx = *DAG.getContext();
4057 auto PtrVT = getPointerTy(DL: DAG.getDataLayout());
4058
4059 const Function &F = DAG.getMachineFunction().getFunction();
4060
4061 SDValue Root = DAG.getRoot();
4062 SmallVector<SDValue, 16> OutChains;
4063
4064 // argTypes.size() (or theArgs.size()) and Ins.size() need not match.
4065 // Ins.size() will be larger
4066 // * if there is an aggregate argument with multiple fields (each field
4067 // showing up separately in Ins)
4068 // * if there is a vector argument with more than typical vector-length
4069 // elements (generally if more than 4) where each vector element is
4070 // individually present in Ins.
4071 // So a different index should be used for indexing into Ins.
4072 // See similar issue in LowerCall.
4073
4074 auto AllIns = ArrayRef(Ins);
4075 for (const auto &Arg : F.args()) {
4076 const auto ArgIns = AllIns.take_while(
4077 Pred: [&](auto I) { return I.OrigArgIndex == Arg.getArgNo(); });
4078 AllIns = AllIns.drop_front(N: ArgIns.size());
4079
4080 Type *Ty = Arg.getType();
4081
4082 if (ArgIns.empty())
4083 report_fatal_error(reason: "Empty parameter types are not supported");
4084
4085 if (Arg.use_empty()) {
4086 // argument is dead
4087 for (const auto &In : ArgIns) {
4088 assert(!In.Used && "Arg.use_empty() is true but Arg is used?");
4089 InVals.push_back(Elt: DAG.getUNDEF(VT: In.VT));
4090 }
4091 continue;
4092 }
4093
4094 SDValue ArgSymbol = getParamSymbol(DAG, I: Arg.getArgNo(), T: PtrVT);
4095
4096 // In the following cases, assign a node order of "i+1"
4097 // to newly created nodes. The SDNodes for params have to
4098 // appear in the same order as their order of appearance
4099 // in the original function. "i+1" holds that order.
4100 if (Arg.hasByValAttr()) {
4101 // Param has ByVal attribute
4102 // Return MoveParam(param symbol).
4103 // Ideally, the param symbol can be returned directly,
4104 // but when SDNode builder decides to use it in a CopyToReg(),
4105 // machine instruction fails because TargetExternalSymbol
4106 // (not lowered) is target dependent, and CopyToReg assumes
4107 // the source is lowered.
4108 assert(ArgIns.size() == 1 && "ByVal argument must be a pointer");
4109 const auto &ByvalIn = ArgIns[0];
4110 assert(getValueType(DL, Ty) == ByvalIn.VT &&
4111 "Ins type did not match function type");
4112 assert(ByvalIn.VT == PtrVT && "ByVal argument must be a pointer");
4113
4114 SDValue P;
4115 if (isKernelFunction(F)) {
4116 assert(isParamGridConstant(Arg) && "ByVal argument must be lowered to "
4117 "grid_constant by NVPTXLowerArgs");
4118 P = ArgSymbol;
4119 P.getNode()->setIROrder(Arg.getArgNo() + 1);
4120 } else {
4121 P = DAG.getNode(Opcode: NVPTXISD::MoveParam, DL: dl, VT: ByvalIn.VT, Operand: ArgSymbol);
4122 P.getNode()->setIROrder(Arg.getArgNo() + 1);
4123 P = DAG.getAddrSpaceCast(dl, VT: ByvalIn.VT, Ptr: P, SrcAS: ADDRESS_SPACE_LOCAL,
4124 DestAS: ADDRESS_SPACE_GENERIC);
4125 }
4126 InVals.push_back(Elt: P);
4127 } else {
4128 SmallVector<EVT, 16> VTs;
4129 SmallVector<uint64_t, 16> Offsets;
4130 ComputePTXValueVTs(TLI: *this, DL, Ctx, CallConv, Ty, ValueVTs&: VTs, Offsets);
4131 assert(VTs.size() == ArgIns.size() && "Size mismatch");
4132 assert(VTs.size() == Offsets.size() && "Size mismatch");
4133
4134 const Align ArgAlign = getFunctionArgumentAlignment(
4135 F: &F, Ty, Idx: Arg.getArgNo() + AttributeList::FirstArgIndex, DL);
4136
4137 unsigned I = 0;
4138 const auto VI = VectorizePTXValueVTs(ValueVTs: VTs, Offsets, ParamAlignment: ArgAlign);
4139 for (const unsigned NumElts : VI) {
4140 // i1 is loaded/stored as i8
4141 const EVT LoadVT = VTs[I] == MVT::i1 ? MVT::i8 : VTs[I];
4142 const EVT VecVT = getVectorizedVT(VT: LoadVT, N: NumElts, C&: Ctx);
4143
4144 SDValue VecAddr = DAG.getObjectPtrOffset(
4145 SL: dl, Ptr: ArgSymbol, Offset: TypeSize::getFixed(ExactSize: Offsets[I]));
4146
4147 const Align PartAlign = commonAlignment(A: ArgAlign, Offset: Offsets[I]);
4148 SDValue P =
4149 DAG.getLoad(VT: VecVT, dl, Chain: Root, Ptr: VecAddr,
4150 PtrInfo: MachinePointerInfo(ADDRESS_SPACE_PARAM), Alignment: PartAlign,
4151 MMOFlags: MachineMemOperand::MODereferenceable |
4152 MachineMemOperand::MOInvariant);
4153 P.getNode()->setIROrder(Arg.getArgNo() + 1);
4154 for (const unsigned J : llvm::seq(Size: NumElts)) {
4155 SDValue Elt = getExtractVectorizedValue(V: P, I: J, VT: LoadVT, dl, DAG);
4156
4157 Elt = correctParamType(V: Elt, ExpectedVT: ArgIns[I + J].VT, Flags: ArgIns[I + J].Flags,
4158 DAG, dl);
4159 InVals.push_back(Elt);
4160 }
4161 I += NumElts;
4162 }
4163 }
4164 }
4165
4166 if (!OutChains.empty())
4167 DAG.setRoot(DAG.getTokenFactor(DL: dl, Vals&: OutChains));
4168
4169 return Chain;
4170}
4171
4172SDValue
4173NVPTXTargetLowering::LowerReturn(SDValue Chain, CallingConv::ID CallConv,
4174 bool isVarArg,
4175 const SmallVectorImpl<ISD::OutputArg> &Outs,
4176 const SmallVectorImpl<SDValue> &OutVals,
4177 const SDLoc &dl, SelectionDAG &DAG) const {
4178 const Function &F = DAG.getMachineFunction().getFunction();
4179 Type *RetTy = F.getReturnType();
4180
4181 if (RetTy->isVoidTy()) {
4182 assert(OutVals.empty() && Outs.empty() && "Return value expected for void");
4183 return DAG.getNode(Opcode: NVPTXISD::RET_GLUE, DL: dl, VT: MVT::Other, Operand: Chain);
4184 }
4185
4186 const DataLayout &DL = DAG.getDataLayout();
4187 LLVMContext &Ctx = *DAG.getContext();
4188
4189 const SDValue RetSymbol = DAG.getExternalSymbol(Sym: "func_retval0", VT: MVT::i32);
4190 const auto RetAlign = getFunctionParamOptimizedAlign(F: &F, ArgTy: RetTy, DL);
4191
4192 // PTX Interoperability Guide 3.3(A): [Integer] Values shorter than
4193 // 32-bits are sign extended or zero extended, depending on whether
4194 // they are signed or unsigned types.
4195 const bool ExtendIntegerRetVal =
4196 RetTy->isIntegerTy() && DL.getTypeAllocSizeInBits(Ty: RetTy) < 32;
4197
4198 SmallVector<EVT, 16> VTs;
4199 SmallVector<uint64_t, 16> Offsets;
4200 ComputePTXValueVTs(TLI: *this, DL, Ctx, CallConv, Ty: RetTy, ValueVTs&: VTs, Offsets);
4201 assert(VTs.size() == OutVals.size() && "Bad return value decomposition");
4202
4203 const auto GetRetVal = [&](unsigned I) -> SDValue {
4204 SDValue RetVal = OutVals[I];
4205 assert(promoteScalarIntegerPTX(RetVal.getValueType()) ==
4206 RetVal.getValueType() &&
4207 "OutVal type should always be legal");
4208
4209 const EVT VTI = promoteScalarIntegerPTX(VT: VTs[I]);
4210 const EVT StoreVT =
4211 ExtendIntegerRetVal ? MVT::i32 : (VTI == MVT::i1 ? MVT::i8 : VTI);
4212 return correctParamType(V: RetVal, ExpectedVT: StoreVT, Flags: Outs[I].Flags, DAG, dl);
4213 };
4214
4215 unsigned I = 0;
4216 const auto VI = VectorizePTXValueVTs(ValueVTs: VTs, Offsets, ParamAlignment: RetAlign);
4217 for (const unsigned NumElts : VI) {
4218 const MaybeAlign CurrentAlign = ExtendIntegerRetVal
4219 ? MaybeAlign(std::nullopt)
4220 : commonAlignment(A: RetAlign, Offset: Offsets[I]);
4221
4222 SDValue Val = getBuildVectorizedValue(
4223 N: NumElts, dl, DAG, GetElement: [&](unsigned K) { return GetRetVal(I + K); });
4224
4225 SDValue Ptr =
4226 DAG.getObjectPtrOffset(SL: dl, Ptr: RetSymbol, Offset: TypeSize::getFixed(ExactSize: Offsets[I]));
4227
4228 Chain = DAG.getStore(Chain, dl, Val, Ptr,
4229 PtrInfo: MachinePointerInfo(ADDRESS_SPACE_PARAM), Alignment: CurrentAlign);
4230
4231 I += NumElts;
4232 }
4233
4234 return DAG.getNode(Opcode: NVPTXISD::RET_GLUE, DL: dl, VT: MVT::Other, Operand: Chain);
4235}
4236
4237void NVPTXTargetLowering::LowerAsmOperandForConstraint(
4238 SDValue Op, StringRef Constraint, std::vector<SDValue> &Ops,
4239 SelectionDAG &DAG) const {
4240 if (Constraint.size() > 1)
4241 return;
4242 TargetLowering::LowerAsmOperandForConstraint(Op, Constraint, Ops, DAG);
4243}
4244
4245// llvm.ptx.memcpy.const and llvm.ptx.memmove.const need to be modeled as
4246// TgtMemIntrinsic
4247// because we need the information that is only available in the "Value" type
4248// of destination
4249// pointer. In particular, the address space information.
4250void NVPTXTargetLowering::getTgtMemIntrinsic(
4251 SmallVectorImpl<IntrinsicInfo> &Infos, const CallBase &I,
4252 MachineFunction &MF, unsigned Intrinsic) const {
4253 IntrinsicInfo Info;
4254 switch (Intrinsic) {
4255 default:
4256 return;
4257 case Intrinsic::nvvm_match_all_sync_i32p:
4258 case Intrinsic::nvvm_match_all_sync_i64p:
4259 Info.opc = ISD::INTRINSIC_W_CHAIN;
4260 // memVT is bogus. These intrinsics have IntrInaccessibleMemOnly attribute
4261 // in order to model data exchange with other threads, but perform no real
4262 // memory accesses.
4263 Info.memVT = MVT::i1;
4264
4265 // Our result depends on both our and other thread's arguments.
4266 Info.flags = MachineMemOperand::MOLoad | MachineMemOperand::MOStore;
4267 Infos.push_back(Elt: Info);
4268 return;
4269 case Intrinsic::nvvm_wmma_m16n16k16_load_a_f16_col:
4270 case Intrinsic::nvvm_wmma_m16n16k16_load_a_f16_row:
4271 case Intrinsic::nvvm_wmma_m16n16k16_load_a_f16_col_stride:
4272 case Intrinsic::nvvm_wmma_m16n16k16_load_a_f16_row_stride:
4273 case Intrinsic::nvvm_wmma_m16n16k16_load_b_f16_col:
4274 case Intrinsic::nvvm_wmma_m16n16k16_load_b_f16_row:
4275 case Intrinsic::nvvm_wmma_m16n16k16_load_b_f16_col_stride:
4276 case Intrinsic::nvvm_wmma_m16n16k16_load_b_f16_row_stride:
4277 case Intrinsic::nvvm_wmma_m32n8k16_load_a_f16_col:
4278 case Intrinsic::nvvm_wmma_m32n8k16_load_a_f16_row:
4279 case Intrinsic::nvvm_wmma_m32n8k16_load_a_f16_col_stride:
4280 case Intrinsic::nvvm_wmma_m32n8k16_load_a_f16_row_stride:
4281 case Intrinsic::nvvm_wmma_m32n8k16_load_b_f16_col:
4282 case Intrinsic::nvvm_wmma_m32n8k16_load_b_f16_row:
4283 case Intrinsic::nvvm_wmma_m32n8k16_load_b_f16_col_stride:
4284 case Intrinsic::nvvm_wmma_m32n8k16_load_b_f16_row_stride:
4285 case Intrinsic::nvvm_wmma_m8n32k16_load_a_f16_col:
4286 case Intrinsic::nvvm_wmma_m8n32k16_load_a_f16_row:
4287 case Intrinsic::nvvm_wmma_m8n32k16_load_a_f16_col_stride:
4288 case Intrinsic::nvvm_wmma_m8n32k16_load_a_f16_row_stride:
4289 case Intrinsic::nvvm_wmma_m8n32k16_load_b_f16_col:
4290 case Intrinsic::nvvm_wmma_m8n32k16_load_b_f16_row:
4291 case Intrinsic::nvvm_wmma_m8n32k16_load_b_f16_col_stride:
4292 case Intrinsic::nvvm_wmma_m8n32k16_load_b_f16_row_stride: {
4293 Info.opc = ISD::INTRINSIC_W_CHAIN;
4294 Info.memVT = MVT::v8f16;
4295 Info.ptrVal = I.getArgOperand(i: 0);
4296 Info.offset = 0;
4297 Info.flags = MachineMemOperand::MOLoad;
4298 Info.align = Align(16);
4299 Infos.push_back(Elt: Info);
4300 return;
4301 }
4302 case Intrinsic::nvvm_wmma_m16n16k16_load_a_s8_col:
4303 case Intrinsic::nvvm_wmma_m16n16k16_load_a_s8_col_stride:
4304 case Intrinsic::nvvm_wmma_m16n16k16_load_a_u8_col_stride:
4305 case Intrinsic::nvvm_wmma_m16n16k16_load_a_u8_col:
4306 case Intrinsic::nvvm_wmma_m16n16k16_load_a_s8_row:
4307 case Intrinsic::nvvm_wmma_m16n16k16_load_a_s8_row_stride:
4308 case Intrinsic::nvvm_wmma_m16n16k16_load_a_u8_row_stride:
4309 case Intrinsic::nvvm_wmma_m16n16k16_load_a_u8_row:
4310 case Intrinsic::nvvm_wmma_m8n32k16_load_a_bf16_col:
4311 case Intrinsic::nvvm_wmma_m8n32k16_load_a_bf16_col_stride:
4312 case Intrinsic::nvvm_wmma_m8n32k16_load_a_bf16_row:
4313 case Intrinsic::nvvm_wmma_m8n32k16_load_a_bf16_row_stride:
4314 case Intrinsic::nvvm_wmma_m16n16k16_load_b_s8_col:
4315 case Intrinsic::nvvm_wmma_m16n16k16_load_b_s8_col_stride:
4316 case Intrinsic::nvvm_wmma_m16n16k16_load_b_u8_col_stride:
4317 case Intrinsic::nvvm_wmma_m16n16k16_load_b_u8_col:
4318 case Intrinsic::nvvm_wmma_m16n16k16_load_b_s8_row:
4319 case Intrinsic::nvvm_wmma_m16n16k16_load_b_s8_row_stride:
4320 case Intrinsic::nvvm_wmma_m16n16k16_load_b_u8_row_stride:
4321 case Intrinsic::nvvm_wmma_m16n16k16_load_b_u8_row:
4322 case Intrinsic::nvvm_wmma_m32n8k16_load_b_bf16_col:
4323 case Intrinsic::nvvm_wmma_m32n8k16_load_b_bf16_col_stride:
4324 case Intrinsic::nvvm_wmma_m32n8k16_load_b_bf16_row:
4325 case Intrinsic::nvvm_wmma_m32n8k16_load_b_bf16_row_stride: {
4326 Info.opc = ISD::INTRINSIC_W_CHAIN;
4327 Info.memVT = MVT::v2i32;
4328 Info.ptrVal = I.getArgOperand(i: 0);
4329 Info.offset = 0;
4330 Info.flags = MachineMemOperand::MOLoad;
4331 Info.align = Align(8);
4332 Infos.push_back(Elt: Info);
4333 return;
4334 }
4335
4336 case Intrinsic::nvvm_wmma_m32n8k16_load_a_s8_col:
4337 case Intrinsic::nvvm_wmma_m32n8k16_load_a_s8_col_stride:
4338 case Intrinsic::nvvm_wmma_m32n8k16_load_a_u8_col_stride:
4339 case Intrinsic::nvvm_wmma_m32n8k16_load_a_u8_col:
4340 case Intrinsic::nvvm_wmma_m32n8k16_load_a_s8_row:
4341 case Intrinsic::nvvm_wmma_m32n8k16_load_a_s8_row_stride:
4342 case Intrinsic::nvvm_wmma_m32n8k16_load_a_u8_row_stride:
4343 case Intrinsic::nvvm_wmma_m32n8k16_load_a_u8_row:
4344 case Intrinsic::nvvm_wmma_m16n16k16_load_a_bf16_col:
4345 case Intrinsic::nvvm_wmma_m16n16k16_load_a_bf16_col_stride:
4346 case Intrinsic::nvvm_wmma_m16n16k16_load_a_bf16_row:
4347 case Intrinsic::nvvm_wmma_m16n16k16_load_a_bf16_row_stride:
4348 case Intrinsic::nvvm_wmma_m16n16k8_load_a_tf32_col:
4349 case Intrinsic::nvvm_wmma_m16n16k8_load_a_tf32_col_stride:
4350 case Intrinsic::nvvm_wmma_m16n16k8_load_a_tf32_row:
4351 case Intrinsic::nvvm_wmma_m16n16k8_load_a_tf32_row_stride:
4352
4353 case Intrinsic::nvvm_wmma_m8n32k16_load_b_s8_col:
4354 case Intrinsic::nvvm_wmma_m8n32k16_load_b_s8_col_stride:
4355 case Intrinsic::nvvm_wmma_m8n32k16_load_b_u8_col_stride:
4356 case Intrinsic::nvvm_wmma_m8n32k16_load_b_u8_col:
4357 case Intrinsic::nvvm_wmma_m8n32k16_load_b_s8_row:
4358 case Intrinsic::nvvm_wmma_m8n32k16_load_b_s8_row_stride:
4359 case Intrinsic::nvvm_wmma_m8n32k16_load_b_u8_row_stride:
4360 case Intrinsic::nvvm_wmma_m8n32k16_load_b_u8_row:
4361 case Intrinsic::nvvm_wmma_m16n16k16_load_b_bf16_col:
4362 case Intrinsic::nvvm_wmma_m16n16k16_load_b_bf16_col_stride:
4363 case Intrinsic::nvvm_wmma_m16n16k16_load_b_bf16_row:
4364 case Intrinsic::nvvm_wmma_m16n16k16_load_b_bf16_row_stride:
4365 case Intrinsic::nvvm_wmma_m16n16k8_load_b_tf32_col:
4366 case Intrinsic::nvvm_wmma_m16n16k8_load_b_tf32_col_stride:
4367 case Intrinsic::nvvm_wmma_m16n16k8_load_b_tf32_row:
4368 case Intrinsic::nvvm_wmma_m16n16k8_load_b_tf32_row_stride:
4369 case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x4_b16:
4370 case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x4_trans_b16:
4371 case Intrinsic::nvvm_ldmatrix_sync_aligned_m16n16_x2_trans_b8:
4372 case Intrinsic::nvvm_ldmatrix_sync_aligned_m16n16_x2_trans_b8x16_b4x16_p64:
4373 case Intrinsic::nvvm_ldmatrix_sync_aligned_m16n16_x2_trans_b8x16_b6x16_p32:
4374 case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n16_x4_b8x16_b4x16_p64:
4375 case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n16_x4_b8x16_b6x16_p32: {
4376 Info.opc = ISD::INTRINSIC_W_CHAIN;
4377 Info.memVT = MVT::v4i32;
4378 Info.ptrVal = I.getArgOperand(i: 0);
4379 Info.offset = 0;
4380 Info.flags = MachineMemOperand::MOLoad;
4381 Info.align = Align(16);
4382 Infos.push_back(Elt: Info);
4383 return;
4384 }
4385
4386 case Intrinsic::nvvm_wmma_m32n8k16_load_b_s8_col:
4387 case Intrinsic::nvvm_wmma_m32n8k16_load_b_s8_col_stride:
4388 case Intrinsic::nvvm_wmma_m32n8k16_load_b_u8_col_stride:
4389 case Intrinsic::nvvm_wmma_m32n8k16_load_b_u8_col:
4390 case Intrinsic::nvvm_wmma_m32n8k16_load_b_s8_row:
4391 case Intrinsic::nvvm_wmma_m32n8k16_load_b_s8_row_stride:
4392 case Intrinsic::nvvm_wmma_m32n8k16_load_b_u8_row_stride:
4393 case Intrinsic::nvvm_wmma_m32n8k16_load_b_u8_row:
4394
4395 case Intrinsic::nvvm_wmma_m8n32k16_load_a_s8_col:
4396 case Intrinsic::nvvm_wmma_m8n32k16_load_a_s8_col_stride:
4397 case Intrinsic::nvvm_wmma_m8n32k16_load_a_u8_col_stride:
4398 case Intrinsic::nvvm_wmma_m8n32k16_load_a_u8_col:
4399 case Intrinsic::nvvm_wmma_m8n32k16_load_a_s8_row:
4400 case Intrinsic::nvvm_wmma_m8n32k16_load_a_s8_row_stride:
4401 case Intrinsic::nvvm_wmma_m8n32k16_load_a_u8_row_stride:
4402 case Intrinsic::nvvm_wmma_m8n32k16_load_a_u8_row:
4403 case Intrinsic::nvvm_wmma_m8n8k128_load_a_b1_row:
4404 case Intrinsic::nvvm_wmma_m8n8k128_load_a_b1_row_stride:
4405 case Intrinsic::nvvm_wmma_m8n8k128_load_b_b1_col:
4406 case Intrinsic::nvvm_wmma_m8n8k128_load_b_b1_col_stride:
4407 case Intrinsic::nvvm_wmma_m8n8k32_load_a_s4_row:
4408 case Intrinsic::nvvm_wmma_m8n8k32_load_a_s4_row_stride:
4409 case Intrinsic::nvvm_wmma_m8n8k32_load_a_u4_row_stride:
4410 case Intrinsic::nvvm_wmma_m8n8k32_load_a_u4_row:
4411 case Intrinsic::nvvm_wmma_m8n8k32_load_b_s4_col:
4412 case Intrinsic::nvvm_wmma_m8n8k32_load_b_s4_col_stride:
4413 case Intrinsic::nvvm_wmma_m8n8k32_load_b_u4_col_stride:
4414 case Intrinsic::nvvm_wmma_m8n8k32_load_b_u4_col:
4415 case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x1_b16:
4416 case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x1_trans_b16:
4417 case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n16_x1_b8x16_b4x16_p64:
4418 case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n16_x1_b8x16_b6x16_p32: {
4419 Info.opc = ISD::INTRINSIC_W_CHAIN;
4420 Info.memVT = MVT::i32;
4421 Info.ptrVal = I.getArgOperand(i: 0);
4422 Info.offset = 0;
4423 Info.flags = MachineMemOperand::MOLoad;
4424 Info.align = Align(4);
4425 Infos.push_back(Elt: Info);
4426 return;
4427 }
4428
4429 case Intrinsic::nvvm_wmma_m16n16k16_load_c_f16_col:
4430 case Intrinsic::nvvm_wmma_m16n16k16_load_c_f16_row:
4431 case Intrinsic::nvvm_wmma_m16n16k16_load_c_f16_col_stride:
4432 case Intrinsic::nvvm_wmma_m16n16k16_load_c_f16_row_stride:
4433 case Intrinsic::nvvm_wmma_m32n8k16_load_c_f16_col:
4434 case Intrinsic::nvvm_wmma_m32n8k16_load_c_f16_row:
4435 case Intrinsic::nvvm_wmma_m32n8k16_load_c_f16_col_stride:
4436 case Intrinsic::nvvm_wmma_m32n8k16_load_c_f16_row_stride:
4437 case Intrinsic::nvvm_wmma_m8n32k16_load_c_f16_col:
4438 case Intrinsic::nvvm_wmma_m8n32k16_load_c_f16_row:
4439 case Intrinsic::nvvm_wmma_m8n32k16_load_c_f16_col_stride:
4440 case Intrinsic::nvvm_wmma_m8n32k16_load_c_f16_row_stride: {
4441 Info.opc = ISD::INTRINSIC_W_CHAIN;
4442 Info.memVT = MVT::v4f16;
4443 Info.ptrVal = I.getArgOperand(i: 0);
4444 Info.offset = 0;
4445 Info.flags = MachineMemOperand::MOLoad;
4446 Info.align = Align(16);
4447 Infos.push_back(Elt: Info);
4448 return;
4449 }
4450
4451 case Intrinsic::nvvm_wmma_m16n16k16_load_c_f32_col:
4452 case Intrinsic::nvvm_wmma_m16n16k16_load_c_f32_row:
4453 case Intrinsic::nvvm_wmma_m16n16k16_load_c_f32_col_stride:
4454 case Intrinsic::nvvm_wmma_m16n16k16_load_c_f32_row_stride:
4455 case Intrinsic::nvvm_wmma_m32n8k16_load_c_f32_col:
4456 case Intrinsic::nvvm_wmma_m32n8k16_load_c_f32_row:
4457 case Intrinsic::nvvm_wmma_m32n8k16_load_c_f32_col_stride:
4458 case Intrinsic::nvvm_wmma_m32n8k16_load_c_f32_row_stride:
4459 case Intrinsic::nvvm_wmma_m8n32k16_load_c_f32_col:
4460 case Intrinsic::nvvm_wmma_m8n32k16_load_c_f32_row:
4461 case Intrinsic::nvvm_wmma_m8n32k16_load_c_f32_col_stride:
4462 case Intrinsic::nvvm_wmma_m8n32k16_load_c_f32_row_stride:
4463 case Intrinsic::nvvm_wmma_m16n16k8_load_c_f32_col:
4464 case Intrinsic::nvvm_wmma_m16n16k8_load_c_f32_row:
4465 case Intrinsic::nvvm_wmma_m16n16k8_load_c_f32_col_stride:
4466 case Intrinsic::nvvm_wmma_m16n16k8_load_c_f32_row_stride: {
4467 Info.opc = ISD::INTRINSIC_W_CHAIN;
4468 Info.memVT = MVT::v8f32;
4469 Info.ptrVal = I.getArgOperand(i: 0);
4470 Info.offset = 0;
4471 Info.flags = MachineMemOperand::MOLoad;
4472 Info.align = Align(16);
4473 Infos.push_back(Elt: Info);
4474 return;
4475 }
4476
4477 case Intrinsic::nvvm_wmma_m32n8k16_load_a_bf16_col:
4478 case Intrinsic::nvvm_wmma_m32n8k16_load_a_bf16_col_stride:
4479 case Intrinsic::nvvm_wmma_m32n8k16_load_a_bf16_row:
4480 case Intrinsic::nvvm_wmma_m32n8k16_load_a_bf16_row_stride:
4481
4482 case Intrinsic::nvvm_wmma_m8n32k16_load_b_bf16_col:
4483 case Intrinsic::nvvm_wmma_m8n32k16_load_b_bf16_col_stride:
4484 case Intrinsic::nvvm_wmma_m8n32k16_load_b_bf16_row:
4485 case Intrinsic::nvvm_wmma_m8n32k16_load_b_bf16_row_stride:
4486
4487 case Intrinsic::nvvm_wmma_m16n16k16_load_c_s32_col:
4488 case Intrinsic::nvvm_wmma_m16n16k16_load_c_s32_col_stride:
4489 case Intrinsic::nvvm_wmma_m16n16k16_load_c_s32_row:
4490 case Intrinsic::nvvm_wmma_m16n16k16_load_c_s32_row_stride:
4491 case Intrinsic::nvvm_wmma_m32n8k16_load_c_s32_col:
4492 case Intrinsic::nvvm_wmma_m32n8k16_load_c_s32_col_stride:
4493 case Intrinsic::nvvm_wmma_m32n8k16_load_c_s32_row:
4494 case Intrinsic::nvvm_wmma_m32n8k16_load_c_s32_row_stride:
4495 case Intrinsic::nvvm_wmma_m8n32k16_load_c_s32_col:
4496 case Intrinsic::nvvm_wmma_m8n32k16_load_c_s32_col_stride:
4497 case Intrinsic::nvvm_wmma_m8n32k16_load_c_s32_row:
4498 case Intrinsic::nvvm_wmma_m8n32k16_load_c_s32_row_stride: {
4499 Info.opc = ISD::INTRINSIC_W_CHAIN;
4500 Info.memVT = MVT::v8i32;
4501 Info.ptrVal = I.getArgOperand(i: 0);
4502 Info.offset = 0;
4503 Info.flags = MachineMemOperand::MOLoad;
4504 Info.align = Align(16);
4505 Infos.push_back(Elt: Info);
4506 return;
4507 }
4508
4509 case Intrinsic::nvvm_wmma_m8n8k128_load_c_s32_col:
4510 case Intrinsic::nvvm_wmma_m8n8k128_load_c_s32_col_stride:
4511 case Intrinsic::nvvm_wmma_m8n8k128_load_c_s32_row:
4512 case Intrinsic::nvvm_wmma_m8n8k128_load_c_s32_row_stride:
4513 case Intrinsic::nvvm_wmma_m8n8k32_load_c_s32_col:
4514 case Intrinsic::nvvm_wmma_m8n8k32_load_c_s32_col_stride:
4515 case Intrinsic::nvvm_wmma_m8n8k32_load_c_s32_row:
4516 case Intrinsic::nvvm_wmma_m8n8k32_load_c_s32_row_stride:
4517 case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x2_b16:
4518 case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x2_trans_b16:
4519 case Intrinsic::nvvm_ldmatrix_sync_aligned_m16n16_x1_trans_b8:
4520 case Intrinsic::nvvm_ldmatrix_sync_aligned_m16n16_x1_trans_b8x16_b4x16_p64:
4521 case Intrinsic::nvvm_ldmatrix_sync_aligned_m16n16_x1_trans_b8x16_b6x16_p32:
4522 case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n16_x2_b8x16_b4x16_p64:
4523 case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n16_x2_b8x16_b6x16_p32: {
4524 Info.opc = ISD::INTRINSIC_W_CHAIN;
4525 Info.memVT = MVT::v2i32;
4526 Info.ptrVal = I.getArgOperand(i: 0);
4527 Info.offset = 0;
4528 Info.flags = MachineMemOperand::MOLoad;
4529 Info.align = Align(8);
4530 Infos.push_back(Elt: Info);
4531 return;
4532 }
4533
4534 case Intrinsic::nvvm_wmma_m8n8k4_load_a_f64_col:
4535 case Intrinsic::nvvm_wmma_m8n8k4_load_a_f64_col_stride:
4536 case Intrinsic::nvvm_wmma_m8n8k4_load_a_f64_row:
4537 case Intrinsic::nvvm_wmma_m8n8k4_load_a_f64_row_stride:
4538
4539 case Intrinsic::nvvm_wmma_m8n8k4_load_b_f64_col:
4540 case Intrinsic::nvvm_wmma_m8n8k4_load_b_f64_col_stride:
4541 case Intrinsic::nvvm_wmma_m8n8k4_load_b_f64_row:
4542 case Intrinsic::nvvm_wmma_m8n8k4_load_b_f64_row_stride: {
4543 Info.opc = ISD::INTRINSIC_W_CHAIN;
4544 Info.memVT = MVT::f64;
4545 Info.ptrVal = I.getArgOperand(i: 0);
4546 Info.offset = 0;
4547 Info.flags = MachineMemOperand::MOLoad;
4548 Info.align = Align(8);
4549 Infos.push_back(Elt: Info);
4550 return;
4551 }
4552
4553 case Intrinsic::nvvm_wmma_m8n8k4_load_c_f64_col:
4554 case Intrinsic::nvvm_wmma_m8n8k4_load_c_f64_col_stride:
4555 case Intrinsic::nvvm_wmma_m8n8k4_load_c_f64_row:
4556 case Intrinsic::nvvm_wmma_m8n8k4_load_c_f64_row_stride: {
4557 Info.opc = ISD::INTRINSIC_W_CHAIN;
4558 Info.memVT = MVT::v2f64;
4559 Info.ptrVal = I.getArgOperand(i: 0);
4560 Info.offset = 0;
4561 Info.flags = MachineMemOperand::MOLoad;
4562 Info.align = Align(16);
4563 Infos.push_back(Elt: Info);
4564 return;
4565 }
4566
4567 case Intrinsic::nvvm_wmma_m16n16k16_store_d_f16_col:
4568 case Intrinsic::nvvm_wmma_m16n16k16_store_d_f16_row:
4569 case Intrinsic::nvvm_wmma_m16n16k16_store_d_f16_col_stride:
4570 case Intrinsic::nvvm_wmma_m16n16k16_store_d_f16_row_stride:
4571 case Intrinsic::nvvm_wmma_m32n8k16_store_d_f16_col:
4572 case Intrinsic::nvvm_wmma_m32n8k16_store_d_f16_row:
4573 case Intrinsic::nvvm_wmma_m32n8k16_store_d_f16_col_stride:
4574 case Intrinsic::nvvm_wmma_m32n8k16_store_d_f16_row_stride:
4575 case Intrinsic::nvvm_wmma_m8n32k16_store_d_f16_col:
4576 case Intrinsic::nvvm_wmma_m8n32k16_store_d_f16_row:
4577 case Intrinsic::nvvm_wmma_m8n32k16_store_d_f16_col_stride:
4578 case Intrinsic::nvvm_wmma_m8n32k16_store_d_f16_row_stride: {
4579 Info.opc = ISD::INTRINSIC_VOID;
4580 Info.memVT = MVT::v4f16;
4581 Info.ptrVal = I.getArgOperand(i: 0);
4582 Info.offset = 0;
4583 Info.flags = MachineMemOperand::MOStore;
4584 Info.align = Align(16);
4585 Infos.push_back(Elt: Info);
4586 return;
4587 }
4588
4589 case Intrinsic::nvvm_wmma_m16n16k16_store_d_f32_col:
4590 case Intrinsic::nvvm_wmma_m16n16k16_store_d_f32_row:
4591 case Intrinsic::nvvm_wmma_m16n16k16_store_d_f32_col_stride:
4592 case Intrinsic::nvvm_wmma_m16n16k16_store_d_f32_row_stride:
4593 case Intrinsic::nvvm_wmma_m32n8k16_store_d_f32_col:
4594 case Intrinsic::nvvm_wmma_m32n8k16_store_d_f32_row:
4595 case Intrinsic::nvvm_wmma_m32n8k16_store_d_f32_col_stride:
4596 case Intrinsic::nvvm_wmma_m32n8k16_store_d_f32_row_stride:
4597 case Intrinsic::nvvm_wmma_m8n32k16_store_d_f32_col:
4598 case Intrinsic::nvvm_wmma_m8n32k16_store_d_f32_row:
4599 case Intrinsic::nvvm_wmma_m8n32k16_store_d_f32_col_stride:
4600 case Intrinsic::nvvm_wmma_m8n32k16_store_d_f32_row_stride:
4601 case Intrinsic::nvvm_wmma_m16n16k8_store_d_f32_col:
4602 case Intrinsic::nvvm_wmma_m16n16k8_store_d_f32_row:
4603 case Intrinsic::nvvm_wmma_m16n16k8_store_d_f32_col_stride:
4604 case Intrinsic::nvvm_wmma_m16n16k8_store_d_f32_row_stride: {
4605 Info.opc = ISD::INTRINSIC_VOID;
4606 Info.memVT = MVT::v8f32;
4607 Info.ptrVal = I.getArgOperand(i: 0);
4608 Info.offset = 0;
4609 Info.flags = MachineMemOperand::MOStore;
4610 Info.align = Align(16);
4611 Infos.push_back(Elt: Info);
4612 return;
4613 }
4614
4615 case Intrinsic::nvvm_wmma_m16n16k16_store_d_s32_col:
4616 case Intrinsic::nvvm_wmma_m16n16k16_store_d_s32_col_stride:
4617 case Intrinsic::nvvm_wmma_m16n16k16_store_d_s32_row:
4618 case Intrinsic::nvvm_wmma_m16n16k16_store_d_s32_row_stride:
4619 case Intrinsic::nvvm_wmma_m32n8k16_store_d_s32_col:
4620 case Intrinsic::nvvm_wmma_m32n8k16_store_d_s32_col_stride:
4621 case Intrinsic::nvvm_wmma_m32n8k16_store_d_s32_row:
4622 case Intrinsic::nvvm_wmma_m32n8k16_store_d_s32_row_stride:
4623 case Intrinsic::nvvm_wmma_m8n32k16_store_d_s32_col:
4624 case Intrinsic::nvvm_wmma_m8n32k16_store_d_s32_col_stride:
4625 case Intrinsic::nvvm_wmma_m8n32k16_store_d_s32_row:
4626 case Intrinsic::nvvm_wmma_m8n32k16_store_d_s32_row_stride: {
4627 Info.opc = ISD::INTRINSIC_VOID;
4628 Info.memVT = MVT::v8i32;
4629 Info.ptrVal = I.getArgOperand(i: 0);
4630 Info.offset = 0;
4631 Info.flags = MachineMemOperand::MOStore;
4632 Info.align = Align(16);
4633 Infos.push_back(Elt: Info);
4634 return;
4635 }
4636
4637 case Intrinsic::nvvm_wmma_m8n8k128_store_d_s32_col:
4638 case Intrinsic::nvvm_wmma_m8n8k128_store_d_s32_col_stride:
4639 case Intrinsic::nvvm_wmma_m8n8k128_store_d_s32_row:
4640 case Intrinsic::nvvm_wmma_m8n8k128_store_d_s32_row_stride:
4641 case Intrinsic::nvvm_wmma_m8n8k32_store_d_s32_col:
4642 case Intrinsic::nvvm_wmma_m8n8k32_store_d_s32_col_stride:
4643 case Intrinsic::nvvm_wmma_m8n8k32_store_d_s32_row:
4644 case Intrinsic::nvvm_wmma_m8n8k32_store_d_s32_row_stride:
4645 case Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x2_b16:
4646 case Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x2_trans_b16:
4647 case Intrinsic::nvvm_stmatrix_sync_aligned_m16n8_x2_trans_b8: {
4648 Info.opc = ISD::INTRINSIC_VOID;
4649 Info.memVT = MVT::v2i32;
4650 Info.ptrVal = I.getArgOperand(i: 0);
4651 Info.offset = 0;
4652 Info.flags = MachineMemOperand::MOStore;
4653 Info.align = Align(8);
4654 Infos.push_back(Elt: Info);
4655 return;
4656 }
4657
4658 case Intrinsic::nvvm_wmma_m8n8k4_store_d_f64_col:
4659 case Intrinsic::nvvm_wmma_m8n8k4_store_d_f64_col_stride:
4660 case Intrinsic::nvvm_wmma_m8n8k4_store_d_f64_row:
4661 case Intrinsic::nvvm_wmma_m8n8k4_store_d_f64_row_stride: {
4662 Info.opc = ISD::INTRINSIC_VOID;
4663 Info.memVT = MVT::v2f64;
4664 Info.ptrVal = I.getArgOperand(i: 0);
4665 Info.offset = 0;
4666 Info.flags = MachineMemOperand::MOStore;
4667 Info.align = Align(16);
4668 Infos.push_back(Elt: Info);
4669 return;
4670 }
4671
4672 case Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x1_b16:
4673 case Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x1_trans_b16:
4674 case Intrinsic::nvvm_stmatrix_sync_aligned_m16n8_x1_trans_b8: {
4675 Info.opc = ISD::INTRINSIC_VOID;
4676 Info.memVT = MVT::i32;
4677 Info.ptrVal = I.getArgOperand(i: 0);
4678 Info.offset = 0;
4679 Info.flags = MachineMemOperand::MOStore;
4680 Info.align = Align(4);
4681 Infos.push_back(Elt: Info);
4682 return;
4683 }
4684
4685 case Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x4_b16:
4686 case Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x4_trans_b16:
4687 case Intrinsic::nvvm_stmatrix_sync_aligned_m16n8_x4_trans_b8: {
4688 Info.opc = ISD::INTRINSIC_VOID;
4689 Info.memVT = MVT::v4i32;
4690 Info.ptrVal = I.getArgOperand(i: 0);
4691 Info.offset = 0;
4692 Info.flags = MachineMemOperand::MOStore;
4693 Info.align = Align(16);
4694 Infos.push_back(Elt: Info);
4695 return;
4696 }
4697
4698 case Intrinsic::nvvm_atomic_add_gen_f_cta:
4699 case Intrinsic::nvvm_atomic_add_gen_f_sys:
4700 case Intrinsic::nvvm_atomic_add_gen_i_cta:
4701 case Intrinsic::nvvm_atomic_add_gen_i_sys:
4702 case Intrinsic::nvvm_atomic_and_gen_i_cta:
4703 case Intrinsic::nvvm_atomic_and_gen_i_sys:
4704 case Intrinsic::nvvm_atomic_cas_gen_i_cta:
4705 case Intrinsic::nvvm_atomic_cas_gen_i_sys:
4706 case Intrinsic::nvvm_atomic_dec_gen_i_cta:
4707 case Intrinsic::nvvm_atomic_dec_gen_i_sys:
4708 case Intrinsic::nvvm_atomic_inc_gen_i_cta:
4709 case Intrinsic::nvvm_atomic_inc_gen_i_sys:
4710 case Intrinsic::nvvm_atomic_max_gen_i_cta:
4711 case Intrinsic::nvvm_atomic_max_gen_i_sys:
4712 case Intrinsic::nvvm_atomic_min_gen_i_cta:
4713 case Intrinsic::nvvm_atomic_min_gen_i_sys:
4714 case Intrinsic::nvvm_atomic_or_gen_i_cta:
4715 case Intrinsic::nvvm_atomic_or_gen_i_sys:
4716 case Intrinsic::nvvm_atomic_exch_gen_i_cta:
4717 case Intrinsic::nvvm_atomic_exch_gen_i_sys:
4718 case Intrinsic::nvvm_atomic_xor_gen_i_cta:
4719 case Intrinsic::nvvm_atomic_xor_gen_i_sys: {
4720 auto &DL = I.getDataLayout();
4721 Info.opc = ISD::INTRINSIC_W_CHAIN;
4722 Info.memVT = getValueType(DL, Ty: I.getType());
4723 Info.ptrVal = I.getArgOperand(i: 0);
4724 Info.offset = 0;
4725 Info.flags = MachineMemOperand::MOLoad | MachineMemOperand::MOStore;
4726 Info.align.reset();
4727 Infos.push_back(Elt: Info);
4728 return;
4729 }
4730
4731 case Intrinsic::nvvm_prefetch_tensormap: {
4732 auto &DL = I.getDataLayout();
4733 Info.opc = ISD::INTRINSIC_VOID;
4734 Info.memVT = getPointerTy(DL);
4735 Info.ptrVal = I.getArgOperand(i: 0);
4736 Info.offset = 0;
4737 Info.flags =
4738 MachineMemOperand::MOLoad | MachineMemOperand::MODereferenceable;
4739 Info.align.reset();
4740 Infos.push_back(Elt: Info);
4741 return;
4742 }
4743
4744 case Intrinsic::nvvm_tensormap_replace_global_address:
4745 case Intrinsic::nvvm_tensormap_replace_global_stride: {
4746 Info.opc = ISD::INTRINSIC_VOID;
4747 Info.memVT = MVT::i64;
4748 Info.ptrVal = I.getArgOperand(i: 0);
4749 Info.offset = 0;
4750 Info.flags = MachineMemOperand::MOStore;
4751 Info.align.reset();
4752 Infos.push_back(Elt: Info);
4753 return;
4754 }
4755
4756 case Intrinsic::nvvm_tensormap_replace_rank:
4757 case Intrinsic::nvvm_tensormap_replace_box_dim:
4758 case Intrinsic::nvvm_tensormap_replace_global_dim:
4759 case Intrinsic::nvvm_tensormap_replace_element_stride:
4760 case Intrinsic::nvvm_tensormap_replace_elemtype:
4761 case Intrinsic::nvvm_tensormap_replace_interleave_layout:
4762 case Intrinsic::nvvm_tensormap_replace_swizzle_mode:
4763 case Intrinsic::nvvm_tensormap_replace_swizzle_atomicity:
4764 case Intrinsic::nvvm_tensormap_replace_fill_mode: {
4765 Info.opc = ISD::INTRINSIC_VOID;
4766 Info.memVT = MVT::i32;
4767 Info.ptrVal = I.getArgOperand(i: 0);
4768 Info.offset = 0;
4769 Info.flags = MachineMemOperand::MOStore;
4770 Info.align.reset();
4771 Infos.push_back(Elt: Info);
4772 return;
4773 }
4774
4775 case Intrinsic::nvvm_ldu_global_i:
4776 case Intrinsic::nvvm_ldu_global_f:
4777 case Intrinsic::nvvm_ldu_global_p: {
4778 Info.opc = ISD::INTRINSIC_W_CHAIN;
4779 Info.memVT = getValueType(DL: I.getDataLayout(), Ty: I.getType());
4780 Info.ptrVal = I.getArgOperand(i: 0);
4781 Info.offset = 0;
4782 Info.flags = MachineMemOperand::MOLoad;
4783 Info.align = cast<ConstantInt>(Val: I.getArgOperand(i: 1))->getMaybeAlignValue();
4784
4785 Infos.push_back(Elt: Info);
4786 return;
4787 }
4788 case Intrinsic::nvvm_tex_1d_v4f32_s32:
4789 case Intrinsic::nvvm_tex_1d_v4f32_f32:
4790 case Intrinsic::nvvm_tex_1d_level_v4f32_f32:
4791 case Intrinsic::nvvm_tex_1d_grad_v4f32_f32:
4792 case Intrinsic::nvvm_tex_1d_array_v4f32_s32:
4793 case Intrinsic::nvvm_tex_1d_array_v4f32_f32:
4794 case Intrinsic::nvvm_tex_1d_array_level_v4f32_f32:
4795 case Intrinsic::nvvm_tex_1d_array_grad_v4f32_f32:
4796 case Intrinsic::nvvm_tex_2d_v4f32_s32:
4797 case Intrinsic::nvvm_tex_2d_v4f32_f32:
4798 case Intrinsic::nvvm_tex_2d_level_v4f32_f32:
4799 case Intrinsic::nvvm_tex_2d_grad_v4f32_f32:
4800 case Intrinsic::nvvm_tex_2d_array_v4f32_s32:
4801 case Intrinsic::nvvm_tex_2d_array_v4f32_f32:
4802 case Intrinsic::nvvm_tex_2d_array_level_v4f32_f32:
4803 case Intrinsic::nvvm_tex_2d_array_grad_v4f32_f32:
4804 case Intrinsic::nvvm_tex_3d_v4f32_s32:
4805 case Intrinsic::nvvm_tex_3d_v4f32_f32:
4806 case Intrinsic::nvvm_tex_3d_level_v4f32_f32:
4807 case Intrinsic::nvvm_tex_3d_grad_v4f32_f32:
4808 case Intrinsic::nvvm_tex_cube_v4f32_f32:
4809 case Intrinsic::nvvm_tex_cube_level_v4f32_f32:
4810 case Intrinsic::nvvm_tex_cube_array_v4f32_f32:
4811 case Intrinsic::nvvm_tex_cube_array_level_v4f32_f32:
4812 case Intrinsic::nvvm_tld4_r_2d_v4f32_f32:
4813 case Intrinsic::nvvm_tld4_g_2d_v4f32_f32:
4814 case Intrinsic::nvvm_tld4_b_2d_v4f32_f32:
4815 case Intrinsic::nvvm_tld4_a_2d_v4f32_f32:
4816 case Intrinsic::nvvm_tex_unified_1d_v4f32_s32:
4817 case Intrinsic::nvvm_tex_unified_1d_v4f32_f32:
4818 case Intrinsic::nvvm_tex_unified_1d_level_v4f32_f32:
4819 case Intrinsic::nvvm_tex_unified_1d_grad_v4f32_f32:
4820 case Intrinsic::nvvm_tex_unified_1d_array_v4f32_s32:
4821 case Intrinsic::nvvm_tex_unified_1d_array_v4f32_f32:
4822 case Intrinsic::nvvm_tex_unified_1d_array_level_v4f32_f32:
4823 case Intrinsic::nvvm_tex_unified_1d_array_grad_v4f32_f32:
4824 case Intrinsic::nvvm_tex_unified_2d_v4f32_s32:
4825 case Intrinsic::nvvm_tex_unified_2d_v4f32_f32:
4826 case Intrinsic::nvvm_tex_unified_2d_level_v4f32_f32:
4827 case Intrinsic::nvvm_tex_unified_2d_grad_v4f32_f32:
4828 case Intrinsic::nvvm_tex_unified_2d_array_v4f32_s32:
4829 case Intrinsic::nvvm_tex_unified_2d_array_v4f32_f32:
4830 case Intrinsic::nvvm_tex_unified_2d_array_level_v4f32_f32:
4831 case Intrinsic::nvvm_tex_unified_2d_array_grad_v4f32_f32:
4832 case Intrinsic::nvvm_tex_unified_3d_v4f32_s32:
4833 case Intrinsic::nvvm_tex_unified_3d_v4f32_f32:
4834 case Intrinsic::nvvm_tex_unified_3d_level_v4f32_f32:
4835 case Intrinsic::nvvm_tex_unified_3d_grad_v4f32_f32:
4836 case Intrinsic::nvvm_tex_unified_cube_v4f32_f32:
4837 case Intrinsic::nvvm_tex_unified_cube_level_v4f32_f32:
4838 case Intrinsic::nvvm_tex_unified_cube_array_v4f32_f32:
4839 case Intrinsic::nvvm_tex_unified_cube_array_level_v4f32_f32:
4840 case Intrinsic::nvvm_tex_unified_cube_grad_v4f32_f32:
4841 case Intrinsic::nvvm_tex_unified_cube_array_grad_v4f32_f32:
4842 case Intrinsic::nvvm_tld4_unified_r_2d_v4f32_f32:
4843 case Intrinsic::nvvm_tld4_unified_g_2d_v4f32_f32:
4844 case Intrinsic::nvvm_tld4_unified_b_2d_v4f32_f32:
4845 case Intrinsic::nvvm_tld4_unified_a_2d_v4f32_f32:
4846 Info.opc = ISD::INTRINSIC_W_CHAIN;
4847 Info.memVT = MVT::v4f32;
4848 Info.ptrVal = nullptr;
4849 Info.offset = 0;
4850 Info.flags = MachineMemOperand::MOLoad;
4851 Info.align = Align(16);
4852 Infos.push_back(Elt: Info);
4853 return;
4854
4855 case Intrinsic::nvvm_tex_1d_v4s32_s32:
4856 case Intrinsic::nvvm_tex_1d_v4s32_f32:
4857 case Intrinsic::nvvm_tex_1d_level_v4s32_f32:
4858 case Intrinsic::nvvm_tex_1d_grad_v4s32_f32:
4859 case Intrinsic::nvvm_tex_1d_array_v4s32_s32:
4860 case Intrinsic::nvvm_tex_1d_array_v4s32_f32:
4861 case Intrinsic::nvvm_tex_1d_array_level_v4s32_f32:
4862 case Intrinsic::nvvm_tex_1d_array_grad_v4s32_f32:
4863 case Intrinsic::nvvm_tex_2d_v4s32_s32:
4864 case Intrinsic::nvvm_tex_2d_v4s32_f32:
4865 case Intrinsic::nvvm_tex_2d_level_v4s32_f32:
4866 case Intrinsic::nvvm_tex_2d_grad_v4s32_f32:
4867 case Intrinsic::nvvm_tex_2d_array_v4s32_s32:
4868 case Intrinsic::nvvm_tex_2d_array_v4s32_f32:
4869 case Intrinsic::nvvm_tex_2d_array_level_v4s32_f32:
4870 case Intrinsic::nvvm_tex_2d_array_grad_v4s32_f32:
4871 case Intrinsic::nvvm_tex_3d_v4s32_s32:
4872 case Intrinsic::nvvm_tex_3d_v4s32_f32:
4873 case Intrinsic::nvvm_tex_3d_level_v4s32_f32:
4874 case Intrinsic::nvvm_tex_3d_grad_v4s32_f32:
4875 case Intrinsic::nvvm_tex_cube_v4s32_f32:
4876 case Intrinsic::nvvm_tex_cube_level_v4s32_f32:
4877 case Intrinsic::nvvm_tex_cube_array_v4s32_f32:
4878 case Intrinsic::nvvm_tex_cube_array_level_v4s32_f32:
4879 case Intrinsic::nvvm_tex_cube_v4u32_f32:
4880 case Intrinsic::nvvm_tex_cube_level_v4u32_f32:
4881 case Intrinsic::nvvm_tex_cube_array_v4u32_f32:
4882 case Intrinsic::nvvm_tex_cube_array_level_v4u32_f32:
4883 case Intrinsic::nvvm_tex_1d_v4u32_s32:
4884 case Intrinsic::nvvm_tex_1d_v4u32_f32:
4885 case Intrinsic::nvvm_tex_1d_level_v4u32_f32:
4886 case Intrinsic::nvvm_tex_1d_grad_v4u32_f32:
4887 case Intrinsic::nvvm_tex_1d_array_v4u32_s32:
4888 case Intrinsic::nvvm_tex_1d_array_v4u32_f32:
4889 case Intrinsic::nvvm_tex_1d_array_level_v4u32_f32:
4890 case Intrinsic::nvvm_tex_1d_array_grad_v4u32_f32:
4891 case Intrinsic::nvvm_tex_2d_v4u32_s32:
4892 case Intrinsic::nvvm_tex_2d_v4u32_f32:
4893 case Intrinsic::nvvm_tex_2d_level_v4u32_f32:
4894 case Intrinsic::nvvm_tex_2d_grad_v4u32_f32:
4895 case Intrinsic::nvvm_tex_2d_array_v4u32_s32:
4896 case Intrinsic::nvvm_tex_2d_array_v4u32_f32:
4897 case Intrinsic::nvvm_tex_2d_array_level_v4u32_f32:
4898 case Intrinsic::nvvm_tex_2d_array_grad_v4u32_f32:
4899 case Intrinsic::nvvm_tex_3d_v4u32_s32:
4900 case Intrinsic::nvvm_tex_3d_v4u32_f32:
4901 case Intrinsic::nvvm_tex_3d_level_v4u32_f32:
4902 case Intrinsic::nvvm_tex_3d_grad_v4u32_f32:
4903 case Intrinsic::nvvm_tld4_r_2d_v4s32_f32:
4904 case Intrinsic::nvvm_tld4_g_2d_v4s32_f32:
4905 case Intrinsic::nvvm_tld4_b_2d_v4s32_f32:
4906 case Intrinsic::nvvm_tld4_a_2d_v4s32_f32:
4907 case Intrinsic::nvvm_tld4_r_2d_v4u32_f32:
4908 case Intrinsic::nvvm_tld4_g_2d_v4u32_f32:
4909 case Intrinsic::nvvm_tld4_b_2d_v4u32_f32:
4910 case Intrinsic::nvvm_tld4_a_2d_v4u32_f32:
4911 case Intrinsic::nvvm_tex_unified_1d_v4s32_s32:
4912 case Intrinsic::nvvm_tex_unified_1d_v4s32_f32:
4913 case Intrinsic::nvvm_tex_unified_1d_level_v4s32_f32:
4914 case Intrinsic::nvvm_tex_unified_1d_grad_v4s32_f32:
4915 case Intrinsic::nvvm_tex_unified_1d_array_v4s32_s32:
4916 case Intrinsic::nvvm_tex_unified_1d_array_v4s32_f32:
4917 case Intrinsic::nvvm_tex_unified_1d_array_level_v4s32_f32:
4918 case Intrinsic::nvvm_tex_unified_1d_array_grad_v4s32_f32:
4919 case Intrinsic::nvvm_tex_unified_2d_v4s32_s32:
4920 case Intrinsic::nvvm_tex_unified_2d_v4s32_f32:
4921 case Intrinsic::nvvm_tex_unified_2d_level_v4s32_f32:
4922 case Intrinsic::nvvm_tex_unified_2d_grad_v4s32_f32:
4923 case Intrinsic::nvvm_tex_unified_2d_array_v4s32_s32:
4924 case Intrinsic::nvvm_tex_unified_2d_array_v4s32_f32:
4925 case Intrinsic::nvvm_tex_unified_2d_array_level_v4s32_f32:
4926 case Intrinsic::nvvm_tex_unified_2d_array_grad_v4s32_f32:
4927 case Intrinsic::nvvm_tex_unified_3d_v4s32_s32:
4928 case Intrinsic::nvvm_tex_unified_3d_v4s32_f32:
4929 case Intrinsic::nvvm_tex_unified_3d_level_v4s32_f32:
4930 case Intrinsic::nvvm_tex_unified_3d_grad_v4s32_f32:
4931 case Intrinsic::nvvm_tex_unified_1d_v4u32_s32:
4932 case Intrinsic::nvvm_tex_unified_1d_v4u32_f32:
4933 case Intrinsic::nvvm_tex_unified_1d_level_v4u32_f32:
4934 case Intrinsic::nvvm_tex_unified_1d_grad_v4u32_f32:
4935 case Intrinsic::nvvm_tex_unified_1d_array_v4u32_s32:
4936 case Intrinsic::nvvm_tex_unified_1d_array_v4u32_f32:
4937 case Intrinsic::nvvm_tex_unified_1d_array_level_v4u32_f32:
4938 case Intrinsic::nvvm_tex_unified_1d_array_grad_v4u32_f32:
4939 case Intrinsic::nvvm_tex_unified_2d_v4u32_s32:
4940 case Intrinsic::nvvm_tex_unified_2d_v4u32_f32:
4941 case Intrinsic::nvvm_tex_unified_2d_level_v4u32_f32:
4942 case Intrinsic::nvvm_tex_unified_2d_grad_v4u32_f32:
4943 case Intrinsic::nvvm_tex_unified_2d_array_v4u32_s32:
4944 case Intrinsic::nvvm_tex_unified_2d_array_v4u32_f32:
4945 case Intrinsic::nvvm_tex_unified_2d_array_level_v4u32_f32:
4946 case Intrinsic::nvvm_tex_unified_2d_array_grad_v4u32_f32:
4947 case Intrinsic::nvvm_tex_unified_3d_v4u32_s32:
4948 case Intrinsic::nvvm_tex_unified_3d_v4u32_f32:
4949 case Intrinsic::nvvm_tex_unified_3d_level_v4u32_f32:
4950 case Intrinsic::nvvm_tex_unified_3d_grad_v4u32_f32:
4951 case Intrinsic::nvvm_tex_unified_cube_v4s32_f32:
4952 case Intrinsic::nvvm_tex_unified_cube_level_v4s32_f32:
4953 case Intrinsic::nvvm_tex_unified_cube_array_v4s32_f32:
4954 case Intrinsic::nvvm_tex_unified_cube_array_level_v4s32_f32:
4955 case Intrinsic::nvvm_tex_unified_cube_v4u32_f32:
4956 case Intrinsic::nvvm_tex_unified_cube_level_v4u32_f32:
4957 case Intrinsic::nvvm_tex_unified_cube_array_v4u32_f32:
4958 case Intrinsic::nvvm_tex_unified_cube_array_level_v4u32_f32:
4959 case Intrinsic::nvvm_tex_unified_cube_grad_v4s32_f32:
4960 case Intrinsic::nvvm_tex_unified_cube_grad_v4u32_f32:
4961 case Intrinsic::nvvm_tex_unified_cube_array_grad_v4s32_f32:
4962 case Intrinsic::nvvm_tex_unified_cube_array_grad_v4u32_f32:
4963 case Intrinsic::nvvm_tld4_unified_r_2d_v4s32_f32:
4964 case Intrinsic::nvvm_tld4_unified_g_2d_v4s32_f32:
4965 case Intrinsic::nvvm_tld4_unified_b_2d_v4s32_f32:
4966 case Intrinsic::nvvm_tld4_unified_a_2d_v4s32_f32:
4967 case Intrinsic::nvvm_tld4_unified_r_2d_v4u32_f32:
4968 case Intrinsic::nvvm_tld4_unified_g_2d_v4u32_f32:
4969 case Intrinsic::nvvm_tld4_unified_b_2d_v4u32_f32:
4970 case Intrinsic::nvvm_tld4_unified_a_2d_v4u32_f32:
4971 Info.opc = ISD::INTRINSIC_W_CHAIN;
4972 Info.memVT = MVT::v4i32;
4973 Info.ptrVal = nullptr;
4974 Info.offset = 0;
4975 Info.flags = MachineMemOperand::MOLoad;
4976 Info.align = Align(16);
4977 Infos.push_back(Elt: Info);
4978 return;
4979
4980 case Intrinsic::nvvm_suld_1d_i8_clamp:
4981 case Intrinsic::nvvm_suld_1d_v2i8_clamp:
4982 case Intrinsic::nvvm_suld_1d_v4i8_clamp:
4983 case Intrinsic::nvvm_suld_1d_array_i8_clamp:
4984 case Intrinsic::nvvm_suld_1d_array_v2i8_clamp:
4985 case Intrinsic::nvvm_suld_1d_array_v4i8_clamp:
4986 case Intrinsic::nvvm_suld_2d_i8_clamp:
4987 case Intrinsic::nvvm_suld_2d_v2i8_clamp:
4988 case Intrinsic::nvvm_suld_2d_v4i8_clamp:
4989 case Intrinsic::nvvm_suld_2d_array_i8_clamp:
4990 case Intrinsic::nvvm_suld_2d_array_v2i8_clamp:
4991 case Intrinsic::nvvm_suld_2d_array_v4i8_clamp:
4992 case Intrinsic::nvvm_suld_3d_i8_clamp:
4993 case Intrinsic::nvvm_suld_3d_v2i8_clamp:
4994 case Intrinsic::nvvm_suld_3d_v4i8_clamp:
4995 case Intrinsic::nvvm_suld_1d_i8_trap:
4996 case Intrinsic::nvvm_suld_1d_v2i8_trap:
4997 case Intrinsic::nvvm_suld_1d_v4i8_trap:
4998 case Intrinsic::nvvm_suld_1d_array_i8_trap:
4999 case Intrinsic::nvvm_suld_1d_array_v2i8_trap:
5000 case Intrinsic::nvvm_suld_1d_array_v4i8_trap:
5001 case Intrinsic::nvvm_suld_2d_i8_trap:
5002 case Intrinsic::nvvm_suld_2d_v2i8_trap:
5003 case Intrinsic::nvvm_suld_2d_v4i8_trap:
5004 case Intrinsic::nvvm_suld_2d_array_i8_trap:
5005 case Intrinsic::nvvm_suld_2d_array_v2i8_trap:
5006 case Intrinsic::nvvm_suld_2d_array_v4i8_trap:
5007 case Intrinsic::nvvm_suld_3d_i8_trap:
5008 case Intrinsic::nvvm_suld_3d_v2i8_trap:
5009 case Intrinsic::nvvm_suld_3d_v4i8_trap:
5010 case Intrinsic::nvvm_suld_1d_i8_zero:
5011 case Intrinsic::nvvm_suld_1d_v2i8_zero:
5012 case Intrinsic::nvvm_suld_1d_v4i8_zero:
5013 case Intrinsic::nvvm_suld_1d_array_i8_zero:
5014 case Intrinsic::nvvm_suld_1d_array_v2i8_zero:
5015 case Intrinsic::nvvm_suld_1d_array_v4i8_zero:
5016 case Intrinsic::nvvm_suld_2d_i8_zero:
5017 case Intrinsic::nvvm_suld_2d_v2i8_zero:
5018 case Intrinsic::nvvm_suld_2d_v4i8_zero:
5019 case Intrinsic::nvvm_suld_2d_array_i8_zero:
5020 case Intrinsic::nvvm_suld_2d_array_v2i8_zero:
5021 case Intrinsic::nvvm_suld_2d_array_v4i8_zero:
5022 case Intrinsic::nvvm_suld_3d_i8_zero:
5023 case Intrinsic::nvvm_suld_3d_v2i8_zero:
5024 case Intrinsic::nvvm_suld_3d_v4i8_zero:
5025 Info.opc = ISD::INTRINSIC_W_CHAIN;
5026 Info.memVT = MVT::i8;
5027 Info.ptrVal = nullptr;
5028 Info.offset = 0;
5029 Info.flags = MachineMemOperand::MOLoad;
5030 Info.align = Align(16);
5031 Infos.push_back(Elt: Info);
5032 return;
5033
5034 case Intrinsic::nvvm_suld_1d_i16_clamp:
5035 case Intrinsic::nvvm_suld_1d_v2i16_clamp:
5036 case Intrinsic::nvvm_suld_1d_v4i16_clamp:
5037 case Intrinsic::nvvm_suld_1d_array_i16_clamp:
5038 case Intrinsic::nvvm_suld_1d_array_v2i16_clamp:
5039 case Intrinsic::nvvm_suld_1d_array_v4i16_clamp:
5040 case Intrinsic::nvvm_suld_2d_i16_clamp:
5041 case Intrinsic::nvvm_suld_2d_v2i16_clamp:
5042 case Intrinsic::nvvm_suld_2d_v4i16_clamp:
5043 case Intrinsic::nvvm_suld_2d_array_i16_clamp:
5044 case Intrinsic::nvvm_suld_2d_array_v2i16_clamp:
5045 case Intrinsic::nvvm_suld_2d_array_v4i16_clamp:
5046 case Intrinsic::nvvm_suld_3d_i16_clamp:
5047 case Intrinsic::nvvm_suld_3d_v2i16_clamp:
5048 case Intrinsic::nvvm_suld_3d_v4i16_clamp:
5049 case Intrinsic::nvvm_suld_1d_i16_trap:
5050 case Intrinsic::nvvm_suld_1d_v2i16_trap:
5051 case Intrinsic::nvvm_suld_1d_v4i16_trap:
5052 case Intrinsic::nvvm_suld_1d_array_i16_trap:
5053 case Intrinsic::nvvm_suld_1d_array_v2i16_trap:
5054 case Intrinsic::nvvm_suld_1d_array_v4i16_trap:
5055 case Intrinsic::nvvm_suld_2d_i16_trap:
5056 case Intrinsic::nvvm_suld_2d_v2i16_trap:
5057 case Intrinsic::nvvm_suld_2d_v4i16_trap:
5058 case Intrinsic::nvvm_suld_2d_array_i16_trap:
5059 case Intrinsic::nvvm_suld_2d_array_v2i16_trap:
5060 case Intrinsic::nvvm_suld_2d_array_v4i16_trap:
5061 case Intrinsic::nvvm_suld_3d_i16_trap:
5062 case Intrinsic::nvvm_suld_3d_v2i16_trap:
5063 case Intrinsic::nvvm_suld_3d_v4i16_trap:
5064 case Intrinsic::nvvm_suld_1d_i16_zero:
5065 case Intrinsic::nvvm_suld_1d_v2i16_zero:
5066 case Intrinsic::nvvm_suld_1d_v4i16_zero:
5067 case Intrinsic::nvvm_suld_1d_array_i16_zero:
5068 case Intrinsic::nvvm_suld_1d_array_v2i16_zero:
5069 case Intrinsic::nvvm_suld_1d_array_v4i16_zero:
5070 case Intrinsic::nvvm_suld_2d_i16_zero:
5071 case Intrinsic::nvvm_suld_2d_v2i16_zero:
5072 case Intrinsic::nvvm_suld_2d_v4i16_zero:
5073 case Intrinsic::nvvm_suld_2d_array_i16_zero:
5074 case Intrinsic::nvvm_suld_2d_array_v2i16_zero:
5075 case Intrinsic::nvvm_suld_2d_array_v4i16_zero:
5076 case Intrinsic::nvvm_suld_3d_i16_zero:
5077 case Intrinsic::nvvm_suld_3d_v2i16_zero:
5078 case Intrinsic::nvvm_suld_3d_v4i16_zero:
5079 Info.opc = ISD::INTRINSIC_W_CHAIN;
5080 Info.memVT = MVT::i16;
5081 Info.ptrVal = nullptr;
5082 Info.offset = 0;
5083 Info.flags = MachineMemOperand::MOLoad;
5084 Info.align = Align(16);
5085 Infos.push_back(Elt: Info);
5086 return;
5087
5088 case Intrinsic::nvvm_suld_1d_i32_clamp:
5089 case Intrinsic::nvvm_suld_1d_v2i32_clamp:
5090 case Intrinsic::nvvm_suld_1d_v4i32_clamp:
5091 case Intrinsic::nvvm_suld_1d_array_i32_clamp:
5092 case Intrinsic::nvvm_suld_1d_array_v2i32_clamp:
5093 case Intrinsic::nvvm_suld_1d_array_v4i32_clamp:
5094 case Intrinsic::nvvm_suld_2d_i32_clamp:
5095 case Intrinsic::nvvm_suld_2d_v2i32_clamp:
5096 case Intrinsic::nvvm_suld_2d_v4i32_clamp:
5097 case Intrinsic::nvvm_suld_2d_array_i32_clamp:
5098 case Intrinsic::nvvm_suld_2d_array_v2i32_clamp:
5099 case Intrinsic::nvvm_suld_2d_array_v4i32_clamp:
5100 case Intrinsic::nvvm_suld_3d_i32_clamp:
5101 case Intrinsic::nvvm_suld_3d_v2i32_clamp:
5102 case Intrinsic::nvvm_suld_3d_v4i32_clamp:
5103 case Intrinsic::nvvm_suld_1d_i32_trap:
5104 case Intrinsic::nvvm_suld_1d_v2i32_trap:
5105 case Intrinsic::nvvm_suld_1d_v4i32_trap:
5106 case Intrinsic::nvvm_suld_1d_array_i32_trap:
5107 case Intrinsic::nvvm_suld_1d_array_v2i32_trap:
5108 case Intrinsic::nvvm_suld_1d_array_v4i32_trap:
5109 case Intrinsic::nvvm_suld_2d_i32_trap:
5110 case Intrinsic::nvvm_suld_2d_v2i32_trap:
5111 case Intrinsic::nvvm_suld_2d_v4i32_trap:
5112 case Intrinsic::nvvm_suld_2d_array_i32_trap:
5113 case Intrinsic::nvvm_suld_2d_array_v2i32_trap:
5114 case Intrinsic::nvvm_suld_2d_array_v4i32_trap:
5115 case Intrinsic::nvvm_suld_3d_i32_trap:
5116 case Intrinsic::nvvm_suld_3d_v2i32_trap:
5117 case Intrinsic::nvvm_suld_3d_v4i32_trap:
5118 case Intrinsic::nvvm_suld_1d_i32_zero:
5119 case Intrinsic::nvvm_suld_1d_v2i32_zero:
5120 case Intrinsic::nvvm_suld_1d_v4i32_zero:
5121 case Intrinsic::nvvm_suld_1d_array_i32_zero:
5122 case Intrinsic::nvvm_suld_1d_array_v2i32_zero:
5123 case Intrinsic::nvvm_suld_1d_array_v4i32_zero:
5124 case Intrinsic::nvvm_suld_2d_i32_zero:
5125 case Intrinsic::nvvm_suld_2d_v2i32_zero:
5126 case Intrinsic::nvvm_suld_2d_v4i32_zero:
5127 case Intrinsic::nvvm_suld_2d_array_i32_zero:
5128 case Intrinsic::nvvm_suld_2d_array_v2i32_zero:
5129 case Intrinsic::nvvm_suld_2d_array_v4i32_zero:
5130 case Intrinsic::nvvm_suld_3d_i32_zero:
5131 case Intrinsic::nvvm_suld_3d_v2i32_zero:
5132 case Intrinsic::nvvm_suld_3d_v4i32_zero:
5133 Info.opc = ISD::INTRINSIC_W_CHAIN;
5134 Info.memVT = MVT::i32;
5135 Info.ptrVal = nullptr;
5136 Info.offset = 0;
5137 Info.flags = MachineMemOperand::MOLoad;
5138 Info.align = Align(16);
5139 Infos.push_back(Elt: Info);
5140 return;
5141
5142 case Intrinsic::nvvm_suld_1d_i64_clamp:
5143 case Intrinsic::nvvm_suld_1d_v2i64_clamp:
5144 case Intrinsic::nvvm_suld_1d_array_i64_clamp:
5145 case Intrinsic::nvvm_suld_1d_array_v2i64_clamp:
5146 case Intrinsic::nvvm_suld_2d_i64_clamp:
5147 case Intrinsic::nvvm_suld_2d_v2i64_clamp:
5148 case Intrinsic::nvvm_suld_2d_array_i64_clamp:
5149 case Intrinsic::nvvm_suld_2d_array_v2i64_clamp:
5150 case Intrinsic::nvvm_suld_3d_i64_clamp:
5151 case Intrinsic::nvvm_suld_3d_v2i64_clamp:
5152 case Intrinsic::nvvm_suld_1d_i64_trap:
5153 case Intrinsic::nvvm_suld_1d_v2i64_trap:
5154 case Intrinsic::nvvm_suld_1d_array_i64_trap:
5155 case Intrinsic::nvvm_suld_1d_array_v2i64_trap:
5156 case Intrinsic::nvvm_suld_2d_i64_trap:
5157 case Intrinsic::nvvm_suld_2d_v2i64_trap:
5158 case Intrinsic::nvvm_suld_2d_array_i64_trap:
5159 case Intrinsic::nvvm_suld_2d_array_v2i64_trap:
5160 case Intrinsic::nvvm_suld_3d_i64_trap:
5161 case Intrinsic::nvvm_suld_3d_v2i64_trap:
5162 case Intrinsic::nvvm_suld_1d_i64_zero:
5163 case Intrinsic::nvvm_suld_1d_v2i64_zero:
5164 case Intrinsic::nvvm_suld_1d_array_i64_zero:
5165 case Intrinsic::nvvm_suld_1d_array_v2i64_zero:
5166 case Intrinsic::nvvm_suld_2d_i64_zero:
5167 case Intrinsic::nvvm_suld_2d_v2i64_zero:
5168 case Intrinsic::nvvm_suld_2d_array_i64_zero:
5169 case Intrinsic::nvvm_suld_2d_array_v2i64_zero:
5170 case Intrinsic::nvvm_suld_3d_i64_zero:
5171 case Intrinsic::nvvm_suld_3d_v2i64_zero:
5172 Info.opc = ISD::INTRINSIC_W_CHAIN;
5173 Info.memVT = MVT::i64;
5174 Info.ptrVal = nullptr;
5175 Info.offset = 0;
5176 Info.flags = MachineMemOperand::MOLoad;
5177 Info.align = Align(16);
5178 Infos.push_back(Elt: Info);
5179 return;
5180
5181 case Intrinsic::nvvm_tcgen05_ld_16x64b_x1:
5182 case Intrinsic::nvvm_tcgen05_ld_32x32b_x1:
5183 case Intrinsic::nvvm_tcgen05_ld_16x32bx2_x1: {
5184 Info.opc = ISD::INTRINSIC_W_CHAIN;
5185 Info.memVT = MVT::v1i32;
5186 Info.ptrVal = I.getArgOperand(i: 0);
5187 Info.offset = 0;
5188 Info.flags = MachineMemOperand::MOLoad;
5189 Info.align.reset();
5190 Infos.push_back(Elt: Info);
5191 return;
5192 }
5193
5194 case Intrinsic::nvvm_tcgen05_ld_16x64b_x2:
5195 case Intrinsic::nvvm_tcgen05_ld_16x128b_x1:
5196 case Intrinsic::nvvm_tcgen05_ld_32x32b_x2:
5197 case Intrinsic::nvvm_tcgen05_ld_16x32bx2_x2:
5198 case Intrinsic::nvvm_tcgen05_ld_red_32x32b_x2_i32:
5199 case Intrinsic::nvvm_tcgen05_ld_red_16x32bx2_x2_i32: {
5200 Info.opc = ISD::INTRINSIC_W_CHAIN;
5201 Info.memVT = MVT::v2i32;
5202 Info.ptrVal = I.getArgOperand(i: 0);
5203 Info.offset = 0;
5204 Info.flags = MachineMemOperand::MOLoad;
5205 Info.align.reset();
5206 Infos.push_back(Elt: Info);
5207 return;
5208 }
5209
5210 case Intrinsic::nvvm_tcgen05_ld_red_32x32b_x2_f32:
5211 case Intrinsic::nvvm_tcgen05_ld_red_16x32bx2_x2_f32: {
5212 Info.opc = ISD::INTRINSIC_W_CHAIN;
5213 Info.memVT = MVT::v2f32;
5214 Info.ptrVal = I.getArgOperand(i: 0);
5215 Info.offset = 0;
5216 Info.flags = MachineMemOperand::MOLoad;
5217 Info.align.reset();
5218 Infos.push_back(Elt: Info);
5219 return;
5220 }
5221
5222 case Intrinsic::nvvm_tcgen05_ld_16x64b_x4:
5223 case Intrinsic::nvvm_tcgen05_ld_16x128b_x2:
5224 case Intrinsic::nvvm_tcgen05_ld_32x32b_x4:
5225 case Intrinsic::nvvm_tcgen05_ld_16x256b_x1:
5226 case Intrinsic::nvvm_tcgen05_ld_16x32bx2_x4:
5227 case Intrinsic::nvvm_tcgen05_ld_red_32x32b_x4_i32:
5228 case Intrinsic::nvvm_tcgen05_ld_red_16x32bx2_x4_i32: {
5229 Info.opc = ISD::INTRINSIC_W_CHAIN;
5230 Info.memVT = MVT::v4i32;
5231 Info.ptrVal = I.getArgOperand(i: 0);
5232 Info.offset = 0;
5233 Info.flags = MachineMemOperand::MOLoad;
5234 Info.align.reset();
5235 Infos.push_back(Elt: Info);
5236 return;
5237 }
5238
5239 case Intrinsic::nvvm_tcgen05_ld_red_32x32b_x4_f32:
5240 case Intrinsic::nvvm_tcgen05_ld_red_16x32bx2_x4_f32: {
5241 Info.opc = ISD::INTRINSIC_W_CHAIN;
5242 Info.memVT = MVT::v4f32;
5243 Info.ptrVal = I.getArgOperand(i: 0);
5244 Info.offset = 0;
5245 Info.flags = MachineMemOperand::MOLoad;
5246 Info.align.reset();
5247 Infos.push_back(Elt: Info);
5248 return;
5249 }
5250
5251 case Intrinsic::nvvm_tcgen05_ld_16x64b_x8:
5252 case Intrinsic::nvvm_tcgen05_ld_16x128b_x4:
5253 case Intrinsic::nvvm_tcgen05_ld_16x256b_x2:
5254 case Intrinsic::nvvm_tcgen05_ld_32x32b_x8:
5255 case Intrinsic::nvvm_tcgen05_ld_16x32bx2_x8:
5256 case Intrinsic::nvvm_tcgen05_ld_red_32x32b_x8_i32:
5257 case Intrinsic::nvvm_tcgen05_ld_red_16x32bx2_x8_i32: {
5258 Info.opc = ISD::INTRINSIC_W_CHAIN;
5259 Info.memVT = MVT::v8i32;
5260 Info.ptrVal = I.getArgOperand(i: 0);
5261 Info.offset = 0;
5262 Info.flags = MachineMemOperand::MOLoad;
5263 Info.align.reset();
5264 Infos.push_back(Elt: Info);
5265 return;
5266 }
5267
5268 case Intrinsic::nvvm_tcgen05_ld_red_32x32b_x8_f32:
5269 case Intrinsic::nvvm_tcgen05_ld_red_16x32bx2_x8_f32: {
5270 Info.opc = ISD::INTRINSIC_W_CHAIN;
5271 Info.memVT = MVT::v8f32;
5272 Info.ptrVal = I.getArgOperand(i: 0);
5273 Info.offset = 0;
5274 Info.flags = MachineMemOperand::MOLoad;
5275 Info.align.reset();
5276 Infos.push_back(Elt: Info);
5277 return;
5278 }
5279
5280 case Intrinsic::nvvm_tcgen05_ld_16x64b_x16:
5281 case Intrinsic::nvvm_tcgen05_ld_16x128b_x8:
5282 case Intrinsic::nvvm_tcgen05_ld_16x256b_x4:
5283 case Intrinsic::nvvm_tcgen05_ld_32x32b_x16:
5284 case Intrinsic::nvvm_tcgen05_ld_16x32bx2_x16:
5285 case Intrinsic::nvvm_tcgen05_ld_red_32x32b_x16_i32:
5286 case Intrinsic::nvvm_tcgen05_ld_red_16x32bx2_x16_i32: {
5287 Info.opc = ISD::INTRINSIC_W_CHAIN;
5288 Info.memVT = MVT::v16i32;
5289 Info.ptrVal = I.getArgOperand(i: 0);
5290 Info.offset = 0;
5291 Info.flags = MachineMemOperand::MOLoad;
5292 Info.align.reset();
5293 Infos.push_back(Elt: Info);
5294 return;
5295 }
5296
5297 case Intrinsic::nvvm_tcgen05_ld_red_32x32b_x16_f32:
5298 case Intrinsic::nvvm_tcgen05_ld_red_16x32bx2_x16_f32: {
5299 Info.opc = ISD::INTRINSIC_W_CHAIN;
5300 Info.memVT = MVT::v16f32;
5301 Info.ptrVal = I.getArgOperand(i: 0);
5302 Info.offset = 0;
5303 Info.flags = MachineMemOperand::MOLoad;
5304 Info.align.reset();
5305 Infos.push_back(Elt: Info);
5306 return;
5307 }
5308
5309 case Intrinsic::nvvm_tcgen05_ld_16x64b_x32:
5310 case Intrinsic::nvvm_tcgen05_ld_16x128b_x16:
5311 case Intrinsic::nvvm_tcgen05_ld_16x256b_x8:
5312 case Intrinsic::nvvm_tcgen05_ld_32x32b_x32:
5313 case Intrinsic::nvvm_tcgen05_ld_16x32bx2_x32:
5314 case Intrinsic::nvvm_tcgen05_ld_red_32x32b_x32_i32:
5315 case Intrinsic::nvvm_tcgen05_ld_red_16x32bx2_x32_i32: {
5316 Info.opc = ISD::INTRINSIC_W_CHAIN;
5317 Info.memVT = MVT::v32i32;
5318 Info.ptrVal = I.getArgOperand(i: 0);
5319 Info.offset = 0;
5320 Info.flags = MachineMemOperand::MOLoad;
5321 Info.align.reset();
5322 Infos.push_back(Elt: Info);
5323 return;
5324 }
5325
5326 case Intrinsic::nvvm_tcgen05_ld_red_32x32b_x32_f32:
5327 case Intrinsic::nvvm_tcgen05_ld_red_16x32bx2_x32_f32: {
5328 Info.opc = ISD::INTRINSIC_W_CHAIN;
5329 Info.memVT = MVT::v32f32;
5330 Info.ptrVal = I.getArgOperand(i: 0);
5331 Info.offset = 0;
5332 Info.flags = MachineMemOperand::MOLoad;
5333 Info.align.reset();
5334 Infos.push_back(Elt: Info);
5335 return;
5336 }
5337
5338 case Intrinsic::nvvm_tcgen05_ld_16x64b_x64:
5339 case Intrinsic::nvvm_tcgen05_ld_16x128b_x32:
5340 case Intrinsic::nvvm_tcgen05_ld_16x256b_x16:
5341 case Intrinsic::nvvm_tcgen05_ld_32x32b_x64:
5342 case Intrinsic::nvvm_tcgen05_ld_16x32bx2_x64:
5343 case Intrinsic::nvvm_tcgen05_ld_red_32x32b_x64_i32:
5344 case Intrinsic::nvvm_tcgen05_ld_red_16x32bx2_x64_i32: {
5345 Info.opc = ISD::INTRINSIC_W_CHAIN;
5346 Info.memVT = MVT::v64i32;
5347 Info.ptrVal = I.getArgOperand(i: 0);
5348 Info.offset = 0;
5349 Info.flags = MachineMemOperand::MOLoad;
5350 Info.align.reset();
5351 Infos.push_back(Elt: Info);
5352 return;
5353 }
5354
5355 case Intrinsic::nvvm_tcgen05_ld_red_32x32b_x64_f32:
5356 case Intrinsic::nvvm_tcgen05_ld_red_16x32bx2_x64_f32: {
5357 Info.opc = ISD::INTRINSIC_W_CHAIN;
5358 Info.memVT = MVT::v64f32;
5359 Info.ptrVal = I.getArgOperand(i: 0);
5360 Info.offset = 0;
5361 Info.flags = MachineMemOperand::MOLoad;
5362 Info.align.reset();
5363 Infos.push_back(Elt: Info);
5364 return;
5365 }
5366
5367 case Intrinsic::nvvm_tcgen05_ld_16x64b_x128:
5368 case Intrinsic::nvvm_tcgen05_ld_16x128b_x64:
5369 case Intrinsic::nvvm_tcgen05_ld_16x256b_x32:
5370 case Intrinsic::nvvm_tcgen05_ld_32x32b_x128:
5371 case Intrinsic::nvvm_tcgen05_ld_16x32bx2_x128:
5372 case Intrinsic::nvvm_tcgen05_ld_red_32x32b_x128_i32:
5373 case Intrinsic::nvvm_tcgen05_ld_red_16x32bx2_x128_i32: {
5374 Info.opc = ISD::INTRINSIC_W_CHAIN;
5375 Info.memVT = MVT::v128i32;
5376 Info.ptrVal = I.getArgOperand(i: 0);
5377 Info.offset = 0;
5378 Info.flags = MachineMemOperand::MOLoad;
5379 Info.align.reset();
5380 Infos.push_back(Elt: Info);
5381 return;
5382 }
5383
5384 case Intrinsic::nvvm_tcgen05_ld_red_32x32b_x128_f32:
5385 case Intrinsic::nvvm_tcgen05_ld_red_16x32bx2_x128_f32: {
5386 Info.opc = ISD::INTRINSIC_W_CHAIN;
5387 Info.memVT = MVT::v128f32;
5388 Info.ptrVal = I.getArgOperand(i: 0);
5389 Info.offset = 0;
5390 Info.flags = MachineMemOperand::MOLoad;
5391 Info.align.reset();
5392 Infos.push_back(Elt: Info);
5393 return;
5394 }
5395
5396 case Intrinsic::nvvm_tcgen05_st_16x64b_x1:
5397 case Intrinsic::nvvm_tcgen05_st_32x32b_x1:
5398 case Intrinsic::nvvm_tcgen05_st_16x32bx2_x1: {
5399 Info.opc = ISD::INTRINSIC_VOID;
5400 Info.memVT = MVT::i32;
5401 Info.ptrVal = I.getArgOperand(i: 0);
5402 Info.offset = 0;
5403 Info.flags = MachineMemOperand::MOStore;
5404 Info.align.reset();
5405 Infos.push_back(Elt: Info);
5406 return;
5407 }
5408
5409 case Intrinsic::nvvm_tcgen05_st_16x64b_x2:
5410 case Intrinsic::nvvm_tcgen05_st_16x128b_x1:
5411 case Intrinsic::nvvm_tcgen05_st_32x32b_x2:
5412 case Intrinsic::nvvm_tcgen05_st_16x32bx2_x2: {
5413 Info.opc = ISD::INTRINSIC_VOID;
5414 Info.memVT = MVT::v2i32;
5415 Info.ptrVal = I.getArgOperand(i: 0);
5416 Info.offset = 0;
5417 Info.flags = MachineMemOperand::MOStore;
5418 Info.align.reset();
5419 Infos.push_back(Elt: Info);
5420 return;
5421 }
5422
5423 case Intrinsic::nvvm_tcgen05_st_16x64b_x4:
5424 case Intrinsic::nvvm_tcgen05_st_16x128b_x2:
5425 case Intrinsic::nvvm_tcgen05_st_16x256b_x1:
5426 case Intrinsic::nvvm_tcgen05_st_32x32b_x4:
5427 case Intrinsic::nvvm_tcgen05_st_16x32bx2_x4: {
5428 Info.opc = ISD::INTRINSIC_VOID;
5429 Info.memVT = MVT::v4i32;
5430 Info.ptrVal = I.getArgOperand(i: 0);
5431 Info.offset = 0;
5432 Info.flags = MachineMemOperand::MOStore;
5433 Info.align.reset();
5434 Infos.push_back(Elt: Info);
5435 return;
5436 }
5437
5438 case Intrinsic::nvvm_tcgen05_st_16x64b_x8:
5439 case Intrinsic::nvvm_tcgen05_st_16x128b_x4:
5440 case Intrinsic::nvvm_tcgen05_st_16x256b_x2:
5441 case Intrinsic::nvvm_tcgen05_st_32x32b_x8:
5442 case Intrinsic::nvvm_tcgen05_st_16x32bx2_x8: {
5443 Info.opc = ISD::INTRINSIC_VOID;
5444 Info.memVT = MVT::v8i32;
5445 Info.ptrVal = I.getArgOperand(i: 0);
5446 Info.offset = 0;
5447 Info.flags = MachineMemOperand::MOStore;
5448 Info.align.reset();
5449 Infos.push_back(Elt: Info);
5450 return;
5451 }
5452
5453 case Intrinsic::nvvm_tcgen05_st_16x64b_x16:
5454 case Intrinsic::nvvm_tcgen05_st_16x128b_x8:
5455 case Intrinsic::nvvm_tcgen05_st_16x256b_x4:
5456 case Intrinsic::nvvm_tcgen05_st_32x32b_x16:
5457 case Intrinsic::nvvm_tcgen05_st_16x32bx2_x16: {
5458 Info.opc = ISD::INTRINSIC_VOID;
5459 Info.memVT = MVT::v16i32;
5460 Info.ptrVal = I.getArgOperand(i: 0);
5461 Info.offset = 0;
5462 Info.flags = MachineMemOperand::MOStore;
5463 Info.align.reset();
5464 Infos.push_back(Elt: Info);
5465 return;
5466 }
5467
5468 case Intrinsic::nvvm_tcgen05_st_16x64b_x32:
5469 case Intrinsic::nvvm_tcgen05_st_16x128b_x16:
5470 case Intrinsic::nvvm_tcgen05_st_16x256b_x8:
5471 case Intrinsic::nvvm_tcgen05_st_32x32b_x32:
5472 case Intrinsic::nvvm_tcgen05_st_16x32bx2_x32: {
5473 Info.opc = ISD::INTRINSIC_VOID;
5474 Info.memVT = MVT::v32i32;
5475 Info.ptrVal = I.getArgOperand(i: 0);
5476 Info.offset = 0;
5477 Info.flags = MachineMemOperand::MOStore;
5478 Info.align.reset();
5479 Infos.push_back(Elt: Info);
5480 return;
5481 }
5482
5483 case Intrinsic::nvvm_tcgen05_st_16x64b_x64:
5484 case Intrinsic::nvvm_tcgen05_st_16x128b_x32:
5485 case Intrinsic::nvvm_tcgen05_st_16x256b_x16:
5486 case Intrinsic::nvvm_tcgen05_st_32x32b_x64:
5487 case Intrinsic::nvvm_tcgen05_st_16x32bx2_x64: {
5488 Info.opc = ISD::INTRINSIC_VOID;
5489 Info.memVT = MVT::v64i32;
5490 Info.ptrVal = I.getArgOperand(i: 0);
5491 Info.offset = 0;
5492 Info.flags = MachineMemOperand::MOStore;
5493 Info.align.reset();
5494 Infos.push_back(Elt: Info);
5495 return;
5496 }
5497
5498 case Intrinsic::nvvm_tcgen05_st_16x64b_x128:
5499 case Intrinsic::nvvm_tcgen05_st_16x128b_x64:
5500 case Intrinsic::nvvm_tcgen05_st_16x256b_x32:
5501 case Intrinsic::nvvm_tcgen05_st_32x32b_x128:
5502 case Intrinsic::nvvm_tcgen05_st_16x32bx2_x128: {
5503 Info.opc = ISD::INTRINSIC_VOID;
5504 Info.memVT = MVT::v128i32;
5505 Info.ptrVal = I.getArgOperand(i: 0);
5506 Info.offset = 0;
5507 Info.flags = MachineMemOperand::MOStore;
5508 Info.align.reset();
5509 Infos.push_back(Elt: Info);
5510 return;
5511 }
5512 case Intrinsic::nvvm_tcgen05_mma_shared_disable_output_lane_cg1:
5513 case Intrinsic::nvvm_tcgen05_mma_shared_scale_d_disable_output_lane_cg1:
5514 case Intrinsic::nvvm_tcgen05_mma_sp_shared_disable_output_lane_cg1:
5515 case Intrinsic::nvvm_tcgen05_mma_sp_shared_scale_d_disable_output_lane_cg1:
5516 case Intrinsic::nvvm_tcgen05_mma_tensor_disable_output_lane_cg1:
5517 case Intrinsic::nvvm_tcgen05_mma_tensor_scale_d_disable_output_lane_cg1:
5518 case Intrinsic::nvvm_tcgen05_mma_tensor_disable_output_lane_cg1_ashift:
5519 case Intrinsic::
5520 nvvm_tcgen05_mma_tensor_scale_d_disable_output_lane_cg1_ashift:
5521 case Intrinsic::nvvm_tcgen05_mma_sp_tensor_disable_output_lane_cg1:
5522 case Intrinsic::nvvm_tcgen05_mma_sp_tensor_scale_d_disable_output_lane_cg1:
5523 case Intrinsic::nvvm_tcgen05_mma_sp_tensor_disable_output_lane_cg1_ashift:
5524 case Intrinsic::
5525 nvvm_tcgen05_mma_sp_tensor_scale_d_disable_output_lane_cg1_ashift: {
5526 // We are reading and writing back to TMem
5527 Info.opc = ISD::INTRINSIC_VOID;
5528 Info.memVT = MVT::v4i32;
5529 Info.ptrVal = I.getArgOperand(i: 0);
5530 Info.offset = 0;
5531 Info.flags = MachineMemOperand::MOLoad | MachineMemOperand::MOStore;
5532 Info.align = Align(16);
5533 Infos.push_back(Elt: Info);
5534 return;
5535 }
5536
5537 case Intrinsic::nvvm_tcgen05_mma_shared_disable_output_lane_cg2:
5538 case Intrinsic::nvvm_tcgen05_mma_shared_scale_d_disable_output_lane_cg2:
5539 case Intrinsic::nvvm_tcgen05_mma_sp_shared_disable_output_lane_cg2:
5540 case Intrinsic::nvvm_tcgen05_mma_sp_shared_scale_d_disable_output_lane_cg2:
5541 case Intrinsic::nvvm_tcgen05_mma_tensor_disable_output_lane_cg2:
5542 case Intrinsic::nvvm_tcgen05_mma_tensor_scale_d_disable_output_lane_cg2:
5543 case Intrinsic::nvvm_tcgen05_mma_sp_tensor_disable_output_lane_cg2:
5544 case Intrinsic::nvvm_tcgen05_mma_sp_tensor_scale_d_disable_output_lane_cg2:
5545 case Intrinsic::nvvm_tcgen05_mma_tensor_disable_output_lane_cg2_ashift:
5546 case Intrinsic::
5547 nvvm_tcgen05_mma_tensor_scale_d_disable_output_lane_cg2_ashift:
5548 case Intrinsic::nvvm_tcgen05_mma_sp_tensor_disable_output_lane_cg2_ashift:
5549 case Intrinsic::
5550 nvvm_tcgen05_mma_sp_tensor_scale_d_disable_output_lane_cg2_ashift: {
5551 // We are reading and writing back to TMem
5552 Info.opc = ISD::INTRINSIC_VOID;
5553 Info.memVT = MVT::v8i32;
5554 Info.ptrVal = I.getArgOperand(i: 0);
5555 Info.offset = 0;
5556 Info.flags = MachineMemOperand::MOLoad | MachineMemOperand::MOStore;
5557 Info.align = Align(16);
5558 Infos.push_back(Elt: Info);
5559 return;
5560 }
5561 }
5562}
5563
5564// Helper for getting a function parameter name. Name is composed from
5565// its index and the function name. Negative index corresponds to special
5566// parameter (unsized array) used for passing variable arguments.
5567std::string NVPTXTargetLowering::getParamName(const Function *F,
5568 int Idx) const {
5569 std::string ParamName;
5570 raw_string_ostream ParamStr(ParamName);
5571
5572 ParamStr << getTargetMachine().getSymbol(GV: F)->getName();
5573 if (Idx < 0)
5574 ParamStr << "_vararg";
5575 else
5576 ParamStr << "_param_" << Idx;
5577
5578 return ParamName;
5579}
5580
5581/// isLegalAddressingMode - Return true if the addressing mode represented
5582/// by AM is legal for this target, for a load/store of the specified type.
5583/// Used to guide target specific optimizations, like loop strength reduction
5584/// (LoopStrengthReduce.cpp) and memory optimization for address mode
5585/// (CodeGenPrepare.cpp)
5586bool NVPTXTargetLowering::isLegalAddressingMode(const DataLayout &DL,
5587 const AddrMode &AM, Type *Ty,
5588 unsigned AS, Instruction *I) const {
5589 // AddrMode - This represents an addressing mode of:
5590 // BaseGV + BaseOffs + BaseReg + Scale*ScaleReg
5591 //
5592 // The legal address modes are
5593 // - [avar]
5594 // - [areg]
5595 // - [areg+immoff]
5596 // - [immAddr]
5597
5598 // immoff must fit in a signed 32-bit int
5599 if (!APInt(64, AM.BaseOffs).isSignedIntN(N: 32))
5600 return false;
5601
5602 if (AM.BaseGV)
5603 return !AM.BaseOffs && !AM.HasBaseReg && !AM.Scale;
5604
5605 switch (AM.Scale) {
5606 case 0: // "r", "r+i" or "i" is allowed
5607 break;
5608 case 1:
5609 if (AM.HasBaseReg) // "r+r+i" or "r+r" is not allowed.
5610 return false;
5611 // Otherwise we have r+i.
5612 break;
5613 default:
5614 // No scale > 1 is allowed
5615 return false;
5616 }
5617 return true;
5618}
5619
5620//===----------------------------------------------------------------------===//
5621// NVPTX Inline Assembly Support
5622//===----------------------------------------------------------------------===//
5623
5624/// getConstraintType - Given a constraint letter, return the type of
5625/// constraint it is for this target.
5626NVPTXTargetLowering::ConstraintType
5627NVPTXTargetLowering::getConstraintType(StringRef Constraint) const {
5628 if (Constraint.size() == 1) {
5629 switch (Constraint[0]) {
5630 default:
5631 break;
5632 case 'b':
5633 case 'r':
5634 case 'h':
5635 case 'c':
5636 case 'l':
5637 case 'f':
5638 case 'd':
5639 case 'q':
5640 case '0':
5641 case 'N':
5642 return C_RegisterClass;
5643 }
5644 }
5645 return TargetLowering::getConstraintType(Constraint);
5646}
5647
5648std::pair<unsigned, const TargetRegisterClass *>
5649NVPTXTargetLowering::getRegForInlineAsmConstraint(const TargetRegisterInfo *TRI,
5650 StringRef Constraint,
5651 MVT VT) const {
5652 if (Constraint.size() == 1) {
5653 switch (Constraint[0]) {
5654 case 'b':
5655 return std::make_pair(x: 0U, y: &NVPTX::B1RegClass);
5656 case 'c':
5657 case 'h':
5658 return std::make_pair(x: 0U, y: &NVPTX::B16RegClass);
5659 case 'r':
5660 case 'f':
5661 return std::make_pair(x: 0U, y: &NVPTX::B32RegClass);
5662 case 'l':
5663 case 'N':
5664 case 'd':
5665 return std::make_pair(x: 0U, y: &NVPTX::B64RegClass);
5666 case 'q': {
5667 if (STI.getSmVersion() < 70)
5668 report_fatal_error(reason: "Inline asm with 128 bit operands is only "
5669 "supported for sm_70 and higher!");
5670 return std::make_pair(x: 0U, y: &NVPTX::B128RegClass);
5671 }
5672 }
5673 }
5674 return TargetLowering::getRegForInlineAsmConstraint(TRI, Constraint, VT);
5675}
5676
5677//===----------------------------------------------------------------------===//
5678// NVPTX DAG Combining
5679//===----------------------------------------------------------------------===//
5680
5681bool NVPTXTargetLowering::allowFMA(MachineFunction &MF,
5682 CodeGenOptLevel OptLevel) const {
5683 // Always honor command-line argument
5684 if (FMAContractLevelOpt.getNumOccurrences() > 0)
5685 return FMAContractLevelOpt > 0;
5686
5687 // Do not contract if we're not optimizing the code.
5688 if (OptLevel == CodeGenOptLevel::None)
5689 return false;
5690
5691 // Honor TargetOptions flags that explicitly say fusion is okay.
5692 if (MF.getTarget().Options.AllowFPOpFusion == FPOpFusion::Fast)
5693 return true;
5694
5695 return false;
5696}
5697
5698static bool isConstZero(const SDValue &Operand) {
5699 const auto *Const = dyn_cast<ConstantSDNode>(Val: Operand);
5700 return Const && Const->getZExtValue() == 0;
5701}
5702
5703/// PerformADDCombineWithOperands - Try DAG combinations for an ADD with
5704/// operands N0 and N1. This is a helper for PerformADDCombine that is
5705/// called with the default operands, and if that fails, with commuted
5706/// operands.
5707static SDValue
5708PerformADDCombineWithOperands(SDNode *N, SDValue N0, SDValue N1,
5709 TargetLowering::DAGCombinerInfo &DCI) {
5710 EVT VT = N0.getValueType();
5711
5712 // Since integer multiply-add costs the same as integer multiply
5713 // but is more costly than integer add, do the fusion only when
5714 // the mul is only used in the add.
5715 // TODO: this may not be true for later architectures, consider relaxing this
5716 if (!N0.getNode()->hasOneUse())
5717 return SDValue();
5718
5719 // fold (add (select cond, 0, (mul a, b)), c)
5720 // -> (select cond, c, (add (mul a, b), c))
5721 //
5722 if (N0.getOpcode() == ISD::SELECT) {
5723 unsigned ZeroOpNum;
5724 if (isConstZero(Operand: N0->getOperand(Num: 1)))
5725 ZeroOpNum = 1;
5726 else if (isConstZero(Operand: N0->getOperand(Num: 2)))
5727 ZeroOpNum = 2;
5728 else
5729 return SDValue();
5730
5731 SDValue M = N0->getOperand(Num: (ZeroOpNum == 1) ? 2 : 1);
5732 if (M->getOpcode() != ISD::MUL || !M.getNode()->hasOneUse())
5733 return SDValue();
5734
5735 SDLoc DL(N);
5736 SDValue Mul =
5737 DCI.DAG.getNode(Opcode: ISD::MUL, DL, VT, N1: M->getOperand(Num: 0), N2: M->getOperand(Num: 1));
5738 SDValue MAD = DCI.DAG.getNode(Opcode: ISD::ADD, DL, VT, N1: Mul, N2: N1);
5739 return DCI.DAG.getSelect(DL: SDLoc(N), VT, Cond: N0->getOperand(Num: 0),
5740 LHS: ((ZeroOpNum == 1) ? N1 : MAD),
5741 RHS: ((ZeroOpNum == 1) ? MAD : N1));
5742 }
5743
5744 return SDValue();
5745}
5746
5747static SDValue
5748PerformFADDCombineWithOperands(SDNode *N, SDValue N0, SDValue N1,
5749 TargetLowering::DAGCombinerInfo &DCI,
5750 CodeGenOptLevel OptLevel) {
5751 EVT VT = N0.getValueType();
5752 if (N0.getOpcode() == ISD::FMUL) {
5753 const auto *TLI = static_cast<const NVPTXTargetLowering *>(
5754 &DCI.DAG.getTargetLoweringInfo());
5755 if (!(TLI->allowFMA(MF&: DCI.DAG.getMachineFunction(), OptLevel) ||
5756 (N->getFlags().hasAllowContract() &&
5757 N0->getFlags().hasAllowContract())))
5758 return SDValue();
5759
5760 // For floating point:
5761 // Do the fusion only when the mul has less than 5 uses and all
5762 // are add.
5763 // The heuristic is that if a use is not an add, then that use
5764 // cannot be fused into fma, therefore mul is still needed anyway.
5765 // If there are more than 4 uses, even if they are all add, fusing
5766 // them will increase register pressue.
5767 //
5768 int numUses = 0;
5769 int nonAddCount = 0;
5770 for (const SDNode *User : N0.getNode()->users()) {
5771 numUses++;
5772 if (User->getOpcode() != ISD::FADD)
5773 ++nonAddCount;
5774 if (numUses >= 5)
5775 return SDValue();
5776 }
5777 if (nonAddCount) {
5778 int orderNo = N->getIROrder();
5779 int orderNo2 = N0.getNode()->getIROrder();
5780 // simple heuristics here for considering potential register
5781 // pressure, the logics here is that the differnce are used
5782 // to measure the distance between def and use, the longer distance
5783 // more likely cause register pressure.
5784 if (orderNo - orderNo2 < 500)
5785 return SDValue();
5786
5787 // Now, check if at least one of the FMUL's operands is live beyond the
5788 // node N, which guarantees that the FMA will not increase register
5789 // pressure at node N.
5790 bool opIsLive = false;
5791 const SDNode *left = N0.getOperand(i: 0).getNode();
5792 const SDNode *right = N0.getOperand(i: 1).getNode();
5793
5794 if (isa<ConstantSDNode>(Val: left) || isa<ConstantSDNode>(Val: right))
5795 opIsLive = true;
5796
5797 if (!opIsLive)
5798 for (const SDNode *User : left->users()) {
5799 int orderNo3 = User->getIROrder();
5800 if (orderNo3 > orderNo) {
5801 opIsLive = true;
5802 break;
5803 }
5804 }
5805
5806 if (!opIsLive)
5807 for (const SDNode *User : right->users()) {
5808 int orderNo3 = User->getIROrder();
5809 if (orderNo3 > orderNo) {
5810 opIsLive = true;
5811 break;
5812 }
5813 }
5814
5815 if (!opIsLive)
5816 return SDValue();
5817 }
5818
5819 return DCI.DAG.getNode(Opcode: ISD::FMA, DL: SDLoc(N), VT, N1: N0.getOperand(i: 0),
5820 N2: N0.getOperand(i: 1), N3: N1);
5821 }
5822
5823 return SDValue();
5824}
5825
5826/// Fold unpacking movs into a load by increasing the number of return values.
5827///
5828/// ex:
5829/// L: v2f16,ch = load <p>
5830/// a: f16 = extractelt L:0, 0
5831/// b: f16 = extractelt L:0, 1
5832/// use(a, b)
5833///
5834/// ...is turned into...
5835///
5836/// L: f16,f16,ch = LoadV2 <p>
5837/// use(L:0, L:1)
5838static SDValue
5839combineUnpackingMovIntoLoad(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) {
5840 // Don't run this optimization before the legalizer
5841 if (!DCI.isAfterLegalizeDAG())
5842 return SDValue();
5843
5844 EVT ElementVT = N->getValueType(ResNo: 0);
5845 // Avoid non-packed types and v4i8
5846 if (!NVPTX::isPackedVectorTy(VT: ElementVT) || ElementVT == MVT::v4i8)
5847 return SDValue();
5848
5849 // Check whether all outputs are either used by an extractelt or are
5850 // glue/chain nodes
5851 if (!all_of(Range: N->uses(), P: [&](SDUse &U) {
5852 // Skip glue, chain nodes
5853 if (U.getValueType() == MVT::Glue || U.getValueType() == MVT::Other)
5854 return true;
5855 if (U.getUser()->getOpcode() == ISD::EXTRACT_VECTOR_ELT) {
5856 if (N->getOpcode() != ISD::LOAD)
5857 return true;
5858 // Since this is an ISD::LOAD, check all extractelts are used. If
5859 // any are not used, we don't want to defeat another optimization that
5860 // will narrow the load.
5861 //
5862 // For example:
5863 //
5864 // L: v2f16,ch = load <p>
5865 // e0: f16 = extractelt L:0, 0
5866 // e1: f16 = extractelt L:0, 1 <-- unused
5867 // store e0
5868 //
5869 // Can be optimized by DAGCombiner to:
5870 //
5871 // L: f16,ch = load <p>
5872 // store L:0
5873 return !U.getUser()->use_empty();
5874 }
5875
5876 // Otherwise, this use prevents us from splitting a value.
5877 return false;
5878 }))
5879 return SDValue();
5880
5881 auto *LD = cast<MemSDNode>(Val: N);
5882 SDLoc DL(LD);
5883
5884 // the new opcode after we double the number of operands
5885 unsigned Opcode;
5886 SmallVector<SDValue> Operands(LD->ops());
5887 unsigned OldNumOutputs; // non-glue, non-chain outputs
5888 switch (LD->getOpcode()) {
5889 case ISD::LOAD:
5890 OldNumOutputs = 1;
5891 // Any packed type is legal, so the legalizer will not have lowered
5892 // ISD::LOAD -> NVPTXISD::Load (unless it's under-aligned). We have to do it
5893 // here.
5894 Opcode = NVPTXISD::LoadV2;
5895 // append a "full" used bytes mask operand right before the extension type
5896 // operand, signifying that all bytes are used.
5897 Operands.push_back(Elt: DCI.DAG.getConstant(UINT32_MAX, DL, VT: MVT::i32));
5898 Operands.push_back(Elt: DCI.DAG.getIntPtrConstant(
5899 Val: cast<LoadSDNode>(Val: LD)->getExtensionType(), DL));
5900 break;
5901 case NVPTXISD::LoadV2:
5902 OldNumOutputs = 2;
5903 Opcode = NVPTXISD::LoadV4;
5904 break;
5905 case NVPTXISD::LoadV4:
5906 // V8 is only supported for f32/i32. Don't forget, we're not changing the
5907 // load size here. This is already a 256-bit load.
5908 if (ElementVT != MVT::v2f32 && ElementVT != MVT::v2i32)
5909 return SDValue();
5910 OldNumOutputs = 4;
5911 Opcode = NVPTXISD::LoadV8;
5912 break;
5913 case NVPTXISD::LoadV8:
5914 // PTX doesn't support the next doubling of outputs
5915 return SDValue();
5916 }
5917
5918 // the non-glue, non-chain outputs in the new load
5919 const unsigned NewNumOutputs = OldNumOutputs * 2;
5920 SmallVector<EVT> NewVTs(NewNumOutputs, ElementVT.getVectorElementType());
5921 // add remaining chain and glue values
5922 NewVTs.append(in_start: LD->value_begin() + OldNumOutputs, in_end: LD->value_end());
5923
5924 // Create the new load
5925 SDValue NewLoad = DCI.DAG.getMemIntrinsicNode(
5926 Opcode, dl: DL, VTList: DCI.DAG.getVTList(VTs: NewVTs), Ops: Operands, MemVT: LD->getMemoryVT(),
5927 MMO: LD->getMemOperand());
5928
5929 // Now we use a combination of BUILD_VECTORs and a MERGE_VALUES node to keep
5930 // the outputs the same. These nodes will be optimized away in later
5931 // DAGCombiner iterations.
5932 SmallVector<SDValue> Results;
5933 for (unsigned I : seq(Size: OldNumOutputs))
5934 Results.push_back(Elt: DCI.DAG.getBuildVector(
5935 VT: ElementVT, DL, Ops: {NewLoad.getValue(R: I * 2), NewLoad.getValue(R: I * 2 + 1)}));
5936 // Add remaining chain and glue nodes
5937 for (unsigned I : seq(Size: NewLoad->getNumValues() - NewNumOutputs))
5938 Results.push_back(Elt: NewLoad.getValue(R: NewNumOutputs + I));
5939
5940 return DCI.DAG.getMergeValues(Ops: Results, dl: DL);
5941}
5942
5943/// Fold packing movs into a store.
5944///
5945/// ex:
5946/// v1: v2f16 = BUILD_VECTOR a:f16, b:f16
5947/// v2: v2f16 = BUILD_VECTOR c:f16, d:f16
5948/// StoreV2 v1, v2
5949///
5950/// ...is turned into...
5951///
5952/// StoreV4 a, b, c, d
5953static SDValue combinePackingMovIntoStore(SDNode *N,
5954 TargetLowering::DAGCombinerInfo &DCI,
5955 unsigned Front, unsigned Back) {
5956 // We want to run this as late as possible since other optimizations may
5957 // eliminate the BUILD_VECTORs.
5958 if (!DCI.isAfterLegalizeDAG())
5959 return SDValue();
5960
5961 // Get the type of the operands being stored.
5962 EVT ElementVT = N->getOperand(Num: Front).getValueType();
5963
5964 // Avoid non-packed types and v4i8
5965 if (!NVPTX::isPackedVectorTy(VT: ElementVT) || ElementVT == MVT::v4i8)
5966 return SDValue();
5967
5968 auto *ST = cast<MemSDNode>(Val: N);
5969
5970 // The new opcode after we double the number of operands.
5971 unsigned Opcode;
5972 switch (N->getOpcode()) {
5973 case ISD::STORE:
5974 // Any packed type is legal, so the legalizer will not have lowered
5975 // ISD::STORE -> NVPTXISD::Store (unless it's under-aligned). We have to do
5976 // it here.
5977 Opcode = NVPTXISD::StoreV2;
5978 break;
5979 case NVPTXISD::StoreV2:
5980 Opcode = NVPTXISD::StoreV4;
5981 break;
5982 case NVPTXISD::StoreV4:
5983 // V8 is only supported for f32/i32. Don't forget, we're not changing the
5984 // store size here. This is already a 256-bit store.
5985 if (ElementVT != MVT::v2f32 && ElementVT != MVT::v2i32)
5986 return SDValue();
5987 Opcode = NVPTXISD::StoreV8;
5988 break;
5989 case NVPTXISD::StoreV8:
5990 // PTX doesn't support the next doubling of operands
5991 return SDValue();
5992 default:
5993 llvm_unreachable("Unhandled store opcode");
5994 }
5995
5996 // Scan the operands and if they're all BUILD_VECTORs, we'll have gathered
5997 // their elements.
5998 SmallVector<SDValue, 4> Operands(N->ops().take_front(N: Front));
5999 for (SDValue BV : N->ops().drop_front(N: Front).drop_back(N: Back)) {
6000 if (BV.getOpcode() != ISD::BUILD_VECTOR)
6001 return SDValue();
6002
6003 // If the operand has multiple uses, this optimization can increase register
6004 // pressure.
6005 if (!BV.hasOneUse())
6006 return SDValue();
6007
6008 // DAGCombiner visits nodes bottom-up. Check the BUILD_VECTOR operands for
6009 // any signs they may be folded by some other pattern or rule.
6010 for (SDValue Op : BV->ops()) {
6011 // Peek through bitcasts
6012 if (Op.getOpcode() == ISD::BITCAST)
6013 Op = Op.getOperand(i: 0);
6014
6015 // This may be folded into a PRMT.
6016 if (Op.getValueType() == MVT::i16 && Op.getOpcode() == ISD::TRUNCATE &&
6017 Op->getOperand(Num: 0).getValueType() == MVT::i32)
6018 return SDValue();
6019
6020 // This may be folded into cvt.bf16x2
6021 if (Op.getOpcode() == ISD::FP_ROUND)
6022 return SDValue();
6023 }
6024 Operands.append(IL: {BV.getOperand(i: 0), BV.getOperand(i: 1)});
6025 }
6026 Operands.append(in_start: N->op_end() - Back, in_end: N->op_end());
6027
6028 // Now we replace the store
6029 return DCI.DAG.getMemIntrinsicNode(Opcode, dl: SDLoc(N), VTList: N->getVTList(), Ops: Operands,
6030 MemVT: ST->getMemoryVT(), MMO: ST->getMemOperand());
6031}
6032
6033static SDValue combineSTORE(SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
6034 const NVPTXSubtarget &STI) {
6035
6036 if (DCI.isBeforeLegalize() && N->getOpcode() == ISD::STORE) {
6037 // Here is our chance to custom lower a store with a non-simple type.
6038 // Unfortunately, we can't do this in the legalizer because there is no
6039 // way to setOperationAction for an non-simple type.
6040 StoreSDNode *ST = cast<StoreSDNode>(Val: N);
6041 if (!ST->getValue().getValueType().isSimple())
6042 return lowerSTOREVector(Op: SDValue(ST, 0), DAG&: DCI.DAG, STI);
6043 }
6044
6045 return combinePackingMovIntoStore(N, DCI, Front: 1, Back: 2);
6046}
6047
6048static SDValue combineLOAD(SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
6049 const NVPTXSubtarget &STI) {
6050 if (DCI.isBeforeLegalize() && N->getOpcode() == ISD::LOAD) {
6051 // Here is our chance to custom lower a load with a non-simple type.
6052 // Unfortunately, we can't do this in the legalizer because there is no
6053 // way to setOperationAction for an non-simple type.
6054 if (!N->getValueType(ResNo: 0).isSimple())
6055 return lowerLoadVector(N, DAG&: DCI.DAG, STI);
6056 }
6057
6058 return combineUnpackingMovIntoLoad(N, DCI);
6059}
6060
6061/// PerformADDCombine - Target-specific dag combine xforms for ISD::ADD.
6062///
6063static SDValue PerformADDCombine(SDNode *N,
6064 TargetLowering::DAGCombinerInfo &DCI,
6065 CodeGenOptLevel OptLevel) {
6066 if (OptLevel == CodeGenOptLevel::None)
6067 return SDValue();
6068
6069 SDValue N0 = N->getOperand(Num: 0);
6070 SDValue N1 = N->getOperand(Num: 1);
6071
6072 // Skip non-integer, non-scalar case
6073 EVT VT = N0.getValueType();
6074 if (VT.isVector() || VT != MVT::i32)
6075 return SDValue();
6076
6077 // First try with the default operand order.
6078 if (SDValue Result = PerformADDCombineWithOperands(N, N0, N1, DCI))
6079 return Result;
6080
6081 // If that didn't work, try again with the operands commuted.
6082 return PerformADDCombineWithOperands(N, N0: N1, N1: N0, DCI);
6083}
6084
6085/// Check if a v2f32 BUILD_VECTOR provably packs values from non-adjacent
6086/// register pairs (non-coalescable).
6087static bool isNonCoalescableBuildVector(const SDValue &BV) {
6088 if (BV.getOpcode() != ISD::BUILD_VECTOR || BV.getValueType() != MVT::v2f32)
6089 return false;
6090
6091 SDValue Elt0 = BV.getOperand(i: 0);
6092 SDValue Elt1 = BV.getOperand(i: 1);
6093
6094 bool IsExt0 = Elt0.getOpcode() == ISD::EXTRACT_VECTOR_ELT;
6095 bool IsExt1 = Elt1.getOpcode() == ISD::EXTRACT_VECTOR_ELT;
6096
6097 // If neither element is an EXTRACT_VECTOR_ELT they are free-standing
6098 // scalars and the register allocator can still place them side-by-side.
6099 if (!IsExt0 && !IsExt1)
6100 return false;
6101
6102 // If exactly one element is an EXTRACT_VECTOR_ELT, the other is a scalar
6103 // that cannot generally occupy the adjacent register slot.
6104 if (IsExt0 != IsExt1)
6105 return true;
6106
6107 // At this point both sources are extracting from vectors. If they are from
6108 // different vectors, then the BUILD_VECTOR is non-coalescable.
6109 SDValue Src0 = Elt0.getOperand(i: 0);
6110 SDValue Src1 = Elt1.getOperand(i: 0);
6111 if (Src0 != Src1)
6112 return true;
6113
6114 auto *Idx0 = dyn_cast<ConstantSDNode>(Val: Elt0.getOperand(i: 1));
6115 auto *Idx1 = dyn_cast<ConstantSDNode>(Val: Elt1.getOperand(i: 1));
6116 // If both indices are dynamic they will be lowered to
6117 // loads and the vector will be spilled to local memory. The register
6118 // allocator can easily place the results in adjacent registers.
6119 if (!Idx0 && !Idx1)
6120 return false;
6121
6122 // If one index is dynamic and the other is constant, the value from the
6123 // constant load will result in an additional register to pair with the result
6124 // from the dynamic load. We consider this non-coalescable.
6125 if ((Idx0 && !Idx1) || (!Idx0 && Idx1))
6126 return true;
6127
6128 // Both are constant, adjacent pairs are coalescable
6129 return std::abs(i: Idx0->getSExtValue() - Idx1->getSExtValue()) != 1;
6130}
6131
6132/// Scalarize a v2f32 arithmetic node (FADD, FMUL, FSUB, FMA) when at least
6133/// one operand is a BUILD_VECTOR that repacks values from non-adjacent register
6134/// pairs. Without this combine the BUILD_VECTOR forces allocation of a
6135/// temporary 64-bit register, increasing register pressure.
6136///
6137/// Example - before:
6138/// t0: v2f32,v2f32,ch = LoadV2 ...
6139/// t1: f32 = extract_vector_elt t0, 0
6140/// t2: f32 = extract_vector_elt t0:1, 0
6141/// t3: v2f32 = BUILD_VECTOR t1, t2 ;; non-coalescable repack
6142/// t4: v2f32 = fma t_a, t3, t_c
6143///
6144/// After:
6145/// t0: v2f32,v2f32,ch = LoadV2 ...
6146/// t1: f32 = extract_vector_elt t0, 0
6147/// t2: f32 = extract_vector_elt t0:1, 0
6148/// a0: f32 = extract_vector_elt t_a, 0
6149/// a1: f32 = extract_vector_elt t_a, 1
6150/// c0: f32 = extract_vector_elt t_c, 0
6151/// c1: f32 = extract_vector_elt t_c, 1
6152/// r0: f32 = fma a0, t1, c0
6153/// r1: f32 = fma a1, t2, c1
6154/// t4: v2f32 = BUILD_VECTOR r0, r1
6155static SDValue PerformScalarizeV2F32Op(SDNode *N,
6156 TargetLowering::DAGCombinerInfo &DCI) {
6157 EVT VT = N->getValueType(ResNo: 0);
6158 if (VT != MVT::v2f32)
6159 return SDValue();
6160
6161 // Only scalarize when at least one operand is a BUILD_VECTOR whose elements
6162 // are guaranteed to reside in different register pairs.
6163 if (none_of(Range: N->ops(), P: isNonCoalescableBuildVector))
6164 return SDValue();
6165
6166 SelectionDAG &DAG = DCI.DAG;
6167 SDLoc DL(N);
6168 EVT EltVT = VT.getVectorElementType();
6169 unsigned Opc = N->getOpcode();
6170
6171 // For each operand, get the scalar element at the given index: if the operand
6172 // is a BUILD_VECTOR, grab the element directly; otherwise, emit an
6173 // EXTRACT_VECTOR_ELT.
6174 auto GetElement = [&](SDValue Op, unsigned Index) -> SDValue {
6175 if (Op.getOpcode() == ISD::BUILD_VECTOR)
6176 return Op.getOperand(i: Index);
6177 return DAG.getNode(Opcode: ISD::EXTRACT_VECTOR_ELT, DL, VT: EltVT, N1: Op,
6178 N2: DAG.getVectorIdxConstant(Val: Index, DL));
6179 };
6180
6181 // Build scalar operand lists for element 0 and element 1.
6182 SmallVector<SDValue, 3> Ops0, Ops1;
6183 for (const SDValue &Op : N->ops()) {
6184 Ops0.push_back(Elt: GetElement(Op, 0));
6185 Ops1.push_back(Elt: GetElement(Op, 1));
6186 }
6187
6188 SDValue Res0 = DAG.getNode(Opcode: Opc, DL, VT: EltVT, Ops: Ops0, Flags: N->getFlags());
6189 SDValue Res1 = DAG.getNode(Opcode: Opc, DL, VT: EltVT, Ops: Ops1, Flags: N->getFlags());
6190
6191 return DAG.getNode(Opcode: ISD::BUILD_VECTOR, DL, VT, N1: Res0, N2: Res1);
6192}
6193
6194/// PerformFADDCombine - Target-specific dag combine xforms for ISD::FADD.
6195///
6196static SDValue PerformFADDCombine(SDNode *N,
6197 TargetLowering::DAGCombinerInfo &DCI,
6198 CodeGenOptLevel OptLevel) {
6199 SDValue N0 = N->getOperand(Num: 0);
6200 SDValue N1 = N->getOperand(Num: 1);
6201
6202 if (SDValue Result = PerformScalarizeV2F32Op(N, DCI))
6203 return Result;
6204
6205 EVT VT = N0.getValueType();
6206 if (VT.isVector() || !(VT == MVT::f32 || VT == MVT::f64))
6207 return SDValue();
6208
6209 // First try with the default operand order.
6210 if (SDValue Result = PerformFADDCombineWithOperands(N, N0, N1, DCI, OptLevel))
6211 return Result;
6212
6213 // If that didn't work, try again with the operands commuted.
6214 return PerformFADDCombineWithOperands(N, N0: N1, N1: N0, DCI, OptLevel);
6215}
6216
6217/// Get 3-input version of a 2-input min/max opcode
6218static unsigned getMinMax3Opcode(unsigned MinMax2Opcode) {
6219 switch (MinMax2Opcode) {
6220 case ISD::FMAXNUM:
6221 case ISD::FMAXIMUMNUM:
6222 return NVPTXISD::FMAXNUM3;
6223 case ISD::FMINNUM:
6224 case ISD::FMINIMUMNUM:
6225 return NVPTXISD::FMINNUM3;
6226 case ISD::FMAXIMUM:
6227 return NVPTXISD::FMAXIMUM3;
6228 case ISD::FMINIMUM:
6229 return NVPTXISD::FMINIMUM3;
6230 default:
6231 llvm_unreachable("Invalid 2-input min/max opcode");
6232 }
6233}
6234
6235/// PerformFMinMaxCombine - Combine (fmaxnum (fmaxnum a, b), c) into
6236/// (fmaxnum3 a, b, c). Also covers other llvm min/max intrinsics.
6237static SDValue PerformFMinMaxCombine(SDNode *N,
6238 TargetLowering::DAGCombinerInfo &DCI,
6239 unsigned PTXVersion, unsigned SmVersion) {
6240
6241 // 3-input min/max requires PTX 8.8+ and SM_100+, and only supports f32s
6242 EVT VT = N->getValueType(ResNo: 0);
6243 if (VT != MVT::f32 || PTXVersion < 88 || SmVersion < 100)
6244 return SDValue();
6245
6246 SDValue Op0 = N->getOperand(Num: 0);
6247 SDValue Op1 = N->getOperand(Num: 1);
6248 unsigned MinMaxOp2 = N->getOpcode();
6249 unsigned MinMaxOp3 = getMinMax3Opcode(MinMax2Opcode: MinMaxOp2);
6250
6251 if (Op0.getOpcode() == MinMaxOp2 && Op0.hasOneUse()) {
6252 // (maxnum (maxnum a, b), c) -> (maxnum3 a, b, c)
6253 SDValue A = Op0.getOperand(i: 0);
6254 SDValue B = Op0.getOperand(i: 1);
6255 SDValue C = Op1;
6256 return DCI.DAG.getNode(Opcode: MinMaxOp3, DL: SDLoc(N), VT, N1: A, N2: B, N3: C, Flags: N->getFlags());
6257 } else if (Op1.getOpcode() == MinMaxOp2 && Op1.hasOneUse()) {
6258 // (maxnum a, (maxnum b, c)) -> (maxnum3 a, b, c)
6259 SDValue A = Op0;
6260 SDValue B = Op1.getOperand(i: 0);
6261 SDValue C = Op1.getOperand(i: 1);
6262 return DCI.DAG.getNode(Opcode: MinMaxOp3, DL: SDLoc(N), VT, N1: A, N2: B, N3: C, Flags: N->getFlags());
6263 }
6264 return SDValue();
6265}
6266
6267static SDValue PerformREMCombine(SDNode *N,
6268 TargetLowering::DAGCombinerInfo &DCI,
6269 CodeGenOptLevel OptLevel) {
6270 assert(N->getOpcode() == ISD::SREM || N->getOpcode() == ISD::UREM);
6271
6272 // Don't do anything at less than -O2.
6273 if (OptLevel < CodeGenOptLevel::Default)
6274 return SDValue();
6275
6276 SelectionDAG &DAG = DCI.DAG;
6277 SDLoc DL(N);
6278 EVT VT = N->getValueType(ResNo: 0);
6279 bool IsSigned = N->getOpcode() == ISD::SREM;
6280 unsigned DivOpc = IsSigned ? ISD::SDIV : ISD::UDIV;
6281
6282 const SDValue &Num = N->getOperand(Num: 0);
6283 const SDValue &Den = N->getOperand(Num: 1);
6284
6285 for (const SDNode *U : Num->users()) {
6286 if (U->getOpcode() == DivOpc && U->getOperand(Num: 0) == Num &&
6287 U->getOperand(Num: 1) == Den) {
6288 // Num % Den -> Num - (Num / Den) * Den
6289 return DAG.getNode(Opcode: ISD::SUB, DL, VT, N1: Num,
6290 N2: DAG.getNode(Opcode: ISD::MUL, DL, VT,
6291 N1: DAG.getNode(Opcode: DivOpc, DL, VT, N1: Num, N2: Den),
6292 N2: Den));
6293 }
6294 }
6295 return SDValue();
6296}
6297
6298// (sign_extend|zero_extend (mul|shl) x, y) -> (mul.wide x, y)
6299static SDValue combineMulWide(SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
6300 CodeGenOptLevel OptLevel) {
6301 if (OptLevel == CodeGenOptLevel::None)
6302 return SDValue();
6303
6304 SDValue Op = N->getOperand(Num: 0);
6305 if (!Op.hasOneUse())
6306 return SDValue();
6307 EVT ToVT = N->getValueType(ResNo: 0);
6308 EVT FromVT = Op.getValueType();
6309 if (!((ToVT == MVT::i32 && FromVT == MVT::i16) ||
6310 (ToVT == MVT::i64 && FromVT == MVT::i32)))
6311 return SDValue();
6312 if (!(Op.getOpcode() == ISD::MUL ||
6313 (Op.getOpcode() == ISD::SHL && isa<ConstantSDNode>(Val: Op.getOperand(i: 1)))))
6314 return SDValue();
6315
6316 SDLoc DL(N);
6317 unsigned ExtOpcode = N->getOpcode();
6318 unsigned Opcode = 0;
6319 if (ExtOpcode == ISD::SIGN_EXTEND && Op->getFlags().hasNoSignedWrap())
6320 Opcode = NVPTXISD::MUL_WIDE_SIGNED;
6321 else if (ExtOpcode == ISD::ZERO_EXTEND && Op->getFlags().hasNoUnsignedWrap())
6322 Opcode = NVPTXISD::MUL_WIDE_UNSIGNED;
6323 else
6324 return SDValue();
6325 SDValue RHS = Op.getOperand(i: 1);
6326 if (Op.getOpcode() == ISD::SHL) {
6327 const auto ShiftAmt = Op.getConstantOperandVal(i: 1);
6328 const auto MulVal = APInt(FromVT.getSizeInBits(), 1) << ShiftAmt;
6329 RHS = DCI.DAG.getConstant(Val: MulVal, DL, VT: FromVT);
6330 }
6331 return DCI.DAG.getNode(Opcode, DL, VT: ToVT, N1: Op.getOperand(i: 0), N2: RHS);
6332}
6333
6334enum OperandSignedness {
6335 Signed = 0,
6336 Unsigned,
6337 Unknown
6338};
6339
6340/// IsMulWideOperandDemotable - Checks if the provided DAG node is an operand
6341/// that can be demoted to \p OptSize bits without loss of information. The
6342/// signedness of the operand, if determinable, is placed in \p S.
6343static bool IsMulWideOperandDemotable(SDValue Op,
6344 unsigned OptSize,
6345 OperandSignedness &S) {
6346 S = Unknown;
6347
6348 if (Op.getOpcode() == ISD::SIGN_EXTEND ||
6349 Op.getOpcode() == ISD::SIGN_EXTEND_INREG) {
6350 EVT OrigVT = Op.getOperand(i: 0).getValueType();
6351 if (OrigVT.getFixedSizeInBits() <= OptSize) {
6352 S = Signed;
6353 return true;
6354 }
6355 } else if (Op.getOpcode() == ISD::ZERO_EXTEND) {
6356 EVT OrigVT = Op.getOperand(i: 0).getValueType();
6357 if (OrigVT.getFixedSizeInBits() <= OptSize) {
6358 S = Unsigned;
6359 return true;
6360 }
6361 }
6362
6363 return false;
6364}
6365
6366/// AreMulWideOperandsDemotable - Checks if the given LHS and RHS operands can
6367/// be demoted to \p OptSize bits without loss of information. If the operands
6368/// contain a constant, it should appear as the RHS operand. The signedness of
6369/// the operands is placed in \p IsSigned.
6370static bool AreMulWideOperandsDemotable(SDValue LHS, SDValue RHS,
6371 unsigned OptSize,
6372 bool &IsSigned) {
6373 OperandSignedness LHSSign;
6374
6375 // The LHS operand must be a demotable op
6376 if (!IsMulWideOperandDemotable(Op: LHS, OptSize, S&: LHSSign))
6377 return false;
6378
6379 // We should have been able to determine the signedness from the LHS
6380 if (LHSSign == Unknown)
6381 return false;
6382
6383 IsSigned = (LHSSign == Signed);
6384
6385 // The RHS can be a demotable op or a constant
6386 if (ConstantSDNode *CI = dyn_cast<ConstantSDNode>(Val&: RHS)) {
6387 const APInt &Val = CI->getAPIntValue();
6388 if (LHSSign == Unsigned) {
6389 return Val.isIntN(N: OptSize);
6390 } else {
6391 return Val.isSignedIntN(N: OptSize);
6392 }
6393 } else {
6394 OperandSignedness RHSSign;
6395 if (!IsMulWideOperandDemotable(Op: RHS, OptSize, S&: RHSSign))
6396 return false;
6397
6398 return LHSSign == RHSSign;
6399 }
6400}
6401
6402/// TryMULWIDECombine - Attempt to replace a multiply of M bits with a multiply
6403/// of M/2 bits that produces an M-bit result (i.e. mul.wide). This transform
6404/// works on both multiply DAG nodes and SHL DAG nodes with a constant shift
6405/// amount.
6406static SDValue TryMULWIDECombine(SDNode *N,
6407 TargetLowering::DAGCombinerInfo &DCI) {
6408 EVT MulType = N->getValueType(ResNo: 0);
6409 if (MulType != MVT::i32 && MulType != MVT::i64) {
6410 return SDValue();
6411 }
6412
6413 SDLoc DL(N);
6414 unsigned OptSize = MulType.getSizeInBits() >> 1;
6415 SDValue LHS = N->getOperand(Num: 0);
6416 SDValue RHS = N->getOperand(Num: 1);
6417
6418 // Canonicalize the multiply so the constant (if any) is on the right
6419 if (N->getOpcode() == ISD::MUL) {
6420 if (isa<ConstantSDNode>(Val: LHS)) {
6421 std::swap(a&: LHS, b&: RHS);
6422 }
6423 }
6424
6425 // If we have a SHL, determine the actual multiply amount
6426 if (N->getOpcode() == ISD::SHL) {
6427 ConstantSDNode *ShlRHS = dyn_cast<ConstantSDNode>(Val&: RHS);
6428 if (!ShlRHS) {
6429 return SDValue();
6430 }
6431
6432 APInt ShiftAmt = ShlRHS->getAPIntValue();
6433 unsigned BitWidth = MulType.getSizeInBits();
6434 if (ShiftAmt.sge(RHS: 0) && ShiftAmt.slt(RHS: BitWidth)) {
6435 APInt MulVal = APInt(BitWidth, 1) << ShiftAmt;
6436 RHS = DCI.DAG.getConstant(Val: MulVal, DL, VT: MulType);
6437 } else {
6438 return SDValue();
6439 }
6440 }
6441
6442 bool Signed;
6443 // Verify that our operands are demotable
6444 if (!AreMulWideOperandsDemotable(LHS, RHS, OptSize, IsSigned&: Signed)) {
6445 return SDValue();
6446 }
6447
6448 EVT DemotedVT;
6449 if (MulType == MVT::i32) {
6450 DemotedVT = MVT::i16;
6451 } else {
6452 DemotedVT = MVT::i32;
6453 }
6454
6455 // Truncate the operands to the correct size. Note that these are just for
6456 // type consistency and will (likely) be eliminated in later phases.
6457 SDValue TruncLHS =
6458 DCI.DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT: DemotedVT, Operand: LHS);
6459 SDValue TruncRHS =
6460 DCI.DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT: DemotedVT, Operand: RHS);
6461
6462 unsigned Opc;
6463 if (Signed) {
6464 Opc = NVPTXISD::MUL_WIDE_SIGNED;
6465 } else {
6466 Opc = NVPTXISD::MUL_WIDE_UNSIGNED;
6467 }
6468
6469 return DCI.DAG.getNode(Opcode: Opc, DL, VT: MulType, N1: TruncLHS, N2: TruncRHS);
6470}
6471
6472static bool isConstOne(const SDValue &Operand) {
6473 const auto *Const = dyn_cast<ConstantSDNode>(Val: Operand);
6474 return Const && Const->getZExtValue() == 1;
6475}
6476
6477static SDValue matchMADConstOnePattern(SDValue Add) {
6478 if (Add->getOpcode() != ISD::ADD)
6479 return SDValue();
6480
6481 if (isConstOne(Operand: Add->getOperand(Num: 0)))
6482 return Add->getOperand(Num: 1);
6483
6484 if (isConstOne(Operand: Add->getOperand(Num: 1)))
6485 return Add->getOperand(Num: 0);
6486
6487 return SDValue();
6488}
6489
6490static SDValue combineMADConstOne(SDValue X, SDValue Add, EVT VT, SDLoc DL,
6491 TargetLowering::DAGCombinerInfo &DCI) {
6492
6493 if (SDValue Y = matchMADConstOnePattern(Add)) {
6494 SDValue Mul = DCI.DAG.getNode(Opcode: ISD::MUL, DL, VT, N1: X, N2: Y);
6495 return DCI.DAG.getNode(Opcode: ISD::ADD, DL, VT, N1: Mul, N2: X);
6496 }
6497
6498 return SDValue();
6499}
6500
6501static SDValue combineMulSelectConstOne(SDValue X, SDValue Select, EVT VT,
6502 SDLoc DL,
6503 TargetLowering::DAGCombinerInfo &DCI) {
6504 if (Select->getOpcode() != ISD::SELECT)
6505 return SDValue();
6506
6507 SDValue Cond = Select->getOperand(Num: 0);
6508
6509 unsigned ConstOpNo;
6510 if (isConstOne(Operand: Select->getOperand(Num: 1)))
6511 ConstOpNo = 1;
6512 else if (isConstOne(Operand: Select->getOperand(Num: 2)))
6513 ConstOpNo = 2;
6514 else
6515 return SDValue();
6516
6517 SDValue Y = Select->getOperand(Num: (ConstOpNo == 1) ? 2 : 1);
6518
6519 // Do not combine if the resulting sequence is not obviously profitable.
6520 if (!matchMADConstOnePattern(Add: Y))
6521 return SDValue();
6522
6523 SDValue NewMul = DCI.DAG.getNode(Opcode: ISD::MUL, DL, VT, N1: X, N2: Y);
6524
6525 return DCI.DAG.getNode(Opcode: ISD::SELECT, DL, VT, N1: Cond,
6526 N2: (ConstOpNo == 1) ? X : NewMul,
6527 N3: (ConstOpNo == 1) ? NewMul : X);
6528}
6529
6530static SDValue
6531PerformMULCombineWithOperands(SDNode *N, SDValue N0, SDValue N1,
6532 TargetLowering::DAGCombinerInfo &DCI) {
6533
6534 EVT VT = N0.getValueType();
6535 if (VT.isVector())
6536 return SDValue();
6537
6538 if (VT != MVT::i16 && VT != MVT::i32 && VT != MVT::i64)
6539 return SDValue();
6540
6541 SDLoc DL(N);
6542
6543 // (mul x, (add y, 1)) -> (add (mul x, y), x)
6544 if (SDValue Res = combineMADConstOne(X: N0, Add: N1, VT, DL, DCI))
6545 return Res;
6546 if (SDValue Res = combineMADConstOne(X: N1, Add: N0, VT, DL, DCI))
6547 return Res;
6548
6549 // (mul x, (select y, 1)) -> (select (mul x, y), x)
6550 if (SDValue Res = combineMulSelectConstOne(X: N0, Select: N1, VT, DL, DCI))
6551 return Res;
6552 if (SDValue Res = combineMulSelectConstOne(X: N1, Select: N0, VT, DL, DCI))
6553 return Res;
6554
6555 return SDValue();
6556}
6557
6558/// PerformMULCombine - Runs PTX-specific DAG combine patterns on MUL nodes.
6559static SDValue PerformMULCombine(SDNode *N,
6560 TargetLowering::DAGCombinerInfo &DCI,
6561 CodeGenOptLevel OptLevel) {
6562 if (OptLevel == CodeGenOptLevel::None)
6563 return SDValue();
6564
6565 if (SDValue Ret = TryMULWIDECombine(N, DCI))
6566 return Ret;
6567
6568 SDValue N0 = N->getOperand(Num: 0);
6569 SDValue N1 = N->getOperand(Num: 1);
6570 return PerformMULCombineWithOperands(N, N0, N1, DCI);
6571}
6572
6573/// PerformSHLCombine - Runs PTX-specific DAG combine patterns on SHL nodes.
6574static SDValue PerformSHLCombine(SDNode *N,
6575 TargetLowering::DAGCombinerInfo &DCI,
6576 CodeGenOptLevel OptLevel) {
6577 if (OptLevel > CodeGenOptLevel::None) {
6578 // Try mul.wide combining at OptLevel > 0
6579 if (SDValue Ret = TryMULWIDECombine(N, DCI))
6580 return Ret;
6581 }
6582
6583 return SDValue();
6584}
6585
6586static SDValue PerformSETCCCombine(SDNode *N,
6587 TargetLowering::DAGCombinerInfo &DCI,
6588 unsigned int SmVersion) {
6589 EVT CCType = N->getValueType(ResNo: 0);
6590 SDValue A = N->getOperand(Num: 0);
6591 SDValue B = N->getOperand(Num: 1);
6592
6593 EVT AType = A.getValueType();
6594 if (!(CCType == MVT::v2i1 && (AType == MVT::v2f16 || AType == MVT::v2bf16)))
6595 return SDValue();
6596
6597 if (A.getValueType() == MVT::v2bf16 && SmVersion < 90)
6598 return SDValue();
6599
6600 SDLoc DL(N);
6601 // setp.f16x2 returns two scalar predicates, which we need to
6602 // convert back to v2i1. The returned result will be scalarized by
6603 // the legalizer, but the comparison will remain a single vector
6604 // instruction.
6605 SDValue CCNode = DCI.DAG.getNode(
6606 Opcode: A.getValueType() == MVT::v2f16 ? NVPTXISD::SETP_F16X2
6607 : NVPTXISD::SETP_BF16X2,
6608 DL, VTList: DCI.DAG.getVTList(VT1: MVT::i1, VT2: MVT::i1), Ops: {A, B, N->getOperand(Num: 2)});
6609 return DCI.DAG.getNode(Opcode: ISD::BUILD_VECTOR, DL, VT: CCType, N1: CCNode.getValue(R: 0),
6610 N2: CCNode.getValue(R: 1));
6611}
6612
6613static SDValue PerformEXTRACTCombine(SDNode *N,
6614 TargetLowering::DAGCombinerInfo &DCI) {
6615 SDValue Vector = N->getOperand(Num: 0);
6616 if (Vector->getOpcode() == ISD::FREEZE)
6617 Vector = Vector->getOperand(Num: 0);
6618 SDLoc DL(N);
6619 EVT VectorVT = Vector.getValueType();
6620 if (Vector->getOpcode() == ISD::LOAD && VectorVT.isSimple() &&
6621 IsPTXVectorType(VT: VectorVT.getSimpleVT()))
6622 return SDValue(); // Native vector loads already combine nicely w/
6623 // extract_vector_elt.
6624 // Don't mess with singletons or packed types (v2*32, v2*16, v4i8 and v8i8),
6625 // we already handle them OK.
6626 if (VectorVT.getVectorNumElements() == 1 ||
6627 NVPTX::isPackedVectorTy(VT: VectorVT) || VectorVT == MVT::v8i8)
6628 return SDValue();
6629
6630 // Don't mess with undef values as sra may be simplified to 0, not undef.
6631 if (Vector->isUndef() || ISD::allOperandsUndef(N: Vector.getNode()))
6632 return SDValue();
6633
6634 uint64_t VectorBits = VectorVT.getSizeInBits();
6635 // We only handle the types we can extract in-register.
6636 if (!(VectorBits == 16 || VectorBits == 32 || VectorBits == 64))
6637 return SDValue();
6638
6639 ConstantSDNode *Index = dyn_cast<ConstantSDNode>(Val: N->getOperand(Num: 1));
6640 // Index == 0 is handled by generic DAG combiner.
6641 if (!Index || Index->getZExtValue() == 0)
6642 return SDValue();
6643
6644 MVT IVT = MVT::getIntegerVT(BitWidth: VectorBits);
6645 EVT EltVT = VectorVT.getVectorElementType();
6646 EVT EltIVT = EltVT.changeTypeToInteger();
6647 uint64_t EltBits = EltVT.getScalarSizeInBits();
6648
6649 SDValue Result = DCI.DAG.getNode(
6650 Opcode: ISD::TRUNCATE, DL, VT: EltIVT,
6651 Operand: DCI.DAG.getNode(
6652 Opcode: ISD::SRA, DL, VT: IVT, N1: DCI.DAG.getNode(Opcode: ISD::BITCAST, DL, VT: IVT, Operand: Vector),
6653 N2: DCI.DAG.getConstant(Val: Index->getZExtValue() * EltBits, DL, VT: IVT)));
6654
6655 // If element has non-integer type, bitcast it back to the expected type.
6656 if (EltVT != EltIVT)
6657 Result = DCI.DAG.getNode(Opcode: ISD::BITCAST, DL, VT: EltVT, Operand: Result);
6658 // Past legalizer, we may need to extent i8 -> i16 to match the register type.
6659 if (EltVT != N->getValueType(ResNo: 0))
6660 Result = DCI.DAG.getNode(Opcode: ISD::ANY_EXTEND, DL, VT: N->getValueType(ResNo: 0), Operand: Result);
6661
6662 return Result;
6663}
6664
6665/// Transform patterns like:
6666/// (select (ugt shift_amt, BitWidth-1), 0, (srl/shl x, shift_amt))
6667/// (select (ult shift_amt, BitWidth), (srl/shl x, shift_amt), 0)
6668/// Into:
6669/// (NVPTXISD::SRL_CLAMP x, shift_amt) or (NVPTXISD::SHL_CLAMP x, shift_amt)
6670///
6671/// These patterns arise from C/C++ code like `shift >= 32 ? 0 : x >> shift`
6672/// which guards against undefined behavior. PTX shr/shl instructions clamp
6673/// shift amounts >= BitWidth to produce 0 for logical shifts, making the
6674/// guard redundant.
6675///
6676/// Note: We only handle SRL and SHL, not SRA, because arithmetic right
6677/// shifts could produce 0 or -1 when shift >= BitWidth.
6678/// Note: We don't handle uge or ule. These don't appear because of
6679/// canonicalization.
6680static SDValue PerformSELECTShiftCombine(SDNode *N,
6681 TargetLowering::DAGCombinerInfo &DCI) {
6682 if (!DCI.isAfterLegalizeDAG())
6683 return SDValue();
6684
6685 using namespace SDPatternMatch;
6686 unsigned BitWidth = N->getValueType(ResNo: 0).getSizeInBits();
6687 SDValue ShiftAmt, ShiftOp;
6688
6689 // Match logical shifts where the shift amount in the guard matches the shift
6690 // amount in the operation.
6691 auto LogicalShift =
6692 m_AllOf(preds: m_Value(N&: ShiftOp),
6693 preds: m_AnyOf(preds: m_Srl(L: m_Value(), R: m_TruncOrSelf(Op: m_Deferred(V&: ShiftAmt))),
6694 preds: m_Shl(L: m_Value(), R: m_TruncOrSelf(Op: m_Deferred(V&: ShiftAmt)))));
6695
6696 // shift_amt > BitWidth-1 ? 0 : shift_op
6697 bool MatchedUGT =
6698 sd_match(N, P: m_Select(Cond: m_SetCC(LHS: m_Value(N&: ShiftAmt),
6699 RHS: m_SpecificInt(V: APInt(BitWidth, BitWidth - 1)),
6700 CC: m_SpecificCondCode(CC: ISD::SETUGT)),
6701 T: m_Zero(), F: LogicalShift));
6702 // shift_amt < BitWidth ? shift_op : 0
6703 bool MatchedULT =
6704 !MatchedUGT &&
6705 sd_match(N, P: m_Select(Cond: m_SetCC(LHS: m_Value(N&: ShiftAmt),
6706 RHS: m_SpecificInt(V: APInt(BitWidth, BitWidth)),
6707 CC: m_SpecificCondCode(CC: ISD::SETULT)),
6708 T: LogicalShift, F: m_Zero()));
6709
6710 if (!MatchedUGT && !MatchedULT)
6711 return SDValue();
6712
6713 // Return a clamp shift operation, which has the same semantics as PTX shift.
6714 unsigned ClampOpc = ShiftOp.getOpcode() == ISD::SRL ? NVPTXISD::SRL_CLAMP
6715 : NVPTXISD::SHL_CLAMP;
6716 return DCI.DAG.getNode(Opcode: ClampOpc, DL: SDLoc(N), VT: ShiftOp.getValueType(),
6717 N1: ShiftOp.getOperand(i: 0), N2: ShiftOp.getOperand(i: 1));
6718}
6719
6720static SDValue PerformVSELECTCombine(SDNode *N,
6721 TargetLowering::DAGCombinerInfo &DCI) {
6722 SDValue VA = N->getOperand(Num: 1);
6723 EVT VectorVT = VA.getValueType();
6724 if (VectorVT != MVT::v4i8)
6725 return SDValue();
6726
6727 // We need to split vselect into individual per-element operations Because we
6728 // use BFE/BFI instruction for byte extraction/insertion, we do end up with
6729 // 32-bit values, so we may as well do comparison as i32 to avoid conversions
6730 // to/from i16 normally used for i8 values.
6731 SmallVector<SDValue, 4> E;
6732 SDLoc DL(N);
6733 SDValue VCond = N->getOperand(Num: 0);
6734 SDValue VB = N->getOperand(Num: 2);
6735 for (int I = 0; I < 4; ++I) {
6736 SDValue C = DCI.DAG.getNode(Opcode: ISD::EXTRACT_VECTOR_ELT, DL, VT: MVT::i1, N1: VCond,
6737 N2: DCI.DAG.getConstant(Val: I, DL, VT: MVT::i32));
6738 SDValue EA = DCI.DAG.getAnyExtOrTrunc(
6739 Op: DCI.DAG.getNode(Opcode: ISD::EXTRACT_VECTOR_ELT, DL, VT: MVT::i8, N1: VA,
6740 N2: DCI.DAG.getConstant(Val: I, DL, VT: MVT::i32)),
6741 DL, VT: MVT::i32);
6742 SDValue EB = DCI.DAG.getAnyExtOrTrunc(
6743 Op: DCI.DAG.getNode(Opcode: ISD::EXTRACT_VECTOR_ELT, DL, VT: MVT::i8, N1: VB,
6744 N2: DCI.DAG.getConstant(Val: I, DL, VT: MVT::i32)),
6745 DL, VT: MVT::i32);
6746 E.push_back(Elt: DCI.DAG.getAnyExtOrTrunc(
6747 Op: DCI.DAG.getNode(Opcode: ISD::SELECT, DL, VT: MVT::i32, N1: C, N2: EA, N3: EB), DL, VT: MVT::i8));
6748 }
6749 return DCI.DAG.getNode(Opcode: ISD::BUILD_VECTOR, DL, VT: MVT::v4i8, Ops: E);
6750}
6751
6752static SDValue
6753PerformBUILD_VECTORCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) {
6754 auto VT = N->getValueType(ResNo: 0);
6755 if (!DCI.isAfterLegalizeDAG() ||
6756 // only process v2*16 types
6757 !(NVPTX::isPackedVectorTy(VT) && VT.is32BitVector() &&
6758 VT.getVectorNumElements() == 2))
6759 return SDValue();
6760
6761 auto Op0 = N->getOperand(Num: 0);
6762 auto Op1 = N->getOperand(Num: 1);
6763
6764 // Start out by assuming we want to take the lower 2 bytes of each i32
6765 // operand.
6766 uint64_t Op0Bytes = 0x10;
6767 uint64_t Op1Bytes = 0x54;
6768
6769 std::pair<SDValue *, uint64_t *> OpData[2] = {{&Op0, &Op0Bytes},
6770 {&Op1, &Op1Bytes}};
6771
6772 // Check that each operand is an i16, truncated from an i32 operand. We'll
6773 // select individual bytes from those original operands. Optionally, fold in a
6774 // shift right of that original operand.
6775 for (auto &[Op, OpBytes] : OpData) {
6776 // Eat up any bitcast
6777 if (Op->getOpcode() == ISD::BITCAST)
6778 *Op = Op->getOperand(i: 0);
6779
6780 if (!(Op->getValueType() == MVT::i16 && Op->getOpcode() == ISD::TRUNCATE &&
6781 Op->getOperand(i: 0).getValueType() == MVT::i32))
6782 return SDValue();
6783
6784 // If the truncate has multiple uses, this optimization can increase
6785 // register pressure
6786 if (!Op->hasOneUse())
6787 return SDValue();
6788
6789 *Op = Op->getOperand(i: 0);
6790
6791 // Optionally, fold in a shift-right of the original operand and let permute
6792 // pick the two higher bytes of the original value directly.
6793 if (Op->getOpcode() == ISD::SRL && isa<ConstantSDNode>(Val: Op->getOperand(i: 1))) {
6794 if (cast<ConstantSDNode>(Val: Op->getOperand(i: 1))->getZExtValue() == 16) {
6795 // Shift the PRMT byte selector to pick upper bytes from each respective
6796 // value, instead of the lower ones: 0x10 -> 0x32, 0x54 -> 0x76
6797 assert((*OpBytes == 0x10 || *OpBytes == 0x54) &&
6798 "PRMT selector values out of range");
6799 *OpBytes += 0x22;
6800 *Op = Op->getOperand(i: 0);
6801 }
6802 }
6803 }
6804
6805 SDLoc DL(N);
6806 auto &DAG = DCI.DAG;
6807
6808 auto PRMT =
6809 getPRMT(A: DAG.getBitcast(VT: MVT::i32, V: Op0), B: DAG.getBitcast(VT: MVT::i32, V: Op1),
6810 Selector: (Op1Bytes << 8) | Op0Bytes, DL, DAG);
6811 return DAG.getBitcast(VT, V: PRMT);
6812}
6813
6814static SDValue combineADDRSPACECAST(SDNode *N,
6815 TargetLowering::DAGCombinerInfo &DCI) {
6816 auto *ASCN1 = cast<AddrSpaceCastSDNode>(Val: N);
6817
6818 if (auto *ASCN2 = dyn_cast<AddrSpaceCastSDNode>(Val: ASCN1->getOperand(Num: 0))) {
6819 assert(ASCN2->getDestAddressSpace() == ASCN1->getSrcAddressSpace());
6820
6821 // Fold asc[B -> A](asc[A -> B](x)) -> x
6822 if (ASCN1->getDestAddressSpace() == ASCN2->getSrcAddressSpace())
6823 return ASCN2->getOperand(Num: 0);
6824 }
6825
6826 return SDValue();
6827}
6828
6829// Given a constant selector value and a prmt mode, return the selector value
6830// normalized to the generic prmt mode. See the PTX ISA documentation for more
6831// details:
6832// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-prmt
6833static APInt getPRMTSelector(const APInt &Selector, unsigned Mode) {
6834 assert(Selector.getBitWidth() == 32 && "PRMT must have i32 operands");
6835
6836 if (Mode == NVPTX::PTXPrmtMode::NONE)
6837 return Selector;
6838
6839 const unsigned V = Selector.trunc(width: 2).getZExtValue();
6840
6841 const auto GetSelector = [](unsigned S0, unsigned S1, unsigned S2,
6842 unsigned S3) {
6843 return APInt(32, S0 | (S1 << 4) | (S2 << 8) | (S3 << 12));
6844 };
6845
6846 switch (Mode) {
6847 case NVPTX::PTXPrmtMode::F4E:
6848 return GetSelector(V, V + 1, V + 2, V + 3);
6849 case NVPTX::PTXPrmtMode::B4E:
6850 return GetSelector(V, (V - 1) & 7, (V - 2) & 7, (V - 3) & 7);
6851 case NVPTX::PTXPrmtMode::RC8:
6852 return GetSelector(V, V, V, V);
6853 case NVPTX::PTXPrmtMode::ECL:
6854 return GetSelector(V, std::max(a: V, b: 1U), std::max(a: V, b: 2U), 3U);
6855 case NVPTX::PTXPrmtMode::ECR:
6856 return GetSelector(0, std::min(a: V, b: 1U), std::min(a: V, b: 2U), V);
6857 case NVPTX::PTXPrmtMode::RC16: {
6858 unsigned V1 = (V & 1) << 1;
6859 return GetSelector(V1, V1 + 1, V1, V1 + 1);
6860 }
6861 default:
6862 llvm_unreachable("Invalid PRMT mode");
6863 }
6864}
6865
6866static APInt computePRMT(APInt A, APInt B, APInt Selector, unsigned Mode) {
6867 assert(A.getBitWidth() == 32 && B.getBitWidth() == 32 &&
6868 Selector.getBitWidth() == 32 && "PRMT must have i32 operands");
6869 // {b, a} = {{b7, b6, b5, b4}, {b3, b2, b1, b0}}
6870 APInt BitField = B.concat(NewLSB: A);
6871 APInt SelectorVal = getPRMTSelector(Selector, Mode);
6872 APInt Result(32, 0);
6873 for (unsigned I : llvm::seq(Size: 4U)) {
6874 APInt Sel = SelectorVal.extractBits(numBits: 4, bitPosition: I * 4);
6875 unsigned Idx = Sel.getLoBits(numBits: 3).getZExtValue();
6876 unsigned Sign = Sel.getHiBits(numBits: 1).getZExtValue();
6877 APInt Byte = BitField.extractBits(numBits: 8, bitPosition: Idx * 8);
6878 if (Sign)
6879 Byte = Byte.ashr(ShiftAmt: 8);
6880 Result.insertBits(SubBits: Byte, bitPosition: I * 8);
6881 }
6882 return Result;
6883}
6884
6885static SDValue combinePRMT(SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
6886 CodeGenOptLevel OptLevel) {
6887 if (OptLevel == CodeGenOptLevel::None)
6888 return SDValue();
6889
6890 // Constant fold PRMT
6891 if (isa<ConstantSDNode>(Val: N->getOperand(Num: 0)) &&
6892 isa<ConstantSDNode>(Val: N->getOperand(Num: 1)) &&
6893 isa<ConstantSDNode>(Val: N->getOperand(Num: 2)))
6894 return DCI.DAG.getConstant(Val: computePRMT(A: N->getConstantOperandAPInt(Num: 0),
6895 B: N->getConstantOperandAPInt(Num: 1),
6896 Selector: N->getConstantOperandAPInt(Num: 2),
6897 Mode: N->getConstantOperandVal(Num: 3)),
6898 DL: SDLoc(N), VT: N->getValueType(ResNo: 0));
6899 return SDValue();
6900}
6901
6902// During call lowering we wrap the return values in a ProxyReg node which
6903// depend on the chain value produced by the completed call. This ensures that
6904// the full call is emitted in cases where libcalls are used to legalize
6905// operations. To improve the functioning of other DAG combines we pull all
6906// operations we can through one of these nodes, ensuring that the ProxyReg
6907// directly wraps a load. That is:
6908//
6909// (ProxyReg (zext (load retval0))) => (zext (ProxyReg (load retval0)))
6910//
6911static SDValue sinkProxyReg(SDValue R, SDValue Chain,
6912 TargetLowering::DAGCombinerInfo &DCI) {
6913 switch (R.getOpcode()) {
6914 case ISD::TRUNCATE:
6915 case ISD::ANY_EXTEND:
6916 case ISD::SIGN_EXTEND:
6917 case ISD::ZERO_EXTEND:
6918 case ISD::BITCAST: {
6919 if (SDValue V = sinkProxyReg(R: R.getOperand(i: 0), Chain, DCI))
6920 return DCI.DAG.getNode(Opcode: R.getOpcode(), DL: SDLoc(R), VT: R.getValueType(), Operand: V);
6921 return SDValue();
6922 }
6923 case ISD::SHL:
6924 case ISD::SRL:
6925 case ISD::SRA:
6926 case ISD::OR: {
6927 if (SDValue A = sinkProxyReg(R: R.getOperand(i: 0), Chain, DCI))
6928 if (SDValue B = sinkProxyReg(R: R.getOperand(i: 1), Chain, DCI))
6929 return DCI.DAG.getNode(Opcode: R.getOpcode(), DL: SDLoc(R), VT: R.getValueType(), N1: A, N2: B);
6930 return SDValue();
6931 }
6932 case ISD::Constant:
6933 return R;
6934 case ISD::LOAD:
6935 case NVPTXISD::LoadV2:
6936 case NVPTXISD::LoadV4: {
6937 return DCI.DAG.getNode(Opcode: NVPTXISD::ProxyReg, DL: SDLoc(R), VT: R.getValueType(),
6938 Ops: {Chain, R});
6939 }
6940 case ISD::BUILD_VECTOR: {
6941 if (DCI.isBeforeLegalize())
6942 return SDValue();
6943
6944 SmallVector<SDValue, 16> Ops;
6945 for (auto &Op : R->ops()) {
6946 SDValue V = sinkProxyReg(R: Op, Chain, DCI);
6947 if (!V)
6948 return SDValue();
6949 Ops.push_back(Elt: V);
6950 }
6951 return DCI.DAG.getNode(Opcode: ISD::BUILD_VECTOR, DL: SDLoc(R), VT: R.getValueType(), Ops);
6952 }
6953 case ISD::EXTRACT_VECTOR_ELT: {
6954 if (DCI.isBeforeLegalize())
6955 return SDValue();
6956
6957 if (SDValue V = sinkProxyReg(R: R.getOperand(i: 0), Chain, DCI))
6958 return DCI.DAG.getNode(Opcode: ISD::EXTRACT_VECTOR_ELT, DL: SDLoc(R),
6959 VT: R.getValueType(), N1: V, N2: R.getOperand(i: 1));
6960 return SDValue();
6961 }
6962 default:
6963 return SDValue();
6964 }
6965}
6966
6967static unsigned getF16SubOpc(Intrinsic::ID AddIntrinsicID) {
6968 switch (AddIntrinsicID) {
6969 default:
6970 break;
6971 case Intrinsic::nvvm_add_rn_sat_f16:
6972 case Intrinsic::nvvm_add_rn_sat_v2f16:
6973 return NVPTXISD::SUB_RN_SAT;
6974 case Intrinsic::nvvm_add_rn_ftz_sat_f16:
6975 case Intrinsic::nvvm_add_rn_ftz_sat_v2f16:
6976 return NVPTXISD::SUB_RN_FTZ_SAT;
6977 }
6978 llvm_unreachable("Invalid F16 add intrinsic");
6979}
6980
6981static SDValue combineF16AddWithNeg(SDNode *N, SelectionDAG &DAG,
6982 Intrinsic::ID AddIntrinsicID) {
6983 SDValue Op1 = N->getOperand(Num: 1);
6984 SDValue Op2 = N->getOperand(Num: 2);
6985
6986 SDValue SubOp1, SubOp2;
6987
6988 if (Op1.getOpcode() == ISD::FNEG) {
6989 SubOp1 = Op2;
6990 SubOp2 = Op1.getOperand(i: 0);
6991 } else if (Op2.getOpcode() == ISD::FNEG) {
6992 SubOp1 = Op1;
6993 SubOp2 = Op2.getOperand(i: 0);
6994 } else {
6995 return SDValue();
6996 }
6997
6998 SDLoc DL(N);
6999 return DAG.getNode(Opcode: getF16SubOpc(AddIntrinsicID), DL, VT: N->getValueType(ResNo: 0),
7000 N1: SubOp1, N2: SubOp2);
7001}
7002
7003static SDValue combineIntrinsicWOChain(SDNode *N,
7004 TargetLowering::DAGCombinerInfo &DCI,
7005 const NVPTXSubtarget &STI) {
7006 unsigned IID = N->getConstantOperandVal(Num: 0);
7007
7008 switch (IID) {
7009 default:
7010 break;
7011 case Intrinsic::nvvm_add_rn_sat_f16:
7012 case Intrinsic::nvvm_add_rn_ftz_sat_f16:
7013 case Intrinsic::nvvm_add_rn_sat_v2f16:
7014 case Intrinsic::nvvm_add_rn_ftz_sat_v2f16:
7015 return combineF16AddWithNeg(N, DAG&: DCI.DAG, AddIntrinsicID: IID);
7016 }
7017 return SDValue();
7018}
7019
7020static SDValue combineProxyReg(SDNode *N,
7021 TargetLowering::DAGCombinerInfo &DCI) {
7022
7023 SDValue Chain = N->getOperand(Num: 0);
7024 SDValue Reg = N->getOperand(Num: 1);
7025
7026 // If the ProxyReg is not wrapping a load, try to pull the operations through
7027 // the ProxyReg.
7028 if (Reg.getOpcode() != ISD::LOAD) {
7029 if (SDValue V = sinkProxyReg(R: Reg, Chain, DCI))
7030 return V;
7031 }
7032
7033 return SDValue();
7034}
7035
7036SDValue NVPTXTargetLowering::PerformDAGCombine(SDNode *N,
7037 DAGCombinerInfo &DCI) const {
7038 CodeGenOptLevel OptLevel = getTargetMachine().getOptLevel();
7039 switch (N->getOpcode()) {
7040 default:
7041 break;
7042 case ISD::ADD:
7043 return PerformADDCombine(N, DCI, OptLevel);
7044 case ISD::ADDRSPACECAST:
7045 return combineADDRSPACECAST(N, DCI);
7046 case ISD::SIGN_EXTEND:
7047 case ISD::ZERO_EXTEND:
7048 return combineMulWide(N, DCI, OptLevel);
7049 case ISD::BUILD_VECTOR:
7050 return PerformBUILD_VECTORCombine(N, DCI);
7051 case ISD::EXTRACT_VECTOR_ELT:
7052 return PerformEXTRACTCombine(N, DCI);
7053 case ISD::FADD:
7054 return PerformFADDCombine(N, DCI, OptLevel);
7055 case ISD::FMA:
7056 case ISD::FMUL:
7057 case ISD::FSUB:
7058 return PerformScalarizeV2F32Op(N, DCI);
7059 case ISD::FMAXNUM:
7060 case ISD::FMINNUM:
7061 case ISD::FMAXIMUM:
7062 case ISD::FMINIMUM:
7063 case ISD::FMAXIMUMNUM:
7064 case ISD::FMINIMUMNUM:
7065 return PerformFMinMaxCombine(N, DCI, PTXVersion: STI.getPTXVersion(),
7066 SmVersion: STI.getSmVersion());
7067 case ISD::LOAD:
7068 case NVPTXISD::LoadV2:
7069 case NVPTXISD::LoadV4:
7070 return combineLOAD(N, DCI, STI);
7071 case ISD::MUL:
7072 return PerformMULCombine(N, DCI, OptLevel);
7073 case NVPTXISD::PRMT:
7074 return combinePRMT(N, DCI, OptLevel);
7075 case NVPTXISD::ProxyReg:
7076 return combineProxyReg(N, DCI);
7077 case ISD::SETCC:
7078 return PerformSETCCCombine(N, DCI, SmVersion: STI.getSmVersion());
7079 case ISD::SHL:
7080 return PerformSHLCombine(N, DCI, OptLevel);
7081 case ISD::SREM:
7082 case ISD::UREM:
7083 return PerformREMCombine(N, DCI, OptLevel);
7084 case ISD::STORE:
7085 case NVPTXISD::StoreV2:
7086 case NVPTXISD::StoreV4:
7087 return combineSTORE(N, DCI, STI);
7088 case ISD::SELECT:
7089 return PerformSELECTShiftCombine(N, DCI);
7090 case ISD::VSELECT:
7091 return PerformVSELECTCombine(N, DCI);
7092 case ISD::INTRINSIC_WO_CHAIN:
7093 return combineIntrinsicWOChain(N, DCI, STI);
7094 }
7095 return SDValue();
7096}
7097
7098static void ReplaceBITCAST(SDNode *Node, SelectionDAG &DAG,
7099 SmallVectorImpl<SDValue> &Results) {
7100 // Handle bitcasting to v2i8 without hitting the default promotion
7101 // strategy which goes through stack memory.
7102 SDValue Op(Node, 0);
7103 EVT ToVT = Op->getValueType(ResNo: 0);
7104 if (ToVT != MVT::v2i8) {
7105 return;
7106 }
7107
7108 // Bitcast to i16 and unpack elements into a vector
7109 SDLoc DL(Node);
7110 SDValue AsInt = DAG.getBitcast(VT: MVT::i16, V: Op->getOperand(Num: 0));
7111 SDValue Vec0 = DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT: MVT::i8, Operand: AsInt);
7112 SDValue Const8 = DAG.getConstant(Val: 8, DL, VT: MVT::i16);
7113 SDValue Vec1 =
7114 DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT: MVT::i8,
7115 Operand: DAG.getNode(Opcode: ISD::SRL, DL, VT: MVT::i16, Ops: {AsInt, Const8}));
7116 Results.push_back(
7117 Elt: DAG.getNode(Opcode: ISD::BUILD_VECTOR, DL, VT: MVT::v2i8, Ops: {Vec0, Vec1}));
7118}
7119
7120static void ReplaceINTRINSIC_W_CHAIN(SDNode *N, SelectionDAG &DAG,
7121 SmallVectorImpl<SDValue> &Results) {
7122 SDValue Chain = N->getOperand(Num: 0);
7123 SDValue Intrin = N->getOperand(Num: 1);
7124 SDLoc DL(N);
7125
7126 // Get the intrinsic ID
7127 unsigned IntrinNo = Intrin.getNode()->getAsZExtVal();
7128 switch (IntrinNo) {
7129 default:
7130 return;
7131 case Intrinsic::nvvm_ldu_global_i:
7132 case Intrinsic::nvvm_ldu_global_f:
7133 case Intrinsic::nvvm_ldu_global_p: {
7134 EVT ResVT = N->getValueType(ResNo: 0);
7135
7136 if (ResVT.isVector()) {
7137 // Vector LDG/LDU
7138
7139 unsigned NumElts = ResVT.getVectorNumElements();
7140 EVT EltVT = ResVT.getVectorElementType();
7141
7142 // Since LDU/LDG are target nodes, we cannot rely on DAG type
7143 // legalization.
7144 // Therefore, we must ensure the type is legal. For i1 and i8, we set the
7145 // loaded type to i16 and propagate the "real" type as the memory type.
7146 bool NeedTrunc = false;
7147 if (EltVT.getSizeInBits() < 16) {
7148 EltVT = MVT::i16;
7149 NeedTrunc = true;
7150 }
7151
7152 unsigned Opcode = 0;
7153 SDVTList LdResVTs;
7154
7155 switch (NumElts) {
7156 default:
7157 return;
7158 case 2:
7159 Opcode = NVPTXISD::LDUV2;
7160 LdResVTs = DAG.getVTList(VT1: EltVT, VT2: EltVT, VT3: MVT::Other);
7161 break;
7162 case 4: {
7163 Opcode = NVPTXISD::LDUV4;
7164 EVT ListVTs[] = { EltVT, EltVT, EltVT, EltVT, MVT::Other };
7165 LdResVTs = DAG.getVTList(VTs: ListVTs);
7166 break;
7167 }
7168 }
7169
7170 SmallVector<SDValue, 8> OtherOps;
7171
7172 // Copy regular operands
7173
7174 OtherOps.push_back(Elt: Chain); // Chain
7175 // Skip operand 1 (intrinsic ID)
7176 // Others
7177 OtherOps.append(in_start: N->op_begin() + 2, in_end: N->op_end());
7178
7179 MemIntrinsicSDNode *MemSD = cast<MemIntrinsicSDNode>(Val: N);
7180
7181 SDValue NewLD = DAG.getMemIntrinsicNode(Opcode, dl: DL, VTList: LdResVTs, Ops: OtherOps,
7182 MemVT: MemSD->getMemoryVT(),
7183 MMO: MemSD->getMemOperand());
7184
7185 SmallVector<SDValue, 4> ScalarRes;
7186
7187 for (unsigned i = 0; i < NumElts; ++i) {
7188 SDValue Res = NewLD.getValue(R: i);
7189 if (NeedTrunc)
7190 Res =
7191 DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT: ResVT.getVectorElementType(), Operand: Res);
7192 ScalarRes.push_back(Elt: Res);
7193 }
7194
7195 SDValue LoadChain = NewLD.getValue(R: NumElts);
7196
7197 SDValue BuildVec =
7198 DAG.getBuildVector(VT: ResVT, DL, Ops: ScalarRes);
7199
7200 Results.push_back(Elt: BuildVec);
7201 Results.push_back(Elt: LoadChain);
7202 } else {
7203 // i8 LDG/LDU
7204 assert(ResVT.isSimple() && ResVT.getSimpleVT().SimpleTy == MVT::i8 &&
7205 "Custom handling of non-i8 ldu/ldg?");
7206
7207 // Just copy all operands as-is
7208 SmallVector<SDValue, 4> Ops(N->ops());
7209
7210 // Force output to i16
7211 SDVTList LdResVTs = DAG.getVTList(VT1: MVT::i16, VT2: MVT::Other);
7212
7213 MemIntrinsicSDNode *MemSD = cast<MemIntrinsicSDNode>(Val: N);
7214
7215 // We make sure the memory type is i8, which will be used during isel
7216 // to select the proper instruction.
7217 SDValue NewLD =
7218 DAG.getMemIntrinsicNode(Opcode: ISD::INTRINSIC_W_CHAIN, dl: DL, VTList: LdResVTs, Ops,
7219 MemVT: MVT::i8, MMO: MemSD->getMemOperand());
7220
7221 Results.push_back(Elt: DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT: MVT::i8,
7222 Operand: NewLD.getValue(R: 0)));
7223 Results.push_back(Elt: NewLD.getValue(R: 1));
7224 }
7225 return;
7226 }
7227
7228 case Intrinsic::nvvm_tcgen05_ld_16x64b_x4:
7229 case Intrinsic::nvvm_tcgen05_ld_16x64b_x8:
7230 case Intrinsic::nvvm_tcgen05_ld_16x64b_x16:
7231 case Intrinsic::nvvm_tcgen05_ld_16x64b_x32:
7232 case Intrinsic::nvvm_tcgen05_ld_16x64b_x64:
7233 case Intrinsic::nvvm_tcgen05_ld_16x64b_x128:
7234 case Intrinsic::nvvm_tcgen05_ld_32x32b_x4:
7235 case Intrinsic::nvvm_tcgen05_ld_32x32b_x8:
7236 case Intrinsic::nvvm_tcgen05_ld_32x32b_x16:
7237 case Intrinsic::nvvm_tcgen05_ld_32x32b_x32:
7238 case Intrinsic::nvvm_tcgen05_ld_32x32b_x64:
7239 case Intrinsic::nvvm_tcgen05_ld_32x32b_x128:
7240 case Intrinsic::nvvm_tcgen05_ld_16x128b_x2:
7241 case Intrinsic::nvvm_tcgen05_ld_16x128b_x4:
7242 case Intrinsic::nvvm_tcgen05_ld_16x128b_x8:
7243 case Intrinsic::nvvm_tcgen05_ld_16x128b_x16:
7244 case Intrinsic::nvvm_tcgen05_ld_16x128b_x32:
7245 case Intrinsic::nvvm_tcgen05_ld_16x128b_x64:
7246 case Intrinsic::nvvm_tcgen05_ld_16x256b_x1:
7247 case Intrinsic::nvvm_tcgen05_ld_16x256b_x2:
7248 case Intrinsic::nvvm_tcgen05_ld_16x256b_x4:
7249 case Intrinsic::nvvm_tcgen05_ld_16x256b_x8:
7250 case Intrinsic::nvvm_tcgen05_ld_16x256b_x16:
7251 case Intrinsic::nvvm_tcgen05_ld_16x256b_x32:
7252 if (auto Res = lowerTcgen05Ld(N, DAG)) {
7253 Results.push_back(Elt: Res->first);
7254 Results.push_back(Elt: Res->second);
7255 }
7256 return;
7257
7258 case Intrinsic::nvvm_tcgen05_ld_16x32bx2_x4:
7259 case Intrinsic::nvvm_tcgen05_ld_16x32bx2_x8:
7260 case Intrinsic::nvvm_tcgen05_ld_16x32bx2_x16:
7261 case Intrinsic::nvvm_tcgen05_ld_16x32bx2_x32:
7262 case Intrinsic::nvvm_tcgen05_ld_16x32bx2_x64:
7263 case Intrinsic::nvvm_tcgen05_ld_16x32bx2_x128:
7264 if (auto Res = lowerTcgen05Ld(N, DAG, /*HasOffset=*/true)) {
7265 Results.push_back(Elt: Res->first);
7266 Results.push_back(Elt: Res->second);
7267 }
7268 return;
7269
7270 case Intrinsic::nvvm_tcgen05_ld_red_32x32b_x8_i32:
7271 case Intrinsic::nvvm_tcgen05_ld_red_32x32b_x8_f32:
7272 case Intrinsic::nvvm_tcgen05_ld_red_32x32b_x64_i32:
7273 case Intrinsic::nvvm_tcgen05_ld_red_32x32b_x64_f32:
7274 case Intrinsic::nvvm_tcgen05_ld_red_32x32b_x4_i32:
7275 case Intrinsic::nvvm_tcgen05_ld_red_32x32b_x4_f32:
7276 case Intrinsic::nvvm_tcgen05_ld_red_32x32b_x32_i32:
7277 case Intrinsic::nvvm_tcgen05_ld_red_32x32b_x32_f32:
7278 case Intrinsic::nvvm_tcgen05_ld_red_32x32b_x16_i32:
7279 case Intrinsic::nvvm_tcgen05_ld_red_32x32b_x16_f32:
7280 case Intrinsic::nvvm_tcgen05_ld_red_32x32b_x128_i32:
7281 case Intrinsic::nvvm_tcgen05_ld_red_32x32b_x128_f32:
7282 case Intrinsic::nvvm_tcgen05_ld_red_16x32bx2_x8_i32:
7283 case Intrinsic::nvvm_tcgen05_ld_red_16x32bx2_x8_f32:
7284 case Intrinsic::nvvm_tcgen05_ld_red_16x32bx2_x64_i32:
7285 case Intrinsic::nvvm_tcgen05_ld_red_16x32bx2_x64_f32:
7286 case Intrinsic::nvvm_tcgen05_ld_red_16x32bx2_x4_i32:
7287 case Intrinsic::nvvm_tcgen05_ld_red_16x32bx2_x4_f32:
7288 case Intrinsic::nvvm_tcgen05_ld_red_16x32bx2_x32_i32:
7289 case Intrinsic::nvvm_tcgen05_ld_red_16x32bx2_x32_f32:
7290 case Intrinsic::nvvm_tcgen05_ld_red_16x32bx2_x16_i32:
7291 case Intrinsic::nvvm_tcgen05_ld_red_16x32bx2_x16_f32:
7292 case Intrinsic::nvvm_tcgen05_ld_red_16x32bx2_x128_i32:
7293 case Intrinsic::nvvm_tcgen05_ld_red_16x32bx2_x128_f32:
7294 if (auto Res = lowerTcgen05LdRed(N, DAG)) {
7295 Results.push_back(Elt: std::get<0>(t&: *Res));
7296 Results.push_back(Elt: std::get<1>(t&: *Res));
7297 Results.push_back(Elt: std::get<2>(t&: *Res));
7298 }
7299 return;
7300 }
7301}
7302
7303static void ReplaceCopyFromReg_128(SDNode *N, SelectionDAG &DAG,
7304 SmallVectorImpl<SDValue> &Results) {
7305 // Change the CopyFromReg to output 2 64-bit results instead of a 128-bit
7306 // result so that it can pass the legalization
7307 SDLoc DL(N);
7308 SDValue Chain = N->getOperand(Num: 0);
7309 SDValue Reg = N->getOperand(Num: 1);
7310 SDValue Glue = N->getOperand(Num: 2);
7311
7312 assert(Reg.getValueType() == MVT::i128 &&
7313 "Custom lowering for CopyFromReg with 128-bit reg only");
7314 SmallVector<EVT, 4> ResultsType = {MVT::i64, MVT::i64, N->getValueType(ResNo: 1),
7315 N->getValueType(ResNo: 2)};
7316 SmallVector<SDValue, 3> NewOps = {Chain, Reg, Glue};
7317
7318 SDValue NewValue = DAG.getNode(Opcode: ISD::CopyFromReg, DL, ResultTys: ResultsType, Ops: NewOps);
7319 SDValue Pair = DAG.getNode(Opcode: ISD::BUILD_PAIR, DL, VT: MVT::i128,
7320 Ops: {NewValue.getValue(R: 0), NewValue.getValue(R: 1)});
7321
7322 Results.push_back(Elt: Pair);
7323 Results.push_back(Elt: NewValue.getValue(R: 2));
7324 Results.push_back(Elt: NewValue.getValue(R: 3));
7325}
7326
7327static void replaceProxyReg(SDNode *N, SelectionDAG &DAG,
7328 const TargetLowering &TLI,
7329 SmallVectorImpl<SDValue> &Results) {
7330 SDValue Chain = N->getOperand(Num: 0);
7331 SDValue Reg = N->getOperand(Num: 1);
7332
7333 MVT VT = TLI.getRegisterType(Context&: *DAG.getContext(), VT: Reg.getValueType());
7334
7335 SDValue NewReg = DAG.getAnyExtOrTrunc(Op: Reg, DL: SDLoc(N), VT);
7336 SDValue NewProxy =
7337 DAG.getNode(Opcode: NVPTXISD::ProxyReg, DL: SDLoc(N), VT, Ops: {Chain, NewReg});
7338 SDValue Res = DAG.getAnyExtOrTrunc(Op: NewProxy, DL: SDLoc(N), VT: N->getValueType(ResNo: 0));
7339
7340 Results.push_back(Elt: Res);
7341}
7342
7343static void replaceAtomicSwap128(SDNode *N, SelectionDAG &DAG,
7344 const NVPTXSubtarget &STI,
7345 SmallVectorImpl<SDValue> &Results) {
7346 assert(N->getValueType(0) == MVT::i128 &&
7347 "Custom lowering for atomic128 only supports i128");
7348
7349 AtomicSDNode *AN = cast<AtomicSDNode>(Val: N);
7350 SDLoc dl(N);
7351
7352 if (!STI.hasAtomSwap128()) {
7353 DAG.getContext()->diagnose(DI: DiagnosticInfoUnsupported(
7354 DAG.getMachineFunction().getFunction(),
7355 "Support for b128 atomics introduced in PTX ISA version 8.3 and "
7356 "requires target sm_90.",
7357 dl.getDebugLoc()));
7358
7359 Results.push_back(Elt: DAG.getUNDEF(VT: MVT::i128));
7360 Results.push_back(Elt: AN->getOperand(Num: 0)); // Chain
7361 return;
7362 }
7363
7364 SmallVector<SDValue, 6> Ops;
7365 Ops.push_back(Elt: AN->getOperand(Num: 0)); // Chain
7366 Ops.push_back(Elt: AN->getOperand(Num: 1)); // Ptr
7367 for (const auto &Op : AN->ops().drop_front(N: 2)) {
7368 // Low part
7369 Ops.push_back(Elt: DAG.getNode(Opcode: ISD::EXTRACT_ELEMENT, DL: dl, VT: MVT::i64, N1: Op,
7370 N2: DAG.getIntPtrConstant(Val: 0, DL: dl)));
7371 // High part
7372 Ops.push_back(Elt: DAG.getNode(Opcode: ISD::EXTRACT_ELEMENT, DL: dl, VT: MVT::i64, N1: Op,
7373 N2: DAG.getIntPtrConstant(Val: 1, DL: dl)));
7374 }
7375 unsigned Opcode = N->getOpcode() == ISD::ATOMIC_SWAP
7376 ? NVPTXISD::ATOMIC_SWAP_B128
7377 : NVPTXISD::ATOMIC_CMP_SWAP_B128;
7378 SDVTList Tys = DAG.getVTList(VT1: MVT::i64, VT2: MVT::i64, VT3: MVT::Other);
7379 SDValue Result = DAG.getMemIntrinsicNode(Opcode, dl, VTList: Tys, Ops, MemVT: MVT::i128,
7380 MMO: AN->getMemOperand());
7381 Results.push_back(Elt: DAG.getNode(Opcode: ISD::BUILD_PAIR, DL: dl, VT: MVT::i128,
7382 Ops: {Result.getValue(R: 0), Result.getValue(R: 1)}));
7383 Results.push_back(Elt: Result.getValue(R: 2));
7384}
7385
7386void NVPTXTargetLowering::ReplaceNodeResults(
7387 SDNode *N, SmallVectorImpl<SDValue> &Results, SelectionDAG &DAG) const {
7388 switch (N->getOpcode()) {
7389 default:
7390 report_fatal_error(reason: "Unhandled custom legalization");
7391 case ISD::BITCAST:
7392 ReplaceBITCAST(Node: N, DAG, Results);
7393 return;
7394 case ISD::LOAD:
7395 case ISD::MLOAD:
7396 replaceLoadVector(N, DAG, Results, STI);
7397 return;
7398 case ISD::INTRINSIC_W_CHAIN:
7399 ReplaceINTRINSIC_W_CHAIN(N, DAG, Results);
7400 return;
7401 case ISD::CopyFromReg:
7402 ReplaceCopyFromReg_128(N, DAG, Results);
7403 return;
7404 case NVPTXISD::ProxyReg:
7405 replaceProxyReg(N, DAG, TLI: *this, Results);
7406 return;
7407 case ISD::ATOMIC_CMP_SWAP:
7408 case ISD::ATOMIC_SWAP:
7409 replaceAtomicSwap128(N, DAG, STI, Results);
7410 return;
7411 }
7412}
7413
7414NVPTXTargetLowering::AtomicExpansionKind
7415NVPTXTargetLowering::shouldExpandAtomicRMWInIR(const AtomicRMWInst *AI) const {
7416 Type *Ty = AI->getValOperand()->getType();
7417
7418 if (AI->isFloatingPointOperation()) {
7419 if (AI->getOperation() == AtomicRMWInst::BinOp::FAdd) {
7420 if (Ty->isHalfTy() && STI.getSmVersion() >= 70 &&
7421 STI.getPTXVersion() >= 63)
7422 return AtomicExpansionKind::None;
7423 if (Ty->isBFloatTy() && STI.getSmVersion() >= 90 &&
7424 STI.getPTXVersion() >= 78)
7425 return AtomicExpansionKind::None;
7426 if (Ty->isFloatTy())
7427 return AtomicExpansionKind::None;
7428 if (Ty->isDoubleTy() && STI.hasAtomAddF64())
7429 return AtomicExpansionKind::None;
7430 }
7431 return AtomicExpansionKind::CmpXChg;
7432 }
7433
7434 assert(Ty->isIntegerTy() && "Ty should be integer at this point");
7435 const unsigned BitWidth = cast<IntegerType>(Val: Ty)->getBitWidth();
7436
7437 switch (AI->getOperation()) {
7438 default:
7439 return AtomicExpansionKind::CmpXChg;
7440 case AtomicRMWInst::BinOp::Xchg:
7441 if (BitWidth == 128)
7442 return AtomicExpansionKind::None;
7443 [[fallthrough]];
7444 case AtomicRMWInst::BinOp::And:
7445 case AtomicRMWInst::BinOp::Or:
7446 case AtomicRMWInst::BinOp::Xor:
7447 switch (BitWidth) {
7448 case 8:
7449 case 16:
7450 return AtomicExpansionKind::CmpXChg;
7451 case 32:
7452 return AtomicExpansionKind::None;
7453 case 64:
7454 if (STI.hasAtomBitwise64())
7455 return AtomicExpansionKind::None;
7456 return AtomicExpansionKind::CmpXChg;
7457 case 128:
7458 return AtomicExpansionKind::CmpXChg;
7459 default:
7460 llvm_unreachable("unsupported width encountered");
7461 }
7462 case AtomicRMWInst::BinOp::Add:
7463 case AtomicRMWInst::BinOp::Sub:
7464 case AtomicRMWInst::BinOp::Max:
7465 case AtomicRMWInst::BinOp::Min:
7466 case AtomicRMWInst::BinOp::UMax:
7467 case AtomicRMWInst::BinOp::UMin:
7468 switch (BitWidth) {
7469 case 8:
7470 case 16:
7471 return AtomicExpansionKind::CmpXChg;
7472 case 32:
7473 return AtomicExpansionKind::None;
7474 case 64:
7475 if (STI.hasAtomMinMax64())
7476 return AtomicExpansionKind::None;
7477 return AtomicExpansionKind::CmpXChg;
7478 case 128:
7479 return AtomicExpansionKind::CmpXChg;
7480 default:
7481 llvm_unreachable("unsupported width encountered");
7482 }
7483 case AtomicRMWInst::BinOp::UIncWrap:
7484 case AtomicRMWInst::BinOp::UDecWrap:
7485 switch (BitWidth) {
7486 case 32:
7487 return AtomicExpansionKind::None;
7488 case 8:
7489 case 16:
7490 case 64:
7491 case 128:
7492 return AtomicExpansionKind::CmpXChg;
7493 default:
7494 llvm_unreachable("unsupported width encountered");
7495 }
7496 }
7497
7498 return AtomicExpansionKind::CmpXChg;
7499}
7500
7501bool NVPTXTargetLowering::shouldInsertFencesForAtomic(
7502 const Instruction *I) const {
7503 // This function returns true iff the operation is emulated using a CAS-loop,
7504 // or if it has the memory order seq_cst (which is not natively supported in
7505 // the PTX `atom` instruction).
7506 //
7507 // atomicrmw and cmpxchg instructions not efficiently supported by PTX
7508 // are lowered to CAS emulation loops that preserve their memory order,
7509 // syncscope, and volatile semantics. For PTX, it is more efficient to use
7510 // atom.cas.relaxed.sco instructions within the loop, and fences before and
7511 // after the loop to restore order.
7512 //
7513 // Atomic instructions efficiently supported by PTX are lowered to
7514 // `atom.<op>.<sem>.<scope` instruction with their corresponding memory order
7515 // and scope. Since PTX does not support seq_cst, we emulate it by lowering to
7516 // a fence.sc followed by an atom according to the PTX atomics ABI
7517 // https://docs.nvidia.com/cuda/ptx-writers-guide-to-interoperability/atomic-abi.html
7518 if (auto *CI = dyn_cast<AtomicCmpXchgInst>(Val: I))
7519 return (cast<IntegerType>(Val: CI->getCompareOperand()->getType())
7520 ->getBitWidth() < STI.getMinCmpXchgSizeInBits()) ||
7521 CI->getMergedOrdering() == AtomicOrdering::SequentiallyConsistent;
7522 if (auto *RI = dyn_cast<AtomicRMWInst>(Val: I))
7523 return shouldExpandAtomicRMWInIR(AI: RI) == AtomicExpansionKind::CmpXChg ||
7524 RI->getOrdering() == AtomicOrdering::SequentiallyConsistent;
7525 return false;
7526}
7527
7528AtomicOrdering NVPTXTargetLowering::atomicOperationOrderAfterFenceSplit(
7529 const Instruction *I) const {
7530 // If the operation is emulated by a CAS-loop, we lower the instruction to
7531 // atom.<op>.relaxed, since AtomicExpandPass will insert fences for enforcing
7532 // the correct memory ordering around the CAS loop.
7533 //
7534 // When the operation is not emulated, but the memory order is seq_cst,
7535 // we must lower to "fence.sc.<scope>; atom.<op>.acquire.<scope>;" to conform
7536 // to the PTX atomics ABI.
7537 // https://docs.nvidia.com/cuda/ptx-writers-guide-to-interoperability/atomic-abi.html
7538 // For such cases, emitLeadingFence() will separately insert the leading
7539 // "fence.sc.<scope>;". Here, we only set the memory order to acquire.
7540 //
7541 // Otherwise, the operation is not emulated, and the memory order is not
7542 // seq_cst. In this case, the LLVM memory order is natively supported by the
7543 // PTX `atom` instruction, and we just lower to the corresponding
7544 // `atom.<op>.relaxed|acquire|release|acq_rel". For such cases, this function
7545 // will NOT be called.
7546 // prerequisite: shouldInsertFencesForAtomic() should have returned `true` for
7547 // I before its memory order was modified.
7548 if (auto *CI = dyn_cast<AtomicCmpXchgInst>(Val: I);
7549 CI && CI->getMergedOrdering() == AtomicOrdering::SequentiallyConsistent &&
7550 cast<IntegerType>(Val: CI->getCompareOperand()->getType())->getBitWidth() >=
7551 STI.getMinCmpXchgSizeInBits())
7552 return AtomicOrdering::Acquire;
7553 else if (auto *RI = dyn_cast<AtomicRMWInst>(Val: I);
7554 RI && RI->getOrdering() == AtomicOrdering::SequentiallyConsistent &&
7555 shouldExpandAtomicRMWInIR(AI: RI) == AtomicExpansionKind::None)
7556 return AtomicOrdering::Acquire;
7557
7558 return AtomicOrdering::Monotonic;
7559}
7560
7561Instruction *NVPTXTargetLowering::emitLeadingFence(IRBuilderBase &Builder,
7562 Instruction *Inst,
7563 AtomicOrdering Ord) const {
7564 // prerequisite: shouldInsertFencesForAtomic() should have returned `true` for
7565 // `Inst` before its memory order was modified. We cannot enforce this with an
7566 // assert, because AtomicExpandPass will have modified the memory order
7567 // between the initial call to shouldInsertFencesForAtomic() and the call to
7568 // this function.
7569 if (!isa<AtomicCmpXchgInst>(Val: Inst) && !isa<AtomicRMWInst>(Val: Inst))
7570 return TargetLoweringBase::emitLeadingFence(Builder, Inst, Ord);
7571
7572 // Specialize for cmpxchg and atomicrmw
7573 auto SSID = getAtomicSyncScopeID(I: Inst);
7574 assert(SSID.has_value() && "Expected an atomic operation");
7575
7576 if (isReleaseOrStronger(AO: Ord))
7577 return Builder.CreateFence(Ordering: Ord == AtomicOrdering::SequentiallyConsistent
7578 ? AtomicOrdering::SequentiallyConsistent
7579 : AtomicOrdering::Release,
7580 SSID: SSID.value());
7581
7582 return nullptr;
7583}
7584
7585Instruction *NVPTXTargetLowering::emitTrailingFence(IRBuilderBase &Builder,
7586 Instruction *Inst,
7587 AtomicOrdering Ord) const {
7588 // prerequisite: shouldInsertFencesForAtomic() should have returned `true` for
7589 // `Inst` before its memory order was modified. See `emitLeadingFence` for why
7590 // this cannot be enforced with an assert. Specialize for cmpxchg and
7591 // atomicrmw
7592 auto *CI = dyn_cast<AtomicCmpXchgInst>(Val: Inst);
7593 auto *RI = dyn_cast<AtomicRMWInst>(Val: Inst);
7594 if (!CI && !RI)
7595 return TargetLoweringBase::emitTrailingFence(Builder, Inst, Ord);
7596
7597 auto SSID = getAtomicSyncScopeID(I: Inst);
7598 assert(SSID.has_value() && "Expected an atomic operation");
7599
7600 bool IsEmulated =
7601 CI ? cast<IntegerType>(Val: CI->getCompareOperand()->getType())
7602 ->getBitWidth() < STI.getMinCmpXchgSizeInBits()
7603 : shouldExpandAtomicRMWInIR(AI: RI) == AtomicExpansionKind::CmpXChg;
7604
7605 if (isAcquireOrStronger(AO: Ord) && IsEmulated)
7606 return Builder.CreateFence(Ordering: AtomicOrdering::Acquire, SSID: SSID.value());
7607
7608 return nullptr;
7609}
7610
7611// Rather than default to SINT when both UINT and SINT are custom, we only
7612// change the opcode when UINT is not legal and SINT is. UINT is preferred when
7613// both are custom since unsigned CVT instructions can lead to slightly better
7614// SASS code with fewer instructions.
7615unsigned NVPTXTargetLowering::getPreferredFPToIntOpcode(unsigned Op, EVT FromVT,
7616 EVT ToVT) const {
7617 if (isOperationLegal(Op, VT: ToVT))
7618 return Op;
7619 switch (Op) {
7620 case ISD::FP_TO_UINT:
7621 if (isOperationLegal(Op: ISD::FP_TO_SINT, VT: ToVT))
7622 return ISD::FP_TO_SINT;
7623 break;
7624 case ISD::STRICT_FP_TO_UINT:
7625 if (isOperationLegal(Op: ISD::STRICT_FP_TO_SINT, VT: ToVT))
7626 return ISD::STRICT_FP_TO_SINT;
7627 break;
7628 case ISD::VP_FP_TO_UINT:
7629 if (isOperationLegal(Op: ISD::VP_FP_TO_SINT, VT: ToVT))
7630 return ISD::VP_FP_TO_SINT;
7631 break;
7632 default:
7633 break;
7634 }
7635 return Op;
7636}
7637
7638// Pin NVPTXTargetObjectFile's vtables to this file.
7639NVPTXTargetObjectFile::~NVPTXTargetObjectFile() = default;
7640
7641MCSection *NVPTXTargetObjectFile::SelectSectionForGlobal(
7642 const GlobalObject *GO, SectionKind Kind, const TargetMachine &TM) const {
7643 return getDataSection();
7644}
7645
7646static void computeKnownBitsForPRMT(const SDValue Op, KnownBits &Known,
7647 const SelectionDAG &DAG, unsigned Depth) {
7648 SDValue A = Op.getOperand(i: 0);
7649 SDValue B = Op.getOperand(i: 1);
7650 ConstantSDNode *Selector = dyn_cast<ConstantSDNode>(Val: Op.getOperand(i: 2));
7651 unsigned Mode = Op.getConstantOperandVal(i: 3);
7652
7653 if (!Selector)
7654 return;
7655
7656 KnownBits AKnown = DAG.computeKnownBits(Op: A, Depth);
7657 KnownBits BKnown = DAG.computeKnownBits(Op: B, Depth);
7658
7659 // {b, a} = {{b7, b6, b5, b4}, {b3, b2, b1, b0}}
7660 assert(AKnown.getBitWidth() == 32 && BKnown.getBitWidth() == 32 &&
7661 "PRMT must have i32 operands");
7662 assert(Known.getBitWidth() == 32 && "PRMT must have i32 result");
7663 KnownBits BitField = BKnown.concat(Lo: AKnown);
7664
7665 APInt SelectorVal = getPRMTSelector(Selector: Selector->getAPIntValue(), Mode);
7666 for (unsigned I : llvm::seq(Size: 4)) {
7667 APInt Sel = SelectorVal.extractBits(numBits: 4, bitPosition: I * 4);
7668 unsigned Idx = Sel.getLoBits(numBits: 3).getZExtValue();
7669 unsigned Sign = Sel.getHiBits(numBits: 1).getZExtValue();
7670 KnownBits Byte = BitField.extractBits(NumBits: 8, BitPosition: Idx * 8);
7671 if (Sign)
7672 Byte = KnownBits::ashr(LHS: Byte, RHS: 8);
7673 Known.insertBits(SubBits: Byte, BitPosition: I * 8);
7674 }
7675}
7676
7677static void computeKnownBitsForLoadV(const SDValue Op, KnownBits &Known) {
7678 MemSDNode *LD = cast<MemSDNode>(Val: Op);
7679
7680 // We can't do anything without knowing the sign bit.
7681 auto ExtType = LD->getConstantOperandVal(Num: LD->getNumOperands() - 1);
7682 if (ExtType == ISD::SEXTLOAD)
7683 return;
7684
7685 // ExtLoading to vector types is weird and may not work well with known bits.
7686 auto DestVT = LD->getValueType(ResNo: 0);
7687 if (DestVT.isVector())
7688 return;
7689
7690 assert(Known.getBitWidth() == DestVT.getSizeInBits());
7691 auto ElementBitWidth = NVPTXDAGToDAGISel::getFromTypeWidthForLoad(Mem: LD);
7692 Known.Zero.setHighBits(Known.getBitWidth() - ElementBitWidth);
7693}
7694
7695void NVPTXTargetLowering::computeKnownBitsForTargetNode(
7696 const SDValue Op, KnownBits &Known, const APInt &DemandedElts,
7697 const SelectionDAG &DAG, unsigned Depth) const {
7698 Known.resetAll();
7699
7700 switch (Op.getOpcode()) {
7701 case NVPTXISD::PRMT:
7702 computeKnownBitsForPRMT(Op, Known, DAG, Depth);
7703 break;
7704 case NVPTXISD::LoadV2:
7705 case NVPTXISD::LoadV4:
7706 case NVPTXISD::LoadV8:
7707 computeKnownBitsForLoadV(Op, Known);
7708 break;
7709 default:
7710 break;
7711 }
7712}
7713
7714static std::pair<APInt, APInt> getPRMTDemandedBits(const APInt &SelectorVal,
7715 const APInt &DemandedBits) {
7716 APInt DemandedLHS = APInt(32, 0);
7717 APInt DemandedRHS = APInt(32, 0);
7718
7719 for (unsigned I : llvm::seq(Size: 4)) {
7720 if (DemandedBits.extractBits(numBits: 8, bitPosition: I * 8).isZero())
7721 continue;
7722
7723 APInt Sel = SelectorVal.extractBits(numBits: 4, bitPosition: I * 4);
7724 unsigned Idx = Sel.getLoBits(numBits: 3).getZExtValue();
7725 unsigned Sign = Sel.getHiBits(numBits: 1).getZExtValue();
7726
7727 APInt &Src = Idx < 4 ? DemandedLHS : DemandedRHS;
7728 unsigned ByteStart = (Idx % 4) * 8;
7729 if (Sign)
7730 Src.setBit(ByteStart + 7);
7731 else
7732 Src.setBits(loBit: ByteStart, hiBit: ByteStart + 8);
7733 }
7734
7735 return {DemandedLHS, DemandedRHS};
7736}
7737
7738// Replace undef with 0 as this is easier for other optimizations such as
7739// known bits.
7740static SDValue canonicalizePRMTInput(SDValue Op, SelectionDAG &DAG) {
7741 if (!Op)
7742 return SDValue();
7743 if (Op.isUndef())
7744 return DAG.getConstant(Val: 0, DL: SDLoc(), VT: MVT::i32);
7745 return Op;
7746}
7747
7748static SDValue simplifyDemandedBitsForPRMT(SDValue PRMT,
7749 const APInt &DemandedBits,
7750 SelectionDAG &DAG,
7751 const TargetLowering &TLI,
7752 unsigned Depth) {
7753 assert(PRMT.getOpcode() == NVPTXISD::PRMT);
7754 SDValue Op0 = PRMT.getOperand(i: 0);
7755 SDValue Op1 = PRMT.getOperand(i: 1);
7756 auto *SelectorConst = dyn_cast<ConstantSDNode>(Val: PRMT.getOperand(i: 2));
7757 if (!SelectorConst)
7758 return SDValue();
7759
7760 unsigned Mode = PRMT.getConstantOperandVal(i: 3);
7761 const APInt Selector = getPRMTSelector(Selector: SelectorConst->getAPIntValue(), Mode);
7762
7763 // Try to simplify the PRMT to one of the inputs if the used bytes are all
7764 // from the same input in the correct order.
7765 const unsigned LeadingBytes = DemandedBits.countLeadingZeros() / 8;
7766 const unsigned SelBits = (4 - LeadingBytes) * 4;
7767 if (Selector.getLoBits(numBits: SelBits) == APInt(32, 0x3210).getLoBits(numBits: SelBits))
7768 return Op0;
7769 if (Selector.getLoBits(numBits: SelBits) == APInt(32, 0x7654).getLoBits(numBits: SelBits))
7770 return Op1;
7771
7772 auto [DemandedLHS, DemandedRHS] = getPRMTDemandedBits(SelectorVal: Selector, DemandedBits);
7773
7774 // Attempt to avoid multi-use ops if we don't need anything from them.
7775 SDValue DemandedOp0 =
7776 TLI.SimplifyMultipleUseDemandedBits(Op: Op0, DemandedBits: DemandedLHS, DAG, Depth: Depth + 1);
7777 SDValue DemandedOp1 =
7778 TLI.SimplifyMultipleUseDemandedBits(Op: Op1, DemandedBits: DemandedRHS, DAG, Depth: Depth + 1);
7779
7780 DemandedOp0 = canonicalizePRMTInput(Op: DemandedOp0, DAG);
7781 DemandedOp1 = canonicalizePRMTInput(Op: DemandedOp1, DAG);
7782 if ((DemandedOp0 && DemandedOp0 != Op0) ||
7783 (DemandedOp1 && DemandedOp1 != Op1)) {
7784 Op0 = DemandedOp0 ? DemandedOp0 : Op0;
7785 Op1 = DemandedOp1 ? DemandedOp1 : Op1;
7786 return getPRMT(A: Op0, B: Op1, Selector: Selector.getZExtValue(), DL: SDLoc(PRMT), DAG);
7787 }
7788
7789 return SDValue();
7790}
7791
7792bool NVPTXTargetLowering::SimplifyDemandedBitsForTargetNode(
7793 SDValue Op, const APInt &DemandedBits, const APInt &DemandedElts,
7794 KnownBits &Known, TargetLoweringOpt &TLO, unsigned Depth) const {
7795 Known.resetAll();
7796
7797 switch (Op.getOpcode()) {
7798 case NVPTXISD::PRMT:
7799 if (SDValue Result = simplifyDemandedBitsForPRMT(PRMT: Op, DemandedBits, DAG&: TLO.DAG,
7800 TLI: *this, Depth)) {
7801 TLO.CombineTo(O: Op, N: Result);
7802 return true;
7803 }
7804 break;
7805 default:
7806 break;
7807 }
7808
7809 computeKnownBitsForTargetNode(Op, Known, DemandedElts, DAG: TLO.DAG, Depth);
7810 return false;
7811}
7812